diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index 6a61866154..d8061e24ae 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -295,26 +295,28 @@ def forward(self, x_in: Tuple): """ data = x_in[0] - start_token = self._permute_transformer_inputs(data[:, -1:, :]) + pad_size = (0, self.input_size - self.target_size) + start_token = self._permute_transformer_inputs(data[:, -1:, : self.target_size]) + start_token_padded = F.pad(start_token, pad_size) if len(x_in) == 3: src, _, tgt = x_in src = self._permute_transformer_inputs(src) tgt_permuted = self._permute_transformer_inputs(tgt) - tgt_padded = F.pad(tgt_permuted, (0, self.input_size - self.target_size)) - tgt = torch.cat([start_token, tgt_padded], dim=0) + tgt_padded = F.pad(tgt_permuted, pad_size) + tgt = torch.cat([start_token_padded, tgt_padded], dim=0) return self._prediction_step(src, tgt)[:, :-1, :, :] data, _ = x_in src = self._permute_transformer_inputs(data) - tgt = start_token + tgt = start_token_padded predictions = [] for _ in range(self.output_chunk_length): pred = self._prediction_step(src, tgt)[:, -1, :, :] predictions.append(pred) tgt = torch.cat( - [tgt, pred.mean(dim=2).unsqueeze(dim=0)], + [tgt, F.pad(pred.mean(dim=2).unsqueeze(dim=0), pad_size)], dim=0, ) # take average of quantiles return torch.stack(predictions, dim=1)