Skip to content

Commit

Permalink
Update quality metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Sep 11, 2024
1 parent 4589c6e commit accf40a
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 44 deletions.
24 changes: 11 additions & 13 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,6 @@ def _set_params(
if include_multi_channel_metrics:
metric_names += get_multi_channel_template_metric_names()

# `run` cannot take parameters, so need to find another way to pass this
metric_names_to_compute = metric_names

if metrics_kwargs is None:
metrics_kwargs_ = _default_function_kwargs.copy()
if len(other_kwargs) > 0:
Expand All @@ -149,6 +146,7 @@ def _set_params(
metrics_kwargs_ = _default_function_kwargs.copy()
metrics_kwargs_.update(metrics_kwargs)

metrics_to_compute = metric_names
tm_extension = self.sorting_analyzer.get_extension("template_metrics")
if delete_existing_metrics is False and tm_extension is not None:

Expand All @@ -164,9 +162,9 @@ def _set_params(
existing_metric_names = tm_extension.params["metric_names"]

existing_metric_names_propogated = [
metric_name for metric_name in existing_metric_names if metric_name not in metric_names_to_compute
metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute
]
metric_names = metric_names_to_compute + existing_metric_names_propogated
metric_names = metrics_to_compute + existing_metric_names_propogated

params = dict(
metric_names=metric_names,
Expand All @@ -175,7 +173,7 @@ def _set_params(
upsampling_factor=int(upsampling_factor),
metrics_kwargs=metrics_kwargs_,
delete_existing_metrics=delete_existing_metrics,
metric_names_to_compute=metric_names_to_compute,
metrics_to_compute=metrics_to_compute,
)

return params
Expand Down Expand Up @@ -317,7 +315,12 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
def _run(self, verbose=False):

delete_existing_metrics = self.params["delete_existing_metrics"]
metric_names_to_compute = self.params["metric_names_to_compute"]
metrics_to_compute = self.params["metrics_to_compute"]

# compute the metrics which have been specified by the user
computed_metrics = self._compute_metrics(
sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metrics_to_compute
)

existing_metrics = []
tm_extension = self.sorting_analyzer.get_extension("template_metrics")
Expand All @@ -328,13 +331,8 @@ def _run(self, verbose=False):
):
existing_metrics = tm_extension.params["metric_names"]

# compute the metrics which have been specified by the user
computed_metrics = self._compute_metrics(
sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metric_names_to_compute
)

# append the metrics which were previously computed
for metric_name in set(existing_metrics).difference(metric_names_to_compute):
for metric_name in set(existing_metrics).difference(metrics_to_compute):
computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name]

self.data["metrics"] = computed_metrics
Expand Down
64 changes: 45 additions & 19 deletions src/spikeinterface/qualitymetrics/quality_metric_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension


from .quality_metric_list import compute_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names
from .quality_metric_list import (
compute_pc_metrics,
_misc_metric_name_to_func,
_possible_pc_metric_names,
compute_name_to_column_names,
)
from .misc_metrics import _default_params as misc_metrics_params
from .pca_metrics import _default_params as pca_metrics_params

Expand All @@ -32,7 +37,7 @@ class ComputeQualityMetrics(AnalyzerExtension):
skip_pc_metrics : bool, default: False
If True, PC metrics computation is skipped.
delete_existing_metrics : bool, default: False
If True, deletes any quality_metrics attached to the `sorting_analyzer`
If True, any quality metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept.
Returns
-------
Expand Down Expand Up @@ -81,21 +86,24 @@ def _set_params(
if "peak_sign" in qm_params_[k] and peak_sign is not None:
qm_params_[k]["peak_sign"] = peak_sign

all_metric_names = metric_names
metrics_to_compute = metric_names
qm_extension = self.sorting_analyzer.get_extension("quality_metrics")
if delete_existing_metrics is False and qm_extension is not None:
existing_params = qm_extension.params
for metric_name in existing_params["metric_names"]:
if metric_name not in metric_names:
all_metric_names.append(metric_name)
qm_params_[metric_name] = existing_params["qm_params"][metric_name]

existing_metric_names = qm_extension.params["metric_names"]
existing_metric_names_propogated = [
metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute
]
metric_names = metrics_to_compute + existing_metric_names_propogated

params = dict(
metric_names=[str(name) for name in np.unique(all_metric_names)],
metric_names=metric_names,
peak_sign=peak_sign,
seed=seed,
qm_params=qm_params_,
skip_pc_metrics=skip_pc_metrics,
delete_existing_metrics=delete_existing_metrics,
metrics_to_compute=metrics_to_compute,
)

return params
Expand Down Expand Up @@ -123,11 +131,11 @@ def _merge_extension_data(
new_data = dict(metrics=metrics)
return new_data

def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job_kwargs):
def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs):
"""
Compute quality metrics.
"""
metric_names = self.params["metric_names"]

qm_params = self.params["qm_params"]
# sparsity = self.params["sparsity"]
seed = self.params["seed"]
Expand Down Expand Up @@ -203,17 +211,35 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job

return metrics

def _run(self, verbose=False, delete_existing_metrics=False, **job_kwargs):
self.data["metrics"] = self._compute_metrics(
sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, **job_kwargs
def _run(self, verbose=False, **job_kwargs):

metrics_to_compute = self.params["metrics_to_compute"]
delete_existing_metrics = self.params["delete_existing_metrics"]

computed_metrics = self._compute_metrics(
sorting_analyzer=self.sorting_analyzer,
unit_ids=None,
verbose=verbose,
metric_names=metrics_to_compute,
**job_kwargs,
)

existing_metrics = []
qm_extension = self.sorting_analyzer.get_extension("quality_metrics")
if delete_existing_metrics is False and qm_extension is not None:
existing_metrics = qm_extension.get_data()
for metric_name, metric_data in existing_metrics.items():
if metric_name not in self.data["metrics"]:
self.data["metrics"][metric_name] = metric_data
if (
delete_existing_metrics is False
and qm_extension is not None
and qm_extension.data.get("metrics") is not None
):
existing_metrics = qm_extension.params["metric_names"]

# append the metrics which were previously computed
for metric_name in set(existing_metrics).difference(metrics_to_compute):
# some metrics names produce data columns with other names. This deals with that.
for column_name in compute_name_to_column_names[metric_name]:
computed_metrics[column_name] = qm_extension.data["metrics"][column_name]

self.data["metrics"] = computed_metrics

def _get_data(self):
return self.data["metrics"]
Expand Down
26 changes: 26 additions & 0 deletions src/spikeinterface/qualitymetrics/quality_metric_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,29 @@
"drift": compute_drift_metrics,
"sd_ratio": compute_sd_ratio,
}

# a dict converting the name of the metric for computation to the output of that computation
compute_name_to_column_names = {
"num_spikes": ["num_spikes"],
"firing_rate": ["firing_rate"],
"presence_ratio": ["presence_ratio"],
"snr": ["snr"],
"isi_violation": ["isi_violations_ratio", "isi_violations_count"],
"rp_violation": ["rp_violations", "rp_contamination"],
"sliding_rp_violation": ["sliding_rp_violation"],
"amplitude_cutoff": ["amplitude_cutoff"],
"amplitude_median": ["amplitude_median"],
"amplitude_cv": ["amplitude_cv_median", "amplitude_cv_range"],
"synchrony": ["sync_spike_2", "sync_spike_4", "sync_spike_8"],
"firing_range": ["firing_range"],
"drift": ["drift_ptp", "drift_std", "drift_mad"],
"sd_ratio": ["sd_ratio"],
"isolation_distance": ["isolation_distance"],
"l_ratio": ["l_ratio"],
"d_prime": ["d_prime"],
"nearest_neighbor": ["nn_hit_rate", "nn_miss_rate"],
"nn_isolation": ["nn_isolation", "nn_unit_id"],
"nn_noise_overlap": ["nn_noise_overlap"],
"silhouette": ["silhouette"],
"silhouette_full": ["silhouette_full"],
}
122 changes: 110 additions & 12 deletions src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
import numpy as np
from copy import deepcopy
import csv
from spikeinterface.core import (
NumpySorting,
synthetize_spike_train_bad_isi,
Expand Down Expand Up @@ -42,6 +43,7 @@
compute_quality_metrics,
)


from spikeinterface.core.basesorting import minimum_spike_dtype


Expand All @@ -60,6 +62,12 @@ def test_compute_new_quality_metrics(small_sorting_analyzer):
"firing_range": {"bin_size_s": 1},
}

small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}})
qm_extension = small_sorting_analyzer.get_extension("quality_metrics")
calculated_metrics = list(qm_extension.get_data().keys())

assert calculated_metrics == ["snr"]

small_sorting_analyzer.compute(
{"quality_metrics": {"metric_names": list(qm_params.keys()), "qm_params": qm_params}}
)
Expand All @@ -68,18 +76,22 @@ def test_compute_new_quality_metrics(small_sorting_analyzer):
quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics")

# Check old metrics are not deleted and the new one is added to the data and metadata
assert list(quality_metric_extension.get_data().keys()) == [
"amplitude_cutoff",
"firing_range",
"presence_ratio",
"snr",
]
assert list(quality_metric_extension.params.get("metric_names")) == [
"amplitude_cutoff",
"firing_range",
"presence_ratio",
"snr",
]
assert set(list(quality_metric_extension.get_data().keys())) == set(
[
"amplitude_cutoff",
"firing_range",
"presence_ratio",
"snr",
]
)
assert set(list(quality_metric_extension.params.get("metric_names"))) == set(
[
"amplitude_cutoff",
"firing_range",
"presence_ratio",
"snr",
]
)

# check that, when parameters are changed, the data and metadata are updated
old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values)
Expand All @@ -106,6 +118,92 @@ def test_compute_new_quality_metrics(small_sorting_analyzer):
assert small_sorting_analyzer.get_extension("quality_metrics") is None


def test_metric_names_in_same_order(small_sorting_analyzer):
"""
Computes sepecified quality metrics and checks order is propogated.
"""
specified_metric_names = ["firing_range", "snr", "amplitude_cutoff"]
small_sorting_analyzer.compute("quality_metrics", metric_names=specified_metric_names)
qm_keys = small_sorting_analyzer.get_extension("quality_metrics").get_data().keys()
for i in range(3):
assert specified_metric_names[i] == qm_keys[i]


