diff --git a/darts/models/forecasting/forecasting_model.py b/darts/models/forecasting/forecasting_model.py index d387d08ef7..35deb35130 100644 --- a/darts/models/forecasting/forecasting_model.py +++ b/darts/models/forecasting/forecasting_model.py @@ -31,10 +31,12 @@ import numpy as np import pandas as pd +from sklearn.multioutput import MultiOutputRegressor from darts import metrics from darts.dataprocessing.encoders import SequentialEncoder from darts.logging import get_logger, raise_if, raise_if_not, raise_log +from darts.models.utils import _check_kwargs_keys from darts.timeseries import TimeSeries from darts.utils import _build_tqdm_iterator, _parallel_apply, _with_sanity_checks from darts.utils.historical_forecasts.utils import ( @@ -316,6 +318,7 @@ def _fit_wrapper( series: TimeSeries, past_covariates: Optional[TimeSeries], future_covariates: Optional[TimeSeries], + **kwargs, ): self.fit(series) @@ -329,12 +332,20 @@ def _predict_wrapper( verbose: bool = False, predict_likelihood_parameters: bool = False, num_loader_workers: int = 0, + batch_size: Optional[int] = None, + n_jobs: int = 1, + roll_size: Optional[int] = None, + mc_dropout: bool = False, ) -> TimeSeries: kwargs = dict() if self.supports_likelihood_parameter_prediction: kwargs["predict_likelihood_parameters"] = predict_likelihood_parameters if getattr(self, "trainer_params", False): kwargs["num_loader_workers"] = num_loader_workers + kwargs["batch_size"] = batch_size + kwargs["n_jobs"] = n_jobs + kwargs["roll_size"] = roll_size + kwargs["mc_dropout"] = mc_dropout return self.predict(n, num_samples=num_samples, verbose=verbose, **kwargs) @property @@ -846,21 +857,17 @@ def retrain_func( "val_past_covariates", "val_future_covariates", ] - fit_invalid_args = set(fit_invalid_args).intersection(set(fit_kwargs.keys())) - if len(fit_invalid_args) > 0: - raise_log( - f"`fit_kwargs` cannot contain the following parameters : {list(fit_invalid_args)}.", - logger, - ) + _check_kwargs_keys( + param_name="fit_kwargs", + kwargs_dict=fit_kwargs, + invalid_keys=fit_invalid_args, + ) predict_invalid_args = forbiden_args + ["n", "trainer"] - predict_invalid_args = set(predict_invalid_args).intersection( - set(predict_kwargs.keys()) + _check_kwargs_keys( + param_name="predict_kwargs", + kwargs_dict=predict_kwargs, + invalid_keys=predict_invalid_args, ) - if len(predict_invalid_args) > 0: - raise_log( - f"`predict_kwargs` cannot contain the following parameters : {list(predict_invalid_args)}.", - logger, - ) series = series2seq(series) past_covariates = series2seq(past_covariates) @@ -1330,6 +1337,7 @@ def gridsearch( verbose=False, n_jobs: int = 1, n_random_samples: Optional[Union[int, float]] = None, + fit_kwargs: Optional[Dict[str, Any]] = None, predict_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple["ForecastingModel", Dict[str, Any], float]: """ @@ -1479,12 +1487,27 @@ def gridsearch( logger, ) + if fit_kwargs is None: + fit_kwargs = dict() if predict_kwargs is None: predict_kwargs = dict() - raise_if( - "num_samples" in predict_kwargs, - "`num_samples = 1` cannot be modified using `predict_kwargs`.", - logger, + + forbiden_args = ["series", "past_covariates", "future_covariates"] + fit_invalid_args = forbiden_args + [ + "val_series", + "val_past_covariates", + "val_future_covariates", + ] + _check_kwargs_keys( + param_name="fit_kwargs", + kwargs_dict=fit_kwargs, + invalid_keys=fit_invalid_args, + ) + predict_invalid_args = forbiden_args + ["n", "trainer", "num_samples"] + _check_kwargs_keys( + param_name="predict_kwargs", + kwargs_dict=predict_kwargs, + invalid_keys=predict_invalid_args, ) # compute all hyperparameter combinations from selection @@ -1513,7 +1536,12 @@ def _evaluate_combination(param_combination) -> float: model = model_class(**param_combination_dict) if use_fitted_values: # fitted value mode - model._fit_wrapper(series, past_covariates, future_covariates) + model._fit_wrapper( + series, + past_covariates, + future_covariates, + **fit_kwargs, + ) fitted_values = TimeSeries.from_times_and_values( series.time_index, model.fitted_values ) @@ -1536,7 +1564,9 @@ def _evaluate_combination(param_combination) -> float: predict_kwargs=predict_kwargs, ) else: # split mode - model._fit_wrapper(series, past_covariates, future_covariates) + model._fit_wrapper( + series, past_covariates, future_covariates, **fit_kwargs + ) pred = model._predict_wrapper( len(val_series), series, @@ -2209,7 +2239,6 @@ def predict( num_samples: int = 1, verbose: bool = False, predict_likelihood_parameters: bool = False, - num_loader_workers: int = 0, ) -> Union[TimeSeries, Sequence[TimeSeries]]: """Forecasts values for `n` time steps after the end of the series. @@ -2303,12 +2332,20 @@ def _predict_wrapper( verbose: bool = False, predict_likelihood_parameters: bool = False, num_loader_workers: int = 0, + batch_size: Optional[int] = None, + n_jobs: int = 1, + roll_size: Optional[int] = None, + mc_dropout: bool = False, ) -> Union[TimeSeries, Sequence[TimeSeries]]: kwargs = dict() if self.supports_likelihood_parameter_prediction: kwargs["predict_likelihood_parameters"] = predict_likelihood_parameters if getattr(self, "trainer_params", False): kwargs["num_loader_workers"] = num_loader_workers + kwargs["batch_size"] = batch_size + kwargs["n_jobs"] = n_jobs + kwargs["roll_size"] = roll_size + kwargs["mc_dropout"] = mc_dropout return self.predict( n, series, @@ -2324,13 +2361,31 @@ def _fit_wrapper( series: Union[TimeSeries, Sequence[TimeSeries]], past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]], future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]], + max_samples_per_ts: Optional[int] = None, + n_jobs_multioutput_wrapper: Optional[int] = None, + trainer=None, + verbose: Optional[bool] = None, + epochs: int = 0, + num_loader_workers: int = 0, ): + """Propagate the supported parameters to the underlying model fit() method""" + kwargs = dict() + if getattr(self, "trainer_params", False): + kwargs["trainer"] = trainer + kwargs["epochs"] = epochs + kwargs["verbose"] = verbose + kwargs["num_loader_workers"] = num_loader_workers + kwargs["max_samples_per_ts"] = max_samples_per_ts + elif isinstance(self, MultiOutputRegressor): + kwargs["n_jobs_multioutput_wrapper"] = n_jobs_multioutput_wrapper + kwargs["max_samples_per_ts"] = max_samples_per_ts self.fit( series=series, past_covariates=past_covariates if self.supports_past_covariates else None, future_covariates=future_covariates if self.supports_future_covariates else None, + **kwargs, ) @property @@ -2538,6 +2593,7 @@ def _fit_wrapper( series: TimeSeries, past_covariates: Optional[TimeSeries], future_covariates: Optional[TimeSeries], + **kwargs, ): self.fit(series, future_covariates=future_covariates) @@ -2550,13 +2606,10 @@ def _predict_wrapper( num_samples: int, verbose: bool = False, predict_likelihood_parameters: bool = False, - num_loader_workers: int = 0, ) -> TimeSeries: kwargs = dict() if self.supports_likelihood_parameter_prediction: kwargs["predict_likelihood_parameters"] = predict_likelihood_parameters - if getattr(self, "trainer_params", False): - kwargs["num_loader_workers"] = num_loader_workers return self.predict( n, future_covariates=future_covariates, @@ -2770,13 +2823,10 @@ def _predict_wrapper( num_samples: int, verbose: bool = False, predict_likelihood_parameters: bool = False, - num_loader_workers: int = 0, ) -> TimeSeries: kwargs = dict() if self.supports_likelihood_parameter_prediction: kwargs["predict_likelihood_parameters"] = predict_likelihood_parameters - if getattr(self, "trainer_params", False): - kwargs["num_loader_workers"] = num_loader_workers return self.predict( n=n, series=series, diff --git a/darts/models/utils.py b/darts/models/utils.py index 8d0d0d11ea..2196cb3e0d 100644 --- a/darts/models/utils.py +++ b/darts/models/utils.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, List + from darts.logging import get_logger, raise_log logger = get_logger(__name__) @@ -20,3 +22,15 @@ def __init__(self, module_name: str, warn: bool = True): def __call__(self, *args, **kwargs): raise_log(ImportError(self.error_message), logger=logger) + + +def _check_kwargs_keys( + param_name: str, kwargs_dict: Dict[str, Any], invalid_keys: List[str] +): + """Check if the dictionary contain any of the invalid key""" + invalid_args_passed = set(invalid_keys).intersection(set(kwargs_dict.keys())) + if len(invalid_args_passed) > 0: + raise_log( + f"`{param_name}` can't contain the following parameters : {list(invalid_args_passed)}.", + logger, + )