diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index 005ab3a8b..b648c2632 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -120,10 +120,11 @@ def forward( seq_length: Tensor, ) -> Tensor: """Forward pass.""" - if max(self.mapping)+1 > x.shape[2]: + mapping_max = max(i for i in self.mapping if i is not None)+1 + if mapping_max > x.shape[2]: raise IndexError(f"Fourier mapping does not fit given data." f"Feature space of data is too small (size {x.shape[2]})," - f"given fourier mapping requires at least {max(self.mapping) + 1}.") + f"given fourier mapping requires at least {mapping_max}.") length = torch.log10(seq_length.to(dtype=x.dtype)) embeddings = [self.sin_emb(4096 * x[:, :, self.mapping[:3]]).flatten(-2)] # Position