Skip to content

Commit

Permalink
add 0 padding to targets, to compensate for missing past covariates
Browse files Browse the repository at this point in the history
  • Loading branch information
JanFidor committed Jul 23, 2023
1 parent 492c3e5 commit 30ccd78
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions darts/models/forecasting/transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,26 +295,28 @@ def forward(self, x_in: Tuple):
"""
data = x_in[0]

start_token = self._permute_transformer_inputs(data[:, -1:, :])
pad_size = (0, self.input_size - self.target_size)
start_token = self._permute_transformer_inputs(data[:, -1:, : self.target_size])
start_token_padded = F.pad(start_token, pad_size)
if len(x_in) == 3:
src, _, tgt = x_in
src = self._permute_transformer_inputs(src)
tgt_permuted = self._permute_transformer_inputs(tgt)
tgt_padded = F.pad(tgt_permuted, (0, self.input_size - self.target_size))
tgt = torch.cat([start_token, tgt_padded], dim=0)
tgt_padded = F.pad(tgt_permuted, pad_size)
tgt = torch.cat([start_token_padded, tgt_padded], dim=0)
return self._prediction_step(src, tgt)[:, :-1, :, :]

data, _ = x_in

src = self._permute_transformer_inputs(data)
tgt = start_token
tgt = start_token_padded

predictions = []
for _ in range(self.output_chunk_length):
pred = self._prediction_step(src, tgt)[:, -1, :, :]
predictions.append(pred)
tgt = torch.cat(
[tgt, pred.mean(dim=2).unsqueeze(dim=0)],
[tgt, F.pad(pred.mean(dim=2).unsqueeze(dim=0), pad_size)],
dim=0,
) # take average of quantiles
return torch.stack(predictions, dim=1)
Expand Down

0 comments on commit 30ccd78

Please sign in to comment.