Skip to content

Commit

Permalink
add DataLoader related parameters to fit() and predict() (#2295)
Browse files Browse the repository at this point in the history
* add `torch.utils.data.DataLoader` related parameters to `fit()` and `predict()` of `TorchForecastingModel`

* update CHANGELOG.md

* replace specific dataloader arguments with dataloader_kwargs

* - allow to set all params
- add predefined defaults

* fix wrong indentation

* - allow to set all params for predict
- add breaking change to CHANGELOG.md

* improve docs

* add unittests

* update test

---------

Co-authored-by: Bohdan Bilonoh <[email protected]>
Co-authored-by: dennisbader <[email protected]>
  • Loading branch information
3 people authored Jun 3, 2024
1 parent a0cc279 commit a4ed8b1
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 58 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- New method `with_times_and_values()`, which returns a new series with a new time index and new values but with identical columns and metadata as the series called from (static covariates, hierarchy).
- New method `slice_intersect_times()`, which returns the sliced time index of a series, where the index has been intersected with another series.
- Method `with_values()` now also acts on array-like `values` rather than only on numpy arrays.

- Improvements to `TorchForecastingModel`: [#2295](https://github.com/unit8co/darts/pull/2295) by [Bohdan Bilonoh](https://github.com/BohdanBilonoh).
- Added `dataloader_kwargs` parameters to `fit*()`, `predict*()`, and `find_lr()` for more control over the PyTorch `DataLoader` setup.
- 🔴 Removed parameter `num_loader_workers` from `fit*()`, `predict*()`, `find_lr()`. You can now set the parameter through the `dataloader_kwargs` dict.

**Fixed**
- Fixed a bug when using a `RegressionModel` (that supports validation series) with a validation set, and encoders and/or component-specific lags, where the encodings and component specific lags were not added to the set. [#2383](https://github.com/unit8co/darts/pull/2383) by [Dennis Bader](https://github.com/dennisbader).
Expand Down
127 changes: 70 additions & 57 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def fit(
verbose: Optional[bool] = None,
epochs: int = 0,
max_samples_per_ts: Optional[int] = None,
num_loader_workers: int = 0,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
) -> "TorchForecastingModel":
"""Fit/train the model on one or multiple series.
Expand Down Expand Up @@ -714,11 +714,12 @@ def fit(
large number of training samples. This parameter upper-bounds the number of training samples per time
series (taking only the most recent samples in each series). Leaving to None does not apply any
upper bound.
num_loader_workers
Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
A larger number of workers can sometimes increase performance, but can also incur extra overheads
and increase memory usage, as more batches are loaded in parallel.
dataloader_kwargs
Optionally, a dictionary of keyword arguments used to create the PyTorch `DataLoader` instances for the
training and validation datasets. For more information on `DataLoader`, check out `this link
<https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_.
By default, Darts configures parameters ("batch_size", "shuffle", "drop_last", "collate_fn", "pin_memory")
for seamless forecasting. Changing them should be done with care to avoid unexpected behavior.
Returns
-------
Expand All @@ -743,7 +744,7 @@ def fit(
verbose=verbose,
epochs=epochs,
max_samples_per_ts=max_samples_per_ts,
num_loader_workers=num_loader_workers,
dataloader_kwargs=dataloader_kwargs,
)
# call super fit only if user is actually fitting the model
super().fit(
Expand All @@ -765,7 +766,7 @@ def _setup_for_fit_from_dataset(
verbose: Optional[bool] = None,
epochs: int = 0,
max_samples_per_ts: Optional[int] = None,
num_loader_workers: int = 0,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[
Tuple[
Sequence[TimeSeries],
Expand All @@ -778,7 +779,7 @@ def _setup_for_fit_from_dataset(
Optional[pl.Trainer],
Optional[bool],
int,
int,
Optional[Dict[str, Any]],
],
]:
"""This method acts on `TimeSeries` inputs. It performs sanity checks, and sets up / returns the datasets and
Expand Down Expand Up @@ -864,7 +865,7 @@ def _setup_for_fit_from_dataset(
trainer,
verbose,
epochs,
num_loader_workers,
dataloader_kwargs,
)
return series_input, fit_from_ds_params

Expand All @@ -876,7 +877,7 @@ def fit_from_dataset(
trainer: Optional[pl.Trainer] = None,
verbose: Optional[bool] = None,
epochs: int = 0,
num_loader_workers: int = 0,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
) -> "TorchForecastingModel":
"""
Train the model with a specific :class:`darts.utils.data.TrainingDataset` instance.
Expand Down Expand Up @@ -909,11 +910,12 @@ def fit_from_dataset(
epochs
If specified, will train the model for ``epochs`` (additional) epochs, irrespective of what ``n_epochs``
was provided to the model constructor.
num_loader_workers
Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
A larger number of workers can sometimes increase performance, but can also incur extra overheads
and increase memory usage, as more batches are loaded in parallel.
dataloader_kwargs
Optionally, a dictionary of keyword arguments used to create the PyTorch `DataLoader` instances for the
training and validation datasets. For more information on `DataLoader`, check out `this link
<https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_.
By default, Darts configures parameters ("batch_size", "shuffle", "drop_last", "collate_fn", "pin_memory")
for seamless forecasting. Changing them should be done with care to avoid unexpected behavior.
Returns
-------
Expand All @@ -927,7 +929,7 @@ def fit_from_dataset(
trainer=trainer,
verbose=verbose,
epochs=epochs,
num_loader_workers=num_loader_workers,
dataloader_kwargs=dataloader_kwargs,
)
)
return self
Expand All @@ -939,7 +941,7 @@ def _setup_for_train(
trainer: Optional[pl.Trainer] = None,
verbose: Optional[bool] = None,
epochs: int = 0,
num_loader_workers: int = 0,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[pl.Trainer, PLForecastingModule, DataLoader, Optional[DataLoader]]:
"""This method acts on `TrainingDataset` inputs. It performs sanity checks, and sets up / returns the trainer,
model, and dataset loaders required for training the model with `_train()`.
Expand Down Expand Up @@ -996,28 +998,30 @@ def _setup_for_train(

# Setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at
# least one batch no matter the chosen batch size
dataloader_kwargs = dict(
{
"batch_size": self.batch_size,
"shuffle": True,
"pin_memory": True,
"drop_last": False,
"collate_fn": self._batch_collate_fn,
},
**(dataloader_kwargs or dict()),
)

train_loader = DataLoader(
train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=num_loader_workers,
pin_memory=True,
drop_last=False,
collate_fn=self._batch_collate_fn,
**dataloader_kwargs,
)

# Prepare validation data
dataloader_kwargs["shuffle"] = False
val_loader = (
None
if val_dataset is None
else DataLoader(
val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=num_loader_workers,
pin_memory=True,
drop_last=False,
collate_fn=self._batch_collate_fn,
**dataloader_kwargs,
)
)

Expand Down Expand Up @@ -1082,7 +1086,7 @@ def lr_find(
verbose: Optional[bool] = None,
epochs: int = 0,
max_samples_per_ts: Optional[int] = None,
num_loader_workers: int = 0,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
min_lr: float = 1e-08,
max_lr: float = 1,
num_training: int = 100,
Expand Down Expand Up @@ -1154,11 +1158,12 @@ def lr_find(
large number of training samples. This parameter upper-bounds the number of training samples per time
series (taking only the most recent samples in each series). Leaving to None does not apply any
upper bound.
num_loader_workers
Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances,
both for the training and validation loaders (if any).
A larger number of workers can sometimes increase performance, but can also incur extra overheads
and increase memory usage, as more batches are loaded in parallel.
dataloader_kwargs
Optionally, a dictionary of keyword arguments used to create the PyTorch `DataLoader` instances for the
training and validation datasets. For more information on `DataLoader`, check out `this link
<https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_.
By default, Darts configures parameters ("batch_size", "shuffle", "drop_last", "collate_fn", "pin_memory")
for seamless forecasting. Changing them should be done with care to avoid unexpected behavior.
min_lr
minimum learning rate to investigate
max_lr
Expand Down Expand Up @@ -1190,7 +1195,7 @@ def lr_find(
verbose=verbose,
epochs=epochs,
max_samples_per_ts=max_samples_per_ts,
num_loader_workers=num_loader_workers,
dataloader_kwargs=dataloader_kwargs,
)
trainer, model, train_loader, val_loader = self._setup_for_train(*params)
return Tuner(trainer).lr_find(
Expand Down Expand Up @@ -1219,7 +1224,7 @@ def predict(
n_jobs: int = 1,
roll_size: Optional[int] = None,
num_samples: int = 1,
num_loader_workers: int = 0,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
mc_dropout: bool = False,
predict_likelihood_parameters: bool = False,
show_warnings: bool = True,
Expand Down Expand Up @@ -1281,11 +1286,12 @@ def predict(
num_samples
Number of times a prediction is sampled from a probabilistic model. Should be left set to 1
for deterministic models.
num_loader_workers
Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances,
for the inference/prediction dataset loaders (if any).
A larger number of workers can sometimes increase performance, but can also incur extra overheads
and increase memory usage, as more batches are loaded in parallel.
dataloader_kwargs
Optionally, a dictionary of keyword arguments used to create the PyTorch `DataLoader` instance for the
inference/prediction dataset. For more information on `DataLoader`, check out `this link
<https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_.
By default, Darts configures parameters ("batch_size", "shuffle", "drop_last", "collate_fn", "pin_memory")
for seamless forecasting. Changing them should be done with care to avoid unexpected behavior.
mc_dropout
Optionally, enable monte carlo dropout for predictions using neural network based models.
This allows bayesian approximation by specifying an implicit prior over learned models.
Expand Down Expand Up @@ -1369,7 +1375,7 @@ def predict(
n_jobs=n_jobs,
roll_size=roll_size,
num_samples=num_samples,
num_loader_workers=num_loader_workers,
dataloader_kwargs=dataloader_kwargs,
mc_dropout=mc_dropout,
predict_likelihood_parameters=predict_likelihood_parameters,
)
Expand All @@ -1387,7 +1393,7 @@ def predict_from_dataset(
n_jobs: int = 1,
roll_size: Optional[int] = None,
num_samples: int = 1,
num_loader_workers: int = 0,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
mc_dropout: bool = False,
predict_likelihood_parameters: bool = False,
) -> Sequence[TimeSeries]:
Expand Down Expand Up @@ -1428,11 +1434,12 @@ def predict_from_dataset(
num_samples
Number of times a prediction is sampled from a probabilistic model. Should be left set to 1
for deterministic models.
num_loader_workers
Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances,
for the inference/prediction dataset loaders (if any).
A larger number of workers can sometimes increase performance, but can also incur extra overheads
and increase memory usage, as more batches are loaded in parallel.
dataloader_kwargs
Optionally, a dictionary of keyword arguments used to create the PyTorch `DataLoader` instance for the
inference/prediction dataset. For more information on `DataLoader`, check out `this link
<https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_.
By default, Darts configures parameters ("batch_size", "shuffle", "drop_last", "collate_fn", "pin_memory")
for seamless forecasting. Changing them should be done with care to avoid unexpected behavior.
mc_dropout
Optionally, enable monte carlo dropout for predictions using neural network based models.
This allows bayesian approximation by specifying an implicit prior over learned models.
Expand Down Expand Up @@ -1487,14 +1494,20 @@ def predict_from_dataset(
mc_dropout=mc_dropout,
)

dataloader_kwargs = dict(
{
"batch_size": batch_size,
"pin_memory": True,
"drop_last": False,
"collate_fn": self._batch_collate_fn,
},
**(dataloader_kwargs or {}),
**{"shuffle": False},
)

pred_loader = DataLoader(
input_series_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_loader_workers,
pin_memory=True,
drop_last=False,
collate_fn=self._batch_collate_fn,
**dataloader_kwargs,
)

# set up trainer. use user supplied trainer or create a new trainer from scratch
Expand All @@ -1503,7 +1516,7 @@ def predict_from_dataset(
)

# prediction output comes as nested list: list of predicted `TimeSeries` for each batch.
predictions = self.trainer.predict(self.model, pred_loader)
predictions = self.trainer.predict(model=self.model, dataloaders=pred_loader)
# flatten and return
return [ts for batch in predictions for ts in batch]

Expand Down
77 changes: 77 additions & 0 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers.logger import DummyLogger
from pytorch_lightning.tuner.lr_finder import _LRFinder
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchmetrics import (
MeanAbsoluteError,
MeanAbsolutePercentageError,
Expand Down Expand Up @@ -1474,6 +1475,82 @@ def test_val_set(self, model_config):
with patch("pytorch_lightning.Trainer.fit") as fit_patch:
self.helper_check_val_set(*model_config, fit_patch)

def test_dataloader_kwargs_setup(self):
train_series, val_series = self.series[:-40], self.series[-40:]
model = RNNModel(12, "RNN", 10, 10, random_state=42, **tfm_kwargs)

with patch("pytorch_lightning.Trainer.fit") as fit_patch:
model.fit(train_series, val_series=val_series)
assert "train_dataloaders" in fit_patch.call_args.kwargs
assert "val_dataloaders" in fit_patch.call_args.kwargs

train_dl = fit_patch.call_args.kwargs["train_dataloaders"]
assert isinstance(train_dl, DataLoader)
val_dl = fit_patch.call_args.kwargs["val_dataloaders"]
assert isinstance(val_dl, DataLoader)

dl_defaults = {
"batch_size": model.batch_size,
"pin_memory": True,
"drop_last": False,
"collate_fn": model._batch_collate_fn,
}
assert all([getattr(train_dl, k) == v for k, v in dl_defaults.items()])
# shuffle=True gives random sampler
assert isinstance(train_dl.sampler, RandomSampler)

assert all([getattr(val_dl, k) == v for k, v in dl_defaults.items()])
# shuffle=False gives sequential sampler
assert isinstance(val_dl.sampler, SequentialSampler)

# check that overwriting the dataloader kwargs works
dl_custom = dict(dl_defaults, **{"batch_size": 50, "drop_last": True})
model.fit(train_series, val_series=val_series, dataloader_kwargs=dl_custom)
train_dl = fit_patch.call_args.kwargs["train_dataloaders"]
val_dl = fit_patch.call_args.kwargs["val_dataloaders"]
assert all([getattr(train_dl, k) == v for k, v in dl_custom.items()])
assert all([getattr(val_dl, k) == v for k, v in dl_custom.items()])

with patch("pytorch_lightning.Trainer.predict") as pred_patch:
# calling predict with the patch will raise an error, but we only need to
# check the dataloader setup
with pytest.raises(Exception):
model.predict(n=1)
assert "dataloaders" in pred_patch.call_args.kwargs
pred_dl = pred_patch.call_args.kwargs["dataloaders"]
assert isinstance(pred_dl, DataLoader)
assert all([getattr(pred_dl, k) == v for k, v in dl_defaults.items()])
# shuffle=False gives sequential sampler
assert isinstance(val_dl.sampler, SequentialSampler)

# check that overwriting the dataloader kwargs works
with pytest.raises(Exception):
model.predict(n=1, dataloader_kwargs=dl_custom)
pred_dl = pred_patch.call_args.kwargs["dataloaders"]
assert all([getattr(pred_dl, k) == v for k, v in dl_custom.items()])

def test_dataloader_kwargs_fit_predict(self):
train_series, val_series = self.series[:-40], self.series[-40:]
model = RNNModel(12, "RNN", 10, 10, random_state=42, **tfm_kwargs)

model.fit(
train_series,
val_series=val_series,
dataloader_kwargs={"batch_size": 100, "shuffle": False},
)

# check same results with default batch size (32) and custom batch size
preds_default = model.predict(
n=2,
series=[train_series, val_series],
)
preds_custom = model.predict(
n=2,
series=[train_series, val_series],
dataloader_kwargs={"batch_size": 100},
)
assert preds_default == preds_custom

def helper_check_val_set(self, model_cls, model_kwargs, fit_patch):
# naive models don't call the Trainer
if issubclass(model_cls, _GlobalNaiveModel):
Expand Down

0 comments on commit a4ed8b1

Please sign in to comment.