diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index 9539fc444..1b49cd901 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -84,7 +84,6 @@ def __init__( super().__init__() self.sin_emb = SinusoidalPosEmb(dim=seq_length, scaled=scaled) - self.aux_emb = nn.Embedding(2, seq_length // 2) self.sin_emb2 = SinusoidalPosEmb(dim=seq_length // 2, scaled=scaled) if n_features < 4: @@ -93,7 +92,7 @@ def __init__( f"{n_features} features." ) elif n_features >= 6: - + self.aux_emb = nn.Embedding(2, seq_length // 2) hidden_dim = 6 * seq_length else: hidden_dim = int((n_features + 0.5) * seq_length)