Skip to content

Commit

Permalink
add rin to all torch models (#1969)
Browse files Browse the repository at this point in the history
* add rin to all torch models

* fix failint transformer test

* update CHANGELOG.md

* update model docs

* ignore RIN for RNNModel

* apply suggestions from PR review

* fix model links in readme

* remove statforecast utils from generated API Reference

* Update CHANGELOG.md

* add RIN reference link to PLForecastingModule
  • Loading branch information
dennisbader authored Sep 2, 2023
1 parent 52ac181 commit fecb99d
Show file tree
Hide file tree
Showing 16 changed files with 233 additions and 50 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Improved**
- `TimeSeries` with a `RangeIndex` starting in the negative start are now supported by `historical_forecasts`. [#1866](https://github.com/unit8co/darts/pull/1866) by [Antoine Madrona](https://github.com/madtoinou).
- Added a new argument `start_format` to `historical_forecasts()`, `backtest()` and `gridsearch` that allows to use an integer `start` either as the index position or index value/label for `series` indexed with a `pd.RangeIndex`. [#1866](https://github.com/unit8co/darts/pull/1866) by [Antoine Madrona](https://github.com/madtoinou).
- Added `RINorm` (Reversible Instance Norm) as an input normalization option for all `TorchForecastingModel` except `RNNModel`. Activate it with model creation parameter `use_reversible_instance_norm`. [#1969](https://github.com/unit8co/darts/pull/1969) by [Dennis Bader](https://github.com/dennisbader).
- Reduced the size of the Darts docker image `unit8/darts:latest`, and included all optional models as well as dev requirements. [#1878](https://github.com/unit8co/darts/pull/1878) by [Alex Colpitts](https://github.com/alexcolpitts96).

**Fixed**
Expand Down Expand Up @@ -60,7 +61,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Other improvements:
- Improved static covariates column naming when using `StaticCovariatesTransformer` with a `sklearn.preprocessing.OneHotEncoder`. [#1863](https://github.com/unit8co/darts/pull/1863) by [Anne de Vries](https://github.com/anne-devries).
- Added `MSTL` (Season-Trend decomposition using LOESS for multiple seasonalities) as a `method` option for `extract_trend_and_seasonality()`. [#1879](https://github.com/unit8co/darts/pull/1879) by [Alex Colpitts](https://github.com/alexcolpitts96).
- Added `RINorm` (Reversible Instance Norm) as a new input normalization option for `TorchForecastingModel`. So far only `TiDEModel` supports it with model creation parameter `user_reversible_instance_norm`. [#1865](https://github.com/unit8co/darts/issues/1856) by [Alex Colpitts](https://github.com/alexcolpitts96).
- Added `RINorm` (Reversible Instance Norm) as a new input normalization option for `TorchForecastingModel`. So far only `TiDEModel` supports it with model creation parameter `use_reversible_instance_norm`. [#1865](https://github.com/unit8co/darts/issues/1856) by [Alex Colpitts](https://github.com/alexcolpitts96).
- Improvements to `TimeSeries.plot()`: custom axes are now properly supported with parameter `ax`. Axis is now returned for downstream tasks. [#1916](https://github.com/unit8co/darts/pull/1916) by [Dennis Bader](https://github.com/dennisbader).

**Fixed**
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ on bringing more models and features.
| [BATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.BATS) and [TBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.TBATS) | [TBATS paper](https://robjhyndman.com/papers/ComplexSeasonality.pdf) | 🟩 🟥 | 🟥 🟥 🟥 | 🟩 🟥 | 🟥 |
| [Theta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.Theta) and [FourTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.FourTheta) | [Theta](https://robjhyndman.com/papers/Theta.pdf) & [4 Theta](https://github.com/Mcompetitions/M4-methods/blob/master/4Theta%20method.R) | 🟩 🟥 | 🟥 🟥 🟥 | 🟥 🟥 | 🟥 |
| [StatsForecastAutoTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_theta.html#darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | 🟩 🟥 | 🟥 🟥 🟥 | 🟩 🟥 | 🟥 |
| [Prophet](file:///Users/dennisbader/projects/unit8/darts/docs/build/html/generated_api/darts.models.forecasting.prophet_model.html#darts.models.forecasting.prophet_model.Prophet) (see [install notes](https://github.com/unit8co/darts/blob/master/INSTALL.md#enabling-support-for-facebook-prophet)) | [Prophet repo](https://github.com/facebook/prophet) | 🟩 🟥 | 🟥 🟩 🟥 | 🟩 🟥 | 🟥 |
| [FFT](file:///Users/dennisbader/projects/unit8/darts/docs/build/html/generated_api/darts.models.forecasting.fft.html#darts.models.forecasting.fft.FFT) (Fast Fourier Transform) | | 🟩 🟥 | 🟥 🟥 🟥 | 🟥 🟥 | 🟥 |
| [Prophet](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.prophet_model.html#darts.models.forecasting.prophet_model.Prophet) (see [install notes](https://github.com/unit8co/darts/blob/master/INSTALL.md#enabling-support-for-facebook-prophet)) | [Prophet repo](https://github.com/facebook/prophet) | 🟩 🟥 | 🟥 🟩 🟥 | 🟩 🟥 | 🟥 |
| [FFT](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.fft.html#darts.models.forecasting.fft.FFT) (Fast Fourier Transform) | | 🟩 🟥 | 🟥 🟥 🟥 | 🟥 🟥 | 🟥 |
| [KalmanForecaster](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.kalman_forecaster.html#darts.models.forecasting.kalman_forecaster.KalmanForecaster) using the Kalman filter and N4SID for system identification | [N4SID paper](https://people.duke.edu/~hpgavin/SystemID/References/VanOverschee-Automatica-1994.pdf) | 🟩 🟩 | 🟥 🟩 🟥 | 🟩 🟥 | 🟥 |
| [Croston](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.croston.html#darts.models.forecasting.croston.Croston) method | | 🟩 🟥 | 🟥 🟥 🟥 | 🟥 🟥 | 🟥 |
| **Regression Models**<br/>([GlobalForecastingModel](https://unit8co.github.io/darts/userguide/covariates.html#global-forecasting-models-gfms)) | | | | | |
Expand Down
5 changes: 0 additions & 5 deletions darts/models/components/statsforecast_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
"""
StatsForecast utils
-----------
"""

import numpy as np

# In a normal distribution, 68.27 percentage of values lie within one standard deviation of the mean
Expand Down
14 changes: 13 additions & 1 deletion darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import torch.nn as nn

from darts.logging import get_logger, raise_if_not
from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule
from darts.models.forecasting.pl_forecasting_module import (
PLPastCovariatesModule,
io_processor,
)
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel

logger = get_logger(__name__)
Expand Down Expand Up @@ -101,6 +104,7 @@ def __init__(
last = feature
self.fc = nn.Sequential(*feats)

@io_processor
def forward(self, x_in: Tuple):
x, _ = x_in
# data is of size (batch_size, input_chunk_length, input_size)
Expand Down Expand Up @@ -194,6 +198,9 @@ def __init__(
to using a constant learning rate. Default: ``None``.
lr_scheduler_kwargs
Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``.
use_reversible_instance_norm
Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [1]_.
It is only applied to the features of the target series and not the covariates.
batch_size
Number of time series (input and output sequences) used in each training pass. Default: ``32``.
n_epochs
Expand Down Expand Up @@ -299,6 +306,11 @@ def encode_year(idx):
show_warnings
whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of
your forecasting use case. Default: ``False``.
References
----------
.. [1] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against
Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p
"""
super().__init__(**self._extract_torch_model_params(**self.model_params))

Expand Down
11 changes: 10 additions & 1 deletion darts/models/forecasting/dlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import torch.nn as nn

from darts.logging import raise_if
from darts.models.forecasting.pl_forecasting_module import PLMixedCovariatesModule
from darts.models.forecasting.pl_forecasting_module import (
PLMixedCovariatesModule,
io_processor,
)
from darts.models.forecasting.torch_forecasting_model import MixedCovariatesTorchModel

MixedCovariatesTrainTensorType = Tuple[
Expand Down Expand Up @@ -155,6 +158,7 @@ def _create_linear_layer(in_dim, out_dim):
layer_in_dim_static_cov, layer_out_dim
)

@io_processor
def forward(
self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
):
Expand Down Expand Up @@ -295,6 +299,9 @@ def __init__(
to using a constant learning rate. Default: ``None``.
lr_scheduler_kwargs
Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``.
use_reversible_instance_norm
Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_.
It is only applied to the features of the target series and not the covariates.
batch_size
Number of time series (input and output sequences) used in each training pass. Default: ``32``.
n_epochs
Expand Down Expand Up @@ -405,6 +412,8 @@ def encode_year(idx):
----------
.. [1] Zeng, A., Chen, M., Zhang, L., & Xu, Q. (2022).
Are Transformers Effective for Time Series Forecasting?. arXiv preprint arXiv:2205.13504.
.. [2] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against
Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p
"""
super().__init__(**self._extract_torch_model_params(**self.model_params))

Expand Down
11 changes: 10 additions & 1 deletion darts/models/forecasting/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import torch.nn as nn

from darts.logging import get_logger, raise_if_not, raise_log
from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule
from darts.models.forecasting.pl_forecasting_module import (
PLPastCovariatesModule,
io_processor,
)
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel
from darts.utils.torch import MonteCarloDropout

Expand Down Expand Up @@ -490,6 +493,7 @@ def __init__(
self.stacks_list[-1].blocks[-1].backcast_linear_layer.requires_grad_(False)
self.stacks_list[-1].blocks[-1].backcast_g.requires_grad_(False)

@io_processor
def forward(self, x_in: Tuple):
x, _ = x_in

Expand Down Expand Up @@ -616,6 +620,9 @@ def __init__(
to using a constant learning rate. Default: ``None``.
lr_scheduler_kwargs
Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``.
use_reversible_instance_norm
Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_.
It is only applied to the features of the target series and not the covariates.
batch_size
Number of time series (input and output sequences) used in each training pass. Default: ``32``.
n_epochs
Expand Down Expand Up @@ -725,6 +732,8 @@ def encode_year(idx):
References
----------
.. [1] https://openreview.net/forum?id=r1ecqn4YwB
.. [2] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against
Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p
"""
super().__init__(**self._extract_torch_model_params(**self.model_params))

Expand Down
11 changes: 10 additions & 1 deletion darts/models/forecasting/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import torch.nn.functional as F

from darts.logging import get_logger, raise_if_not
from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule
from darts.models.forecasting.pl_forecasting_module import (
PLPastCovariatesModule,
io_processor,
)
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel
from darts.utils.torch import MonteCarloDropout

Expand Down Expand Up @@ -417,6 +420,7 @@ def __init__(
# on this params (the last block backcast is not part of the final output of the net).
self.stacks_list[-1].blocks[-1].backcast_linear_layer.requires_grad_(False)

@io_processor
def forward(self, x_in: Tuple):
x, _ = x_in

Expand Down Expand Up @@ -552,6 +556,9 @@ def __init__(
to using a constant learning rate. Default: ``None``.
lr_scheduler_kwargs
Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``.
use_reversible_instance_norm
Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_.
It is only applied to the features of the target series and not the covariates.
batch_size
Number of time series (input and output sequences) used in each training pass. Default: ``32``.
n_epochs
Expand Down Expand Up @@ -662,6 +669,8 @@ def encode_year(idx):
----------
.. [1] C. Challu et al. "N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting",
https://arxiv.org/abs/2201.12886
.. [2] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against
Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p
"""
super().__init__(**self._extract_torch_model_params(**self.model_params))

Expand Down
11 changes: 10 additions & 1 deletion darts/models/forecasting/nlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import torch.nn as nn

from darts.logging import raise_if
from darts.models.forecasting.pl_forecasting_module import PLMixedCovariatesModule
from darts.models.forecasting.pl_forecasting_module import (
PLMixedCovariatesModule,
io_processor,
)
from darts.models.forecasting.torch_forecasting_model import MixedCovariatesTorchModel


Expand Down Expand Up @@ -106,6 +109,7 @@ def _create_linear_layer(in_dim, out_dim):
layer_in_dim_static_cov, layer_out_dim
)

@io_processor
def forward(
self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
):
Expand Down Expand Up @@ -246,6 +250,9 @@ def __init__(
to using a constant learning rate. Default: ``None``.
lr_scheduler_kwargs
Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``.
use_reversible_instance_norm
Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_.
It is only applied to the features of the target series and not the covariates.
batch_size
Number of time series (input and output sequences) used in each training pass. Default: ``32``.
n_epochs
Expand Down Expand Up @@ -354,6 +361,8 @@ def encode_year(idx):
----------
.. [1] Zeng, A., Chen, M., Zhang, L., & Xu, Q. (2022).
Are Transformers Effective for Time Series Forecasting?. arXiv preprint arXiv:2205.13504.
.. [2] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against
Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p
"""
super().__init__(**self._extract_torch_model_params(**self.model_params))

Expand Down
Loading

0 comments on commit fecb99d

Please sign in to comment.