From 5d79ce7749a72df8b8979e7a5d5e9bc2366fc69d Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Fri, 26 Jan 2024 14:07:58 +0900 Subject: [PATCH] embedding small change + docstring --- src/graphnet/models/components/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index de29a6fe5..da39716a5 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -25,7 +25,7 @@ def __init__(self, dim: int = 16, m: int = 10000) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply learnable forward pass to the layer.""" device = x.device - half_dim = self.dim + half_dim = self.dim // 2 emb = torch.log(torch.tensor(self.m, device=device)) / half_dim emb = torch.exp(torch.arange(half_dim, device=device) * (-emb)) emb = x[..., None] * emb[None, ...]