From fee5f5be28042011608f013949a7e63382f8993a Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Dec 2024 16:05:51 +0000 Subject: [PATCH] Write metric names for csv mode, and don't check params if not in model_info --- src/spikeinterface/curation/model_based_curation.py | 8 +++++--- src/spikeinterface/curation/train_manual_curation.py | 3 +++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py index ba88f9e5f9..26a7e4d724 100644 --- a/src/spikeinterface/curation/model_based_curation.py +++ b/src/spikeinterface/curation/model_based_curation.py @@ -109,7 +109,8 @@ def predict_labels( probabilities = np.max(probabilities, axis=1) if isinstance(label_conversion, dict): - if set(predictions) != set(label_conversion.keys()): + + if set(predictions).issubset(set(label_conversion.keys())) is False: raise ValueError("Labels in predictions do not match those in label_conversion") predictions = [label_conversion[label] for label in predictions] @@ -161,9 +162,10 @@ def _check_params_for_classification(self, enforce_metric_params=False, model_in # remove the 's' at the end of the extension name extension_name = extension_name[:-1] - if metric_extension is not None: + model_metric_params = model_info["metric_params"].get(extension_name + "_params") + + if metric_extension is not None and model_metric_params is not None: - model_metric_params = model_info["metric_params"][extension_name + "_params"] metric_params = metric_extension.params["metric_params"] inconsistent_metrics = [] diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py index 0edc9cf913..cd2904587e 100644 --- a/src/spikeinterface/curation/train_manual_curation.py +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -301,6 +301,9 @@ def _check_metrics_parameters(self, analyzers, enforce_metric_params): def load_and_preprocess_csv(self, paths): self._load_data_files(paths) self.process_test_data_for_classification() + self.metrics_params = {} + for metric_name in self.metric_names: + self.metrics_params[metric_name] = {} def process_test_data_for_classification(self): """