From 6957e74ad9dcee84c93e3009c852b082f137c870 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 29 May 2024 12:34:51 +0100 Subject: [PATCH 1/3] Remove toy_example from test codebase (2) --- .../comparison/tests/test_templatecomparison.py | 9 ++++----- .../extractors/tests/test_mdaextractors.py | 5 +++-- .../extractors/tests/test_shybridextractors.py | 5 +++-- src/spikeinterface/preprocessing/tests/test_motion.py | 7 +++++-- .../qualitymetrics/tests/test_metrics_functions.py | 1 - 5 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/comparison/tests/test_templatecomparison.py b/src/spikeinterface/comparison/tests/test_templatecomparison.py index adffb258d6..6777b60f1f 100644 --- a/src/spikeinterface/comparison/tests/test_templatecomparison.py +++ b/src/spikeinterface/comparison/tests/test_templatecomparison.py @@ -3,8 +3,7 @@ from pathlib import Path import numpy as np -from spikeinterface.core import create_sorting_analyzer -from spikeinterface.extractors import toy_example +from spikeinterface.core import create_sorting_analyzer, generate_ground_truth_recording from spikeinterface.comparison import compare_templates, compare_multiple_templates @@ -27,9 +26,9 @@ def test_compare_multiple_templates(): duration = 60 num_channels = 8 - rec, sort = toy_example(duration=duration, num_segments=1, num_channels=num_channels) - # rec = rec.save(folder=test_dir / "rec") - # sort = sort.save(folder=test_dir / "sort") + rec, sort = generate_ground_truth_recording( + durations=[duration], num_channels=num_channels + ) # split recording in 3 equal slices fs = rec.get_sampling_frequency() diff --git a/src/spikeinterface/extractors/tests/test_mdaextractors.py b/src/spikeinterface/extractors/tests/test_mdaextractors.py index 8c8bc0aa8c..6440e575d5 100644 --- a/src/spikeinterface/extractors/tests/test_mdaextractors.py +++ b/src/spikeinterface/extractors/tests/test_mdaextractors.py @@ -1,7 +1,8 @@ import pytest from pathlib import Path from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal -from spikeinterface.extractors import toy_example, MdaRecordingExtractor, MdaSortingExtractor +from spikeinterface.core import generate_ground_truth_recording +from spikeinterface.extractors import MdaRecordingExtractor, MdaSortingExtractor if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "extractors" @@ -10,7 +11,7 @@ def test_mda_extractors(): - rec, sort = toy_example(num_segments=1, num_units=10) + rec, sort = generate_ground_truth_recording(durations=[10.0], num_units=10) MdaRecordingExtractor.write_recording(rec, cache_folder / "mdatest") rec_mda = MdaRecordingExtractor(cache_folder / "mdatest") diff --git a/src/spikeinterface/extractors/tests/test_shybridextractors.py b/src/spikeinterface/extractors/tests/test_shybridextractors.py index eed64bdbba..a0164fd119 100644 --- a/src/spikeinterface/extractors/tests/test_shybridextractors.py +++ b/src/spikeinterface/extractors/tests/test_shybridextractors.py @@ -1,7 +1,8 @@ import pytest from pathlib import Path +from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal -from spikeinterface.extractors import toy_example, SHYBRIDRecordingExtractor, SHYBRIDSortingExtractor +from spikeinterface.extractors import SHYBRIDRecordingExtractor, SHYBRIDSortingExtractor if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "extractors" @@ -11,7 +12,7 @@ @pytest.mark.skipif(True, reason="SHYBRID only tested locally") def test_shybrid_extractors(): - rec, sort = toy_example(num_segments=1, num_units=10) + rec, sort = generate_ground_truth_recording(durations=[10.0], num_units=10) SHYBRIDSortingExtractor.write_sorting(sort, cache_folder / "shybridtest") sort_shybrid = SHYBRIDSortingExtractor( diff --git a/src/spikeinterface/preprocessing/tests/test_motion.py b/src/spikeinterface/preprocessing/tests/test_motion.py index 7cea531bb4..c498957401 100644 --- a/src/spikeinterface/preprocessing/tests/test_motion.py +++ b/src/spikeinterface/preprocessing/tests/test_motion.py @@ -3,7 +3,7 @@ import shutil - +from spikeinterface.core import generate_ground_truth_recording from spikeinterface.preprocessing import correct_motion, load_motion_info from spikeinterface.extractors import toy_example @@ -19,7 +19,10 @@ def test_estimate_and_correct_motion(): - rec, sorting = toy_example(num_segments=1, duration=30.0, num_units=10, num_channels=12) + rec, sorting = generate_ground_truth_recording( + durations=[30.0], num_units=10, num_channels=12 + ) + print(rec) folder = cache_folder / "estimate_and_correct_motion" diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 79fbab8893..88908d05c5 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -11,7 +11,6 @@ synthesize_random_firings, ) -# from spikeinterface.extractors.toy_example import toy_example from spikeinterface.qualitymetrics.utils import create_ground_truth_pc_distributions from spikeinterface.qualitymetrics import calculate_pc_metrics From ed5f90c42ece149299ddff6da33deb80ac0887ab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 May 2024 11:47:22 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../comparison/tests/test_templatecomparison.py | 4 +--- src/spikeinterface/preprocessing/tests/test_motion.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/comparison/tests/test_templatecomparison.py b/src/spikeinterface/comparison/tests/test_templatecomparison.py index 6777b60f1f..f4b3a14f0d 100644 --- a/src/spikeinterface/comparison/tests/test_templatecomparison.py +++ b/src/spikeinterface/comparison/tests/test_templatecomparison.py @@ -26,9 +26,7 @@ def test_compare_multiple_templates(): duration = 60 num_channels = 8 - rec, sort = generate_ground_truth_recording( - durations=[duration], num_channels=num_channels - ) + rec, sort = generate_ground_truth_recording(durations=[duration], num_channels=num_channels) # split recording in 3 equal slices fs = rec.get_sampling_frequency() diff --git a/src/spikeinterface/preprocessing/tests/test_motion.py b/src/spikeinterface/preprocessing/tests/test_motion.py index c498957401..8aee96ae56 100644 --- a/src/spikeinterface/preprocessing/tests/test_motion.py +++ b/src/spikeinterface/preprocessing/tests/test_motion.py @@ -19,9 +19,7 @@ def test_estimate_and_correct_motion(): - rec, sorting = generate_ground_truth_recording( - durations=[30.0], num_units=10, num_channels=12 - ) + rec, sorting = generate_ground_truth_recording(durations=[30.0], num_units=10, num_channels=12) print(rec) From 567ff64f734a001c25f11e3571d3c641bf859150 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 29 May 2024 13:40:27 +0100 Subject: [PATCH 3/3] Update generate ground truth to generate recording --- src/spikeinterface/preprocessing/tests/test_motion.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_motion.py b/src/spikeinterface/preprocessing/tests/test_motion.py index 8aee96ae56..ea4611b372 100644 --- a/src/spikeinterface/preprocessing/tests/test_motion.py +++ b/src/spikeinterface/preprocessing/tests/test_motion.py @@ -3,10 +3,9 @@ import shutil -from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.preprocessing import correct_motion, load_motion_info +from spikeinterface.core import generate_recording -from spikeinterface.extractors import toy_example +from spikeinterface.preprocessing import correct_motion, load_motion_info import numpy as np @@ -19,7 +18,7 @@ def test_estimate_and_correct_motion(): - rec, sorting = generate_ground_truth_recording(durations=[30.0], num_units=10, num_channels=12) + rec = generate_recording(durations=[30.0], num_channels=12) print(rec)