Skip to content

Commit

Permalink
Improve tests for template metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Sep 10, 2024
1 parent 689c633 commit 6cbe4db
Showing 1 changed file with 28 additions and 16 deletions.
44 changes: 28 additions & 16 deletions src/spikeinterface/postprocessing/tests/test_template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import pytest
import csv

from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func

template_metrics = list(_single_channel_metric_name_to_func.keys())


def test_compute_new_template_metrics(small_sorting_analyzer):
"""
Expand All @@ -13,28 +17,38 @@ def test_compute_new_template_metrics(small_sorting_analyzer):
are deleted.
"""

# calculate just exp_decay
small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}})
template_metric_extension = small_sorting_analyzer.get_extension("template_metrics")

assert "exp_decay" in list(template_metric_extension.get_data().keys())
assert "half_width" not in list(template_metric_extension.get_data().keys())

# calculate all template metrics
small_sorting_analyzer.compute("template_metrics")

# calculate just exp_decay - this should not delete the previously calculated metrics
# calculate just exp_decay - this should not delete any other metrics
small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}})

template_metric_extension = small_sorting_analyzer.get_extension("template_metrics")

# Check old metrics are not deleted and the new one is added to the data and metadata
assert "exp_decay" in list(template_metric_extension.get_data().keys())
assert "half_width" in list(template_metric_extension.get_data().keys())
set(template_metrics) == set(template_metric_extension.get_data().keys())

# check that, when parameters are changed, the old metrics are deleted
# calculate just exp_decay with delete_existing_metrics
small_sorting_analyzer.compute(
{"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}}
{"template_metrics": {"metric_names": ["exp_decay"], "delete_existing_metrics": True}}
)

template_metric_extension = small_sorting_analyzer.get_extension("template_metrics")
computed_metric_names = template_metric_extension.get_data().keys()

assert "half_width" not in list(template_metric_extension.get_data().keys())
for metric_name in template_metrics:
if metric_name == "exp_decay":
assert metric_name in computed_metric_names
else:
assert metric_name not in computed_metric_names

assert small_sorting_analyzer.get_extension("quality_metrics") is None
# check that, when parameters are changed, the old metrics are deleted
small_sorting_analyzer.compute(
{"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}}
)


def test_save_template_metrics(small_sorting_analyzer, create_cache_folder):
Expand All @@ -43,8 +57,6 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder):
metrics and checks if they are saved correctly.
"""

from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func

small_sorting_analyzer.compute("template_metrics")

cache_folder = create_cache_folder
Expand All @@ -57,7 +69,7 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder):
saved_metrics = csv.reader(metrics_file)
metric_names = next(saved_metrics)

for metric_name in list(_single_channel_metric_name_to_func.keys()):
for metric_name in template_metrics:
assert metric_name in metric_names

folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=False)
Expand All @@ -66,7 +78,7 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder):
saved_metrics = csv.reader(metrics_file)
metric_names = next(saved_metrics)

for metric_name in list(_single_channel_metric_name_to_func.keys()):
for metric_name in template_metrics:
assert metric_name in metric_names

folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=True)
Expand All @@ -75,7 +87,7 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder):
saved_metrics = csv.reader(metrics_file)
metric_names = next(saved_metrics)

for metric_name in list(_single_channel_metric_name_to_func.keys()):
for metric_name in template_metrics:
if metric_name == "half_width":
assert metric_name in metric_names
else:
Expand Down

0 comments on commit 6cbe4db

Please sign in to comment.