Skip to content

Commit

Permalink
Add optional inverse transform in historical forecast (#2267)
Browse files Browse the repository at this point in the history
* Add optional inverse transform in historical forecast

* Update variables names and docstrings

* Move the inverse transform to InvertibleDataTransformer

* Fix single element list

* Update docstrings

* Move the inverse transform of list of lists to inverse_transform method

* make invertible transformers act on list of lists of series

* add tests

* update changelog

---------

Co-authored-by: dennisbader <[email protected]>
  • Loading branch information
alicjakrzeminska and dennisbader authored Mar 16, 2024
1 parent d764bc4 commit 91c7087
Show file tree
Hide file tree
Showing 6 changed files with 345 additions and 86 deletions.
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

### For users of the library:
**Improved**
- Improvements to `ForecastingModel`:
- Renamed the private `_is_probabilistic` property to a public `supports_probabilistic_prediction`. [#2269](https://github.com/unit8co/darts/pull/2269) by [Felix Divo](https://github.com/felixdivo).
- Improvements to `ForecastingModel`: [#2269](https://github.com/unit8co/darts/pull/2269) by [Felix Divo](https://github.com/felixdivo).
- Renamed the private `_is_probabilistic` property to a public `supports_probabilistic_prediction`.
- Improvements to `DataTransformer`: [#2267](https://github.com/unit8co/darts/pull/2267) by [Alicja Krzeminska-Sciga](https://github.com/alicjakrzeminska).
- `InvertibleDataTransformer` now supports parallelized inverse transformation for `series` being a list of lists of `TimeSeries` (`Sequence[Sequence[TimeSeries]]`). This `series` type represents for example the output from `historical_forecasts()` when using multiple series.

**Fixed**
- Fixed type hint warning "Unexpected argument" when calling `historical_forecasts()` caused by the `_with_sanity_checks` decorator. The type hinting is now properly configured to expect any input arguments and return the output type of the method for which the sanity checks are performed for. [#2286](https://github.com/unit8co/darts/pull/2286) by [Dennis Bader](https://github.com/dennisbader).
Expand Down
118 changes: 75 additions & 43 deletions darts/dataprocessing/transformers/base_data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
"""

from abc import ABC, abstractmethod
from typing import Any, Generator, List, Mapping, Optional, Sequence, Union
from typing import Any, Generator, Iterable, List, Mapping, Optional, Sequence, Union

import numpy as np
import xarray as xr

from darts import TimeSeries
from darts.logging import get_logger, raise_if, raise_if_not
from darts.logging import get_logger, raise_log
from darts.utils import _build_tqdm_iterator, _parallel_apply

logger = get_logger(__name__)
Expand Down Expand Up @@ -168,7 +168,8 @@ def set_verbose(self, value: bool):
value
New verbosity status
"""
raise_if_not(isinstance(value, bool), "Verbosity status must be a boolean.")
if not isinstance(value, bool):
raise_log(ValueError("Verbosity status must be a boolean."), logger=logger)

self._verbose = value

Expand All @@ -180,8 +181,8 @@ def set_n_jobs(self, value: int):
value
New n_jobs value. Set to `-1` for using all the available cores.
"""

raise_if_not(isinstance(value, int), "n_jobs must be an integer")
if not isinstance(value, int):
raise_log(ValueError("n_jobs must be an integer"), logger=logger)
self._n_jobs = value

@staticmethod
Expand Down Expand Up @@ -314,9 +315,11 @@ def transform(
if isinstance(series, TimeSeries):
input_series = [series]
data = [series]
transformer_selector = [0]
else:
input_series = series
data = series
transformer_selector = range(len(series))

if self._mask_components:
data = [
Expand All @@ -327,7 +330,7 @@ def transform(
kwargs["component_mask"] = component_mask

input_iterator = _build_tqdm_iterator(
zip(data, self._get_params(n_timeseries=len(data))),
zip(data, self._get_params(transformer_selector=transformer_selector)),
verbose=self._verbose,
desc=desc,
total=len(data),
Expand All @@ -350,7 +353,7 @@ def transform(
)

def _get_params(
self, n_timeseries: int
self, transformer_selector: Iterable
) -> Generator[Mapping[str, Any], None, None]:
"""
Creates generator of dictionaries containing fixed parameter values
Expand All @@ -359,11 +362,11 @@ def _get_params(
parallel jobs. Called by `transform` and `inverse_transform`,
if `Transformer` does *not* inherit from `FittableTransformer`.
"""
self._check_fixed_params(n_timeseries)
self._check_fixed_params(transformer_selector)

def params_generator(n_timeseries, fixed_params, parallel_params):
def params_generator(transformer_selector, fixed_params, parallel_params):
fixed_params_copy = fixed_params.copy()
for i in range(n_timeseries):
for i in transformer_selector:
for key in parallel_params:
fixed_params_copy[key] = fixed_params[key][i]
if fixed_params_copy:
Expand All @@ -373,21 +376,35 @@ def params_generator(n_timeseries, fixed_params, parallel_params):
yield params
return None

return params_generator(n_timeseries, self._fixed_params, self._parallel_params)
return params_generator(
transformer_selector, self._fixed_params, self._parallel_params
)

def _check_fixed_params(self, n_timeseries: int) -> None:
def _check_fixed_params(self, transformer_selector: Iterable) -> None:
"""
Raises `ValueError` if `self._parallel_params` specifies a `key` in
`self._fixed_params` that should be distributed, but
`len(self._fixed_params[key])` does not equal `n_timeseries`.
`len(self._fixed_params[key])` does not equal to the number of time series
(the maximum value + 1 from `transformer_selector`).
"""
for key in self._parallel_params:
raise_if(
n_timeseries > len(self._fixed_params[key]),
f"{n_timeseries} TimeSeries were provided "
f"but only {len(self._fixed_params[key])} {key} values "
f"were specified upon initialising {self.name}.",
)
n_timeseries_ = max(transformer_selector) + 1
if n_timeseries_ > len(self._fixed_params[key]):
raise_log(
ValueError(
f"{n_timeseries_} TimeSeries were provided "
f"but only {len(self._fixed_params[key])} {key} values "
f"were specified upon initialising {self.name}."
),
logger=logger,
)
elif n_timeseries_ < len(self._fixed_params[key]):
logger.warning(
f"Only {n_timeseries_} TimeSeries were provided "
f"which is lower than the number of {key} values "
f"(n={len(self._fixed_params[key])}) that were specified "
f"upon initialising {self.name}."
)
return None

@staticmethod
Expand Down Expand Up @@ -418,16 +435,22 @@ def apply_component_mask(
if component_mask is None:
masked = series.copy() if return_ts else series.all_values()
else:
raise_if_not(
isinstance(component_mask, np.ndarray) and component_mask.dtype == bool,
f"`component_mask` must be a boolean `np.ndarray`, not a {type(component_mask)}.",
logger,
)
raise_if_not(
series.width == len(component_mask),
"mismatch between number of components in `series` and length of `component_mask`",
logger,
)
if not (
isinstance(component_mask, np.ndarray) and component_mask.dtype == bool
):
raise_log(
ValueError(
f"`component_mask` must be a boolean `np.ndarray`, not a {type(component_mask)}."
),
logger=logger,
)
if not series.width == len(component_mask):
raise_log(
ValueError(
"mismatch between number of components in `series` and length of `component_mask`"
),
logger=logger,
)
masked = series.all_values(copy=False)[:, component_mask, :]
if return_ts:
# Remove masked components from coords:
Expand Down Expand Up @@ -469,16 +492,22 @@ def unapply_component_mask(
if component_mask is None:
unmasked = vals
else:
raise_if_not(
isinstance(component_mask, np.ndarray) and component_mask.dtype == bool,
"If `component_mask` is given, must be a boolean np.ndarray`",
logger,
)
raise_if_not(
series.width == len(component_mask),
"mismatch between number of components in `series` and length of `component_mask`",
logger,
)
if not (
isinstance(component_mask, np.ndarray) and component_mask.dtype == bool
):
raise_log(
ValueError(
"If `component_mask` is given, must be a boolean np.ndarray`"
),
logger=logger,
)
if not series.width == len(component_mask):
raise_log(
ValueError(
"mismatch between number of components in `series` and length of `component_mask`"
),
logger=logger,
)
unmasked = series.all_values()
if isinstance(vals, TimeSeries):
unmasked[:, component_mask, :] = vals.all_values()
Expand Down Expand Up @@ -560,10 +589,13 @@ def unstack_samples(
if series is not None:
n_samples = series.n_samples
else:
raise_if(
all(x is None for x in [n_timesteps, n_samples]),
"Must specify either `n_timesteps`, `n_samples`, or `series`.",
)
if all(x is None for x in [n_timesteps, n_samples]):
raise_log(
ValueError(
"Must specify either `n_timesteps`, `n_samples`, or `series`."
),
logger=logger,
)
n_components = vals.shape[-1]
if n_timesteps is not None:
reshaped_vals = vals.reshape(n_timesteps, -1, n_components)
Expand Down
72 changes: 48 additions & 24 deletions darts/dataprocessing/transformers/fittable_data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
"""

from abc import abstractmethod
from typing import Any, Generator, List, Mapping, Optional, Sequence, Union
from typing import Any, Generator, Iterable, List, Mapping, Optional, Sequence, Union

import numpy as np

from darts import TimeSeries
from darts.logging import get_logger, raise_if, raise_if_not
from darts.logging import get_logger, raise_log
from darts.utils import _build_tqdm_iterator, _parallel_apply

from .base_data_transformer import BaseDataTransformer
Expand Down Expand Up @@ -256,8 +256,10 @@ def fit(

if isinstance(series, TimeSeries):
data = [series]
transformer_selector = [0]
else:
data = series
transformer_selector = range(len(series))

if self._mask_components:
data = [
Expand All @@ -267,7 +269,9 @@ def fit(
else:
kwargs["component_mask"] = component_mask

params_iterator = self._get_params(n_timeseries=len(data), calling_fit=True)
params_iterator = self._get_params(
transformer_selector=transformer_selector, calling_fit=True
)
fit_iterator = (
zip(data, params_iterator)
if not self._global_fit
Expand Down Expand Up @@ -315,7 +319,7 @@ def fit_transform(
).transform(series, *args, component_mask=component_mask, **kwargs)

def _get_params(
self, n_timeseries: int, calling_fit: bool = False
self, transformer_selector: Iterable, calling_fit: bool = False
) -> Generator[Mapping[str, Any], None, None]:
"""
Overrides `_get_params` of `BaseDataTransformer`. Creates generator of dictionaries containing
Expand All @@ -327,14 +331,18 @@ def _get_params(
`transform` and `inverse_transform`.
"""
# Call `_check_fixed_params` of `BaseDataTransformer`:
self._check_fixed_params(n_timeseries)
fitted_params = self._get_fitted_params(n_timeseries, calling_fit)
self._check_fixed_params(transformer_selector)
fitted_params = self._get_fitted_params(transformer_selector, calling_fit)

def params_generator(
n_jobs, fixed_params, fitted_params, parallel_params, global_fit
transformer_selector_,
fixed_params,
fitted_params,
parallel_params,
global_fit,
):
fixed_params_copy = fixed_params.copy()
for i in range(n_jobs):
for i in transformer_selector_:
for key in parallel_params:
fixed_params_copy[key] = fixed_params[key][i]
params = {}
Expand All @@ -348,37 +356,53 @@ def params_generator(
params = None
yield params

n_jobs = n_timeseries if not (calling_fit and self._global_fit) else 1
transformer_selector_ = (
transformer_selector if not (calling_fit and self._global_fit) else [0]
)

return params_generator(
n_jobs,
transformer_selector_,
self._fixed_params,
fitted_params,
self._parallel_params,
self._global_fit,
)

def _get_fitted_params(self, n_timeseries: int, calling_fit: bool) -> Sequence[Any]:
def _get_fitted_params(
self, transformer_selector: Iterable, calling_fit: bool
) -> Sequence[Any]:
"""
Returns `self._fitted_params` if `calling_fit = False`, otherwise returns an empty
tuple. If `calling_fit = False`, also checks that `self._fitted_params`, which is a
sequence of values, contains exactly `n_timeseries` values; if not, a `ValueError` is thrown.
sequence of values, contains exactly `transformer_selector` values; if not, a `ValueError` is thrown.
"""
if not calling_fit:
raise_if_not(
self._fit_called,
("Must call `fit` before calling `transform`/`inverse_transform`."),
)
if not self._fit_called:
raise_log(
ValueError(
"Must call `fit` before calling `transform`/`inverse_transform`."
),
logger=logger,
)
fitted_params = self._fitted_params
else:
fitted_params = tuple()
if not self._global_fit and fitted_params:
raise_if(
n_timeseries > len(fitted_params),
(
f"{n_timeseries} TimeSeries were provided "
f"but only {len(fitted_params)} TimeSeries "
f"were specified upon training {self.name}."
),
)
n_timeseries_ = max(transformer_selector) + 1
if n_timeseries_ > len(fitted_params):
raise_log(
ValueError(
f"{n_timeseries_} TimeSeries were provided "
f"but only {len(fitted_params)} TimeSeries "
f"were specified upon training {self.name}."
),
logger=logger,
)
elif n_timeseries_ < len(fitted_params):
logger.warning(
f"Only {n_timeseries_} TimeSeries (lists) were provided "
f"which is lower than the number of series (n={len(fitted_params)}) "
f"used to fit {self.name}. This can result in a mismatch between the "
f"series and the underlying transformers."
)
return fitted_params
Loading

0 comments on commit 91c7087

Please sign in to comment.