Skip to content

Commit

Permalink
Feat/sample weight torch (#2410)
Browse files Browse the repository at this point in the history
* integrate sample weights into torch datasets part 1

* extract sample weights per sample

* compute loss with sample weights

* add sample weight support to likelihoods

* add sample weights to horizon based ds

* dynamically adapt loss function to work with or without sample weights

* add support for val sample weight

* clean up datasets

* fix dataset tests

* add torch sample weight tests

* probabilistic torch sample weight tests

* add torch dataset tests

* update changelog

* fix failing tests

* update changelog

* apply suggestions from PR review

* refactor common weights logic
  • Loading branch information
dennisbader authored Jun 17, 2024
1 parent 6835c36 commit b532a80
Show file tree
Hide file tree
Showing 19 changed files with 1,395 additions and 338 deletions.
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

### For users of the library:
**Improved**
- 🚀🚀 Improvements to `GlobalForecastingModel` : [#2404](https://github.com/unit8co/darts/pull/2404) and [#2410](https://github.com/unit8co/darts/pull/2410) by [Anton Ragot](https://github.com/AntonRagot) and [Dennis Bader](https://github.com/dennisbader).
- Added parameters `sample_weight` and `val_sample_weight` to `fit()` to apply weights to each observation with the corresponding output step, and target component in the training and evaluation set. Supported by both deterministic and probabilistic models. The sample weight can either be `TimeSeries` themselves or built-in weight generators "linear" and "exponential" decay. In case of a `TimeSeries` it is handled identically as the covariates (e.g. pass multiple weight series with multiple target series, relevant time frame extraction is handled automatically for you, ...).
- Improvements to the Anomaly Detection Module through major refactor. The refactor includes major performance optimization for the majority of the processes and improvements to the API, consistency, reliability, and the documentation. Some of these necessary changes come at the cost of breaking changes : [#1477](https://github.com/unit8co/darts/pull/1477) by [Dennis Bader](https://github.com/dennisbader), [Samuele Giuliano Piazzetta](https://github.com/piaz97), [Antoine Madrona](https://github.com/madtoinou), [Julien Herzen](https://github.com/hrzn), [Julien Adda](https://github.com/julien12234).
- 🚀 Added an example notebook that showcases how to use Darts for Time Series Anomaly Detection
- 🚀🚀 Added an example notebook that showcases how to use Darts for Time Series Anomaly Detection
- Added a new dataset for anomaly detection with the number of taxi passengers in New York from the year 2014 to 2015.
- `FittableWindowScorer` (KMeans, PyOD, and Wasserstein Scorers) now accept any of darts "per-time" step metrics as difference function `diff_fn`.
- `ForecastingAnomalyModel` is now much faster thanks to optimized historical forecasts to generate the prediction input for the scorers. We also added more control over the historical forecasts generation through additional parameters in all model methods.
Expand All @@ -35,8 +37,6 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- renamed params `actual_anoamlies` to `anomalies`, and `binary_pred_anomalies` to `pred_anomalies`
- `darts.ad.utils.show_anomalies_from_scores`:
- renamed params `series` to `actual_series`, `actual_anomalies` to `anomalies`, `model_output` to `pred_series`, and `anomaly_scores` to `pred_scores`
- Improvements to `RegressionModel` : [#2404](https://github.com/unit8co/darts/pull/2404) by [Anton Ragot](https://github.com/AntonRagot) and [Dennis Bader](https://github.com/dennisbader).
- Added parameters `sample_weight` and `val_sample_weight` to `fit()` to apply weights to each observation with the corresponding output step, and target component in the training and evaluation set.
- Improvements to `TimeSeries` : [#1477](https://github.com/unit8co/darts/pull/1477) by [Dennis Bader](https://github.com/dennisbader).
- New method `with_times_and_values()`, which returns a new series with a new time index and new values but with identical columns and metadata as the series called from (static covariates, hierarchy).
- New method `slice_intersect_times()`, which returns the sliced time index of a series, where the index has been intersected with another series.
Expand Down
2 changes: 2 additions & 0 deletions darts/models/forecasting/global_baseline_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def _build_train_dataset(
target: Sequence[TimeSeries],
past_covariates: Optional[Sequence[TimeSeries]],
future_covariates: Optional[Sequence[TimeSeries]],
sample_weight: Optional[Sequence[TimeSeries]],
max_samples_per_ts: Optional[int],
) -> MixedCovariatesTrainingDataset:
return MixedCovariatesSequentialDataset(
Expand All @@ -264,6 +265,7 @@ def _build_train_dataset(
output_chunk_shift=self.output_chunk_shift,
max_samples_per_ts=max_samples_per_ts,
use_static_covariates=self.uses_static_covariates,
sample_weight=sample_weight,
)


Expand Down
44 changes: 33 additions & 11 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This file contains abstract classes for deterministic and probabilistic PyTorch Lightning Modules
"""

import copy
from abc import ABC, abstractmethod
from functools import wraps
from typing import Any, Dict, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -161,6 +162,13 @@ def __init__(

# define the loss function
self.criterion = loss_fn
self.train_criterion = copy.deepcopy(loss_fn)
self.val_criterion = copy.deepcopy(loss_fn)
# reduction will be set to `None` when calling `TFM.fit()` with sample weights;
# reset the actual criterion in method `on_fit_end()`
self.train_criterion_reduction: Optional[str] = None
self.val_criterion_reduction: Optional[str] = None

# by default models are deterministic (i.e. not probabilistic)
self.likelihood = likelihood

Expand Down Expand Up @@ -212,11 +220,11 @@ def forward(self, *args, **kwargs) -> Any:

def training_step(self, train_batch, batch_idx) -> torch.Tensor:
"""performs the training step"""
output = self._produce_train_output(train_batch[:-1])
target = train_batch[
-1
] # By convention target is always the last element returned by datasets
loss = self._compute_loss(output, target)
# by convention, the last two elements are sample weights and future target
output = self._produce_train_output(train_batch[:-2])
sample_weight = train_batch[-2]
target = train_batch[-1]
loss = self._compute_loss(output, target, self.train_criterion, sample_weight)
self.log(
"train_loss",
loss,
Expand All @@ -229,9 +237,11 @@ def training_step(self, train_batch, batch_idx) -> torch.Tensor:

def validation_step(self, val_batch, batch_idx) -> torch.Tensor:
"""performs the validation step"""
output = self._produce_train_output(val_batch[:-1])
# the last two elements are sample weights and future target
output = self._produce_train_output(val_batch[:-2])
sample_weight = val_batch[-2]
target = val_batch[-1]
loss = self._compute_loss(output, target)
loss = self._compute_loss(output, target, self.val_criterion, sample_weight)
self.log(
"val_loss",
loss,
Expand All @@ -242,6 +252,15 @@ def validation_step(self, val_batch, batch_idx) -> torch.Tensor:
self._update_metrics(output, target, self.val_metrics)
return loss

def on_fit_end(self) -> None:
# revert the loss function reduction change when sample weights were used
if self.train_criterion_reduction is not None:
self.train_criterion.reduction = self.train_criterion_reduction
self.train_criterion_reduction = None
if self.val_criterion_reduction is not None:
self.val_criterion.reduction = self.val_criterion_reduction
self.val_criterion_reduction = None

def on_train_epoch_end(self):
self._compute_metrics(self.train_metrics)

Expand Down Expand Up @@ -364,14 +383,17 @@ def set_predict_parameters(
self.predict_likelihood_parameters = predict_likelihood_parameters
self.pred_mc_dropout = mc_dropout

def _compute_loss(self, output, target):
def _compute_loss(self, output, target, criterion, sample_weight):
# output is of shape (batch_size, n_timesteps, n_components, n_params)
if self.likelihood:
return self.likelihood.compute_loss(output, target)
loss = self.likelihood.compute_loss(output, target, sample_weight)
else:
# If there's no likelihood, nr_params=1, and we need to squeeze out the
# last dimension of model output, for properly computing the loss.
return self.criterion(output.squeeze(dim=-1), target)
loss = criterion(output.squeeze(dim=-1), target)
if sample_weight is not None:
loss = (loss * sample_weight).mean()
return loss

def _update_metrics(self, output, target, metrics):
if not len(metrics):
Expand Down Expand Up @@ -511,7 +533,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint["train_sample_shape"] = self.train_sample_shape
# we must save the loss to properly restore it when resuming training
checkpoint["loss_fn"] = self.criterion
# we must save the metrics to continue outputing them when resuming training
# we must save the metrics to continue logging them when resuming training
checkpoint["torch_metrics_train"] = self.train_metrics
checkpoint["torch_metrics_val"] = self.val_metrics

Expand Down
2 changes: 2 additions & 0 deletions darts/models/forecasting/rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def _build_train_dataset(
target: Sequence[TimeSeries],
past_covariates: Optional[Sequence[TimeSeries]],
future_covariates: Optional[Sequence[TimeSeries]],
sample_weight: Optional[Sequence[TimeSeries]],
max_samples_per_ts: Optional[int],
) -> DualCovariatesShiftedDataset:
return DualCovariatesShiftedDataset(
Expand All @@ -572,6 +573,7 @@ def _build_train_dataset(
shift=1,
max_samples_per_ts=max_samples_per_ts,
use_static_covariates=self.uses_static_covariates,
sample_weight=sample_weight,
)

def _verify_train_dataset_type(self, train_dataset: TrainingDataset):
Expand Down
2 changes: 2 additions & 0 deletions darts/models/forecasting/tcn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def _build_train_dataset(
target: Sequence[TimeSeries],
past_covariates: Optional[Sequence[TimeSeries]],
future_covariates: Optional[Sequence[TimeSeries]],
sample_weight: Optional[Sequence[TimeSeries]],
max_samples_per_ts: Optional[int],
) -> PastCovariatesShiftedDataset:
return PastCovariatesShiftedDataset(
Expand All @@ -544,4 +545,5 @@ def _build_train_dataset(
shift=self.output_chunk_length + self.output_chunk_shift,
max_samples_per_ts=max_samples_per_ts,
use_static_covariates=self.uses_static_covariates,
sample_weight=sample_weight,
)
2 changes: 2 additions & 0 deletions darts/models/forecasting/tft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,7 @@ def _build_train_dataset(
target: Sequence[TimeSeries],
past_covariates: Optional[Sequence[TimeSeries]],
future_covariates: Optional[Sequence[TimeSeries]],
sample_weight: Optional[Sequence[TimeSeries]],
max_samples_per_ts: Optional[int],
) -> MixedCovariatesSequentialDataset:
raise_if(
Expand All @@ -1179,6 +1180,7 @@ def _build_train_dataset(
output_chunk_shift=self.output_chunk_shift,
max_samples_per_ts=max_samples_per_ts,
use_static_covariates=self.uses_static_covariates,
sample_weight=sample_weight,
)

def _verify_train_dataset_type(self, train_dataset: TrainingDataset):
Expand Down
Loading

0 comments on commit b532a80

Please sign in to comment.