Skip to content

Commit

Permalink
fix: only pass the supported argument to GlobalForecastingModel.predi…
Browse files Browse the repository at this point in the history
…ct()
  • Loading branch information
madtoinou committed Nov 3, 2023
1 parent 37fba40 commit 2f780cf
Showing 1 changed file with 6 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,13 @@ def _optimized_historical_forecasts(
for cls in model.__class__.__mro__
if cls.__name__ == "TorchForecastingModel"
][0]
gfm_kwargs = {
k: v
for k, v in predict_kwargs.items()
if k in ["num_samples", "predict_likelihood_parameters"]
}
super(tfm_cls, model).predict(
forecast_horizon,
series,
past_covariates,
future_covariates,
**predict_kwargs,
forecast_horizon, series, past_covariates, future_covariates, **gfm_kwargs
)

dataset = model._build_inference_dataset(
Expand Down

0 comments on commit 2f780cf

Please sign in to comment.