diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index d8061e24ae..d17082cb08 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -289,15 +289,16 @@ def _permute_transformer_inputs(self, data): def forward(self, x_in: Tuple): """ - When teacher forcing, x_in = (past_target + past_covariates, static_covariates, future_targets) - When inference, x_in = (past_target + past_covariates, static_covariates) - + During training (teacher forcing) x_in = tuple(past_target + past_covariates, static_covariates, future_targets) + During inference x_in = tuple(past_target + past_covariates, static_covariates) """ data = x_in[0] - pad_size = (0, self.input_size - self.target_size) + + # start token consists only of target series, past covariates are substituted with 0 padding start_token = self._permute_transformer_inputs(data[:, -1:, : self.target_size]) start_token_padded = F.pad(start_token, pad_size) + if len(x_in) == 3: src, _, tgt = x_in src = self._permute_transformer_inputs(src) @@ -340,7 +341,7 @@ def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor): # Here we change the data format # from (1, batch_size, output_chunk_length * output_size) # to (batch_size, output_chunk_length, output_size, nr_params) - predictions = out.permute(1, 0, 2) + predictions = self._permute_transformer_inputs(out) predictions = predictions.view(batch_size, -1, self.target_size, self.nr_params) return predictions @@ -357,13 +358,11 @@ def _produce_train_output(self, input_batch: Tuple): Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset for training. - Parameters: if len(inp - # print([x.shape if x is not None else x for x in train_batch], "TRAIN")ut_batch) != 4: - a = 1 - ---------- - a = 1 + Parameters: input_batch - ``(past_target, past_covariates, static_covariates)`` + ``(past_target, past_covariates, static_covariates, future_target)`` during training + + ``(past_target, past_covariates, static_covariates)`` during validation (not teacher forced) """ past_target, past_covariates, static_covariates = input_batch[:3] @@ -375,7 +374,7 @@ def _produce_train_output(self, input_batch: Tuple): static_covariates, ] - # add future targets when teacher forcing + # add future targets when training (teacher forcing) if len(input_batch) == 4: inpt.append(input_batch[-1]) return self(inpt)