Skip to content

Commit

Permalink
fix bug in create lagged component names without target lags (#2576)
Browse files Browse the repository at this point in the history
* fix bug in create lagged component names without target lags

* update changelog

* clean up diffs
  • Loading branch information
dennisbader authored Nov 2, 2024
1 parent 0b9efd0 commit c116405
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

**Fixed**

- Fixed a bug when using `darts.utils.data.tabularization.create_lagged_component_names()` with target `lags=None`, that did not return any lagged target label component names. [#2576](https://github.com/unit8co/darts/pull/2576) by [Dennis Bader](https://github.com/dennisbader).

**Dependencies**

### For developers of the library:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2252,7 +2252,9 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
None,
None,
False,
1,
["no_static_target_lag-2", "no_static_target_lag-1"],
["no_static_target_hrz0"],
),
# target with static covariate (but don't use them in feature names)
(
Expand All @@ -2263,12 +2265,19 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
None,
None,
False,
2,
[
"static_0_target_lag-4",
"static_1_target_lag-4",
"static_0_target_lag-1",
"static_1_target_lag-1",
],
[
"static_0_target_hrz0",
"static_1_target_hrz0",
"static_0_target_hrz1",
"static_1_target_hrz1",
],
),
# target with static covariate (acting on global target components)
(
Expand All @@ -2279,13 +2288,18 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
None,
None,
True,
1,
[
"static_0_target_lag-4",
"static_1_target_lag-4",
"static_0_target_lag-1",
"static_1_target_lag-1",
"dummy_statcov_target_global_components",
],
[
"static_0_target_hrz0",
"static_1_target_hrz0",
],
),
# target with static covariate (component specific)
(
Expand All @@ -2296,6 +2310,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
None,
None,
True,
1,
[
"static_0_target_lag-4",
"static_1_target_lag-4",
Expand All @@ -2304,6 +2319,10 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
"dummy_statcov_target_static_0",
"dummy_statcov_target_static_1",
],
[
"static_0_target_hrz0",
"static_1_target_hrz0",
],
),
# target with static covariate (component specific & multivariate)
(
Expand All @@ -2314,6 +2333,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
None,
None,
True,
1,
[
"static_0_target_lag-4",
"static_1_target_lag-4",
Expand All @@ -2324,6 +2344,10 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
"dummy1_statcov_target_static_0",
"dummy1_statcov_target_static_1",
],
[
"static_0_target_hrz0",
"static_1_target_hrz0",
],
),
# target + past
(
Expand All @@ -2334,13 +2358,15 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
[-1],
None,
False,
1,
[
"no_static_target_lag-4",
"no_static_target_lag-3",
"past_0_pastcov_lag-1",
"past_1_pastcov_lag-1",
"past_2_pastcov_lag-1",
],
["no_static_target_hrz0"],
),
# target + future
(
Expand All @@ -2351,6 +2377,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
None,
[3],
False,
1,
[
"no_static_target_lag-2",
"no_static_target_lag-1",
Expand All @@ -2359,6 +2386,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
"future_2_futcov_lag3",
"future_3_futcov_lag3",
],
["no_static_target_hrz0"],
),
# past + future
(
Expand All @@ -2369,6 +2397,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
[-1],
[2],
False,
1,
[
"past_0_pastcov_lag-1",
"past_1_pastcov_lag-1",
Expand All @@ -2378,6 +2407,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
"future_2_futcov_lag2",
"future_3_futcov_lag2",
],
["no_static_target_hrz0"],
),
# target with static (not used) + past + future
(
Expand All @@ -2388,6 +2418,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
[-1],
[2],
False,
1,
[
"static_0_target_lag-2",
"static_1_target_lag-2",
Expand All @@ -2401,6 +2432,10 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
"future_2_futcov_lag2",
"future_3_futcov_lag2",
],
[
"static_0_target_hrz0",
"static_1_target_hrz0",
],
),
# multiple series with same components names, including past/future covariates
(
Expand All @@ -2411,6 +2446,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
[-1],
[2],
False,
1,
[
"static_0_target_lag-3",
"static_1_target_lag-3",
Expand All @@ -2422,6 +2458,10 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
"future_2_futcov_lag2",
"future_3_futcov_lag2",
],
[
"static_0_target_hrz0",
"static_1_target_hrz0",
],
),
# multiple series with different components will use the first series as reference
(
Expand All @@ -2435,6 +2475,7 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
[-1],
[2],
False,
1,
[
"static_0_target_lag-2",
"static_1_target_lag-2",
Expand All @@ -2448,6 +2489,10 @@ def test_lagged_training_data_unspecified_lag_or_series_warning(self):
"future_2_futcov_lag2",
"future_3_futcov_lag2",
],
[
"static_0_target_hrz0",
"static_1_target_hrz0",
],
),
],
)
Expand All @@ -2466,10 +2511,12 @@ def test_create_lagged_component_names(self, config):
lags_pc,
lags_fc,
use_static_cov,
ocl,
expected_lagged_features,
expected_lagged_labels,
) = config
# lags as list
created_lagged_features, _ = create_lagged_component_names(
created_lagged_features, created_lagged_labels = create_lagged_component_names(
target_series=ts_tg,
past_covariates=ts_pc,
future_covariates=ts_fc,
Expand All @@ -2478,6 +2525,7 @@ def test_create_lagged_component_names(self, config):
lags_future_covariates=lags_fc,
concatenate=False,
use_static_covariates=use_static_cov,
output_chunk_length=ocl,
)

# converts lags to dictionary format
Expand All @@ -2490,18 +2538,23 @@ def test_create_lagged_component_names(self, config):
lags_fc,
)

