diff --git a/model_pytorch.py b/model_pytorch.py index 61bc066..b4af6e9 100644 --- a/model_pytorch.py +++ b/model_pytorch.py @@ -165,6 +165,7 @@ def __init__(self, cfg, vocab=40990, n_ctx=512): def forward(self, x): x = x.view(-1, x.size(-2), x.size(-1)) e = self.embed(x) + # Add the position information to the input embeddings h = e.sum(dim=2) for block in self.h: h = block(h) diff --git a/train.py b/train.py index e5ced09..b51dcb5 100644 --- a/train.py +++ b/train.py @@ -32,6 +32,7 @@ def transform_roc(X1, X2, X3): xmb[i, 1, :l13, 0] = x13 mmb[i, 0, :l12] = 1 mmb[i, 1, :l13] = 1 + # Position information that is added to the input embeddings in the TransformerModel xmb[:, :, :, 1] = np.arange(n_vocab + n_special, n_vocab + n_special + n_ctx) return xmb, mmb