diff --git a/darts/explainability/shap_explainer.py b/darts/explainability/shap_explainer.py index 143ea0d8b9..26978bb00a 100644 --- a/darts/explainability/shap_explainer.py +++ b/darts/explainability/shap_explainer.py @@ -375,7 +375,7 @@ def summary_plot( num_samples: Optional[int] = None, plot_type: Optional[str] = "dot", **kwargs, - ): + ) -> Dict[int, Dict[str, shap.Explanation]]: """ Display a shap plot summary for each horizon and each component dimension of the target. This method reuses the initial background data as foreground (potentially sampled) to give a general importance @@ -395,6 +395,12 @@ def summary_plot( for the sake of performance. plot_type Optionally, specify which of the shap library plot type to use. Can be one of ``'dot', 'bar', 'violin'``. + + Returns + ------- + shaps_ + A nested dictionary {horizon : {component : shap.Explaination}} containing the raw Explanations for all + the horizons and components. """ horizons, target_components = self._process_horizons_and_targets( @@ -421,6 +427,7 @@ def summary_plot( plot_type=plot_type, **kwargs, ) + return shaps_ def force_plot_from_ts( self, @@ -613,7 +620,7 @@ def __init__( def shap_explanations( self, - foreground_X, + foreground_X: pd.DataFrame, horizons: Optional[Sequence[int]] = None, target_components: Optional[Sequence[str]] = None, ) -> Dict[int, Dict[str, shap.Explanation]]: @@ -735,8 +742,8 @@ def _create_regression_model_shap_X( target_series: Optional[Union[TimeSeries, Sequence[TimeSeries]]], past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]], future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]], - n_samples=None, - train=False, + n_samples: Optional[int] = None, + train: bool = False, ) -> pd.DataFrame: """ Creates the shap format input for regression models. diff --git a/darts/tests/explainability/test_shap_explainer.py b/darts/tests/explainability/test_shap_explainer.py index a5e950adb4..dfc3773af2 100644 --- a/darts/tests/explainability/test_shap_explainer.py +++ b/darts/tests/explainability/test_shap_explainer.py @@ -596,6 +596,14 @@ def test_plot(self): "power", ) + # Check the dimensions of returned values + dict_shap_values = shap_explain.summary_plot(show=False) + # One nested dict per horizon + assert len(dict_shap_values) == m_0.output_chunk_length + # Size of nested dict match number of component + for i in range(1, m_0.output_chunk_length + 1): + assert len(dict_shap_values[i]) == self.target_ts.width + # Wrong component name with pytest.raises(ValueError): shap_explain.summary_plot(horizons=[1], target_components=["test"])