diff --git a/src/utils/tokenizer.py b/src/utils/tokenizer.py index 90af3ad..6357410 100644 --- a/src/utils/tokenizer.py +++ b/src/utils/tokenizer.py @@ -101,8 +101,8 @@ def forward(self, x, mask=None): x = self.conv_layers(x) x = x.transpose(1, 3).squeeze(1) if mask is not None: - mask = self.forward_mask(mask).unsqueeze(-1).float() - x = x * mask + mask = self.forward_mask(mask) + x = x * mask.unsqueeze(-1).float() return x, mask @staticmethod