created_lagged_features_dict_lags, _ = create_lagged_component_names(
target_series=ts_tg,
past_covariates=ts_pc,
future_covariates=ts_fc,
lags=lags_as_dict["target"],
lags_past_covariates=lags_as_dict["past"],
lags_future_covariates=lags_as_dict["future"],
concatenate=False,
use_static_covariates=use_static_cov,
created_lagged_features_dict_lags, created_lagged_labels_dict_lags = (
create_lagged_component_names(
target_series=ts_tg,
past_covariates=ts_pc,
future_covariates=ts_fc,
lags=lags_as_dict["target"],
lags_past_covariates=lags_as_dict["past"],
lags_future_covariates=lags_as_dict["future"],
concatenate=False,
use_static_covariates=use_static_cov,
output_chunk_length=ocl,
)
)
assert expected_lagged_features == created_lagged_features
assert expected_lagged_features == created_lagged_features_dict_lags
assert expected_lagged_labels == created_lagged_labels
assert expected_lagged_labels == created_lagged_labels_dict_lags

@pytest.mark.parametrize(
"config",
Expand Down
20 changes: 12 additions & 8 deletions darts/utils/data/tabularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,10 +859,21 @@ def create_lagged_component_names(
[lags, lags_past_covariates, lags_future_covariates],
["target", "pastcov", "futcov"],
):
if variate is None or variate_lags is None:
if variate is None:
continue

components = get_single_series(variate).components.tolist()
# target labels
if variate_type == "target":
label_feature_names = [
f"{name}_target_hrz{lag}"
for lag in range(output_chunk_length)
for name in components
]

if variate_lags is None:
continue

if isinstance(variate_lags, dict):
if "default_lags" in variate_lags:
raise_log(
Expand Down Expand Up @@ -894,13 +905,6 @@ def create_lagged_component_names(
for name in components
]

if variate_type == "target" and lags:
label_feature_names = [
f"{name}_target_hrz{lag}"
for lag in range(output_chunk_length)
for name in components
]

# static covariates
if use_static_covariates:
static_covs = get_single_series(target_series).static_covariates
Expand Down

0 comments on commit c116405

Please sign in to comment.