forked from graphnet-team/graphnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request graphnet-team#660 from ArturoLlorente/icemix_kaggl…
…e_solution Icemix kaggle solution
- Loading branch information
Showing
22 changed files
with
935 additions
and
42 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified
BIN
-6.78 KB
(78%)
data/tests/parquet/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.parquet
Binary file not shown.
Binary file modified
BIN
+0 Bytes
(100%)
data/tests/sqlite/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.db
Binary file not shown.
Binary file modified
BIN
+4 KB
(100%)
...ie_step4_140028_000998_first_5_frames/upgrade_genie_step4_140028_000998_first_5_frames.db
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,143 @@ | ||
"""Classes for performing embedding of input data.""" | ||
import torch | ||
import torch.nn as nn | ||
from torch.functional import Tensor | ||
|
||
from pytorch_lightning import LightningModule | ||
|
||
class SinusoidalPosEmb(torch.nn.Module): | ||
"""Sinusoidal positional embedding layer. | ||
|
||
class SinusoidalPosEmb(LightningModule): | ||
"""Sinusoidal positional embeddings module. | ||
This module is from the kaggle competition 2nd place solution (see | ||
arXiv:2310.15674): It performs what is called Fourier encoding or it's used | ||
in the Attention is all you need arXiv:1706.03762. It can be seen as a soft | ||
digitization of the input data | ||
""" | ||
|
||
def __init__(self, dim: int = 16, n_freq: int = 10000) -> None: | ||
def __init__( | ||
self, | ||
dim: int = 16, | ||
n_freq: int = 10000, | ||
scaled: bool = False, | ||
): | ||
"""Construct `SinusoidalPosEmb`. | ||
Args: | ||
dim: Embedding dimension. | ||
n_freq: Number of frequencies. | ||
scaled: Whether or not to scale the output. | ||
""" | ||
super().__init__() | ||
if dim % 2 != 0: | ||
raise ValueError(f"dim has to be even. Got: {dim}") | ||
self.scale = ( | ||
nn.Parameter(torch.ones(1) * dim**-0.5) if scaled else 1.0 | ||
) | ||
self.dim = dim | ||
self.n_freq = n_freq | ||
self.n_freq = torch.Tensor([n_freq]) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
"""Apply learnable forward pass to the layer.""" | ||
def forward(self, x: Tensor) -> Tensor: | ||
"""Forward pass.""" | ||
device = x.device | ||
half_dim = self.dim // 2 | ||
emb = torch.log(torch.tensor(self.n_freq, device=device)) / half_dim | ||
half_dim = self.dim / 2 | ||
emb = torch.log(self.n_freq.to(device=device)) / half_dim | ||
emb = torch.exp(torch.arange(half_dim, device=device) * (-emb)) | ||
emb = x[..., None] * emb[None, ...] | ||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | ||
return emb | ||
emb = x.unsqueeze(-1) * emb.unsqueeze(0) | ||
emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1) | ||
return emb * self.scale | ||
|
||
|
||
class FourierEncoder(LightningModule): | ||
"""Fourier encoder module. | ||
This module incorporates sinusoidal positional embeddings and auxiliary | ||
embeddings to process input sequences and produce meaningful | ||
representations. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
seq_length: int = 128, | ||
output_dim: int = 384, | ||
scaled: bool = False, | ||
): | ||
"""Construct `FourierEncoder`. | ||
Args: | ||
seq_length: Dimensionality of the base sinusoidal positional | ||
embeddings. | ||
output_dim: Output dimensionality of the final projection. | ||
scaled: Whether or not to scale the embeddings. | ||
""" | ||
super().__init__() | ||
self.sin_emb = SinusoidalPosEmb(dim=seq_length, scaled=scaled) | ||
self.aux_emb = nn.Embedding(2, seq_length // 2) | ||
self.sin_emb2 = SinusoidalPosEmb(dim=seq_length // 2, scaled=scaled) | ||
self.projection = nn.Sequential( | ||
nn.Linear(6 * seq_length, 6 * seq_length), | ||
nn.LayerNorm(6 * seq_length), | ||
nn.GELU(), | ||
nn.Linear(6 * seq_length, output_dim), | ||
) | ||
|
||
def forward( | ||
self, | ||
x: Tensor, | ||
seq_length: Tensor, | ||
) -> Tensor: | ||
"""Forward pass.""" | ||
length = torch.log10(seq_length.to(dtype=x.dtype)) | ||
x = torch.cat( | ||
[ | ||
self.sin_emb(4096 * x[:, :, :3]).flatten(-2), # pos | ||
self.sin_emb(1024 * x[:, :, 4]), # charge | ||
self.sin_emb(4096 * x[:, :, 3]), # time | ||
self.aux_emb(x[:, :, 5].long()), # auxiliary | ||
self.sin_emb2(length) | ||
.unsqueeze(1) | ||
.expand(-1, max(seq_length), -1), | ||
], | ||
-1, | ||
) | ||
x = self.projection(x) | ||
return x | ||
|
||
|
||
class SpacetimeEncoder(LightningModule): | ||
"""Spacetime encoder module.""" | ||
|
||
def __init__( | ||
self, | ||
seq_length: int = 32, | ||
): | ||
"""Construct `SpacetimeEncoder`. | ||
This module calculates space-time interval between each pair of events | ||
and generates sinusoidal positional embeddings to be added to input | ||
sequences. | ||
Args: | ||
seq_length: Dimensionality of the sinusoidal positional embeddings. | ||
""" | ||
super().__init__() | ||
self.sin_emb = SinusoidalPosEmb(dim=seq_length) | ||
self.projection = nn.Linear(seq_length, seq_length) | ||
|
||
def forward( | ||
self, | ||
x: Tensor, | ||
# Lmax: Optional[int] = None, | ||
) -> Tensor: | ||
"""Forward pass.""" | ||
pos = x[:, :, :3] | ||
time = x[:, :, 3] | ||
spacetime_interval = (pos[:, :, None] - pos[:, None, :]).pow(2).sum( | ||
-1 | ||
) - ((time[:, :, None] - time[:, None, :]) * (3e4 / 500 * 3e-1)).pow(2) | ||
four_distance = torch.sign(spacetime_interval) * torch.sqrt( | ||
torch.abs(spacetime_interval) | ||
) | ||
sin_emb = self.sin_emb(1024 * four_distance.clip(-4, 4)) | ||
rel_attn = self.projection(sin_emb) | ||
return rel_attn |
Oops, something went wrong.