Skip to content

Commit

Permalink
fix: gridsearch parameters check
Browse files Browse the repository at this point in the history
  • Loading branch information
madtoinou committed Nov 14, 2024
1 parent 8514feb commit 367829d
Showing 1 changed file with 29 additions and 15 deletions.
44 changes: 29 additions & 15 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1775,12 +1775,33 @@ def gridsearch(
)
)

if "model" in parameters:
valid_model_list = isinstance(parameters["model"], list)
valid_nested_params = parameters["model"].get(
"wrapped_model_class"
) and all(
isinstance(params, (list, np.ndarray))
for p_name, params in parameters["model"].items()
if p_name != "wrapped_model_class"
)
if not (valid_model_list or valid_nested_params):
raise_log(
ValueError(
"The 'model' entry in `parameters` must either be a list of instantiated models or "
"a dictionary containing as keys hyperparameter names, and as values lists of values "
"plus a 'wrapped_model_class': model_cls item.",
logger,
)
)

if not all(
isinstance(params, (list, np.ndarray)) for params in parameters.values()
isinstance(params, (list, np.ndarray))
for p_name, params in parameters.items()
if p_name != "model"
):
raise_log(
ValueError(
"Every value in the `parameters` dictionary should be a list or a np.ndarray."
"Every hyper-parameter value in the `parameters` dictionary should be a list or a np.ndarray."
),
logger,
)
Expand Down Expand Up @@ -1812,19 +1833,12 @@ def gridsearch(

if "model" in parameters:
# Ask if model has been passed as a dictionary. This implies that the arguments
# of the wrapped model must be passed to the grid. If 'model' is passed as a
# list of instances of scikit-learn models, the behavior should work like
# any argument passed to the Darts model."
if isinstance(parameters["model"], dict):
if "model_class" not in parameters["model"]:
raise_log(
ValueError(
"When the 'model' key is set as a dictionary, it must contain "
"the 'model_class' key, which represents the class of the model "
"to be wrapped."
)
)
wrapped_model_class = parameters["model"].pop("model_class")
# of the wrapped model must be passed to the grid.
if (
isinstance(parameters["model"], dict)
and "wrapped_model_class" in parameters["model"]
):
wrapped_model_class = parameters["model"].pop("wrapped_model_class")
# Create a flat dictionary by adding a suffix to the arguments of the wrapped model in
# order to distinguish them from the other arguments of the Darts model
parameters.update({
Expand Down

0 comments on commit 367829d

Please sign in to comment.