diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 062b0bd76b..5f4c1e904b 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -136,9 +136,6 @@ def _set_params( if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() - # `run` cannot take parameters, so need to find another way to pass this - metric_names_to_compute = metric_names - if metrics_kwargs is None: metrics_kwargs_ = _default_function_kwargs.copy() if len(other_kwargs) > 0: @@ -149,6 +146,7 @@ def _set_params( metrics_kwargs_ = _default_function_kwargs.copy() metrics_kwargs_.update(metrics_kwargs) + metrics_to_compute = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") if delete_existing_metrics is False and tm_extension is not None: @@ -164,9 +162,9 @@ def _set_params( existing_metric_names = tm_extension.params["metric_names"] existing_metric_names_propogated = [ - metric_name for metric_name in existing_metric_names if metric_name not in metric_names_to_compute + metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute ] - metric_names = metric_names_to_compute + existing_metric_names_propogated + metric_names = metrics_to_compute + existing_metric_names_propogated params = dict( metric_names=metric_names, @@ -175,7 +173,7 @@ def _set_params( upsampling_factor=int(upsampling_factor), metrics_kwargs=metrics_kwargs_, delete_existing_metrics=delete_existing_metrics, - metric_names_to_compute=metric_names_to_compute, + metrics_to_compute=metrics_to_compute, ) return params @@ -317,7 +315,12 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri def _run(self, verbose=False): delete_existing_metrics = self.params["delete_existing_metrics"] - metric_names_to_compute = self.params["metric_names_to_compute"] + metrics_to_compute = self.params["metrics_to_compute"] + + # compute the metrics which have been specified by the user + computed_metrics = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metrics_to_compute + ) existing_metrics = [] tm_extension = self.sorting_analyzer.get_extension("template_metrics") @@ -328,13 +331,8 @@ def _run(self, verbose=False): ): existing_metrics = tm_extension.params["metric_names"] - # compute the metrics which have been specified by the user - computed_metrics = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metric_names_to_compute - ) - # append the metrics which were previously computed - for metric_name in set(existing_metrics).difference(metric_names_to_compute): + for metric_name in set(existing_metrics).difference(metrics_to_compute): computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] self.data["metrics"] = computed_metrics diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 31353df724..1c7483212a 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -11,7 +11,12 @@ from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from .quality_metric_list import compute_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names +from .quality_metric_list import ( + compute_pc_metrics, + _misc_metric_name_to_func, + _possible_pc_metric_names, + compute_name_to_column_names, +) from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params @@ -32,7 +37,7 @@ class ComputeQualityMetrics(AnalyzerExtension): skip_pc_metrics : bool, default: False If True, PC metrics computation is skipped. delete_existing_metrics : bool, default: False - If True, deletes any quality_metrics attached to the `sorting_analyzer` + If True, any quality metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept. Returns ------- @@ -81,21 +86,24 @@ def _set_params( if "peak_sign" in qm_params_[k] and peak_sign is not None: qm_params_[k]["peak_sign"] = peak_sign - all_metric_names = metric_names + metrics_to_compute = metric_names qm_extension = self.sorting_analyzer.get_extension("quality_metrics") if delete_existing_metrics is False and qm_extension is not None: - existing_params = qm_extension.params - for metric_name in existing_params["metric_names"]: - if metric_name not in metric_names: - all_metric_names.append(metric_name) - qm_params_[metric_name] = existing_params["qm_params"][metric_name] + + existing_metric_names = qm_extension.params["metric_names"] + existing_metric_names_propogated = [ + metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute + ] + metric_names = metrics_to_compute + existing_metric_names_propogated params = dict( - metric_names=[str(name) for name in np.unique(all_metric_names)], + metric_names=metric_names, peak_sign=peak_sign, seed=seed, qm_params=qm_params_, skip_pc_metrics=skip_pc_metrics, + delete_existing_metrics=delete_existing_metrics, + metrics_to_compute=metrics_to_compute, ) return params @@ -123,11 +131,11 @@ def _merge_extension_data( new_data = dict(metrics=metrics) return new_data - def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job_kwargs): + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute quality metrics. """ - metric_names = self.params["metric_names"] + qm_params = self.params["qm_params"] # sparsity = self.params["sparsity"] seed = self.params["seed"] @@ -203,17 +211,35 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job return metrics - def _run(self, verbose=False, delete_existing_metrics=False, **job_kwargs): - self.data["metrics"] = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, **job_kwargs + def _run(self, verbose=False, **job_kwargs): + + metrics_to_compute = self.params["metrics_to_compute"] + delete_existing_metrics = self.params["delete_existing_metrics"] + + computed_metrics = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, + unit_ids=None, + verbose=verbose, + metric_names=metrics_to_compute, + **job_kwargs, ) + existing_metrics = [] qm_extension = self.sorting_analyzer.get_extension("quality_metrics") - if delete_existing_metrics is False and qm_extension is not None: - existing_metrics = qm_extension.get_data() - for metric_name, metric_data in existing_metrics.items(): - if metric_name not in self.data["metrics"]: - self.data["metrics"][metric_name] = metric_data + if ( + delete_existing_metrics is False + and qm_extension is not None + and qm_extension.data.get("metrics") is not None + ): + existing_metrics = qm_extension.params["metric_names"] + + # append the metrics which were previously computed + for metric_name in set(existing_metrics).difference(metrics_to_compute): + # some metrics names produce data columns with other names. This deals with that. + for column_name in compute_name_to_column_names[metric_name]: + computed_metrics[column_name] = qm_extension.data["metrics"][column_name] + + self.data["metrics"] = computed_metrics def _get_data(self): return self.data["metrics"] diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 140ad87a8b..375dd320ae 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -53,3 +53,29 @@ "drift": compute_drift_metrics, "sd_ratio": compute_sd_ratio, } + +# a dict converting the name of the metric for computation to the output of that computation +compute_name_to_column_names = { + "num_spikes": ["num_spikes"], + "firing_rate": ["firing_rate"], + "presence_ratio": ["presence_ratio"], + "snr": ["snr"], + "isi_violation": ["isi_violations_ratio", "isi_violations_count"], + "rp_violation": ["rp_violations", "rp_contamination"], + "sliding_rp_violation": ["sliding_rp_violation"], + "amplitude_cutoff": ["amplitude_cutoff"], + "amplitude_median": ["amplitude_median"], + "amplitude_cv": ["amplitude_cv_median", "amplitude_cv_range"], + "synchrony": ["sync_spike_2", "sync_spike_4", "sync_spike_8"], + "firing_range": ["firing_range"], + "drift": ["drift_ptp", "drift_std", "drift_mad"], + "sd_ratio": ["sd_ratio"], + "isolation_distance": ["isolation_distance"], + "l_ratio": ["l_ratio"], + "d_prime": ["d_prime"], + "nearest_neighbor": ["nn_hit_rate", "nn_miss_rate"], + "nn_isolation": ["nn_isolation", "nn_unit_id"], + "nn_noise_overlap": ["nn_noise_overlap"], + "silhouette": ["silhouette"], + "silhouette_full": ["silhouette_full"], +} diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index e34c15c936..77909798a3 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -2,6 +2,7 @@ from pathlib import Path import numpy as np from copy import deepcopy +import csv from spikeinterface.core import ( NumpySorting, synthetize_spike_train_bad_isi, @@ -42,6 +43,7 @@ compute_quality_metrics, ) + from spikeinterface.core.basesorting import minimum_spike_dtype @@ -60,6 +62,12 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): "firing_range": {"bin_size_s": 1}, } + small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) + qm_extension = small_sorting_analyzer.get_extension("quality_metrics") + calculated_metrics = list(qm_extension.get_data().keys()) + + assert calculated_metrics == ["snr"] + small_sorting_analyzer.compute( {"quality_metrics": {"metric_names": list(qm_params.keys()), "qm_params": qm_params}} ) @@ -68,18 +76,22 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") # Check old metrics are not deleted and the new one is added to the data and metadata - assert list(quality_metric_extension.get_data().keys()) == [ - "amplitude_cutoff", - "firing_range", - "presence_ratio", - "snr", - ] - assert list(quality_metric_extension.params.get("metric_names")) == [ - "amplitude_cutoff", - "firing_range", - "presence_ratio", - "snr", - ] + assert set(list(quality_metric_extension.get_data().keys())) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) + assert set(list(quality_metric_extension.params.get("metric_names"))) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) # check that, when parameters are changed, the data and metadata are updated old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) @@ -106,6 +118,92 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): assert small_sorting_analyzer.get_extension("quality_metrics") is None +def test_metric_names_in_same_order(small_sorting_analyzer): + """ + Computes sepecified quality metrics and checks order is propogated. + """ + specified_metric_names = ["firing_range", "snr", "amplitude_cutoff"] + small_sorting_analyzer.compute("quality_metrics", metric_names=specified_metric_names) + qm_keys = small_sorting_analyzer.get_extension("quality_metrics").get_data().keys() + for i in range(3): + assert specified_metric_names[i] == qm_keys[i] + + +def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): + """ + Computes quality metrics in binary folder format. Then computes subsets of quality + metrics and checks if they are saved correctly. + """ + + # can't use _misc_metric_name_to_func as some functions compute several qms + # e.g. isi_violation and synchrony + quality_metrics = [ + "num_spikes", + "firing_rate", + "presence_ratio", + "snr", + "isi_violations_ratio", + "isi_violations_count", + "rp_contamination", + "rp_violations", + "sliding_rp_violation", + "amplitude_cutoff", + "amplitude_median", + "amplitude_cv_median", + "amplitude_cv_range", + "sync_spike_2", + "sync_spike_4", + "sync_spike_8", + "firing_range", + "drift_ptp", + "drift_std", + "drift_mad", + "sd_ratio", + "isolation_distance", + "l_ratio", + "d_prime", + "silhouette", + "nn_hit_rate", + "nn_miss_rate", + ] + + small_sorting_analyzer.compute("quality_metrics") + + cache_folder = create_cache_folder + output_folder = cache_folder / "sorting_analyzer" + + folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) + quality_metrics_filename = output_folder / "extensions" / "quality_metrics" / "metrics.csv" + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=False) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=True) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + if metric_name == "snr": + assert metric_name in metric_names + else: + assert metric_name not in metric_names + + def test_unit_structure_in_output(small_sorting_analyzer): qm_params = {