Skip to content

Commit

Permalink
Fix LOCO RF
Browse files Browse the repository at this point in the history
  • Loading branch information
achamma723 committed Jul 12, 2024
1 parent 9131ebb commit 389ec5d
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions hidimstat/importance_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def compute_loco(X, y, ntree=100, problem_type="regression", use_dnn=True, seed=
)
else:
if problem_type == "classification":
clf_rf_full = RandomForestClassifier(n_estimators=ntree, random_state=seed)
clf_rf_full = GridSearchCV(
RandomForestClassifier(n_estimators=ntree, random_state=seed),
param_grid=[{"max_depth": [2, 5, 10]}],
cv=5,
)
else:
clf_rf_full = GridSearchCV(
RandomForestRegressor(n_estimators=ntree, random_state=seed),
Expand Down Expand Up @@ -90,12 +94,16 @@ def compute_loco(X, y, ntree=100, problem_type="regression", use_dnn=True, seed=
)
else:
if problem_type == "classification":
clf_rf_retrain = RandomForestClassifier(
n_estimators=ntree, random_state=seed
clf_rf_retrain = GridSearchCV(
RandomForestClassifier(n_estimators=ntree, random_state=seed),
param_grid=[{"max_depth": [2, 5, 10]}],
cv=5,
)
else:
clf_rf_retrain = RandomForestRegressor(
n_estimators=ntree, random_state=seed
clf_rf_retrain = GridSearchCV(
RandomForestRegressor(n_estimators=ntree, random_state=seed),
param_grid=[{"max_depth": [2, 5, 10]}],
cv=5,
)

print(f"Processing col: {col+1}")
Expand Down

0 comments on commit 389ec5d

Please sign in to comment.