From 91c7087e757c5ef3afc53b5fbaccda83dec6b71a Mon Sep 17 00:00:00 2001 From: Alicja Krzeminska-Sciga <110606089+alicjakrzeminska@users.noreply.github.com> Date: Sat, 16 Mar 2024 12:05:51 +0100 Subject: [PATCH] Add optional inverse transform in historical forecast (#2267) * Add optional inverse transform in historical forecast * Update variables names and docstrings * Move the inverse transform to InvertibleDataTransformer * Fix single element list * Update docstrings * Move the inverse transform of list of lists to inverse_transform method * make invertible transformers act on list of lists of series * add tests * update changelog --------- Co-authored-by: dennisbader --- CHANGELOG.md | 6 +- .../transformers/base_data_transformer.py | 118 +++++++++++------- .../transformers/fittable_data_transformer.py | 72 +++++++---- .../invertible_data_transformer.py | 61 ++++++--- .../test_invertible_data_transformer.py | 90 +++++++++++++ ...st_invertible_fittable_data_transformer.py | 84 +++++++++++++ 6 files changed, 345 insertions(+), 86 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 752f90120f..8710d28754 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co ### For users of the library: **Improved** -- Improvements to `ForecastingModel`: - - Renamed the private `_is_probabilistic` property to a public `supports_probabilistic_prediction`. [#2269](https://github.com/unit8co/darts/pull/2269) by [Felix Divo](https://github.com/felixdivo). +- Improvements to `ForecastingModel`: [#2269](https://github.com/unit8co/darts/pull/2269) by [Felix Divo](https://github.com/felixdivo). + - Renamed the private `_is_probabilistic` property to a public `supports_probabilistic_prediction`. +- Improvements to `DataTransformer`: [#2267](https://github.com/unit8co/darts/pull/2267) by [Alicja Krzeminska-Sciga](https://github.com/alicjakrzeminska). + - `InvertibleDataTransformer` now supports parallelized inverse transformation for `series` being a list of lists of `TimeSeries` (`Sequence[Sequence[TimeSeries]]`). This `series` type represents for example the output from `historical_forecasts()` when using multiple series. **Fixed** - Fixed type hint warning "Unexpected argument" when calling `historical_forecasts()` caused by the `_with_sanity_checks` decorator. The type hinting is now properly configured to expect any input arguments and return the output type of the method for which the sanity checks are performed for. [#2286](https://github.com/unit8co/darts/pull/2286) by [Dennis Bader](https://github.com/dennisbader). diff --git a/darts/dataprocessing/transformers/base_data_transformer.py b/darts/dataprocessing/transformers/base_data_transformer.py index 4ba79fbd81..1d4b8dfdf2 100644 --- a/darts/dataprocessing/transformers/base_data_transformer.py +++ b/darts/dataprocessing/transformers/base_data_transformer.py @@ -4,13 +4,13 @@ """ from abc import ABC, abstractmethod -from typing import Any, Generator, List, Mapping, Optional, Sequence, Union +from typing import Any, Generator, Iterable, List, Mapping, Optional, Sequence, Union import numpy as np import xarray as xr from darts import TimeSeries -from darts.logging import get_logger, raise_if, raise_if_not +from darts.logging import get_logger, raise_log from darts.utils import _build_tqdm_iterator, _parallel_apply logger = get_logger(__name__) @@ -168,7 +168,8 @@ def set_verbose(self, value: bool): value New verbosity status """ - raise_if_not(isinstance(value, bool), "Verbosity status must be a boolean.") + if not isinstance(value, bool): + raise_log(ValueError("Verbosity status must be a boolean."), logger=logger) self._verbose = value @@ -180,8 +181,8 @@ def set_n_jobs(self, value: int): value New n_jobs value. Set to `-1` for using all the available cores. """ - - raise_if_not(isinstance(value, int), "n_jobs must be an integer") + if not isinstance(value, int): + raise_log(ValueError("n_jobs must be an integer"), logger=logger) self._n_jobs = value @staticmethod @@ -314,9 +315,11 @@ def transform( if isinstance(series, TimeSeries): input_series = [series] data = [series] + transformer_selector = [0] else: input_series = series data = series + transformer_selector = range(len(series)) if self._mask_components: data = [ @@ -327,7 +330,7 @@ def transform( kwargs["component_mask"] = component_mask input_iterator = _build_tqdm_iterator( - zip(data, self._get_params(n_timeseries=len(data))), + zip(data, self._get_params(transformer_selector=transformer_selector)), verbose=self._verbose, desc=desc, total=len(data), @@ -350,7 +353,7 @@ def transform( ) def _get_params( - self, n_timeseries: int + self, transformer_selector: Iterable ) -> Generator[Mapping[str, Any], None, None]: """ Creates generator of dictionaries containing fixed parameter values @@ -359,11 +362,11 @@ def _get_params( parallel jobs. Called by `transform` and `inverse_transform`, if `Transformer` does *not* inherit from `FittableTransformer`. """ - self._check_fixed_params(n_timeseries) + self._check_fixed_params(transformer_selector) - def params_generator(n_timeseries, fixed_params, parallel_params): + def params_generator(transformer_selector, fixed_params, parallel_params): fixed_params_copy = fixed_params.copy() - for i in range(n_timeseries): + for i in transformer_selector: for key in parallel_params: fixed_params_copy[key] = fixed_params[key][i] if fixed_params_copy: @@ -373,21 +376,35 @@ def params_generator(n_timeseries, fixed_params, parallel_params): yield params return None - return params_generator(n_timeseries, self._fixed_params, self._parallel_params) + return params_generator( + transformer_selector, self._fixed_params, self._parallel_params + ) - def _check_fixed_params(self, n_timeseries: int) -> None: + def _check_fixed_params(self, transformer_selector: Iterable) -> None: """ Raises `ValueError` if `self._parallel_params` specifies a `key` in `self._fixed_params` that should be distributed, but - `len(self._fixed_params[key])` does not equal `n_timeseries`. + `len(self._fixed_params[key])` does not equal to the number of time series + (the maximum value + 1 from `transformer_selector`). """ for key in self._parallel_params: - raise_if( - n_timeseries > len(self._fixed_params[key]), - f"{n_timeseries} TimeSeries were provided " - f"but only {len(self._fixed_params[key])} {key} values " - f"were specified upon initialising {self.name}.", - ) + n_timeseries_ = max(transformer_selector) + 1 + if n_timeseries_ > len(self._fixed_params[key]): + raise_log( + ValueError( + f"{n_timeseries_} TimeSeries were provided " + f"but only {len(self._fixed_params[key])} {key} values " + f"were specified upon initialising {self.name}." + ), + logger=logger, + ) + elif n_timeseries_ < len(self._fixed_params[key]): + logger.warning( + f"Only {n_timeseries_} TimeSeries were provided " + f"which is lower than the number of {key} values " + f"(n={len(self._fixed_params[key])}) that were specified " + f"upon initialising {self.name}." + ) return None @staticmethod @@ -418,16 +435,22 @@ def apply_component_mask( if component_mask is None: masked = series.copy() if return_ts else series.all_values() else: - raise_if_not( - isinstance(component_mask, np.ndarray) and component_mask.dtype == bool, - f"`component_mask` must be a boolean `np.ndarray`, not a {type(component_mask)}.", - logger, - ) - raise_if_not( - series.width == len(component_mask), - "mismatch between number of components in `series` and length of `component_mask`", - logger, - ) + if not ( + isinstance(component_mask, np.ndarray) and component_mask.dtype == bool + ): + raise_log( + ValueError( + f"`component_mask` must be a boolean `np.ndarray`, not a {type(component_mask)}." + ), + logger=logger, + ) + if not series.width == len(component_mask): + raise_log( + ValueError( + "mismatch between number of components in `series` and length of `component_mask`" + ), + logger=logger, + ) masked = series.all_values(copy=False)[:, component_mask, :] if return_ts: # Remove masked components from coords: @@ -469,16 +492,22 @@ def unapply_component_mask( if component_mask is None: unmasked = vals else: - raise_if_not( - isinstance(component_mask, np.ndarray) and component_mask.dtype == bool, - "If `component_mask` is given, must be a boolean np.ndarray`", - logger, - ) - raise_if_not( - series.width == len(component_mask), - "mismatch between number of components in `series` and length of `component_mask`", - logger, - ) + if not ( + isinstance(component_mask, np.ndarray) and component_mask.dtype == bool + ): + raise_log( + ValueError( + "If `component_mask` is given, must be a boolean np.ndarray`" + ), + logger=logger, + ) + if not series.width == len(component_mask): + raise_log( + ValueError( + "mismatch between number of components in `series` and length of `component_mask`" + ), + logger=logger, + ) unmasked = series.all_values() if isinstance(vals, TimeSeries): unmasked[:, component_mask, :] = vals.all_values() @@ -560,10 +589,13 @@ def unstack_samples( if series is not None: n_samples = series.n_samples else: - raise_if( - all(x is None for x in [n_timesteps, n_samples]), - "Must specify either `n_timesteps`, `n_samples`, or `series`.", - ) + if all(x is None for x in [n_timesteps, n_samples]): + raise_log( + ValueError( + "Must specify either `n_timesteps`, `n_samples`, or `series`." + ), + logger=logger, + ) n_components = vals.shape[-1] if n_timesteps is not None: reshaped_vals = vals.reshape(n_timesteps, -1, n_components) diff --git a/darts/dataprocessing/transformers/fittable_data_transformer.py b/darts/dataprocessing/transformers/fittable_data_transformer.py index e037d3ad40..654ef24338 100644 --- a/darts/dataprocessing/transformers/fittable_data_transformer.py +++ b/darts/dataprocessing/transformers/fittable_data_transformer.py @@ -4,12 +4,12 @@ """ from abc import abstractmethod -from typing import Any, Generator, List, Mapping, Optional, Sequence, Union +from typing import Any, Generator, Iterable, List, Mapping, Optional, Sequence, Union import numpy as np from darts import TimeSeries -from darts.logging import get_logger, raise_if, raise_if_not +from darts.logging import get_logger, raise_log from darts.utils import _build_tqdm_iterator, _parallel_apply from .base_data_transformer import BaseDataTransformer @@ -256,8 +256,10 @@ def fit( if isinstance(series, TimeSeries): data = [series] + transformer_selector = [0] else: data = series + transformer_selector = range(len(series)) if self._mask_components: data = [ @@ -267,7 +269,9 @@ def fit( else: kwargs["component_mask"] = component_mask - params_iterator = self._get_params(n_timeseries=len(data), calling_fit=True) + params_iterator = self._get_params( + transformer_selector=transformer_selector, calling_fit=True + ) fit_iterator = ( zip(data, params_iterator) if not self._global_fit @@ -315,7 +319,7 @@ def fit_transform( ).transform(series, *args, component_mask=component_mask, **kwargs) def _get_params( - self, n_timeseries: int, calling_fit: bool = False + self, transformer_selector: Iterable, calling_fit: bool = False ) -> Generator[Mapping[str, Any], None, None]: """ Overrides `_get_params` of `BaseDataTransformer`. Creates generator of dictionaries containing @@ -327,14 +331,18 @@ def _get_params( `transform` and `inverse_transform`. """ # Call `_check_fixed_params` of `BaseDataTransformer`: - self._check_fixed_params(n_timeseries) - fitted_params = self._get_fitted_params(n_timeseries, calling_fit) + self._check_fixed_params(transformer_selector) + fitted_params = self._get_fitted_params(transformer_selector, calling_fit) def params_generator( - n_jobs, fixed_params, fitted_params, parallel_params, global_fit + transformer_selector_, + fixed_params, + fitted_params, + parallel_params, + global_fit, ): fixed_params_copy = fixed_params.copy() - for i in range(n_jobs): + for i in transformer_selector_: for key in parallel_params: fixed_params_copy[key] = fixed_params[key][i] params = {} @@ -348,37 +356,53 @@ def params_generator( params = None yield params - n_jobs = n_timeseries if not (calling_fit and self._global_fit) else 1 + transformer_selector_ = ( + transformer_selector if not (calling_fit and self._global_fit) else [0] + ) return params_generator( - n_jobs, + transformer_selector_, self._fixed_params, fitted_params, self._parallel_params, self._global_fit, ) - def _get_fitted_params(self, n_timeseries: int, calling_fit: bool) -> Sequence[Any]: + def _get_fitted_params( + self, transformer_selector: Iterable, calling_fit: bool + ) -> Sequence[Any]: """ Returns `self._fitted_params` if `calling_fit = False`, otherwise returns an empty tuple. If `calling_fit = False`, also checks that `self._fitted_params`, which is a - sequence of values, contains exactly `n_timeseries` values; if not, a `ValueError` is thrown. + sequence of values, contains exactly `transformer_selector` values; if not, a `ValueError` is thrown. """ if not calling_fit: - raise_if_not( - self._fit_called, - ("Must call `fit` before calling `transform`/`inverse_transform`."), - ) + if not self._fit_called: + raise_log( + ValueError( + "Must call `fit` before calling `transform`/`inverse_transform`." + ), + logger=logger, + ) fitted_params = self._fitted_params else: fitted_params = tuple() if not self._global_fit and fitted_params: - raise_if( - n_timeseries > len(fitted_params), - ( - f"{n_timeseries} TimeSeries were provided " - f"but only {len(fitted_params)} TimeSeries " - f"were specified upon training {self.name}." - ), - ) + n_timeseries_ = max(transformer_selector) + 1 + if n_timeseries_ > len(fitted_params): + raise_log( + ValueError( + f"{n_timeseries_} TimeSeries were provided " + f"but only {len(fitted_params)} TimeSeries " + f"were specified upon training {self.name}." + ), + logger=logger, + ) + elif n_timeseries_ < len(fitted_params): + logger.warning( + f"Only {n_timeseries_} TimeSeries (lists) were provided " + f"which is lower than the number of series (n={len(fitted_params)}) " + f"used to fit {self.name}. This can result in a mismatch between the " + f"series and the underlying transformers." + ) return fitted_params diff --git a/darts/dataprocessing/transformers/invertible_data_transformer.py b/darts/dataprocessing/transformers/invertible_data_transformer.py index fbd9e0e61a..ecf22b0261 100644 --- a/darts/dataprocessing/transformers/invertible_data_transformer.py +++ b/darts/dataprocessing/transformers/invertible_data_transformer.py @@ -9,7 +9,7 @@ import numpy as np from darts import TimeSeries -from darts.logging import get_logger, raise_if_not +from darts.logging import get_logger, raise_log from darts.utils import _build_tqdm_iterator, _parallel_apply from .base_data_transformer import BaseDataTransformer @@ -245,14 +245,14 @@ def ts_inverse_transform( def inverse_transform( self, - series: Union[TimeSeries, Sequence[TimeSeries]], + series: Union[TimeSeries, Sequence[TimeSeries], Sequence[Sequence[TimeSeries]]], *args, component_mask: Optional[np.array] = None, **kwargs, - ) -> Union[TimeSeries, List[TimeSeries]]: + ) -> Union[TimeSeries, List[TimeSeries], List[List[TimeSeries]]]: """Inverse transforms a (sequence of) series by calling the user-implemented `ts_inverse_transform` method. - In case a sequence is passed as input data, this function takes care of parallelising the + In case a sequence or list of lists is passed as input data, this function takes care of parallelising the transformation of multiple series in the sequence at the same time. Additionally, if the `mask_components` attribute was set to `True` when instantiating `InvertibleDataTransformer`, then any provided `component_mask`s will be automatically applied to each input `TimeSeries`; @@ -263,7 +263,14 @@ def inverse_transform( Parameters ---------- series - the (sequence of) series be inverse-transformed. + The series to inverse-transform. + If a single `TimeSeries`, returns a single series. + If a sequence of `TimeSeries`, returns a list of series. The series should be in the same order as the + sequence used to fit the transformer. + If a list of lists of `TimeSeries`, returns a list of lists of series. This can for example be the output + of `ForecastingModel.historical_forecasts()` when using multiple series. Each inner list should contain + `TimeSeries` related to the same series. The order of inner lists should be the same as the sequence used + to fit the transformer. args Additional positional arguments for the :func:`ts_inverse_transform()` method component_mask : Optional[np.ndarray] = None @@ -274,7 +281,7 @@ def inverse_transform( Returns ------- - Union[TimeSeries, List[TimeSeries]] + Union[TimeSeries, List[TimeSeries], List[List[TimeSeries]]] Inverse transformed data. Notes @@ -295,22 +302,35 @@ def inverse_transform( `component_masks` will be passed as a keyword argument `ts_inverse_transform`; the user can then manually specify how the `component_mask` should be applied to each series. """ - if hasattr(self, "_fit_called"): - raise_if_not( - self._fit_called, - "fit() must have been called before inverse_transform()", - logger, + if hasattr(self, "_fit_called") and not self._fit_called: + raise_log( + ValueError("fit() must have been called before inverse_transform()"), + logger=logger, ) desc = f"Inverse ({self._name})" # Take note of original input for unmasking purposes: + called_with_single_series = False + called_with_sequence_series = False if isinstance(series, TimeSeries): input_series = [series] data = [series] - else: + transformer_selector = [0] + called_with_single_series = True + elif isinstance(series[0], TimeSeries): # Sequence[TimeSeries] input_series = series data = series + transformer_selector = range(len(series)) + called_with_sequence_series = True + else: # Sequence[Sequence[TimeSeries]] + input_series = [] + data = [] + transformer_selector = [] + for idx, series_list in enumerate(series): + input_series.extend(series_list) + data.extend(series_list) + transformer_selector += [idx] * len(series_list) if self._mask_components: data = [ @@ -321,10 +341,10 @@ def inverse_transform( kwargs["component_mask"] = component_mask input_iterator = _build_tqdm_iterator( - zip(data, self._get_params(n_timeseries=len(data))), + zip(data, self._get_params(transformer_selector=transformer_selector)), verbose=self._verbose, desc=desc, - total=len(data), + total=len(transformer_selector), ) transformed_data = _parallel_apply( @@ -343,6 +363,13 @@ def inverse_transform( ) transformed_data = unmasked - return ( - transformed_data[0] if isinstance(series, TimeSeries) else transformed_data - ) + if called_with_single_series: + return transformed_data[0] + elif called_with_sequence_series: + return transformed_data + else: + cum_len = np.cumsum([0] + [len(s_) for s_ in series]) + return [ + transformed_data[cum_len[i] : cum_len[i + 1]] + for i in range(len(cum_len) - 1) + ] diff --git a/darts/tests/dataprocessing/transformers/test_invertible_data_transformer.py b/darts/tests/dataprocessing/transformers/test_invertible_data_transformer.py index 71163eb928..ca9a9f0a01 100644 --- a/darts/tests/dataprocessing/transformers/test_invertible_data_transformer.py +++ b/darts/tests/dataprocessing/transformers/test_invertible_data_transformer.py @@ -262,6 +262,96 @@ def test_input_transformed_multiple_series(self): assert inv_1 == test_input_1 assert inv_2 == test_input_2 + def test_input_transformed_list_of_lists_of_series(self): + """ + Tests for correct transformation of multiple series when + different param values are used for different parallel + jobs (i.e. test that `parallel_params` argument is treated + correctly). Also tests that transformer correctly handles + being provided with fewer input series than fixed parameter + value sets. + """ + test_input_1 = constant_timeseries(value=1, length=10) + test_input_2 = constant_timeseries(value=2, length=11) + + # Don't have different params for different jobs: + mock = self.DataTransformerMock(scale=2, translation=10, parallel_params=False) + (transformed_1, transformed_2) = mock.transform((test_input_1, test_input_2)) + # 2 * 1 + 10 = 12 + assert transformed_1 == constant_timeseries(value=12, length=10) + # 2 * 2 + 10 = 14 + assert transformed_2 == constant_timeseries(value=14, length=11) + + # list of lists of series must get input back + inv = mock.inverse_transform([[transformed_1], [transformed_2]]) + assert len(inv) == 2 + assert all( + isinstance(series_list, list) and len(series_list) == 1 + for series_list in inv + ) + assert all( + isinstance(series, TimeSeries) + for series_list in inv + for series in series_list + ) + assert inv[0][0] == test_input_1 + assert inv[1][0] == test_input_2 + + # one list of lists of is longer than others, must get input back + inv = mock.inverse_transform([[transformed_1, transformed_1], [transformed_2]]) + assert len(inv) == 2 + assert len(inv[0]) == 2 and len(inv[1]) == 1 + assert all(isinstance(series_list, list) for series_list in inv) + assert all( + isinstance(series, TimeSeries) + for series_list in inv + for series in series_list + ) + assert inv[0][0] == test_input_1 + assert inv[0][1] == test_input_1 + assert inv[1][0] == test_input_2 + + # different types of Sequences, must get input back + inv = mock.inverse_transform(((transformed_1, transformed_1), (transformed_2,))) + assert len(inv) == 2 + assert len(inv[0]) == 2 and len(inv[1]) == 1 + assert all(isinstance(series_list, list) for series_list in inv) + assert all( + isinstance(series, TimeSeries) + for series_list in inv + for series in series_list + ) + assert inv[0][0] == test_input_1 + assert inv[0][1] == test_input_1 + assert inv[1][0] == test_input_2 + + # one list of lists is empty, returns empty list as well + inv = mock.inverse_transform([[], [transformed_2, transformed_2]]) + assert len(inv) == 2 + assert len(inv[0]) == 0 and len(inv[1]) == 2 + assert all(isinstance(series_list, list) for series_list in inv) + assert all(isinstance(series, TimeSeries) for series in inv[1]) + assert inv[1][0] == test_input_2 + assert inv[1][1] == test_input_2 + + # more list of lists than used during transform works + inv = mock.inverse_transform( + [[transformed_1], [transformed_2], [transformed_2]] + ) + assert len(inv) == 3 + assert all( + isinstance(series_list, list) and len(series_list) == 1 + for series_list in inv + ) + assert all( + isinstance(series, TimeSeries) + for series_list in inv + for series in series_list + ) + assert inv[0][0] == test_input_1 + assert inv[1][0] == test_input_2 + assert inv[2][0] == test_input_2 + def test_input_transformed_multiple_samples(self): """ Tests that `stack_samples` and `unstack_samples` correctly diff --git a/darts/tests/dataprocessing/transformers/test_invertible_fittable_data_transformer.py b/darts/tests/dataprocessing/transformers/test_invertible_fittable_data_transformer.py index b699dd47bb..cdae6fdb86 100644 --- a/darts/tests/dataprocessing/transformers/test_invertible_fittable_data_transformer.py +++ b/darts/tests/dataprocessing/transformers/test_invertible_fittable_data_transformer.py @@ -1,6 +1,7 @@ from typing import Any, Mapping, Sequence, Union import numpy as np +import pytest from darts import TimeSeries from darts.dataprocessing.transformers.fittable_data_transformer import ( @@ -293,6 +294,89 @@ def test_input_transformed_multiple_series(self): assert inv_1 == test_input_1 assert inv_2 == test_input_2 + def test_input_transformed_list_of_lists_of_series(self): + """ + Tests for correct transformation of multiple series when + different param values are used for different parallel + jobs (i.e. test that `parallel_params` argument is treated + correctly). Also tests that transformer correctly handles + being provided with fewer input series than fixed parameter + value sets. + """ + test_input_1 = constant_timeseries(value=1, length=10) + test_input_2 = constant_timeseries(value=2, length=11) + + # Don't have different params for different jobs: + mock = self.DataTransformerMock(scale=2, translation=10, parallel_params=False) + (transformed_1, transformed_2) = mock.fit_transform( + (test_input_1, test_input_2) + ) + # 2 * 1 + 10 = 12 + assert transformed_1 == constant_timeseries(value=12, length=10) + # 2 * 2 + 10 = 14 + assert transformed_2 == constant_timeseries(value=14, length=11) + + # list of lists of series must get input back + inv = mock.inverse_transform([[transformed_1], [transformed_2]]) + assert len(inv) == 2 + assert all( + isinstance(series_list, list) and len(series_list) == 1 + for series_list in inv + ) + assert all( + isinstance(series, TimeSeries) + for series_list in inv + for series in series_list + ) + assert inv[0][0] == test_input_1 + assert inv[1][0] == test_input_2 + + # one list of lists of is longer than others, must get input back + inv = mock.inverse_transform([[transformed_1, transformed_1], [transformed_2]]) + assert len(inv) == 2 + assert len(inv[0]) == 2 and len(inv[1]) == 1 + assert all(isinstance(series_list, list) for series_list in inv) + assert all( + isinstance(series, TimeSeries) + for series_list in inv + for series in series_list + ) + assert inv[0][0] == test_input_1 + assert inv[0][1] == test_input_1 + assert inv[1][0] == test_input_2 + + # different types of Sequences, must get input back + inv = mock.inverse_transform(((transformed_1, transformed_1), (transformed_2,))) + assert len(inv) == 2 + assert len(inv[0]) == 2 and len(inv[1]) == 1 + assert all(isinstance(series_list, list) for series_list in inv) + assert all( + isinstance(series, TimeSeries) + for series_list in inv + for series in series_list + ) + assert inv[0][0] == test_input_1 + assert inv[0][1] == test_input_1 + assert inv[1][0] == test_input_2 + + # one list of lists is empty, returns empty list as well + inv = mock.inverse_transform([[], [transformed_2, transformed_2]]) + assert len(inv) == 2 + assert len(inv[0]) == 0 and len(inv[1]) == 2 + assert all(isinstance(series_list, list) for series_list in inv) + assert all(isinstance(series, TimeSeries) for series in inv[1]) + assert inv[1][0] == test_input_2 + assert inv[1][1] == test_input_2 + + # more list of lists than used during transform, raises error + with pytest.raises(ValueError) as err: + _ = mock.inverse_transform( + [[transformed_1], [transformed_2], [transformed_2]] + ) + assert str(err.value).startswith( + "3 TimeSeries were provided but only 2 TimeSeries were specified" + ) + def test_input_transformed_multiple_samples(self): """ Tests that `stack_samples` and `unstack_samples` correctly