def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder):
"""
Computes quality metrics in binary folder format. Then computes subsets of quality
metrics and checks if they are saved correctly.
"""

# can't use _misc_metric_name_to_func as some functions compute several qms
# e.g. isi_violation and synchrony
quality_metrics = [
"num_spikes",
"firing_rate",
"presence_ratio",
"snr",
"isi_violations_ratio",
"isi_violations_count",
"rp_contamination",
"rp_violations",
"sliding_rp_violation",
"amplitude_cutoff",
"amplitude_median",
"amplitude_cv_median",
"amplitude_cv_range",
"sync_spike_2",
"sync_spike_4",
"sync_spike_8",
"firing_range",
"drift_ptp",
"drift_std",
"drift_mad",
"sd_ratio",
"isolation_distance",
"l_ratio",
"d_prime",
"silhouette",
"nn_hit_rate",
"nn_miss_rate",
]

small_sorting_analyzer.compute("quality_metrics")

cache_folder = create_cache_folder
output_folder = cache_folder / "sorting_analyzer"

folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder)
quality_metrics_filename = output_folder / "extensions" / "quality_metrics" / "metrics.csv"

with open(quality_metrics_filename) as metrics_file:
saved_metrics = csv.reader(metrics_file)
metric_names = next(saved_metrics)

for metric_name in quality_metrics:
assert metric_name in metric_names

folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=False)

with open(quality_metrics_filename) as metrics_file:
saved_metrics = csv.reader(metrics_file)
metric_names = next(saved_metrics)

for metric_name in quality_metrics:
assert metric_name in metric_names

folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=True)

with open(quality_metrics_filename) as metrics_file:
saved_metrics = csv.reader(metrics_file)
metric_names = next(saved_metrics)

for metric_name in quality_metrics:
if metric_name == "snr":
assert metric_name in metric_names
else:
assert metric_name not in metric_names


def test_unit_structure_in_output(small_sorting_analyzer):

qm_params = {
Expand Down

0 comments on commit accf40a

Please sign in to comment.