Skip to content

Commit

Permalink
Fixed small bug in embedding.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mobra7 committed Oct 28, 2024
1 parent c54d66e commit a587e18
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/graphnet/models/components/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,11 @@ def forward(
seq_length: Tensor,
) -> Tensor:
"""Forward pass."""
if max(self.mapping)+1 > x.shape[2]:
mapping_max = max(i for i in self.mapping if i is not None)+1
if mapping_max > x.shape[2]:
raise IndexError(f"Fourier mapping does not fit given data."
f"Feature space of data is too small (size {x.shape[2]}),"
f"given fourier mapping requires at least {max(self.mapping) + 1}.")
f"given fourier mapping requires at least {mapping_max}.")

length = torch.log10(seq_length.to(dtype=x.dtype))
embeddings = [self.sin_emb(4096 * x[:, :, self.mapping[:3]]).flatten(-2)] # Position
Expand Down

0 comments on commit a587e18

Please sign in to comment.