Skip to content

Commit

Permalink
Merge branch 'master' into feat/doc_autoreg_pc
Browse files Browse the repository at this point in the history
  • Loading branch information
madtoinou authored Nov 6, 2023
2 parents 42f54d1 + 2d0233f commit 23d3eea
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 7 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- 🚀🚀 Optimized `historical_forecasts()` for pre-trained `TorchForecastingModel` running up to 20 times faster than before!. [#2013](https://github.com/unit8co/darts/pull/2013) by [Dennis Bader](https://github.com/dennisbader).
- Added callback `darts.utils.callbacks.TFMProgressBar` to customize at which model stages to display the progress bar. [#2020](https://github.com/unit8co/darts/pull/2020) by [Dennis Bader](https://github.com/dennisbader).
- Improvements to documentation:
- Adapted the example notebooks to properly apply data transformers and avoid look-ahead bias. [#2020](https://github.com/unit8co/darts/pull/2020) by [Samriddhi Singh](https://github.com/SimTheGreat).
- Adapted the example notebooks to properly apply data transformers and avoid look-ahead bias. [#2020](https://github.com/unit8co/darts/pull/2020) by [Samriddhi Singh](https://github.com/SimTheGreat).
- Improvements to Regression Models:
- `XGBModel` now leverages XGBoost's native Quantile Regression support that was released in version 2.0.0 for improved probabilistic forecasts. [#2051](https://github.com/unit8co/darts/pull/2051) by [Dennis Bader](https://github.com/dennisbader).

**Fixed**
- Fixed a bug when calling optimized `historical_forecasts()` for a `RegressionModel` trained with unequal component-specific lags. [#2040](https://github.com/unit8co/darts/pull/2040) by [Antoine Madrona](https://github.com/madtoinou).
Expand Down
15 changes: 11 additions & 4 deletions darts/explainability/shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def summary_plot(
num_samples: Optional[int] = None,
plot_type: Optional[str] = "dot",
**kwargs,
):
) -> Dict[int, Dict[str, shap.Explanation]]:
"""
Display a shap plot summary for each horizon and each component dimension of the target.
This method reuses the initial background data as foreground (potentially sampled) to give a general importance
Expand All @@ -395,6 +395,12 @@ def summary_plot(
for the sake of performance.
plot_type
Optionally, specify which of the shap library plot type to use. Can be one of ``'dot', 'bar', 'violin'``.
Returns
-------
shaps_
A nested dictionary {horizon : {component : shap.Explaination}} containing the raw Explanations for all
the horizons and components.
"""

horizons, target_components = self._process_horizons_and_targets(
Expand All @@ -421,6 +427,7 @@ def summary_plot(
plot_type=plot_type,
**kwargs,
)
return shaps_

def force_plot_from_ts(
self,
Expand Down Expand Up @@ -613,7 +620,7 @@ def __init__(

def shap_explanations(
self,
foreground_X,
foreground_X: pd.DataFrame,
horizons: Optional[Sequence[int]] = None,
target_components: Optional[Sequence[str]] = None,
) -> Dict[int, Dict[str, shap.Explanation]]:
Expand Down Expand Up @@ -735,8 +742,8 @@ def _create_regression_model_shap_X(
target_series: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
n_samples=None,
train=False,
n_samples: Optional[int] = None,
train: bool = False,
) -> pd.DataFrame:
"""
Creates the shap format input for regression models.
Expand Down
17 changes: 15 additions & 2 deletions darts/models/forecasting/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@

logger = get_logger(__name__)

# Check whether we are running xgboost >= 2.0.0 for quantile regression
tokens = xgb.__version__.split(".")
xgb_200_or_above = int(tokens[0]) >= 2


def xgb_quantile_loss(labels: np.ndarray, preds: np.ndarray, quantile: float):
"""Custom loss function for XGBoost to compute quantile loss gradient.
Expand Down Expand Up @@ -184,8 +188,12 @@ def encode_year(idx):
if likelihood in {"poisson"}:
self.kwargs["objective"] = f"count:{likelihood}"
elif likelihood == "quantile":
if xgb_200_or_above:
# leverage built-in Quantile Regression
self.kwargs["objective"] = "reg:quantileerror"
self.quantiles, self._median_idx = self._prepare_quantiles(quantiles)
self._model_container = self._get_model_container()

self._rng = np.random.default_rng(seed=random_state) # seed for sampling

super().__init__(
Expand Down Expand Up @@ -250,12 +258,17 @@ def fit(
)
]

# TODO: XGBRegressor supports multi quantile reqression which we could leverage in the future
# see https://xgboost.readthedocs.io/en/latest/python/examples/quantile_regression.html
if self.likelihood == "quantile":
# empty model container in case of multiple calls to fit, e.g. when backtesting
self._model_container.clear()
for quantile in self.quantiles:
obj_func = partial(xgb_quantile_loss, quantile=quantile)
self.kwargs["objective"] = obj_func
if xgb_200_or_above:
self.kwargs["quantile_alpha"] = quantile
else:
objective = partial(xgb_quantile_loss, quantile=quantile)
self.kwargs["objective"] = objective
self.model = xgb.XGBRegressor(**self.kwargs)

super().fit(
Expand Down
8 changes: 8 additions & 0 deletions darts/tests/explainability/test_shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,14 @@ def test_plot(self):
"power",
)

# Check the dimensions of returned values
dict_shap_values = shap_explain.summary_plot(show=False)
# One nested dict per horizon
assert len(dict_shap_values) == m_0.output_chunk_length
# Size of nested dict match number of component
for i in range(1, m_0.output_chunk_length + 1):
assert len(dict_shap_values[i]) == self.target_ts.width

# Wrong component name
with pytest.raises(ValueError):
shap_explain.summary_plot(horizons=[1], target_components=["test"])
Expand Down

0 comments on commit 23d3eea

Please sign in to comment.