Skip to content

Commit

Permalink
fix bug in shapexplainer with native multioutput support and explaini… (
Browse files Browse the repository at this point in the history
#2428)

* fix bug in shapexplainer with native multioutput support and explaining only selected target components

* update changelog
  • Loading branch information
dennisbader authored Jun 26, 2024
1 parent c498405 commit 5c8b366
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Fixed**
- Fixed a bug when using `historical_forecasts()` with a pre-trained `RegressionModel` that has no target lags `lags=None` but uses static covariates. [#2426](https://github.com/unit8co/darts/pull/2426) by [Dennis Bader](https://github.com/dennisbader).
- Fixed a bug with `xgboost>=2.1.0`, where multi output regression was not properly handled. [#2426](https://github.com/unit8co/darts/pull/2426) by [Dennis Bader](https://github.com/dennisbader).
- Fixed a bug when using `ShapExplainer.explain()` with some selected `target_components` and a regression model that natively supports multi output regression: The target components were not properly mapped. [#2428](https://github.com/unit8co/darts/pull/2428) by [Dennis Bader](https://github.com/dennisbader).

**Dependencies**

Expand Down
2 changes: 1 addition & 1 deletion darts/explainability/shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def shap_explanations(
shap_explanation_tmp = self.explainers(foreground_X)
for h in horizons:
tmp_n = {}
for t_idx, t in enumerate(target_components):
for t_idx, t in enumerate(self.target_components):
if t not in target_components:
continue
if not self.single_output:
Expand Down
36 changes: 28 additions & 8 deletions darts/tests/explainability/test_shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,13 @@ def test_creation(self):

# Good type of explainers
shap_explain = ShapExplainer(m)
assert isinstance(
shap_explain.explainers.explainers[0][0], shap.explainers.Tree
)
if isinstance(m, XGBModel):
# since xgboost > 2.1.0, model supports native multi output regression
assert isinstance(shap_explain.explainers.explainers, shap.explainers.Tree)
else:
assert isinstance(
shap_explain.explainers.explainers[0][0], shap.explainers.Tree
)

# Linear model - also not a MultiOutputRegressor
m = LinearRegressionModel(
Expand Down Expand Up @@ -266,9 +270,12 @@ def test_creation(self):
future_covariates=self.fut_cov_ts,
)
shap_explain = ShapExplainer(m)
assert isinstance(
shap_explain.explainers.explainers[0][0], shap.explainers.Tree
)
if isinstance(m, XGBModel):
assert isinstance(shap_explain.explainers.explainers, shap.explainers.Tree)
else:
assert isinstance(
shap_explain.explainers.explainers[0][0], shap.explainers.Tree
)

# Bad choice of shap explainer
with pytest.raises(ValueError):
Expand Down Expand Up @@ -709,13 +716,26 @@ def test_shap_explanation_object_validity(self):
shap.Explanation,
)

def test_shap_selected_components(self):
model_cls = LightGBMModel if lgbm_available else XGBModel
@pytest.mark.parametrize(
"config",
[
(XGBModel, {}),
(
LightGBMModel if lgbm_available else XGBModel,
{"likelihood": "quantile", "quantiles": [0.5]},
),
],
)
def test_shap_selected_components(self, config):
"""Test selected components with and without Darts' MultiOutputRegressor"""
model_cls, model_kwargs = config
# model_cls = XGBModel
model = model_cls(
lags=4,
lags_past_covariates=2,
lags_future_covariates=[1],
output_chunk_length=1,
**model_kwargs,
)
model.fit(
series=self.target_ts,
Expand Down

0 comments on commit 5c8b366

Please sign in to comment.