diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py index ea5066307..6974cdddc 100644 --- a/src/graphnet/models/graphs/__init__.py +++ b/src/graphnet/models/graphs/__init__.py @@ -7,4 +7,4 @@ from .graph_definition import GraphDefinition -from .graphs import KNNGraph +from .graphs import KNNGraph, IsolatedNodes diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index d486bba0a..bdb442f09 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -54,3 +54,41 @@ def __init__( perturbation_dict=perturbation_dict, seed=seed, ) + + +class IsolatedNodes(GraphDefinition): + """A Graph representation where each node is isolated.""" + + def __init__( + self, + detector: Detector, + node_definition: NodeDefinition = None, + input_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, + ) -> None: + """Construct isolated nodes graph representation. + + Args: + detector: Detector that represents your data. + node_definition: Definition of nodes in the graph. + input_feature_names: Name of input feature columns. + dtype: data type for node features. + perturbation_dict: Dictionary mapping a feature name to a standard + deviation according to which the values for this + feature should be randomly perturbed. Defaults + to None. + seed: seed or Generator used to randomly sample perturbations. + Defaults to None. + """ + # Base class constructor + super().__init__( + detector=detector, + node_definition=node_definition or NodesAsPulses(), + edge_definition=None, + dtype=dtype, + input_feature_names=input_feature_names, + perturbation_dict=perturbation_dict, + seed=seed, + )