Skip to content

Commit

Permalink
Doc/data leak (#2020)
Browse files Browse the repository at this point in the history
* changes

* deleating logs

* update notebooks

* add progress bar callback

* update first notebook with new progress bar

* update remaining notebooks

---------

Co-authored-by: simsalabim1 <[email protected]>
Co-authored-by: dennisbader <[email protected]>
  • Loading branch information
3 people authored Oct 13, 2023
1 parent 2f7c9db commit f3bdbcf
Show file tree
Hide file tree
Showing 8 changed files with 1,172 additions and 511 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
[Full Changelog](https://github.com/unit8co/darts/compare/0.26.0...master)

### For users of the library:
- Improvements to `TorchForecastingModel`:
- Added callback `darts.utils.callbacks.TFMProgressBar` to customize at which model stages to display the progress bar. [#2020](https://github.com/unit8co/darts/pull/2020) by [Dennis Bader](https://github.com/dennisbader).
- Improvements to documentation:
- Adapted the example notebooks to properly apply data transformers and avoid look-ahead bias. [#2020](https://github.com/unit8co/darts/pull/2020) by [Samriddhi Singh](https://github.com/SimTheGreat).
### For developers of the library:

## [0.26.0](https://github.com/unit8co/darts/tree/0.26.0) (2023-09-16)
Expand Down
22 changes: 16 additions & 6 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import pytorch_lightning as pl
import torch
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ProgressBar
from torch import Tensor
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -510,7 +511,11 @@ def _setup_trainer(
return trainer

trainer_params = {key: val for key, val in self.trainer_params.items()}
if verbose is not None:
has_progress_bar = any(
[isinstance(cb, ProgressBar) for cb in trainer_params.get("callbacks", [])]
)
# we ignore `verbose` if `trainer` has a progress bar, to avoid errors from lightning
if verbose is not None and not has_progress_bar:
trainer_params["enable_model_summary"] = (
verbose if model.epochs_trained == 0 else False
)
Expand Down Expand Up @@ -653,7 +658,8 @@ def fit(
Optionally, a custom PyTorch-Lightning Trainer object to perform training. Using a custom ``trainer`` will
override Darts' default trainer.
verbose
Optionally, whether to print progress.
Optionally, whether to print the progress. Ignored if there is a `ProgressBar` callback in
`pl_trainer_kwargs`.
epochs
If specified, will train the model for ``epochs`` (additional) epochs, irrespective of what ``n_epochs``
was provided to the model constructor.
Expand Down Expand Up @@ -883,7 +889,8 @@ def fit_from_dataset(
Optionally, a custom PyTorch-Lightning Trainer object to perform prediction. Using a custom `trainer` will
override Darts' default trainer.
verbose
Optionally, whether to print progress.
Optionally, whether to print the progress. Ignored if there is a `ProgressBar` callback in
`pl_trainer_kwargs`.
epochs
If specified, will train the model for ``epochs`` (additional) epochs, irrespective of what ``n_epochs``
was provided to the model constructor.
Expand Down Expand Up @@ -1125,7 +1132,8 @@ def lr_find(
Optionally, a custom PyTorch-Lightning Trainer object to perform training. Using a custom ``trainer`` will
override Darts' default trainer.
verbose
Optionally, whether to print progress.
Optionally, whether to print the progress. Ignored if there is a `ProgressBar` callback in
`pl_trainer_kwargs`.
epochs
If specified, will train the model for ``epochs`` (additional) epochs, irrespective of what ``n_epochs``
was provided to the model constructor.
Expand Down Expand Up @@ -1249,7 +1257,8 @@ def predict(
batch_size
Size of batches during prediction. Defaults to the models' training ``batch_size`` value.
verbose
Optionally, whether to print progress.
Optionally, whether to print the progress. Ignored if there is a `ProgressBar` callback in
`pl_trainer_kwargs`.
n_jobs
The number of jobs to run in parallel. ``-1`` means using all processors. Defaults to ``1``.
roll_size
Expand Down Expand Up @@ -1386,7 +1395,8 @@ def predict_from_dataset(
batch_size
Size of batches during prediction. Defaults to the models ``batch_size`` value.
verbose
Optionally, whether to print progress.
Optionally, whether to print the progress. Ignored if there is a `ProgressBar` callback in
`pl_trainer_kwargs`.
n_jobs
The number of jobs to run in parallel. ``-1`` means using all processors. Defaults to ``1``.
roll_size
Expand Down
102 changes: 102 additions & 0 deletions darts/utils/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import sys

from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.callbacks.progress.tqdm_progress import Tqdm


class TFMProgressBar(TQDMProgressBar):
def __init__(
self,
enable_sanity_check_bar: bool = True,
enable_train_bar: bool = True,
enable_validation_bar: bool = True,
enable_prediction_bar: bool = True,
enable_train_bar_only: bool = False,
**kwargs
):
"""Darts' Progress Bar for `TorchForecastingModels`.
Allows to customize for which model stages (sanity checks, training, validation, prediction) to display a
progress bar.
This class is a PyTorch Lightning `Callback` and can be passed to the `TorchForecastingModel` constructor
through the `pl_trainer_kwargs` parameter.
Examples
--------
>>> from darts.models import NBEATSModel
>>> from darts.utils.callbacks import TFMProgressBar
>>> # only display the training bar and not the validation, prediction, and sanity check bars
>>> prog_bar = TFMProgressBar(enable_train_bar_only=True)
>>> model = NBEATSModel(1, 1, pl_trainer_kwargs={"callbacks": [prog_bar]})
Parameters
----------
enable_sanity_check_bar
Whether to enable to progress bar for sanity checks.
enable_train_bar
Whether to enable to progress bar for training.
enable_validation_bar
Whether to enable to progress bar for validation.
enable_prediction_bar
Whether to enable to progress bar for prediction.
enable_train_bar_only
Whether to disable all progress bars except the bar for training.
**kwargs
Arguments passed to the PyTorch Lightning's `TQDMProgressBar
<https://scikit-learn.org/stable/glossary.html#term-random_state>`_.
"""
super().__init__(**kwargs)
self.enable_sanity_check_bar = enable_sanity_check_bar
self.enable_train_bar = enable_train_bar
self.enable_validation_bar = enable_validation_bar
self.enable_prediction_bar = enable_prediction_bar
self.enable_train_bar_only = enable_train_bar_only

def init_sanity_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for the validation sanity run."""
return Tqdm(
desc=self.sanity_check_description,
position=(2 * self.process_position),
disable=not self.enable_sanity_check_bar or self.enable_train_bar_only,
leave=False,
dynamic_ncols=True,
file=sys.stdout,
)

def init_predict_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for predicting."""
return Tqdm(
desc=self.predict_description,
position=(2 * self.process_position),
disable=not self.enable_prediction_bar or self.enable_train_bar_only,
leave=True,
dynamic_ncols=True,
file=sys.stdout,
smoothing=0,
)

def init_train_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for training."""
return Tqdm(
desc=self.train_description,
position=(2 * self.process_position),
disable=not self.enable_train_bar,
leave=True,
dynamic_ncols=True,
file=sys.stdout,
smoothing=0,
)

def init_validation_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for validation."""
# The train progress bar doesn't exist in `trainer.validate()`
has_main_bar = self.trainer.state.fn != "validate"
return Tqdm(
desc=self.validation_description,
position=(2 * self.process_position + has_main_bar),
disable=not self.enable_validation_bar or self.enable_train_bar_only,
leave=not has_main_bar,
dynamic_ncols=True,
file=sys.stdout,
)
422 changes: 291 additions & 131 deletions examples/01-multi-time-series-and-covariates.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit f3bdbcf

Please sign in to comment.