From 169f8fd43bd72da9483a0df543373f6ac00beb9b Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 14 May 2024 11:56:59 +0900 Subject: [PATCH 1/3] add isolated nodes --- src/graphnet/models/graphs/__init__.py | 2 +- src/graphnet/models/graphs/graphs.py | 38 ++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py index ea5066307..6974cdddc 100644 --- a/src/graphnet/models/graphs/__init__.py +++ b/src/graphnet/models/graphs/__init__.py @@ -7,4 +7,4 @@ from .graph_definition import GraphDefinition -from .graphs import KNNGraph +from .graphs import KNNGraph, IsolatedNodes diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index d486bba0a..bdb442f09 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -54,3 +54,41 @@ def __init__( perturbation_dict=perturbation_dict, seed=seed, ) + + +class IsolatedNodes(GraphDefinition): + """A Graph representation where each node is isolated.""" + + def __init__( + self, + detector: Detector, + node_definition: NodeDefinition = None, + input_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, + ) -> None: + """Construct isolated nodes graph representation. + + Args: + detector: Detector that represents your data. + node_definition: Definition of nodes in the graph. + input_feature_names: Name of input feature columns. + dtype: data type for node features. + perturbation_dict: Dictionary mapping a feature name to a standard + deviation according to which the values for this + feature should be randomly perturbed. Defaults + to None. + seed: seed or Generator used to randomly sample perturbations. + Defaults to None. + """ + # Base class constructor + super().__init__( + detector=detector, + node_definition=node_definition or NodesAsPulses(), + edge_definition=None, + dtype=dtype, + input_feature_names=input_feature_names, + perturbation_dict=perturbation_dict, + seed=seed, + ) From c9eadcc5ec9fddae85bff9962f15195719aa75e2 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Mon, 20 May 2024 14:09:42 +0900 Subject: [PATCH 2/3] docstring update --- src/graphnet/models/components/embedding.py | 2 +- src/graphnet/models/graphs/__init__.py | 2 +- src/graphnet/models/graphs/graphs.py | 7 +++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index e97ca90e7..1b49cd901 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -84,7 +84,6 @@ def __init__( super().__init__() self.sin_emb = SinusoidalPosEmb(dim=seq_length, scaled=scaled) - self.aux_emb = nn.Embedding(2, seq_length // 2) self.sin_emb2 = SinusoidalPosEmb(dim=seq_length // 2, scaled=scaled) if n_features < 4: @@ -93,6 +92,7 @@ def __init__( f"{n_features} features." ) elif n_features >= 6: + self.aux_emb = nn.Embedding(2, seq_length // 2) hidden_dim = 6 * seq_length else: hidden_dim = int((n_features + 0.5) * seq_length) diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py index 6974cdddc..a07d1308d 100644 --- a/src/graphnet/models/graphs/__init__.py +++ b/src/graphnet/models/graphs/__init__.py @@ -7,4 +7,4 @@ from .graph_definition import GraphDefinition -from .graphs import KNNGraph, IsolatedNodes +from .graphs import KNNGraph, EdgelessGraph diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index bdb442f09..0289b943d 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -56,8 +56,11 @@ def __init__( ) -class IsolatedNodes(GraphDefinition): - """A Graph representation where each node is isolated.""" +class EdgelessGraph(GraphDefinition): + """A Data representation without edge assignment. + + I.e the resulting representation is created without an EdgeDefinition. + """ def __init__( self, From 96c7c695ed7cb93dd69b258c49cb4d70712c3e1c Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 21 May 2024 10:49:10 +0900 Subject: [PATCH 3/3] revert embedding changes --- src/graphnet/models/components/embedding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index 1b49cd901..9539fc444 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -84,6 +84,7 @@ def __init__( super().__init__() self.sin_emb = SinusoidalPosEmb(dim=seq_length, scaled=scaled) + self.aux_emb = nn.Embedding(2, seq_length // 2) self.sin_emb2 = SinusoidalPosEmb(dim=seq_length // 2, scaled=scaled) if n_features < 4: @@ -92,7 +93,7 @@ def __init__( f"{n_features} features." ) elif n_features >= 6: - self.aux_emb = nn.Embedding(2, seq_length // 2) + hidden_dim = 6 * seq_length else: hidden_dim = int((n_features + 0.5) * seq_length)