diff --git a/src/graphnet/models/transformer/iseecube.py b/src/graphnet/models/transformer/iseecube.py index e77aecbd4..68279d399 100644 --- a/src/graphnet/models/transformer/iseecube.py +++ b/src/graphnet/models/transformer/iseecube.py @@ -31,7 +31,7 @@ def __init__( max_rel_pos: int = 256, num_register_tokens: int = 3, scaled_emb: bool = False, - n_features: int = 6, + fourier_mapping: list = [0, 1, 2, 3, 4, 5], ): """Construct `ISeeCube`. @@ -46,7 +46,9 @@ def __init__( max_rel_pos: Maximum relative position for relative position bias. num_register_tokens: The number of register tokens. scaled_emb: Whether to scale the sinusoidal positional embeddings. - n_features: The number of features in the input data. + fourier_mapping: Mapping of the data to [x,y,z,time,charge, + auxiliary] for the FourierEncoder. Use None for missing + features. """ super().__init__(seq_length, hidden_dim) self.fourier_ext = FourierEncoder( @@ -54,7 +56,7 @@ def __init__( mlp_dim=mlp_dim, output_dim=hidden_dim, scaled=scaled_emb, - n_features=n_features, + mapping=fourier_mapping, ) self.pos_embedding = nn.Parameter( torch.empty(1, seq_length, hidden_dim).normal_(std=0.02),