diff --git a/darts/tests/utils/historical_forecasts/test_historical_forecasts.py b/darts/tests/utils/historical_forecasts/test_historical_forecasts.py index 9e54638192..66dc2f8542 100644 --- a/darts/tests/utils/historical_forecasts/test_historical_forecasts.py +++ b/darts/tests/utils/historical_forecasts/test_historical_forecasts.py @@ -2715,7 +2715,7 @@ def helper_get_model_params( self, model_cls, series: dict, output_chunk_length: int ) -> dict: model_params = {} - if model_cls in [NLinearModel]: + if TORCH_AVAILABLE and issubclass(model_cls, NLinearModel): model_params["input_chunk_length"] = 5 model_params["output_chunk_length"] = output_chunk_length model_params["n_epochs"] = 1 @@ -2724,7 +2724,7 @@ def helper_get_model_params( **model_params, **tfm_kwargs, } - elif model_cls in [LinearRegressionModel]: + elif issubclass(model_cls, LinearRegressionModel): model_params["lags"] = 5 model_params["output_chunk_length"] = output_chunk_length if "past_covariates" in series: