Skip to content

Commit

Permalink
fix: update hierarchy for single transform window_transform (#2207)
Browse files Browse the repository at this point in the history
* fix: update hierarchy for single transform window_transform

* update changelog

* update changelog

* fix: using set to check overlap

* fix: corrected logic to update the hierarchy after window_transform

* fix: hierarchy can be conserved when applying non-overlapping transforms

* feat: add new argument, improve logic

* feat: adding tests

* fix: expected argument match docstring in resample()

* fix: addressing review comments

* fix: linting issue

* fix: linting

* linting

* update changelog and remane keep_old_names to keep_names

---------

Co-authored-by: dennisbader <[email protected]>
  • Loading branch information
madtoinou and dennisbader authored Feb 24, 2024
1 parent bf51476 commit 24ae0e1
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Added option to exclude some `group_cols` from being added as static covariates when using `TimeSeries.from_group_dataframe()` with parameter `drop_group_cols`.
- Improvements to `TorchForecastingModel`:
- Added support for additional lr scheduler configuration parameters for more control ("interval", "frequency", "monitor", "strict", "name"). [#2218](https://github.com/unit8co/darts/pull/2218) by [Dennis Bader](https://github.com/dennisbader).
- Improvements to `WindowTransformer` and `window_transform`:
- Added argument `keep_names` to indicate whether the original component names should be kept. [#2207](https://github.com/unit8co/darts/pull/2207)by [Antoine Madrona](https://github.com/madtoinou).

**Fixed**
- Fixed a bug when calling `window_transform` on a `TimeSeries` with a hierarchy. The hierarchy is now only preseved for single transformations applied to all components, or removed otherwise. [#2207](https://github.com/unit8co/darts/pull/2207)by [Antoine Madrona](https://github.com/madtoinou).
- Fixed a bug in probabilistic `LinearRegressionModel.fit()`, where the `model` attribute was not pointing to all underlying estimators. [#2205](https://github.com/unit8co/darts/pull/2205) by [Antoine Madrona](https://github.com/madtoinou).
- Raise an error in `RegressionEsembleModel` when the `regression_model` was created with `multi_models=False` (not supported). [#2205](https://github.com/unit8co/darts/pull/2205) by [Antoine Madrona](https://github.com/madtoinou).
- Fixed a bug in `coefficient_of_variation()` with `intersect=True`, where the coefficient was not computed on the intersection. [#2202](https://github.com/unit8co/darts/pull/2202) by [Antoine Madrona](https://github.com/madtoinou).
Expand Down
8 changes: 7 additions & 1 deletion darts/dataprocessing/transformers/window_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
forecasting_safe: Optional[bool] = True,
keep_non_transformed: Optional[bool] = False,
include_current: Optional[bool] = True,
keep_names: Optional[bool] = False,
name: str = "WindowTransformer",
n_jobs: int = 1,
verbose: bool = False,
Expand Down Expand Up @@ -123,11 +124,15 @@ def __init__(
keep_non_transformed
``False`` to return the transformed components only, ``True`` to return all original components along
the transformed ones. Default is ``False``.
the transformed ones. Default is ``False``. If the series has a hierarchy, must be set to ``False``.
include_current
``True`` to include the current time step in the window, ``False`` to exclude it. Default is ``True``.
keep_names
Whether the transformed components should keep the original component names or. Must be set to ``False``
if `keep_non_transformed = True` or the number of transformation is greater than 1.
name
A specific name for the transformer.
Expand All @@ -147,6 +152,7 @@ def __init__(
self.treat_na = treat_na
self.forecasting_safe = forecasting_safe
self.include_current = include_current
self.keep_names = keep_names
super().__init__(name, n_jobs, verbose)

@staticmethod
Expand Down
186 changes: 186 additions & 0 deletions darts/tests/dataprocessing/transformers/test_window_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,30 @@
from darts.dataprocessing.transformers import Mapper, WindowTransformer


def helper_generate_ts_hierarchy(length: int):
values = np.stack(
[
np.ones(
length,
)
* 5,
np.ones(
length,
)
* 3,
np.ones(
length,
)
* 2,
],
axis=1,
)
hierarchy = {"B": "A", "C": "A"}
return TimeSeries.from_values(
values=values, columns=["A", "B", "C"], hierarchy=hierarchy
)


class TestTimeSeriesWindowTransform:

times = pd.date_range("20130101", "20130110")
Expand Down Expand Up @@ -128,6 +152,49 @@ def test_ts_windowtransf_input_dictionary(self):
} # forecating_safe=True vs center=True
self.series_univ_det.window_transform(transforms=window_transformations)

# keep_names and overlapping transforms
with pytest.raises(ValueError) as err:
window_transformations = [
{
"function": "mean",
"mode": "rolling",
"window": 3,
"components": self.series_multi_det.components[:1],
},
{
"function": "median",
"mode": "rolling",
"window": 3,
"components": self.series_multi_det.components,
},
]
self.series_multi_det.window_transform(
transforms=window_transformations, keep_names=True
)
assert str(err.value) == (
"Cannot keep the original component names as some transforms are overlapping "
"(applied to the same components). Set `keep_names` to `False`."
)

# keep_names and keep_non_transformed
with pytest.raises(ValueError) as err:
window_transformations = [
{
"function": "mean",
"mode": "rolling",
"window": 3,
"components": self.series_multi_det.components[:1],
},
]
self.series_multi_det.window_transform(
transforms=window_transformations,
keep_names=True,
keep_non_transformed=True,
)
assert str(err.value) == (
"`keep_names = True` and `keep_non_transformed = True` cannot be used together."
)

def test_ts_windowtransf_output_series(self):
# univariate deterministic input
transforms = {"function": "sum", "mode": "rolling", "window": 1}
Expand Down Expand Up @@ -462,6 +529,98 @@ def test_include_current(self):
)
assert transformed_ts == expected_transformed_series

@pytest.mark.parametrize(
"transforms",
[
{
"function": "median",
"mode": "rolling",
"window": 3,
},
{
"function": "mean",
"mode": "expanding",
"window": 2,
"components": ["A", "B", "C"],
},
],
)
def test_ts_windowtransf_hierarchy(self, transforms):
"""Checking that supported transforms behave as expected:
- implicitely applied to all components
- passing explicitely all components
"""
ts = helper_generate_ts_hierarchy(10)

# renaming components based on transform parameters
ts_tr = ts.window_transform(transforms=transforms)
tr_prefix = (
f"{transforms['mode']}_{transforms['function']}_{transforms['window']}_"
)
assert ts_tr.hierarchy == {
tr_prefix + comp: [tr_prefix + "A"] for comp in ["B", "C"]
}

# keeping original components name
ts_tr = ts.window_transform(transforms=transforms, keep_names=True)
assert ts_tr.hierarchy == ts.hierarchy == {"C": ["A"], "B": ["A"]}

@pytest.mark.parametrize(
"transforms",
[
{"function": "median", "mode": "rolling", "window": 3, "components": ["B"]},
[
{
"function": "mean",
"mode": "expanding",
"window": 2,
},
{
"function": "median",
"mode": "rolling",
"window": 3,
},
],
[
{
"function": "median",
"mode": "rolling",
"window": 3,
"components": ["B", "C"],
},
{
"function": "sum",
"mode": "rolling",
"window": 5,
"components": ["A", "C"],
},
],
],
)
def test_ts_windowtransf_drop_hierarchy(self, transforms):
"""Checking that hierarchy is correctly removed when
- transform is not applied to all the components
- several transforms applied to all the components
- two transforms with overlapping components
"""
ts = helper_generate_ts_hierarchy(10)
ts_tr = ts.window_transform(transforms=transforms)
assert ts_tr.hierarchy is None

def test_ts_windowtransf_hierarchy_wrong_args(self):
ts = helper_generate_ts_hierarchy(10)

# hierarchy + keep_non_transformed = ambiguity for hierarchy
with pytest.raises(ValueError):
ts.window_transform(
transforms={
"function": "sum",
"mode": "rolling",
"window": 3,
},
keep_non_transformed=True,
)


class TestWindowTransformer:

Expand Down Expand Up @@ -579,3 +738,30 @@ def times_five(x):
transformed_series = pipeline.fit_transform(series_1)

assert transformed_series == expected_transformed_series

def test_transformer_hierarchy(self):
ts = helper_generate_ts_hierarchy(10)
transform = {
"function": "median",
"mode": "rolling",
"window": 3,
}

# renaming components
window_transformer = WindowTransformer(
transforms=[transform],
)
ts_tr = window_transformer.transform(ts)
tr_prefix = (
f"{transform['mode']}_{transform['function']}_{transform['window']}_"
)
assert ts_tr.hierarchy == {
tr_prefix + comp: [tr_prefix + "A"] for comp in ["B", "C"]
}
# keeping old components
window_transformer = WindowTransformer(
transforms=transform,
keep_names=True,
)
ts_tr = window_transformer.transform(ts)
assert ts_tr.hierarchy == ts.hierarchy == {"C": ["A"], "B": ["A"]}
86 changes: 80 additions & 6 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3238,7 +3238,7 @@ def resample(self, freq: str, method: str = "pad", **kwargs) -> Self:
# TODO: check
if method == "pad":
new_xa = resample.pad()
elif method == "bfill":
elif method in ["bfill", "backfill"]:
new_xa = resample.backfill()
else:
raise_log(ValueError(f"Unknown method: {method}"), logger)
Expand Down Expand Up @@ -3360,6 +3360,7 @@ def window_transform(
forecasting_safe: Optional[bool] = True,
keep_non_transformed: Optional[bool] = False,
include_current: Optional[bool] = True,
keep_names: Optional[bool] = False,
) -> Self:
"""
Applies a moving/rolling, expanding or exponentially weighted window transformation over this ``TimeSeries``.
Expand Down Expand Up @@ -3458,11 +3459,15 @@ def window_transform(
keep_non_transformed
``False`` to return the transformed components only, ``True`` to return all original components along
the transformed ones. Default is ``False``.
the transformed ones. Default is ``False``. If the series has a hierarchy, must be set to ``False``.
include_current
``True`` to include the current time step in the window, ``False`` to exclude it. Default is ``True``.
keep_names
Whether the transformed components should keep the original component names or. Must be set to ``False``
if `keep_non_transformed = True` or the number of transformation is greater than 1.
Returns
-------
TimeSeries
Expand Down Expand Up @@ -3611,6 +3616,53 @@ def _get_kwargs(transformation, forecasting_safe):
if isinstance(transforms, dict):
transforms = [transforms]

# check if some transformations are applied to the same components
overlapping_transforms = False
transformed_components = set()
for tr in transforms:
if not isinstance(tr, dict):
raise_log(
ValueError("Every entry in `transforms` must be a dictionary"),
logger,
)
tr_comps = set(tr["components"] if "components" in tr else self.components)
if len(transformed_components.intersection(tr_comps)) > 0:
overlapping_transforms = True
transformed_components = transformed_components.union(tr_comps)

if keep_names and overlapping_transforms:
raise_log(
ValueError(
"Cannot keep the original component names as some transforms are overlapping "
"(applied to the same components). Set `keep_names` to `False`."
),
logger,
)

# actually, this could be allowed to allow transformation "in place"?
# keep_non_transformed can be changed to False/ignored if the transforms are not partial
if keep_names and keep_non_transformed:
raise_log(
ValueError(
"`keep_names = True` and `keep_non_transformed = True` cannot be used together."
),
logger,
)

partial_transforms = transformed_components != set(self.components)
new_hierarchy = None
convert_hierarchy = False
comp_names_map = dict()
if self.hierarchy:
# the partial_transform covers for scenario keep_non_transformed = True
if len(transforms) > 1 or partial_transforms:
logger.warning(
"The hierarchy cannot be retained, either because there is more than one transform or "
"because the transform is not applied to all the components of the series."
)
else:
convert_hierarchy = True

raise_if_not(
all([isinstance(tr, dict) for tr in transforms]),
"`transforms` must be a non-empty dictionary or a non-empty list of dictionaries.",
Expand Down Expand Up @@ -3688,9 +3740,22 @@ def _get_kwargs(transformation, forecasting_safe):
f"{'_'+str(min_periods) if min_periods>1 else ''}"
)

new_columns.extend(
[f"{name_prefix}_{comp_name}" for comp_name in comps_to_transform]
)
if keep_names:
new_columns.extend(comps_to_transform)
else:
names_w_prefix = [
f"{name_prefix}_{comp_name}" for comp_name in comps_to_transform
]
new_columns.extend(names_w_prefix)
if convert_hierarchy:
comp_names_map.update(
{
c_name: new_c_name
for c_name, new_c_name in zip(
comps_to_transform, names_w_prefix
)
}
)

# track how many NaN rows are added by each transformation on each transformed column
# NaNs would appear only if user changes "min_periods" to else than 1, if not,
Expand Down Expand Up @@ -3745,14 +3810,23 @@ def _get_kwargs(transformation, forecasting_safe):
# revert dataframe to TimeSeries
new_index = original_index.__class__(resulting_transformations.index)

if convert_hierarchy:
if keep_names:
new_hierarchy = self.hierarchy
else:
new_hierarchy = {
comp_names_map[k]: [comp_names_map[old_name] for old_name in v]
for k, v in self.hierarchy.items()
}

transformed_time_series = TimeSeries.from_times_and_values(
times=new_index,
values=resulting_transformations.values.reshape(
len(new_index), -1, n_samples
),
columns=new_columns,
static_covariates=self.static_covariates,
hierarchy=self.hierarchy,
hierarchy=new_hierarchy,
)

return transformed_time_series
Expand Down

0 comments on commit 24ae0e1

Please sign in to comment.