Skip to content

Commit

Permalink
Update template metrics based on Joe feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Sep 10, 2024
1 parent 5074a4c commit 689c633
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 26 deletions.
61 changes: 37 additions & 24 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ComputeTemplateMetrics(AnalyzerExtension):
include_multi_channel_metrics : bool, default: False
Whether to compute multi-channel metrics
delete_existing_metrics : bool, default: False
If True, deletes any quality_metrics attached to the `sorting_analyzer`
If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged.
metrics_kwargs : dict
Additional arguments to pass to the metric functions. Including:
* recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7
Expand Down Expand Up @@ -116,6 +116,7 @@ def _set_params(
delete_existing_metrics=False,
**other_kwargs,
):

import pandas as pd

# TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory()
Expand All @@ -135,6 +136,10 @@ def _set_params(
if include_multi_channel_metrics:
metric_names += get_multi_channel_template_metric_names()

# `run` cannot take parameters, so need to find another way to pass this
self.delete_existing_metrics = delete_existing_metrics
self.metric_names = metric_names

if metrics_kwargs is None:
metrics_kwargs_ = _default_function_kwargs.copy()
if len(other_kwargs) > 0:
Expand All @@ -145,29 +150,24 @@ def _set_params(
metrics_kwargs_ = _default_function_kwargs.copy()
metrics_kwargs_.update(metrics_kwargs)

all_metric_names = metric_names
tm_extension = self.sorting_analyzer.get_extension("template_metrics")
if delete_existing_metrics is False and tm_extension is not None:
existing_metric_names = tm_extension.params["metric_names"]
existing_params = tm_extension.params["metrics_kwargs"]

existing_params = tm_extension.params["metrics_kwargs"]
# checks that existing metrics were calculated using the same params
if existing_params != metrics_kwargs_:
warnings.warn(
"The parameters used to calculate the previous template metrics are different than those used now. Deleting previous template metrics..."
)
self.sorting_analyzer.get_extension("template_metrics").data["metrics"] = pd.DataFrame(
index=self.sorting_analyzer.unit_ids
)
self.sorting_analyzer.get_extension("template_metrics").params["metric_names"] = []
tm_extension.params["metric_names"] = []
existing_metric_names = []
else:
existing_metric_names = tm_extension.params["metric_names"]

for metric_name in existing_metric_names:
if metric_name not in metric_names:
all_metric_names.append(metric_name)
metric_names = list(set(existing_metric_names + metric_names))

params = dict(
metric_names=[str(name) for name in np.unique(all_metric_names)],
metric_names=metric_names,
sparsity=sparsity,
peak_sign=peak_sign,
upsampling_factor=int(upsampling_factor),
Expand All @@ -185,6 +185,7 @@ def _merge_extension_data(
):
import pandas as pd

metric_names = self.params["metric_names"]
old_metrics = self.data["metrics"]

all_unit_ids = new_sorting_analyzer.unit_ids
Expand All @@ -193,19 +194,20 @@ def _merge_extension_data(
metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns)

metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :]
metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs)
metrics.loc[new_unit_ids, :] = self._compute_metrics(
new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs
)

new_data = dict(metrics=metrics)
return new_data

def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job_kwargs):
def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs):
"""
Compute template metrics.
"""
import pandas as pd
from scipy.signal import resample_poly

metric_names = self.params["metric_names"]
sparsity = self.params["sparsity"]
peak_sign = self.params["peak_sign"]
upsampling_factor = self.params["upsampling_factor"]
Expand Down Expand Up @@ -308,19 +310,30 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job
template_metrics.at[index, metric_name] = value
return template_metrics

def _run(self, delete_existing_metrics=False, verbose=False):
self.data["metrics"] = self._compute_metrics(
sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose
)
def _run(self, verbose=False):

delete_existing_metrics = self.delete_existing_metrics
metric_names = self.metric_names

existing_metrics = []
tm_extension = self.sorting_analyzer.get_extension("template_metrics")
if delete_existing_metrics is False and tm_extension is not None:
if (
delete_existing_metrics is False
and tm_extension is not None
and tm_extension.data.get("metrics") is not None
):
existing_metrics = tm_extension.params["metric_names"]

for metric_name in existing_metrics:
if metric_name not in self.data["metrics"]:
metric_data = tm_extension.get_data()[metric_name]
self.data["metrics"][metric_name] = metric_data
# compute the metrics which have been specified by the user
computed_metrics = self._compute_metrics(
sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metric_names
)

# append the metrics which were previously computed
for metric_name in set(existing_metrics).difference(metric_names):
computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name]

self.data["metrics"] = computed_metrics

def _get_data(self):
return self.data["metrics"]
Expand Down
53 changes: 51 additions & 2 deletions src/spikeinterface/postprocessing/tests/test_template_metrics.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite
from spikeinterface.postprocessing import ComputeTemplateMetrics
import pytest
import csv


def test_compute_new_template_metrics(small_sorting_analyzer):
"""
Computes template metrics then computes a subset of quality metrics, and checks
that the old quality metrics are not deleted.
Computes template metrics then computes a subset of template metrics, and checks
that the old template metrics are not deleted.
Then computes template metrics with new parameters and checks that old metrics
are deleted.
"""

# calculate all template metrics
small_sorting_analyzer.compute("template_metrics")

# calculate just exp_decay - this should not delete the previously calculated metrics
small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}})

template_metric_extension = small_sorting_analyzer.get_extension("template_metrics")
Expand All @@ -33,6 +37,51 @@ def test_compute_new_template_metrics(small_sorting_analyzer):
assert small_sorting_analyzer.get_extension("quality_metrics") is None


def test_save_template_metrics(small_sorting_analyzer, create_cache_folder):
"""
Computes template metrics in binary folder format. Then computes subsets of template
metrics and checks if they are saved correctly.
"""

from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func

small_sorting_analyzer.compute("template_metrics")

cache_folder = create_cache_folder
output_folder = cache_folder / "sorting_analyzer"

folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder)
template_metrics_filename = output_folder / "extensions" / "template_metrics" / "metrics.csv"

with open(template_metrics_filename) as metrics_file:
saved_metrics = csv.reader(metrics_file)
metric_names = next(saved_metrics)

for metric_name in list(_single_channel_metric_name_to_func.keys()):
assert metric_name in metric_names

folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=False)

with open(template_metrics_filename) as metrics_file:
saved_metrics = csv.reader(metrics_file)
metric_names = next(saved_metrics)

for metric_name in list(_single_channel_metric_name_to_func.keys()):
assert metric_name in metric_names

folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=True)

with open(template_metrics_filename) as metrics_file:
saved_metrics = csv.reader(metrics_file)
metric_names = next(saved_metrics)

for metric_name in list(_single_channel_metric_name_to_func.keys()):
if metric_name == "half_width":
assert metric_name in metric_names
else:
assert metric_name not in metric_names


class TestTemplateMetrics(AnalyzerExtensionCommonTestSuite):

@pytest.mark.parametrize(
Expand Down

0 comments on commit 689c633

Please sign in to comment.