Skip to content

Commit

Permalink
Merge pull request #660 from ArturoLlorente/icemix_kaggle_solution
Browse files Browse the repository at this point in the history
Icemix kaggle solution
  • Loading branch information
ArturoLlorente authored Mar 8, 2024
2 parents e572efc + 8483de4 commit 68ab6cc
Show file tree
Hide file tree
Showing 22 changed files with 935 additions and 42 deletions.
Binary file modified data/geometry_tables/icecube/icecube86.parquet
Binary file not shown.
Binary file modified data/geometry_tables/icecube/icecube_upgrade.parquet
Binary file not shown.
Binary file added data/ice_properties/ice_transparency.parquet
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
132 changes: 121 additions & 11 deletions src/graphnet/models/components/embedding.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,143 @@
"""Classes for performing embedding of input data."""
import torch
import torch.nn as nn
from torch.functional import Tensor

from pytorch_lightning import LightningModule

class SinusoidalPosEmb(torch.nn.Module):
"""Sinusoidal positional embedding layer.

class SinusoidalPosEmb(LightningModule):
"""Sinusoidal positional embeddings module.
This module is from the kaggle competition 2nd place solution (see
arXiv:2310.15674): It performs what is called Fourier encoding or it's used
in the Attention is all you need arXiv:1706.03762. It can be seen as a soft
digitization of the input data
"""

def __init__(self, dim: int = 16, n_freq: int = 10000) -> None:
def __init__(
self,
dim: int = 16,
n_freq: int = 10000,
scaled: bool = False,
):
"""Construct `SinusoidalPosEmb`.
Args:
dim: Embedding dimension.
n_freq: Number of frequencies.
scaled: Whether or not to scale the output.
"""
super().__init__()
if dim % 2 != 0:
raise ValueError(f"dim has to be even. Got: {dim}")
self.scale = (
nn.Parameter(torch.ones(1) * dim**-0.5) if scaled else 1.0
)
self.dim = dim
self.n_freq = n_freq
self.n_freq = torch.Tensor([n_freq])

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply learnable forward pass to the layer."""
def forward(self, x: Tensor) -> Tensor:
"""Forward pass."""
device = x.device
half_dim = self.dim // 2
emb = torch.log(torch.tensor(self.n_freq, device=device)) / half_dim
half_dim = self.dim / 2
emb = torch.log(self.n_freq.to(device=device)) / half_dim
emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
emb = x[..., None] * emb[None, ...]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1)
return emb * self.scale


class FourierEncoder(LightningModule):
"""Fourier encoder module.
This module incorporates sinusoidal positional embeddings and auxiliary
embeddings to process input sequences and produce meaningful
representations.
"""

def __init__(
self,
seq_length: int = 128,
output_dim: int = 384,
scaled: bool = False,
):
"""Construct `FourierEncoder`.
Args:
seq_length: Dimensionality of the base sinusoidal positional
embeddings.
output_dim: Output dimensionality of the final projection.
scaled: Whether or not to scale the embeddings.
"""
super().__init__()
self.sin_emb = SinusoidalPosEmb(dim=seq_length, scaled=scaled)
self.aux_emb = nn.Embedding(2, seq_length // 2)
self.sin_emb2 = SinusoidalPosEmb(dim=seq_length // 2, scaled=scaled)
self.projection = nn.Sequential(
nn.Linear(6 * seq_length, 6 * seq_length),
nn.LayerNorm(6 * seq_length),
nn.GELU(),
nn.Linear(6 * seq_length, output_dim),
)

def forward(
self,
x: Tensor,
seq_length: Tensor,
) -> Tensor:
"""Forward pass."""
length = torch.log10(seq_length.to(dtype=x.dtype))
x = torch.cat(
[
self.sin_emb(4096 * x[:, :, :3]).flatten(-2), # pos
self.sin_emb(1024 * x[:, :, 4]), # charge
self.sin_emb(4096 * x[:, :, 3]), # time
self.aux_emb(x[:, :, 5].long()), # auxiliary
self.sin_emb2(length)
.unsqueeze(1)
.expand(-1, max(seq_length), -1),
],
-1,
)
x = self.projection(x)
return x


class SpacetimeEncoder(LightningModule):
"""Spacetime encoder module."""

def __init__(
self,
seq_length: int = 32,
):
"""Construct `SpacetimeEncoder`.
This module calculates space-time interval between each pair of events
and generates sinusoidal positional embeddings to be added to input
sequences.
Args:
seq_length: Dimensionality of the sinusoidal positional embeddings.
"""
super().__init__()
self.sin_emb = SinusoidalPosEmb(dim=seq_length)
self.projection = nn.Linear(seq_length, seq_length)

def forward(
self,
x: Tensor,
# Lmax: Optional[int] = None,
) -> Tensor:
"""Forward pass."""
pos = x[:, :, :3]
time = x[:, :, 3]
spacetime_interval = (pos[:, :, None] - pos[:, None, :]).pow(2).sum(
-1
) - ((time[:, :, None] - time[:, None, :]) * (3e4 / 500 * 3e-1)).pow(2)
four_distance = torch.sign(spacetime_interval) * torch.sqrt(
torch.abs(spacetime_interval)
)
sin_emb = self.sin_emb(1024 * four_distance.clip(-4, 4))
rel_attn = self.projection(sin_emb)
return rel_attn
Loading

0 comments on commit 68ab6cc

Please sign in to comment.