diff --git a/examples/train_model.py b/examples/train_model.py index 656292ffe..c30fb49c6 100644 --- a/examples/train_model.py +++ b/examples/train_model.py @@ -93,6 +93,7 @@ def main(): ) gnn = DynEdge( nb_inputs=detector.nb_outputs, + global_pooling_schemes=["min", "max", "mean", "sum"], ) task = EnergyReconstruction( hidden_size=gnn.nb_outputs, diff --git a/src/graphnet/models/components/layers.py b/src/graphnet/models/components/layers.py index b0ce4477e..80c1f8c3f 100644 --- a/src/graphnet/models/components/layers.py +++ b/src/graphnet/models/components/layers.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence +from typing import Callable, List, Optional, Sequence, Union from torch.functional import Tensor @@ -13,7 +13,7 @@ def __init__( nn: Callable, aggr: str = "max", nb_neighbors: int = 8, - features_subset: Optional[Sequence] = None, + features_subset: Optional[Union[Sequence[int], List[int]]] = None, **kwargs, ): # Check(s) diff --git a/src/graphnet/models/gnn/__init__.py b/src/graphnet/models/gnn/__init__.py index c16d9c990..9c2cef2a0 100644 --- a/src/graphnet/models/gnn/__init__.py +++ b/src/graphnet/models/gnn/__init__.py @@ -1,2 +1,3 @@ -from .dynedge import DynEdge, DynEdge_V2, DynEdge_V3 from .convnet import ConvNet +from .dynedge import DynEdge +from .dynedge_jinst import DynEdgeJINST diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py index c4de255b2..74647aa3c 100644 --- a/src/graphnet/models/gnn/dynedge.py +++ b/src/graphnet/models/gnn/dynedge.py @@ -1,269 +1,255 @@ -"""Implementation of the DynEdge GNN model architecture. +"""Implementation of the DynEdge GNN model architecture.""" +from typing import List, Optional, Tuple, Union -[Description of what this architecture does.] - -Author: Rasmus Oersoe -Email: ###@###.### -""" -from typing import List, Optional, Union import torch -from torch import Tensor +from torch import Tensor, LongTensor from torch_geometric.data import Data -from torch_geometric.nn import EdgeConv from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum -from graphnet.models.components.layers import DynEdgeConv -from graphnet.models.coarsening import Coarsening +from graphnet.models.components.layers import DynEdgeConv from graphnet.models.gnn.gnn import GNN from graphnet.models.utils import calculate_xyzt_homophily +GLOBAL_POOLINGS = { + "min": scatter_min, + "max": scatter_max, + "sum": scatter_sum, + "mean": scatter_mean, +} + class DynEdge(GNN): def __init__( self, nb_inputs: int, - layer_size_scale: Optional[int] = 4, - node_pooling: Coarsening = None, + *, + nb_neighbours: Optional[int] = 8, + features_subset: Optional[List[int]] = None, + dynedge_layer_sizes: Optional[List[Tuple[int]]] = None, + post_processing_layer_sizes: Optional[List[int]] = None, + readout_layer_sizes: Optional[List[int]] = None, + global_pooling_schemes: Optional[Union[str, List[str]]] = None, + add_global_variables_after_pooling: bool = False, ): - """DynEdge model. + """DynEdge (dynamical edge convolutional) model. Args: - nb_inputs (int): Number of input features. - nb_outputs (int): Number of output features. - layer_size_scale (int, optional): Integer that scales the size of - hidden layers. Defaults to 4. - node_pooling: A Coarsening module that pools the nodes before they are processed by the model. Defaults to None (no pooling). + nb_inputs (int): Number of input features on each node + nb_neighbours (Optional[int], optional): Number of neighbours to + used in the k-nearest neighbour clustering which is performed + after each (dynamical) edge convolution. Defaults to 8. + features_subset (Optional[List[int]], optional): The subset of + latent features on each node that are used as metric dimensions + when performing the k-nearest neighbours clustering. Defaults + to slice(0,3). + dynedge_layer_sizes (Optional[List[Tuple[int]]], optional): The + layer sizes, or latent feature dimenions, used in the + `DynEdgeConv` layer. Each entry in `dynedge_layer_sizes` + corresponds to a single `DynEdgeConv` layer; the integers in + the corresponding tuple corresponds to the layer sizes in the + multi-layer perceptron (MLP) that is applied within each + `DynEdgeConv` layer. That is, a list of size-two tuples means + that all `DynEdgeConv` layers contain a two-layer MLP. + Defaults to [(128, 256), (336, 256), (336, 256), (336, 256)]. + post_processing_layer_sizes (Optional[List[int]], optional): Hidden + layer sizes in the MLP following the skip-concatenation of the + outputs of each `DynEdgeConv` layer. Defaults to [336, 256]. + readout_layer_sizes (Optional[List[int]], optional): Hidden layer + sizes in the MLP following the post-processing _and_ optional + global pooling. As this is the last layer(s) in the model, the + last layer in the read-out yields the output of the `DynEdge` + model. Defaults to [128,]. + global_pooling_schemes (Optional[Union[str, List[str]]], optional): + The list global pooling schemes to use. Options are: "min", + "max", "mean", and "sum". Defaults to None. + add_global_variables_after_pooling (bool, optional): Whether to add + global variables after global pooling. The alternative is to + added (distribute) them to the individual nodes before any + convolutional operations. Defaults to False. """ - # Node Pooling via Coarsening Module - self._coarsening = node_pooling - # Architecture configuration - c = layer_size_scale - l1, l2, l3, l4, l5, l6 = ( - nb_inputs, - c * 16 * 2, - c * 32 * 2, - c * 42 * 2, - c * 32 * 2, - c * 16 * 2, + # Latent feature subset for computing nearest neighbours in DynEdge. + if features_subset is None: + features_subset = slice(0, 3) + + # DynEdge layer sizes + if dynedge_layer_sizes is None: + dynedge_layer_sizes = [ + ( + 128, + 256, + ), + ( + 336, + 256, + ), + ( + 336, + 256, + ), + ( + 336, + 256, + ), + ] + + assert isinstance(dynedge_layer_sizes, list) + assert len(dynedge_layer_sizes) + assert all(isinstance(sizes, tuple) for sizes in dynedge_layer_sizes) + assert all(len(sizes) > 0 for sizes in dynedge_layer_sizes) + assert all( + all(size > 0 for size in sizes) for sizes in dynedge_layer_sizes ) - # Base class constructor - super().__init__(nb_inputs, l6) - - # Graph convolutional operations - features_subset = slice(0, 3) - nb_neighbors = 8 - - self.conv_add1 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l1 * 2, l2), - torch.nn.LeakyReLU(), - torch.nn.Linear(l2, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) - - self.conv_add2 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, + self._dynedge_layer_sizes = dynedge_layer_sizes + + # Post-processing layer sizes + if post_processing_layer_sizes is None: + post_processing_layer_sizes = [ + 336, + 256, + ] + + assert isinstance(post_processing_layer_sizes, list) + assert len(post_processing_layer_sizes) + assert all(size > 0 for size in post_processing_layer_sizes) + + self._post_processing_layer_sizes = post_processing_layer_sizes + + # Read-out layer sizes + if readout_layer_sizes is None: + readout_layer_sizes = [ + 128, + ] + + assert isinstance(readout_layer_sizes, list) + assert len(readout_layer_sizes) + assert all(size > 0 for size in readout_layer_sizes) + + self._readout_layer_sizes = readout_layer_sizes + + # Global pooling scheme(s) + if isinstance(global_pooling_schemes, str): + global_pooling_schemes = [global_pooling_schemes] + + if isinstance(global_pooling_schemes, list): + for pooling_scheme in global_pooling_schemes: + assert ( + pooling_scheme in GLOBAL_POOLINGS + ), f"Global pooling scheme {pooling_scheme} not supported." + else: + assert global_pooling_schemes is None + + self._global_pooling_schemes = global_pooling_schemes + + if add_global_variables_after_pooling: + assert self._global_pooling_schemes, ( + "No global pooling schemes were request, so cannot add global" + " variables after pooling." + ) + self._add_global_variables_after_pooling = ( + add_global_variables_after_pooling ) - self.conv_add3 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) - - self.conv_add4 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) + # Base class constructor + super().__init__(nb_inputs, self._readout_layer_sizes[-1]) + + # Remaining member variables() + self._activation = torch.nn.LeakyReLU() + self._nb_inputs = nb_inputs + self._nb_global_variables = 5 + nb_inputs + self._nb_neighbours = nb_neighbours + self._features_subset = features_subset + + self._construct_layers() + + def _construct_layers(self): + """Construct layers (torch.nn.Modules).""" + + # Convolutional operations + nb_input_features = self._nb_inputs + if not self._add_global_variables_after_pooling: + nb_input_features += self._nb_global_variables + + self._conv_layers = torch.nn.ModuleList() + nb_latent_features = nb_input_features + for sizes in self._dynedge_layer_sizes: + layers = [] + layer_sizes = [nb_latent_features] + list(sizes) + for ix, (nb_in, nb_out) in enumerate( + zip(layer_sizes[:-1], layer_sizes[1:]) + ): + if ix == 0: + nb_in *= 2 + layers.append(torch.nn.Linear(nb_in, nb_out)) + layers.append(self._activation) + + conv_layer = DynEdgeConv( + torch.nn.Sequential(*layers), + aggr="add", + nb_neighbors=self._nb_neighbours, + features_subset=self._features_subset, + ) + self._conv_layers.append(conv_layer) + + nb_latent_features = nb_out # Post-processing operations - self.nn1 = torch.nn.Linear(l3 * 4 + l1, l4) - self.nn2 = torch.nn.Linear(l4, l5) - self.nn3 = torch.nn.Linear(4 * l5 + 5, l6) - self.lrelu = torch.nn.LeakyReLU() - - def forward(self, data: Data) -> Tensor: - """Model forward pass. - - Args: - data (Data): Graph of input features. - - Returns: - Tensor: Model output. - """ - if self._coarsening is not None: - data = self._coarsening(data) - # Convenience variables - x, edge_index, batch = data.x, data.edge_index, data.batch - - # Calculate homophily (scalar variables) - h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch) - - a, edge_index = self.conv_add1(x, edge_index, batch) - b, edge_index = self.conv_add2(a, edge_index, batch) - c, edge_index = self.conv_add3(b, edge_index, batch) - d, edge_index = self.conv_add4(c, edge_index, batch) - - # Skip-cat - x = torch.cat((x, a, b, c, d), dim=1) - - # Post-processing - x = self.nn1(x) - x = self.lrelu(x) - x = self.nn2(x) - - # Aggregation across nodes - a, _ = scatter_max(x, batch, dim=0) - b, _ = scatter_min(x, batch, dim=0) - c = scatter_sum(x, batch, dim=0) - d = scatter_mean(x, batch, dim=0) - - # Concatenate aggregations and scalar features - x = torch.cat( - ( - a, - b, - c, - d, - h_t.reshape(-1, 1), - h_x.reshape(-1, 1), - h_y.reshape(-1, 1), - h_z.reshape(-1, 1), - data.n_pulses.reshape(-1, 1), - ), - dim=1, - ) - - # Read-out - x = self.lrelu(x) - x = self.nn3(x) - - x = self.lrelu(x) - - return x - - -class DynEdge_V2(GNN): - def __init__(self, nb_inputs, layer_size_scale=4): - """DynEdge model. - - Args: - nb_inputs (int): Number of input features. - nb_outputs (int): Number of output features. - layer_size_scale (int, optional): Integer that scales the size of - hidden layers. Defaults to 4. - """ - - # Architecture configuration - c = layer_size_scale - l1, l2, l3, l4, l5, l6 = ( - nb_inputs * 2 + 5, - c * 16 * 2, - c * 32 * 2, - c * 42 * 2, - c * 32 * 2, - c * 16 * 2, - ) - - # Base class constructor - super().__init__(nb_inputs, l6) - - # Graph convolutional operations - features_subset = slice(0, 3) - nb_neighbors = 8 - - self.conv_add1 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l1 * 2, l2), - torch.nn.LeakyReLU(), - torch.nn.Linear(l2, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, + nb_latent_features = ( + sum(sizes[-1] for sizes in self._dynedge_layer_sizes) + + nb_input_features ) - self.conv_add2 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, + post_processing_layers = [] + layer_sizes = [nb_latent_features] + list( + self._post_processing_layer_sizes ) + for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]): + post_processing_layers.append(torch.nn.Linear(nb_in, nb_out)) + post_processing_layers.append(self._activation) - self.conv_add3 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) + self._post_processing = torch.nn.Sequential(*post_processing_layers) - self.conv_add4 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, + # Read-out operations + nb_poolings = ( + len(self._global_pooling_schemes) + if self._global_pooling_schemes + else 1 ) - - # Post-processing operations - self.nn1 = torch.nn.Linear(l3 * 4 + l1, l4) - self.nn2 = torch.nn.Linear(l4, l5) - self.nn3 = torch.nn.Linear(3 * l5 + 0, l6) # 4*l5 + 5 - self.lrelu = torch.nn.LeakyReLU() - - def forward(self, data: Data) -> Tensor: - """Model forward pass. - - Args: - data (Data): Graph of input features. - - Returns: - Tensor: Model output. - """ - - # Convenience variables - x, edge_index, batch = data.x, data.edge_index, data.batch + nb_latent_features = nb_out * nb_poolings + if self._add_global_variables_after_pooling: + nb_latent_features += self._nb_global_variables + + readout_layers = [] + layer_sizes = [nb_latent_features] + list(self._readout_layer_sizes) + for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]): + readout_layers.append(torch.nn.Linear(nb_in, nb_out)) + readout_layers.append(self._activation) + + self._readout = torch.nn.Sequential(*readout_layers) + + def _global_pooling(self, x: Tensor, batch: LongTensor) -> Tensor: + """Perform global pooling.""" + assert self._global_pooling_schemes + pooled = [] + for pooling_scheme in self._global_pooling_schemes: + pooling_fn = GLOBAL_POOLINGS[pooling_scheme] + pooled_x = pooling_fn(x, index=batch, dim=0) + if isinstance(pooled_x, tuple) and len(pooled_x) == 2: + # `scatter_{min,max}`, which return also an argument, vs. + # `scatter_{mean,sum}` + pooled_x, _ = pooled_x + pooled.append(pooled_x) + + return torch.cat(pooled, dim=1) + + def _calculate_global_variables( + self, + x: Tensor, + edge_index: LongTensor, + batch: LongTensor, + *additional_attributes: Tensor, + ) -> Tensor: + """Calculate global variables""" # Calculate homophily (scalar variables) h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch) @@ -272,153 +258,26 @@ def forward(self, data: Data) -> Tensor: global_means = scatter_mean(x, batch, dim=0) # Add global variables - distribute = ( - batch.unsqueeze(dim=1) == torch.unique(batch).unsqueeze(dim=0) - ).type(torch.float) - global_variables = torch.cat( [ global_means, - torch.log10(data.n_pulses).unsqueeze(dim=1), h_x, h_y, h_z, h_t, - ], - dim=1, - ) - - global_variables_distributed = torch.sum( - distribute.unsqueeze(dim=2) * global_variables.unsqueeze(dim=0), - dim=1, - ) - - x = torch.cat((x, global_variables_distributed), dim=1) - - a, edge_index = self.conv_add1(x, edge_index, batch) - b, edge_index = self.conv_add2(a, edge_index, batch) - c, edge_index = self.conv_add3(b, edge_index, batch) - d, edge_index = self.conv_add4(c, edge_index, batch) - - # Skip-cat - x = torch.cat((x, a, b, c, d), dim=1) - - # Post-processing - x = self.nn1(x) - x = self.lrelu(x) - x = self.nn2(x) - - # Aggregation across nodes - a, _ = scatter_max(x, batch, dim=0) - b, _ = scatter_min(x, batch, dim=0) - c = scatter_mean(x, batch, dim=0) - - # Concatenate aggregations and scalar features - x = torch.cat( - ( - a, - b, - c, - ), + ] + + [attr.unsqueeze(dim=1) for attr in additional_attributes], dim=1, ) - # Read-out - x = self.lrelu(x) - x = self.nn3(x) - - x = self.lrelu(x) - - return x - - -class DynEdge_V3(GNN): - def __init__(self, nb_inputs, layer_size_scale=4): - """DynEdge model. - Args: - nb_inputs (int): Number of input features. - nb_outputs (int): Number of output features. - layer_size_scale (int, optional): Integer that scales the size of - hidden layers. Defaults to 4. - """ - - # Architecture configuration - c = layer_size_scale - l1, l2, l3, l4, l5, l6 = ( - nb_inputs * 2 + 5, - c * 16 * 2, - c * 32 * 2, - c * 42 * 2, - c * 32 * 2, - c * 16 * 2, - ) - - # Base class constructor - super().__init__(nb_inputs, l6) - - # Graph convolutional operations - features_subset = slice(0, 3) - nb_neighbors = 8 - - self.conv_add1 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l1 * 2, l2), - torch.nn.LeakyReLU(), - torch.nn.Linear(l2, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) - - self.conv_add2 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) - - self.conv_add3 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) - - self.conv_add4 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) - - # Post-processing operations - self.nn1 = torch.nn.Linear(l3 * 4 + l1, l4) - self.nn2 = torch.nn.Linear(l4, l5) - # self.nn3 = torch.nn.Linear(3*l5 + 0,l6) # 4*l5 + 5 - self.nn3 = torch.nn.Linear(l5, l6) - self.lrelu = torch.nn.LeakyReLU() + return global_variables def forward(self, data: Data) -> Tensor: """Model forward pass. + Args: data (Data): Graph of input features. + Returns: Tensor: Model output. """ @@ -426,53 +285,52 @@ def forward(self, data: Data) -> Tensor: # Convenience variables x, edge_index, batch = data.x, data.edge_index, data.batch - # Calculate homophily (scalar variables) - h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch) - - # Calculate mean features - global_means = scatter_mean(x, batch, dim=0) - - # Add global variables - distribute = ( - batch.unsqueeze(dim=1) == torch.unique(batch).unsqueeze(dim=0) - ).type(torch.float) - - global_variables = torch.cat( - [ - global_means, - torch.log10(data.n_pulses).unsqueeze(dim=1), - h_x, - h_y, - h_z, - h_t, - ], - dim=1, + global_variables = self._calculate_global_variables( + x, + edge_index, + batch, + torch.log10(data.n_pulses), ) - global_variables_distributed = torch.sum( - distribute.unsqueeze(dim=2) * global_variables.unsqueeze(dim=0), - dim=1, - ) + # Distribute global variables out to each node + if not self._add_global_variables_after_pooling: + distribute = ( + batch.unsqueeze(dim=1) == torch.unique(batch).unsqueeze(dim=0) + ).type(torch.float) + + global_variables_distributed = torch.sum( + distribute.unsqueeze(dim=2) + * global_variables.unsqueeze(dim=0), + dim=1, + ) - x = torch.cat((x, global_variables_distributed), dim=1) + x = torch.cat((x, global_variables_distributed), dim=1) - a, edge_index = self.conv_add1(x, edge_index, batch) - b, edge_index = self.conv_add2(a, edge_index, batch) - c, edge_index = self.conv_add3(b, edge_index, batch) - d, edge_index = self.conv_add4(c, edge_index, batch) + # DynEdge-convolutions + skip_connections = [x] + for conv_layer in self._conv_layers: + x, edge_index = conv_layer(x, edge_index, batch) + skip_connections.append(x) # Skip-cat - x = torch.cat((x, a, b, c, d), dim=1) + x = torch.cat(skip_connections, dim=1) # Post-processing - x = self.nn1(x) - x = self.lrelu(x) - x = self.nn2(x) + x = self._post_processing(x) + + # (Optional) Global pooling + if self._global_pooling_schemes: + x = self._global_pooling(x, batch=batch) + if self._add_global_variables_after_pooling: + x = torch.cat( + [ + x, + global_variables.unsqueeze(1), + ], + dim=1, + ) # Read-out - x = self.lrelu(x) - x = self.nn3(x) - - x = self.lrelu(x) + x = self._readout(x) return x diff --git a/src/graphnet/models/gnn/dynedge_jinst.py b/src/graphnet/models/gnn/dynedge_jinst.py new file mode 100644 index 000000000..82a94e9b5 --- /dev/null +++ b/src/graphnet/models/gnn/dynedge_jinst.py @@ -0,0 +1,156 @@ +"""Implementation of the exact DynEdge architecture used in [2209.03042]. + +Author: Rasmus Oersoe +""" +from typing import Optional + +import torch +from torch import Tensor +from torch_geometric.data import Data +from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum + +from graphnet.models.components.layers import DynEdgeConv +from graphnet.models.gnn.gnn import GNN +from graphnet.models.utils import calculate_xyzt_homophily + + +class DynEdgeJINST(GNN): + def __init__( + self, + nb_inputs: int, + layer_size_scale: Optional[int] = 4, + ): + """DynEdge model. + Args: + nb_inputs (int): Number of input features. + nb_outputs (int): Number of output features. + layer_size_scale (int, optional): Integer that scales the size of + hidden layers. Defaults to 4. + """ + # Architecture configuration + c = layer_size_scale + l1, l2, l3, l4, l5, l6 = ( + nb_inputs, + c * 16 * 2, + c * 32 * 2, + c * 42 * 2, + c * 32 * 2, + c * 16 * 2, + ) + + # Base class constructor + super().__init__(nb_inputs, l6) + + # Graph convolutional operations + features_subset = slice(0, 3) + nb_neighbors = 8 + + self.conv_add1 = DynEdgeConv( + torch.nn.Sequential( + torch.nn.Linear(l1 * 2, l2), + torch.nn.LeakyReLU(), + torch.nn.Linear(l2, l3), + torch.nn.LeakyReLU(), + ), + aggr="add", + nb_neighbors=nb_neighbors, + features_subset=features_subset, + ) + + self.conv_add2 = DynEdgeConv( + torch.nn.Sequential( + torch.nn.Linear(l3 * 2, l4), + torch.nn.LeakyReLU(), + torch.nn.Linear(l4, l3), + torch.nn.LeakyReLU(), + ), + aggr="add", + nb_neighbors=nb_neighbors, + features_subset=features_subset, + ) + + self.conv_add3 = DynEdgeConv( + torch.nn.Sequential( + torch.nn.Linear(l3 * 2, l4), + torch.nn.LeakyReLU(), + torch.nn.Linear(l4, l3), + torch.nn.LeakyReLU(), + ), + aggr="add", + nb_neighbors=nb_neighbors, + features_subset=features_subset, + ) + + self.conv_add4 = DynEdgeConv( + torch.nn.Sequential( + torch.nn.Linear(l3 * 2, l4), + torch.nn.LeakyReLU(), + torch.nn.Linear(l4, l3), + torch.nn.LeakyReLU(), + ), + aggr="add", + nb_neighbors=nb_neighbors, + features_subset=features_subset, + ) + + # Post-processing operations + self.nn1 = torch.nn.Linear(l3 * 4 + l1, l4) + self.nn2 = torch.nn.Linear(l4, l5) + self.nn3 = torch.nn.Linear(4 * l5 + 5, l6) + self.lrelu = torch.nn.LeakyReLU() + + def forward(self, data: Data) -> Tensor: + """Model forward pass. + Args: + data (Data): Graph of input features. + Returns: + Tensor: Model output. + """ + # Convenience variables + x, edge_index, batch = data.x, data.edge_index, data.batch + + # Calculate homophily (scalar variables) + h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch) + + a, edge_index = self.conv_add1(x, edge_index, batch) + b, edge_index = self.conv_add2(a, edge_index, batch) + c, edge_index = self.conv_add3(b, edge_index, batch) + d, edge_index = self.conv_add4(c, edge_index, batch) + + # Skip-cat + x = torch.cat((x, a, b, c, d), dim=1) + + # Post-processing + x = self.nn1(x) + x = self.lrelu(x) + x = self.nn2(x) + + # Aggregation across nodes + a, _ = scatter_max(x, batch, dim=0) + b, _ = scatter_min(x, batch, dim=0) + c = scatter_sum(x, batch, dim=0) + d = scatter_mean(x, batch, dim=0) + + # Concatenate aggregations and scalar features + x = torch.cat( + ( + a, + b, + c, + d, + h_t.reshape(-1, 1), + h_x.reshape(-1, 1), + h_y.reshape(-1, 1), + h_z.reshape(-1, 1), + data.n_pulses.reshape(-1, 1), + ), + dim=1, + ) + + # Read-out + x = self.lrelu(x) + x = self.nn3(x) + + x = self.lrelu(x) + + return x