Skip to content

Commit

Permalink
feat: increased the number of parameters handled by GlobalForecasting…
Browse files Browse the repository at this point in the history
…Models._fit_wrapper
  • Loading branch information
madtoinou committed Nov 3, 2023
1 parent a3817e7 commit 7353db2
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 26 deletions.
102 changes: 76 additions & 26 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@

import numpy as np
import pandas as pd
from sklearn.multioutput import MultiOutputRegressor

from darts import metrics
from darts.dataprocessing.encoders import SequentialEncoder
from darts.logging import get_logger, raise_if, raise_if_not, raise_log
from darts.models.utils import _check_kwargs_keys
from darts.timeseries import TimeSeries
from darts.utils import _build_tqdm_iterator, _parallel_apply, _with_sanity_checks
from darts.utils.historical_forecasts.utils import (
Expand Down Expand Up @@ -316,6 +318,7 @@ def _fit_wrapper(
series: TimeSeries,
past_covariates: Optional[TimeSeries],
future_covariates: Optional[TimeSeries],
**kwargs,
):
self.fit(series)

Expand All @@ -329,12 +332,20 @@ def _predict_wrapper(
verbose: bool = False,
predict_likelihood_parameters: bool = False,
num_loader_workers: int = 0,
batch_size: Optional[int] = None,
n_jobs: int = 1,
roll_size: Optional[int] = None,
mc_dropout: bool = False,
) -> TimeSeries:
kwargs = dict()
if self.supports_likelihood_parameter_prediction:
kwargs["predict_likelihood_parameters"] = predict_likelihood_parameters
if getattr(self, "trainer_params", False):
kwargs["num_loader_workers"] = num_loader_workers
kwargs["batch_size"] = batch_size
kwargs["n_jobs"] = n_jobs
kwargs["roll_size"] = roll_size
kwargs["mc_dropout"] = mc_dropout
return self.predict(n, num_samples=num_samples, verbose=verbose, **kwargs)

@property
Expand Down Expand Up @@ -846,21 +857,17 @@ def retrain_func(
"val_past_covariates",
"val_future_covariates",
]
fit_invalid_args = set(fit_invalid_args).intersection(set(fit_kwargs.keys()))
if len(fit_invalid_args) > 0:
raise_log(
f"`fit_kwargs` cannot contain the following parameters : {list(fit_invalid_args)}.",
logger,
)
_check_kwargs_keys(
param_name="fit_kwargs",
kwargs_dict=fit_kwargs,
invalid_keys=fit_invalid_args,
)
predict_invalid_args = forbiden_args + ["n", "trainer"]
predict_invalid_args = set(predict_invalid_args).intersection(
set(predict_kwargs.keys())
_check_kwargs_keys(
param_name="predict_kwargs",
kwargs_dict=predict_kwargs,
invalid_keys=predict_invalid_args,
)
if len(predict_invalid_args) > 0:
raise_log(
f"`predict_kwargs` cannot contain the following parameters : {list(predict_invalid_args)}.",
logger,
)

series = series2seq(series)
past_covariates = series2seq(past_covariates)
Expand Down Expand Up @@ -1330,6 +1337,7 @@ def gridsearch(
verbose=False,
n_jobs: int = 1,
n_random_samples: Optional[Union[int, float]] = None,
fit_kwargs: Optional[Dict[str, Any]] = None,
predict_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple["ForecastingModel", Dict[str, Any], float]:
"""
Expand Down Expand Up @@ -1479,12 +1487,27 @@ def gridsearch(
logger,
)

if fit_kwargs is None:
fit_kwargs = dict()
if predict_kwargs is None:
predict_kwargs = dict()
raise_if(
"num_samples" in predict_kwargs,
"`num_samples = 1` cannot be modified using `predict_kwargs`.",
logger,

forbiden_args = ["series", "past_covariates", "future_covariates"]
fit_invalid_args = forbiden_args + [
"val_series",
"val_past_covariates",
"val_future_covariates",
]
_check_kwargs_keys(
param_name="fit_kwargs",
kwargs_dict=fit_kwargs,
invalid_keys=fit_invalid_args,
)
predict_invalid_args = forbiden_args + ["n", "trainer", "num_samples"]
_check_kwargs_keys(
param_name="predict_kwargs",
kwargs_dict=predict_kwargs,
invalid_keys=predict_invalid_args,
)

# compute all hyperparameter combinations from selection
Expand Down Expand Up @@ -1513,7 +1536,12 @@ def _evaluate_combination(param_combination) -> float:

model = model_class(**param_combination_dict)
if use_fitted_values: # fitted value mode
model._fit_wrapper(series, past_covariates, future_covariates)
model._fit_wrapper(
series,
past_covariates,
future_covariates,
**fit_kwargs,
)
fitted_values = TimeSeries.from_times_and_values(
series.time_index, model.fitted_values
)
Expand All @@ -1536,7 +1564,9 @@ def _evaluate_combination(param_combination) -> float:
predict_kwargs=predict_kwargs,
)
else: # split mode
model._fit_wrapper(series, past_covariates, future_covariates)
model._fit_wrapper(
series, past_covariates, future_covariates, **fit_kwargs
)
pred = model._predict_wrapper(
len(val_series),
series,
Expand Down Expand Up @@ -2209,7 +2239,6 @@ def predict(
num_samples: int = 1,
verbose: bool = False,
predict_likelihood_parameters: bool = False,
num_loader_workers: int = 0,
) -> Union[TimeSeries, Sequence[TimeSeries]]:
"""Forecasts values for `n` time steps after the end of the series.
Expand Down Expand Up @@ -2303,12 +2332,20 @@ def _predict_wrapper(
verbose: bool = False,
predict_likelihood_parameters: bool = False,
num_loader_workers: int = 0,
batch_size: Optional[int] = None,
n_jobs: int = 1,
roll_size: Optional[int] = None,
mc_dropout: bool = False,
) -> Union[TimeSeries, Sequence[TimeSeries]]:
kwargs = dict()
if self.supports_likelihood_parameter_prediction:
kwargs["predict_likelihood_parameters"] = predict_likelihood_parameters
if getattr(self, "trainer_params", False):
kwargs["num_loader_workers"] = num_loader_workers
kwargs["batch_size"] = batch_size
kwargs["n_jobs"] = n_jobs
kwargs["roll_size"] = roll_size
kwargs["mc_dropout"] = mc_dropout
return self.predict(
n,
series,
Expand All @@ -2324,13 +2361,31 @@ def _fit_wrapper(
series: Union[TimeSeries, Sequence[TimeSeries]],
past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
max_samples_per_ts: Optional[int] = None,
n_jobs_multioutput_wrapper: Optional[int] = None,
trainer=None,
verbose: Optional[bool] = None,
epochs: int = 0,
num_loader_workers: int = 0,
):
"""Propagate the supported parameters to the underlying model fit() method"""
kwargs = dict()
if getattr(self, "trainer_params", False):
kwargs["trainer"] = trainer
kwargs["epochs"] = epochs
kwargs["verbose"] = verbose
kwargs["num_loader_workers"] = num_loader_workers
kwargs["max_samples_per_ts"] = max_samples_per_ts
elif isinstance(self, MultiOutputRegressor):
kwargs["n_jobs_multioutput_wrapper"] = n_jobs_multioutput_wrapper
kwargs["max_samples_per_ts"] = max_samples_per_ts
self.fit(
series=series,
past_covariates=past_covariates if self.supports_past_covariates else None,
future_covariates=future_covariates
if self.supports_future_covariates
else None,
**kwargs,
)

@property
Expand Down Expand Up @@ -2538,6 +2593,7 @@ def _fit_wrapper(
series: TimeSeries,
past_covariates: Optional[TimeSeries],
future_covariates: Optional[TimeSeries],
**kwargs,
):
self.fit(series, future_covariates=future_covariates)

Expand All @@ -2550,13 +2606,10 @@ def _predict_wrapper(
num_samples: int,
verbose: bool = False,
predict_likelihood_parameters: bool = False,
num_loader_workers: int = 0,
) -> TimeSeries:
kwargs = dict()
if self.supports_likelihood_parameter_prediction:
kwargs["predict_likelihood_parameters"] = predict_likelihood_parameters
if getattr(self, "trainer_params", False):
kwargs["num_loader_workers"] = num_loader_workers
return self.predict(
n,
future_covariates=future_covariates,
Expand Down Expand Up @@ -2770,13 +2823,10 @@ def _predict_wrapper(
num_samples: int,
verbose: bool = False,
predict_likelihood_parameters: bool = False,
num_loader_workers: int = 0,
) -> TimeSeries:
kwargs = dict()
if self.supports_likelihood_parameter_prediction:
kwargs["predict_likelihood_parameters"] = predict_likelihood_parameters
if getattr(self, "trainer_params", False):
kwargs["num_loader_workers"] = num_loader_workers
return self.predict(
n=n,
series=series,
Expand Down
14 changes: 14 additions & 0 deletions darts/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict, List

from darts.logging import get_logger, raise_log

logger = get_logger(__name__)
Expand All @@ -20,3 +22,15 @@ def __init__(self, module_name: str, warn: bool = True):

def __call__(self, *args, **kwargs):
raise_log(ImportError(self.error_message), logger=logger)


def _check_kwargs_keys(
param_name: str, kwargs_dict: Dict[str, Any], invalid_keys: List[str]
):
"""Check if the dictionary contain any of the invalid key"""
invalid_args_passed = set(invalid_keys).intersection(set(kwargs_dict.keys()))
if len(invalid_args_passed) > 0:
raise_log(
f"`{param_name}` can't contain the following parameters : {list(invalid_args_passed)}.",
logger,
)

0 comments on commit 7353db2

Please sign in to comment.