From 9cf8ad3865fcd145ffadb8bea3836dfb22705b17 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Wed, 13 Mar 2024 14:05:11 +0900 Subject: [PATCH 1/2] automatic ensemble creation from list --- src/graphnet/data/dataloader.py | 44 +++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/src/graphnet/data/dataloader.py b/src/graphnet/data/dataloader.py index 1ded6fa37..8e73d4733 100644 --- a/src/graphnet/data/dataloader.py +++ b/src/graphnet/data/dataloader.py @@ -5,7 +5,7 @@ import torch.utils.data from torch_geometric.data import Batch, Data -from graphnet.data.dataset import Dataset +from graphnet.data.dataset import Dataset, EnsembleDataset from graphnet.utilities.config import DatasetConfig @@ -63,16 +63,33 @@ def from_dataset_config( "`shuffle` is automatically inferred from the selection name, " "and thus should not specified as an argument." ) - datasets = Dataset.from_config(config) + + if isinstance(config.path, list): + datasets: Union[Dict[str, Dataset], Dict[str, EnsembleDataset]] = {} # type: ignore + dataset_col: Dict[str, list] = {} + for key in config.selection.keys(): + dataset_col[key] = [] + save_path = config.path.copy() + for path in config.path: + config.path = path + tmp_dataset: Dict[str, Dataset] = Dataset.from_config( + config + ) + for key in config.selection.keys(): + dataset_col[key].append(tmp_dataset[key]) + config.path = save_path + for key in config.selection.keys(): + datasets[key] = EnsembleDataset(dataset_col[key]) + else: + datasets = Dataset.from_config(config) assert isinstance(datasets, dict) data_loaders: Dict[str, DataLoader] = {} - for name, dataset in datasets.items(): + for name, dataset_item in datasets.items(): data_loaders[name] = cls( - dataset, + dataset_item, shuffle=do_shuffle(name), **kwargs, ) - return data_loaders else: @@ -80,6 +97,17 @@ def from_dataset_config( "When passing a `DatasetConfig` with a single selections, you " "need to specify `shuffle` as an argument." ) - dataset = Dataset.from_config(config) - assert isinstance(dataset, Dataset) - return cls(dataset, **kwargs) + if isinstance(config.path, list): + dataset_list: List[Any] = [] + save_path = config.path.copy() + for path in config.path: + config.path = path + dataset_list.append(Dataset.from_config(config)) + ensembleset = EnsembleDataset(dataset_list) + config.path = save_path + assert isinstance(ensembleset, EnsembleDataset) + return cls(ensembleset, **kwargs) + else: + dataset = Dataset.from_config(config) + assert isinstance(dataset, Dataset) + return cls(dataset, **kwargs) From ad172236c7a1d5e937e829ba1c5c848ea99b4141 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 14 May 2024 11:05:06 +0900 Subject: [PATCH 2/2] update embedding to main --- src/graphnet/models/components/embedding.py | 70 +++++++++++++++------ 1 file changed, 51 insertions(+), 19 deletions(-) diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index 40145ad1a..e97ca90e7 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -3,6 +3,8 @@ import torch.nn as nn from torch.functional import Tensor +from typing import Optional + from pytorch_lightning import LightningModule @@ -53,34 +55,60 @@ class FourierEncoder(LightningModule): This module incorporates sinusoidal positional embeddings and auxiliary embeddings to process input sequences and produce meaningful - representations. + representations. The module assumes that the input data is in the format of + (x, y, z, time, charge, auxiliary), being the first four features + mandatory. """ def __init__( self, seq_length: int = 128, + mlp_dim: Optional[int] = None, output_dim: int = 384, scaled: bool = False, + n_features: int = 6, ): """Construct `FourierEncoder`. Args: seq_length: Dimensionality of the base sinusoidal positional embeddings. - output_dim: Output dimensionality of the final projection. + mlp_dim (Optional): Size of hidden, latent space of MLP. If not + given, `mlp_dim` is set automatically as multiples of + `seq_length` (in consistent with the 2nd place solution), + depending on `n_features`. + output_dim: Dimension of the output (I.e. number of columns). scaled: Whether or not to scale the embeddings. + n_features: The number of features in the input data. """ super().__init__() + self.sin_emb = SinusoidalPosEmb(dim=seq_length, scaled=scaled) self.aux_emb = nn.Embedding(2, seq_length // 2) self.sin_emb2 = SinusoidalPosEmb(dim=seq_length // 2, scaled=scaled) - self.projection = nn.Sequential( - nn.Linear(6 * seq_length, 6 * seq_length), - nn.LayerNorm(6 * seq_length), + + if n_features < 4: + raise ValueError( + f"At least x, y, z and time of the DOM are required. Got only " + f"{n_features} features." + ) + elif n_features >= 6: + hidden_dim = 6 * seq_length + else: + hidden_dim = int((n_features + 0.5) * seq_length) + + if mlp_dim is None: + mlp_dim = hidden_dim + + self.mlp = nn.Sequential( + nn.Linear(hidden_dim, mlp_dim), + nn.LayerNorm(mlp_dim), nn.GELU(), - nn.Linear(6 * seq_length, output_dim), + nn.Linear(mlp_dim, output_dim), ) + self.n_features = n_features + def forward( self, x: Tensor, @@ -88,19 +116,23 @@ def forward( ) -> Tensor: """Forward pass.""" length = torch.log10(seq_length.to(dtype=x.dtype)) - x = torch.cat( - [ - self.sin_emb(4096 * x[:, :, :3]).flatten(-2), # pos - self.sin_emb(1024 * x[:, :, 4]), # charge - self.sin_emb(4096 * x[:, :, 3]), # time - self.aux_emb(x[:, :, 5].long()), # auxiliary - self.sin_emb2(length) - .unsqueeze(1) - .expand(-1, max(seq_length), -1), - ], - -1, - ) - x = self.projection(x) + embeddings = [self.sin_emb(4096 * x[:, :, :3]).flatten(-2)] # Position + + if self.n_features >= 5: + embeddings.append(self.sin_emb(1024 * x[:, :, 4])) # Charge + + embeddings.append(self.sin_emb(4096 * x[:, :, 3])) # Time + + if self.n_features >= 6: + embeddings.append(self.aux_emb(x[:, :, 5].long())) # Auxiliary + + embeddings.append( + self.sin_emb2(length).unsqueeze(1).expand(-1, max(seq_length), -1) + ) # Length + + x = torch.cat(embeddings, -1) + x = self.mlp(x) + return x