Skip to content

Commit

Permalink
add isolated nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed May 14, 2024
1 parent 6c88786 commit 169f8fd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/graphnet/models/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@


from .graph_definition import GraphDefinition
from .graphs import KNNGraph
from .graphs import KNNGraph, IsolatedNodes
38 changes: 38 additions & 0 deletions src/graphnet/models/graphs/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,41 @@ def __init__(
perturbation_dict=perturbation_dict,
seed=seed,
)


class IsolatedNodes(GraphDefinition):
"""A Graph representation where each node is isolated."""

def __init__(
self,
detector: Detector,
node_definition: NodeDefinition = None,
input_feature_names: Optional[List[str]] = None,
dtype: Optional[torch.dtype] = torch.float,
perturbation_dict: Optional[Dict[str, float]] = None,
seed: Optional[Union[int, Generator]] = None,
) -> None:
"""Construct isolated nodes graph representation.
Args:
detector: Detector that represents your data.
node_definition: Definition of nodes in the graph.
input_feature_names: Name of input feature columns.
dtype: data type for node features.
perturbation_dict: Dictionary mapping a feature name to a standard
deviation according to which the values for this
feature should be randomly perturbed. Defaults
to None.
seed: seed or Generator used to randomly sample perturbations.
Defaults to None.
"""
# Base class constructor
super().__init__(
detector=detector,
node_definition=node_definition or NodesAsPulses(),
edge_definition=None,
dtype=dtype,
input_feature_names=input_feature_names,
perturbation_dict=perturbation_dict,
seed=seed,
)

0 comments on commit 169f8fd

Please sign in to comment.