From 389ec5d84b30d03801283184116e93840192517a Mon Sep 17 00:00:00 2001 From: achamma723 Date: Sat, 13 Jul 2024 00:41:11 +0200 Subject: [PATCH] Fix LOCO RF --- hidimstat/importance_functions.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/hidimstat/importance_functions.py b/hidimstat/importance_functions.py index f8fb4e8..5baec5d 100644 --- a/hidimstat/importance_functions.py +++ b/hidimstat/importance_functions.py @@ -47,7 +47,11 @@ def compute_loco(X, y, ntree=100, problem_type="regression", use_dnn=True, seed= ) else: if problem_type == "classification": - clf_rf_full = RandomForestClassifier(n_estimators=ntree, random_state=seed) + clf_rf_full = GridSearchCV( + RandomForestClassifier(n_estimators=ntree, random_state=seed), + param_grid=[{"max_depth": [2, 5, 10]}], + cv=5, + ) else: clf_rf_full = GridSearchCV( RandomForestRegressor(n_estimators=ntree, random_state=seed), @@ -90,12 +94,16 @@ def compute_loco(X, y, ntree=100, problem_type="regression", use_dnn=True, seed= ) else: if problem_type == "classification": - clf_rf_retrain = RandomForestClassifier( - n_estimators=ntree, random_state=seed + clf_rf_retrain = GridSearchCV( + RandomForestClassifier(n_estimators=ntree, random_state=seed), + param_grid=[{"max_depth": [2, 5, 10]}], + cv=5, ) else: - clf_rf_retrain = RandomForestRegressor( - n_estimators=ntree, random_state=seed + clf_rf_retrain = GridSearchCV( + RandomForestRegressor(n_estimators=ntree, random_state=seed), + param_grid=[{"max_depth": [2, 5, 10]}], + cv=5, ) print(f"Processing col: {col+1}")