From 6cbe4dbe9c805aac3fb3ed29a0f50f4623eeab9c Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 10 Sep 2024 18:25:30 +0100 Subject: [PATCH] Improve tests for template metrics --- .../tests/test_template_metrics.py | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 1fa2ac638c..8aaad8ffbc 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -3,6 +3,10 @@ import pytest import csv +from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func + +template_metrics = list(_single_channel_metric_name_to_func.keys()) + def test_compute_new_template_metrics(small_sorting_analyzer): """ @@ -13,28 +17,38 @@ def test_compute_new_template_metrics(small_sorting_analyzer): are deleted. """ + # calculate just exp_decay + small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) + template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + + assert "exp_decay" in list(template_metric_extension.get_data().keys()) + assert "half_width" not in list(template_metric_extension.get_data().keys()) + # calculate all template metrics small_sorting_analyzer.compute("template_metrics") - - # calculate just exp_decay - this should not delete the previously calculated metrics + # calculate just exp_decay - this should not delete any other metrics small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) - template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") - # Check old metrics are not deleted and the new one is added to the data and metadata - assert "exp_decay" in list(template_metric_extension.get_data().keys()) - assert "half_width" in list(template_metric_extension.get_data().keys()) + set(template_metrics) == set(template_metric_extension.get_data().keys()) - # check that, when parameters are changed, the old metrics are deleted + # calculate just exp_decay with delete_existing_metrics small_sorting_analyzer.compute( - {"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}} + {"template_metrics": {"metric_names": ["exp_decay"], "delete_existing_metrics": True}} ) - template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + computed_metric_names = template_metric_extension.get_data().keys() - assert "half_width" not in list(template_metric_extension.get_data().keys()) + for metric_name in template_metrics: + if metric_name == "exp_decay": + assert metric_name in computed_metric_names + else: + assert metric_name not in computed_metric_names - assert small_sorting_analyzer.get_extension("quality_metrics") is None + # check that, when parameters are changed, the old metrics are deleted + small_sorting_analyzer.compute( + {"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}} + ) def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): @@ -43,8 +57,6 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): metrics and checks if they are saved correctly. """ - from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func - small_sorting_analyzer.compute("template_metrics") cache_folder = create_cache_folder @@ -57,7 +69,7 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): saved_metrics = csv.reader(metrics_file) metric_names = next(saved_metrics) - for metric_name in list(_single_channel_metric_name_to_func.keys()): + for metric_name in template_metrics: assert metric_name in metric_names folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=False) @@ -66,7 +78,7 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): saved_metrics = csv.reader(metrics_file) metric_names = next(saved_metrics) - for metric_name in list(_single_channel_metric_name_to_func.keys()): + for metric_name in template_metrics: assert metric_name in metric_names folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=True) @@ -75,7 +87,7 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): saved_metrics = csv.reader(metrics_file) metric_names = next(saved_metrics) - for metric_name in list(_single_channel_metric_name_to_func.keys()): + for metric_name in template_metrics: if metric_name == "half_width": assert metric_name in metric_names else: