diff --git a/CHANGELOG.md b/CHANGELOG.md index 1cbe774182..1ca87278c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,7 +39,9 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - New method `with_times_and_values()`, which returns a new series with a new time index and new values but with identical columns and metadata as the series called from (static covariates, hierarchy). - New method `slice_intersect_times()`, which returns the sliced time index of a series, where the index has been intersected with another series. - Method `with_values()` now also acts on array-like `values` rather than only on numpy arrays. - +- Improvements to `TorchForecastingModel`: [#2295](https://github.com/unit8co/darts/pull/2295) by [Bohdan Bilonoh](https://github.com/BohdanBilonoh). + - Added `dataloader_kwargs` parameters to `fit*()`, `predict*()`, and `find_lr()` for more control over the PyTorch `DataLoader` setup. + - 🔴 Removed parameter `num_loader_workers` from `fit*()`, `predict*()`, `find_lr()`. You can now set the parameter through the `dataloader_kwargs` dict. **Fixed** - Fixed a bug when using a `RegressionModel` (that supports validation series) with a validation set, and encoders and/or component-specific lags, where the encodings and component specific lags were not added to the set. [#2383](https://github.com/unit8co/darts/pull/2383) by [Dennis Bader](https://github.com/dennisbader). diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 955e8fc2de..a501415a61 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -660,7 +660,7 @@ def fit( verbose: Optional[bool] = None, epochs: int = 0, max_samples_per_ts: Optional[int] = None, - num_loader_workers: int = 0, + dataloader_kwargs: Optional[Dict[str, Any]] = None, ) -> "TorchForecastingModel": """Fit/train the model on one or multiple series. @@ -714,11 +714,12 @@ def fit( large number of training samples. This parameter upper-bounds the number of training samples per time series (taking only the most recent samples in each series). Leaving to None does not apply any upper bound. - num_loader_workers - Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances, - both for the training and validation loaders (if any). - A larger number of workers can sometimes increase performance, but can also incur extra overheads - and increase memory usage, as more batches are loaded in parallel. + dataloader_kwargs + Optionally, a dictionary of keyword arguments used to create the PyTorch `DataLoader` instances for the + training and validation datasets. For more information on `DataLoader`, check out `this link + `_. + By default, Darts configures parameters ("batch_size", "shuffle", "drop_last", "collate_fn", "pin_memory") + for seamless forecasting. Changing them should be done with care to avoid unexpected behavior. Returns ------- @@ -743,7 +744,7 @@ def fit( verbose=verbose, epochs=epochs, max_samples_per_ts=max_samples_per_ts, - num_loader_workers=num_loader_workers, + dataloader_kwargs=dataloader_kwargs, ) # call super fit only if user is actually fitting the model super().fit( @@ -765,7 +766,7 @@ def _setup_for_fit_from_dataset( verbose: Optional[bool] = None, epochs: int = 0, max_samples_per_ts: Optional[int] = None, - num_loader_workers: int = 0, + dataloader_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[ Tuple[ Sequence[TimeSeries], @@ -778,7 +779,7 @@ def _setup_for_fit_from_dataset( Optional[pl.Trainer], Optional[bool], int, - int, + Optional[Dict[str, Any]], ], ]: """This method acts on `TimeSeries` inputs. It performs sanity checks, and sets up / returns the datasets and @@ -864,7 +865,7 @@ def _setup_for_fit_from_dataset( trainer, verbose, epochs, - num_loader_workers, + dataloader_kwargs, ) return series_input, fit_from_ds_params @@ -876,7 +877,7 @@ def fit_from_dataset( trainer: Optional[pl.Trainer] = None, verbose: Optional[bool] = None, epochs: int = 0, - num_loader_workers: int = 0, + dataloader_kwargs: Optional[Dict[str, Any]] = None, ) -> "TorchForecastingModel": """ Train the model with a specific :class:`darts.utils.data.TrainingDataset` instance. @@ -909,11 +910,12 @@ def fit_from_dataset( epochs If specified, will train the model for ``epochs`` (additional) epochs, irrespective of what ``n_epochs`` was provided to the model constructor. - num_loader_workers - Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances, - both for the training and validation loaders (if any). - A larger number of workers can sometimes increase performance, but can also incur extra overheads - and increase memory usage, as more batches are loaded in parallel. + dataloader_kwargs + Optionally, a dictionary of keyword arguments used to create the PyTorch `DataLoader` instances for the + training and validation datasets. For more information on `DataLoader`, check out `this link + `_. + By default, Darts configures parameters ("batch_size", "shuffle", "drop_last", "collate_fn", "pin_memory") + for seamless forecasting. Changing them should be done with care to avoid unexpected behavior. Returns ------- @@ -927,7 +929,7 @@ def fit_from_dataset( trainer=trainer, verbose=verbose, epochs=epochs, - num_loader_workers=num_loader_workers, + dataloader_kwargs=dataloader_kwargs, ) ) return self @@ -939,7 +941,7 @@ def _setup_for_train( trainer: Optional[pl.Trainer] = None, verbose: Optional[bool] = None, epochs: int = 0, - num_loader_workers: int = 0, + dataloader_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[pl.Trainer, PLForecastingModule, DataLoader, Optional[DataLoader]]: """This method acts on `TrainingDataset` inputs. It performs sanity checks, and sets up / returns the trainer, model, and dataset loaders required for training the model with `_train()`. @@ -996,28 +998,30 @@ def _setup_for_train( # Setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at # least one batch no matter the chosen batch size + dataloader_kwargs = dict( + { + "batch_size": self.batch_size, + "shuffle": True, + "pin_memory": True, + "drop_last": False, + "collate_fn": self._batch_collate_fn, + }, + **(dataloader_kwargs or dict()), + ) + train_loader = DataLoader( train_dataset, - batch_size=self.batch_size, - shuffle=True, - num_workers=num_loader_workers, - pin_memory=True, - drop_last=False, - collate_fn=self._batch_collate_fn, + **dataloader_kwargs, ) # Prepare validation data + dataloader_kwargs["shuffle"] = False val_loader = ( None if val_dataset is None else DataLoader( val_dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=num_loader_workers, - pin_memory=True, - drop_last=False, - collate_fn=self._batch_collate_fn, + **dataloader_kwargs, ) ) @@ -1082,7 +1086,7 @@ def lr_find( verbose: Optional[bool] = None, epochs: int = 0, max_samples_per_ts: Optional[int] = None, - num_loader_workers: int = 0, + dataloader_kwargs: Optional[Dict[str, Any]] = None, min_lr: float = 1e-08, max_lr: float = 1, num_training: int = 100, @@ -1154,11 +1158,12 @@ def lr_find( large number of training samples. This parameter upper-bounds the number of training samples per time series (taking only the most recent samples in each series). Leaving to None does not apply any upper bound. - num_loader_workers - Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances, - both for the training and validation loaders (if any). - A larger number of workers can sometimes increase performance, but can also incur extra overheads - and increase memory usage, as more batches are loaded in parallel. + dataloader_kwargs + Optionally, a dictionary of keyword arguments used to create the PyTorch `DataLoader` instances for the + training and validation datasets. For more information on `DataLoader`, check out `this link + `_. + By default, Darts configures parameters ("batch_size", "shuffle", "drop_last", "collate_fn", "pin_memory") + for seamless forecasting. Changing them should be done with care to avoid unexpected behavior. min_lr minimum learning rate to investigate max_lr @@ -1190,7 +1195,7 @@ def lr_find( verbose=verbose, epochs=epochs, max_samples_per_ts=max_samples_per_ts, - num_loader_workers=num_loader_workers, + dataloader_kwargs=dataloader_kwargs, ) trainer, model, train_loader, val_loader = self._setup_for_train(*params) return Tuner(trainer).lr_find( @@ -1219,7 +1224,7 @@ def predict( n_jobs: int = 1, roll_size: Optional[int] = None, num_samples: int = 1, - num_loader_workers: int = 0, + dataloader_kwargs: Optional[Dict[str, Any]] = None, mc_dropout: bool = False, predict_likelihood_parameters: bool = False, show_warnings: bool = True, @@ -1281,11 +1286,12 @@ def predict( num_samples Number of times a prediction is sampled from a probabilistic model. Should be left set to 1 for deterministic models. - num_loader_workers - Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances, - for the inference/prediction dataset loaders (if any). - A larger number of workers can sometimes increase performance, but can also incur extra overheads - and increase memory usage, as more batches are loaded in parallel. + dataloader_kwargs + Optionally, a dictionary of keyword arguments used to create the PyTorch `DataLoader` instance for the + inference/prediction dataset. For more information on `DataLoader`, check out `this link + `_. + By default, Darts configures parameters ("batch_size", "shuffle", "drop_last", "collate_fn", "pin_memory") + for seamless forecasting. Changing them should be done with care to avoid unexpected behavior. mc_dropout Optionally, enable monte carlo dropout for predictions using neural network based models. This allows bayesian approximation by specifying an implicit prior over learned models. @@ -1369,7 +1375,7 @@ def predict( n_jobs=n_jobs, roll_size=roll_size, num_samples=num_samples, - num_loader_workers=num_loader_workers, + dataloader_kwargs=dataloader_kwargs, mc_dropout=mc_dropout, predict_likelihood_parameters=predict_likelihood_parameters, ) @@ -1387,7 +1393,7 @@ def predict_from_dataset( n_jobs: int = 1, roll_size: Optional[int] = None, num_samples: int = 1, - num_loader_workers: int = 0, + dataloader_kwargs: Optional[Dict[str, Any]] = None, mc_dropout: bool = False, predict_likelihood_parameters: bool = False, ) -> Sequence[TimeSeries]: @@ -1428,11 +1434,12 @@ def predict_from_dataset( num_samples Number of times a prediction is sampled from a probabilistic model. Should be left set to 1 for deterministic models. - num_loader_workers - Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances, - for the inference/prediction dataset loaders (if any). - A larger number of workers can sometimes increase performance, but can also incur extra overheads - and increase memory usage, as more batches are loaded in parallel. + dataloader_kwargs + Optionally, a dictionary of keyword arguments used to create the PyTorch `DataLoader` instance for the + inference/prediction dataset. For more information on `DataLoader`, check out `this link + `_. + By default, Darts configures parameters ("batch_size", "shuffle", "drop_last", "collate_fn", "pin_memory") + for seamless forecasting. Changing them should be done with care to avoid unexpected behavior. mc_dropout Optionally, enable monte carlo dropout for predictions using neural network based models. This allows bayesian approximation by specifying an implicit prior over learned models. @@ -1487,14 +1494,20 @@ def predict_from_dataset( mc_dropout=mc_dropout, ) + dataloader_kwargs = dict( + { + "batch_size": batch_size, + "pin_memory": True, + "drop_last": False, + "collate_fn": self._batch_collate_fn, + }, + **(dataloader_kwargs or {}), + **{"shuffle": False}, + ) + pred_loader = DataLoader( input_series_dataset, - batch_size=batch_size, - shuffle=False, - num_workers=num_loader_workers, - pin_memory=True, - drop_last=False, - collate_fn=self._batch_collate_fn, + **dataloader_kwargs, ) # set up trainer. use user supplied trainer or create a new trainer from scratch @@ -1503,7 +1516,7 @@ def predict_from_dataset( ) # prediction output comes as nested list: list of predicted `TimeSeries` for each batch. - predictions = self.trainer.predict(self.model, pred_loader) + predictions = self.trainer.predict(model=self.model, dataloaders=pred_loader) # flatten and return return [ts for batch in predictions for ts in batch] diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index a3cb7d6c9c..0b80e9ce63 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -24,6 +24,7 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.loggers.logger import DummyLogger from pytorch_lightning.tuner.lr_finder import _LRFinder +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torchmetrics import ( MeanAbsoluteError, MeanAbsolutePercentageError, @@ -1474,6 +1475,82 @@ def test_val_set(self, model_config): with patch("pytorch_lightning.Trainer.fit") as fit_patch: self.helper_check_val_set(*model_config, fit_patch) + def test_dataloader_kwargs_setup(self): + train_series, val_series = self.series[:-40], self.series[-40:] + model = RNNModel(12, "RNN", 10, 10, random_state=42, **tfm_kwargs) + + with patch("pytorch_lightning.Trainer.fit") as fit_patch: + model.fit(train_series, val_series=val_series) + assert "train_dataloaders" in fit_patch.call_args.kwargs + assert "val_dataloaders" in fit_patch.call_args.kwargs + + train_dl = fit_patch.call_args.kwargs["train_dataloaders"] + assert isinstance(train_dl, DataLoader) + val_dl = fit_patch.call_args.kwargs["val_dataloaders"] + assert isinstance(val_dl, DataLoader) + + dl_defaults = { + "batch_size": model.batch_size, + "pin_memory": True, + "drop_last": False, + "collate_fn": model._batch_collate_fn, + } + assert all([getattr(train_dl, k) == v for k, v in dl_defaults.items()]) + # shuffle=True gives random sampler + assert isinstance(train_dl.sampler, RandomSampler) + + assert all([getattr(val_dl, k) == v for k, v in dl_defaults.items()]) + # shuffle=False gives sequential sampler + assert isinstance(val_dl.sampler, SequentialSampler) + + # check that overwriting the dataloader kwargs works + dl_custom = dict(dl_defaults, **{"batch_size": 50, "drop_last": True}) + model.fit(train_series, val_series=val_series, dataloader_kwargs=dl_custom) + train_dl = fit_patch.call_args.kwargs["train_dataloaders"] + val_dl = fit_patch.call_args.kwargs["val_dataloaders"] + assert all([getattr(train_dl, k) == v for k, v in dl_custom.items()]) + assert all([getattr(val_dl, k) == v for k, v in dl_custom.items()]) + + with patch("pytorch_lightning.Trainer.predict") as pred_patch: + # calling predict with the patch will raise an error, but we only need to + # check the dataloader setup + with pytest.raises(Exception): + model.predict(n=1) + assert "dataloaders" in pred_patch.call_args.kwargs + pred_dl = pred_patch.call_args.kwargs["dataloaders"] + assert isinstance(pred_dl, DataLoader) + assert all([getattr(pred_dl, k) == v for k, v in dl_defaults.items()]) + # shuffle=False gives sequential sampler + assert isinstance(val_dl.sampler, SequentialSampler) + + # check that overwriting the dataloader kwargs works + with pytest.raises(Exception): + model.predict(n=1, dataloader_kwargs=dl_custom) + pred_dl = pred_patch.call_args.kwargs["dataloaders"] + assert all([getattr(pred_dl, k) == v for k, v in dl_custom.items()]) + + def test_dataloader_kwargs_fit_predict(self): + train_series, val_series = self.series[:-40], self.series[-40:] + model = RNNModel(12, "RNN", 10, 10, random_state=42, **tfm_kwargs) + + model.fit( + train_series, + val_series=val_series, + dataloader_kwargs={"batch_size": 100, "shuffle": False}, + ) + + # check same results with default batch size (32) and custom batch size + preds_default = model.predict( + n=2, + series=[train_series, val_series], + ) + preds_custom = model.predict( + n=2, + series=[train_series, val_series], + dataloader_kwargs={"batch_size": 100}, + ) + assert preds_default == preds_custom + def helper_check_val_set(self, model_cls, model_kwargs, fit_patch): # naive models don't call the Trainer if issubclass(model_cls, _GlobalNaiveModel):