diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py index 2d4090fcff..1c3b6de34a 100644 --- a/src/spikeinterface/curation/train_manual_curation.py +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -95,7 +95,7 @@ class CurationModelTrainer: If True, useful information is printed during training. search_kwargs : dict or None, default: None Keyword arguments passed to `BayesSearchCV` or `RandomizedSearchCV` from `sklearn`. If None, use - `search_kwargs = {'cv': 3, 'scoring': 'balanced_accuracy', 'n_iter': 25}`. + `search_kwargs = {'cv': 5, 'scoring': 'balanced_accuracy', 'n_iter': 25}`. Attributes ---------- @@ -544,7 +544,8 @@ def _evaluate( ) test_accuracies, models = zip(*results) - self.test_accuracies_df = pd.DataFrame(test_accuracies).sort_values("accuracy", ascending=False) + scoring_method = self.search_kwargs.get("scoring") + self.test_accuracies_df = pd.DataFrame(test_accuracies).sort_values(scoring_method, ascending=False) best_model_id = int(self.test_accuracies_df.iloc[0]["model_id"]) best_model, best_imputer, best_scaler = models[best_model_id] @@ -598,8 +599,6 @@ def _train_and_evaluate( model, param_space = self.get_classifier_search_space(classifier.__class__.__name__) print("search kwargs:", search_kwargs, flush=True) try: - print("now trying the classifier search...") - from skopt import BayesSearchCV model = BayesSearchCV( @@ -614,7 +613,7 @@ def _train_and_evaluate( print("BayesSearchCV from scikit-optimize not available, using GridSearchCV") from sklearn.model_selection import RandomizedSearchCV - model = RandomizedSearchCV(model, param_space, n_jobs=self.n_jobs, **search_kwargs, verbose=5) + model = RandomizedSearchCV(model, param_space, n_jobs=self.n_jobs, **search_kwargs) model.fit(X_train_scaled, y_train) y_pred = model.predict(X_test_scaled) @@ -625,7 +624,7 @@ def _train_and_evaluate( "classifier name": classifier.__class__.__name__, "imputation_strategy": imputation_strategy, "scaling_strategy": scaler, - "accuracy": balanced_acc, + "balanced_accuracy": balanced_acc, "precision": precision, "recall": recall, "model_id": model_id, @@ -790,7 +789,7 @@ def set_default_search_kwargs(search_kwargs): search_kwargs = {} if search_kwargs.get("cv") is None: - search_kwargs["cv"] = 3 + search_kwargs["cv"] = 5 if search_kwargs.get("scoring") is None: search_kwargs["scoring"] = "balanced_accuracy" if search_kwargs.get("n_iter") is None: