diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index 1a79019f96..17bfb81d49 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -2,7 +2,7 @@ import shutil from pathlib import Path -from spikeinterface import WaveformExtractor, load_extractor, extract_waveforms, generate_recording, generate_sorting +from spikeinterface import load_extractor, extract_waveforms, load_waveforms, generate_recording, generate_sorting from spikeinterface.core import ( get_template_amplitudes, @@ -33,25 +33,23 @@ def setup_module(): sorting.set_property("group", [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) sorting = sorting.save(folder=cache_folder / "toy_sort") - we = WaveformExtractor.create(recording, sorting, cache_folder / "toy_waveforms") - we.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500) - we.run_extract_waveforms(n_jobs=1, chunk_size=30000) + we = extract_waveforms(recording, sorting, cache_folder / "toy_waveforms") def test_get_template_amplitudes(): - we = WaveformExtractor.load(cache_folder / "toy_waveforms") + we = load_waveforms(cache_folder / "toy_waveforms") peak_values = get_template_amplitudes(we) print(peak_values) def test_get_template_extremum_channel(): - we = WaveformExtractor.load(cache_folder / "toy_waveforms") + we = load_waveforms(cache_folder / "toy_waveforms") extremum_channels_ids = get_template_extremum_channel(we, peak_sign="both") print(extremum_channels_ids) def test_get_template_extremum_channel_peak_shift(): - we = WaveformExtractor.load(cache_folder / "toy_waveforms") + we = load_waveforms(cache_folder / "toy_waveforms") shifts = get_template_extremum_channel_peak_shift(we, peak_sign="neg") print(shifts) @@ -72,7 +70,7 @@ def test_get_template_extremum_channel_peak_shift(): def test_get_template_extremum_amplitude(): - we = WaveformExtractor.load(cache_folder / "toy_waveforms") + we = load_waveforms(cache_folder / "toy_waveforms") extremum_channels_ids = get_template_extremum_amplitude(we, peak_sign="both") print(extremum_channels_ids) @@ -85,4 +83,3 @@ def test_get_template_extremum_amplitude(): test_get_template_extremum_channel() test_get_template_extremum_channel_peak_shift() test_get_template_extremum_amplitude() - test_get_template_channel_sparsity() diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 204f796c0e..501fd8cc79 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -388,14 +388,17 @@ def test_unfiltered_extraction(): shutil.rmtree(wf_folder) we = WaveformExtractor.create(recording, sorting, wf_folder, mode=mode, allow_unfiltered=True) - we.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500) - + ms_before = 2.0 + ms_after = 3.0 + max_spikes_per_unit = 500 + num_samples = int((ms_before + ms_after) * sampling_frequency / 1000.0) + we.set_params(ms_before=ms_before, ms_after=ms_after, max_spikes_per_unit=max_spikes_per_unit) we.run_extract_waveforms(n_jobs=1, chunk_size=30000) we.run_extract_waveforms(n_jobs=4, chunk_size=30000, progress_bar=True) wfs = we.get_waveforms(0) - assert wfs.shape[0] <= 500 - assert wfs.shape[1:] == (210, num_channels) + assert wfs.shape[0] <= max_spikes_per_unit + assert wfs.shape[1:] == (num_samples, num_channels) wfs, sampled_index = we.get_waveforms(0, with_index=True) @@ -406,18 +409,18 @@ def test_unfiltered_extraction(): wfs = we.get_waveforms(0) template = we.get_template(0) - assert template.shape == (210, 2) + assert template.shape == (num_samples, 2) templates = we.get_all_templates() - assert templates.shape == (num_units, 210, num_channels) + assert templates.shape == (num_units, num_samples, num_channels) wf_std = we.get_template(0, mode="std") - assert wf_std.shape == (210, num_channels) + assert wf_std.shape == (num_samples, num_channels) wfs_std = we.get_all_templates(mode="std") - assert wfs_std.shape == (num_units, 210, num_channels) + assert wfs_std.shape == (num_units, num_samples, num_channels) wf_segment = we.get_template_segment(unit_id=0, segment_index=0) - assert wf_segment.shape == (210, num_channels) - assert wf_segment.shape == (210, num_channels) + assert wf_segment.shape == (num_samples, num_channels) + assert wf_segment.shape == (num_samples, num_channels) def test_portability(): diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 50e2ecdb57..b539bbd5d4 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -54,8 +54,6 @@ def setUp(self): recording, sorting, cache_folder / "toy_waveforms_1seg", - ms_before=3.0, - ms_after=4.0, max_spikes_per_unit=500, sparse=False, n_jobs=1, @@ -90,8 +88,6 @@ def setUp(self): recording, sorting, cache_folder / "toy_waveforms_2seg", - ms_before=3.0, - ms_after=4.0, max_spikes_per_unit=500, sparse=False, n_jobs=1, @@ -115,8 +111,6 @@ def setUp(self): sorting, mode="memory", sparse=False, - ms_before=3.0, - ms_after=4.0, max_spikes_per_unit=500, n_jobs=1, chunk_size=30000, diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 977beca210..73bbee611b 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -210,9 +210,7 @@ def test_peak_sign(self): # invert recording rec_inv = scale(rec, gain=-1.0) - we_inv = WaveformExtractor.create(rec_inv, sort, self.cache_folder / "toy_waveforms_inv") - we_inv.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=None) - we_inv.run_extract_waveforms(n_jobs=1, chunk_size=30000) + we_inv = extract_waveforms(rec_inv, sort, self.cache_folder / "toy_waveforms_inv") # compute amplitudes _ = compute_spike_amplitudes(we, peak_sign="neg")