diff --git a/CHANGELOG.md b/CHANGELOG.md index a0944ca835..843492d0b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Improved** - `TimeSeries` with a `RangeIndex` starting in the negative start are now supported by `historical_forecasts`. [#1866](https://github.com/unit8co/darts/pull/1866) by [Antoine Madrona](https://github.com/madtoinou). - Added a new argument `start_format` to `historical_forecasts()`, `backtest()` and `gridsearch` that allows to use an integer `start` either as the index position or index value/label for `series` indexed with a `pd.RangeIndex`. [#1866](https://github.com/unit8co/darts/pull/1866) by [Antoine Madrona](https://github.com/madtoinou). +- Added `RINorm` (Reversible Instance Norm) as an input normalization option for all `TorchForecastingModel` except `RNNModel`. Activate it with model creation parameter `use_reversible_instance_norm`. [#1969](https://github.com/unit8co/darts/pull/1969) by [Dennis Bader](https://github.com/dennisbader). - Reduced the size of the Darts docker image `unit8/darts:latest`, and included all optional models as well as dev requirements. [#1878](https://github.com/unit8co/darts/pull/1878) by [Alex Colpitts](https://github.com/alexcolpitts96). **Fixed** @@ -60,7 +61,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Other improvements: - Improved static covariates column naming when using `StaticCovariatesTransformer` with a `sklearn.preprocessing.OneHotEncoder`. [#1863](https://github.com/unit8co/darts/pull/1863) by [Anne de Vries](https://github.com/anne-devries). - Added `MSTL` (Season-Trend decomposition using LOESS for multiple seasonalities) as a `method` option for `extract_trend_and_seasonality()`. [#1879](https://github.com/unit8co/darts/pull/1879) by [Alex Colpitts](https://github.com/alexcolpitts96). - - Added `RINorm` (Reversible Instance Norm) as a new input normalization option for `TorchForecastingModel`. So far only `TiDEModel` supports it with model creation parameter `user_reversible_instance_norm`. [#1865](https://github.com/unit8co/darts/issues/1856) by [Alex Colpitts](https://github.com/alexcolpitts96). + - Added `RINorm` (Reversible Instance Norm) as a new input normalization option for `TorchForecastingModel`. So far only `TiDEModel` supports it with model creation parameter `use_reversible_instance_norm`. [#1865](https://github.com/unit8co/darts/issues/1856) by [Alex Colpitts](https://github.com/alexcolpitts96). - Improvements to `TimeSeries.plot()`: custom axes are now properly supported with parameter `ax`. Axis is now returned for downstream tasks. [#1916](https://github.com/unit8co/darts/pull/1916) by [Dennis Bader](https://github.com/dennisbader). **Fixed** diff --git a/README.md b/README.md index 0cf949e46f..24ddb14b31 100644 --- a/README.md +++ b/README.md @@ -229,8 +229,8 @@ on bringing more models and features. | [BATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.BATS) and [TBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.TBATS) | [TBATS paper](https://robjhyndman.com/papers/ComplexSeasonality.pdf) | 🟩 🟥 | 🟥 🟥 🟥 | 🟩 🟥 | 🟥 | | [Theta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.Theta) and [FourTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.FourTheta) | [Theta](https://robjhyndman.com/papers/Theta.pdf) & [4 Theta](https://github.com/Mcompetitions/M4-methods/blob/master/4Theta%20method.R) | 🟩 🟥 | 🟥 🟥 🟥 | 🟥 🟥 | 🟥 | | [StatsForecastAutoTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_theta.html#darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | 🟩 🟥 | 🟥 🟥 🟥 | 🟩 🟥 | 🟥 | -| [Prophet](file:///Users/dennisbader/projects/unit8/darts/docs/build/html/generated_api/darts.models.forecasting.prophet_model.html#darts.models.forecasting.prophet_model.Prophet) (see [install notes](https://github.com/unit8co/darts/blob/master/INSTALL.md#enabling-support-for-facebook-prophet)) | [Prophet repo](https://github.com/facebook/prophet) | 🟩 🟥 | 🟥 🟩 🟥 | 🟩 🟥 | 🟥 | -| [FFT](file:///Users/dennisbader/projects/unit8/darts/docs/build/html/generated_api/darts.models.forecasting.fft.html#darts.models.forecasting.fft.FFT) (Fast Fourier Transform) | | 🟩 🟥 | 🟥 🟥 🟥 | 🟥 🟥 | 🟥 | +| [Prophet](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.prophet_model.html#darts.models.forecasting.prophet_model.Prophet) (see [install notes](https://github.com/unit8co/darts/blob/master/INSTALL.md#enabling-support-for-facebook-prophet)) | [Prophet repo](https://github.com/facebook/prophet) | 🟩 🟥 | 🟥 🟩 🟥 | 🟩 🟥 | 🟥 | +| [FFT](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.fft.html#darts.models.forecasting.fft.FFT) (Fast Fourier Transform) | | 🟩 🟥 | 🟥 🟥 🟥 | 🟥 🟥 | 🟥 | | [KalmanForecaster](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.kalman_forecaster.html#darts.models.forecasting.kalman_forecaster.KalmanForecaster) using the Kalman filter and N4SID for system identification | [N4SID paper](https://people.duke.edu/~hpgavin/SystemID/References/VanOverschee-Automatica-1994.pdf) | 🟩 🟩 | 🟥 🟩 🟥 | 🟩 🟥 | 🟥 | | [Croston](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.croston.html#darts.models.forecasting.croston.Croston) method | | 🟩 🟥 | 🟥 🟥 🟥 | 🟥 🟥 | 🟥 | | **Regression Models**
([GlobalForecastingModel](https://unit8co.github.io/darts/userguide/covariates.html#global-forecasting-models-gfms)) | | | | | | diff --git a/darts/models/components/statsforecast_utils.py b/darts/models/components/statsforecast_utils.py index e399aecf79..05a77d79ac 100644 --- a/darts/models/components/statsforecast_utils.py +++ b/darts/models/components/statsforecast_utils.py @@ -1,8 +1,3 @@ -""" -StatsForecast utils ------------ -""" - import numpy as np # In a normal distribution, 68.27 percentage of values lie within one standard deviation of the mean diff --git a/darts/models/forecasting/block_rnn_model.py b/darts/models/forecasting/block_rnn_model.py index 3ff52a3878..db5fa7281d 100644 --- a/darts/models/forecasting/block_rnn_model.py +++ b/darts/models/forecasting/block_rnn_model.py @@ -9,7 +9,10 @@ import torch.nn as nn from darts.logging import get_logger, raise_if_not -from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule +from darts.models.forecasting.pl_forecasting_module import ( + PLPastCovariatesModule, + io_processor, +) from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel logger = get_logger(__name__) @@ -101,6 +104,7 @@ def __init__( last = feature self.fc = nn.Sequential(*feats) + @io_processor def forward(self, x_in: Tuple): x, _ = x_in # data is of size (batch_size, input_chunk_length, input_size) @@ -194,6 +198,9 @@ def __init__( to using a constant learning rate. Default: ``None``. lr_scheduler_kwargs Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. + use_reversible_instance_norm + Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [1]_. + It is only applied to the features of the target series and not the covariates. batch_size Number of time series (input and output sequences) used in each training pass. Default: ``32``. n_epochs @@ -299,6 +306,11 @@ def encode_year(idx): show_warnings whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of your forecasting use case. Default: ``False``. + + References + ---------- + .. [1] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against + Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p """ super().__init__(**self._extract_torch_model_params(**self.model_params)) diff --git a/darts/models/forecasting/dlinear.py b/darts/models/forecasting/dlinear.py index 1151ac1992..d10f0d0d46 100644 --- a/darts/models/forecasting/dlinear.py +++ b/darts/models/forecasting/dlinear.py @@ -9,7 +9,10 @@ import torch.nn as nn from darts.logging import raise_if -from darts.models.forecasting.pl_forecasting_module import PLMixedCovariatesModule +from darts.models.forecasting.pl_forecasting_module import ( + PLMixedCovariatesModule, + io_processor, +) from darts.models.forecasting.torch_forecasting_model import MixedCovariatesTorchModel MixedCovariatesTrainTensorType = Tuple[ @@ -155,6 +158,7 @@ def _create_linear_layer(in_dim, out_dim): layer_in_dim_static_cov, layer_out_dim ) + @io_processor def forward( self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] ): @@ -295,6 +299,9 @@ def __init__( to using a constant learning rate. Default: ``None``. lr_scheduler_kwargs Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. + use_reversible_instance_norm + Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_. + It is only applied to the features of the target series and not the covariates. batch_size Number of time series (input and output sequences) used in each training pass. Default: ``32``. n_epochs @@ -405,6 +412,8 @@ def encode_year(idx): ---------- .. [1] Zeng, A., Chen, M., Zhang, L., & Xu, Q. (2022). Are Transformers Effective for Time Series Forecasting?. arXiv preprint arXiv:2205.13504. + .. [2] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against + Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p """ super().__init__(**self._extract_torch_model_params(**self.model_params)) diff --git a/darts/models/forecasting/nbeats.py b/darts/models/forecasting/nbeats.py index 029804bfa8..b49f3f08f9 100644 --- a/darts/models/forecasting/nbeats.py +++ b/darts/models/forecasting/nbeats.py @@ -11,7 +11,10 @@ import torch.nn as nn from darts.logging import get_logger, raise_if_not, raise_log -from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule +from darts.models.forecasting.pl_forecasting_module import ( + PLPastCovariatesModule, + io_processor, +) from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel from darts.utils.torch import MonteCarloDropout @@ -490,6 +493,7 @@ def __init__( self.stacks_list[-1].blocks[-1].backcast_linear_layer.requires_grad_(False) self.stacks_list[-1].blocks[-1].backcast_g.requires_grad_(False) + @io_processor def forward(self, x_in: Tuple): x, _ = x_in @@ -616,6 +620,9 @@ def __init__( to using a constant learning rate. Default: ``None``. lr_scheduler_kwargs Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. + use_reversible_instance_norm + Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_. + It is only applied to the features of the target series and not the covariates. batch_size Number of time series (input and output sequences) used in each training pass. Default: ``32``. n_epochs @@ -725,6 +732,8 @@ def encode_year(idx): References ---------- .. [1] https://openreview.net/forum?id=r1ecqn4YwB + .. [2] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against + Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p """ super().__init__(**self._extract_torch_model_params(**self.model_params)) diff --git a/darts/models/forecasting/nhits.py b/darts/models/forecasting/nhits.py index d8238ec061..7429d593cd 100644 --- a/darts/models/forecasting/nhits.py +++ b/darts/models/forecasting/nhits.py @@ -11,7 +11,10 @@ import torch.nn.functional as F from darts.logging import get_logger, raise_if_not -from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule +from darts.models.forecasting.pl_forecasting_module import ( + PLPastCovariatesModule, + io_processor, +) from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel from darts.utils.torch import MonteCarloDropout @@ -417,6 +420,7 @@ def __init__( # on this params (the last block backcast is not part of the final output of the net). self.stacks_list[-1].blocks[-1].backcast_linear_layer.requires_grad_(False) + @io_processor def forward(self, x_in: Tuple): x, _ = x_in @@ -552,6 +556,9 @@ def __init__( to using a constant learning rate. Default: ``None``. lr_scheduler_kwargs Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. + use_reversible_instance_norm + Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_. + It is only applied to the features of the target series and not the covariates. batch_size Number of time series (input and output sequences) used in each training pass. Default: ``32``. n_epochs @@ -662,6 +669,8 @@ def encode_year(idx): ---------- .. [1] C. Challu et al. "N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting", https://arxiv.org/abs/2201.12886 + .. [2] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against + Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p """ super().__init__(**self._extract_torch_model_params(**self.model_params)) diff --git a/darts/models/forecasting/nlinear.py b/darts/models/forecasting/nlinear.py index 35fa456e3d..88f6665391 100644 --- a/darts/models/forecasting/nlinear.py +++ b/darts/models/forecasting/nlinear.py @@ -9,7 +9,10 @@ import torch.nn as nn from darts.logging import raise_if -from darts.models.forecasting.pl_forecasting_module import PLMixedCovariatesModule +from darts.models.forecasting.pl_forecasting_module import ( + PLMixedCovariatesModule, + io_processor, +) from darts.models.forecasting.torch_forecasting_model import MixedCovariatesTorchModel @@ -106,6 +109,7 @@ def _create_linear_layer(in_dim, out_dim): layer_in_dim_static_cov, layer_out_dim ) + @io_processor def forward( self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] ): @@ -246,6 +250,9 @@ def __init__( to using a constant learning rate. Default: ``None``. lr_scheduler_kwargs Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. + use_reversible_instance_norm + Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_. + It is only applied to the features of the target series and not the covariates. batch_size Number of time series (input and output sequences) used in each training pass. Default: ``32``. n_epochs @@ -354,6 +361,8 @@ def encode_year(idx): ---------- .. [1] Zeng, A., Chen, M., Zhang, L., & Xu, Q. (2022). Are Transformers Effective for Time Series Forecasting?. arXiv preprint arXiv:2205.13504. + .. [2] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against + Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p """ super().__init__(**self._extract_torch_model_params(**self.model_params)) diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index f4706b6dbd..821e35745c 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -12,6 +12,7 @@ from joblib import Parallel, delayed from darts.logging import get_logger, raise_if, raise_log +from darts.models.components.layer_norm_variants import RINorm from darts.timeseries import TimeSeries from darts.utils.likelihood_models import Likelihood from darts.utils.timeseries_generation import _build_forecast_series @@ -24,6 +25,43 @@ pl_160_or_above = int(tokens[0]) > 1 or int(tokens[0]) == 1 and int(tokens[1]) >= 6 +def io_processor(forward): + """Applies some input / output processing to PLForecastingModule.forward. + Note that this wrapper must be added to each of PLForecastinModule's subclasses forward methods. + Here is an example how to add the decorator: + + ```python + @io_processor + def forward(self, *args, **kwargs) + pass + ``` + + Applies + ------- + Reversible Instance Normalization + normalizes batch input target features, and inverse transform the forward output back to the original scale + """ + + def forward_wrapper(self, *args, **kwargs): + if not self.use_reversible_instance_norm: + return forward(self, *args, **kwargs) + + # x is input batch tuple which by definition has the past features in the first element starting with the + # first n target features + x: Tuple = args[0][0] + # apply reversible instance normalization + x[:, :, : self.n_targets] = self.rin(x[:, :, : self.n_targets]) + # run the forward pass + out = forward(self, *((x, *args[0][1:]), *args[1:]), **kwargs) + # inverse transform target output back to original scale; by definition the first output + if isinstance(out, tuple): + return self.rin.inverse(out[0]), *out[1:] + else: + return self.rin.inverse(out) + + return forward_wrapper + + class PLForecastingModule(pl.LightningModule, ABC): @abstractmethod def __init__( @@ -40,6 +78,7 @@ def __init__( optimizer_kwargs: Optional[Dict] = None, lr_scheduler_cls: Optional[torch.optim.lr_scheduler._LRScheduler] = None, lr_scheduler_kwargs: Optional[Dict] = None, + use_reversible_instance_norm: bool = False, ) -> None: """ PyTorch Lightning-based Forecasting Module. @@ -84,6 +123,14 @@ def __init__( to using a constant learning rate. Default: ``None``. lr_scheduler_kwargs Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. + use_reversible_instance_norm + Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [1]_. + It is only applied to the features of the target series and not the covariates. + + References + ---------- + .. [1] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against + Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p """ super().__init__() @@ -107,6 +154,9 @@ def __init__( # saved in checkpoint to be able to instantiate a model without calling fit_from_dataset self.train_sample_shape = train_sample_shape + self.n_targets = ( + train_sample_shape[0][1] if train_sample_shape is not None else 1 + ) # persist optimiser and LR scheduler parameters self.optimizer_cls = optimizer_cls @@ -121,6 +171,13 @@ def __init__( self.train_metrics = torch_metrics.clone(prefix="train_") self.val_metrics = torch_metrics.clone(prefix="val_") + # reversible instance norm + self.use_reversible_instance_norm = use_reversible_instance_norm + if use_reversible_instance_norm: + self.rin = RINorm(input_dim=self.n_targets) + else: + self.rin = None + # initialize prediction parameters self.pred_n: Optional[int] = None self.pred_num_samples: Optional[int] = None diff --git a/darts/models/forecasting/rnn_model.py b/darts/models/forecasting/rnn_model.py index cbb205a38e..7db46a1345 100644 --- a/darts/models/forecasting/rnn_model.py +++ b/darts/models/forecasting/rnn_model.py @@ -9,7 +9,10 @@ import torch.nn as nn from darts.logging import get_logger, raise_if_not -from darts.models.forecasting.pl_forecasting_module import PLDualCovariatesModule +from darts.models.forecasting.pl_forecasting_module import ( + PLDualCovariatesModule, + io_processor, +) from darts.models.forecasting.torch_forecasting_model import DualCovariatesTorchModel from darts.timeseries import TimeSeries from darts.utils.data import DualCovariatesShiftedDataset, TrainingDataset @@ -86,6 +89,7 @@ def __init__( # The RNN module needs a linear layer V that transforms hidden states into outputs, individually self.V = nn.Linear(hidden_dim, target_size * nr_params) + @io_processor def forward( self, x_in: Tuple, h: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -390,12 +394,15 @@ def encode_year(idx): # create copy of model parameters model_kwargs = {key: val for key, val in self.model_params.items()} - if model_kwargs.get("output_chunk_length") is not None: - logger.warning( - "ignoring user defined `output_chunk_length`. RNNModel uses a fixed `output_chunk_length=1`." - ) - - model_kwargs["output_chunk_length"] = 1 + for kwarg, default_value in zip( + ["output_chunk_length", "use_reversible_instance_norm"], [1, False] + ): + if model_kwargs.get(kwarg) is not None: + logger.warning( + f"ignoring user defined `{kwarg}`. RNNModel uses a fixed " + f"`{kwarg}={default_value}`." + ) + model_kwargs[kwarg] = default_value super().__init__(**self._extract_torch_model_params(**model_kwargs)) diff --git a/darts/models/forecasting/tcn_model.py b/darts/models/forecasting/tcn_model.py index ed643cb0fe..a75e04eac8 100644 --- a/darts/models/forecasting/tcn_model.py +++ b/darts/models/forecasting/tcn_model.py @@ -11,7 +11,10 @@ import torch.nn.functional as F from darts.logging import get_logger, raise_if_not -from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule +from darts.models.forecasting.pl_forecasting_module import ( + PLPastCovariatesModule, + io_processor, +) from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel from darts.timeseries import TimeSeries from darts.utils.data import PastCovariatesShiftedDataset @@ -231,6 +234,7 @@ def __init__( self.res_blocks_list.append(res_block) self.res_blocks = nn.ModuleList(self.res_blocks_list) + @io_processor def forward(self, x_in: Tuple): x, _ = x_in # data is of size (batch_size, input_chunk_length, input_size) @@ -317,6 +321,9 @@ def __init__( to using a constant learning rate. Default: ``None``. lr_scheduler_kwargs Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. + use_reversible_instance_norm + Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_. + It is only applied to the features of the target series and not the covariates. batch_size Number of time series (input and output sequences) used in each training pass. Default: ``32``. n_epochs @@ -427,6 +434,8 @@ def encode_year(idx): References ---------- .. [1] https://arxiv.org/abs/1803.01271 + .. [2] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against + Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p """ raise_if_not( diff --git a/darts/models/forecasting/tft_model.py b/darts/models/forecasting/tft_model.py index f1cf5f99e8..646670dbb8 100644 --- a/darts/models/forecasting/tft_model.py +++ b/darts/models/forecasting/tft_model.py @@ -15,7 +15,10 @@ from darts.logging import get_logger, raise_if, raise_if_not, raise_log from darts.models.components import glu_variants, layer_norm_variants from darts.models.components.glu_variants import GLU_FFN -from darts.models.forecasting.pl_forecasting_module import PLMixedCovariatesModule +from darts.models.forecasting.pl_forecasting_module import ( + PLMixedCovariatesModule, + io_processor, +) from darts.models.forecasting.tft_submodels import ( _GateAddNorm, _GatedResidualNetwork, @@ -448,6 +451,7 @@ def get_attention_mask_future( ) return mask + @io_processor def forward( self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] ) -> torch.Tensor: @@ -764,6 +768,9 @@ def __init__( to using a constant learning rate. Default: ``None``. lr_scheduler_kwargs Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. + use_reversible_instance_norm + Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [3]_. + It is only applied to the features of the target series and not the covariates. batch_size Number of time series (input and output sequences) used in each training pass. Default: ``32``. n_epochs @@ -873,7 +880,9 @@ def encode_year(idx): References ---------- .. [1] https://arxiv.org/pdf/1912.09363.pdf - ..[2] Shazeer, Noam, "GLU Variants Improve Transformer", 2020. arVix https://arxiv.org/abs/2002.05202. + .. [2] Shazeer, Noam, "GLU Variants Improve Transformer", 2020. arVix https://arxiv.org/abs/2002.05202. + .. [3] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against + Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p """ model_kwargs = {key: val for key, val in self.model_params.items()} if likelihood is None and loss_fn is None: diff --git a/darts/models/forecasting/tide_model.py b/darts/models/forecasting/tide_model.py index cd5e9e8038..7ae8f9ff1f 100644 --- a/darts/models/forecasting/tide_model.py +++ b/darts/models/forecasting/tide_model.py @@ -9,8 +9,10 @@ import torch.nn as nn from darts.logging import get_logger -from darts.models.components.layer_norm_variants import RINorm -from darts.models.forecasting.pl_forecasting_module import PLMixedCovariatesModule +from darts.models.forecasting.pl_forecasting_module import ( + PLMixedCovariatesModule, + io_processor, +) from darts.models.forecasting.torch_forecasting_model import MixedCovariatesTorchModel MixedCovariatesTrainTensorType = Tuple[ @@ -77,7 +79,6 @@ def __init__( temporal_decoder_hidden: int, temporal_width: int, use_layer_norm: bool, - use_reversible_instance_norm: bool, dropout: float, **kwargs, ): @@ -109,8 +110,6 @@ def __init__( The width of the future covariate embedding space. use_layer_norm Whether to use layer normalization in the Residual Blocks. - use_reversible_instance_norm - Whether to use reversible instance normalization. dropout Dropout probability **kwargs @@ -141,7 +140,6 @@ def __init__( self.hidden_size = hidden_size self.temporal_decoder_hidden = temporal_decoder_hidden self.use_layer_norm = use_layer_norm - self.use_reversible_instance_norm = use_reversible_instance_norm self.dropout = dropout self.temporal_width = temporal_width @@ -225,11 +223,7 @@ def __init__( self.input_chunk_length, self.output_chunk_length * self.nr_params ) - if self.use_reversible_instance_norm: - self.rin = RINorm(input_dim=output_dim) - else: - self.rin = None - + @io_processor def forward( self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] ) -> torch.Tensor: @@ -250,9 +244,6 @@ def forward( # x_static_covariates has shape (batch_size, static_cov_dim) x, x_future_covariates, x_static_covariates = x_in - if self.use_reversible_instance_norm: - x[:, :, : self.output_dim] = self.rin(x[:, :, : self.output_dim]) - x_lookback = x[:, :, : self.output_dim] # future covariates need to be extracted from x and stacked with historical future covariates @@ -328,10 +319,6 @@ def forward( ) # skip.view(temporal_decoded.shape) y = y.view(-1, self.output_chunk_length, self.output_dim, self.nr_params) - - if self.use_reversible_instance_norm: - y = self.rin.inverse(y) - return y @@ -347,7 +334,6 @@ def __init__( temporal_width: int = 4, temporal_decoder_hidden: int = 32, use_layer_norm: bool = False, - use_reversible_instance_norm: bool = False, dropout: float = 0.1, use_static_covariates: bool = True, **kwargs, @@ -389,9 +375,6 @@ def __init__( The width of the layers in the temporal decoder. use_layer_norm Whether to use layer normalization in the residual blocks. - use_reversible_instance_norm - Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_. - It is only applied to the features of the target series and not the covariates. dropout The dropout probability to be used in fully connected layers. This is compatible with Monte Carlo dropout at inference time for model uncertainty estimation (enabled with ``mc_dropout=True`` at @@ -421,6 +404,9 @@ def __init__( to using a constant learning rate. Default: ``None``. lr_scheduler_kwargs Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. + use_reversible_instance_norm + Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_. + It is only applied to the features of the target series and not the covariates. batch_size Number of time series (input and output sequences) used in each training pass. Default: ``32``. n_epochs @@ -549,7 +535,6 @@ def encode_year(idx): self._considers_static_covariates = use_static_covariates self.use_layer_norm = use_layer_norm - self.use_reversible_instance_norm = use_reversible_instance_norm self.dropout = dropout def _create_model( @@ -601,7 +586,6 @@ def _create_model( temporal_width=self.temporal_width, temporal_decoder_hidden=self.temporal_decoder_hidden, use_layer_norm=self.use_layer_norm, - use_reversible_instance_norm=self.use_reversible_instance_norm, dropout=self.dropout, **self.pl_module_params, ) diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index 22e37f8c60..ec445a4a32 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -16,7 +16,10 @@ CustomFeedForwardDecoderLayer, CustomFeedForwardEncoderLayer, ) -from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule +from darts.models.forecasting.pl_forecasting_module import ( + PLPastCovariatesModule, + io_processor, +) from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel logger = get_logger(__name__) @@ -290,6 +293,7 @@ def _create_transformer_inputs(self, data): return src, tgt + @io_processor def forward(self, x_in: Tuple): data, _ = x_in # Here we create 'src' and 'tgt', the inputs for the encoder and decoder @@ -405,6 +409,9 @@ def __init__( to using a constant learning rate. Default: ``None``. lr_scheduler_kwargs Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. + use_reversible_instance_norm + Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [3]_. + It is only applied to the features of the target series and not the covariates. batch_size Number of time series (input and output sequences) used in each training pass. Default: ``32``. n_epochs @@ -516,7 +523,9 @@ def encode_year(idx): .. [1] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin, "Attention Is All You Need", 2017. In Advances in Neural Information Processing Systems, pages 6000-6010. https://arxiv.org/abs/1706.03762. - ..[2] Shazeer, Noam, "GLU Variants Improve Transformer", 2020. arVix https://arxiv.org/abs/2002.05202. + .. [2] Shazeer, Noam, "GLU Variants Improve Transformer", 2020. arVix https://arxiv.org/abs/2002.05202. + .. [3] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against + Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p Notes ----- diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index c9e12be408..400899b76a 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -26,13 +26,44 @@ MetricCollection, ) - from darts.models import DLinearModel, RNNModel + from darts.models import ( + BlockRNNModel, + DLinearModel, + NBEATSModel, + NHiTSModel, + NLinearModel, + RNNModel, + TCNModel, + TFTModel, + TiDEModel, + TransformerModel, + ) + from darts.models.components.layer_norm_variants import RINorm from darts.utils.likelihood_models import ( GaussianLikelihood, LaplaceLikelihood, Likelihood, ) + kwargs = { + "input_chunk_length": 10, + "output_chunk_length": 1, + "n_epochs": 1, + "pl_trainer_kwargs": {"fast_dev_run": True, **tfm_kwargs["pl_trainer_kwargs"]}, + } + models = [ + (BlockRNNModel, kwargs), + (DLinearModel, kwargs), + (NBEATSModel, kwargs), + (NHiTSModel, kwargs), + (NLinearModel, kwargs), + (RNNModel, {"training_length": 2, **kwargs}), + (TCNModel, kwargs), + (TFTModel, {"add_relative_index": 2, **kwargs}), + (TiDEModel, kwargs), + (TransformerModel, kwargs), + ] + TORCH_AVAILABLE = True except ImportError: logger.warning("Torch not available. RNN tests will be skipped.") @@ -1384,6 +1415,38 @@ def test_encoders(self, tmpdir_fn): _ = model.predict(n=n, future_covariates=fc) _ = model.predict(n=n, past_covariates=pc, future_covariates=fc) + @pytest.mark.parametrize("model_config", models) + def test_rin(self, model_config): + model_cls, kwargs = model_config + model_no_rin = model_cls(use_reversible_instance_norm=False, **kwargs) + model_rin = model_cls(use_reversible_instance_norm=True, **kwargs) + + # univariate no RIN + model_no_rin.fit(self.series) + assert not model_no_rin.model.use_reversible_instance_norm + assert model_no_rin.model.rin is None + + # univariate with RIN + model_rin.fit(self.series) + if issubclass(model_cls, RNNModel): + # RNNModel will not use RIN + assert not model_rin.model.use_reversible_instance_norm + assert model_rin.model.rin is None + return + else: + assert model_rin.model.use_reversible_instance_norm + assert isinstance(model_rin.model.rin, RINorm) + assert model_rin.model.rin.input_dim == self.series.n_components + # multivariate with RIN + model_rin_mv = model_rin.untrained_model() + model_rin_mv.fit(self.multivariate_series) + assert model_rin_mv.model.use_reversible_instance_norm + assert isinstance(model_rin_mv.model.rin, RINorm) + assert ( + model_rin_mv.model.rin.input_dim + == self.multivariate_series.n_components + ) + def helper_equality_encoders( self, first_encoders: Dict[str, Any], second_encoders: Dict[str, Any] ): diff --git a/darts/tests/models/forecasting/test_transformer_model.py b/darts/tests/models/forecasting/test_transformer_model.py index 3b4958411c..8ece59c09d 100644 --- a/darts/tests/models/forecasting/test_transformer_model.py +++ b/darts/tests/models/forecasting/test_transformer_model.py @@ -38,6 +38,7 @@ class TestTransformerModel: input_size=1, input_chunk_length=1, output_chunk_length=1, + train_sample_shape=((1, 1),), output_size=1, nr_params=1, d_model=512,