diff --git a/darts/tests/models/forecasting/test_regression_models.py b/darts/tests/models/forecasting/test_regression_models.py index 1fecded0f4..5601f8f2d4 100644 --- a/darts/tests/models/forecasting/test_regression_models.py +++ b/darts/tests/models/forecasting/test_regression_models.py @@ -179,14 +179,10 @@ class TestRegressionModels: LinearRegressionModel, likelihood="poisson", random_state=42 ) PoissonXGBModel = partialclass( - XGBModel, - likelihood="poisson", - random_state=42, + XGBModel, likelihood="poisson", random_state=42, tree_method="exact" ) QuantileXGBModel = partialclass( - XGBModel, - likelihood="quantile", - random_state=42, + XGBModel, likelihood="quantile", random_state=42, tree_method="exact" ) # targets for poisson regression must be positive, so we exclude them for some tests models.extend( @@ -1165,10 +1161,24 @@ def test_multioutput_validation(self): lags = 4 models = [ - XGBModel(lags=lags, output_chunk_length=1, multi_models=True), - XGBModel(lags=lags, output_chunk_length=1, multi_models=False), - XGBModel(lags=lags, output_chunk_length=2, multi_models=True), - XGBModel(lags=lags, output_chunk_length=2, multi_models=False), + XGBModel( + lags=lags, output_chunk_length=1, multi_models=True, tree_method="exact" + ), + XGBModel( + lags=lags, + output_chunk_length=1, + multi_models=False, + tree_method="exact", + ), + XGBModel( + lags=lags, output_chunk_length=2, multi_models=True, tree_method="exact" + ), + XGBModel( + lags=lags, + output_chunk_length=2, + multi_models=False, + tree_method="exact", + ), ] if lgbm_available: models += [