Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add meta class to handle config saving #591

Merged
merged 9 commits into from
Sep 23, 2023
11 changes: 8 additions & 3 deletions src/graphnet/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from graphnet.utilities.config import (
Configurable,
DatasetConfig,
save_dataset_config,
DatasetConfigSaverABCMeta,
)
from graphnet.utilities.config.parsing import traverse_and_apply
from graphnet.utilities.logging import Logger
Expand Down Expand Up @@ -85,7 +85,13 @@ def parse_graph_definition(cfg: dict) -> GraphDefinition:
return graph_definition


class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
class Dataset(
Logger,
Configurable,
torch.utils.data.Dataset,
ABC,
metaclass=DatasetConfigSaverABCMeta,
):
"""Base Dataset class for reading from any intermediate file format."""

# Class method(s)
Expand Down Expand Up @@ -188,7 +194,6 @@ def _resolve_graphnet_paths(
.replace("${GRAPHNET}", GRAPHNET_ROOT_DIR)
)

@save_dataset_config
def __init__(
self,
path: Union[str, List[str]],
Expand Down
3 changes: 0 additions & 3 deletions src/graphnet/models/coarsening.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
std_pool_x,
)
from graphnet.models import Model
from graphnet.utilities.config import save_model_config

# Utility method(s)
from torch_geometric.utils import degree
Expand Down Expand Up @@ -63,7 +62,6 @@ class Coarsening(Model):
"sum": (sum_pool, sum_pool_x),
}

@save_model_config
def __init__(
self,
reduce: str = "avg",
Expand Down Expand Up @@ -198,7 +196,6 @@ def _add_inc_dict(self, original: Data, pooled: Data) -> Data:
class AttributeCoarsening(Coarsening):
"""Coarsen pulses based on specified attributes."""

@save_model_config
def __init__(
self,
attributes: List[str],
Expand Down
2 changes: 0 additions & 2 deletions src/graphnet/models/detector/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@

from graphnet.models import Model
from graphnet.utilities.decorators import final
from graphnet.utilities.config import save_model_config


class Detector(Model):
"""Base class for all detector-specific read-ins in graphnet."""

@save_model_config
def __init__(self) -> None:
"""Construct `Detector`."""
# Base class constructor
Expand Down
2 changes: 0 additions & 2 deletions src/graphnet/models/gnn/convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
from torch_geometric.nn import TAGConv, global_add_pool, global_max_pool
from torch_geometric.data import Data

from graphnet.utilities.config import save_model_config
from graphnet.models.gnn.gnn import GNN


class ConvNet(GNN):
"""ConvNet (convolutional network) model."""

@save_model_config
def __init__(
self,
nb_inputs: int,
Expand Down
2 changes: 0 additions & 2 deletions src/graphnet/models/gnn/dynedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum

from graphnet.models.components.layers import DynEdgeConv
from graphnet.utilities.config import save_model_config
from graphnet.models.gnn.gnn import GNN
from graphnet.models.utils import calculate_xyzt_homophily

Expand All @@ -22,7 +21,6 @@
class DynEdge(GNN):
"""DynEdge (dynamical edge convolutional) model."""

@save_model_config
def __init__(
self,
nb_inputs: int,
Expand Down
2 changes: 0 additions & 2 deletions src/graphnet/models/gnn/dynedge_jinst.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@
from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum

from graphnet.models.components.layers import DynEdgeConv
from graphnet.utilities.config import save_model_config
from graphnet.models.gnn.gnn import GNN
from graphnet.models.utils import calculate_xyzt_homophily


class DynEdgeJINST(GNN):
"""DynEdge (dynamical edge convolutional) model used in [2209.03042]."""

@save_model_config
def __init__(
self,
nb_inputs: int,
Expand Down
2 changes: 0 additions & 2 deletions src/graphnet/models/gnn/dynedge_kaggle_tito.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum

from graphnet.models.components.layers import DynTrans
from graphnet.utilities.config import save_model_config
from graphnet.models.gnn.gnn import GNN
from graphnet.models.utils import calculate_xyzt_homophily

Expand All @@ -33,7 +32,6 @@
class DynEdgeTITO(GNN):
"""DynEdge (dynamical edge convolutional) model."""

@save_model_config
def __init__(
self,
nb_inputs: int,
Expand Down
2 changes: 0 additions & 2 deletions src/graphnet/models/gnn/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
from torch_geometric.data import Data

from graphnet.models import Model
from graphnet.utilities.config import save_model_config


class GNN(Model):
"""Base class for all core GNN models in graphnet."""

@save_model_config
def __init__(self, nb_inputs: int, nb_outputs: int) -> None:
"""Construct `GNN`."""
# Base class constructor
Expand Down
4 changes: 0 additions & 4 deletions src/graphnet/models/graphs/edges/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from torch_geometric.nn import knn_graph, radius_graph
from torch_geometric.data import Data

from graphnet.utilities.config import save_model_config
from graphnet.models.utils import calculate_distance_matrix
from graphnet.models import Model

Expand Down Expand Up @@ -48,7 +47,6 @@ def _construct_edges(self, graph: Data) -> Data:
class KNNEdges(EdgeDefinition): # pylint: disable=too-few-public-methods
"""Builds edges from the k-nearest neighbours."""

@save_model_config
def __init__(
self,
nb_nearest_neighbours: int,
Expand Down Expand Up @@ -85,7 +83,6 @@ def _construct_edges(self, graph: Data) -> Data:
class RadialEdges(EdgeDefinition):
"""Builds graph from a sphere of chosen radius centred at each node."""

@save_model_config
def __init__(
self,
radius: float,
Expand Down Expand Up @@ -126,7 +123,6 @@ class EuclideanEdges(EdgeDefinition): # pylint: disable=too-few-public-methods
See https://arxiv.org/pdf/1809.06166.pdf.
"""

@save_model_config
def __init__(
self,
sigma: float,
Expand Down
3 changes: 0 additions & 3 deletions src/graphnet/models/graphs/graph_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from torch_geometric.data import Data
import numpy as np

from graphnet.utilities.config import save_model_config

from graphnet.models.detector import Detector
from .edges import EdgeDefinition
from .nodes import NodeDefinition
Expand All @@ -22,7 +20,6 @@
class GraphDefinition(Model):
"""An Abstract class to create graph definitions from."""

@save_model_config
def __init__(
self,
detector: Detector,
Expand Down
2 changes: 0 additions & 2 deletions src/graphnet/models/graphs/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import List, Optional
import torch

from graphnet.utilities.config import save_model_config
from .graph_definition import GraphDefinition
from graphnet.models.detector import Detector
from graphnet.models.graphs.edges import EdgeDefinition, KNNEdges
Expand All @@ -13,7 +12,6 @@
class KNNGraph(GraphDefinition):
"""A Graph representation where Edges are drawn to nearest neighbours."""

@save_model_config
def __init__(
self,
detector: Detector,
Expand Down
2 changes: 0 additions & 2 deletions src/graphnet/models/graphs/nodes/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
from torch_geometric.data import Data

from graphnet.utilities.decorators import final
from graphnet.utilities.config import save_model_config
from graphnet.models import Model


class NodeDefinition(Model): # pylint: disable=too-few-public-methods
"""Base class for graph building."""

@save_model_config
def __init__(self) -> None:
"""Construct `Detector`."""
# Base class constructor
Expand Down
10 changes: 8 additions & 2 deletions src/graphnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@
from torch_geometric.data import Data

from graphnet.utilities.logging import Logger
from graphnet.utilities.config import Configurable, ModelConfig
from graphnet.utilities.config import (
Configurable,
ModelConfig,
ModelConfigSaverABC,
)
from graphnet.training.callbacks import ProgressBar


class Model(Logger, Configurable, LightningModule, ABC):
class Model(
Logger, Configurable, LightningModule, ABC, metaclass=ModelConfigSaverABC
):
"""Base class for all models in graphnet."""

@abstractmethod
Expand Down
2 changes: 0 additions & 2 deletions src/graphnet/models/standard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch_geometric.data import Data
import pandas as pd

from graphnet.utilities.config import save_model_config
from graphnet.models.graphs import GraphDefinition
from graphnet.models.gnn.gnn import GNN
from graphnet.models.model import Model
Expand All @@ -24,7 +23,6 @@ class StandardModel(Model):
model (detector read-in, GNN architecture, and task-specific read-outs).
"""

@save_model_config
def __init__(
self,
*,
Expand Down
3 changes: 0 additions & 3 deletions src/graphnet/models/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from graphnet.training.loss_functions import LossFunction # type: ignore[attr-defined]

from graphnet.models import Model
from graphnet.utilities.config import save_model_config
from graphnet.utilities.decorators import final


Expand All @@ -39,7 +38,6 @@ def default_prediction_labels(self) -> List[str]:
"""Return default prediction labels."""
return self._default_prediction_labels

@save_model_config
def __init__(
self,
*,
Expand Down Expand Up @@ -264,7 +262,6 @@ def _validate_and_set_transforms(
class IdentityTask(Task):
"""Identity, or trivial, task."""

@save_model_config
def __init__(
self,
nb_outputs: int,
Expand Down
3 changes: 0 additions & 3 deletions src/graphnet/training/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@
softplus,
)

from graphnet.utilities.config import save_model_config
from graphnet.models.model import Model
from graphnet.utilities.decorators import final


class LossFunction(Model):
"""Base class for loss functions in `graphnet`."""

@save_model_config
def __init__(self, **kwargs: Any) -> None:
"""Construct `LossFunction`, saving model config."""
super().__init__(**kwargs)
Expand Down Expand Up @@ -120,7 +118,6 @@ class CrossEntropyLoss(LossFunction):
(0, num_classes - 1).
"""

@save_model_config
def __init__(
self,
options: Union[int, List[Any], Dict[Any, int]],
Expand Down
6 changes: 3 additions & 3 deletions src/graphnet/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
def collate_fn(graphs: List[Data]) -> Batch:
"""Remove graphs with less than two DOM hits.

Should not occur in "production.
Should not occur in "production".
"""
graphs = [g for g in graphs if g.n_pulses > 1]
return Batch.from_data_list(graphs)
Expand All @@ -32,7 +32,7 @@ def collate_fn(graphs: List[Data]) -> Batch:
def make_dataloader(
db: str,
pulsemaps: Union[str, List[str]],
graph_definition: Optional[GraphDefinition],
graph_definition: GraphDefinition,
features: List[str],
truth: List[str],
*,
Expand Down Expand Up @@ -92,7 +92,7 @@ def make_dataloader(
# @TODO: Remove in favour of DataLoader{,.from_dataset_config}
def make_train_validation_dataloader(
db: str,
graph_definition: Optional[GraphDefinition],
graph_definition: GraphDefinition,
selection: Optional[List[int]],
pulsemaps: Union[str, List[str]],
features: List[str],
Expand Down
14 changes: 12 additions & 2 deletions src/graphnet/utilities/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
"""Modules for configuration files for use across `graphnet`."""

from .configurable import Configurable
from .dataset_config import DatasetConfig, save_dataset_config
from .model_config import ModelConfig, save_model_config
from .dataset_config import (
DatasetConfig,
DatasetConfigSaverMeta,
DatasetConfigSaverABCMeta,
save_dataset_config,
)
from .model_config import (
ModelConfig,
ModelConfigSaverMeta,
ModelConfigSaverABC,
save_model_config,
)
from .training_config import TrainingConfig
Loading