From 9db4e104f7dace508f95607c8f4728e68e2f0b8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Tue, 4 Oct 2022 13:32:14 +0200 Subject: [PATCH 1/6] Use coarsning in Model rather than in Dynedge --- src/graphnet/models/gnn/dynedge.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py index d381fe70e..e20b5c7ed 100644 --- a/src/graphnet/models/gnn/dynedge.py +++ b/src/graphnet/models/gnn/dynedge.py @@ -12,7 +12,6 @@ from torch_geometric.nn import EdgeConv from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum from graphnet.components.layers import DynEdgeConv -from graphnet.models.coarsening import Coarsening from graphnet.models.gnn.gnn import GNN from graphnet.models.utils import calculate_xyzt_homophily @@ -23,7 +22,6 @@ def __init__( self, nb_inputs: int, layer_size_scale: Optional[int] = 4, - node_pooling: Coarsening = None, ): """DynEdge model. @@ -32,10 +30,7 @@ def __init__( nb_outputs (int): Number of output features. layer_size_scale (int, optional): Integer that scales the size of hidden layers. Defaults to 4. - node_pooling: A Coarsening module that pools the nodes before they are processed by the model. Defaults to None (no pooling). """ - # Node Pooling via Coarsening Module - self._coarsening = node_pooling # Architecture configuration c = layer_size_scale l1, l2, l3, l4, l5, l6 = ( From e4c5b3d4f02a50888e85bb2e9835033dd7c81ad9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Tue, 4 Oct 2022 14:52:01 +0200 Subject: [PATCH 2/6] Make features_subset sequence *or* list --- src/graphnet/components/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/graphnet/components/layers.py b/src/graphnet/components/layers.py index b0ce4477e..80c1f8c3f 100644 --- a/src/graphnet/components/layers.py +++ b/src/graphnet/components/layers.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence +from typing import Callable, List, Optional, Sequence, Union from torch.functional import Tensor @@ -13,7 +13,7 @@ def __init__( nn: Callable, aggr: str = "max", nb_neighbors: int = 8, - features_subset: Optional[Sequence] = None, + features_subset: Optional[Union[Sequence[int], List[int]]] = None, **kwargs, ): # Check(s) From 062970767652eb0f9b672218236836b83ab352bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Tue, 4 Oct 2022 14:56:12 +0200 Subject: [PATCH 3/6] Refactor DynEdge model --- src/graphnet/models/gnn/__init__.py | 3 +- src/graphnet/models/gnn/dynedge.py | 608 ++++++++--------------- src/graphnet/models/gnn/dynedge_jinst.py | 156 ++++++ 3 files changed, 375 insertions(+), 392 deletions(-) create mode 100644 src/graphnet/models/gnn/dynedge_jinst.py diff --git a/src/graphnet/models/gnn/__init__.py b/src/graphnet/models/gnn/__init__.py index c16d9c990..9c2cef2a0 100644 --- a/src/graphnet/models/gnn/__init__.py +++ b/src/graphnet/models/gnn/__init__.py @@ -1,2 +1,3 @@ -from .dynedge import DynEdge, DynEdge_V2, DynEdge_V3 from .convnet import ConvNet +from .dynedge import DynEdge +from .dynedge_jinst import DynEdgeJINST diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py index e20b5c7ed..5c2de399f 100644 --- a/src/graphnet/models/gnn/dynedge.py +++ b/src/graphnet/models/gnn/dynedge.py @@ -5,9 +5,11 @@ Author: Rasmus Oersoe Email: ###@###.### """ -from typing import List, Optional, Union +from multiprocessing import pool +from pickle import GLOBAL +from typing import List, Optional, Tuple, Union import torch -from torch import Tensor +from torch import Tensor, LongTensor from torch_geometric.data import Data from torch_geometric.nn import EdgeConv from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum @@ -16,12 +18,26 @@ from graphnet.models.gnn.gnn import GNN from graphnet.models.utils import calculate_xyzt_homophily +GLOBAL_POOLINGS = { + "min": scatter_min, + "max": scatter_max, + "sum": scatter_sum, + "mean": scatter_mean, +} + class DynEdge(GNN): def __init__( self, - nb_inputs: int, - layer_size_scale: Optional[int] = 4, + nb_inputs, + *, + nb_neighbours: Optional[int] = 8, + features_subset: Optional[List[int]] = None, + dynedge_layer_sizes: Optional[List[Tuple[int]]] = None, + post_processing_layer_sizes: Optional[List[int]] = None, + readout_layer_sizes: Optional[List[int]] = None, + global_pooling_schemes: Optional[Union[str, List[str]]] = None, + add_global_variables_after_pooling: bool = False, ): """DynEdge model. @@ -31,234 +47,172 @@ def __init__( layer_size_scale (int, optional): Integer that scales the size of hidden layers. Defaults to 4. """ - # Architecture configuration - c = layer_size_scale - l1, l2, l3, l4, l5, l6 = ( - nb_inputs, - c * 16 * 2, - c * 32 * 2, - c * 42 * 2, - c * 32 * 2, - c * 16 * 2, - ) - - # Base class constructor - super().__init__(nb_inputs, l6) - - # Graph convolutional operations - features_subset = slice(0, 3) - nb_neighbors = 8 - - self.conv_add1 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l1 * 2, l2), - torch.nn.LeakyReLU(), - torch.nn.Linear(l2, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, + # Latent feature subset for computing nearest neighbours in DynEdge. + if features_subset is None: + features_subset = slice(0, 3) + + # DynEdge layer sizes + if dynedge_layer_sizes is None: + dynedge_layer_sizes = [ + ( + 128, + 256, + ), + ( + 256, + 336, + ), + ( + 256, + 336, + ), + ( + 256, + 336, + ), + ] + + assert isinstance(dynedge_layer_sizes, list) + assert len(dynedge_layer_sizes) + assert all(isinstance(sizes, tuple) for sizes in dynedge_layer_sizes) + assert all(len(sizes) > 0 for sizes in dynedge_layer_sizes) + assert all( + all(size > 0 for size in sizes) for sizes in dynedge_layer_sizes ) - self.conv_add2 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, + self._dynedge_layer_sizes = dynedge_layer_sizes + + # Post-processing layer sizes + if post_processing_layer_sizes is None: + post_processing_layer_sizes = [ + 336, + 256, + ] + + assert isinstance(post_processing_layer_sizes, list) + assert len(post_processing_layer_sizes) + assert all(size > 0 for size in post_processing_layer_sizes) + + self._post_processing_layer_sizes = post_processing_layer_sizes + + # Read-out layer sizes + if readout_layer_sizes is None: + readout_layer_sizes = [ + 128, + ] + + assert isinstance(readout_layer_sizes, list) + assert len(readout_layer_sizes) + assert all(size > 0 for size in readout_layer_sizes) + + self._readout_layer_sizes = readout_layer_sizes + + # Global pooling scheme(s) + if isinstance(global_pooling_schemes, str): + global_pooling_schemes = [global_pooling_schemes] + + if isinstance(global_pooling_schemes, list): + for pooling_scheme in global_pooling_schemes: + assert ( + pooling_scheme in GLOBAL_POOLINGS + ), f"Global pooling scheme {pooling_scheme} not supported." + else: + assert global_pooling_schemes is None + + self._global_pooling_schemes = global_pooling_schemes + + if add_global_variables_after_pooling: + assert self._global_pooling_schemes, ( + "No global pooling schemes were request, so cannot add global" + " variables after pooling." + ) + self._add_global_variables_after_pooling = ( + add_global_variables_after_pooling ) - self.conv_add3 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) - - self.conv_add4 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) + # Base class constructor + super().__init__(nb_inputs, self._layer_sizes[-1]) + + # Common layer(s) + self._activation = torch.nn.LeakyReLU() + nb_global_variables = 5 + nb_inputs + + # Convolutional operations + nb_input_features = nb_inputs + if not self._add_global_variables_after_pooling: + nb_input_features += nb_global_variables + + self._conv_layers = [] + for sizes in self._dynedge_layer_sizes: + layers = [] + for nb_in, nb_out in zip([nb_input_features] + sizes[:-1], sizes): + layers.append(torch.nn.Linear(nb_in, nb_out)) + layers.append(self._activation) + + conv_layer = DynEdgeConv( + torch.nn.Sequential(*layers), + aggr="add", + nb_neighbors=nb_neighbours, + features_subset=features_subset, + ) + self._conv_layers.append(conv_layer) # Post-processing operations - self.nn1 = torch.nn.Linear(l3 * 4 + l1, l4) - self.nn2 = torch.nn.Linear(l4, l5) - self.nn3 = torch.nn.Linear(4 * l5 + 5, l6) - self.lrelu = torch.nn.LeakyReLU() - - def forward(self, data: Data) -> Tensor: - """Model forward pass. - - Args: - data (Data): Graph of input features. - - Returns: - Tensor: Model output. - """ - if self._coarsening is not None: - data = self._coarsening(data) - # Convenience variables - x, edge_index, batch = data.x, data.edge_index, data.batch - - # Calculate homophily (scalar variables) - h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch) - - a, edge_index = self.conv_add1(x, edge_index, batch) - b, edge_index = self.conv_add2(a, edge_index, batch) - c, edge_index = self.conv_add3(b, edge_index, batch) - d, edge_index = self.conv_add4(c, edge_index, batch) - - # Skip-cat - x = torch.cat((x, a, b, c, d), dim=1) - - # Post-processing - x = self.nn1(x) - x = self.lrelu(x) - x = self.nn2(x) - - # Aggregation across nodes - a, _ = scatter_max(x, batch, dim=0) - b, _ = scatter_min(x, batch, dim=0) - c = scatter_sum(x, batch, dim=0) - d = scatter_mean(x, batch, dim=0) - - # Concatenate aggregations and scalar features - x = torch.cat( - ( - a, - b, - c, - d, - h_t.reshape(-1, 1), - h_x.reshape(-1, 1), - h_y.reshape(-1, 1), - h_z.reshape(-1, 1), - data.n_pulses.reshape(-1, 1), - ), - dim=1, - ) - - # Read-out - x = self.lrelu(x) - x = self.nn3(x) - - x = self.lrelu(x) - - return x - - -class DynEdge_V2(GNN): - def __init__(self, nb_inputs, layer_size_scale=4): - """DynEdge model. - - Args: - nb_inputs (int): Number of input features. - nb_outputs (int): Number of output features. - layer_size_scale (int, optional): Integer that scales the size of - hidden layers. Defaults to 4. - """ - - # Architecture configuration - c = layer_size_scale - l1, l2, l3, l4, l5, l6 = ( - nb_inputs * 2 + 5, - c * 16 * 2, - c * 32 * 2, - c * 42 * 2, - c * 32 * 2, - c * 16 * 2, - ) - - # Base class constructor - super().__init__(nb_inputs, l6) - - # Graph convolutional operations - features_subset = slice(0, 3) - nb_neighbors = 8 - - self.conv_add1 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l1 * 2, l2), - torch.nn.LeakyReLU(), - torch.nn.Linear(l2, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, + nb_latent_features = ( + nb_out * len(self._dynedge_layer_sizes) + nb_input_features ) - self.conv_add2 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) + post_processing_layers = [] + for nb_in, nb_out in zip( + [nb_latent_features] + self._readout_layer_sizes[:-1], + self._readout_layer_sizes, + ): + post_processing_layers.append(torch.nn.Linear(nb_in, nb_out)) + post_processing_layers.append(self._activation) - self.conv_add3 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) + self._post_processing = torch.nn.Sequential(*post_processing_layers) - self.conv_add4 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, + # Read-out operations + nb_poolings = ( + len(self._global_pooling_schemes) if global_pooling_schemes else 1 ) - - # Post-processing operations - self.nn1 = torch.nn.Linear(l3 * 4 + l1, l4) - self.nn2 = torch.nn.Linear(l4, l5) - self.nn3 = torch.nn.Linear(3 * l5 + 0, l6) # 4*l5 + 5 - self.lrelu = torch.nn.LeakyReLU() - - def forward(self, data: Data) -> Tensor: - """Model forward pass. - - Args: - data (Data): Graph of input features. - - Returns: - Tensor: Model output. - """ - - # Convenience variables - x, edge_index, batch = data.x, data.edge_index, data.batch + nb_latent_features = nb_out * nb_poolings + if self._add_global_variables_after_pooling: + nb_latent_features += nb_global_variables + + readout_layers = [] + for nb_in, nb_out in zip( + [nb_latent_features] + self._readout_layer_sizes[:-1], + self._readout_layer_sizes, + ): + readout_layers.append(torch.nn.Linear(nb_in, nb_out)) + readout_layers.append(self._activation) + + self._readout = torch.nn.Sequential(*readout_layers) + + def _global_pooling(self, x: Tensor, batch: LongTensor) -> Tensor: + """Perform global pooling.""" + assert self._global_pooling_schemes + pooled = [] + for pooling_scheme in self._global_pooling_schemes: + pooling_fn = GLOBAL_POOLINGS[pooling_scheme] + pooled_x = pooling_fn(x, batch=batch, dim=0) + if isinstance(pooled_x, tuple) and len(pooled_x) == 2: + # `scatter_{min,max}`, which return also an argument, vs. + # `scatter_{mean,sum}` + pooled_x, _ = pooled_x + pooled.append(pooled_x) + + return torch.cat(pooled, dim=1) + + def _calculate_global_variables( + self, + x: Tensor, + edge_index: LongTensor, + batch: LongTensor, + *additional_attributes: Tensor, + ) -> Tensor: + """Calculate global variables""" # Calculate homophily (scalar variables) h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch) @@ -267,153 +221,26 @@ def forward(self, data: Data) -> Tensor: global_means = scatter_mean(x, batch, dim=0) # Add global variables - distribute = ( - batch.unsqueeze(dim=1) == torch.unique(batch).unsqueeze(dim=0) - ).type(torch.float) - global_variables = torch.cat( [ global_means, - torch.log10(data.n_pulses).unsqueeze(dim=1), h_x, h_y, h_z, h_t, - ], - dim=1, - ) - - global_variables_distributed = torch.sum( - distribute.unsqueeze(dim=2) * global_variables.unsqueeze(dim=0), - dim=1, - ) - - x = torch.cat((x, global_variables_distributed), dim=1) - - a, edge_index = self.conv_add1(x, edge_index, batch) - b, edge_index = self.conv_add2(a, edge_index, batch) - c, edge_index = self.conv_add3(b, edge_index, batch) - d, edge_index = self.conv_add4(c, edge_index, batch) - - # Skip-cat - x = torch.cat((x, a, b, c, d), dim=1) - - # Post-processing - x = self.nn1(x) - x = self.lrelu(x) - x = self.nn2(x) - - # Aggregation across nodes - a, _ = scatter_max(x, batch, dim=0) - b, _ = scatter_min(x, batch, dim=0) - c = scatter_mean(x, batch, dim=0) - - # Concatenate aggregations and scalar features - x = torch.cat( - ( - a, - b, - c, - ), + ] + + [attr.unsqueeze(dim=1) for attr in additional_attributes], dim=1, ) - # Read-out - x = self.lrelu(x) - x = self.nn3(x) - - x = self.lrelu(x) - - return x - - -class DynEdge_V3(GNN): - def __init__(self, nb_inputs, layer_size_scale=4): - """DynEdge model. - Args: - nb_inputs (int): Number of input features. - nb_outputs (int): Number of output features. - layer_size_scale (int, optional): Integer that scales the size of - hidden layers. Defaults to 4. - """ - - # Architecture configuration - c = layer_size_scale - l1, l2, l3, l4, l5, l6 = ( - nb_inputs * 2 + 5, - c * 16 * 2, - c * 32 * 2, - c * 42 * 2, - c * 32 * 2, - c * 16 * 2, - ) - - # Base class constructor - super().__init__(nb_inputs, l6) - - # Graph convolutional operations - features_subset = slice(0, 3) - nb_neighbors = 8 - - self.conv_add1 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l1 * 2, l2), - torch.nn.LeakyReLU(), - torch.nn.Linear(l2, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) - - self.conv_add2 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) - - self.conv_add3 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) - - self.conv_add4 = DynEdgeConv( - torch.nn.Sequential( - torch.nn.Linear(l3 * 2, l4), - torch.nn.LeakyReLU(), - torch.nn.Linear(l4, l3), - torch.nn.LeakyReLU(), - ), - aggr="add", - nb_neighbors=nb_neighbors, - features_subset=features_subset, - ) - - # Post-processing operations - self.nn1 = torch.nn.Linear(l3 * 4 + l1, l4) - self.nn2 = torch.nn.Linear(l4, l5) - # self.nn3 = torch.nn.Linear(3*l5 + 0,l6) # 4*l5 + 5 - self.nn3 = torch.nn.Linear(l5, l6) - self.lrelu = torch.nn.LeakyReLU() + return global_variables def forward(self, data: Data) -> Tensor: """Model forward pass. + Args: data (Data): Graph of input features. + Returns: Tensor: Model output. """ @@ -421,53 +248,52 @@ def forward(self, data: Data) -> Tensor: # Convenience variables x, edge_index, batch = data.x, data.edge_index, data.batch - # Calculate homophily (scalar variables) - h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch) - - # Calculate mean features - global_means = scatter_mean(x, batch, dim=0) - - # Add global variables - distribute = ( - batch.unsqueeze(dim=1) == torch.unique(batch).unsqueeze(dim=0) - ).type(torch.float) - - global_variables = torch.cat( - [ - global_means, - torch.log10(data.n_pulses).unsqueeze(dim=1), - h_x, - h_y, - h_z, - h_t, - ], - dim=1, + global_variables = self._calculate_global_variables( + x, + edge_index, + batch, + torch.log10(data.n_pulses), ) - global_variables_distributed = torch.sum( - distribute.unsqueeze(dim=2) * global_variables.unsqueeze(dim=0), - dim=1, - ) + # Distribute global variables out to each node + if not self._add_global_variables_after_pooling: + distribute = ( + batch.unsqueeze(dim=1) == torch.unique(batch).unsqueeze(dim=0) + ).type(torch.float) + + global_variables_distributed = torch.sum( + distribute.unsqueeze(dim=2) + * global_variables.unsqueeze(dim=0), + dim=1, + ) - x = torch.cat((x, global_variables_distributed), dim=1) + x = torch.cat((x, global_variables_distributed), dim=1) - a, edge_index = self.conv_add1(x, edge_index, batch) - b, edge_index = self.conv_add2(a, edge_index, batch) - c, edge_index = self.conv_add3(b, edge_index, batch) - d, edge_index = self.conv_add4(c, edge_index, batch) + # DynEdge-convolutions + skip_connections = [x] + for conv_layer in self._conv_layers: + x, edge_index = conv_layer(x, edge_index, batch) + skip_connections.append(x) # Skip-cat - x = torch.cat((x, a, b, c, d), dim=1) + x = torch.cat(skip_connections, dim=1) # Post-processing - x = self.nn1(x) - x = self.lrelu(x) - x = self.nn2(x) + x = self._post_processing(x) + + # (Optional) Global pooling + if self._global_pooling_schemes: + x = self._global_pooling(x, batch=batch) + if self._add_global_variables_after_pooling: + x = torch.cat( + [ + x, + global_variables.unsqueeze(1), + ], + dim=1, + ) # Read-out - x = self.lrelu(x) - x = self.nn3(x) - - x = self.lrelu(x) + x = self._readout(x) return x diff --git a/src/graphnet/models/gnn/dynedge_jinst.py b/src/graphnet/models/gnn/dynedge_jinst.py new file mode 100644 index 000000000..60897267d --- /dev/null +++ b/src/graphnet/models/gnn/dynedge_jinst.py @@ -0,0 +1,156 @@ +"""Implementation of the exact DynEdge architecture used in [2209.03042]. + +Author: Rasmus Oersoe +""" +from typing import Optional + +import torch +from torch import Tensor +from torch_geometric.data import Data +from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum + +from graphnet.components.layers import DynEdgeConv +from graphnet.models.gnn.gnn import GNN +from graphnet.models.utils import calculate_xyzt_homophily + + +class DynEdgeJINST(GNN): + def __init__( + self, + nb_inputs: int, + layer_size_scale: Optional[int] = 4, + ): + """DynEdge model. + Args: + nb_inputs (int): Number of input features. + nb_outputs (int): Number of output features. + layer_size_scale (int, optional): Integer that scales the size of + hidden layers. Defaults to 4. + """ + # Architecture configuration + c = layer_size_scale + l1, l2, l3, l4, l5, l6 = ( + nb_inputs, + c * 16 * 2, + c * 32 * 2, + c * 42 * 2, + c * 32 * 2, + c * 16 * 2, + ) + + # Base class constructor + super().__init__(nb_inputs, l6) + + # Graph convolutional operations + features_subset = slice(0, 3) + nb_neighbors = 8 + + self.conv_add1 = DynEdgeConv( + torch.nn.Sequential( + torch.nn.Linear(l1 * 2, l2), + torch.nn.LeakyReLU(), + torch.nn.Linear(l2, l3), + torch.nn.LeakyReLU(), + ), + aggr="add", + nb_neighbors=nb_neighbors, + features_subset=features_subset, + ) + + self.conv_add2 = DynEdgeConv( + torch.nn.Sequential( + torch.nn.Linear(l3 * 2, l4), + torch.nn.LeakyReLU(), + torch.nn.Linear(l4, l3), + torch.nn.LeakyReLU(), + ), + aggr="add", + nb_neighbors=nb_neighbors, + features_subset=features_subset, + ) + + self.conv_add3 = DynEdgeConv( + torch.nn.Sequential( + torch.nn.Linear(l3 * 2, l4), + torch.nn.LeakyReLU(), + torch.nn.Linear(l4, l3), + torch.nn.LeakyReLU(), + ), + aggr="add", + nb_neighbors=nb_neighbors, + features_subset=features_subset, + ) + + self.conv_add4 = DynEdgeConv( + torch.nn.Sequential( + torch.nn.Linear(l3 * 2, l4), + torch.nn.LeakyReLU(), + torch.nn.Linear(l4, l3), + torch.nn.LeakyReLU(), + ), + aggr="add", + nb_neighbors=nb_neighbors, + features_subset=features_subset, + ) + + # Post-processing operations + self.nn1 = torch.nn.Linear(l3 * 4 + l1, l4) + self.nn2 = torch.nn.Linear(l4, l5) + self.nn3 = torch.nn.Linear(4 * l5 + 5, l6) + self.lrelu = torch.nn.LeakyReLU() + + def forward(self, data: Data) -> Tensor: + """Model forward pass. + Args: + data (Data): Graph of input features. + Returns: + Tensor: Model output. + """ + # Convenience variables + x, edge_index, batch = data.x, data.edge_index, data.batch + + # Calculate homophily (scalar variables) + h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch) + + a, edge_index = self.conv_add1(x, edge_index, batch) + b, edge_index = self.conv_add2(a, edge_index, batch) + c, edge_index = self.conv_add3(b, edge_index, batch) + d, edge_index = self.conv_add4(c, edge_index, batch) + + # Skip-cat + x = torch.cat((x, a, b, c, d), dim=1) + + # Post-processing + x = self.nn1(x) + x = self.lrelu(x) + x = self.nn2(x) + + # Aggregation across nodes + a, _ = scatter_max(x, batch, dim=0) + b, _ = scatter_min(x, batch, dim=0) + c = scatter_sum(x, batch, dim=0) + d = scatter_mean(x, batch, dim=0) + + # Concatenate aggregations and scalar features + x = torch.cat( + ( + a, + b, + c, + d, + h_t.reshape(-1, 1), + h_x.reshape(-1, 1), + h_y.reshape(-1, 1), + h_z.reshape(-1, 1), + data.n_pulses.reshape(-1, 1), + ), + dim=1, + ) + + # Read-out + x = self.lrelu(x) + x = self.nn3(x) + + x = self.lrelu(x) + + return x From e704b4e50ae8c24348002cb4fac0692e025577bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Tue, 4 Oct 2022 16:39:00 +0200 Subject: [PATCH 4/6] Update DynEdge --- examples/train_model.py | 1 + src/graphnet/models/gnn/dynedge.py | 65 +++++++++++++++++++----------- 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/examples/train_model.py b/examples/train_model.py index f2a0f417e..991231c0b 100644 --- a/examples/train_model.py +++ b/examples/train_model.py @@ -93,6 +93,7 @@ def main(): ) gnn = DynEdge( nb_inputs=detector.nb_outputs, + global_pooling_schemes=["min", "max", "mean", "sum"], ) task = EnergyReconstruction( hidden_size=gnn.nb_outputs, diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py index 5c2de399f..95340df56 100644 --- a/src/graphnet/models/gnn/dynedge.py +++ b/src/graphnet/models/gnn/dynedge.py @@ -59,16 +59,16 @@ def __init__( 256, ), ( - 256, 336, + 256, ), ( - 256, 336, + 256, ), ( - 256, 336, + 256, ), ] @@ -131,42 +131,59 @@ def __init__( ) # Base class constructor - super().__init__(nb_inputs, self._layer_sizes[-1]) + super().__init__(nb_inputs, self._readout_layer_sizes[-1]) - # Common layer(s) + # Remaining member variables() self._activation = torch.nn.LeakyReLU() - nb_global_variables = 5 + nb_inputs + self._nb_inputs = nb_inputs + self._nb_global_variables = 5 + nb_inputs + self._nb_neighbours = nb_neighbours + self._features_subset = features_subset + + self._construct_layers() + + def _construct_layers(self): + """Construct layers (torch.nn.Modules).""" # Convolutional operations - nb_input_features = nb_inputs + nb_input_features = self._nb_inputs if not self._add_global_variables_after_pooling: - nb_input_features += nb_global_variables + nb_input_features += self._nb_global_variables - self._conv_layers = [] + self._conv_layers = torch.nn.ModuleList() + nb_latent_features = nb_input_features for sizes in self._dynedge_layer_sizes: layers = [] - for nb_in, nb_out in zip([nb_input_features] + sizes[:-1], sizes): + layer_sizes = [nb_latent_features] + list(sizes) + for ix, (nb_in, nb_out) in enumerate( + zip(layer_sizes[:-1], layer_sizes[1:]) + ): + if ix == 0: + nb_in *= 2 layers.append(torch.nn.Linear(nb_in, nb_out)) layers.append(self._activation) conv_layer = DynEdgeConv( torch.nn.Sequential(*layers), aggr="add", - nb_neighbors=nb_neighbours, - features_subset=features_subset, + nb_neighbors=self._nb_neighbours, + features_subset=self._features_subset, ) self._conv_layers.append(conv_layer) + nb_latent_features = nb_out + # Post-processing operations nb_latent_features = ( - nb_out * len(self._dynedge_layer_sizes) + nb_input_features + sum(sizes[-1] for sizes in self._dynedge_layer_sizes) + + nb_input_features ) post_processing_layers = [] - for nb_in, nb_out in zip( - [nb_latent_features] + self._readout_layer_sizes[:-1], - self._readout_layer_sizes, - ): + layer_sizes = [nb_latent_features] + list( + self._post_processing_layer_sizes + ) + for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]): post_processing_layers.append(torch.nn.Linear(nb_in, nb_out)) post_processing_layers.append(self._activation) @@ -174,17 +191,17 @@ def __init__( # Read-out operations nb_poolings = ( - len(self._global_pooling_schemes) if global_pooling_schemes else 1 + len(self._global_pooling_schemes) + if self._global_pooling_schemes + else 1 ) nb_latent_features = nb_out * nb_poolings if self._add_global_variables_after_pooling: - nb_latent_features += nb_global_variables + nb_latent_features += self._nb_global_variables readout_layers = [] - for nb_in, nb_out in zip( - [nb_latent_features] + self._readout_layer_sizes[:-1], - self._readout_layer_sizes, - ): + layer_sizes = [nb_latent_features] + list(self._readout_layer_sizes) + for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]): readout_layers.append(torch.nn.Linear(nb_in, nb_out)) readout_layers.append(self._activation) @@ -196,7 +213,7 @@ def _global_pooling(self, x: Tensor, batch: LongTensor) -> Tensor: pooled = [] for pooling_scheme in self._global_pooling_schemes: pooling_fn = GLOBAL_POOLINGS[pooling_scheme] - pooled_x = pooling_fn(x, batch=batch, dim=0) + pooled_x = pooling_fn(x, index=batch, dim=0) if isinstance(pooled_x, tuple) and len(pooled_x) == 2: # `scatter_{min,max}`, which return also an argument, vs. # `scatter_{mean,sum}` From 6b622741161d43655cea7e04bc2bd16e70d59076 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 6 Oct 2022 13:29:46 +0200 Subject: [PATCH 5/6] Update docstring --- src/graphnet/models/gnn/dynedge.py | 40 +++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py index 95340df56..c8e63ec38 100644 --- a/src/graphnet/models/gnn/dynedge.py +++ b/src/graphnet/models/gnn/dynedge.py @@ -29,7 +29,7 @@ class DynEdge(GNN): def __init__( self, - nb_inputs, + nb_inputs: int, *, nb_neighbours: Optional[int] = 8, features_subset: Optional[List[int]] = None, @@ -39,13 +39,41 @@ def __init__( global_pooling_schemes: Optional[Union[str, List[str]]] = None, add_global_variables_after_pooling: bool = False, ): - """DynEdge model. + """DynEdge (dynamical edge convolutional) model. Args: - nb_inputs (int): Number of input features. - nb_outputs (int): Number of output features. - layer_size_scale (int, optional): Integer that scales the size of - hidden layers. Defaults to 4. + nb_inputs (int): Number of input features on each node + nb_neighbours (Optional[int], optional): Number of neighbours to + used in the k-nearest neighbour clustering which is performed + after each (dynamical) edge convolution. Defaults to 8. + features_subset (Optional[List[int]], optional): The subset of + latent features on each node that are used as metric dimensions + when performing the k-nearest neighbours clustering. Defaults + to slice(0,3). + dynedge_layer_sizes (Optional[List[Tuple[int]]], optional): The + layer sizes, or latent feature dimenions, used in the + `DynEdgeConv` layer. Each entry in `dynedge_layer_sizes` + corresponds to a single `DynEdgeConv` layer; the integers in + the corresponding tuple corresponds to the layer sizes in the + multi-layer perceptron (MLP) that is applied within each + `DynEdgeConv` layer. That is, a list of size-two tuples means + that all `DynEdgeConv` layers contain a two-layer MLP. + Defaults to [(128, 256), (336, 256), (336, 256), (336, 256)]. + post_processing_layer_sizes (Optional[List[int]], optional): Hidden + layer sizes in the MLP following the skip-concatenation of the + outputs of each `DynEdgeConv` layer. Defaults to [336, 256]. + readout_layer_sizes (Optional[List[int]], optional): Hidden layer + sizes in the MLP following the post-processing _and_ optional + global pooling. As this is the last layer(s) in the model, the + last layer in the read-out yields the output of the `DynEdge` + model. Defaults to [128,]. + global_pooling_schemes (Optional[Union[str, List[str]]], optional): + The list global pooling schemes to use. Options are: "min", + "max", "mean", and "sum". Defaults to None. + add_global_variables_after_pooling (bool, optional): Whether to add + global variables after global pooling. The alternative is to + added (distribute) them to the individual nodes before any + convolutional operations. Defaults to False. """ # Latent feature subset for computing nearest neighbours in DynEdge. if features_subset is None: From 39c9148b8878879b8955a3a05851ef6cdf237943 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 6 Oct 2022 13:35:36 +0200 Subject: [PATCH 6/6] Fix import --- src/graphnet/models/gnn/dynedge_jinst.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/models/gnn/dynedge_jinst.py b/src/graphnet/models/gnn/dynedge_jinst.py index 60897267d..82a94e9b5 100644 --- a/src/graphnet/models/gnn/dynedge_jinst.py +++ b/src/graphnet/models/gnn/dynedge_jinst.py @@ -9,7 +9,7 @@ from torch_geometric.data import Data from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum -from graphnet.components.layers import DynEdgeConv +from graphnet.models.components.layers import DynEdgeConv from graphnet.models.gnn.gnn import GNN from graphnet.models.utils import calculate_xyzt_homophily