Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
JanFidor committed Jul 23, 2023
1 parent 30ccd78 commit fd13830
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions darts/models/forecasting/transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,15 +289,16 @@ def _permute_transformer_inputs(self, data):

def forward(self, x_in: Tuple):
"""
When teacher forcing, x_in = (past_target + past_covariates, static_covariates, future_targets)
When inference, x_in = (past_target + past_covariates, static_covariates)
During training (teacher forcing) x_in = tuple(past_target + past_covariates, static_covariates, future_targets)
During inference x_in = tuple(past_target + past_covariates, static_covariates)
"""
data = x_in[0]

pad_size = (0, self.input_size - self.target_size)

# start token consists only of target series, past covariates are substituted with 0 padding
start_token = self._permute_transformer_inputs(data[:, -1:, : self.target_size])
start_token_padded = F.pad(start_token, pad_size)

if len(x_in) == 3:
src, _, tgt = x_in
src = self._permute_transformer_inputs(src)
Expand Down Expand Up @@ -340,7 +341,7 @@ def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor):
# Here we change the data format
# from (1, batch_size, output_chunk_length * output_size)
# to (batch_size, output_chunk_length, output_size, nr_params)
predictions = out.permute(1, 0, 2)
predictions = self._permute_transformer_inputs(out)
predictions = predictions.view(batch_size, -1, self.target_size, self.nr_params)

return predictions
Expand All @@ -357,13 +358,11 @@ def _produce_train_output(self, input_batch: Tuple):
Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset for
training.
Parameters: if len(inp
# print([x.shape if x is not None else x for x in train_batch], "TRAIN")ut_batch) != 4:
a = 1
----------
a = 1
Parameters:
input_batch
``(past_target, past_covariates, static_covariates)``
``(past_target, past_covariates, static_covariates, future_target)`` during training
``(past_target, past_covariates, static_covariates)`` during validation (not teacher forced)
"""

past_target, past_covariates, static_covariates = input_batch[:3]
Expand All @@ -375,7 +374,7 @@ def _produce_train_output(self, input_batch: Tuple):
static_covariates,
]

# add future targets when teacher forcing
# add future targets when training (teacher forcing)
if len(input_batch) == 4:
inpt.append(input_batch[-1])
return self(inpt)
Expand Down

0 comments on commit fd13830

Please sign in to comment.