Skip to content

Commit

Permalink
fix: force tree_method argument since xgboost changed the default val…
Browse files Browse the repository at this point in the history
…ue to hist (#1990)
  • Loading branch information
madtoinou authored Sep 12, 2023
1 parent 74ed2bb commit 626ac36
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions darts/tests/models/forecasting/test_regression_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,10 @@ class TestRegressionModels:
LinearRegressionModel, likelihood="poisson", random_state=42
)
PoissonXGBModel = partialclass(
XGBModel,
likelihood="poisson",
random_state=42,
XGBModel, likelihood="poisson", random_state=42, tree_method="exact"
)
QuantileXGBModel = partialclass(
XGBModel,
likelihood="quantile",
random_state=42,
XGBModel, likelihood="quantile", random_state=42, tree_method="exact"
)
# targets for poisson regression must be positive, so we exclude them for some tests
models.extend(
Expand Down Expand Up @@ -1165,10 +1161,24 @@ def test_multioutput_validation(self):
lags = 4

models = [
XGBModel(lags=lags, output_chunk_length=1, multi_models=True),
XGBModel(lags=lags, output_chunk_length=1, multi_models=False),
XGBModel(lags=lags, output_chunk_length=2, multi_models=True),
XGBModel(lags=lags, output_chunk_length=2, multi_models=False),
XGBModel(
lags=lags, output_chunk_length=1, multi_models=True, tree_method="exact"
),
XGBModel(
lags=lags,
output_chunk_length=1,
multi_models=False,
tree_method="exact",
),
XGBModel(
lags=lags, output_chunk_length=2, multi_models=True, tree_method="exact"
),
XGBModel(
lags=lags,
output_chunk_length=2,
multi_models=False,
tree_method="exact",
),
]
if lgbm_available:
models += [
Expand Down

0 comments on commit 626ac36

Please sign in to comment.