From 30ccd785fcce5ab5aba9d5d735d5c6c26005c780 Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Sun, 23 Jul 2023 15:12:36 +0200 Subject: [PATCH] add 0 padding to targets, to compensate for missing past covariates --- darts/models/forecasting/transformer_model.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index 6a61866154..d8061e24ae 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -295,26 +295,28 @@ def forward(self, x_in: Tuple): """ data = x_in[0] - start_token = self._permute_transformer_inputs(data[:, -1:, :]) + pad_size = (0, self.input_size - self.target_size) + start_token = self._permute_transformer_inputs(data[:, -1:, : self.target_size]) + start_token_padded = F.pad(start_token, pad_size) if len(x_in) == 3: src, _, tgt = x_in src = self._permute_transformer_inputs(src) tgt_permuted = self._permute_transformer_inputs(tgt) - tgt_padded = F.pad(tgt_permuted, (0, self.input_size - self.target_size)) - tgt = torch.cat([start_token, tgt_padded], dim=0) + tgt_padded = F.pad(tgt_permuted, pad_size) + tgt = torch.cat([start_token_padded, tgt_padded], dim=0) return self._prediction_step(src, tgt)[:, :-1, :, :] data, _ = x_in src = self._permute_transformer_inputs(data) - tgt = start_token + tgt = start_token_padded predictions = [] for _ in range(self.output_chunk_length): pred = self._prediction_step(src, tgt)[:, -1, :, :] predictions.append(pred) tgt = torch.cat( - [tgt, pred.mean(dim=2).unsqueeze(dim=0)], + [tgt, F.pad(pred.mean(dim=2).unsqueeze(dim=0), pad_size)], dim=0, ) # take average of quantiles return torch.stack(predictions, dim=1)