diff --git a/src/train.py b/src/train.py index 0b1a646..a2180fe 100644 --- a/src/train.py +++ b/src/train.py @@ -221,6 +221,7 @@ def train( objective='binary:logistic', # For binary classification n_jobs=-1, random_state=42, + tree_method='gpu_hist', # Use GPU optimized histogram algorithm ) # clfs = [('Naive Bayes',nb_model),('Logistic Regression',lr_model),('Random Forest',rf_model)]