Skip to content

Commit

Permalink
Add ForecastingModel.supports_probabilistic_prediction (#2259) (#2269)
Browse files Browse the repository at this point in the history
* Remove unnessesary `pass` statements

* Rename ForecastingModel_is_probabilistic to supports_probabilistic_prediction, rearrange some documentation

* Remove redundant overrides

* Reformat

* Add CHANGELOG entry

---------

Co-authored-by: Dennis Bader <[email protected]>
  • Loading branch information
felixdivo and dennisbader authored Mar 12, 2024
1 parent 2264cca commit 7986348
Show file tree
Hide file tree
Showing 27 changed files with 55 additions and 53 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

### For users of the library:
**Improved**
- Improvements to `ForecastingModel`:
- Renamed the private `_is_probabilistic` property to a public `supports_probabilistic_prediction`. [#2269](https://github.com/unit8co/darts/pull/2269) by [Felix Divo](https://github.com/felixdivo).

**Fixed**

Expand Down
2 changes: 1 addition & 1 deletion darts/explainability/shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __init__(
test_stationarity=True,
)

if model._is_probabilistic:
if model.supports_probabilistic_prediction:
logger.warning(
"The model is probabilistic, but num_samples=1 will be used for explainability."
)
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _predict(
return self._build_forecast_series(forecast)

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return True

@property
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/catboost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _likelihood_components_names(
return None

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return self.likelihood is not None

@property
Expand Down
4 changes: 0 additions & 4 deletions darts/models/forecasting/croston.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,3 @@ def min_train_series_length(self) -> int:
@property
def _supports_range_index(self) -> bool:
return True

@property
def _is_probabilistic(self) -> bool:
return False
17 changes: 13 additions & 4 deletions darts/models/forecasting/ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def __init__(
raise_if(
train_num_samples is not None
and train_num_samples > 1
and all([not m._is_probabilistic for m in forecasting_models]),
and all(
[not m.supports_probabilistic_prediction for m in forecasting_models]
),
"`train_num_samples` is greater than 1 but the `RegressionEnsembleModel` "
"contains only deterministic `forecasting_models`.",
logger,
Expand Down Expand Up @@ -261,7 +263,9 @@ def _make_multiple_predictions(
future_covariates=(
future_covariates if model.supports_future_covariates else None
),
num_samples=num_samples if model._is_probabilistic else 1,
num_samples=(
num_samples if model.supports_probabilistic_prediction else 1
),
predict_likelihood_parameters=predict_likelihood_parameters,
)
for model in self.forecasting_models
Expand Down Expand Up @@ -432,7 +436,12 @@ def output_chunk_length(self) -> Optional[int]:

@property
def _models_are_probabilistic(self) -> bool:
return all([model._is_probabilistic for model in self.forecasting_models])
return all(
[
model.supports_probabilistic_prediction
for model in self.forecasting_models
]
)

@property
def _models_same_likelihood(self) -> bool:
Expand Down Expand Up @@ -480,7 +489,7 @@ def supports_likelihood_parameter_prediction(self) -> bool:
)

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return self._models_are_probabilistic

@property
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/exponential_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def supports_multivariate(self) -> bool:
return False

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return True

@property
Expand Down
17 changes: 7 additions & 10 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,10 @@ def _supports_range_index(self) -> bool:
return True

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
"""
Checks if the forecasting model supports probabilistic predictions.
Checks if the forecasting model with this configuration supports probabilistic predictions.
By default, returns False. Needs to be overwritten by models that do support
probabilistic predictions.
"""
Expand All @@ -204,7 +205,9 @@ def _is_probabilistic(self) -> bool:
def _supports_non_retrainable_historical_forecasts(self) -> bool:
"""
Checks if the forecasting model supports historical forecasts without retraining
the model. By default, returns False. Needs to be overwritten by models that do
the model.
By default, returns False. Needs to be overwritten by models that do
support historical forecasts without retraining.
"""
return False
Expand Down Expand Up @@ -250,7 +253,6 @@ def supports_transferrable_series_prediction(self) -> bool:
"""
Whether the model supports prediction for any input `series`.
"""
pass

@property
def uses_past_covariates(self) -> bool:
Expand Down Expand Up @@ -347,7 +349,7 @@ def predict(
logger=logger,
)

if not self._is_probabilistic and num_samples > 1:
if not self.supports_probabilistic_prediction and num_samples > 1:
raise_log(
ValueError(
"`num_samples > 1` is only supported for probabilistic models."
Expand Down Expand Up @@ -488,7 +490,6 @@ def extreme_lags(
>>> model.extreme_lags
(-10, 6, None, None, 4, 6, 0)
"""
pass

@property
def _training_sample_time_index_length(self) -> int:
Expand Down Expand Up @@ -1870,7 +1871,6 @@ def _model_encoder_settings(
Must return Tuple (input_chunk_length, output_chunk_length, takes_past_covariates, takes_future_covariates,
lags_past_covariates, lags_future_covariates).
"""
pass

@classmethod
def _sample_params(model_class, params, n_random_samples):
Expand Down Expand Up @@ -2481,7 +2481,6 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non
"""Fits/trains the model on the provided series.
DualCovariatesModels must implement the fit logic in this method.
"""
pass

def predict(
self,
Expand Down Expand Up @@ -2575,7 +2574,6 @@ def _predict(
"""Forecasts values for a certain number of time steps after the end of the series.
DualCovariatesModels must implement the predict logic in this method.
"""
pass

@property
def _model_encoder_settings(
Expand Down Expand Up @@ -2778,7 +2776,6 @@ def _predict(
"""Forecasts values for a certain number of time steps after the end of the series.
TransferableFutureCovariatesLocalForecastingModel must implement the predict logic in this method.
"""
pass

@property
def supports_transferrable_series_prediction(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/global_baseline_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def min_train_series_length(self) -> int:
def supports_likelihood_parameter_prediction(self) -> bool:
return False

def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return False

@property
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/kalman_forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,5 @@ def supports_multivariate(self) -> bool:
return True

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return True
2 changes: 1 addition & 1 deletion darts/models/forecasting/lgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def _predict_and_sample(
)

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return self.likelihood is not None

@property
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/linear_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,5 +305,5 @@ def _predict_and_sample(
)

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return self.likelihood is not None
2 changes: 1 addition & 1 deletion darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def set_mc_dropout(self, active: bool):
module.mc_dropout_enabled = active

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return self.likelihood is not None or len(self._get_mc_dropout_modules()) > 0

def _produce_predict_output(self, x: Tuple) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/prophet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def supports_multivariate(self) -> bool:
return False

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return True

def _stochastic_samples(self, predict_df, n_samples) -> np.ndarray:
Expand Down
8 changes: 5 additions & 3 deletions darts/models/forecasting/regression_ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ def _make_multiple_historical_forecasts(
),
forecast_horizon=model.output_chunk_length,
stride=model.output_chunk_length,
num_samples=num_samples if model._is_probabilistic else 1,
num_samples=(
num_samples if model.supports_probabilistic_prediction else 1
),
start=-start_hist_forecasts,
start_format="position",
retrain=False,
Expand Down Expand Up @@ -486,9 +488,9 @@ def supports_multivariate(self) -> bool:
)

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
"""
A RegressionEnsembleModel is probabilistic if its regression
model is probabilistic (ensembling layer)
"""
return self.regression_model._is_probabilistic
return self.regression_model.supports_probabilistic_prediction
2 changes: 1 addition & 1 deletion darts/models/forecasting/sf_auto_arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,5 @@ def _supports_range_index(self) -> bool:
return True

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return True
4 changes: 0 additions & 4 deletions darts/models/forecasting/sf_auto_ces.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,3 @@ def min_train_series_length(self) -> int:
@property
def _supports_range_index(self) -> bool:
return True

@property
def _is_probabilistic(self) -> bool:
return False
2 changes: 1 addition & 1 deletion darts/models/forecasting/sf_auto_ets.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,5 +164,5 @@ def _supports_range_index(self) -> bool:
return True

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return True
2 changes: 1 addition & 1 deletion darts/models/forecasting/sf_auto_theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,5 @@ def _supports_range_index(self) -> bool:
return True

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return True
2 changes: 1 addition & 1 deletion darts/models/forecasting/tbats_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def supports_multivariate(self) -> bool:
return False

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return True

@property
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2051,9 +2051,9 @@ def output_chunk_shift(self) -> int:
)

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return (
self.model._is_probabilistic
self.model.supports_probabilistic_prediction
if self.model_created
else True # all torch models can be probabilistic (via Dropout)
)
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/varima.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def min_train_series_length(self) -> int:
return 30

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return True

@property
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def _predict_and_sample(
)

@property
def _is_probabilistic(self) -> bool:
def supports_probabilistic_prediction(self) -> bool:
return self.likelihood is not None

@property
Expand Down
2 changes: 1 addition & 1 deletion darts/tests/models/forecasting/test_TFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def helper_fit_predict(
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
num_samples=(100 if model._is_probabilistic else 1),
num_samples=(100 if model.supports_probabilistic_prediction else 1),
)

if isinstance(y_hat, TimeSeries):
Expand Down
2 changes: 1 addition & 1 deletion darts/tests/models/forecasting/test_ensemble_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_stochastic_naive_ensemble(self):

# only probabilistic forecasting models
naive_ensemble_proba = NaiveEnsembleModel([model_proba_1, model_proba_2])
assert naive_ensemble_proba._is_probabilistic
assert naive_ensemble_proba.supports_probabilistic_prediction

naive_ensemble_proba.fit(self.series1 + self.series2)
# by default, only 1 sample
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def test_covariates(self, config):
)

# when model is fit using 1 training and 1 covariate series, time series args are optional
if model._is_probabilistic:
if model.supports_probabilistic_prediction:
return
model = model_cls(
input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs
Expand Down Expand Up @@ -661,7 +661,7 @@ def test_same_result_with_different_n_jobs(self, config):
model.fit(multiple_ts)

# safe random state for two successive identical predictions
if model._is_probabilistic:
if model.supports_probabilistic_prediction:
random_state = deepcopy(model._random_instance)
else:
random_state = None
Expand Down
Loading

0 comments on commit 7986348

Please sign in to comment.