diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index fef35bfc59..15b8c85e38 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -65,7 +65,7 @@ class ComputeTemplateMetrics(AnalyzerExtension): include_multi_channel_metrics : bool, default: False Whether to compute multi-channel metrics delete_existing_metrics : bool, default: False - If True, deletes any quality_metrics attached to the `sorting_analyzer` + If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged. metrics_kwargs : dict Additional arguments to pass to the metric functions. Including: * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 @@ -116,6 +116,7 @@ def _set_params( delete_existing_metrics=False, **other_kwargs, ): + import pandas as pd # TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory() @@ -135,6 +136,10 @@ def _set_params( if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() + # `run` cannot take parameters, so need to find another way to pass this + self.delete_existing_metrics = delete_existing_metrics + self.metric_names = metric_names + if metrics_kwargs is None: metrics_kwargs_ = _default_function_kwargs.copy() if len(other_kwargs) > 0: @@ -145,29 +150,24 @@ def _set_params( metrics_kwargs_ = _default_function_kwargs.copy() metrics_kwargs_.update(metrics_kwargs) - all_metric_names = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") if delete_existing_metrics is False and tm_extension is not None: - existing_metric_names = tm_extension.params["metric_names"] - existing_params = tm_extension.params["metrics_kwargs"] + existing_params = tm_extension.params["metrics_kwargs"] # checks that existing metrics were calculated using the same params if existing_params != metrics_kwargs_: warnings.warn( "The parameters used to calculate the previous template metrics are different than those used now. Deleting previous template metrics..." ) - self.sorting_analyzer.get_extension("template_metrics").data["metrics"] = pd.DataFrame( - index=self.sorting_analyzer.unit_ids - ) - self.sorting_analyzer.get_extension("template_metrics").params["metric_names"] = [] + tm_extension.params["metric_names"] = [] existing_metric_names = [] + else: + existing_metric_names = tm_extension.params["metric_names"] - for metric_name in existing_metric_names: - if metric_name not in metric_names: - all_metric_names.append(metric_name) + metric_names = list(set(existing_metric_names + metric_names)) params = dict( - metric_names=[str(name) for name in np.unique(all_metric_names)], + metric_names=metric_names, sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), @@ -185,6 +185,7 @@ def _merge_extension_data( ): import pandas as pd + metric_names = self.params["metric_names"] old_metrics = self.data["metrics"] all_unit_ids = new_sorting_analyzer.unit_ids @@ -193,19 +194,20 @@ def _merge_extension_data( metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs) + metrics.loc[new_unit_ids, :] = self._compute_metrics( + new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs + ) new_data = dict(metrics=metrics) return new_data - def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job_kwargs): + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute template metrics. """ import pandas as pd from scipy.signal import resample_poly - metric_names = self.params["metric_names"] sparsity = self.params["sparsity"] peak_sign = self.params["peak_sign"] upsampling_factor = self.params["upsampling_factor"] @@ -308,19 +310,30 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job template_metrics.at[index, metric_name] = value return template_metrics - def _run(self, delete_existing_metrics=False, verbose=False): - self.data["metrics"] = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose - ) + def _run(self, verbose=False): + + delete_existing_metrics = self.delete_existing_metrics + metric_names = self.metric_names + existing_metrics = [] tm_extension = self.sorting_analyzer.get_extension("template_metrics") - if delete_existing_metrics is False and tm_extension is not None: + if ( + delete_existing_metrics is False + and tm_extension is not None + and tm_extension.data.get("metrics") is not None + ): existing_metrics = tm_extension.params["metric_names"] - for metric_name in existing_metrics: - if metric_name not in self.data["metrics"]: - metric_data = tm_extension.get_data()[metric_name] - self.data["metrics"][metric_name] = metric_data + # compute the metrics which have been specified by the user + computed_metrics = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metric_names + ) + + # append the metrics which were previously computed + for metric_name in set(existing_metrics).difference(metric_names): + computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] + + self.data["metrics"] = computed_metrics def _get_data(self): return self.data["metrics"] diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index f444e12c36..1fa2ac638c 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -1,18 +1,22 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeTemplateMetrics import pytest +import csv def test_compute_new_template_metrics(small_sorting_analyzer): """ - Computes template metrics then computes a subset of quality metrics, and checks - that the old quality metrics are not deleted. + Computes template metrics then computes a subset of template metrics, and checks + that the old template metrics are not deleted. Then computes template metrics with new parameters and checks that old metrics are deleted. """ + # calculate all template metrics small_sorting_analyzer.compute("template_metrics") + + # calculate just exp_decay - this should not delete the previously calculated metrics small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") @@ -33,6 +37,51 @@ def test_compute_new_template_metrics(small_sorting_analyzer): assert small_sorting_analyzer.get_extension("quality_metrics") is None +def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): + """ + Computes template metrics in binary folder format. Then computes subsets of template + metrics and checks if they are saved correctly. + """ + + from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func + + small_sorting_analyzer.compute("template_metrics") + + cache_folder = create_cache_folder + output_folder = cache_folder / "sorting_analyzer" + + folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) + template_metrics_filename = output_folder / "extensions" / "template_metrics" / "metrics.csv" + + with open(template_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in list(_single_channel_metric_name_to_func.keys()): + assert metric_name in metric_names + + folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=False) + + with open(template_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in list(_single_channel_metric_name_to_func.keys()): + assert metric_name in metric_names + + folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=True) + + with open(template_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in list(_single_channel_metric_name_to_func.keys()): + if metric_name == "half_width": + assert metric_name in metric_names + else: + assert metric_name not in metric_names + + class TestTemplateMetrics(AnalyzerExtensionCommonTestSuite): @pytest.mark.parametrize(