From 233d5981189dabe622ddf3b70dc51d1e65e0974e Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 4 Dec 2024 11:18:28 +0000 Subject: [PATCH] Refactor to take advantage of unified quality and template metrics --- .../curation/model_based_curation.py | 78 +++++++--------- .../curation/train_manual_curation.py | 90 +++++++++---------- 2 files changed, 71 insertions(+), 97 deletions(-) diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py index a8e9420e2c..5f3f705219 100644 --- a/src/spikeinterface/curation/model_based_curation.py +++ b/src/spikeinterface/curation/model_based_curation.py @@ -150,53 +150,37 @@ def _check_params_for_classification(self, enforce_metric_params=False, model_in Path to model_info.json provenance file """ - quality_metrics_extension = self.sorting_analyzer.get_extension("quality_metrics") - template_metrics_extension = self.sorting_analyzer.get_extension("template_metrics") - - if quality_metrics_extension is not None: - - model_quality_metrics_params = model_info["metric_params"]["quality_metric_params"] - quality_metrics_params = quality_metrics_extension.params["metric_params"] - - inconsistent_metrics = [] - for metric in model_quality_metrics_params["metric_names"]: - if metric not in model_quality_metrics_params["metric_params"]: - inconsistent_metrics += metric - else: - if quality_metrics_params[metric] != model_quality_metrics_params["metric_params"][metric]: - warning_message = f"Quality metric params for {metric} do not match those used to train the model. Parameters can be found in the 'model_info.json' file." - if enforce_metric_params is True: - raise Exception(warning_message) - else: - warnings.warn(warning_message) - - if len(inconsistent_metrics) > 0: - warning_message = ( - f"Parameters used to compute metrics {inconsistent_metrics}, used to train this model, are unknown." - ) - if enforce_metric_params is True: - raise Exception(warning_message) - else: - warnings.warn(warning_message) - - if template_metrics_extension is not None: - - model_template_metrics_params = model_info["metric_params"]["template_metric_params"]["metric_params"] - template_metrics_params = template_metrics_extension.params["metric_params"] - - if template_metrics_params == {}: - warning_message = "Parameters used to compute template metrics, used to train this model, are unknown." - if enforce_metric_params is True: - raise Exception(warning_message) - else: - warnings.warn(warning_message) - - if template_metrics_params != model_template_metrics_params: - warning_message = "Template metrics params do not match those used to train model. Parameters can be found in the 'model_info.json' file." - if enforce_metric_params is True: - raise Exception(warning_message) - else: - warnings.warn(warning_message) + extension_names = ["quality_metrics", "template_metrics"] + + metric_extensions = [self.sorting_analyzer.get_extension(extension_name) for extension_name in extension_names] + + for metric_extension, extension_name in zip(metric_extensions, extension_names): + + # remove the 's' at the end of the extension name + extension_name = extension_name[:-1] + if metric_extension is not None: + + model_metric_params = model_info["metric_params"][extension_name + "_params"] + metric_params = metric_extension.params["metric_params"] + + inconsistent_metrics = [] + for metric in model_metric_params["metric_names"]: + if metric not in model_metric_params["metric_params"]: + inconsistent_metrics += metric + else: + if metric_params[metric] != model_metric_params["metric_params"][metric]: + warning_message = f"{extension_name} params for {metric} do not match those used to train the model. Parameters can be found in the 'model_info.json' file." + if enforce_metric_params is True: + raise Exception(warning_message) + else: + warnings.warn(warning_message) + + if len(inconsistent_metrics) > 0: + warning_message = f"Parameters used to compute metrics {inconsistent_metrics}, used to train this model, are unknown." + if enforce_metric_params is True: + raise Exception(warning_message) + else: + warnings.warn(warning_message) def _export_to_phy(self, classified_units): """Export the classified units to Phy as cluster_prediction.tsv file""" diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py index 7183e86a65..0bcdd5dac1 100644 --- a/src/spikeinterface/curation/train_manual_curation.py +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -237,28 +237,32 @@ def load_and_preprocess_analyzers(self, analyzers, enforce_metric_params): conflicting_metrics = self._check_metrics_parameters(analyzers, enforce_metric_params) self.metrics_params = {} - if analyzers[0].has_extension("quality_metrics") is True: - self.metrics_params["quality_metric_params"] = analyzers[0].extensions["quality_metrics"].params - - # Only save metric params which are 1) consistent and 2) exist in metric_names - qm_names = self.metrics_params["quality_metric_params"]["metric_names"] - consistent_metrics = list(set(qm_names).difference(set(conflicting_metrics))) - consistent_metric_params = { - metric: analyzers[0].extensions["quality_metrics"].params["metric_params"][metric] - for metric in consistent_metrics - } - self.metrics_params["quality_metric_params"]["metric_params"] = consistent_metric_params - - if analyzers[0].has_extension("template_metrics") is True: - self.metrics_params["template_metric_params"] = deepcopy(analyzers[0].extensions["template_metrics"].params) - if "template_metrics" in conflicting_metrics: - self.metrics_params["template_metric_params"] = {} + + extension_names = ["quality_metrics", "template_metrics"] + metric_extensions = [analyzers[0].get_extension(extension_name) for extension_name in extension_names] + + for metric_extension, extension_name in zip(metric_extensions, extension_names): + + # remove the 's' at the end of the extension name + extension_name = extension_name[:-1] + if metric_extension is not None: + self.metrics_params[extension_name + "_params"] = metric_extension.params + + # Only save metric params which are 1) consistent and 2) exist in metric_names + metric_names = metric_extension.params["metric_names"] + consistent_metrics = list(set(metric_names).difference(set(conflicting_metrics))) + consistent_metric_params = { + metric: metric_extension.params["metric_params"][metric] for metric in consistent_metrics + } + self.metrics_params[extension_name + "_params"]["metric_params"] = consistent_metric_params self.process_test_data_for_classification() def _check_metrics_parameters(self, analyzers, enforce_metric_params): """Checks that the metrics of each analyzer have been calcualted using the same parameters""" + extension_names = ["quality_metrics", "template_metrics"] + conflicting_metrics = [] for analyzer_index_1, analyzer_1 in enumerate(analyzers): for analyzer_index_2, analyzer_2 in enumerate(analyzers): @@ -267,29 +271,20 @@ def _check_metrics_parameters(self, analyzers, enforce_metric_params): continue else: - qm_params_1 = {} - qm_params_2 = {} - tm_params_1 = {} - tm_params_2 = {} + metric_params_1 = {} + metric_params_2 = {} - if analyzer_1.has_extension("quality_metrics") is True: - qm_params_1 = analyzer_1.extensions["quality_metrics"].params["metric_params"] - if analyzer_2.has_extension("quality_metrics") is True: - qm_params_2 = analyzer_2.extensions["quality_metrics"].params["metric_params"] - if analyzer_1.has_extension("template_metrics") is True: - tm_params_1 = analyzer_1.extensions["template_metrics"].params["metric_params"] - if analyzer_2.has_extension("template_metrics") is True: - tm_params_2 = analyzer_2.extensions["template_metrics"].params["metric_params"] + for extension_name in extension_names: + if (extension_1 := analyzer_1.get_extension(extension_name)) is not None: + metric_params_1.update(extension_1.params["metric_params"]) + if (extension_2 := analyzer_2.get_extension(extension_name)) is not None: + metric_params_2.update(extension_2.params["metric_params"]) conflicting_metrics_between_1_2 = [] # check quality metrics params - for metric, params_1 in qm_params_1.items(): - if params_1 != qm_params_2.get(metric): + for metric, params_1 in metric_params_1.items(): + if params_1 != metric_params_2.get(metric): conflicting_metrics_between_1_2.append(metric) - # check template metric params - for metric, params_1 in tm_params_1.items(): - if params_1 != tm_params_2.get(metric): - conflicting_metrics_between_1_2.append("template_metrics") conflicting_metrics += conflicting_metrics_between_1_2 @@ -737,28 +732,23 @@ def _get_computed_metrics(sorting_analyzer): def try_to_get_metrics_from_analyzer(sorting_analyzer): - quality_metrics = None - template_metrics = None + extension_names = ["quality_metrics", "template_metrics"] + metric_extensions = [sorting_analyzer.get_extension(extension_name) for extension_name in extension_names] - # Try to get metrics if available - try: - quality_metrics = sorting_analyzer.get_extension("quality_metrics").get_data() - except: - pass - - try: - template_metrics = sorting_analyzer.get_extension("template_metrics").get_data() - except: - pass - - # Check if at least one of the metrics is available - if quality_metrics is None and template_metrics is None: + if any(metric_extensions) is False: raise ValueError( "At least one of quality metrics or template metrics must be computed before classification.", "Compute both using `sorting_analyzer.compute('quality_metrics', 'template_metrics')", ) - return quality_metrics, template_metrics + metric_extensions_data = [] + for metric_extension in metric_extensions: + try: + metric_extensions_data.append(metric_extension.get_data()) + except: + metric_extensions_data.append(None) + + return metric_extensions_data def set_default_search_kwargs(search_kwargs):