Skip to content

Commit

Permalink
use gpu_hist for older xgb versions
Browse files Browse the repository at this point in the history
  • Loading branch information
rishic3 authored Nov 19, 2024
1 parent 74427fe commit 05cd103
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,12 @@ def get_gpu_id(task_context: TaskContext) -> int:
Xy_train_qdm = xgb.QuantileDMatrix(X_train, y_train) # Precompute Quantile DMatrix to avoid repeated quantization every trial.

def objective(trial):
params = ({
params = {
"objective": "reg:squarederror",
"verbosity": 0,
"tree_method": "hist",
"tree_method": "gpu_hist",
"device": f"cuda:{gpu_id}",
})
}
params.update(hyperparams.to_dict(trial))

if "max_bins" in params:
Expand Down

0 comments on commit 05cd103

Please sign in to comment.