From 2f780cf1f74611585e49227532305b523fbc4230 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Fri, 3 Nov 2023 17:38:58 +0100 Subject: [PATCH] fix: only pass the supported argument to GlobalForecastingModel.predict() --- .../optimized_historical_forecasts_torch.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/darts/utils/historical_forecasts/optimized_historical_forecasts_torch.py b/darts/utils/historical_forecasts/optimized_historical_forecasts_torch.py index 56bb89fdcf..ac2f403f48 100644 --- a/darts/utils/historical_forecasts/optimized_historical_forecasts_torch.py +++ b/darts/utils/historical_forecasts/optimized_historical_forecasts_torch.py @@ -94,12 +94,13 @@ def _optimized_historical_forecasts( for cls in model.__class__.__mro__ if cls.__name__ == "TorchForecastingModel" ][0] + gfm_kwargs = { + k: v + for k, v in predict_kwargs.items() + if k in ["num_samples", "predict_likelihood_parameters"] + } super(tfm_cls, model).predict( - forecast_horizon, - series, - past_covariates, - future_covariates, - **predict_kwargs, + forecast_horizon, series, past_covariates, future_covariates, **gfm_kwargs ) dataset = model._build_inference_dataset(