Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/transformer refactorisation #1915

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 90 additions & 17 deletions darts/models/forecasting/transformer_model.py
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from darts.logging import get_logger, raise_if, raise_if_not, raise_log
from darts.models.components import glu_variants, layer_norm_variants
Expand Down Expand Up @@ -183,6 +184,7 @@ def __init__(
self.target_size = output_size
self.nr_params = nr_params
self.target_length = self.output_chunk_length
self.d_model = d_model

self.encoder = nn.Linear(input_size, d_model)
self.positional_encoding = _PositionalEncoding(
Expand Down Expand Up @@ -276,47 +278,118 @@ def __init__(
custom_decoder=custom_decoder,
)

self.decoder = nn.Linear(
d_model, self.target_length * self.target_size * self.nr_params
)
self.decoder = nn.Linear(d_model, self.target_size * self.nr_params)

def _create_transformer_inputs(self, data):
def _permute_transformer_inputs(self, data):
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
# '_TimeSeriesSequentialDataset' stores time series in the
# (batch_size, input_chunk_length, input_size) format. PyTorch's nn.Transformer
# module needs it the (input_chunk_length, batch_size, input_size) format.
# Therefore, the first two dimensions need to be swapped.
src = data.permute(1, 0, 2)
tgt = src[-1:, :, :]

return src, tgt
return data.permute(1, 0, 2)

def forward(self, x_in: Tuple):
"""
During training (teacher forcing) x_in = tuple(past_target + past_covariates, static_covariates, future_targets)
During inference x_in = tuple(past_target + past_covariates, static_covariates)
"""
data = x_in[0]
pad_size = (0, self.input_size - self.target_size)

# start token consists only of target series, past covariates are substituted with 0 padding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would you not include the past covariates in the start token?

start_token = self._permute_transformer_inputs(data[:, -1:, : self.target_size])
start_token_padded = F.pad(start_token, pad_size)

if len(x_in) == 3:
src, _, tgt = x_in
src = self._permute_transformer_inputs(src)
tgt_permuted = self._permute_transformer_inputs(tgt)
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
tgt_padded = F.pad(tgt_permuted, pad_size)
tgt = torch.cat([start_token_padded, tgt_padded], dim=0)
return self._prediction_step(src, tgt)[:, :-1, :, :]

data, _ = x_in
# Here we create 'src' and 'tgt', the inputs for the encoder and decoder
# side of the Transformer architecture
src, tgt = self._create_transformer_inputs(data)

src = self._permute_transformer_inputs(data)
tgt = start_token_padded

predictions = []
for _ in range(self.target_length):
pred = self._prediction_step(src, tgt)[:, -1, :, :]
predictions.append(pred)
tgt = torch.cat(
[tgt, F.pad(pred.mean(dim=2).unsqueeze(dim=0), pad_size)],
dim=0,
) # take average of quantiles
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
return torch.stack(predictions, dim=1)

def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor):
target_length = tgt.shape[0]
device, tensor_type = src.device, src.dtype
# "math.sqrt(self.input_size)" is a normalization factor
# see section 3.2.1 in 'Attention is All you Need' by Vaswani et al. (2017)
src = self.encoder(src) * math.sqrt(self.input_size)
src = self.positional_encoding(src)
src = self.encoder(src) * math.sqrt(self.d_model)
tgt = self.encoder(tgt) * math.sqrt(self.d_model)

tgt = self.encoder(tgt) * math.sqrt(self.input_size)
src = self.positional_encoding(src)
tgt = self.positional_encoding(tgt)

x = self.transformer(src=src, tgt=tgt)
tgt_mask = torch.triu(
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
torch.full(
(target_length, target_length),
float("-inf"),
device=device,
dtype=tensor_type,
),
diagonal=1,
)

x = self.transformer(src=src, tgt=tgt, tgt_mask=tgt_mask)
out = self.decoder(x)

# Here we change the data format
# from (1, batch_size, output_chunk_length * output_size)
# to (batch_size, output_chunk_length, output_size, nr_params)
predictions = out[0, :, :]
predictions = self._permute_transformer_inputs(out)
predictions = predictions.view(
-1, self.target_length, self.target_size, self.nr_params
-1, target_length, self.target_size, self.nr_params
)

return predictions

# Allow teacher forcing
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
def training_step(self, train_batch, batch_idx) -> torch.Tensor:
"""performs the training step"""
train_batch = list(train_batch)
future_targets = train_batch[-1]
train_batch.append(future_targets)
return super().training_step(train_batch, batch_idx)

def _produce_train_output(self, input_batch: Tuple):
"""
Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset for
training.

Parameters:
input_batch
``(past_target, past_covariates, static_covariates, future_target)`` during training

``(past_target, past_covariates, static_covariates)`` during validation (not teacher forced)
"""

past_target, past_covariates, static_covariates = input_batch[:3]
# Currently all our PastCovariates models require past target and covariates concatenated
inpt = [
torch.cat([past_target, past_covariates], dim=2)
if past_covariates is not None
else past_target,
static_covariates,
]

# add future targets when training (teacher forcing)
if len(input_batch) == 4:
inpt.append(input_batch[-1])
return self(inpt)


class TransformerModel(PastCovariatesTorchModel):
def __init__(
Expand Down