From 57e1703cf3d0935061ab1dd89d9f1a00075ef8ff Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 27 Feb 2024 18:08:59 +0100 Subject: [PATCH 001/136] Initial discussion Charlie and Sam to make a Motion object --- src/spikeinterface/preprocessing/motion.py | 6 +- .../sortingcomponents/motion_estimation.py | 7 ++ .../sortingcomponents/motion_interpolation.py | 9 ++ .../sortingcomponents/motion_utils.py | 98 +++++++++++++++++++ .../tests/test_motiopn_utils.py | 3 + 5 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 src/spikeinterface/sortingcomponents/motion_utils.py create mode 100644 src/spikeinterface/sortingcomponents/tests/test_motiopn_utils.py diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 1b182a6436..26ebe9ee38 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -293,6 +293,8 @@ def correct_motion( Optional output if `output_motion_info=True` """ + # TODO : Use motion object + # local import are important because "sortingcomponents" is not important by default from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods from spikeinterface.sortingcomponents.peak_selection import select_peaks @@ -401,7 +403,8 @@ def correct_motion( if folder is not None: (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - + + # TODO save Motion np.save(folder / "temporal_bins.npy", temporal_bins) np.save(folder / "motion.npy", motion) if spatial_bins is not None: @@ -413,6 +416,7 @@ def correct_motion( run_times=run_times, peaks=peaks, peak_locations=peak_locations, + # TODO use Motion temporal_bins=temporal_bins, spatial_bins=spatial_bins, motion=motion, diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py index ef3a39bed1..4c65f8f44b 100644 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion_estimation.py @@ -15,6 +15,9 @@ from .tools import make_multi_method_doc + + + def estimate_motion( recording, peaks, @@ -182,12 +185,16 @@ def estimate_motion( non_rigid_window_centers = spatial_bin_edges[:-1] + bin_um / 2 motion = motion @ non_rigid_windows + + # TODO : add Motion object here if output_extra_check: return motion, temporal_bins, non_rigid_window_centers, extra_check else: return motion, temporal_bins, non_rigid_window_centers + + class DecentralizedRegistration: """ Method developed by the Paninski's group from Columbia university: diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index f71ae0304d..05e9073c1b 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -22,9 +22,11 @@ def correct_motion_on_peaks( peaks, peak_locations, sampling_frequency, + # TODO use add Motion motion, temporal_bins, spatial_bins, + ### direction="y", ): """ @@ -74,9 +76,11 @@ def interpolate_motion_on_traces( traces, times, channel_locations, + # TODO : add Motion object here motion, temporal_bins, spatial_bins, + ### direction=1, channel_inds=None, spatial_interpolation_method="kriging", @@ -132,6 +136,8 @@ def interpolate_motion_on_traces( # inperpolation kernel will be the same per temporal bin for bin_ind in np.unique(bin_inds): + # TODO use # TODO : add Motion.get_displacement_at_time_and_depth() instead + # Step 1 : channel motion if spatial_bins.shape[0] == 1: # rigid motion : same motion for all channels @@ -364,9 +370,12 @@ def __init__( self, parent_recording_segment, channel_locations, + # TODO : add Motion object here motion, temporal_bins, spatial_bins, + ### + direction, spatial_interpolation_method, spatial_interpolation_kwargs, diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py new file mode 100644 index 0000000000..87a7350e7d --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -0,0 +1,98 @@ +import numpy as np + + + +class Motion: + """ + Motion of the tissue relative the probe. + + Parameters + ---------- + + displacement: numpy array 2d or list of + Motion estimate in um. + Shape (temporal bins, spatial bins) + motion.shape[0] = temporal_bins.shape[0] + motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) + temporal_bins_s: numpy.array 1d or list of + temporal bins (bin center) + spatial_bins_um: numpy.array 1d + Windows center. + spatial_bins_um.shape[0] == displacement.shape[1] + If rigid then spatial_bins_um.shape[0] == 1 + + """ + def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y"): + if isinstance(displacement, np.ndarray): + self.displacement = [displacement] + assert isinstance(temporal_bins_s, np.ndarray) + self.temporal_bins_s = [temporal_bins_s] + else: + assert isinstance(displacement, (list, tuple)) + self.displacement = displacement + self.temporal_bins_s = temporal_bins_s + + assert isinstance(spatial_bins_um, np.ndarray) + self.spatial_bins_um = spatial_bins_um + + self.num_segments = len(self.displacement) + self.interpolator = None + + self.direction = direction + self.dim = ["x", "y", "z"].index(direction) + + def make_interpolators(self): + from scipy.interpolate import RegularGridInterpolator2D + self.interpolator = [ + RegularGridInterpolator2D((self.spatial_bins_um, self.temporal_bins_s[j]), self.displacement[j]) + for j in range(self.num_segments) + ] + self.temporal_bounds = [(t[0], t[-1]) for t in self.temporal_bins_s] + self.spatial_bounds = (self.spatial_bins_um.min(), self.spatial_bins_um.max()) + + def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_index=None): + """ + + + Parameters + ---------- + times_s: np.array + + + locations_um: np.array + + segment_index: + + """ + if self.interpolator is None: + self.make_interpolators() + + if segment_index is None: + if self.num_segments == 1: + segment_index = 0 + else: + raise ValueError("Several segment need segment_index=") + + if locations_um.ndim == 1: + locations_um = locations_um + else: + locations_um = locations_um[:, self.dim] + times_s = np.clip(times_s, *self.temporal_bounds[segment_index]) + positions = np.clip(positions, *self.spatial_bounds) + points = np.stack([positions, times_s], axis=1) + + return self.interpolator[segment_index](points) + + def to_dict(self): + return dict( + displacement=self.displacement, + temporal_bins_s=self.temporal_bins_s, + spatial_bins_um=self.spatial_bins_um, + ) + + def save(self): + pass + + @classmethod + def load(cls): + pass diff --git a/src/spikeinterface/sortingcomponents/tests/test_motiopn_utils.py b/src/spikeinterface/sortingcomponents/tests/test_motiopn_utils.py new file mode 100644 index 0000000000..9efd26a3d5 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_motiopn_utils.py @@ -0,0 +1,3 @@ + + +# TODO Motion Make some test \ No newline at end of file From 49a346ab567435977a63df725026a13cb558ebc1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 30 Apr 2024 14:16:47 +0200 Subject: [PATCH 002/136] Add option to set recording --- src/spikeinterface/core/sortinganalyzer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 85ea9b8438..dd5695860c 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -572,6 +572,11 @@ def load_from_zarr(cls, folder, recording=None): return sorting_analyzer + def set_recording(self, recording): + if self._recording is not None: + raise ValueError("Recording is already set") + self._recording = recording + def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> "SortingAnalyzer": """ Internal used by both save_as(), copy() and select_units() which are more or less the same. From dae78136620a9fd4ce3f2f0ce9dfd931dcb3677a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 14 May 2024 15:32:47 +0200 Subject: [PATCH 003/136] Add recording attributes check, docs, and warning --- src/spikeinterface/core/recording_tools.py | 32 ++++++++++++++ src/spikeinterface/core/sortinganalyzer.py | 30 +++++++++++-- .../core/tests/test_sortinganalyzer.py | 44 +++++++++++++------ 3 files changed, 90 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 2f228432b0..6b3de3ec98 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -910,3 +910,35 @@ def get_rec_attributes(recording): dtype=recording.get_dtype(), ) return rec_attributes + + +def check_recording_attributes_match(recording1, recording2_attributes, skip_properties=True): + """ + Check if two recordings have the same attributes + + Parameters + ---------- + recording1 : BaseRecording + The first recording object + recording2 : BaseRecording + The second recording object + + Returns + ------- + bool + True if the recordings have the same attributes + """ + recording1_attributes = get_rec_attributes(recording1) + recording1_attributes["probegroup"] = recording1.get_probegroup() + recording2_attributes = deepcopy(recording2_attributes) + if skip_properties: + recording1_attributes.pop("properties") + recording2_attributes.pop("properties") + return ( + np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"]) + and recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"] + and recording1_attributes["num_channels"] == recording2_attributes["num_channels"] + and recording1_attributes["num_samples"] == recording2_attributes["num_samples"] + and recording1_attributes["is_filtered"] == recording2_attributes["is_filtered"] + and recording1_attributes["dtype"] == recording2_attributes["dtype"] + ) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 9dfbabd729..be1c1d1fec 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -20,7 +20,7 @@ from .basesorting import BaseSorting from .base import load_extractor -from .recording_tools import check_probe_do_not_overlap, get_rec_attributes +from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, check_recording_attributes_match from .core_tools import check_json, retrieve_importing_provenance from .job_tools import split_job_kwargs from .numpyextractors import NumpySorting @@ -588,9 +588,33 @@ def load_from_zarr(cls, folder, recording=None): return sorting_analyzer - def set_recording(self, recording): + def set_temporary_recording(self, recording: BaseRecording): + """ + Sets a temporary recording object. This function can be useful to temporarily set + a "cached" recording object that is not saved in the SortingAnalyzer object to speed up + computations. Upon reloading, the SortingAnalyzer object will try to reload the recording + from the original location in a lazy way. + + + Parameters + ---------- + recording : BaseRecording + The recording object to set as temporary recording. + + Raises + ------ + ValueError + _description_ + """ + # check that recording is compatible + assert check_recording_attributes_match( + recording, self.rec_attributes, skip_properties=True + ), "Recording attributes do not match." + assert np.array_equal( + recording.get_channel_locations(), self.get_channel_locations() + ), "Recording channel locations do not match." if self._recording is not None: - raise ValueError("Recording is already set") + warnings.warn("SortingAnalyzer recording is already set. This will overwrite the current recording.") self._recording = recording def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> "SortingAnalyzer": diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 66b670d956..bdce31c5b2 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -15,7 +15,7 @@ import numpy as np -def get_dataset(): +def _get_dataset(): recording, sorting = generate_ground_truth_recording( durations=[30.0], sampling_frequency=16000.0, @@ -28,8 +28,13 @@ def get_dataset(): return recording, sorting -def test_SortingAnalyzer_memory(tmp_path): - recording, sorting = get_dataset() +@pytest.fixture(scope="module") +def get_dataset(): + return _get_dataset() + + +def test_SortingAnalyzer_memory(tmp_path, get_dataset): + recording, sorting = get_dataset sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) @@ -48,8 +53,8 @@ def test_SortingAnalyzer_memory(tmp_path): assert not sorting_analyzer.return_scaled -def test_SortingAnalyzer_binary_folder(tmp_path): - recording, sorting = get_dataset() +def test_SortingAnalyzer_binary_folder(tmp_path, get_dataset): + recording, sorting = get_dataset folder = tmp_path / "test_SortingAnalyzer_binary_folder" if folder.exists(): @@ -78,8 +83,8 @@ def test_SortingAnalyzer_binary_folder(tmp_path): _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) -def test_SortingAnalyzer_zarr(tmp_path): - recording, sorting = get_dataset() +def test_SortingAnalyzer_zarr(tmp_path, get_dataset): + recording, sorting = get_dataset folder = tmp_path / "test_SortingAnalyzer_zarr.zarr" if folder.exists(): @@ -99,10 +104,21 @@ def test_SortingAnalyzer_zarr(tmp_path): ) -def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): +def test_SortingAnalyzer_tmp_recording(get_dataset): + recording, sorting = get_dataset + recording_cached = recording.save(mode="memory") - print() - print(sorting_analyzer) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) + sorting_analyzer.set_temporary_recording(recording_cached) + + recording_sliced = recording.channel_slice(recording.channel_ids[:-1]) + + # wrong channels + with pytest.raises(AssertionError): + sorting_analyzer.set_temporary_recording(recording_sliced) + + +def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): register_result_extension(DummyAnalyzerExtension) @@ -257,8 +273,10 @@ def test_extension(): if __name__ == "__main__": tmp_path = Path("test_SortingAnalyzer") - test_SortingAnalyzer_memory(tmp_path) - test_SortingAnalyzer_binary_folder(tmp_path) - test_SortingAnalyzer_zarr(tmp_path) + dataset = _get_dataset() + test_SortingAnalyzer_memory(tmp_path, dataset) + test_SortingAnalyzer_binary_folder(tmp_path, dataset) + test_SortingAnalyzer_zarr(tmp_path, dataset) + test_SortingAnalyzer_tmp_recording(dataset) test_extension() test_extension_params() From b036cf340174f5525923430f43de6fe01615b458 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 14 May 2024 16:23:39 +0200 Subject: [PATCH 004/136] thank you Zach! --- src/spikeinterface/core/recording_tools.py | 7 +++---- src/spikeinterface/core/sortinganalyzer.py | 7 +------ 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 6b3de3ec98..0f9fa028f3 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -912,7 +912,7 @@ def get_rec_attributes(recording): return rec_attributes -def check_recording_attributes_match(recording1, recording2_attributes, skip_properties=True): +def check_recording_attributes_match(recording1, recording2_attributes, skip_properties=True) -> bool: """ Check if two recordings have the same attributes @@ -920,8 +920,8 @@ def check_recording_attributes_match(recording1, recording2_attributes, skip_pro ---------- recording1 : BaseRecording The first recording object - recording2 : BaseRecording - The second recording object + recording2_attributes : dict + The recording attributes to test against Returns ------- @@ -929,7 +929,6 @@ def check_recording_attributes_match(recording1, recording2_attributes, skip_pro True if the recordings have the same attributes """ recording1_attributes = get_rec_attributes(recording1) - recording1_attributes["probegroup"] = recording1.get_probegroup() recording2_attributes = deepcopy(recording2_attributes) if skip_properties: recording1_attributes.pop("properties") diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index be1c1d1fec..a122933ecc 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -600,11 +600,6 @@ def set_temporary_recording(self, recording: BaseRecording): ---------- recording : BaseRecording The recording object to set as temporary recording. - - Raises - ------ - ValueError - _description_ """ # check that recording is compatible assert check_recording_attributes_match( @@ -614,7 +609,7 @@ def set_temporary_recording(self, recording: BaseRecording): recording.get_channel_locations(), self.get_channel_locations() ), "Recording channel locations do not match." if self._recording is not None: - warnings.warn("SortingAnalyzer recording is already set. This will overwrite the current recording.") + warnings.warn("SortingAnalyzer recording is already set. " "The current recording is temporarily replaced.") self._recording = recording def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> "SortingAnalyzer": From 294fa26fcb2bf3dac4d84f5ffc0040fa413c6103 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 15 May 2024 10:24:31 -0600 Subject: [PATCH 005/136] remove unused imports ensure integer --- src/spikeinterface/core/recording_tools.py | 2 +- src/spikeinterface/postprocessing/amplitude_scalings.py | 2 +- src/spikeinterface/sortingcomponents/matching/naive.py | 2 +- src/spikeinterface/sortingcomponents/peak_detection.py | 1 - 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 2f228432b0..3bcb91cc23 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -702,7 +702,7 @@ def get_chunk_with_margin( case zero padding is used, in the second case np.pad is called with mod="reflect". """ - length = rec_segment.get_num_samples() + length = int(rec_segment.get_num_samples()) if channel_indices is None: channel_indices = slice(None) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index e2dcdd8e5a..57a97be16e 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core import ChannelSparsity, get_chunk_with_margin +from spikeinterface.core import ChannelSparsity from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, ensure_n_jobs, fix_job_kwargs from spikeinterface.core.template_tools import get_template_extremum_channel diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index c172e90fd8..0dc71d789b 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -4,7 +4,7 @@ import numpy as np -from spikeinterface.core import get_noise_levels, get_channel_distances, get_chunk_with_margin, get_random_data_chunks +from spikeinterface.core import get_noise_levels, get_channel_distances, get_random_data_chunks from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive from spikeinterface.core.template import Templates diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 508a033c41..a67f2ef674 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -26,7 +26,6 @@ ) from spikeinterface.postprocessing.unit_localization import get_convolution_weights -from ..core import get_chunk_with_margin from .tools import make_multi_method_doc From 1df89594865caae348b300797c8b1e6c27236407 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 15 May 2024 10:43:34 -0600 Subject: [PATCH 006/136] segment sum --- src/spikeinterface/core/segmentutils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index c3881cc1f8..75fd874f78 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -181,8 +181,8 @@ def get_traces(self, start_frame, end_frame, channel_indices): if i0 == i1: #  one segment - rec_seg = self.parent_segments[i0] - seg_start = self.cumsum_length[i0] + rec_seg = int(self.parent_segments[i0]) + seg_start = int(self.cumsum_length[i0]) # Cum sum length is a numpy array traces = rec_seg.get_traces(start_frame - seg_start, end_frame - seg_start, channel_indices) else: #  several segments @@ -192,8 +192,8 @@ def get_traces(self, start_frame, end_frame, channel_indices): # limit case continue - rec_seg = self.parent_segments[i] - seg_start = self.cumsum_length[i] + rec_seg = int(self.parent_segments[i]) + seg_start = int(self.cumsum_length[i]) if i == i0: # first traces_chunk = rec_seg.get_traces(start_frame - seg_start, None, channel_indices) From 378c6c1c0f3dbfe0b683fad1db549eed8b66deda Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 15 May 2024 10:45:24 -0600 Subject: [PATCH 007/136] segment sum --- src/spikeinterface/core/segmentutils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index 75fd874f78..9bc53c11f1 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -156,7 +156,8 @@ def __init__(self, parent_segments, sampling_frequency, ignore_times=True): BaseRecordingSegment.__init__(self, **time_kwargs) self.parent_segments = parent_segments self.all_length = [rec_seg.get_num_samples() for rec_seg in self.parent_segments] - self.cumsum_length = np.cumsum([0] + self.all_length) + cumulative_sum_numpy = np.cumsum([0] + self.all_length) # We need to cast to int for overflow concerns + self.cumsum_length = [int(samples_till_segment for samples_till_segment in cumulative_sum_numpy)] self.total_length = int(np.sum(self.all_length)) def get_num_samples(self): @@ -181,8 +182,8 @@ def get_traces(self, start_frame, end_frame, channel_indices): if i0 == i1: #  one segment - rec_seg = int(self.parent_segments[i0]) - seg_start = int(self.cumsum_length[i0]) # Cum sum length is a numpy array + rec_seg = self.parent_segments[i0] + seg_start = self.cumsum_length[i0] traces = rec_seg.get_traces(start_frame - seg_start, end_frame - seg_start, channel_indices) else: #  several segments @@ -192,8 +193,8 @@ def get_traces(self, start_frame, end_frame, channel_indices): # limit case continue - rec_seg = int(self.parent_segments[i]) - seg_start = int(self.cumsum_length[i]) + rec_seg = self.parent_segments[i] + seg_start = self.cumsum_length[i] if i == i0: # first traces_chunk = rec_seg.get_traces(start_frame - seg_start, None, channel_indices) From af29f412b2aef688a3a4b799f0d963620709391a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 15 May 2024 11:25:27 -0600 Subject: [PATCH 008/136] fix at the root --- src/spikeinterface/core/segmentutils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index 9bc53c11f1..959b7f8c43 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -156,8 +156,7 @@ def __init__(self, parent_segments, sampling_frequency, ignore_times=True): BaseRecordingSegment.__init__(self, **time_kwargs) self.parent_segments = parent_segments self.all_length = [rec_seg.get_num_samples() for rec_seg in self.parent_segments] - cumulative_sum_numpy = np.cumsum([0] + self.all_length) # We need to cast to int for overflow concerns - self.cumsum_length = [int(samples_till_segment for samples_till_segment in cumulative_sum_numpy)] + self.cumsum_length = [0] + [sum(self.all_length[: i + 1]) for i in range(len(self.all_length))] self.total_length = int(np.sum(self.all_length)) def get_num_samples(self): From 0866e3da057ba24bd6b852e501680c608a3c3ccd Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 20 May 2024 09:48:35 +0200 Subject: [PATCH 009/136] wip Motion object --- .../sortingcomponents/motion_estimation.py | 37 +++++++----------- .../sortingcomponents/motion_utils.py | 34 +++++++++++++---- .../sortingcomponents/tests/common.py | 2 + .../tests/test_motion_estimation.py | 30 +++++++-------- .../tests/test_motion_utils.py | 38 +++++++++++++++++++ .../tests/test_motiopn_utils.py | 3 -- 6 files changed, 95 insertions(+), 49 deletions(-) create mode 100644 src/spikeinterface/sortingcomponents/tests/test_motion_utils.py delete mode 100644 src/spikeinterface/sortingcomponents/tests/test_motiopn_utils.py diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py index 4c65f8f44b..6925c7aede 100644 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion_estimation.py @@ -13,7 +13,7 @@ HAVE_TORCH = False from .tools import make_multi_method_doc - +from .motion_utils import Motion @@ -109,19 +109,8 @@ def estimate_motion( Returns ------- - motion: numpy array 2d - Motion estimate in um. - Shape (temporal bins, spatial bins) - motion.shape[0] = temporal_bins.shape[0] - motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) - If upsample_to_histogram_bin, motion.shape[1] corresponds to spatial - bins given by bin_um. - temporal_bins: numpy.array 1d - temporal bins (bin center) - spatial_bins: numpy.array 1d - Windows center. - spatial_bins.shape[0] == motion.shape[1] - If rigid then spatial_bins.shape[0] == 1 + motion: Motion object + The motion object. extra_check: dict Optional output if `output_extra_check=True` This dict contain histogram, pairwise_displacement usefull for ploting. @@ -152,7 +141,7 @@ def estimate_motion( # run method method_class = estimate_motion_methods[method] - motion, temporal_bins = method_class.run( + motion_array, temporal_bins = method_class.run( recording, peaks, peak_locations, @@ -168,29 +157,31 @@ def estimate_motion( ) # replace nan by zeros - motion[np.isnan(motion)] = 0 + motion_array[np.isnan(motion_array)] = 0 if post_clean: - motion = clean_motion_vector( - motion, temporal_bins, bin_duration_s, speed_threshold=speed_threshold, sigma_smooth_s=sigma_smooth_s + motion_array = clean_motion_vector( + motion_array, temporal_bins, bin_duration_s, speed_threshold=speed_threshold, sigma_smooth_s=sigma_smooth_s ) if upsample_to_histogram_bin is None: upsample_to_histogram_bin = not rigid if upsample_to_histogram_bin: - extra_check["motion"] = motion + extra_check["motion_array"] = motion_array extra_check["non_rigid_window_centers"] = non_rigid_window_centers non_rigid_windows = np.array(non_rigid_windows) non_rigid_windows /= non_rigid_windows.sum(axis=0, keepdims=True) non_rigid_window_centers = spatial_bin_edges[:-1] + bin_um / 2 - motion = motion @ non_rigid_windows + motion_array = motion_array @ non_rigid_windows + + # TODO handle multi segment + motion = Motion([motion_array], [temporal_bins], non_rigid_window_centers, direction=direction) - # TODO : add Motion object here if output_extra_check: - return motion, temporal_bins, non_rigid_window_centers, extra_check + return motion, extra_check else: - return motion, temporal_bins, non_rigid_window_centers + return motion diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py index 87a7350e7d..448aefdb9e 100644 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -11,9 +11,11 @@ class Motion: displacement: numpy array 2d or list of Motion estimate in um. - Shape (temporal bins, spatial bins) - motion.shape[0] = temporal_bins.shape[0] - motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) + List is the number of segment. + For each semgent : + * shape (temporal bins, spatial bins) + * motion.shape[0] = temporal_bins.shape[0] + * motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) temporal_bins_s: numpy.array 1d or list of temporal bins (bin center) spatial_bins_um: numpy.array 1d @@ -42,9 +44,9 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y" self.dim = ["x", "y", "z"].index(direction) def make_interpolators(self): - from scipy.interpolate import RegularGridInterpolator2D + from scipy.interpolate import RegularGridInterpolator self.interpolator = [ - RegularGridInterpolator2D((self.spatial_bins_um, self.temporal_bins_s[j]), self.displacement[j]) + RegularGridInterpolator((self.temporal_bins_s[j], self.spatial_bins_um), self.displacement[j]) for j in range(self.num_segments) ] self.temporal_bounds = [(t[0], t[-1]) for t in self.temporal_bins_s] @@ -72,14 +74,17 @@ def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_inde segment_index = 0 else: raise ValueError("Several segment need segment_index=") + + times_s = np.asarray(times_s) + locations_um = np.asarray(times_s) if locations_um.ndim == 1: locations_um = locations_um else: locations_um = locations_um[:, self.dim] times_s = np.clip(times_s, *self.temporal_bounds[segment_index]) - positions = np.clip(positions, *self.spatial_bounds) - points = np.stack([positions, times_s], axis=1) + locations_um = np.clip(locations_um, *self.spatial_bounds) + points = np.stack([times_s, locations_um,], axis=1) return self.interpolator[segment_index](points) @@ -91,8 +96,23 @@ def to_dict(self): ) def save(self): + # TODO pass @classmethod def load(cls): + # TODO pass + + def __eq__(self, other): + + for segment_index in range(self.num_segments): + if not np.allclose(self.displacement[segment_index], other.displacement[segment_index]): + return False + if not np.allclose(self.temporal_bins_s[segment_index], other.temporal_bins_s[segment_index]): + return False + + if not np.allclose(self.spatial_bins_um, other.spatial_bins_um): + return False + + return True diff --git a/src/spikeinterface/sortingcomponents/tests/common.py b/src/spikeinterface/sortingcomponents/tests/common.py index aacd7576fb..a711b67bda 100644 --- a/src/spikeinterface/sortingcomponents/tests/common.py +++ b/src/spikeinterface/sortingcomponents/tests/common.py @@ -3,6 +3,7 @@ from spikeinterface.core import generate_ground_truth_recording + def make_dataset(): # this replace the MEArec 10s file for testing recording, sorting = generate_ground_truth_recording( @@ -22,3 +23,4 @@ def make_dataset(): seed=2205, ) return recording, sorting + diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index 36d2d34f4d..3519c66228 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -161,30 +161,28 @@ def test_estimate_motion(): ) kwargs.update(cases_kwargs) - motion, temporal_bins, spatial_bins, extra_check = estimate_motion(recording, peaks, peak_locations, **kwargs) + motion, extra_check = estimate_motion(recording, peaks, peak_locations, **kwargs) motions[name] = motion - assert temporal_bins.shape[0] == motion.shape[0] - assert spatial_bins.shape[0] == motion.shape[1] - if cases_kwargs["rigid"]: - assert motion.shape[1] == 1 + assert motion.displacement[0].shape[1] == 1 else: - assert motion.shape[1] > 1 + assert motion.displacement[0].shape[1] > 1 - # Test saving to disk - corrected_rec = InterpolateMotionRecording( - recording, motion, temporal_bins, spatial_bins, border_mode="force_extrapolate" - ) - rec_folder = cache_folder / (name.replace("/", "").replace(" ", "_") + "_recording") - if rec_folder.exists(): - shutil.rmtree(rec_folder) - corrected_rec.save(folder=rec_folder) + # # Test saving to disk + # corrected_rec = InterpolateMotionRecording( + # recording, motion, temporal_bins, spatial_bins, border_mode="force_extrapolate" + # ) + # rec_folder = cache_folder / (name.replace("/", "").replace(" ", "_") + "_recording") + # if rec_folder.exists(): + # shutil.rmtree(rec_folder) + # corrected_rec.save(folder=rec_folder) if DEBUG: fig, ax = plt.subplots() - ax.plot(temporal_bins, motion) + seg_index = 0 + ax.plot(motion.temporal_bins_s[0], motion.displacement[seg_index]) # motion_histogram = extra_check['motion_histogram'] # spatial_hist_bins = extra_check['spatial_hist_bin_edges'] @@ -205,7 +203,7 @@ def test_estimate_motion(): # same params with differents engine should be the same motion0, motion1 = motions["rigid / decentralized / torch"], motions["rigid / decentralized / numpy"] - assert (motion0 == motion1).all() + assert (motion0 == motion1) motion0, motion1 = ( motions["rigid / decentralized / torch / time_horizon_s"], diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py new file mode 100644 index 0000000000..dc826ce773 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py @@ -0,0 +1,38 @@ + + +# TODO Motion Make some test + +import pytest +import numpy as np + +from spikeinterface.sortingcomponents.motion_utils import Motion + + + + +def test_Motion(): + + temporal_bins_s = np.arange(0., 10., 1.) + spatial_bins_um = np.array([100., 200.]) + + displacement = np.zeros((temporal_bins_s.shape[0], spatial_bins_um.shape[0])) + displacement[:, :] = np.linspace(-20, 20, temporal_bins_s.shape[0])[:, np.newaxis] + + motion = Motion( + displacement, temporal_bins_s, spatial_bins_um, direction="y" + ) + + motion2 = Motion(**motion.to_dict()) + assert motion == motion2 + + displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11, ], [120., 80., 150.]) + # print(displacement) + assert displacement.shape[0] == 3 + # check clip + assert displacement[2] == 20. + + + + +if __name__ == "__main__": + test_Motion() \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/tests/test_motiopn_utils.py b/src/spikeinterface/sortingcomponents/tests/test_motiopn_utils.py deleted file mode 100644 index 9efd26a3d5..0000000000 --- a/src/spikeinterface/sortingcomponents/tests/test_motiopn_utils.py +++ /dev/null @@ -1,3 +0,0 @@ - - -# TODO Motion Make some test \ No newline at end of file From e816f9383aad5b2350d9186d441bd344db1823d3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 20 May 2024 15:12:37 +0200 Subject: [PATCH 010/136] WIP : refactor with Motion object --- .../sortingcomponents/motion_interpolation.py | 221 ++++++++---------- .../sortingcomponents/motion_utils.py | 31 +++ .../tests/test_motion_interpolation.py | 45 ++-- .../tests/test_motion_utils.py | 1 + 4 files changed, 157 insertions(+), 141 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 05e9073c1b..b209cb31bc 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -11,23 +11,11 @@ from spikeinterface.preprocessing import get_spatial_interpolation_kernel -# try: -# import numba -# HAVE_NUMBA = True -# except ImportError: -# HAVE_NUMBA = False - - def correct_motion_on_peaks( peaks, peak_locations, sampling_frequency, - # TODO use add Motion motion, - temporal_bins, - spatial_bins, - ### - direction="y", ): """ Given the output of estimate_motion(), apply inverse motion on peak locations. @@ -40,13 +28,8 @@ def correct_motion_on_peaks( peaks location vector sampling_frequency: np.array sampling_frequency of the recording - motion: np.array 2D - motion.shape[0] equal temporal_bins.shape[0] - motion.shape[1] equal 1 when "rigid" motion equal temporal_bins.shape[0] when "non-rigid" - temporal_bins: np.array - Temporal bins in second. - spatial_bins: np.array - Bins for non-rigid motion. If spatial_bins.sahpe[0] == 1 then rigid motion is used. + motion: Motion + The motion object. Returns ------- @@ -55,33 +38,26 @@ def correct_motion_on_peaks( """ corrected_peak_locations = peak_locations.copy() - spike_times = peaks["sample_index"] / sampling_frequency - if spatial_bins.shape[0] == 1: - # rigid motion interpolation 1D - f = scipy.interpolate.interp1d(temporal_bins, motion[:, 0], bounds_error=False, fill_value="extrapolate") - shift = f(spike_times) - corrected_peak_locations[direction] -= shift - else: - # non rigid motion = interpolation 2D - f = scipy.interpolate.RegularGridInterpolator( - (temporal_bins, spatial_bins), motion, method="linear", bounds_error=False, fill_value=None - ) - shift = f(np.c_[spike_times, peak_locations[direction]]) - corrected_peak_locations[direction] -= shift + for segment_index in range(motion.num_segments): + i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) - return corrected_peak_locations + # TODO delegate times to recording object + spike_times = peaks["sample_index"][i0:i1] / sampling_frequency + spike_locs = peak_locations[motion.direction][i0:i1] + spike_displacement = motion.get_displacement_at_time_and_depth(spike_times, spike_locs, segment_index=segment_index) + + corrected_peak_locations[i0:i1][motion.direction] -= spike_displacement + + return corrected_peak_locations + def interpolate_motion_on_traces( traces, times, channel_locations, - # TODO : add Motion object here motion, - temporal_bins, - spatial_bins, - ### - direction=1, + segment_index=None, channel_inds=None, spatial_interpolation_method="kriging", spatial_interpolation_kwargs={}, @@ -97,16 +73,10 @@ def interpolate_motion_on_traces( Trace snippet (num_samples, num_channels) channel_location: np.array 2d Channel location with shape (n, 2) or (n, 3) - motion: np.array 2D - motion.shape[0] equal temporal_bins.shape[0] - motion.shape[1] equal 1 when "rigid" motion - equal temporal_bins.shape[0] when "none rigid" - temporal_bins: np.array - Temporal bins in second. - spatial_bins: None or np.array - Bins for non-rigid motion. If None, rigid motion is used - direction: int in (0, 1, 2) - Dimension of shift in channel_locations. + motion: Motion + The motion object. + segment_index: int or None + The segment index. channel_inds: None or list If not None, interpolate only a subset of channels. spatial_interpolation_method: "idw" | "kriging", default: "kriging" @@ -125,33 +95,56 @@ def interpolate_motion_on_traces( # assert HAVE_NUMBA assert times.shape[0] == traces.shape[0] + if segment_index is None: + if motion.num_segments == 1: + segment_index = 0 + else: + raise ValueError("Several segment need segment_index=") + if channel_inds is None: traces_corrected = np.zeros(traces.shape, dtype=traces.dtype) else: channel_inds = np.asarray(channel_inds) traces_corrected = np.zeros((traces.shape[0], channel_inds.size), dtype=traces.dtype) + + total_num_chans = channel_locations.shape[0] + + # TODO give optional possibility to have smaler times bins than the motion with interpolation + # this would remove the need of _get_closest_ind and searchsorted + # TODO delegate times to recording, at the moment this is 0 based # regroup times by closet temporal_bins - bin_inds = _get_closest_ind(temporal_bins, times) + bin_inds = _get_closest_ind(motion.temporal_bins_s[segment_index], times) # inperpolation kernel will be the same per temporal bin for bin_ind in np.unique(bin_inds): - # TODO use # TODO : add Motion.get_displacement_at_time_and_depth() instead - # Step 1 : channel motion - if spatial_bins.shape[0] == 1: - # rigid motion : same motion for all channels - channel_motions = motion[bin_ind, 0] - else: - # non rigid : interpolation channel motion for this temporal bin - f = scipy.interpolate.interp1d( - spatial_bins, motion[bin_ind, :], kind="linear", axis=0, bounds_error=False, fill_value="extrapolate" - ) - locs = channel_locations[:, direction] - channel_motions = f(locs) + bin_time = motion.temporal_bins_s[segment_index][bin_ind] + + channel_motions = motion.get_displacement_at_time_and_depth( + np.full(total_num_chans, bin_time), + channel_locations[motion.dim], + segment_index=segment_index + ) channel_locations_moved = channel_locations.copy() - channel_locations_moved[:, direction] += channel_motions - # channel_locations_moved[:, direction] -= channel_motions + channel_locations_moved[:, motion.dim] += channel_motions + + # # TODO use # TODO : add Motion.get_displacement_at_time_and_depth() instead + + # # Step 1 : channel motion + # if spatial_bins.shape[0] == 1: + # # rigid motion : same motion for all channels + # channel_motions = motion[bin_ind, 0] + # else: + # # non rigid : interpolation channel motion for this temporal bin + # f = scipy.interpolate.interp1d( + # spatial_bins, motion[bin_ind, :], kind="linear", axis=0, bounds_error=False, fill_value="extrapolate" + # ) + # locs = channel_locations[:, direction] + # channel_motions = f(locs) + # channel_locations_moved = channel_locations.copy() + # channel_locations_moved[:, direction] += channel_motions + # # channel_locations_moved[:, direction] -= channel_motions if channel_inds is not None: channel_locations_moved = channel_locations_moved[channel_inds] @@ -164,8 +157,16 @@ def interpolate_motion_on_traces( **spatial_interpolation_kwargs, ) - i0 = np.searchsorted(bin_inds, bin_ind, side="left") - i1 = np.searchsorted(bin_inds, bin_ind, side="right") + # keep this for DEBUG + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # ax.matshow(drift_kernel) + # ax.set_title(f"bin_ind {bin_ind} - {bin_time}s - {spatial_interpolation_method}") + # plt.show() + + # i0 = np.searchsorted(bin_inds, bin_ind, side="left") + # i1 = np.searchsorted(bin_inds, bin_ind, side="right") + i0, i1 = np.searchsorted(bin_inds, [bin_ind, bin_ind + 1], side="left") # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing @@ -226,16 +227,8 @@ class InterpolateMotionRecording(BasePreprocessor): ---------- recording: Recording The parent recording. - motion: np.array 2D - The motion signal obtained with `estimate_motion()` - motion.shape[0] must correspond to temporal_bins.shape[0] - motion.shape[1] is 1 when "rigid" motion and spatial_bins.shape[0] when "non-rigid" - temporal_bins: np.array - Temporal bins in second. - spatial_bins: None or np.array - Bins for non-rigid motion. If None, rigid motion is used - direction: 0 | 1 | 2, default: 1 - Dimension along which channel_locations are shifted (0 - x, 1 - y, 2 - z) + motion: Motion + The motion object spatial_interpolation_method: "kriging" | "idw" | "nearest", default: "kriging" The spatial interpolation method used to interpolate the channel locations. See `spikeinterface.preprocessing.get_spatial_interpolation_kernel()` for more details. @@ -269,49 +262,55 @@ def __init__( self, recording, motion, - temporal_bins, - spatial_bins, - direction=1, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=1, num_closest=3, ): - assert recording.get_num_segments() == 1, "correct_motion() is only available for single-segment recordings" + # assert recording.get_num_segments() == 1, "correct_motion() is only available for single-segment recordings" - # force as arrays - temporal_bins = np.asarray(temporal_bins) - motion = np.asarray(motion) - spatial_bins = np.asarray(spatial_bins) + # # force as arrays + # temporal_bins = np.asarray(temporal_bins) + # motion = np.asarray(motion) + # spatial_bins = np.asarray(spatial_bins) channel_locations = recording.get_channel_locations() - assert channel_locations.ndim >= direction, ( - f"'direction' {direction} not available. " f"Channel locations have {channel_locations.ndim} dimensions." + assert channel_locations.ndim >= motion.dim, ( + f"'direction' {motion.direction} not available. " f"Channel locations have {channel_locations.ndim} dimensions." ) spatial_interpolation_kwargs = dict(sigma_um=sigma_um, p=p, num_closest=num_closest) if border_mode == "remove_channels": - locs = channel_locations[:, direction] - l0, l1 = np.min(channel_locations[:, direction]), np.max(channel_locations[:, direction]) + locs = channel_locations[:, motion.dim] + l0, l1 = np.min(locs), np.max(locs) # compute max and min motion (with interpolation) - # and check if channels are inside + # and check if channels are inside for all segment channel_inside = np.ones(locs.shape[0], dtype="bool") - for operator in (np.max, np.min): - if spatial_bins.shape[0] == 1: - best_motions = operator(motion[:, 0]) - else: - # non rigid : interpolation channel motion for this temporal bin - f = scipy.interpolate.interp1d( - spatial_bins, - operator(motion[:, :], axis=0), - kind="linear", - axis=0, - bounds_error=False, - fill_value="extrapolate", + for operator, arg_operator in ((np.max,np.argmax), (np.min, np.argmin)): + for segment_index in range(recording.get_num_segments()): + ind = arg_operator(operator(motion.displacement[segment_index], axis=1)) + bin_time = motion.temporal_bins_s[segment_index][ind] + best_motions = motion.get_displacement_at_time_and_depth( + np.full(locs.shape[0], bin_time), locs, segment_index=segment_index ) - best_motions = f(locs) - channel_inside &= ((locs + best_motions) >= l0) & ((locs + best_motions) <= l1) + channel_inside &= ((locs + best_motions) >= l0) & ((locs + best_motions) <= l1) + + + # if spatial_bins.shape[0] == 1: + # best_motions = operator(motion[:, 0]) + # else: + # # non rigid : interpolation channel motion for this temporal bin + # f = scipy.interpolate.interp1d( + # spatial_bins, + # operator(motion[:, :], axis=0), + # kind="linear", + # axis=0, + # bounds_error=False, + # fill_value="extrapolate", + # ) + # best_motions = f(locs) + # channel_inside &= ((locs + best_motions) >= l0) & ((locs + best_motions) <= l1) (channel_inds,) = np.nonzero(channel_inside) channel_ids = recording.channel_ids[channel_inds] @@ -342,9 +341,6 @@ def __init__( parent_segment, channel_locations, motion, - temporal_bins, - spatial_bins, - direction, spatial_interpolation_method, spatial_interpolation_kwargs, channel_inds, @@ -354,9 +350,6 @@ def __init__( self._kwargs = dict( recording=recording, motion=motion, - temporal_bins=temporal_bins, - spatial_bins=spatial_bins, - direction=direction, border_mode=border_mode, spatial_interpolation_method=spatial_interpolation_method, sigma_um=sigma_um, @@ -370,13 +363,7 @@ def __init__( self, parent_recording_segment, channel_locations, - # TODO : add Motion object here motion, - temporal_bins, - spatial_bins, - ### - - direction, spatial_interpolation_method, spatial_interpolation_kwargs, channel_inds, @@ -384,9 +371,6 @@ def __init__( BasePreprocessorSegment.__init__(self, parent_recording_segment) self.channel_locations = channel_locations self.motion = motion - self.temporal_bins = temporal_bins - self.spatial_bins = spatial_bins - self.direction = direction self.spatial_interpolation_method = spatial_interpolation_method self.spatial_interpolation_kwargs = spatial_interpolation_kwargs self.channel_inds = channel_inds @@ -417,9 +401,6 @@ def get_traces(self, start_frame, end_frame, channel_indices): times, self.channel_locations, self.motion, - self.temporal_bins, - self.spatial_bins, - direction=self.direction, channel_inds=self.channel_inds, spatial_interpolation_method=self.spatial_interpolation_method, spatial_interpolation_kwargs=self.spatial_interpolation_kwargs, diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py index 448aefdb9e..6bf10372d5 100644 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -2,6 +2,25 @@ + +# @charlie @sam +# here TODO list for motion object +# * simple test for Motion: DONE +# * save/load Motion +# * make better test for Motion object with save/load +# * propagate to estimate_motion : DONE +# * handle multi segment in estimate_motion(): maybe in another PR +# * propagate to motion_interpolation.py: +# * propagate to preprocessing/correct_motion() +# * generate drifting signals for test estimate_motion and interpolate_motion +# * uncomment assert in test_estimate_motion (aka debug torch vs numpy diff) +# * delegate times to recording object in +# * estimate motion +# * correct_motion_on_peaks() +# * interpolate_motion_on_traces() + + + class Motion: """ Motion of the tissue relative the probe. @@ -42,6 +61,18 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y" self.direction = direction self.dim = ["x", "y", "z"].index(direction) + + def __repr__(self): + nbins = self.spatial_bins_um.shape[0] + if nbins == 1: + rigid_txt = "rigid" + else: + rigid_txt = f"non-rigid - {nbins} spatial bins" + + interval_s = self.temporal_bins_s[0][1] - self.temporal_bins_s[0][0] + txt = f"Motion {rigid_txt} - interval {interval_s}s -{self.num_segments} segments" + return txt + def make_interpolators(self): from scipy.interpolate import RegularGridInterpolator diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py index cc3434b782..47f61f9ad6 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py @@ -4,6 +4,7 @@ from spikeinterface import download_dataset +from spikeinterface.sortingcomponents.motion_utils import Motion from spikeinterface.sortingcomponents.motion_interpolation import ( correct_motion_on_peaks, interpolate_motion_on_traces, @@ -20,21 +21,25 @@ def make_fake_motion(rec): - # make a fake motion vector + # make a fake motion object duration = rec.get_total_duration() locs = rec.get_channel_locations() temporal_bins = np.arange(0.5, duration - 0.49, 0.5) spatial_bins = np.arange(locs[:, 1].min(), locs[:, 1].max(), 100) - motion = np.zeros((temporal_bins.size, spatial_bins.size)) - motion[:, :] = np.linspace(-30, 30, temporal_bins.size)[:, None] + displacament = np.zeros((temporal_bins.size, spatial_bins.size)) + displacament[:, :] = np.linspace(-30, 30, temporal_bins.size)[:, None] - return motion, temporal_bins, spatial_bins + motion = Motion([displacament], [temporal_bins], spatial_bins, direction="y") + + return motion def test_correct_motion_on_peaks(): rec, sorting = make_dataset() peaks = sorting.to_spike_vector() - motion, temporal_bins, spatial_bins = make_fake_motion(rec) + print(peaks.dtype) + motion = make_fake_motion(rec) + # print(motion) # fake locations peak_locations = np.zeros((peaks.size), dtype=[("x", "float32"), ("y", "float")]) @@ -44,24 +49,24 @@ def test_correct_motion_on_peaks(): peak_locations, rec.sampling_frequency, motion, - temporal_bins, - spatial_bins, - direction="y", ) # print(corrected_peak_locations) assert np.any(corrected_peak_locations["y"] != 0) # import matplotlib.pyplot as plt # fig, ax = plt.subplots() - # ax.plot(times[peaks['sample_index']], corrected_peak_locations['y']) - # ax.plot(temporal_bins, motion[:, 1]) + # segment_index = 0 + # times = rec.get_times(segment_index=segment_index) + # ax.scatter(times[peaks['sample_index']], corrected_peak_locations['y']) + # ax.plot(motion.temporal_bins_s[segment_index], motion.displacement[segment_index][:, 1]) # plt.show() + def test_interpolate_motion_on_traces(): rec, sorting = make_dataset() - motion, temporal_bins, spatial_bins = make_fake_motion(rec) + motion = make_fake_motion(rec) channel_locations = rec.get_channel_locations() @@ -74,12 +79,10 @@ def test_interpolate_motion_on_traces(): times, channel_locations, motion, - temporal_bins, - spatial_bins, - direction=1, channel_inds=None, spatial_interpolation_method=method, - spatial_interpolation_kwargs={}, + # spatial_interpolation_kwargs={}, + spatial_interpolation_kwargs={"force_extrapolate": True}, ) assert traces.shape == traces_corrected.shape assert traces.dtype == traces_corrected.dtype @@ -87,15 +90,15 @@ def test_interpolate_motion_on_traces(): def test_InterpolateMotionRecording(): rec, sorting = make_dataset() - motion, temporal_bins, spatial_bins = make_fake_motion(rec) + motion = make_fake_motion(rec) - rec2 = InterpolateMotionRecording(rec, motion, temporal_bins, spatial_bins, border_mode="force_extrapolate") + rec2 = InterpolateMotionRecording(rec, motion, border_mode="force_extrapolate") assert rec2.channel_ids.size == 32 - rec2 = InterpolateMotionRecording(rec, motion, temporal_bins, spatial_bins, border_mode="force_zeros") + rec2 = InterpolateMotionRecording(rec, motion, border_mode="force_zeros") assert rec2.channel_ids.size == 32 - rec2 = InterpolateMotionRecording(rec, motion, temporal_bins, spatial_bins, border_mode="remove_channels") + rec2 = InterpolateMotionRecording(rec, motion, border_mode="remove_channels") assert rec2.channel_ids.size == 24 for ch_id in (0, 1, 14, 15, 16, 17, 30, 31): assert ch_id not in rec2.channel_ids @@ -116,6 +119,6 @@ def test_InterpolateMotionRecording(): if __name__ == "__main__": - test_correct_motion_on_peaks() - test_interpolate_motion_on_traces() + # test_correct_motion_on_peaks() + # test_interpolate_motion_on_traces() test_InterpolateMotionRecording() diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py index dc826ce773..289e3bfe57 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py @@ -21,6 +21,7 @@ def test_Motion(): motion = Motion( displacement, temporal_bins_s, spatial_bins_um, direction="y" ) + print(motion) motion2 = Motion(**motion.to_dict()) assert motion == motion2 From f53b824b389d3c3ea03a6253acb4c779a6a4a522 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Sun, 26 May 2024 22:03:04 +0200 Subject: [PATCH 011/136] Propagate Motion object to preprocessing. --- src/spikeinterface/preprocessing/motion.py | 28 ++++----- .../preprocessing/tests/test_motion.py | 5 +- .../sortingcomponents/motion_utils.py | 63 ++++++++++++++++--- .../tests/test_motion_utils.py | 29 +++++++-- 4 files changed, 92 insertions(+), 33 deletions(-) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 3956cfc17d..a5300ccadc 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -69,7 +69,7 @@ weight_with_amplitude=False, ), "interpolate_motion_kwargs": dict( - direction=1, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 + border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 ), }, "nonrigid_fast_and_accurate": { @@ -128,7 +128,7 @@ weight_with_amplitude=False, ), "interpolate_motion_kwargs": dict( - direction=1, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 + border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 ), }, # This preset is a super fast rigid estimation with center of mass @@ -153,7 +153,7 @@ rigid=True, ), "interpolate_motion_kwargs": dict( - direction=1, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 + border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 ), }, # This preset try to mimic kilosort2.5 motion estimator @@ -187,7 +187,7 @@ win_shape="rect", ), "interpolate_motion_kwargs": dict( - direction=1, border_mode="force_extrapolate", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 + border_mode="force_extrapolate", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 ), }, # empty preset @@ -380,22 +380,17 @@ def correct_motion( np.save(folder / "peak_locations.npy", peak_locations) t0 = time.perf_counter() - motion, temporal_bins, spatial_bins = estimate_motion(recording, peaks, peak_locations, **estimate_motion_kwargs) + motion = estimate_motion(recording, peaks, peak_locations, **estimate_motion_kwargs) t1 = time.perf_counter() run_times["estimate_motion"] = t1 - t0 recording_corrected = InterpolateMotionRecording( - recording, motion, temporal_bins, spatial_bins, **interpolate_motion_kwargs + recording, motion, **interpolate_motion_kwargs ) if folder is not None: (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - - # TODO save Motion - np.save(folder / "temporal_bins.npy", temporal_bins) - np.save(folder / "motion.npy", motion) - if spatial_bins is not None: - np.save(folder / "spatial_bins.npy", spatial_bins) + motion.save(folder / "motion") if output_motion_info: motion_info = dict( @@ -403,9 +398,6 @@ def correct_motion( run_times=run_times, peaks=peaks, peak_locations=peak_locations, - # TODO use Motion - temporal_bins=temporal_bins, - spatial_bins=spatial_bins, motion=motion, ) return recording_corrected, motion_info @@ -424,6 +416,8 @@ def correct_motion( def load_motion_info(folder): + from spikeinterface.sortingcomponents.motion_utils import Motion + folder = Path(folder) motion_info = {} @@ -434,11 +428,13 @@ def load_motion_info(folder): with open(folder / "run_times.json") as f: motion_info["run_times"] = json.load(f) - array_names = ("peaks", "peak_locations", "temporal_bins", "spatial_bins", "motion") + array_names = ("peaks", "peak_locations") for name in array_names: if (folder / f"{name}.npy").exists(): motion_info[name] = np.load(folder / f"{name}.npy") else: motion_info[name] = None + + motion_info["motion"] = Motion.load(folder / "motion") return motion_info diff --git a/src/spikeinterface/preprocessing/tests/test_motion.py b/src/spikeinterface/preprocessing/tests/test_motion.py index 7cea531bb4..c2b8d0024e 100644 --- a/src/spikeinterface/preprocessing/tests/test_motion.py +++ b/src/spikeinterface/preprocessing/tests/test_motion.py @@ -25,6 +25,7 @@ def test_estimate_and_correct_motion(): folder = cache_folder / "estimate_and_correct_motion" if folder.exists(): shutil.rmtree(folder) + rec_corrected = correct_motion(rec, folder=folder) print(rec_corrected) @@ -33,5 +34,5 @@ def test_estimate_and_correct_motion(): if __name__ == "__main__": - print(correct_motion.__doc__) - # test_estimate_and_correct_motion() + # print(correct_motion.__doc__) + test_estimate_and_correct_motion() diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py index 6bf10372d5..ddb6c2d8ae 100644 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -1,23 +1,30 @@ +import json +from pathlib import Path import numpy as np +import spikeinterface +from spikeinterface.core.core_tools import check_json + # @charlie @sam # here TODO list for motion object # * simple test for Motion: DONE -# * save/load Motion -# * make better test for Motion object with save/load +# * save/load Motion DONE +# * make simple test for Motion object with save/load DONE # * propagate to estimate_motion : DONE # * handle multi segment in estimate_motion(): maybe in another PR -# * propagate to motion_interpolation.py: -# * propagate to preprocessing/correct_motion() +# * propagate to motion_interpolation.py: ALMOST DONE +# * propagate to preprocessing/correct_motion(): # * generate drifting signals for test estimate_motion and interpolate_motion # * uncomment assert in test_estimate_motion (aka debug torch vs numpy diff) # * delegate times to recording object in # * estimate motion # * correct_motion_on_peaks() # * interpolate_motion_on_traces() +# update plot_motion() dans widget +# @@ -126,14 +133,50 @@ def to_dict(self): spatial_bins_um=self.spatial_bins_um, ) - def save(self): - # TODO - pass + def save(self, folder): + folder = Path(folder) + + folder.mkdir(exist_ok=False, parents=True) + + info_file = folder / f"spikeinterface_info.json" + info = dict( + version=spikeinterface.__version__, + dev_mode=spikeinterface.DEV_MODE, + object="Motion", + num_segments=self.num_segments, + direction=self.direction, + ) + with open(info_file, mode="w") as f: + json.dump(check_json(info), f, indent=4) + + np.save(folder / "spatial_bins_um.npy", self.spatial_bins_um) + + for segment_index in range(self.num_segments): + np.save(folder / f"displacement_seg{segment_index}.npy", self.displacement[segment_index]) + np.save(folder / f"temporal_bins_s_seg{segment_index}.npy", self.temporal_bins_s[segment_index]) @classmethod - def load(cls): - # TODO - pass + def load(cls, folder): + folder = Path(folder) + + info_file = folder / f"spikeinterface_info.json" + if not info_file.exists(): + raise IOError("Motion.load(folder) : the folder do not contain Motion") + + with open(info_file, "r") as f: + info = json.load(f) + if info["object"] != "Motion": + raise IOError("Motion.load(folder) : the folder do not contain Motion") + + direction = info["direction"] + spatial_bins_um = np.load(folder / "spatial_bins_um.npy") + displacement = [] + temporal_bins_s = [] + for segment_index in range(info["num_segments"]): + displacement.append(np.load(folder / f"displacement_seg{segment_index}.npy")) + temporal_bins_s.append(np.load(folder / f"temporal_bins_s_seg{segment_index}.npy")) + + return cls(displacement, temporal_bins_s, spatial_bins_um, direction=direction) def __eq__(self, other): diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py index 289e3bfe57..289a8a12cb 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py @@ -1,13 +1,15 @@ - - -# TODO Motion Make some test - import pytest import numpy as np +import pickle +from pathlib import Path +import shutil from spikeinterface.sortingcomponents.motion_utils import Motion - +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "sortingcomponents" +else: + cache_folder = Path("cache_folder") / "sortingcomponents" def test_Motion(): @@ -23,9 +25,18 @@ def test_Motion(): ) print(motion) + # serialize with pickle before interpolation fit + motion2 = pickle.loads(pickle.dumps(motion)) + assert motion2.interpolator == None + # serialize with pickle after interpolation fit + motion.make_interpolators() + motion2 = pickle.loads(pickle.dumps(motion)) + + # to/from dict motion2 = Motion(**motion.to_dict()) assert motion == motion2 + # do interpolate displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11, ], [120., 80., 150.]) # print(displacement) assert displacement.shape[0] == 3 @@ -33,6 +44,14 @@ def test_Motion(): assert displacement[2] == 20. + # save/load to folder + folder = cache_folder / "motion_saved" + if folder.exists(): + shutil.rmtree(folder) + motion.save(folder) + motion2 = Motion.load(folder) + assert motion == motion2 + if __name__ == "__main__": From fd80fd6596faedd4f11c37e8418e29ac01343466 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 28 May 2024 08:20:32 +0200 Subject: [PATCH 012/136] updata todo --- src/spikeinterface/sortingcomponents/motion_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py index ddb6c2d8ae..6a033b6818 100644 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -16,15 +16,15 @@ # * propagate to estimate_motion : DONE # * handle multi segment in estimate_motion(): maybe in another PR # * propagate to motion_interpolation.py: ALMOST DONE -# * propagate to preprocessing/correct_motion(): +# * propagate to preprocessing/correct_motion(): ALMOST DONE # * generate drifting signals for test estimate_motion and interpolate_motion # * uncomment assert in test_estimate_motion (aka debug torch vs numpy diff) # * delegate times to recording object in # * estimate motion # * correct_motion_on_peaks() # * interpolate_motion_on_traces() +# propagate to benchmark estimate motion # update plot_motion() dans widget -# From 597a094ba99293ba24594cb5871d00caed0b14f8 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 29 May 2024 16:06:30 +0100 Subject: [PATCH 013/136] Add grid option to Motion with a test --- .../sortingcomponents/motion_utils.py | 134 ++++++++++++------ .../tests/test_motion_utils.py | 20 ++- 2 files changed, 103 insertions(+), 51 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py index 6a033b6818..71cde08689 100644 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -1,13 +1,10 @@ import json from pathlib import Path -import numpy as np - +import numpy as np import spikeinterface from spikeinterface.core.core_tools import check_json - - # @charlie @sam # here TODO list for motion object # * simple test for Motion: DONE @@ -27,18 +24,16 @@ # update plot_motion() dans widget - class Motion: """ Motion of the tissue relative the probe. Parameters ---------- - displacement: numpy array 2d or list of Motion estimate in um. List is the number of segment. - For each semgent : + For each semgent : * shape (temporal bins, spatial bins) * motion.shape[0] = temporal_bins.shape[0] * motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) @@ -48,9 +43,12 @@ class Motion: Windows center. spatial_bins_um.shape[0] == displacement.shape[1] If rigid then spatial_bins_um.shape[0] == 1 - + interpolation_method : str + How to determine the displacement between bin centers? See the docs + for scipy.interpolate.RegularGridInterpolator for options. """ - def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y"): + + def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y", interpolation_method="linear"): if isinstance(displacement, np.ndarray): self.displacement = [displacement] assert isinstance(temporal_bins_s, np.ndarray) @@ -64,47 +62,68 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y" self.spatial_bins_um = spatial_bins_um self.num_segments = len(self.displacement) - self.interpolator = None - + self.interpolators = None + self.interpolation_method = interpolation_method + self.direction = direction self.dim = ["x", "y", "z"].index(direction) - + self.check_properties() + + def check_properties(self): + assert all(d.ndim == 2 for d in self.displacement) + assert all(t.ndim == 1 for t in self.temporal_bins_s) + assert all(self.spatial_bins_um.shape == (d.shape[1],) for d in self.displacement) + def __repr__(self): nbins = self.spatial_bins_um.shape[0] if nbins == 1: rigid_txt = "rigid" else: rigid_txt = f"non-rigid - {nbins} spatial bins" - + interval_s = self.temporal_bins_s[0][1] - self.temporal_bins_s[0][0] - txt = f"Motion {rigid_txt} - interval {interval_s}s -{self.num_segments} segments" + txt = f"Motion {rigid_txt} - interval {interval_s}s - {self.num_segments} segments" return txt - def make_interpolators(self): from scipy.interpolate import RegularGridInterpolator - self.interpolator = [ - RegularGridInterpolator((self.temporal_bins_s[j], self.spatial_bins_um), self.displacement[j]) + + self.interpolators = [ + RegularGridInterpolator( + (self.temporal_bins_s[j], self.spatial_bins_um), self.displacement[j], method=self.interpolation_method + ) for j in range(self.num_segments) ] self.temporal_bounds = [(t[0], t[-1]) for t in self.temporal_bins_s] self.spatial_bounds = (self.spatial_bins_um.min(), self.spatial_bins_um.max()) - - def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_index=None): - """ + def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_index=None, grid=False): + """Evaluate the motion estimate at times and positions + + Evaluate the motion estimate, returning the (linearly interpolated) estimated displacement + at the given times and locations. Parameters ---------- times_s: np.array - - locations_um: np.array - - segment_index: - + Either this is a one-dimensional array (a vector of positions along self.dimension), or + else a 2d array with the 2 or 3 spatial dimensions indexed along axis=1. + segment_index: int, optional + grid : bool + If grid=False, the default, then times_s and locations_um should have the same one-dimensional + shape, and the returned displacement[i] is the displacement at time times_s[i] and location + locations_um[i]. + If grid=True, times_s and locations_um determine a grid of positions to evaluate the displacement. + Then the returned displacement[i,j] is the displacement at depth locations_um[i] and time times_s[j]. + + Returns + ------- + displacement : np.array + A displacement per input location, of shape times_s.shape if grid=False and (locations_um.size, times_s.size) + if grid=True. """ - if self.interpolator is None: + if self.interpolators is None: self.make_interpolators() if segment_index is None: @@ -112,30 +131,49 @@ def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_inde segment_index = 0 else: raise ValueError("Several segment need segment_index=") - + times_s = np.asarray(times_s) - locations_um = np.asarray(times_s) + locations_um = np.asarray(locations_um) if locations_um.ndim == 1: locations_um = locations_um - else: + elif locations_um.ndim == 2: locations_um = locations_um[:, self.dim] - times_s = np.clip(times_s, *self.temporal_bounds[segment_index]) - locations_um = np.clip(locations_um, *self.spatial_bounds) - points = np.stack([times_s, locations_um,], axis=1) + else: + assert False + + times_s = times_s.clip(*self.temporal_bounds[segment_index]) + locations_um = locations_um.clip(*self.spatial_bounds) + + if grid: + # construct a grid over which to evaluate the displacement + locations_um, times_s = np.meshgrid(locations_um, times_s, indexing="ij") + out_shape = times_s.shape + locations_um = locations_um.ravel() + times_s = times_s.ravel() + else: + # usual case: input is a point cloud + assert locations_um.shape == times_s.shape + assert times_s.ndim == 1 + out_shape = times_s.shape + + points = np.column_stack((times_s, locations_um)) + displacement = self.interpolators[segment_index](points) + # reshape to grid domain shape if necessary + displacement = displacement.reshape(out_shape) - return self.interpolator[segment_index](points) + return displacement def to_dict(self): return dict( displacement=self.displacement, temporal_bins_s=self.temporal_bins_s, spatial_bins_um=self.spatial_bins_um, + interpolation_method=self.interpolation_method, ) - + def save(self, folder): folder = Path(folder) - folder.mkdir(exist_ok=False, parents=True) info_file = folder / f"spikeinterface_info.json" @@ -145,6 +183,7 @@ def save(self, folder): object="Motion", num_segments=self.num_segments, direction=self.direction, + interpolation_method=self.interpolation_method, ) with open(info_file, mode="w") as f: json.dump(check_json(info), f, indent=4) @@ -160,33 +199,40 @@ def load(cls, folder): folder = Path(folder) info_file = folder / f"spikeinterface_info.json" + err_msg = f"Motion.load(folder): the folder {folder} does not contain a Motion object." if not info_file.exists(): - raise IOError("Motion.load(folder) : the folder do not contain Motion") - + raise IOError(err_msg) + with open(info_file, "r") as f: info = json.load(f) - if info["object"] != "Motion": - raise IOError("Motion.load(folder) : the folder do not contain Motion") + if "object" not in info or info["object"] != "Motion": + raise IOError(err_msg) direction = info["direction"] + interpolation_method = info["interpolation_method"] spatial_bins_um = np.load(folder / "spatial_bins_um.npy") displacement = [] temporal_bins_s = [] for segment_index in range(info["num_segments"]): displacement.append(np.load(folder / f"displacement_seg{segment_index}.npy")) temporal_bins_s.append(np.load(folder / f"temporal_bins_s_seg{segment_index}.npy")) - - return cls(displacement, temporal_bins_s, spatial_bins_um, direction=direction) - def __eq__(self, other): + return cls( + displacement, + temporal_bins_s, + spatial_bins_um, + direction=direction, + interpolation_method=interpolation_method, + ) + def __eq__(self, other): for segment_index in range(self.num_segments): if not np.allclose(self.displacement[segment_index], other.displacement[segment_index]): return False if not np.allclose(self.temporal_bins_s[segment_index], other.temporal_bins_s[segment_index]): return False - + if not np.allclose(self.spatial_bins_um, other.spatial_bins_um): return False - + return True diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py index 289a8a12cb..a170245d7d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py @@ -1,9 +1,9 @@ -import pytest -import numpy as np import pickle -from pathlib import Path import shutil +from pathlib import Path +import numpy as np +import pytest from spikeinterface.sortingcomponents.motion_utils import Motion if hasattr(pytest, "global_test_folder"): @@ -27,22 +27,29 @@ def test_Motion(): # serialize with pickle before interpolation fit motion2 = pickle.loads(pickle.dumps(motion)) - assert motion2.interpolator == None + assert motion2.interpolator is None # serialize with pickle after interpolation fit motion.make_interpolators() + assert motion2.interpolator is not None motion2 = pickle.loads(pickle.dumps(motion)) + assert motion2.interpolator is not None # to/from dict motion2 = Motion(**motion.to_dict()) assert motion == motion2 + assert motion2.interpolator is None # do interpolate - displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11, ], [120., 80., 150.]) + displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11], [120., 80., 150.]) # print(displacement) assert displacement.shape[0] == 3 # check clip assert displacement[2] == 20. + # interpolate grid + displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11, 15, 19], [150., 80.], grid=True) + assert displacement.shape == (2, 5) + assert displacement[0, 2] == 20. # save/load to folder folder = cache_folder / "motion_saved" @@ -53,6 +60,5 @@ def test_Motion(): assert motion == motion2 - if __name__ == "__main__": - test_Motion() \ No newline at end of file + test_Motion() From e6678cbe99333254ae614560241908fd4d11745c Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 29 May 2024 16:41:19 +0100 Subject: [PATCH 014/136] Dtype handling in interpolation; Flexible time bins; fast time bins logic --- .../sortingcomponents/motion_interpolation.py | 232 ++++++++++-------- 1 file changed, 128 insertions(+), 104 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index b209cb31bc..4a5a2b0c47 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -1,14 +1,12 @@ from __future__ import annotations import numpy as np -import scipy.interpolate -from tqdm import tqdm - -import scipy.spatial - from spikeinterface.core.core_tools import define_function_from_class -from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment from spikeinterface.preprocessing import get_spatial_interpolation_kernel +from spikeinterface.preprocessing.basepreprocessor import ( + BasePreprocessor, BasePreprocessorSegment) + +from .filter import fix_dtype def correct_motion_on_peaks( @@ -39,7 +37,7 @@ def correct_motion_on_peaks( corrected_peak_locations = peak_locations.copy() for segment_index in range(motion.num_segments): - i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) + i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) # TODO delegate times to recording object spike_times = peaks["sample_index"][i0:i1] / sampling_frequency @@ -50,7 +48,7 @@ def correct_motion_on_peaks( corrected_peak_locations[i0:i1][motion.direction] -= spike_displacement return corrected_peak_locations - + def interpolate_motion_on_traces( traces, @@ -59,6 +57,7 @@ def interpolate_motion_on_traces( motion, segment_index=None, channel_inds=None, + interpolation_time_bin_centers_s=None, spatial_interpolation_method="kriging", spatial_interpolation_kwargs={}, ): @@ -71,6 +70,8 @@ def interpolate_motion_on_traces( ---------- traces : np.array Trace snippet (num_samples, num_channels) + times : np.array + Sample times in seconds for the frames of the traces snippet channel_location: np.array 2d Channel location with shape (n, 2) or (n, 3) motion: Motion @@ -79,6 +80,9 @@ def interpolate_motion_on_traces( The segment index. channel_inds: None or list If not None, interpolate only a subset of channels. + interpolation_time_bin_centers_s : None or np.array + Manually specify the time bins which the interpolation happens + in for this segment. If None, these are the motion estimate's time bins. spatial_interpolation_method: "idw" | "kriging", default: "kriging" The spatial interpolation method used to interpolate the channel locations: * idw : Inverse Distance Weighing @@ -88,9 +92,8 @@ def interpolate_motion_on_traces( Returns ------- - channel_motions: np.array - Shift over time by channel - Shape (times.shape[0], channel_location.shape[0]) + traces_corrected: np.array + Motion-corrected trace snippet, (num_samples, num_channels) """ # assert HAVE_NUMBA assert times.shape[0] == traces.shape[0] @@ -106,53 +109,45 @@ def interpolate_motion_on_traces( else: channel_inds = np.asarray(channel_inds) traces_corrected = np.zeros((traces.shape[0], channel_inds.size), dtype=traces.dtype) - - total_num_chans = channel_locations.shape[0] - # TODO give optional possibility to have smaler times bins than the motion with interpolation - # this would remove the need of _get_closest_ind and searchsorted + total_num_chans = channel_locations.shape[0] - # TODO delegate times to recording, at the moment this is 0 based - # regroup times by closet temporal_bins - bin_inds = _get_closest_ind(motion.temporal_bins_s[segment_index], times) + # -- determine the blocks of frames that will land in the same interpolation time bin + time_bins = interpolation_time_bin_centers_s + if time_bins is None: + time_bins = motion.temporal_bins_s[segment_index] + bin_s = time_bins[1] - time_bins + bins_start = time_bins[0] - 0.5 * bin_s + # nearest bin center for each frame? + bin_inds = (times - bins_start) // bin_s + # the time bins may not cover the whole set of times in the recording, + # so we need to clip these indices to the valid range + np.clip(bin_inds, 0, time_bins.size, out=bin_inds) + + # -- what are the possibilities here anyway? + bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1) # inperpolation kernel will be the same per temporal bin - for bin_ind in np.unique(bin_inds): - - bin_time = motion.temporal_bins_s[segment_index][bin_ind] - + interp_times = np.empty(total_num_chans) + current_start_index = 0 + for bin_ind in bins_here: + bin_time = time_bins[bin_ind] + interp_times.fill(bin_time) channel_motions = motion.get_displacement_at_time_and_depth( - np.full(total_num_chans, bin_time), + interp_times, channel_locations[motion.dim], - segment_index=segment_index + segment_index=segment_index, ) channel_locations_moved = channel_locations.copy() channel_locations_moved[:, motion.dim] += channel_motions - # # TODO use # TODO : add Motion.get_displacement_at_time_and_depth() instead - - # # Step 1 : channel motion - # if spatial_bins.shape[0] == 1: - # # rigid motion : same motion for all channels - # channel_motions = motion[bin_ind, 0] - # else: - # # non rigid : interpolation channel motion for this temporal bin - # f = scipy.interpolate.interp1d( - # spatial_bins, motion[bin_ind, :], kind="linear", axis=0, bounds_error=False, fill_value="extrapolate" - # ) - # locs = channel_locations[:, direction] - # channel_motions = f(locs) - # channel_locations_moved = channel_locations.copy() - # channel_locations_moved[:, direction] += channel_motions - # # channel_locations_moved[:, direction] -= channel_motions - if channel_inds is not None: channel_locations_moved = channel_locations_moved[channel_inds] drift_kernel = get_spatial_interpolation_kernel( channel_locations, channel_locations_moved, - dtype="float32", + dtype=traces.dtype, method=spatial_interpolation_method, **spatial_interpolation_kwargs, ) @@ -164,20 +159,21 @@ def interpolate_motion_on_traces( # ax.set_title(f"bin_ind {bin_ind} - {bin_time}s - {spatial_interpolation_method}") # plt.show() - # i0 = np.searchsorted(bin_inds, bin_ind, side="left") - # i1 = np.searchsorted(bin_inds, bin_ind, side="right") - i0, i1 = np.searchsorted(bin_inds, [bin_ind, bin_ind + 1], side="left") + # quickly find the end of this bin, which is also the start of the next + next_start_index = current_start_index + np.searchsorted(bin_inds[current_start_index:], bin_ind + 1, side="left") + in_bin = slice(current_start_index, next_start_index) # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing # in ChunkRecordingExecutor) - traces_corrected[i0:i1] = traces[i0:i1] @ drift_kernel + np.matmul(traces[in_bin], drift_kernel, out=traces_corrected[in_bin]) + current_start_index = next_start_index return traces_corrected # if HAVE_NUMBA: -# # @numba.jit(parallel=False) +# # @numba.jit(parallel=False) # @numba.jit(parallel=True) # def my_sparse_dot(data_in, data_out, sparse_chans, weights): # """ @@ -192,7 +188,7 @@ def interpolate_motion_on_traces( # num_samples = data_in.shape[0] # num_chan_out = data_out.shape[1] # num_sparse = sparse_chans.shape[1] -# # for sample_index in range(num_samples): +# # for sample_index in range(num_samples): # for sample_index in numba.prange(num_samples): # for out_chan in range(num_chan_out): # v = 0 @@ -219,9 +215,18 @@ def _get_closest_ind(array, values): class InterpolateMotionRecording(BasePreprocessor): """ - Recording that corrects motion on-the-fly given a motion vector estimation (rigid or non-rigid). - This internally applies a spatial interpolation on the original traces after reversing the motion. - `estimate_motion()` must be called before this to estimate the motion vector. + Interpolate the input recording's traces to correct for motion, according to the + motion estimate object `motion`. The interpolation is carried out "lazily" / on the fly + by applying a spatial interpolation on the original traces to estimate their values + at the positions of the probe's channels after being shifted inversely to the motion. + + To get a Motion object, use `interpolate_motion()`. + + By default, each frame is spatially interpolated by the motion at the nearest motion + estimation time bin -- in other words, the temporal resolution of the motion correction + is the same as the motion estimation's. However, this behavior can be changed by setting + `interpolation_time_bin_centers_s` or `interpolation_time_bin_size_s` below. In that case, + the motion estimate will be interpolated to match the interpolation time bins. Parameters ---------- @@ -245,10 +250,22 @@ class InterpolateMotionRecording(BasePreprocessor): Number of closest channels used by "idw" method for interpolation. border_mode: "remove_channels" | "force_extrapolate" | "force_zeros", default: "remove_channels" Control how channels are handled on border: - * "remove_channels": remove channels on the border, the recording has less channels * "force_extrapolate": keep all channel and force extrapolation (can lead to strange signal) * "force_zeros": keep all channel but set zeros when outside (force_extrapolate=False) + interpolation_time_bin_centers_s: np.array or list of np.array, optional + Spatially interpolate each frame according to the displacement estimate at its closest + bin center in this array. If not supplied, this is set to the motion estimate's time bin + centers. If it's supplied, the motion estimate is interpolated to these bin centers. + If you have a multi-segment recording, pass a list of these, one per segment. + interpolation_time_bin_size_s: float, optional + Similar to the previous argument: interpolation_time_bin_centers_s will be constructed + by bins spaced by interpolation_time_bin_size_s. This is ignored if interpolation_time_bin_centers_s + is supplied. + dtype : str or np.dtype, optional + Interpolation needs to convert to a floating dtype. If dtype is supplied, that will be used. + If the input recording is already floating and dtype=None, then its dtype is used by default. + If the input recording is integer, then float32 is used by default. Returns ------- @@ -267,14 +284,12 @@ def __init__( sigma_um=20.0, p=1, num_closest=3, + interpolation_time_bin_centers_s=None, + interpolation_time_bin_size_s=None, + dtype=None, ): # assert recording.get_num_segments() == 1, "correct_motion() is only available for single-segment recordings" - # # force as arrays - # temporal_bins = np.asarray(temporal_bins) - # motion = np.asarray(motion) - # spatial_bins = np.asarray(spatial_bins) - channel_locations = recording.get_channel_locations() assert channel_locations.ndim >= motion.dim, ( f"'direction' {motion.direction} not available. " f"Channel locations have {channel_locations.ndim} dimensions." @@ -284,35 +299,21 @@ def __init__( locs = channel_locations[:, motion.dim] l0, l1 = np.min(locs), np.max(locs) - # compute max and min motion (with interpolation) - # and check if channels are inside for all segment + # check if channels stay inside the probe extents for all segments channel_inside = np.ones(locs.shape[0], dtype="bool") - for operator, arg_operator in ((np.max,np.argmax), (np.min, np.argmin)): - for segment_index in range(recording.get_num_segments()): - ind = arg_operator(operator(motion.displacement[segment_index], axis=1)) - bin_time = motion.temporal_bins_s[segment_index][ind] - best_motions = motion.get_displacement_at_time_and_depth( - np.full(locs.shape[0], bin_time), locs, segment_index=segment_index - ) - channel_inside &= ((locs + best_motions) >= l0) & ((locs + best_motions) <= l1) - - - # if spatial_bins.shape[0] == 1: - # best_motions = operator(motion[:, 0]) - # else: - # # non rigid : interpolation channel motion for this temporal bin - # f = scipy.interpolate.interp1d( - # spatial_bins, - # operator(motion[:, :], axis=0), - # kind="linear", - # axis=0, - # bounds_error=False, - # fill_value="extrapolate", - # ) - # best_motions = f(locs) - # channel_inside &= ((locs + best_motions) >= l0) & ((locs + best_motions) <= l1) - - (channel_inds,) = np.nonzero(channel_inside) + for segment_index in range(recording.get_num_segments()): + # evaluate the positions of all channels over all time bins + channel_locations = motion.get_displacement_at_time_and_depth( + times_s=motion.temporal_bins_s[segment_index], + locations_um=locs, + grid=True, + ) + # check if these remain inside of the probe + seg_inside = channel_locations.clip(l0, l1) == channel_locations + seg_inside = seg_inside.all(axis=1) + channel_inside &= seg_inside + + channel_inds = np.flatnonzero(channel_inside) channel_ids = recording.channel_ids[channel_inds] spatial_interpolation_kwargs["force_extrapolate"] = False elif border_mode == "force_extrapolate": @@ -326,7 +327,10 @@ def __init__( else: raise ValueError("Wrong border_mode") - BasePreprocessor.__init__(self, recording, channel_ids=channel_ids) + if dtype is None and recording.dtype.kind != "f": + dtype = "float32" + dtype_ = fix_dtype(recording, dtype) + BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_) if border_mode == "remove_channels": # change the wiring of the probe @@ -336,7 +340,23 @@ def __init__( contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") self.set_property("contact_vector", contact_vector) - for parent_segment in recording._recording_segments: + # handle manual interpolation_time_bin_centers_s + # the case where interpolation_time_bin_size_s is set is handled per-segment below + if interpolation_time_bin_centers_s is None: + if interpolation_time_bin_size_s is None: + interpolation_time_bin_centers_s = motion.temporal_bins_s + + for segment_index, parent_segment in enumerate(recording._recording_segments): + # finish the per-segment part of the time bin logic + if interpolation_time_bin_centers_s is None: + # in this case, interpolation_time_bin_size_s is set. + s_end = parent_segment.get_num_samples() + t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end])) + halfbin = interpolation_time_bin_size_s / 2. + segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s) + else: + segment_interpolation_time_bins_s = interpolation_time_bin_centers_s[segment_index] + rec_segment = InterpolateMotionRecordingSegment( parent_segment, channel_locations, @@ -344,6 +364,9 @@ def __init__( spatial_interpolation_method, spatial_interpolation_kwargs, channel_inds, + segment_index, + segment_interpolation_time_bins_s, + dtype=dtype_, ) self.add_recording_segment(rec_segment) @@ -355,6 +378,8 @@ def __init__( sigma_um=sigma_um, p=p, num_closest=num_closest, + interpolation_time_bin_centers_s=interpolation_time_bin_centers_s, + dtype=dtype_.str, ) @@ -367,49 +392,48 @@ def __init__( spatial_interpolation_method, spatial_interpolation_kwargs, channel_inds, + segment_index, + interpolation_time_bin_centers_s, + dtype="float32", ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.channel_locations = channel_locations - self.motion = motion self.spatial_interpolation_method = spatial_interpolation_method self.spatial_interpolation_kwargs = spatial_interpolation_kwargs self.channel_inds = channel_inds + self.segment_index = segment_index + self.interpolation_time_bin_centers_s = interpolation_time_bin_centers_s + self.dtype = dtype def get_traces(self, start_frame, end_frame, channel_indices): - if self.time_vector is not None: + if self.has_time_vector(): raise NotImplementedError( - "time_vector for InterpolateMotionRecording do not work because temporal_bins start from 0" + "InterpolateMotionRecording does not yet support recordings with time_vectors." ) - # times = np.asarray(self.time_vector[start_frame:end_frame]) if start_frame is None: start_frame = 0 if end_frame is None: end_frame = self.get_num_samples() - times = np.arange(end_frame - start_frame, dtype="float64") - times /= self.sampling_frequency - t0 = start_frame / self.sampling_frequency - # if self.t_start is not None: - # t0 = t0 + self.t_start - times += t0 - + times = self.parent_recording_segment.sample_index_to_time(np.arange(start_frame, end_frame)) traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices=slice(None)) - - trace2 = interpolate_motion_on_traces( + traces = traces.astype(self.dtype) + traces = interpolate_motion_on_traces( traces, times, self.channel_locations, self.motion, channel_inds=self.channel_inds, - spatial_interpolation_method=self.spatial_interpolation_method, spatial_interpolation_kwargs=self.spatial_interpolation_kwargs, + interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s, + segment_index=self.segment_index, ) if channel_indices is not None: - trace2 = trace2[:, channel_indices] + traces = traces[:, channel_indices] - return trace2 + return traces -interpolate_motion = define_function_from_class(source_class=InterpolateMotionRecording, name="correct_motion") +interpolate_motion = define_function_from_class(source_class=InterpolateMotionRecording, name="interpolate_motion") From 47bb85b3027aa94538714ba06d434a0ee12b369b Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 29 May 2024 16:41:42 +0100 Subject: [PATCH 015/136] Small docs and cleaning --- .../sortingcomponents/motion_estimation.py | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py index 6925c7aede..3a8b75f8b3 100644 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion_estimation.py @@ -1,9 +1,13 @@ from __future__ import annotations -import numpy as np from tqdm.auto import tqdm, trange + +import numpy as np import scipy.interpolate +from .motion_utils import Motion +from .tools import make_multi_method_doc + try: import torch import torch.nn.functional as F @@ -12,11 +16,6 @@ except ImportError: HAVE_TORCH = False -from .tools import make_multi_method_doc -from .motion_utils import Motion - - - def estimate_motion( recording, @@ -59,7 +58,7 @@ def estimate_motion( **histogram section** direction: "x" | "y" | "z", default: "y" - Dimension on which the motion is estimated + Dimension on which the motion is estimated. "y" is depth along the probe. bin_duration_s: float, default: 10 Bin duration in second bin_um: float, default: 10 @@ -157,7 +156,7 @@ def estimate_motion( ) # replace nan by zeros - motion_array[np.isnan(motion_array)] = 0 + np.nan_to_num(motion_array, copy=False) if post_clean: motion_array = clean_motion_vector( @@ -177,15 +176,12 @@ def estimate_motion( # TODO handle multi segment motion = Motion([motion_array], [temporal_bins], non_rigid_window_centers, direction=direction) - if output_extra_check: return motion, extra_check else: return motion - - class DecentralizedRegistration: """ Method developed by the Paninski's group from Columbia university: @@ -339,7 +335,7 @@ def run( extra_check["spatial_hist_bin_edges"] = spatial_hist_bin_edges # temporal bins are bin center - temporal_bins = temporal_hist_bin_edges[:-1] + bin_duration_s // 2.0 + temporal_bins = 0.5 * (temporal_hist_bin_edges[1:] + temporal_hist_bin_edges[:-1]) motion = np.zeros((temporal_bins.size, len(non_rigid_windows)), dtype=np.float64) windows_iter = non_rigid_windows @@ -822,7 +818,6 @@ def compute_pairwise_displacement( """ Compute pairwise displacement """ - from scipy import sparse from scipy import linalg assert conv_engine in ("torch", "numpy"), f"'conv_engine' must be 'torch' or 'numpy'" From 5f8d49ba8f99a49570e135714c1a763b6f2c37c0 Mon Sep 17 00:00:00 2001 From: r_pr Date: Thu, 30 May 2024 13:17:11 +0200 Subject: [PATCH 016/136] WIP: Proposal of format to hold the manual curation information Took 1 hour 17 minutes --- .../curation/curation_format.py | 34 +++ .../curation/tests/test_curation_format.py | 225 ++++++++++++++++++ 2 files changed, 259 insertions(+) create mode 100644 src/spikeinterface/curation/curation_format.py create mode 100644 src/spikeinterface/curation/tests/test_curation_format.py diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py new file mode 100644 index 0000000000..ef10fb2c74 --- /dev/null +++ b/src/spikeinterface/curation/curation_format.py @@ -0,0 +1,34 @@ +from itertools import combinations + + +def validate_curation_dict(curation_dict): + """ + Validate that the curation dictionary given as parameter complies with the format + + Parameters + ---------- + curation_dict : dict + + + Returns + ------- + + """ + + unit_set = set(curation_dict['unit_ids']) + labeled_unit_set = set([lbl['unit_id'] for lbl in curation_dict['manual_labels']]) + merged_units_set = set(sum(curation_dict['merged_unit_groups'], [])) + removed_units_set = set(curation_dict['removed_units']) + if not labeled_unit_set.issubset(unit_set): + raise ValueError("Some labeled units are not in the unit list") + if not merged_units_set.issubset(unit_set): + raise ValueError("Some merged units are not in the unit list") + if not removed_units_set.issubset(unit_set): + raise ValueError("Some removed units are not in the unit list") + all_merging_groups = [set(group) for group in curation_dict['merged_unit_groups']] + for gp_1, gp_2 in combinations(all_merging_groups, 2): + if len(gp_1.intersection(gp_2)) != 0: + raise ValueError("Some units belong to multiple merge groups") + if len(removed_units_set.intersection(merged_units_set)) != 0: + raise ValueError("Some units were merged and deleted") + return True \ No newline at end of file diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py new file mode 100644 index 0000000000..0470161220 --- /dev/null +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -0,0 +1,225 @@ +from spikeinterface.curation.curation_format import validate_curation_dict +import pytest + + +"""example = { + 'unit_ids': List[str, int], + 'labels_definition': { + 'category_key1': + {'name': str, + 'labels': List[str], + 'auto_eclusive': bool} + }, + 'manual_labels': [ + {'unit_id': str or int, + 'label_category_key': str, + 'label_category_value': list or str + } + ], + 'merged_unit_groups': List[List[unit_ids]], # one cell goes into at most one list + 'removed_units': List[unit_ids] # Can not be in the merged_units +} +""" + +valid_int = { + 'unit_ids': [1, 2, 3, 6, 10, 14, 20, 31, 42], + 'labels_definition': { + 'quality': + {'name': 'quality', + 'labels': ['good', 'noise', 'MUA', 'artifact'], + 'auto_eclusive': True}, + 'experimental': + {'name': 'experimental', + 'labels': ['acute', 'chronic', 'headfixed', 'freelymoving'], + 'auto_eclusive': False} + }, + 'manual_labels': [ + {'unit_id': 1, + 'label_category_key': 'quality', + 'label_category_value': 'good' + }, + {'unit_id': 2, + 'label_category_key': 'quality', + 'label_category_value': 'noise' + }, + {'unit_id': 2, + 'label_category_key': 'experimental', + 'label_category_value': ['chronic', 'headfixed'] + }, + ], + 'merged_unit_groups': [[3, 6], [10, 14, 20]], # one cell goes into at most one list + 'removed_units': [31, 42] # Can not be in the merged_units +} + + +valid_str = { + 'unit_ids': ["u1", "u2", "u3", "u6", "u10", "u14", "u20", "u31", "u42"], + 'labels_definition': { + 'quality': + {'name': 'quality', + 'labels': ['good', 'noise', 'MUA', 'artifact'], + 'auto_eclusive': True}, + 'experimental': + {'name': 'experimental', + 'labels': ['acute', 'chronic', 'headfixed', 'freelymoving'], + 'auto_eclusive': False} + }, + 'manual_labels': [ + {'unit_id': "u1", + 'label_category_key': 'quality', + 'label_category_value': 'good' + }, + {'unit_id': "u2", + 'label_category_key': 'quality', + 'label_category_value': 'noise' + }, + {'unit_id': "u2", + 'label_category_key': 'experimental', + 'label_category_value': ['chronic', 'headfixed'] + }, + ], + 'merged_unit_groups': [["u3", "u6"], ["u10", "u14", "u20"]], # one cell goes into at most one list + 'removed_units': ["u31", "u42"] # Can not be in the merged_units +} + +# This is a failure example +duplicate_merge = { + 'unit_ids': [1, 2, 3, 6, 10, 14, 20, 31, 42], + 'labels_definition': { + 'quality': + {'name': 'quality', + 'labels': ['good', 'noise', 'MUA', 'artifact'], + 'auto_eclusive': True}, + 'experimental': + {'name': 'experimental', + 'labels': ['acute', 'chronic', 'headfixed', 'freelymoving'], + 'auto_eclusive': False} + }, + 'manual_labels': [ + {'unit_id': 1, + 'label_category_key': 'quality', + 'label_category_value': 'good' + }, + {'unit_id': 2, + 'label_category_key': 'quality', + 'label_category_value': 'noise' + }, + {'unit_id': 2, + 'label_category_key': 'experimental', + 'label_category_value': ['chronic', 'headfixed'] + }, + ], + 'merged_unit_groups': [[3, 6, 10], [10, 14, 20]], # one cell goes into at most one list + 'removed_units': [31, 42] # Can not be in the merged_units +} + + +# This is a failure example +merged_and_removed = { + 'unit_ids': [1, 2, 3, 6, 10, 14, 20, 31, 42], + 'labels_definition': { + 'quality': + {'name': 'quality', + 'labels': ['good', 'noise', 'MUA', 'artifact'], + 'auto_eclusive': True}, + 'experimental': + {'name': 'experimental', + 'labels': ['acute', 'chronic', 'headfixed', 'freelymoving'], + 'auto_eclusive': False} + }, + 'manual_labels': [ + {'unit_id': 1, + 'label_category_key': 'quality', + 'label_category_value': 'good' + }, + {'unit_id': 2, + 'label_category_key': 'quality', + 'label_category_value': 'noise' + }, + {'unit_id': 2, + 'label_category_key': 'experimental', + 'label_category_value': ['chronic', 'headfixed'] + }, + ], + 'merged_unit_groups': [[3, 6], [10, 14, 20]], # one cell goes into at most one list + 'removed_units': [3, 31, 42] # Can not be in the merged_units +} + + +unknown_merged_unit = { + 'unit_ids': [1, 2, 3, 6, 10, 14, 20, 31, 42], + 'labels_definition': { + 'quality': + {'name': 'quality', + 'labels': ['good', 'noise', 'MUA', 'artifact'], + 'auto_eclusive': True}, + 'experimental': + {'name': 'experimental', + 'labels': ['acute', 'chronic', 'headfixed', 'freelymoving'], + 'auto_eclusive': False} + }, + 'manual_labels': [ + {'unit_id': 1, + 'label_category_key': 'quality', + 'label_category_value': 'good' + }, + {'unit_id': 2, + 'label_category_key': 'quality', + 'label_category_value': 'noise' + }, + {'unit_id': 2, + 'label_category_key': 'experimental', + 'label_category_value': ['chronic', 'headfixed'] + }, + ], + 'merged_unit_groups': [[3, 6, 99], [10, 14, 20]], # one cell goes into at most one list + 'removed_units': [31, 42] # Can not be in the merged_units +} + + +unknown_removed_unit = { + 'unit_ids': [1, 2, 3, 6, 10, 14, 20, 31, 42], + 'labels_definition': { + 'quality': + {'name': 'quality', + 'labels': ['good', 'noise', 'MUA', 'artifact'], + 'auto_eclusive': True}, + 'experimental': + {'name': 'experimental', + 'labels': ['acute', 'chronic', 'headfixed', 'freelymoving'], + 'auto_eclusive': False} + }, + 'manual_labels': [ + {'unit_id': 1, + 'label_category_key': 'quality', + 'label_category_value': 'good' + }, + {'unit_id': 2, + 'label_category_key': 'quality', + 'label_category_value': 'noise' + }, + {'unit_id': 2, + 'label_category_key': 'experimental', + 'label_category_value': ['chronic', 'headfixed'] + }, + ], + 'merged_unit_groups': [[3, 6], [10, 14, 20]], # one cell goes into at most one list + 'removed_units': [31, 42, 99] # Can not be in the merged_units +} + + +def test_curation_format_validation(): + assert validate_curation_dict(valid_int) + assert validate_curation_dict(valid_str) + with pytest.raises(ValueError): + # Raised because duplicated merged units + validate_curation_dict(duplicate_merge) + with pytest.raises(ValueError): + # Raised because Some units belong to multiple merge groups" + validate_curation_dict(merged_and_removed) + with pytest.raises(ValueError): + # Some merged units are not in the unit list + validate_curation_dict(unknown_merged_unit) + with pytest.raises(ValueError): + # Raise beecause Some removed units are not in the unit list + validate_curation_dict(unknown_removed_unit) From fa2493b17ca57fb8eb72787d87742dddab8204a1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 May 2024 11:20:12 +0000 Subject: [PATCH 017/136] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../curation/curation_format.py | 12 +- .../curation/tests/test_curation_format.py | 234 +++++++----------- 2 files changed, 90 insertions(+), 156 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index ef10fb2c74..9b9c862cb6 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -15,20 +15,20 @@ def validate_curation_dict(curation_dict): """ - unit_set = set(curation_dict['unit_ids']) - labeled_unit_set = set([lbl['unit_id'] for lbl in curation_dict['manual_labels']]) - merged_units_set = set(sum(curation_dict['merged_unit_groups'], [])) - removed_units_set = set(curation_dict['removed_units']) + unit_set = set(curation_dict["unit_ids"]) + labeled_unit_set = set([lbl["unit_id"] for lbl in curation_dict["manual_labels"]]) + merged_units_set = set(sum(curation_dict["merged_unit_groups"], [])) + removed_units_set = set(curation_dict["removed_units"]) if not labeled_unit_set.issubset(unit_set): raise ValueError("Some labeled units are not in the unit list") if not merged_units_set.issubset(unit_set): raise ValueError("Some merged units are not in the unit list") if not removed_units_set.issubset(unit_set): raise ValueError("Some removed units are not in the unit list") - all_merging_groups = [set(group) for group in curation_dict['merged_unit_groups']] + all_merging_groups = [set(group) for group in curation_dict["merged_unit_groups"]] for gp_1, gp_2 in combinations(all_merging_groups, 2): if len(gp_1.intersection(gp_2)) != 0: raise ValueError("Some units belong to multiple merge groups") if len(removed_units_set.intersection(merged_units_set)) != 0: raise ValueError("Some units were merged and deleted") - return True \ No newline at end of file + return True diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 0470161220..6a6686f676 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -22,189 +22,123 @@ """ valid_int = { - 'unit_ids': [1, 2, 3, 6, 10, 14, 20, 31, 42], - 'labels_definition': { - 'quality': - {'name': 'quality', - 'labels': ['good', 'noise', 'MUA', 'artifact'], - 'auto_eclusive': True}, - 'experimental': - {'name': 'experimental', - 'labels': ['acute', 'chronic', 'headfixed', 'freelymoving'], - 'auto_eclusive': False} + "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], + "labels_definition": { + "quality": {"name": "quality", "labels": ["good", "noise", "MUA", "artifact"], "auto_eclusive": True}, + "experimental": { + "name": "experimental", + "labels": ["acute", "chronic", "headfixed", "freelymoving"], + "auto_eclusive": False, + }, }, - 'manual_labels': [ - {'unit_id': 1, - 'label_category_key': 'quality', - 'label_category_value': 'good' - }, - {'unit_id': 2, - 'label_category_key': 'quality', - 'label_category_value': 'noise' - }, - {'unit_id': 2, - 'label_category_key': 'experimental', - 'label_category_value': ['chronic', 'headfixed'] - }, + "manual_labels": [ + {"unit_id": 1, "label_category_key": "quality", "label_category_value": "good"}, + {"unit_id": 2, "label_category_key": "quality", "label_category_value": "noise"}, + {"unit_id": 2, "label_category_key": "experimental", "label_category_value": ["chronic", "headfixed"]}, ], - 'merged_unit_groups': [[3, 6], [10, 14, 20]], # one cell goes into at most one list - 'removed_units': [31, 42] # Can not be in the merged_units + "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list + "removed_units": [31, 42], # Can not be in the merged_units } valid_str = { - 'unit_ids': ["u1", "u2", "u3", "u6", "u10", "u14", "u20", "u31", "u42"], - 'labels_definition': { - 'quality': - {'name': 'quality', - 'labels': ['good', 'noise', 'MUA', 'artifact'], - 'auto_eclusive': True}, - 'experimental': - {'name': 'experimental', - 'labels': ['acute', 'chronic', 'headfixed', 'freelymoving'], - 'auto_eclusive': False} + "unit_ids": ["u1", "u2", "u3", "u6", "u10", "u14", "u20", "u31", "u42"], + "labels_definition": { + "quality": {"name": "quality", "labels": ["good", "noise", "MUA", "artifact"], "auto_eclusive": True}, + "experimental": { + "name": "experimental", + "labels": ["acute", "chronic", "headfixed", "freelymoving"], + "auto_eclusive": False, + }, }, - 'manual_labels': [ - {'unit_id': "u1", - 'label_category_key': 'quality', - 'label_category_value': 'good' - }, - {'unit_id': "u2", - 'label_category_key': 'quality', - 'label_category_value': 'noise' - }, - {'unit_id': "u2", - 'label_category_key': 'experimental', - 'label_category_value': ['chronic', 'headfixed'] - }, + "manual_labels": [ + {"unit_id": "u1", "label_category_key": "quality", "label_category_value": "good"}, + {"unit_id": "u2", "label_category_key": "quality", "label_category_value": "noise"}, + {"unit_id": "u2", "label_category_key": "experimental", "label_category_value": ["chronic", "headfixed"]}, ], - 'merged_unit_groups': [["u3", "u6"], ["u10", "u14", "u20"]], # one cell goes into at most one list - 'removed_units': ["u31", "u42"] # Can not be in the merged_units + "merged_unit_groups": [["u3", "u6"], ["u10", "u14", "u20"]], # one cell goes into at most one list + "removed_units": ["u31", "u42"], # Can not be in the merged_units } # This is a failure example duplicate_merge = { - 'unit_ids': [1, 2, 3, 6, 10, 14, 20, 31, 42], - 'labels_definition': { - 'quality': - {'name': 'quality', - 'labels': ['good', 'noise', 'MUA', 'artifact'], - 'auto_eclusive': True}, - 'experimental': - {'name': 'experimental', - 'labels': ['acute', 'chronic', 'headfixed', 'freelymoving'], - 'auto_eclusive': False} + "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], + "labels_definition": { + "quality": {"name": "quality", "labels": ["good", "noise", "MUA", "artifact"], "auto_eclusive": True}, + "experimental": { + "name": "experimental", + "labels": ["acute", "chronic", "headfixed", "freelymoving"], + "auto_eclusive": False, + }, }, - 'manual_labels': [ - {'unit_id': 1, - 'label_category_key': 'quality', - 'label_category_value': 'good' - }, - {'unit_id': 2, - 'label_category_key': 'quality', - 'label_category_value': 'noise' - }, - {'unit_id': 2, - 'label_category_key': 'experimental', - 'label_category_value': ['chronic', 'headfixed'] - }, + "manual_labels": [ + {"unit_id": 1, "label_category_key": "quality", "label_category_value": "good"}, + {"unit_id": 2, "label_category_key": "quality", "label_category_value": "noise"}, + {"unit_id": 2, "label_category_key": "experimental", "label_category_value": ["chronic", "headfixed"]}, ], - 'merged_unit_groups': [[3, 6, 10], [10, 14, 20]], # one cell goes into at most one list - 'removed_units': [31, 42] # Can not be in the merged_units + "merged_unit_groups": [[3, 6, 10], [10, 14, 20]], # one cell goes into at most one list + "removed_units": [31, 42], # Can not be in the merged_units } # This is a failure example merged_and_removed = { - 'unit_ids': [1, 2, 3, 6, 10, 14, 20, 31, 42], - 'labels_definition': { - 'quality': - {'name': 'quality', - 'labels': ['good', 'noise', 'MUA', 'artifact'], - 'auto_eclusive': True}, - 'experimental': - {'name': 'experimental', - 'labels': ['acute', 'chronic', 'headfixed', 'freelymoving'], - 'auto_eclusive': False} + "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], + "labels_definition": { + "quality": {"name": "quality", "labels": ["good", "noise", "MUA", "artifact"], "auto_eclusive": True}, + "experimental": { + "name": "experimental", + "labels": ["acute", "chronic", "headfixed", "freelymoving"], + "auto_eclusive": False, + }, }, - 'manual_labels': [ - {'unit_id': 1, - 'label_category_key': 'quality', - 'label_category_value': 'good' - }, - {'unit_id': 2, - 'label_category_key': 'quality', - 'label_category_value': 'noise' - }, - {'unit_id': 2, - 'label_category_key': 'experimental', - 'label_category_value': ['chronic', 'headfixed'] - }, + "manual_labels": [ + {"unit_id": 1, "label_category_key": "quality", "label_category_value": "good"}, + {"unit_id": 2, "label_category_key": "quality", "label_category_value": "noise"}, + {"unit_id": 2, "label_category_key": "experimental", "label_category_value": ["chronic", "headfixed"]}, ], - 'merged_unit_groups': [[3, 6], [10, 14, 20]], # one cell goes into at most one list - 'removed_units': [3, 31, 42] # Can not be in the merged_units + "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list + "removed_units": [3, 31, 42], # Can not be in the merged_units } unknown_merged_unit = { - 'unit_ids': [1, 2, 3, 6, 10, 14, 20, 31, 42], - 'labels_definition': { - 'quality': - {'name': 'quality', - 'labels': ['good', 'noise', 'MUA', 'artifact'], - 'auto_eclusive': True}, - 'experimental': - {'name': 'experimental', - 'labels': ['acute', 'chronic', 'headfixed', 'freelymoving'], - 'auto_eclusive': False} + "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], + "labels_definition": { + "quality": {"name": "quality", "labels": ["good", "noise", "MUA", "artifact"], "auto_eclusive": True}, + "experimental": { + "name": "experimental", + "labels": ["acute", "chronic", "headfixed", "freelymoving"], + "auto_eclusive": False, + }, }, - 'manual_labels': [ - {'unit_id': 1, - 'label_category_key': 'quality', - 'label_category_value': 'good' - }, - {'unit_id': 2, - 'label_category_key': 'quality', - 'label_category_value': 'noise' - }, - {'unit_id': 2, - 'label_category_key': 'experimental', - 'label_category_value': ['chronic', 'headfixed'] - }, + "manual_labels": [ + {"unit_id": 1, "label_category_key": "quality", "label_category_value": "good"}, + {"unit_id": 2, "label_category_key": "quality", "label_category_value": "noise"}, + {"unit_id": 2, "label_category_key": "experimental", "label_category_value": ["chronic", "headfixed"]}, ], - 'merged_unit_groups': [[3, 6, 99], [10, 14, 20]], # one cell goes into at most one list - 'removed_units': [31, 42] # Can not be in the merged_units + "merged_unit_groups": [[3, 6, 99], [10, 14, 20]], # one cell goes into at most one list + "removed_units": [31, 42], # Can not be in the merged_units } unknown_removed_unit = { - 'unit_ids': [1, 2, 3, 6, 10, 14, 20, 31, 42], - 'labels_definition': { - 'quality': - {'name': 'quality', - 'labels': ['good', 'noise', 'MUA', 'artifact'], - 'auto_eclusive': True}, - 'experimental': - {'name': 'experimental', - 'labels': ['acute', 'chronic', 'headfixed', 'freelymoving'], - 'auto_eclusive': False} + "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], + "labels_definition": { + "quality": {"name": "quality", "labels": ["good", "noise", "MUA", "artifact"], "auto_eclusive": True}, + "experimental": { + "name": "experimental", + "labels": ["acute", "chronic", "headfixed", "freelymoving"], + "auto_eclusive": False, + }, }, - 'manual_labels': [ - {'unit_id': 1, - 'label_category_key': 'quality', - 'label_category_value': 'good' - }, - {'unit_id': 2, - 'label_category_key': 'quality', - 'label_category_value': 'noise' - }, - {'unit_id': 2, - 'label_category_key': 'experimental', - 'label_category_value': ['chronic', 'headfixed'] - }, + "manual_labels": [ + {"unit_id": 1, "label_category_key": "quality", "label_category_value": "good"}, + {"unit_id": 2, "label_category_key": "quality", "label_category_value": "noise"}, + {"unit_id": 2, "label_category_key": "experimental", "label_category_value": ["chronic", "headfixed"]}, ], - 'merged_unit_groups': [[3, 6], [10, 14, 20]], # one cell goes into at most one list - 'removed_units': [31, 42, 99] # Can not be in the merged_units + "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list + "removed_units": [31, 42, 99], # Can not be in the merged_units } From 102116361b56310420d1c5508dbffa3fd37522a4 Mon Sep 17 00:00:00 2001 From: r_pr Date: Thu, 30 May 2024 15:40:53 +0200 Subject: [PATCH 018/136] Feature: Conversion from sortingview format to the new proposed format Took 1 hour 7 minutes --- .../curation/curation_format.py | 76 +++++++++++++++++++ .../curation/tests/test_curation_format.py | 6 ++ 2 files changed, 82 insertions(+) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 9b9c862cb6..1857b970a6 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -15,6 +15,7 @@ def validate_curation_dict(curation_dict): """ + supported_versions = {1} unit_set = set(curation_dict["unit_ids"]) labeled_unit_set = set([lbl["unit_id"] for lbl in curation_dict["manual_labels"]]) merged_units_set = set(sum(curation_dict["merged_unit_groups"], [])) @@ -31,4 +32,79 @@ def validate_curation_dict(curation_dict): raise ValueError("Some units belong to multiple merge groups") if len(removed_units_set.intersection(merged_units_set)) != 0: raise ValueError("Some units were merged and deleted") + if curation_dict["format_version"] not in supported_versions: + raise ValueError(f"Format version ({curation_dict['format_version']}) not supported. " + f"Only {supported_versions} are valid") + # Check the labels exclusivity + for lbl in curation_dict["manual_labels"]: + lbl_key = lbl["label_category_key"] + is_exclusive = curation_dict["labels_definition"][lbl_key]["auto_eclusive"] + if is_exclusive and not isinstance(lbl["label_category_value"], str): + raise ValueError(f"{lbl_key} are mutually exclusive labels. {lbl['label_category_value']} is invalid") + elif not is_exclusive and not isinstance(lbl["label_category_value"], list): + raise ValueError(f"{lbl_key} are not mutually exclusive labels. " + f"{lbl['label_category_value']} should be a lists") return True + + +def convert_from_sortingview(sortingview_dict, destination_format=1): + """ + Converts the sortingview curation format into a curation dictionary + Couple of caveats: + * The list of units is not available in the original sortingview dictionary. We set it to None + * Labels can not be mutually exclusive. + * Labels have no category, so we regroup them under the "all_labels" category + + Parameters + ---------- + sortingview_dict : dict + Dictionary containing the curation information from sortingview + destination_format : int + Version of the format to use. + Default to 1 + + Returns + ------- + curation_dict: dict + A curation dictionary + """ + merge_groups = sortingview_dict["mergeGroups"] + merged_units = sum(merge_groups, []) + if len(merged_units) > 0: + unit_id_type = int if isinstance(merged_units[0], int) else str + else: + unit_id_type = str + all_units = [] + all_labels = [] + manual_labels = [] + general_cat = "all_labels" + for unit_id, l_labels in sortingview_dict["labelsByUnit"].items(): + all_labels.extend(l_labels) + u_id = unit_id_type(unit_id) + all_units.append(u_id) + manual_labels.append({'unit_id': u_id, "label_category_key": general_cat, + "label_category_value": l_labels}) + labels_def = {"all_labels": + {"name": "all_labels", + "labels": all_labels, + "auto_eclusive": False}} + + curation_dict = {"unit_ids": None, + "labels_definition": labels_def, + "manual_labels": manual_labels, + "merged_unit_groups": merge_groups, + "removed_units": [], + "format_version": destination_format} + + return curation_dict + + +if __name__ == "__main__": + import json + with open("src/spikeinterface/curation/tests/sv-sorting-curation-str.json") as jf: + sv_curation = json.load(jf) + cur_d = convert_from_sortingview(sortingview_dict=sv_curation) + + + + diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 6a6686f676..2263ced95c 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -38,6 +38,7 @@ ], "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list "removed_units": [31, 42], # Can not be in the merged_units + "format_version": 1 } @@ -58,6 +59,7 @@ ], "merged_unit_groups": [["u3", "u6"], ["u10", "u14", "u20"]], # one cell goes into at most one list "removed_units": ["u31", "u42"], # Can not be in the merged_units + "format_version": 1 } # This is a failure example @@ -78,6 +80,7 @@ ], "merged_unit_groups": [[3, 6, 10], [10, 14, 20]], # one cell goes into at most one list "removed_units": [31, 42], # Can not be in the merged_units + "format_version": 1 } @@ -99,6 +102,7 @@ ], "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list "removed_units": [3, 31, 42], # Can not be in the merged_units + "format_version": 1 } @@ -119,6 +123,7 @@ ], "merged_unit_groups": [[3, 6, 99], [10, 14, 20]], # one cell goes into at most one list "removed_units": [31, 42], # Can not be in the merged_units + "format_version": 1 } @@ -139,6 +144,7 @@ ], "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list "removed_units": [31, 42, 99], # Can not be in the merged_units + "format_version": 1 } From 8094053712b3736dc4c703ba7dd2fc499135cc80 Mon Sep 17 00:00:00 2001 From: r_pr Date: Thu, 30 May 2024 15:54:16 +0200 Subject: [PATCH 019/136] Renaming curation dictionary keys Took 10 minutes --- .../curation/curation_format.py | 22 ++--- .../curation/tests/test_curation_format.py | 94 +++++++++---------- 2 files changed, 58 insertions(+), 58 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 1857b970a6..a086d3e963 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -37,13 +37,13 @@ def validate_curation_dict(curation_dict): f"Only {supported_versions} are valid") # Check the labels exclusivity for lbl in curation_dict["manual_labels"]: - lbl_key = lbl["label_category_key"] - is_exclusive = curation_dict["labels_definition"][lbl_key]["auto_eclusive"] - if is_exclusive and not isinstance(lbl["label_category_value"], str): - raise ValueError(f"{lbl_key} are mutually exclusive labels. {lbl['label_category_value']} is invalid") - elif not is_exclusive and not isinstance(lbl["label_category_value"], list): + lbl_key = lbl["label_category"] + is_exclusive = curation_dict["label_definitions"][lbl_key]["auto_exclusive"] + if is_exclusive and not isinstance(lbl["labels"], str): + raise ValueError(f"{lbl_key} are mutually exclusive labels. {lbl['labels']} is invalid") + elif not is_exclusive and not isinstance(lbl["labels"], list): raise ValueError(f"{lbl_key} are not mutually exclusive labels. " - f"{lbl['label_category_value']} should be a lists") + f"{lbl['labels']} should be a lists") return True @@ -82,15 +82,15 @@ def convert_from_sortingview(sortingview_dict, destination_format=1): all_labels.extend(l_labels) u_id = unit_id_type(unit_id) all_units.append(u_id) - manual_labels.append({'unit_id': u_id, "label_category_key": general_cat, - "label_category_value": l_labels}) + manual_labels.append({'unit_id': u_id, "label_category": general_cat, + "labels": l_labels}) labels_def = {"all_labels": {"name": "all_labels", - "labels": all_labels, - "auto_eclusive": False}} + "label_options": all_labels, + "auto_exclusive": False}} curation_dict = {"unit_ids": None, - "labels_definition": labels_def, + "label_definitions": labels_def, "manual_labels": manual_labels, "merged_unit_groups": merge_groups, "removed_units": [], diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 2263ced95c..a626538148 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -4,16 +4,16 @@ """example = { 'unit_ids': List[str, int], - 'labels_definition': { + 'label_definitions': { 'category_key1': {'name': str, - 'labels': List[str], - 'auto_eclusive': bool} + 'label_options': List[str], + 'auto_exclusive': bool} }, 'manual_labels': [ {'unit_id': str or int, - 'label_category_key': str, - 'label_category_value': list or str + 'label_category': str, + 'labels': list or str } ], 'merged_unit_groups': List[List[unit_ids]], # one cell goes into at most one list @@ -23,18 +23,18 @@ valid_int = { "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], - "labels_definition": { - "quality": {"name": "quality", "labels": ["good", "noise", "MUA", "artifact"], "auto_eclusive": True}, + "label_definitions": { + "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "auto_exclusive": True}, "experimental": { "name": "experimental", - "labels": ["acute", "chronic", "headfixed", "freelymoving"], - "auto_eclusive": False, + "label_options": ["acute", "chronic", "headfixed", "freelymoving"], + "auto_exclusive": False, }, }, "manual_labels": [ - {"unit_id": 1, "label_category_key": "quality", "label_category_value": "good"}, - {"unit_id": 2, "label_category_key": "quality", "label_category_value": "noise"}, - {"unit_id": 2, "label_category_key": "experimental", "label_category_value": ["chronic", "headfixed"]}, + {"unit_id": 1, "label_category": "quality", "labels": "good"}, + {"unit_id": 2, "label_category": "quality", "labels": "noise"}, + {"unit_id": 2, "label_category": "experimental", "labels": ["chronic", "headfixed"]}, ], "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list "removed_units": [31, 42], # Can not be in the merged_units @@ -44,18 +44,18 @@ valid_str = { "unit_ids": ["u1", "u2", "u3", "u6", "u10", "u14", "u20", "u31", "u42"], - "labels_definition": { - "quality": {"name": "quality", "labels": ["good", "noise", "MUA", "artifact"], "auto_eclusive": True}, + "label_definitions": { + "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "auto_exclusive": True}, "experimental": { "name": "experimental", - "labels": ["acute", "chronic", "headfixed", "freelymoving"], - "auto_eclusive": False, + "label_options": ["acute", "chronic", "headfixed", "freelymoving"], + "auto_exclusive": False, }, }, "manual_labels": [ - {"unit_id": "u1", "label_category_key": "quality", "label_category_value": "good"}, - {"unit_id": "u2", "label_category_key": "quality", "label_category_value": "noise"}, - {"unit_id": "u2", "label_category_key": "experimental", "label_category_value": ["chronic", "headfixed"]}, + {"unit_id": "u1", "label_category": "quality", "labels": "good"}, + {"unit_id": "u2", "label_category": "quality", "labels": "noise"}, + {"unit_id": "u2", "label_category": "experimental", "labels": ["chronic", "headfixed"]}, ], "merged_unit_groups": [["u3", "u6"], ["u10", "u14", "u20"]], # one cell goes into at most one list "removed_units": ["u31", "u42"], # Can not be in the merged_units @@ -65,18 +65,18 @@ # This is a failure example duplicate_merge = { "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], - "labels_definition": { - "quality": {"name": "quality", "labels": ["good", "noise", "MUA", "artifact"], "auto_eclusive": True}, + "label_definitions": { + "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "auto_exclusive": True}, "experimental": { "name": "experimental", - "labels": ["acute", "chronic", "headfixed", "freelymoving"], - "auto_eclusive": False, + "label_options": ["acute", "chronic", "headfixed", "freelymoving"], + "auto_exclusive": False, }, }, "manual_labels": [ - {"unit_id": 1, "label_category_key": "quality", "label_category_value": "good"}, - {"unit_id": 2, "label_category_key": "quality", "label_category_value": "noise"}, - {"unit_id": 2, "label_category_key": "experimental", "label_category_value": ["chronic", "headfixed"]}, + {"unit_id": 1, "label_category": "quality", "labels": "good"}, + {"unit_id": 2, "label_category": "quality", "labels": "noise"}, + {"unit_id": 2, "label_category": "experimental", "labels": ["chronic", "headfixed"]}, ], "merged_unit_groups": [[3, 6, 10], [10, 14, 20]], # one cell goes into at most one list "removed_units": [31, 42], # Can not be in the merged_units @@ -87,18 +87,18 @@ # This is a failure example merged_and_removed = { "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], - "labels_definition": { - "quality": {"name": "quality", "labels": ["good", "noise", "MUA", "artifact"], "auto_eclusive": True}, + "label_definitions": { + "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "auto_exclusive": True}, "experimental": { "name": "experimental", - "labels": ["acute", "chronic", "headfixed", "freelymoving"], - "auto_eclusive": False, + "label_options": ["acute", "chronic", "headfixed", "freelymoving"], + "auto_exclusive": False, }, }, "manual_labels": [ - {"unit_id": 1, "label_category_key": "quality", "label_category_value": "good"}, - {"unit_id": 2, "label_category_key": "quality", "label_category_value": "noise"}, - {"unit_id": 2, "label_category_key": "experimental", "label_category_value": ["chronic", "headfixed"]}, + {"unit_id": 1, "label_category": "quality", "labels": "good"}, + {"unit_id": 2, "label_category": "quality", "labels": "noise"}, + {"unit_id": 2, "label_category": "experimental", "labels": ["chronic", "headfixed"]}, ], "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list "removed_units": [3, 31, 42], # Can not be in the merged_units @@ -108,18 +108,18 @@ unknown_merged_unit = { "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], - "labels_definition": { - "quality": {"name": "quality", "labels": ["good", "noise", "MUA", "artifact"], "auto_eclusive": True}, + "label_definitions": { + "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "auto_exclusive": True}, "experimental": { "name": "experimental", - "labels": ["acute", "chronic", "headfixed", "freelymoving"], - "auto_eclusive": False, + "label_options": ["acute", "chronic", "headfixed", "freelymoving"], + "auto_exclusive": False, }, }, "manual_labels": [ - {"unit_id": 1, "label_category_key": "quality", "label_category_value": "good"}, - {"unit_id": 2, "label_category_key": "quality", "label_category_value": "noise"}, - {"unit_id": 2, "label_category_key": "experimental", "label_category_value": ["chronic", "headfixed"]}, + {"unit_id": 1, "label_category": "quality", "labels": "good"}, + {"unit_id": 2, "label_category": "quality", "labels": "noise"}, + {"unit_id": 2, "label_category": "experimental", "labels": ["chronic", "headfixed"]}, ], "merged_unit_groups": [[3, 6, 99], [10, 14, 20]], # one cell goes into at most one list "removed_units": [31, 42], # Can not be in the merged_units @@ -129,18 +129,18 @@ unknown_removed_unit = { "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], - "labels_definition": { - "quality": {"name": "quality", "labels": ["good", "noise", "MUA", "artifact"], "auto_eclusive": True}, + "label_definitions": { + "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "auto_exclusive": True}, "experimental": { "name": "experimental", - "labels": ["acute", "chronic", "headfixed", "freelymoving"], - "auto_eclusive": False, + "label_options": ["acute", "chronic", "headfixed", "freelymoving"], + "auto_exclusive": False, }, }, "manual_labels": [ - {"unit_id": 1, "label_category_key": "quality", "label_category_value": "good"}, - {"unit_id": 2, "label_category_key": "quality", "label_category_value": "noise"}, - {"unit_id": 2, "label_category_key": "experimental", "label_category_value": ["chronic", "headfixed"]}, + {"unit_id": 1, "label_category": "quality", "labels": "good"}, + {"unit_id": 2, "label_category": "quality", "labels": "noise"}, + {"unit_id": 2, "label_category": "experimental", "labels": ["chronic", "headfixed"]}, ], "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list "removed_units": [31, 42, 99], # Can not be in the merged_units From 5983194a372541380ccbac3031d8a0cd7a13b625 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 May 2024 13:54:45 +0000 Subject: [PATCH 020/136] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../curation/curation_format.py | 37 ++++++++----------- .../curation/tests/test_curation_format.py | 12 +++--- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index a086d3e963..43c7181baf 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -33,8 +33,9 @@ def validate_curation_dict(curation_dict): if len(removed_units_set.intersection(merged_units_set)) != 0: raise ValueError("Some units were merged and deleted") if curation_dict["format_version"] not in supported_versions: - raise ValueError(f"Format version ({curation_dict['format_version']}) not supported. " - f"Only {supported_versions} are valid") + raise ValueError( + f"Format version ({curation_dict['format_version']}) not supported. " f"Only {supported_versions} are valid" + ) # Check the labels exclusivity for lbl in curation_dict["manual_labels"]: lbl_key = lbl["label_category"] @@ -42,8 +43,7 @@ def validate_curation_dict(curation_dict): if is_exclusive and not isinstance(lbl["labels"], str): raise ValueError(f"{lbl_key} are mutually exclusive labels. {lbl['labels']} is invalid") elif not is_exclusive and not isinstance(lbl["labels"], list): - raise ValueError(f"{lbl_key} are not mutually exclusive labels. " - f"{lbl['labels']} should be a lists") + raise ValueError(f"{lbl_key} are not mutually exclusive labels. " f"{lbl['labels']} should be a lists") return True @@ -82,29 +82,24 @@ def convert_from_sortingview(sortingview_dict, destination_format=1): all_labels.extend(l_labels) u_id = unit_id_type(unit_id) all_units.append(u_id) - manual_labels.append({'unit_id': u_id, "label_category": general_cat, - "labels": l_labels}) - labels_def = {"all_labels": - {"name": "all_labels", - "label_options": all_labels, - "auto_exclusive": False}} - - curation_dict = {"unit_ids": None, - "label_definitions": labels_def, - "manual_labels": manual_labels, - "merged_unit_groups": merge_groups, - "removed_units": [], - "format_version": destination_format} + manual_labels.append({"unit_id": u_id, "label_category": general_cat, "labels": l_labels}) + labels_def = {"all_labels": {"name": "all_labels", "label_options": all_labels, "auto_exclusive": False}} + + curation_dict = { + "unit_ids": None, + "label_definitions": labels_def, + "manual_labels": manual_labels, + "merged_unit_groups": merge_groups, + "removed_units": [], + "format_version": destination_format, + } return curation_dict if __name__ == "__main__": import json + with open("src/spikeinterface/curation/tests/sv-sorting-curation-str.json") as jf: sv_curation = json.load(jf) cur_d = convert_from_sortingview(sortingview_dict=sv_curation) - - - - diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index a626538148..92fc963cef 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -38,7 +38,7 @@ ], "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list "removed_units": [31, 42], # Can not be in the merged_units - "format_version": 1 + "format_version": 1, } @@ -59,7 +59,7 @@ ], "merged_unit_groups": [["u3", "u6"], ["u10", "u14", "u20"]], # one cell goes into at most one list "removed_units": ["u31", "u42"], # Can not be in the merged_units - "format_version": 1 + "format_version": 1, } # This is a failure example @@ -80,7 +80,7 @@ ], "merged_unit_groups": [[3, 6, 10], [10, 14, 20]], # one cell goes into at most one list "removed_units": [31, 42], # Can not be in the merged_units - "format_version": 1 + "format_version": 1, } @@ -102,7 +102,7 @@ ], "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list "removed_units": [3, 31, 42], # Can not be in the merged_units - "format_version": 1 + "format_version": 1, } @@ -123,7 +123,7 @@ ], "merged_unit_groups": [[3, 6, 99], [10, 14, 20]], # one cell goes into at most one list "removed_units": [31, 42], # Can not be in the merged_units - "format_version": 1 + "format_version": 1, } @@ -144,7 +144,7 @@ ], "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list "removed_units": [31, 42, 99], # Can not be in the merged_units - "format_version": 1 + "format_version": 1, } From d8ff88993e7a5b3de6e955361387f9fbb8f80be6 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 30 May 2024 15:46:46 +0100 Subject: [PATCH 021/136] Dtype handling --- .../sortingcomponents/motion_interpolation.py | 34 +++++++++++++------ 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 4a5a2b0c47..b080e098f8 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -43,7 +43,9 @@ def correct_motion_on_peaks( spike_times = peaks["sample_index"][i0:i1] / sampling_frequency spike_locs = peak_locations[motion.direction][i0:i1] - spike_displacement = motion.get_displacement_at_time_and_depth(spike_times, spike_locs, segment_index=segment_index) + spike_displacement = motion.get_displacement_at_time_and_depth( + spike_times, spike_locs, segment_index=segment_index + ) corrected_peak_locations[i0:i1][motion.direction] -= spike_displacement @@ -60,6 +62,7 @@ def interpolate_motion_on_traces( interpolation_time_bin_centers_s=None, spatial_interpolation_method="kriging", spatial_interpolation_kwargs={}, + dtype=None, ): """ Apply inverse motion with spatial interpolation on traces. @@ -98,6 +101,11 @@ def interpolate_motion_on_traces( # assert HAVE_NUMBA assert times.shape[0] == traces.shape[0] + if dtype is None: + dtype = traces.dtype + if dtype.kind != "f": + raise ValueError(f"Can't interpolate traces of dtype {traces.dtype}.") + if segment_index is None: if motion.num_segments == 1: segment_index = 0 @@ -147,7 +155,7 @@ def interpolate_motion_on_traces( drift_kernel = get_spatial_interpolation_kernel( channel_locations, channel_locations_moved, - dtype=traces.dtype, + dtype=dtype, method=spatial_interpolation_method, **spatial_interpolation_kwargs, ) @@ -160,7 +168,9 @@ def interpolate_motion_on_traces( # plt.show() # quickly find the end of this bin, which is also the start of the next - next_start_index = current_start_index + np.searchsorted(bin_inds[current_start_index:], bin_ind + 1, side="left") + next_start_index = current_start_index + np.searchsorted( + bin_inds[current_start_index:], bin_ind + 1, side="left" + ) in_bin = slice(current_start_index, next_start_index) # here we use a simple np.matmul even if dirft_kernel can be super sparse. @@ -292,7 +302,8 @@ def __init__( channel_locations = recording.get_channel_locations() assert channel_locations.ndim >= motion.dim, ( - f"'direction' {motion.direction} not available. " f"Channel locations have {channel_locations.ndim} dimensions." + f"'direction' {motion.direction} not available. " + f"Channel locations have {channel_locations.ndim} dimensions." ) spatial_interpolation_kwargs = dict(sigma_um=sigma_um, p=p, num_closest=num_closest) if border_mode == "remove_channels": @@ -327,8 +338,13 @@ def __init__( else: raise ValueError("Wrong border_mode") - if dtype is None and recording.dtype.kind != "f": - dtype = "float32" + if dtype is None: + if recording.dtype.kind == "f": + dtype = recording.dtype + else: + raise ValueError( + f"Can't interpolate traces of recording with non-floating dtype={recording.dtype=}.") + dtype_ = fix_dtype(recording, dtype) BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_) @@ -352,7 +368,7 @@ def __init__( # in this case, interpolation_time_bin_size_s is set. s_end = parent_segment.get_num_samples() t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end])) - halfbin = interpolation_time_bin_size_s / 2. + halfbin = interpolation_time_bin_size_s / 2.0 segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s) else: segment_interpolation_time_bins_s = interpolation_time_bin_centers_s[segment_index] @@ -407,9 +423,7 @@ def __init__( def get_traces(self, start_frame, end_frame, channel_indices): if self.has_time_vector(): - raise NotImplementedError( - "InterpolateMotionRecording does not yet support recordings with time_vectors." - ) + raise NotImplementedError("InterpolateMotionRecording does not yet support recordings with time_vectors.") if start_frame is None: start_frame = 0 From 39f264313d099607addfe95662663b63a3adb8e6 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 30 May 2024 15:48:30 +0100 Subject: [PATCH 022/136] Dtype handling --- .../sortingcomponents/motion_interpolation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index b080e098f8..229152f4fa 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -103,8 +103,10 @@ def interpolate_motion_on_traces( if dtype is None: dtype = traces.dtype - if dtype.kind != "f": - raise ValueError(f"Can't interpolate traces of dtype {traces.dtype}.") + if dtype.kind != "f": + raise ValueError(f"Can't interpolate_motion with dtype {traces.dtype}.") + if traces.dtype != dtype: + traces = traces.astype(dtype) if segment_index is None: if motion.num_segments == 1: From 5f1b2a02cae462f0e273b7a74cd649b6a69f70b8 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 30 May 2024 15:48:58 +0100 Subject: [PATCH 023/136] Dtype handling --- src/spikeinterface/sortingcomponents/motion_interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 229152f4fa..b9e11bc9bc 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -104,7 +104,7 @@ def interpolate_motion_on_traces( if dtype is None: dtype = traces.dtype if dtype.kind != "f": - raise ValueError(f"Can't interpolate_motion with dtype {traces.dtype}.") + raise ValueError(f"Can't interpolate_motion with dtype {dtype}.") if traces.dtype != dtype: traces = traces.astype(dtype) From 33e39e4da406d00700111be8f4804d5902cc9297 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 30 May 2024 16:28:11 +0100 Subject: [PATCH 024/136] Add a simple correctness test --- .../preprocessing/preprocessing_tools.py | 4 +- .../sortingcomponents/motion_interpolation.py | 16 ++++--- .../tests/test_motion_interpolation.py | 42 ++++++++++++------- 3 files changed, 40 insertions(+), 22 deletions(-) diff --git a/src/spikeinterface/preprocessing/preprocessing_tools.py b/src/spikeinterface/preprocessing/preprocessing_tools.py index c0b80c349b..942478fd71 100644 --- a/src/spikeinterface/preprocessing/preprocessing_tools.py +++ b/src/spikeinterface/preprocessing/preprocessing_tools.py @@ -80,7 +80,7 @@ def get_spatial_interpolation_kernel( elif method == "idw": distances = scipy.spatial.distance.cdist(source_location, target_location, metric="euclidean") - interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype="float64") + interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype=dtype) for c in range(target_location.shape[0]): ind_sorted = np.argsort(distances[:, c]) chan_closest = ind_sorted[:num_closest] @@ -97,7 +97,7 @@ def get_spatial_interpolation_kernel( elif method == "nearest": distances = scipy.spatial.distance.cdist(source_location, target_location, metric="euclidean") - interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype="float64") + interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype=dtype) for c in range(target_location.shape[0]): ind_closest = np.argmin(distances[:, c]) interpolation_kernel[ind_closest, c] = 1.0 diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index b9e11bc9bc..cbc24c83c3 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -6,7 +6,7 @@ from spikeinterface.preprocessing.basepreprocessor import ( BasePreprocessor, BasePreprocessorSegment) -from .filter import fix_dtype +from ..preprocessing.filter import fix_dtype def correct_motion_on_peaks( @@ -126,10 +126,11 @@ def interpolate_motion_on_traces( time_bins = interpolation_time_bin_centers_s if time_bins is None: time_bins = motion.temporal_bins_s[segment_index] - bin_s = time_bins[1] - time_bins + bin_s = time_bins[1] - time_bins[0] bins_start = time_bins[0] - 0.5 * bin_s # nearest bin center for each frame? bin_inds = (times - bins_start) // bin_s + bin_inds = bin_inds.astype(int) # the time bins may not cover the whole set of times in the recording, # so we need to clip these indices to the valid range np.clip(bin_inds, 0, time_bins.size, out=bin_inds) @@ -145,7 +146,7 @@ def interpolate_motion_on_traces( interp_times.fill(bin_time) channel_motions = motion.get_displacement_at_time_and_depth( interp_times, - channel_locations[motion.dim], + channel_locations[:, motion.dim], segment_index=segment_index, ) channel_locations_moved = channel_locations.copy() @@ -316,13 +317,14 @@ def __init__( channel_inside = np.ones(locs.shape[0], dtype="bool") for segment_index in range(recording.get_num_segments()): # evaluate the positions of all channels over all time bins - channel_locations = motion.get_displacement_at_time_and_depth( + channel_displacements = motion.get_displacement_at_time_and_depth( times_s=motion.temporal_bins_s[segment_index], locations_um=locs, grid=True, ) + channel_locations_moved = locs[:, None] + channel_displacements # check if these remain inside of the probe - seg_inside = channel_locations.clip(l0, l1) == channel_locations + seg_inside = channel_locations_moved.clip(l0, l1) == channel_locations_moved seg_inside = seg_inside.all(axis=1) channel_inside &= seg_inside @@ -422,9 +424,10 @@ def __init__( self.segment_index = segment_index self.interpolation_time_bin_centers_s = interpolation_time_bin_centers_s self.dtype = dtype + self.motion = motion def get_traces(self, start_frame, end_frame, channel_indices): - if self.has_time_vector(): + if self.time_vector is not None: raise NotImplementedError("InterpolateMotionRecording does not yet support recordings with time_vectors.") if start_frame is None: @@ -441,6 +444,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): self.channel_locations, self.motion, channel_inds=self.channel_inds, + spatial_interpolation_method=self.spatial_interpolation_method, spatial_interpolation_kwargs=self.spatial_interpolation_kwargs, interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s, segment_index=self.segment_index, diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py index 47f61f9ad6..b97040a740 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py @@ -1,19 +1,15 @@ -import pytest from pathlib import Path -import numpy as np +import numpy as np +import pytest +import spikeinterface.core as sc from spikeinterface import download_dataset - -from spikeinterface.sortingcomponents.motion_utils import Motion from spikeinterface.sortingcomponents.motion_interpolation import ( - correct_motion_on_peaks, - interpolate_motion_on_traces, - InterpolateMotionRecording, -) - + InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, + interpolate_motion_on_traces) +from spikeinterface.sortingcomponents.motion_utils import Motion from spikeinterface.sortingcomponents.tests.common import make_dataset - if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "sortingcomponents" else: @@ -26,10 +22,10 @@ def make_fake_motion(rec): locs = rec.get_channel_locations() temporal_bins = np.arange(0.5, duration - 0.49, 0.5) spatial_bins = np.arange(locs[:, 1].min(), locs[:, 1].max(), 100) - displacament = np.zeros((temporal_bins.size, spatial_bins.size)) - displacament[:, :] = np.linspace(-30, 30, temporal_bins.size)[:, None] + displacement = np.zeros((temporal_bins.size, spatial_bins.size)) + displacement[:, :] = np.linspace(-30, 30, temporal_bins.size)[:, None] - motion = Motion([displacament], [temporal_bins], spatial_bins, direction="y") + motion = Motion([displacement], [temporal_bins], spatial_bins, direction="y") return motion @@ -62,7 +58,6 @@ def test_correct_motion_on_peaks(): # plt.show() - def test_interpolate_motion_on_traces(): rec, sorting = make_dataset() @@ -88,6 +83,24 @@ def test_interpolate_motion_on_traces(): assert traces.dtype == traces_corrected.dtype +def test_interpolation_simple(): + # a recording where a 1 moves at 1 chan per second. 30 chans 10 frames. + # there will be 9 chans of drift, so we add 9 chans of padding to the bottom + nt = nc0 = 10 # these need to be the same for this test + nc1 = nc0 + nc0 - 1 + traces = np.zeros((nt, nc1), dtype="float32") + traces[:, :nc0] = np.eye(nc0) + rec = sc.NumpyRecording(traces, sampling_frequency=1) + rec.set_dummy_probe_from_locations(np.c_[np.zeros(nc1), np.arange(nc1)]) + + true_motion = Motion(np.arange(nt)[:, None], 0.5 + np.arange(nt), np.zeros(1)) + rec_corrected = interpolate_motion(rec, true_motion, spatial_interpolation_method="nearest") + traces_corrected = rec_corrected.get_traces() + assert traces_corrected.shape == (nc0, nc0) + assert np.array_equal(traces_corrected[:, 0], np.ones(nt)) + assert np.array_equal(traces_corrected[:, 1:], np.zeros((nt, nc0 - 1))) + + def test_InterpolateMotionRecording(): rec, sorting = make_dataset() motion = make_fake_motion(rec) @@ -121,4 +134,5 @@ def test_InterpolateMotionRecording(): if __name__ == "__main__": # test_correct_motion_on_peaks() # test_interpolate_motion_on_traces() + test_interpolation_simple() test_InterpolateMotionRecording() From 221afde68b3f22587a5483e7b642804c8f7599d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 May 2024 14:45:13 +0000 Subject: [PATCH 025/136] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/motion.py | 6 ++---- .../preprocessing/tests/test_motion.py | 2 +- .../sortingcomponents/motion_interpolation.py | 6 ++---- .../sortingcomponents/tests/common.py | 2 -- .../tests/test_motion_estimation.py | 2 +- .../tests/test_motion_interpolation.py | 7 +++++-- .../sortingcomponents/tests/test_motion_utils.py | 16 +++++++--------- 7 files changed, 18 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index a5300ccadc..8b89e2f545 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -384,9 +384,7 @@ def correct_motion( t1 = time.perf_counter() run_times["estimate_motion"] = t1 - t0 - recording_corrected = InterpolateMotionRecording( - recording, motion, **interpolate_motion_kwargs - ) + recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs) if folder is not None: (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") @@ -434,7 +432,7 @@ def load_motion_info(folder): motion_info[name] = np.load(folder / f"{name}.npy") else: motion_info[name] = None - + motion_info["motion"] = Motion.load(folder / "motion") return motion_info diff --git a/src/spikeinterface/preprocessing/tests/test_motion.py b/src/spikeinterface/preprocessing/tests/test_motion.py index d678b2d565..f42a64b90b 100644 --- a/src/spikeinterface/preprocessing/tests/test_motion.py +++ b/src/spikeinterface/preprocessing/tests/test_motion.py @@ -25,7 +25,7 @@ def test_estimate_and_correct_motion(): folder = cache_folder / "estimate_and_correct_motion" if folder.exists(): shutil.rmtree(folder) - + rec_corrected = correct_motion(rec, folder=folder) print(rec_corrected) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index cbc24c83c3..d0bbbddd71 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -3,8 +3,7 @@ import numpy as np from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.preprocessing import get_spatial_interpolation_kernel -from spikeinterface.preprocessing.basepreprocessor import ( - BasePreprocessor, BasePreprocessorSegment) +from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment from ..preprocessing.filter import fix_dtype @@ -346,8 +345,7 @@ def __init__( if recording.dtype.kind == "f": dtype = recording.dtype else: - raise ValueError( - f"Can't interpolate traces of recording with non-floating dtype={recording.dtype=}.") + raise ValueError(f"Can't interpolate traces of recording with non-floating dtype={recording.dtype=}.") dtype_ = fix_dtype(recording, dtype) BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_) diff --git a/src/spikeinterface/sortingcomponents/tests/common.py b/src/spikeinterface/sortingcomponents/tests/common.py index 84d532d3aa..01e4445a13 100644 --- a/src/spikeinterface/sortingcomponents/tests/common.py +++ b/src/spikeinterface/sortingcomponents/tests/common.py @@ -3,7 +3,6 @@ from spikeinterface.core import generate_ground_truth_recording - def make_dataset(): # this replace the MEArec 10s file for testing recording, sorting = generate_ground_truth_recording( @@ -23,4 +22,3 @@ def make_dataset(): seed=2205, ) return recording, sorting - diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index e842d876a2..945aa6a09e 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -200,7 +200,7 @@ def test_estimate_motion(): # same params with differents engine should be the same motion0, motion1 = motions["rigid / decentralized / torch"], motions["rigid / decentralized / numpy"] - assert (motion0 == motion1) + assert motion0 == motion1 motion0, motion1 = ( motions["rigid / decentralized / torch / time_horizon_s"], diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py index b97040a740..1de0337ec0 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py @@ -5,8 +5,11 @@ import spikeinterface.core as sc from spikeinterface import download_dataset from spikeinterface.sortingcomponents.motion_interpolation import ( - InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, - interpolate_motion_on_traces) + InterpolateMotionRecording, + correct_motion_on_peaks, + interpolate_motion, + interpolate_motion_on_traces, +) from spikeinterface.sortingcomponents.motion_utils import Motion from spikeinterface.sortingcomponents.tests.common import make_dataset diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py index a170245d7d..8a62ef324b 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py @@ -14,15 +14,13 @@ def test_Motion(): - temporal_bins_s = np.arange(0., 10., 1.) - spatial_bins_um = np.array([100., 200.]) + temporal_bins_s = np.arange(0.0, 10.0, 1.0) + spatial_bins_um = np.array([100.0, 200.0]) displacement = np.zeros((temporal_bins_s.shape[0], spatial_bins_um.shape[0])) displacement[:, :] = np.linspace(-20, 20, temporal_bins_s.shape[0])[:, np.newaxis] - motion = Motion( - displacement, temporal_bins_s, spatial_bins_um, direction="y" - ) + motion = Motion(displacement, temporal_bins_s, spatial_bins_um, direction="y") print(motion) # serialize with pickle before interpolation fit @@ -40,16 +38,16 @@ def test_Motion(): assert motion2.interpolator is None # do interpolate - displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11], [120., 80., 150.]) + displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11], [120.0, 80.0, 150.0]) # print(displacement) assert displacement.shape[0] == 3 # check clip - assert displacement[2] == 20. + assert displacement[2] == 20.0 # interpolate grid - displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11, 15, 19], [150., 80.], grid=True) + displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11, 15, 19], [150.0, 80.0], grid=True) assert displacement.shape == (2, 5) - assert displacement[0, 2] == 20. + assert displacement[0, 2] == 20.0 # save/load to folder folder = cache_folder / "motion_saved" From 39fd14555c51a2e2510616d88953c039dcc27b76 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 31 May 2024 11:29:37 -0400 Subject: [PATCH 026/136] Motion est/tests --- src/spikeinterface/preprocessing/motion.py | 8 +-- .../preprocessing/tests/test_motion.py | 9 +-- .../sortingcomponents/motion_utils.py | 8 +-- .../tests/test_motion_estimation.py | 55 ++++++++----------- .../tests/test_motion_utils.py | 8 +-- 5 files changed, 36 insertions(+), 52 deletions(-) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 8b89e2f545..9af21a76f2 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -1,15 +1,13 @@ from __future__ import annotations +import json import time from pathlib import Path import numpy as np -import json -import copy - -from spikeinterface.core import get_noise_levels, fix_job_kwargs -from spikeinterface.core.job_tools import _shared_job_kwargs_doc +from spikeinterface.core import fix_job_kwargs, get_noise_levels from spikeinterface.core.core_tools import SIJsonEncoder +from spikeinterface.core.job_tools import _shared_job_kwargs_doc motion_options_preset = { # This preset should be the most acccurate diff --git a/src/spikeinterface/preprocessing/tests/test_motion.py b/src/spikeinterface/preprocessing/tests/test_motion.py index f42a64b90b..2f045b7a68 100644 --- a/src/spikeinterface/preprocessing/tests/test_motion.py +++ b/src/spikeinterface/preprocessing/tests/test_motion.py @@ -1,14 +1,11 @@ -import pytest -from pathlib import Path - import shutil +from pathlib import Path +import numpy as np +import pytest from spikeinterface.core import generate_recording - from spikeinterface.preprocessing import correct_motion, load_motion_info -import numpy as np - if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "preprocessing" else: diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py index 71cde08689..9537b5bf1c 100644 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -12,10 +12,10 @@ # * make simple test for Motion object with save/load DONE # * propagate to estimate_motion : DONE # * handle multi segment in estimate_motion(): maybe in another PR -# * propagate to motion_interpolation.py: ALMOST DONE -# * propagate to preprocessing/correct_motion(): ALMOST DONE -# * generate drifting signals for test estimate_motion and interpolate_motion -# * uncomment assert in test_estimate_motion (aka debug torch vs numpy diff) +# * propagate to motion_interpolation.py: DONE +# * propagate to preprocessing/correct_motion(): DONE +# * generate drifting signals for test estimate_motion and interpolate_motion: SIMPLE ONE DONE? +# * uncomment assert in test_estimate_motion (aka debug torch vs numpy diff): DONE # * delegate times to recording object in # * estimate motion # * correct_motion_on_peaks() diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index 945aa6a09e..87534ec1bf 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -1,18 +1,15 @@ -import pytest -from pathlib import Path import shutil +from pathlib import Path import numpy as np - -from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from spikeinterface.sortingcomponents.motion_estimation import estimate_motion - - -from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording +import pytest from spikeinterface.core.node_pipeline import ExtractDenseWaveforms - -from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass - +from spikeinterface.sortingcomponents.motion_estimation import estimate_motion +from spikeinterface.sortingcomponents.motion_interpolation import \ + InterpolateMotionRecording +from spikeinterface.sortingcomponents.peak_detection import detect_peaks +from spikeinterface.sortingcomponents.peak_localization import \ + LocalizeCenterOfMass from spikeinterface.sortingcomponents.tests.common import make_dataset if hasattr(pytest, "global_test_folder"): @@ -199,33 +196,25 @@ def test_estimate_motion(): plt.show() # same params with differents engine should be the same - motion0, motion1 = motions["rigid / decentralized / torch"], motions["rigid / decentralized / numpy"] + motion0 = motions["rigid / decentralized / torch"] + motion1 = motions["rigid / decentralized / numpy"] assert motion0 == motion1 - motion0, motion1 = ( - motions["rigid / decentralized / torch / time_horizon_s"], - motions["rigid / decentralized / numpy / time_horizon_s"], - ) - # TODO : later torch and numpy used to be the same - # assert np.testing.assert_almost_equal(motion0, motion1) + motion0 = motions["rigid / decentralized / torch / time_horizon_s"] + motion1 = motions["rigid / decentralized / numpy / time_horizon_s"], + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) - motion0, motion1 = motions["non-rigid / decentralized / torch"], motions["non-rigid / decentralized / numpy"] - # TODO : later torch and numpy used to be the same - # assert np.testing.assert_almost_equal(motion0, motion1) + motion0 = motions["non-rigid / decentralized / torch"] + motion1 = motions["non-rigid / decentralized / numpy"] + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) - motion0, motion1 = ( - motions["non-rigid / decentralized / torch / time_horizon_s"], - motions["non-rigid / decentralized / numpy / time_horizon_s"], - ) - # TODO : later torch and numpy used to be the same - # assert np.testing.assert_almost_equal(motion0, motion1) + motion0 = motions["non-rigid / decentralized / torch / time_horizon_s"] + motion1 = motions["non-rigid / decentralized / numpy / time_horizon_s"], + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) - motion0, motion1 = ( - motions["non-rigid / decentralized / torch / spatial_prior"], - motions["non-rigid / decentralized / numpy / spatial_prior"], - ) - # TODO : later torch and numpy used to be the same - # assert np.testing.assert_almost_equal(motion0, motion1) + motion0 = motions["non-rigid / decentralized / torch / spatial_prior"] + motion1 = motions["non-rigid / decentralized / numpy / spatial_prior"] + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) if __name__ == "__main__": diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py index 8a62ef324b..2fbbea0a25 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py @@ -25,17 +25,17 @@ def test_Motion(): # serialize with pickle before interpolation fit motion2 = pickle.loads(pickle.dumps(motion)) - assert motion2.interpolator is None + assert motion2.interpolators is None # serialize with pickle after interpolation fit motion.make_interpolators() - assert motion2.interpolator is not None + assert motion2.interpolators is not None motion2 = pickle.loads(pickle.dumps(motion)) - assert motion2.interpolator is not None + assert motion2.interpolators is not None # to/from dict motion2 = Motion(**motion.to_dict()) assert motion == motion2 - assert motion2.interpolator is None + assert motion2.interpolators is None # do interpolate displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11], [120.0, 80.0, 150.0]) From 63b851c6842c61c448d151d3049e77faaaebd75e Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 31 May 2024 11:38:18 -0400 Subject: [PATCH 027/136] Delegate to sample_index_to_time() in estimation --- .../sortingcomponents/motion_estimation.py | 28 +++++++++---------- .../sortingcomponents/motion_utils.py | 2 +- .../tests/test_motion_estimation.py | 4 +-- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py index 3a8b75f8b3..bede0a19bb 100644 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion_estimation.py @@ -683,16 +683,15 @@ def make_2d_motion_histogram( spatial_bin_edges 1d array with spatial bin edges """ - fs = recording.get_sampling_frequency() - num_samples = recording.get_num_samples(segment_index=0) - bin_sample_size = int(bin_duration_s * fs) - sample_bin_edges = np.arange(0, num_samples + bin_sample_size, bin_sample_size) - temporal_bin_edges = sample_bin_edges / fs + n_samples = recording.get_num_samples() + mint_s = recording.sample_index_to_time(0) + maxt_s = recording.sample_index_to_time(n_samples) + temporal_bin_edges = np.arange(mint_s, maxt_s + bin_duration_s, bin_duration_s) if spatial_bin_edges is None: spatial_bin_edges = get_spatial_bin_edges(recording, direction, margin_um, bin_um) arr = np.zeros((peaks.size, 2), dtype="float64") - arr[:, 0] = peaks["sample_index"] + arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) arr[:, 1] = peak_locations[direction] if weight_with_amplitude: @@ -700,11 +699,11 @@ def make_2d_motion_histogram( else: weights = None - motion_histogram, edges = np.histogramdd(arr, bins=(sample_bin_edges, spatial_bin_edges), weights=weights) + motion_histogram, edges = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges), weights=weights) # average amplitude in each bin if weight_with_amplitude: - bin_counts, _ = np.histogramdd(arr, bins=(sample_bin_edges, spatial_bin_edges)) + bin_counts, _ = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges)) bin_counts[bin_counts == 0] = 1 motion_histogram = motion_histogram / bin_counts @@ -759,11 +758,10 @@ def make_3d_motion_histograms( spatial_bin_edges 1d array with spatial bin edges """ - fs = recording.get_sampling_frequency() - num_samples = recording.get_num_samples(segment_index=0) - bin_sample_size = int(bin_duration_s * fs) - sample_bin_edges = np.arange(0, num_samples + bin_sample_size, bin_sample_size) - temporal_bin_edges = sample_bin_edges / fs + n_samples = recording.get_num_samples() + mint_s = recording.sample_index_to_time(0) + maxt_s = recording.sample_index_to_time(n_samples) + temporal_bin_edges = np.arange(mint_s, maxt_s + bin_duration_s, bin_duration_s) if spatial_bin_edges is None: spatial_bin_edges = get_spatial_bin_edges(recording, direction, margin_um, bin_um) @@ -778,14 +776,14 @@ def make_3d_motion_histograms( ) arr = np.zeros((peaks.size, 3), dtype="float64") - arr[:, 0] = peaks["sample_index"] + arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) arr[:, 1] = peak_locations[direction] arr[:, 2] = abs_peaks_log_norm motion_histograms, edges = np.histogramdd( arr, bins=( - sample_bin_edges, + temporal_bin_edges, spatial_bin_edges, amplitude_bin_edges, ), diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py index 9537b5bf1c..0f19c2a2de 100644 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -17,7 +17,7 @@ # * generate drifting signals for test estimate_motion and interpolate_motion: SIMPLE ONE DONE? # * uncomment assert in test_estimate_motion (aka debug torch vs numpy diff): DONE # * delegate times to recording object in -# * estimate motion +# * estimate motion: DONE # * correct_motion_on_peaks() # * interpolate_motion_on_traces() # propagate to benchmark estimate motion diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index 87534ec1bf..7eea4e0bdd 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -201,7 +201,7 @@ def test_estimate_motion(): assert motion0 == motion1 motion0 = motions["rigid / decentralized / torch / time_horizon_s"] - motion1 = motions["rigid / decentralized / numpy / time_horizon_s"], + motion1 = motions["rigid / decentralized / numpy / time_horizon_s"] np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) motion0 = motions["non-rigid / decentralized / torch"] @@ -209,7 +209,7 @@ def test_estimate_motion(): np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) motion0 = motions["non-rigid / decentralized / torch / time_horizon_s"] - motion1 = motions["non-rigid / decentralized / numpy / time_horizon_s"], + motion1 = motions["non-rigid / decentralized / numpy / time_horizon_s"] np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) motion0 = motions["non-rigid / decentralized / torch / spatial_prior"] From d99d05bbd68e7582c917b6281674b6c5975ca3b9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 May 2024 15:39:52 +0000 Subject: [PATCH 028/136] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/tests/test_motion_estimation.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index 7eea4e0bdd..d916102376 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -5,11 +5,9 @@ import pytest from spikeinterface.core.node_pipeline import ExtractDenseWaveforms from spikeinterface.sortingcomponents.motion_estimation import estimate_motion -from spikeinterface.sortingcomponents.motion_interpolation import \ - InterpolateMotionRecording +from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from spikeinterface.sortingcomponents.peak_localization import \ - LocalizeCenterOfMass +from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass from spikeinterface.sortingcomponents.tests.common import make_dataset if hasattr(pytest, "global_test_folder"): From d7b6a598a7e99b64053eaea1b686c9e81f2d1427 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 31 May 2024 12:04:51 -0400 Subject: [PATCH 029/136] Add a test of time bin changing at interpolatino time --- .../sortingcomponents/motion_interpolation.py | 10 +++++--- .../tests/test_motion_estimation.py | 6 ++--- .../tests/test_motion_interpolation.py | 23 +++++++++++++++---- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index d0bbbddd71..889e89446d 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -3,7 +3,8 @@ import numpy as np from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.preprocessing import get_spatial_interpolation_kernel -from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment +from spikeinterface.preprocessing.basepreprocessor import ( + BasePreprocessor, BasePreprocessorSegment) from ..preprocessing.filter import fix_dtype @@ -285,7 +286,7 @@ class InterpolateMotionRecording(BasePreprocessor): Recording after motion correction """ - name = "correct_motion" + name = "interpolate_motion" def __init__( self, @@ -299,6 +300,7 @@ def __init__( interpolation_time_bin_centers_s=None, interpolation_time_bin_size_s=None, dtype=None, + **spatial_interpolation_kwargs, ): # assert recording.get_num_segments() == 1, "correct_motion() is only available for single-segment recordings" @@ -307,7 +309,9 @@ def __init__( f"'direction' {motion.direction} not available. " f"Channel locations have {channel_locations.ndim} dimensions." ) - spatial_interpolation_kwargs = dict(sigma_um=sigma_um, p=p, num_closest=num_closest) + spatial_interpolation_kwargs = dict( + sigma_um=sigma_um, p=p, num_closest=num_closest, **spatial_interpolation_kwargs + ) if border_mode == "remove_channels": locs = channel_locations[:, motion.dim] l0, l1 = np.min(locs), np.max(locs) diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index d916102376..88908c5cc4 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -1,13 +1,12 @@ -import shutil from pathlib import Path import numpy as np import pytest from spikeinterface.core.node_pipeline import ExtractDenseWaveforms from spikeinterface.sortingcomponents.motion_estimation import estimate_motion -from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass +from spikeinterface.sortingcomponents.peak_localization import \ + LocalizeCenterOfMass from spikeinterface.sortingcomponents.tests.common import make_dataset if hasattr(pytest, "global_test_folder"): @@ -153,7 +152,6 @@ def test_estimate_motion(): ) kwargs.update(cases_kwargs) - job_kwargs = dict(progress_bar=False) motion, extra_check = estimate_motion(recording, peaks, peak_locations, **kwargs) motions[name] = motion diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py index 1de0337ec0..ffed3e72fc 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py @@ -5,11 +5,8 @@ import spikeinterface.core as sc from spikeinterface import download_dataset from spikeinterface.sortingcomponents.motion_interpolation import ( - InterpolateMotionRecording, - correct_motion_on_peaks, - interpolate_motion, - interpolate_motion_on_traces, -) + InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, + interpolate_motion_on_traces) from spikeinterface.sortingcomponents.motion_utils import Motion from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -103,6 +100,22 @@ def test_interpolation_simple(): assert np.array_equal(traces_corrected[:, 0], np.ones(nt)) assert np.array_equal(traces_corrected[:, 1:], np.zeros((nt, nc0 - 1))) + # let's try a new version where we interpolate too slowly + rec_corrected = interpolate_motion( + rec, true_motion, spatial_interpolation_method="nearest", num_closest=2, interpolation_time_bin_size_s=2 + ) + traces_corrected = rec_corrected.get_traces() + assert traces_corrected.shape == (nc0, nc0) + # what happens with nearest here? + # well... due to rounding towards the nearest even number, the motion (which at + # these time bin centers is 0.5, 2.5, 4.5, ...) flips the signal's nearest + # neighbor back and forth between the first and second channels + assert np.all(traces_corrected[::2, 0] == 1) + assert np.all(traces_corrected[1::2, 0] == 0) + assert np.all(traces_corrected[1::2, 1] == 1) + assert np.all(traces_corrected[::2, 1] == 0) + assert np.all(traces_corrected[:, 2:] == 0) + def test_InterpolateMotionRecording(): rec, sorting = make_dataset() From 9907372a2934e71a29a6b641b56a431fa9c1340b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 May 2024 16:05:19 +0000 Subject: [PATCH 030/136] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/motion_interpolation.py | 3 +-- .../sortingcomponents/tests/test_motion_estimation.py | 3 +-- .../sortingcomponents/tests/test_motion_interpolation.py | 7 +++++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 889e89446d..1a827a7b5b 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -3,8 +3,7 @@ import numpy as np from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.preprocessing import get_spatial_interpolation_kernel -from spikeinterface.preprocessing.basepreprocessor import ( - BasePreprocessor, BasePreprocessorSegment) +from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment from ..preprocessing.filter import fix_dtype diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index 88908c5cc4..7c25bc8923 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -5,8 +5,7 @@ from spikeinterface.core.node_pipeline import ExtractDenseWaveforms from spikeinterface.sortingcomponents.motion_estimation import estimate_motion from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from spikeinterface.sortingcomponents.peak_localization import \ - LocalizeCenterOfMass +from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass from spikeinterface.sortingcomponents.tests.common import make_dataset if hasattr(pytest, "global_test_folder"): diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py index ffed3e72fc..3870517d5a 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py @@ -5,8 +5,11 @@ import spikeinterface.core as sc from spikeinterface import download_dataset from spikeinterface.sortingcomponents.motion_interpolation import ( - InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, - interpolate_motion_on_traces) + InterpolateMotionRecording, + correct_motion_on_peaks, + interpolate_motion, + interpolate_motion_on_traces, +) from spikeinterface.sortingcomponents.motion_utils import Motion from spikeinterface.sortingcomponents.tests.common import make_dataset From 3387592c61663f1e4eed2f097521b5f18011d890 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 31 May 2024 12:28:11 -0400 Subject: [PATCH 031/136] Update correct_motion_on_peaks to take a recording and delegate to sample_index_to_time --- .../benchmark/benchmark_motion_estimation.py | 16 +++++++--------- .../sortingcomponents/motion_interpolation.py | 9 ++++----- .../tests/test_motion_interpolation.py | 2 +- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 5d3c9c207a..9df9fe34c3 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -1,22 +1,20 @@ from __future__ import annotations import json +import pickle import time from pathlib import Path -import pickle +import matplotlib.pyplot as plt import numpy as np import scipy.interpolate - from spikeinterface.core import get_noise_levels +from spikeinterface.sortingcomponents.benchmark.benchmark_tools import ( + Benchmark, BenchmarkStudy, _simpleaxis) +from spikeinterface.sortingcomponents.motion_estimation import estimate_motion from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.peak_localization import localize_peaks -from spikeinterface.sortingcomponents.motion_estimation import estimate_motion -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis - - -import matplotlib.pyplot as plt +from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.widgets import plot_probe_map # import MEArec as mr @@ -670,7 +668,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # peak_locations_corrected = correct_motion_on_peaks( # self.selected_peaks, # self.peak_locations, -# self.recording.sampling_frequency, +# self.recording, # self.motion, # self.temporal_bins, # self.spatial_bins, diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 889e89446d..c65e94ee9a 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -12,7 +12,7 @@ def correct_motion_on_peaks( peaks, peak_locations, - sampling_frequency, + rec, motion, ): """ @@ -35,18 +35,16 @@ def correct_motion_on_peaks( Motion-corrected peak locations """ corrected_peak_locations = peak_locations.copy() + times_s = rec.sample_index_to_time(peaks["sample_index"]) for segment_index in range(motion.num_segments): i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) - # TODO delegate times to recording object - spike_times = peaks["sample_index"][i0:i1] / sampling_frequency + spike_times = times_s[i0:i1] spike_locs = peak_locations[motion.direction][i0:i1] - spike_displacement = motion.get_displacement_at_time_and_depth( spike_times, spike_locs, segment_index=segment_index ) - corrected_peak_locations[i0:i1][motion.direction] -= spike_displacement return corrected_peak_locations @@ -403,6 +401,7 @@ def __init__( interpolation_time_bin_centers_s=interpolation_time_bin_centers_s, dtype=dtype_.str, ) + self._kwargs.update(spatial_interpolation_kwargs) class InterpolateMotionRecordingSegment(BasePreprocessorSegment): diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py index ffed3e72fc..cbfaa8adfb 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py @@ -43,7 +43,7 @@ def test_correct_motion_on_peaks(): corrected_peak_locations = correct_motion_on_peaks( peaks, peak_locations, - rec.sampling_frequency, + rec, motion, ) # print(corrected_peak_locations) From 97bda65474dbc78d22776e59a11facfd8653cc6c Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 31 May 2024 12:28:46 -0400 Subject: [PATCH 032/136] Update todo list --- src/spikeinterface/sortingcomponents/motion_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py index 0f19c2a2de..1edf484aa4 100644 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -18,8 +18,8 @@ # * uncomment assert in test_estimate_motion (aka debug torch vs numpy diff): DONE # * delegate times to recording object in # * estimate motion: DONE -# * correct_motion_on_peaks() -# * interpolate_motion_on_traces() +# * correct_motion_on_peaks(): DONE +# * interpolate_motion_on_traces(): DONE # propagate to benchmark estimate motion # update plot_motion() dans widget From accdd0af708176d4495c9138f30251b6eb7dc86c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 May 2024 16:29:17 +0000 Subject: [PATCH 033/136] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/benchmark/benchmark_motion_estimation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 9df9fe34c3..4e8bf71044 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -9,8 +9,7 @@ import numpy as np import scipy.interpolate from spikeinterface.core import get_noise_levels -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import ( - Benchmark, BenchmarkStudy, _simpleaxis) +from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis from spikeinterface.sortingcomponents.motion_estimation import estimate_motion from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_localization import localize_peaks From dd0edc8d71e906f3ba94a50940c5c5932682a265 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 31 May 2024 13:57:24 -0400 Subject: [PATCH 034/136] Fix test --- .../sortingcomponents/tests/test_motion_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py index 2fbbea0a25..84dda89d0d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py @@ -21,15 +21,15 @@ def test_Motion(): displacement[:, :] = np.linspace(-20, 20, temporal_bins_s.shape[0])[:, np.newaxis] motion = Motion(displacement, temporal_bins_s, spatial_bins_um, direction="y") - print(motion) + assert motion.interpolators is None # serialize with pickle before interpolation fit motion2 = pickle.loads(pickle.dumps(motion)) assert motion2.interpolators is None # serialize with pickle after interpolation fit - motion.make_interpolators() + motion2.make_interpolators() assert motion2.interpolators is not None - motion2 = pickle.loads(pickle.dumps(motion)) + motion2 = pickle.loads(pickle.dumps(motion2)) assert motion2.interpolators is not None # to/from dict From abda223a4ee14e43d8a7cfe56a6b53aa972b0960 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Sat, 1 Jun 2024 09:45:41 -0400 Subject: [PATCH 035/136] Clean/doc --- src/spikeinterface/preprocessing/motion.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 9af21a76f2..c2abf65692 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -275,11 +275,8 @@ def correct_motion( recording_corrected: Recording The motion corrected recording motion_info: dict - Optional output if `output_motion_info=True` + Optional output if `output_motion_info=True`. The key "motion" holds the Motion object. """ - - # TODO : Use motion object - # local import are important because "sortingcomponents" is not important by default from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods from spikeinterface.sortingcomponents.peak_selection import select_peaks From ab696e63065d82009ee949570a551583dfa9ffb3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 4 Jun 2024 10:15:03 +0200 Subject: [PATCH 036/136] Fix tests --- .../sortingcomponents/benchmark/benchmark_motion_estimation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 2d62547778..86428cf1ee 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -5,7 +5,6 @@ import pickle import time -import matplotlib.pyplot as plt import numpy as np from spikeinterface.core import get_noise_levels @@ -289,6 +288,7 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): ax.set_ylim(0, lim) def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)): + import matplotlib.pyplot as plt if case_keys is None: case_keys = list(self.cases.keys()) From 7cfc6f99b8b5d10cb06d18fc0858a14c92d25271 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 4 Jun 2024 10:18:21 +0200 Subject: [PATCH 037/136] Fix imports in tests --- .../sortingcomponents/benchmark/benchmark_motion_estimation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 86428cf1ee..7428629c4a 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -10,7 +10,8 @@ from spikeinterface.core import get_noise_levels from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis from spikeinterface.sortingcomponents.motion_estimation import estimate_motion -from spikeinterface.sortingcomponents.peak_detection import detect_peaks, select_peaks +from spikeinterface.sortingcomponents.peak_detection import detect_peaks +from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.peak_localization import localize_peaks from spikeinterface.widgets import plot_probe_map From 64453c6f532e5bc88b8615ec87c7562a0d16dace Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 5 Jun 2024 14:07:51 +0200 Subject: [PATCH 038/136] Update on the curation format. --- doc/modules/curation.rst | 105 ++++++++++++ .../curation/curation_format.py | 50 ++++-- .../curation/tests/test_curation_format.py | 155 ++++++------------ 3 files changed, 187 insertions(+), 123 deletions(-) diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index 401ceea5dc..83410494ba 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -41,6 +41,111 @@ The merging and splitting operations are handled by the :py:class:`~spikeinterfa # here is the final clean sorting clean_sorting = cs.sorting +Manual curation format +---------------------- + +SpikeInterface internally support a manual curation format JSON based. +When a mnual curation is necessary, modifying in place a dataset is a bad practice. +Instead, to keep the reproducibility in the spike sorting piepline, we introduce a manual curation format, +simple and JSON based. This format defines at the moment : merges + deletions + manual tags. +The simple file can be kept along side the output of a sorter and applied on the result to have a "clean" result. + +This format has two part: + + * **definition** with the folowing keys: + + * "format_version" : format specification + * "unit_ids" : give the list of unit_ds + * "label_definitions" : list of label category and possible labels per category. + Every category can be *exclusive=True* onely one label or *exclusive=False* several labels possible + + * **manual output** curation with the folowing keys: + + * "manual_labels" + * "merged_unit_groups" + * "removed_units" + +Here the description of the format with a simple example: + +.. code-block:: json + + + { + # the first part of the format is the definitation + "format_version": "1", + "unit_ids": [ + "u1", + "u2", + "u3", + "u6", + "u10", + "u14", + "u20", + "u31", + "u42" + ], + "label_definitions": { + "quality": { + "name": "quality", + "label_options": [ + "good", + "noise", + "MUA", + "artifact" + ], + "exclusive": true + }, + "experimental": { + "name": "experimental", + "label_options": [ + "acute", + "chronic", + "headfixed", + "freelymoving" + ], + "exclusive": false + } + }, + # the second part of the format is manual action + "manual_labels": [ + { + "unit_id": "u1", + "label_category": "quality", + "labels": "good" + }, + { + "unit_id": "u2", + "label_category": "quality", + "labels": "noise" + }, + { + "unit_id": "u2", + "label_category": "experimental", + "labels": [ + "chronic", + "headfixed" + ] + } + ], + "merged_unit_groups": [ + [ + "u3", + "u6" + ], + [ + "u10", + "u14", + "u20" + ] + ], + "removed_units": [ + "u31", + "u42" + ] + } + + + Automatic curation tools ------------------------ diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 43c7181baf..d25efaa6d8 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -1,10 +1,15 @@ from itertools import combinations +supported_curation_format_versions = {"1"} + + def validate_curation_dict(curation_dict): """ Validate that the curation dictionary given as parameter complies with the format + The function do not return anything. This raise an error if something is wring in the format. + Parameters ---------- curation_dict : dict @@ -12,39 +17,52 @@ def validate_curation_dict(curation_dict): Returns ------- + Nothing. + """ - supported_versions = {1} + # format + if "format_version" not in curation_dict: + raise ValueError("No version_format") + + if curation_dict["format_version"] not in supported_curation_format_versions: + raise ValueError( + f"Format version ({curation_dict['format_version']}) not supported. " f"Only {supported_curation_format_versions} are valid" + ) + + # unit_ids unit_set = set(curation_dict["unit_ids"]) labeled_unit_set = set([lbl["unit_id"] for lbl in curation_dict["manual_labels"]]) merged_units_set = set(sum(curation_dict["merged_unit_groups"], [])) removed_units_set = set(curation_dict["removed_units"]) if not labeled_unit_set.issubset(unit_set): - raise ValueError("Some labeled units are not in the unit list") + raise ValueError("Curation format: some labeled units are not in the unit list") if not merged_units_set.issubset(unit_set): - raise ValueError("Some merged units are not in the unit list") + raise ValueError("Curation format: some merged units are not in the unit list") if not removed_units_set.issubset(unit_set): - raise ValueError("Some removed units are not in the unit list") + raise ValueError("Curation format: some removed units are not in the unit list") + all_merging_groups = [set(group) for group in curation_dict["merged_unit_groups"]] for gp_1, gp_2 in combinations(all_merging_groups, 2): if len(gp_1.intersection(gp_2)) != 0: raise ValueError("Some units belong to multiple merge groups") if len(removed_units_set.intersection(merged_units_set)) != 0: raise ValueError("Some units were merged and deleted") - if curation_dict["format_version"] not in supported_versions: - raise ValueError( - f"Format version ({curation_dict['format_version']}) not supported. " f"Only {supported_versions} are valid" - ) + # Check the labels exclusivity for lbl in curation_dict["manual_labels"]: - lbl_key = lbl["label_category"] - is_exclusive = curation_dict["label_definitions"][lbl_key]["auto_exclusive"] - if is_exclusive and not isinstance(lbl["labels"], str): - raise ValueError(f"{lbl_key} are mutually exclusive labels. {lbl['labels']} is invalid") - elif not is_exclusive and not isinstance(lbl["labels"], list): - raise ValueError(f"{lbl_key} are not mutually exclusive labels. " f"{lbl['labels']} should be a lists") - return True + for label_key in curation_dict["label_definitions"].keys(): + if label_key in lbl: + unit_id = lbl["unit_id"] + label_value = lbl[label_key] + if not isinstance(label_value, list): + raise ValueError(f"Curation format: manual_labels {unit_id} is invalid shoudl be a list") + + is_exclusive = curation_dict["label_definitions"][label_key]["exclusive"] + + if is_exclusive and not len(label_value) <=1: + raise ValueError(f"Curation format: manual_labels {unit_id} {label_key} are exclusive labels. {label_value} is invalid") def convert_from_sortingview(sortingview_dict, destination_format=1): @@ -83,7 +101,7 @@ def convert_from_sortingview(sortingview_dict, destination_format=1): u_id = unit_id_type(unit_id) all_units.append(u_id) manual_labels.append({"unit_id": u_id, "label_category": general_cat, "labels": l_labels}) - labels_def = {"all_labels": {"name": "all_labels", "label_options": all_labels, "auto_exclusive": False}} + labels_def = {"all_labels": {"name": "all_labels", "label_options": all_labels, "exclusive": False}} curation_dict = { "unit_ids": None, diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 92fc963cef..6d3700b94a 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -1,6 +1,7 @@ from spikeinterface.curation.curation_format import validate_curation_dict import pytest +import json """example = { 'unit_ids': List[str, int], @@ -8,12 +9,11 @@ 'category_key1': {'name': str, 'label_options': List[str], - 'auto_exclusive': bool} + 'exclusive': bool} }, 'manual_labels': [ {'unit_id': str or int, - 'label_category': str, - 'labels': list or str + category_key1': List[str], } ], 'merged_unit_groups': List[List[unit_ids]], # one cell goes into at most one list @@ -21,136 +21,64 @@ } """ -valid_int = { + +curation_ids_int = { + "format_version": "1", "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], "label_definitions": { - "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "auto_exclusive": True}, - "experimental": { - "name": "experimental", - "label_options": ["acute", "chronic", "headfixed", "freelymoving"], - "auto_exclusive": False, + "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, + "putative_type": {"name": "putative_type", "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral" ], "exclusive": False}, }, - }, "manual_labels": [ - {"unit_id": 1, "label_category": "quality", "labels": "good"}, - {"unit_id": 2, "label_category": "quality", "labels": "noise"}, - {"unit_id": 2, "label_category": "experimental", "labels": ["chronic", "headfixed"]}, + {"unit_id": 1, "quality": ["good"]}, + {"unit_id": 2, "quality": ["noise", ], "putative_type":["excitatory", "pyramidal"]}, + {"unit_id": 3, "putative_type": ["inhibitory"]}, ], "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list "removed_units": [31, 42], # Can not be in the merged_units - "format_version": 1, + } - -valid_str = { +curation_ids_str = { + "format_version": "1", "unit_ids": ["u1", "u2", "u3", "u6", "u10", "u14", "u20", "u31", "u42"], "label_definitions": { - "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "auto_exclusive": True}, - "experimental": { - "name": "experimental", - "label_options": ["acute", "chronic", "headfixed", "freelymoving"], - "auto_exclusive": False, + "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, + "putative_type": {"name": "putative_type", "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral" ], "exclusive": False}, }, - }, "manual_labels": [ - {"unit_id": "u1", "label_category": "quality", "labels": "good"}, - {"unit_id": "u2", "label_category": "quality", "labels": "noise"}, - {"unit_id": "u2", "label_category": "experimental", "labels": ["chronic", "headfixed"]}, + {"unit_id": "u1", "quality": ["good"]}, + {"unit_id": "u2", "quality": ["noise", ], "putative_type":["excitatory", "pyramidal"]}, + {"unit_id": "u3", "putative_type": ["inhibitory"]}, ], "merged_unit_groups": [["u3", "u6"], ["u10", "u14", "u20"]], # one cell goes into at most one list "removed_units": ["u31", "u42"], # Can not be in the merged_units - "format_version": 1, } -# This is a failure example -duplicate_merge = { - "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], - "label_definitions": { - "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "auto_exclusive": True}, - "experimental": { - "name": "experimental", - "label_options": ["acute", "chronic", "headfixed", "freelymoving"], - "auto_exclusive": False, - }, - }, - "manual_labels": [ - {"unit_id": 1, "label_category": "quality", "labels": "good"}, - {"unit_id": 2, "label_category": "quality", "labels": "noise"}, - {"unit_id": 2, "label_category": "experimental", "labels": ["chronic", "headfixed"]}, - ], - "merged_unit_groups": [[3, 6, 10], [10, 14, 20]], # one cell goes into at most one list - "removed_units": [31, 42], # Can not be in the merged_units - "format_version": 1, -} +# This is a failure example with duplicated merge +duplicate_merge = curation_ids_int.copy() +duplicate_merge["merged_unit_groups"] = [[3, 6, 10], [10, 14, 20]] -# This is a failure example -merged_and_removed = { - "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], - "label_definitions": { - "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "auto_exclusive": True}, - "experimental": { - "name": "experimental", - "label_options": ["acute", "chronic", "headfixed", "freelymoving"], - "auto_exclusive": False, - }, - }, - "manual_labels": [ - {"unit_id": 1, "label_category": "quality", "labels": "good"}, - {"unit_id": 2, "label_category": "quality", "labels": "noise"}, - {"unit_id": 2, "label_category": "experimental", "labels": ["chronic", "headfixed"]}, - ], - "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list - "removed_units": [3, 31, 42], # Can not be in the merged_units - "format_version": 1, -} +# This is a failure example with unit 3 both in removed and merged +merged_and_removed = curation_ids_int.copy() +merged_and_removed["merged_unit_groups"] = [[3, 6], [10, 14, 20]] +merged_and_removed["removed_units"] = [3, 31, 42] +# this is a failure because unit 99 is not in the initial list +unknown_merged_unit = curation_ids_int.copy() +unknown_merged_unit["merged_unit_groups"] = [[3, 6, 99], [10, 14, 20]] -unknown_merged_unit = { - "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], - "label_definitions": { - "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "auto_exclusive": True}, - "experimental": { - "name": "experimental", - "label_options": ["acute", "chronic", "headfixed", "freelymoving"], - "auto_exclusive": False, - }, - }, - "manual_labels": [ - {"unit_id": 1, "label_category": "quality", "labels": "good"}, - {"unit_id": 2, "label_category": "quality", "labels": "noise"}, - {"unit_id": 2, "label_category": "experimental", "labels": ["chronic", "headfixed"]}, - ], - "merged_unit_groups": [[3, 6, 99], [10, 14, 20]], # one cell goes into at most one list - "removed_units": [31, 42], # Can not be in the merged_units - "format_version": 1, -} +# this is a failure because unit 99 is not in the initial list +unknown_removed_unit = curation_ids_int.copy() +unknown_removed_unit["removed_units"] = [31, 42, 99] -unknown_removed_unit = { - "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], - "label_definitions": { - "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "auto_exclusive": True}, - "experimental": { - "name": "experimental", - "label_options": ["acute", "chronic", "headfixed", "freelymoving"], - "auto_exclusive": False, - }, - }, - "manual_labels": [ - {"unit_id": 1, "label_category": "quality", "labels": "good"}, - {"unit_id": 2, "label_category": "quality", "labels": "noise"}, - {"unit_id": 2, "label_category": "experimental", "labels": ["chronic", "headfixed"]}, - ], - "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list - "removed_units": [31, 42, 99], # Can not be in the merged_units - "format_version": 1, -} +def test_curation_format_validation(): + validate_curation_dict(curation_ids_int) + validate_curation_dict(curation_ids_str) -def test_curation_format_validation(): - assert validate_curation_dict(valid_int) - assert validate_curation_dict(valid_str) with pytest.raises(ValueError): # Raised because duplicated merged units validate_curation_dict(duplicate_merge) @@ -163,3 +91,16 @@ def test_curation_format_validation(): with pytest.raises(ValueError): # Raise beecause Some removed units are not in the unit list validate_curation_dict(unknown_removed_unit) + + +def test_to_from_json(): + + json.loads(json.dumps(curation_ids_int, indent=4)) + json.loads(json.dumps(curation_ids_str, indent=4)) + + + + +if __name__ == "__main__": + test_curation_format_validation() + # test_to_from_json() From 5a4630bd05bc3b494635ada94cee6dafb7f303f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Jun 2024 12:10:27 +0000 Subject: [PATCH 039/136] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../curation/curation_format.py | 13 +++++--- .../curation/tests/test_curation_format.py | 32 ++++++++++++++----- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index d25efaa6d8..b32fca5ab9 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -18,7 +18,7 @@ def validate_curation_dict(curation_dict): Returns ------- Nothing. - + """ @@ -28,10 +28,11 @@ def validate_curation_dict(curation_dict): if curation_dict["format_version"] not in supported_curation_format_versions: raise ValueError( - f"Format version ({curation_dict['format_version']}) not supported. " f"Only {supported_curation_format_versions} are valid" + f"Format version ({curation_dict['format_version']}) not supported. " + f"Only {supported_curation_format_versions} are valid" ) - # unit_ids + # unit_ids unit_set = set(curation_dict["unit_ids"]) labeled_unit_set = set([lbl["unit_id"] for lbl in curation_dict["manual_labels"]]) merged_units_set = set(sum(curation_dict["merged_unit_groups"], [])) @@ -61,8 +62,10 @@ def validate_curation_dict(curation_dict): is_exclusive = curation_dict["label_definitions"][label_key]["exclusive"] - if is_exclusive and not len(label_value) <=1: - raise ValueError(f"Curation format: manual_labels {unit_id} {label_key} are exclusive labels. {label_value} is invalid") + if is_exclusive and not len(label_value) <= 1: + raise ValueError( + f"Curation format: manual_labels {unit_id} {label_key} are exclusive labels. {label_value} is invalid" + ) def convert_from_sortingview(sortingview_dict, destination_format=1): diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 6d3700b94a..a5a1418f32 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -27,16 +27,25 @@ "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], "label_definitions": { "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, - "putative_type": {"name": "putative_type", "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral" ], "exclusive": False}, + "putative_type": { + "name": "putative_type", + "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral"], + "exclusive": False, }, + }, "manual_labels": [ {"unit_id": 1, "quality": ["good"]}, - {"unit_id": 2, "quality": ["noise", ], "putative_type":["excitatory", "pyramidal"]}, + { + "unit_id": 2, + "quality": [ + "noise", + ], + "putative_type": ["excitatory", "pyramidal"], + }, {"unit_id": 3, "putative_type": ["inhibitory"]}, ], "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list "removed_units": [31, 42], # Can not be in the merged_units - } curation_ids_str = { @@ -44,11 +53,21 @@ "unit_ids": ["u1", "u2", "u3", "u6", "u10", "u14", "u20", "u31", "u42"], "label_definitions": { "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, - "putative_type": {"name": "putative_type", "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral" ], "exclusive": False}, + "putative_type": { + "name": "putative_type", + "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral"], + "exclusive": False, }, + }, "manual_labels": [ {"unit_id": "u1", "quality": ["good"]}, - {"unit_id": "u2", "quality": ["noise", ], "putative_type":["excitatory", "pyramidal"]}, + { + "unit_id": "u2", + "quality": [ + "noise", + ], + "putative_type": ["excitatory", "pyramidal"], + }, {"unit_id": "u3", "putative_type": ["inhibitory"]}, ], "merged_unit_groups": [["u3", "u6"], ["u10", "u14", "u20"]], # one cell goes into at most one list @@ -78,7 +97,6 @@ def test_curation_format_validation(): validate_curation_dict(curation_ids_int) validate_curation_dict(curation_ids_str) - with pytest.raises(ValueError): # Raised because duplicated merged units validate_curation_dict(duplicate_merge) @@ -99,8 +117,6 @@ def test_to_from_json(): json.loads(json.dumps(curation_ids_str, indent=4)) - - if __name__ == "__main__": test_curation_format_validation() # test_to_from_json() From 8ef0633a8248e1aba2c9f49afbd88ac0a4b0b333 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 5 Jun 2024 14:40:28 +0200 Subject: [PATCH 040/136] more tests for converting curation format --- src/spikeinterface/curation/__init__.py | 4 ++ .../curation/curation_format.py | 50 ++++++++++--------- .../curation/tests/test_curation_format.py | 28 ++++++++++- 3 files changed, 56 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 9c6e17edb5..db67479123 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -11,4 +11,8 @@ from .mergeunitssorting import MergeUnitsSorting, merge_units_sorting from .splitunitsorting import SplitUnitSorting, split_unit_sorting +# curation format +from .curation_format import validate_curation_dict + from .sortingview_curation import apply_sortingview_curation + diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index d25efaa6d8..b30a030c5f 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -31,17 +31,20 @@ def validate_curation_dict(curation_dict): f"Format version ({curation_dict['format_version']}) not supported. " f"Only {supported_curation_format_versions} are valid" ) - # unit_ids - unit_set = set(curation_dict["unit_ids"]) + # unit_ids labeled_unit_set = set([lbl["unit_id"] for lbl in curation_dict["manual_labels"]]) merged_units_set = set(sum(curation_dict["merged_unit_groups"], [])) removed_units_set = set(curation_dict["removed_units"]) - if not labeled_unit_set.issubset(unit_set): - raise ValueError("Curation format: some labeled units are not in the unit list") - if not merged_units_set.issubset(unit_set): - raise ValueError("Curation format: some merged units are not in the unit list") - if not removed_units_set.issubset(unit_set): - raise ValueError("Curation format: some removed units are not in the unit list") + + if curation_dict["unit_ids"] is not None: + # old format v0 did not contain unit_ids so this can contains None + unit_set = set(curation_dict["unit_ids"]) + if not labeled_unit_set.issubset(unit_set): + raise ValueError("Curation format: some labeled units are not in the unit list") + if not merged_units_set.issubset(unit_set): + raise ValueError("Curation format: some merged units are not in the unit list") + if not removed_units_set.issubset(unit_set): + raise ValueError("Curation format: some removed units are not in the unit list") all_merging_groups = [set(group) for group in curation_dict["merged_unit_groups"]] for gp_1, gp_2 in combinations(all_merging_groups, 2): @@ -65,9 +68,9 @@ def validate_curation_dict(curation_dict): raise ValueError(f"Curation format: manual_labels {unit_id} {label_key} are exclusive labels. {label_value} is invalid") -def convert_from_sortingview(sortingview_dict, destination_format=1): +def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_format="1"): """ - Converts the sortingview curation format into a curation dictionary + Converts the old sortingview curation format (v0) into a curation dictionary new format (v1) Couple of caveats: * The list of units is not available in the original sortingview dictionary. We set it to None * Labels can not be mutually exclusive. @@ -77,15 +80,18 @@ def convert_from_sortingview(sortingview_dict, destination_format=1): ---------- sortingview_dict : dict Dictionary containing the curation information from sortingview - destination_format : int + destination_format : str Version of the format to use. - Default to 1 + Default to "1" Returns ------- curation_dict: dict A curation dictionary """ + + assert destination_format == "1" + merge_groups = sortingview_dict["mergeGroups"] merged_units = sum(merge_groups, []) if len(merged_units) > 0: @@ -96,28 +102,24 @@ def convert_from_sortingview(sortingview_dict, destination_format=1): all_labels = [] manual_labels = [] general_cat = "all_labels" - for unit_id, l_labels in sortingview_dict["labelsByUnit"].items(): + for unit_id_, l_labels in sortingview_dict["labelsByUnit"].items(): all_labels.extend(l_labels) - u_id = unit_id_type(unit_id) - all_units.append(u_id) - manual_labels.append({"unit_id": u_id, "label_category": general_cat, "labels": l_labels}) - labels_def = {"all_labels": {"name": "all_labels", "label_options": all_labels, "exclusive": False}} + # recorver the correct type for unit_id + unit_id = unit_id_type(unit_id_) + all_units.append(unit_id) + manual_labels.append({"unit_id": unit_id, general_cat: l_labels}) + labels_def = {"all_labels": {"name": "all_labels", "label_options": list(set(all_labels)), "exclusive": False}} curation_dict = { + "format_version": destination_format, "unit_ids": None, "label_definitions": labels_def, "manual_labels": manual_labels, "merged_unit_groups": merge_groups, "removed_units": [], - "format_version": destination_format, + } return curation_dict -if __name__ == "__main__": - import json - - with open("src/spikeinterface/curation/tests/sv-sorting-curation-str.json") as jf: - sv_curation = json.load(jf) - cur_d = convert_from_sortingview(sortingview_dict=sv_curation) diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 6d3700b94a..632ebc2fba 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -1,8 +1,13 @@ -from spikeinterface.curation.curation_format import validate_curation_dict import pytest +from pathlib import Path import json +from spikeinterface.curation.curation_format import validate_curation_dict, convert_from_sortingview_curation_format_v0 + + + + """example = { 'unit_ids': List[str, int], 'label_definitions': { @@ -99,8 +104,27 @@ def test_to_from_json(): json.loads(json.dumps(curation_ids_str, indent=4)) +def test_convert_from_sortingview_curation_format_v0(): + + parent_folder = Path(__file__).parent + for filename in ( + "sv-sorting-curation.json", + "sv-sorting-curation-int.json", + "sv-sorting-curation-str.json", + "sv-sorting-curation-false-positive.json", + ): + + json_file = parent_folder / filename + with open(json_file, "r") as f: + curation_v0 = json.load(f) + # print(curation_v0) + curation_v1 = convert_from_sortingview_curation_format_v0(curation_v0) + # print(curation_v1) + validate_curation_dict(curation_v1) + if __name__ == "__main__": - test_curation_format_validation() + # test_curation_format_validation() # test_to_from_json() + test_convert_from_sortingview_curation_format_v0() From 85dcea7d745b0ac3daa7036dea0602dc9b460092 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Jun 2024 12:42:56 +0000 Subject: [PATCH 041/136] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/__init__.py | 1 - src/spikeinterface/curation/curation_format.py | 3 --- src/spikeinterface/curation/tests/test_curation_format.py | 3 --- 3 files changed, 7 deletions(-) diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index db67479123..f541ff8ca5 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -15,4 +15,3 @@ from .curation_format import validate_curation_dict from .sortingview_curation import apply_sortingview_curation - diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 85a689dfd0..3535bb67a5 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -120,9 +120,6 @@ def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_fo "manual_labels": manual_labels, "merged_unit_groups": merge_groups, "removed_units": [], - } return curation_dict - - diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index d1d5120ffb..1945e4ca02 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -6,8 +6,6 @@ from spikeinterface.curation.curation_format import validate_curation_dict, convert_from_sortingview_curation_format_v0 - - """example = { 'unit_ids': List[str, int], 'label_definitions': { @@ -141,7 +139,6 @@ def test_convert_from_sortingview_curation_format_v0(): validate_curation_dict(curation_v1) - if __name__ == "__main__": # test_curation_format_validation() # test_to_from_json() From 127b42611b9a54a7ad9084905ba5ccc63f0931ac Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 5 Jun 2024 13:48:52 +0100 Subject: [PATCH 042/136] Fix all numpydoc validate PR01 in preprocessing docstrings --- src/spikeinterface/preprocessing/astype.py | 16 ++++++++-- .../preprocessing/depth_order.py | 2 +- .../preprocessing/detect_bad_channels.py | 30 +++++++++---------- src/spikeinterface/preprocessing/filter.py | 21 +++++++++++-- .../preprocessing/normalize_scale.py | 2 +- .../preprocessing/phase_shift.py | 3 ++ src/spikeinterface/preprocessing/resample.py | 3 ++ .../preprocessing/silence_periods.py | 3 +- 8 files changed, 57 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/preprocessing/astype.py b/src/spikeinterface/preprocessing/astype.py index da1435130c..4b0d5f9e55 100644 --- a/src/spikeinterface/preprocessing/astype.py +++ b/src/spikeinterface/preprocessing/astype.py @@ -14,8 +14,20 @@ class AstypeRecording(BasePreprocessor): For recording with an unsigned dtype, please use the `unsigned_to_signed` preprocessing function. - If `round` is True, will round the values to the nearest integer. - If `round` is None, will round in the case of float to integer conversion. + Parameters + ---------- + dtype : None | str | dtype, default: None + dtype of the output recording. + recording : Recording + The recording extractor to be converted. + round : Bool + If True, will round the values to the nearest integer. + If None, will round in the case of float to integer conversion. + + Returns + ------- + astype_recording : AstypeRecording + The converted recording extractor object """ name = "astype" diff --git a/src/spikeinterface/preprocessing/depth_order.py b/src/spikeinterface/preprocessing/depth_order.py index 9569459080..f08f6404da 100644 --- a/src/spikeinterface/preprocessing/depth_order.py +++ b/src/spikeinterface/preprocessing/depth_order.py @@ -12,7 +12,7 @@ class DepthOrderRecording(ChannelSliceRecording): Parameters ---------- - recording : BaseRecording + parent_recording : BaseRecording The recording to re-order. channel_ids : list/array or None If given, a subset of channels to order locations for diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 276a8ac0b4..218c9cb822 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -57,29 +57,29 @@ def detect_bad_channels( The method to be used for bad channel detection std_mad_threshold : float, default: 5 The standard deviation/mad multiplier threshold - psd_hf_threshold (coeherence+psd) : float, default: 0.02 - An absolute threshold (uV^2/Hz) used as a cutoff for noise channels. + psd_hf_threshold : float, default: 0.02 + Coeherence+psd. An absolute threshold (uV^2/Hz) used as a cutoff for noise channels. Channels with average power at >80% Nyquist larger than this threshold will be labeled as noise - dead_channel_threshold (coeherence+psd) : float, default: -0.5 - Threshold for channel coherence below which channels are labeled as dead - noisy_channel_threshold (coeherence+psd) : float, default: 1 + dead_channel_threshold : float, default: -0.5 + Coeherence+psd. Threshold for channel coherence below which channels are labeled as dead + noisy_channel_threshold : float, default: 1 Threshold for channel coherence above which channels are labeled as noisy (together with psd condition) - outside_channel_threshold (coeherence+psd) : float, default: -0.75 - Threshold for channel coherence above which channels at the edge of the recording are marked as outside + outside_channel_threshold : float, default: -0.75 + Coeherence+psd. Threshold for channel coherence above which channels at the edge of the recording are marked as outside of the brain - outside_channels_location (coeherence+psd) : "top" | "bottom" | "both", default: "top" - Location of the outside channels. If "top", only the channels at the top of the probe can be + outside_channels_location : "top" | "bottom" | "both", default: "top" + Coeherence+psd. Location of the outside channels. If "top", only the channels at the top of the probe can be marked as outside channels. If "bottom", only the channels at the bottom of the probe can be marked as outside channels. If "both", both the channels at the top and bottom of the probe can be marked as outside channels - n_neighbors (coeherence+psd) : int, default: 11 - Number of channel neighbors to compute median filter (needs to be odd) - nyquist_threshold (coeherence+psd) : float, default: 0.8 - Frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared + n_neighbors : int, default: 11 + Coeherence+psd. Number of channel neighbors to compute median filter (needs to be odd) + nyquist_threshold : float, default: 0.8 + Coeherence+psd. Frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared with psd_hf_threshold - direction (coeherence+psd) : "x" | "y" | "z", default: "y" - The depth dimension + direction : "x" | "y" | "z", default: "y" + Coeherence+psd. The depth dimension highpass_filter_cutoff : float, default: 300 If the recording is not filtered, the cutoff frequency of the highpass filter chunk_duration_s : float, default: 0.5 diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 3f1a155d0d..84ac542acc 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -43,11 +43,16 @@ class FilterRecording(BasePreprocessor): Filter form of the filter coefficients: - second-order sections ("sos") - numerator/denominator : ("ba") - coef : array or None, default: None + coeff : array | None, default: None Filter coefficients in the filter_mode form. dtype : dtype or None, default: None The dtype of the returned traces. If None, the dtype of the parent recording is used - {} + add_reflect_padding : Bool, default False + If True, uses a left and right margin during calculation. + ftype : str | None, default: "butter" + The type of IIR filter to design, used in `scipy.signal.iirfilter`. + filter_order : int, default: 5 + The order of the filter, used in `scipy.signal.iirfilter`. Returns ------- @@ -178,7 +183,9 @@ class BandpassFilterRecording(FilterRecording): Margin in ms on border to avoid border effect dtype : dtype or None The dtype of the returned traces. If None, the dtype of the parent recording is used - {} + **filter_kwargs : dict + Keyword arguments for `spikeinterface.preprocessing.FilterRecording` class. + Returns ------- filter_recording : BandpassFilterRecording @@ -212,6 +219,9 @@ class HighpassFilterRecording(FilterRecording): Margin in ms on border to avoid border effect dtype : dtype or None The dtype of the returned traces. If None, the dtype of the parent recording is used + **filter_kwargs : dict + Keyword arguments for `spikeinterface.preprocessing.FilterRecording` class. + {} Returns ------- @@ -240,6 +250,11 @@ class NotchFilterRecording(BasePreprocessor): The target frequency in Hz of the notch filter q : int The quality factor of the notch filter + dtype : None | dtype, default: None + dtype of recording. If None, will take from `recording` + margin_ms : float, default: 5.0 + Margin in ms on border to avoid border effect + {} Returns ------- diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 44b9ac9937..e537be4694 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -54,7 +54,7 @@ class NormalizeByQuantileRecording(BasePreprocessor): Median for the output distribution q1 : float, default: 0.01 Lower quantile used for measuring the scale - q1 : float, default: 0.99 + q2 : float, default: 0.99 Upper quantile used for measuring the mode : "by_channel" | "pool_channel", default: "by_channel" If "by_channel" each channel is rescaled independently. diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 41c18e2f38..ac308c975d 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -31,6 +31,9 @@ class PhaseShiftRecording(BasePreprocessor): inter_sample_shift : None or numpy array, default: None If "inter_sample_shift" is not in recording properties, we can externally provide one. + dtype : None | str | dtype, default: None + Dtype of input and output `recording` objects. + Returns ------- diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index cc110118a5..54a602b7c0 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -34,6 +34,9 @@ class ResampleRecording(BasePreprocessor): The dtype of the returned traces. If None, the dtype of the parent recording is used. skip_checks : bool, default: False If True, checks on sampling frequencies and cutoff filter frequencies are skipped + margin_ms : float, default: 100.0 + Margin in ms on border to avoid border effect + Returns ------- diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 5f70bfbb40..88c7e2109c 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -25,7 +25,8 @@ class SilencedPeriodsRecording(BasePreprocessor): One list per segment of tuples (start_frame, end_frame) to silence noise_levels : array Noise levels if already computed - + seed : int | None, default: None + Random seed for `get_noise_levels` mode : "zeros" | "noise, default: "zeros" Determines what periods are replaced by. Can be one of the following: From 1d10bdc986456c02a05853f02f96b7e562130920 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Wed, 5 Jun 2024 15:10:43 +0200 Subject: [PATCH 043/136] Merci Zach Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- doc/modules/curation.rst | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index 83410494ba..2c5bb84071 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -44,10 +44,10 @@ The merging and splitting operations are handled by the :py:class:`~spikeinterfa Manual curation format ---------------------- -SpikeInterface internally support a manual curation format JSON based. -When a mnual curation is necessary, modifying in place a dataset is a bad practice. -Instead, to keep the reproducibility in the spike sorting piepline, we introduce a manual curation format, -simple and JSON based. This format defines at the moment : merges + deletions + manual tags. +SpikeInterface internally supports a JSON-based manual curation format. +When manual curation is necessary, modifying a dataset in place is a bad practice. +Instead, to ensure the reproducibility of the spike sorting pipelines, we have introduced a simple and JSON-based manual curation format. +This format defines at the moment : merges + deletions + manual tags. The simple file can be kept along side the output of a sorter and applied on the result to have a "clean" result. This format has two part: @@ -55,8 +55,8 @@ This format has two part: * **definition** with the folowing keys: * "format_version" : format specification - * "unit_ids" : give the list of unit_ds - * "label_definitions" : list of label category and possible labels per category. + * "unit_ids" : the list of unit_ds + * "label_definitions" : list of label categories and possible labels per category. Every category can be *exclusive=True* onely one label or *exclusive=False* several labels possible * **manual output** curation with the folowing keys: @@ -65,7 +65,7 @@ This format has two part: * "merged_unit_groups" * "removed_units" -Here the description of the format with a simple example: +Here is the description of the format with a simple example: .. code-block:: json From 710eade8d75b65b9db5ad74624acae69b3ae1ce8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 5 Jun 2024 15:40:35 +0200 Subject: [PATCH 044/136] curation_label_to_dataframe() --- src/spikeinterface/curation/__init__.py | 2 +- .../curation/curation_format.py | 41 +++++++++++++++++++ .../curation/tests/test_curation_format.py | 22 ++++++---- 3 files changed, 57 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index db67479123..1cfcfe2db2 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -12,7 +12,7 @@ from .splitunitsorting import SplitUnitSorting, split_unit_sorting # curation format -from .curation_format import validate_curation_dict +from .curation_format import validate_curation_dict, curation_label_to_dataframe from .sortingview_curation import apply_sortingview_curation diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 85a689dfd0..0ce4264010 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -126,3 +126,44 @@ def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_fo return curation_dict +def curation_label_to_dataframe(curation_dict): + """ + Transform the curation dict into a pandas dataframe. + For label category with exclusive=True : a column is created and values are the unique label. + For label category with exclusive=False : one column per possible is created and values are boolean. + + If exclusive=False and the same label appear several times then it raises an error. + + Parameters + ---------- + curation_dict : dict + A curation dictionary + + Returns + ------- + labels : pd.DataFrame + dataframe with labels. + """ + import pandas as pd + labels = pd.DataFrame(index=curation_dict["unit_ids"]) + + for label_key, label_def in curation_dict["label_definitions"].items(): + if label_def["exclusive"]: + assert label_key not in labels.columns, f"{label_key} is already a column" + labels[label_key] = pd.Series(dtype=str) + labels[label_key][:] = "" + for lbl in curation_dict["manual_labels"]: + value = lbl.get(label_key, []) + if len(value) == 1: + labels.at[lbl["unit_id"], label_key] = value[0] + else: + for label_opt in label_def["label_options"]: + assert label_opt not in labels.columns, f"{label_opt} is already a column" + labels[label_opt] = pd.Series(dtype=bool) + labels[label_opt][:] = False + for lbl in curation_dict["manual_labels"]: + values = lbl.get(label_key, []) + for value in values: + labels.at[lbl["unit_id"], value] = True + + return labels diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index d1d5120ffb..3a4b2a7ec5 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -3,7 +3,7 @@ from pathlib import Path import json -from spikeinterface.curation.curation_format import validate_curation_dict, convert_from_sortingview_curation_format_v0 +from spikeinterface.curation.curation_format import validate_curation_dict, convert_from_sortingview_curation_format_v0, curation_label_to_dataframe @@ -12,7 +12,7 @@ 'unit_ids': List[str, int], 'label_definitions': { 'category_key1': - {'name': str, + { 'label_options': List[str], 'exclusive': bool} }, @@ -31,9 +31,8 @@ "format_version": "1", "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], "label_definitions": { - "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, + "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, "putative_type": { - "name": "putative_type", "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral"], "exclusive": False, }, @@ -57,9 +56,8 @@ "format_version": "1", "unit_ids": ["u1", "u2", "u3", "u6", "u10", "u14", "u20", "u31", "u42"], "label_definitions": { - "quality": {"name": "quality", "label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, + "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, "putative_type": { - "name": "putative_type", "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral"], "exclusive": False, }, @@ -140,9 +138,19 @@ def test_convert_from_sortingview_curation_format_v0(): # print(curation_v1) validate_curation_dict(curation_v1) +def test_curation_label_to_dataframe(): + + df = curation_label_to_dataframe(curation_ids_int) + assert "quality" in df.columns + assert "excitatory" in df.columns + print(df) + + df = curation_label_to_dataframe(curation_ids_str) + # print(df) if __name__ == "__main__": # test_curation_format_validation() # test_to_from_json() - test_convert_from_sortingview_curation_format_v0() + # test_convert_from_sortingview_curation_format_v0() + test_curation_label_to_dataframe() From 52857ccfc6a00aa3daef1e7d130cb498caecfc99 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Jun 2024 13:42:01 +0000 Subject: [PATCH 045/136] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/curation_format.py | 5 +++-- src/spikeinterface/curation/tests/test_curation_format.py | 8 +++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index e6fbaac4f7..82921a56b5 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -132,7 +132,7 @@ def curation_label_to_dataframe(curation_dict): For label category with exclusive=False : one column per possible is created and values are boolean. If exclusive=False and the same label appear several times then it raises an error. - + Parameters ---------- curation_dict : dict @@ -144,6 +144,7 @@ def curation_label_to_dataframe(curation_dict): dataframe with labels. """ import pandas as pd + labels = pd.DataFrame(index=curation_dict["unit_ids"]) for label_key, label_def in curation_dict["label_definitions"].items(): @@ -164,5 +165,5 @@ def curation_label_to_dataframe(curation_dict): values = lbl.get(label_key, []) for value in values: labels.at[lbl["unit_id"], value] = True - + return labels diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 802d7ce49c..c691543414 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -3,7 +3,11 @@ from pathlib import Path import json -from spikeinterface.curation.curation_format import validate_curation_dict, convert_from_sortingview_curation_format_v0, curation_label_to_dataframe +from spikeinterface.curation.curation_format import ( + validate_curation_dict, + convert_from_sortingview_curation_format_v0, + curation_label_to_dataframe, +) """example = { @@ -136,6 +140,7 @@ def test_convert_from_sortingview_curation_format_v0(): # print(curation_v1) validate_curation_dict(curation_v1) + def test_curation_label_to_dataframe(): df = curation_label_to_dataframe(curation_ids_int) @@ -146,6 +151,7 @@ def test_curation_label_to_dataframe(): df = curation_label_to_dataframe(curation_ids_str) # print(df) + if __name__ == "__main__": # test_curation_format_validation() # test_to_from_json() From 22ff94dff957f874c395aa3b7c10edc85b1a20ae Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 5 Jun 2024 15:45:05 +0200 Subject: [PATCH 046/136] TODO for alessio in the code. --- src/spikeinterface/curation/sortingview_curation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index b31b8c39d5..267f1e423b 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -5,7 +5,8 @@ from .curationsorting import CurationSorting - +# @alessio +# TODO later : this should be reimplemented using the new curation format def apply_sortingview_curation( sorting, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=False ): From a9c5aadc3e3a8833c782abe6ba3a4b2734d013f8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 5 Jun 2024 15:50:01 +0200 Subject: [PATCH 047/136] oups --- doc/modules/curation.rst | 36 +++++++++---------- .../curation/tests/test_curation_format.py | 4 ++- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index 2c5bb84071..46fdcc6d65 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -69,7 +69,6 @@ Here is the description of the format with a simple example: .. code-block:: json - { # the first part of the format is the definitation "format_version": "1", @@ -86,7 +85,6 @@ Here is the description of the format with a simple example: ], "label_definitions": { "quality": { - "name": "quality", "label_options": [ "good", "noise", @@ -95,13 +93,12 @@ Here is the description of the format with a simple example: ], "exclusive": true }, - "experimental": { - "name": "experimental", + "putative_type": { "label_options": [ - "acute", - "chronic", - "headfixed", - "freelymoving" + "excitatory", + "inhibitory", + "pyramidal", + "mitral" ], "exclusive": false } @@ -110,20 +107,24 @@ Here is the description of the format with a simple example: "manual_labels": [ { "unit_id": "u1", - "label_category": "quality", - "labels": "good" + "quality": [ + "good" + ] }, { "unit_id": "u2", - "label_category": "quality", - "labels": "noise" + "quality": [ + "noise" + ], + "putative_type": [ + "excitatory", + "pyramidal" + ] }, { - "unit_id": "u2", - "label_category": "experimental", - "labels": [ - "chronic", - "headfixed" + "unit_id": "u3", + "putative_type": [ + "inhibitory" ] } ], @@ -146,7 +147,6 @@ Here is the description of the format with a simple example: - Automatic curation tools ------------------------ diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 802d7ce49c..778091664c 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -150,4 +150,6 @@ def test_curation_label_to_dataframe(): # test_curation_format_validation() # test_to_from_json() # test_convert_from_sortingview_curation_format_v0() - test_curation_label_to_dataframe() + # test_curation_label_to_dataframe() + + print(json.dumps(curation_ids_str, indent=4)) From 6eef836a16c1d17d96266d60828d6d7c151382c3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Jun 2024 13:50:37 +0000 Subject: [PATCH 048/136] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/sortingview_curation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index 267f1e423b..c4d2a32958 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -5,6 +5,7 @@ from .curationsorting import CurationSorting + # @alessio # TODO later : this should be reimplemented using the new curation format def apply_sortingview_curation( From 3e5d1a429d16ba3a80842d1279d821161ee3e939 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 17:51:48 +0100 Subject: [PATCH 049/136] Use pytest fixture instead of unittest, base class and test_amplitude_scalings. --- .../tests/common_extension_tests.py | 23 ++++++++++++++----- .../tests/test_amplitude_scalings.py | 4 ++-- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 605997f5f6..706918187f 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -73,11 +73,24 @@ class AnalyzerExtensionCommonTestSuite: extension_class = None extension_function_params_list = None - @classmethod - def setUpClass(cls): - cls.recording, cls.sorting = get_dataset() + @pytest.fixture(autouse=True, scope="class") + def setUpClass(self): + """ + This method sets up the class once at the start of testing. It is + in scope for the lifetime of te class and is reused across all + tests that inherit from this base class to save processing time and + force a small radius. + + When setting attributes on `self` in `scope="class"` a new + class instance is used for each. In this case, we have to set + from the base object `__class__` to ensure the attributes + are available to all subclass instances. + """ + self.__class__.recording, self.__class__.sorting = get_dataset() # sparsity is computed once for all cases to save processing time and force a small radius - cls.sparsity = estimate_sparsity(cls.recording, cls.sorting, method="radius", radius_um=20) + self.__class__.sparsity = estimate_sparsity( + self.__class__.recording, self.__class__.sorting, method="radius", radius_um=20 + ) @property def extension_name(self): @@ -114,12 +127,10 @@ def _check_one(self, sorting_analyzer): some_unit_ids = sorting_analyzer.unit_ids[::2] sliced = sorting_analyzer.select_units(some_unit_ids, format="memory") assert np.array_equal(sliced.unit_ids, sorting_analyzer.unit_ids[::2]) - # print(sliced) def test_extension(self): for sparse in (True, False): for format in ("memory", "binary_folder", "zarr"): - print() print("sparse", sparse, format) sorting_analyzer = self._prepare_sorting_analyzer(format, sparse) self._check_one(sorting_analyzer) diff --git a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py index b59aca16a8..f5ef0db956 100644 --- a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py @@ -7,7 +7,7 @@ from spikeinterface.postprocessing import ComputeAmplitudeScalings -class AmplitudeScalingsExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): +class TestAmplitudeScalingsExtension(AnalyzerExtensionCommonTestSuite): extension_class = ComputeAmplitudeScalings extension_function_params_list = [ dict(handle_collisions=True), @@ -36,7 +36,7 @@ def test_scaling_values(self): if __name__ == "__main__": - test = AmplitudeScalingsExtensionTest() + test = TestAmplitudeScalingsExtension() test.setUpClass() test.test_extension() test.test_scaling_values() From f0de1fd6357d4ee976a68d03421d549a5581e44e Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 17:55:44 +0100 Subject: [PATCH 050/136] test_correlograms.py --- .../postprocessing/tests/common_extension_tests.py | 2 ++ src/spikeinterface/postprocessing/tests/test_correlograms.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 706918187f..d8575cbeb2 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -115,6 +115,8 @@ def _check_one(self, sorting_analyzer): else: job_kwargs = dict() + # TODO: a downside of this approach is each parameterisation does + # not get it's own test, but all falls under the same test. for params in self.extension_function_params_list: print(" params", params) ext = sorting_analyzer.compute(self.extension_name, **params, **job_kwargs) diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 6d727e6448..56b4032630 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -16,7 +16,7 @@ from spikeinterface.postprocessing.correlograms import compute_correlograms_on_sorting, _make_bins -class ComputeCorrelogramsTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): +class TestComputeCorrelograms(AnalyzerExtensionCommonTestSuite): extension_class = ComputeCorrelograms extension_function_params_list = [ dict(method="numpy"), From 471d0c8f5b7c6e27293acc486b5442c3d2f950f4 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 18:01:10 +0100 Subject: [PATCH 051/136] test_isi.py --- src/spikeinterface/postprocessing/tests/test_isi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/test_isi.py b/src/spikeinterface/postprocessing/tests/test_isi.py index 8626e56453..3eff96ebfb 100644 --- a/src/spikeinterface/postprocessing/tests/test_isi.py +++ b/src/spikeinterface/postprocessing/tests/test_isi.py @@ -16,7 +16,7 @@ HAVE_NUMBA = False -class ComputeISIHistogramsTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): +class TestComputeISIHistograms(AnalyzerExtensionCommonTestSuite): extension_class = ComputeISIHistograms extension_function_params_list = [ dict(method="numpy"), From 38ee32284cb34916a71d4dc79bd964b47f5c2aa3 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 18:02:18 +0100 Subject: [PATCH 052/136] Add test noise levels note. --- src/spikeinterface/postprocessing/tests/test_noise_levels.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/postprocessing/tests/test_noise_levels.py b/src/spikeinterface/postprocessing/tests/test_noise_levels.py index f334f92fa6..0f9265d00f 100644 --- a/src/spikeinterface/postprocessing/tests/test_noise_levels.py +++ b/src/spikeinterface/postprocessing/tests/test_noise_levels.py @@ -1 +1,3 @@ # "noise_levels" extensions is now in core + +# TODO: can this page now be deleted? From d1a05b9396d8a225c6a2a78381d8f31649090e16 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 18:05:13 +0100 Subject: [PATCH 053/136] test_principal_component.py --- .../postprocessing/tests/test_principal_component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index d94d7ea586..91b23bab9a 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -11,7 +11,7 @@ DEBUG = False -class PrincipalComponentsExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): +class TestPrincipalComponentsExtension(AnalyzerExtensionCommonTestSuite): extension_class = ComputePrincipalComponents extension_function_params_list = [ dict(mode="by_channel_local"), From 6853f90addac639acfad29f326c521f6b8ea8fc7 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 18:22:46 +0100 Subject: [PATCH 054/136] test_spike_amplitudes.py --- .../postprocessing/tests/test_spike_amplitudes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py index 8ff7666371..08f1ed31db 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py @@ -5,7 +5,7 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -class ComputeSpikeAmplitudesTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): +class TestComputeSpikeAmplitudes(AnalyzerExtensionCommonTestSuite): extension_class = ComputeSpikeAmplitudes extension_function_params_list = [ dict(), @@ -13,7 +13,7 @@ class ComputeSpikeAmplitudesTest(AnalyzerExtensionCommonTestSuite, unittest.Test if __name__ == "__main__": - test = ComputeSpikeAmplitudesTest() + test = TestComputeSpikeAmplitudes() test.setUpClass() test.test_extension() From e8b3c734a7633b3c17372acc1be76a96a88b4e51 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 18:23:46 +0100 Subject: [PATCH 055/136] Fix missed class renaming in __main__ blocks. --- src/spikeinterface/postprocessing/tests/test_correlograms.py | 2 +- src/spikeinterface/postprocessing/tests/test_isi.py | 2 +- .../postprocessing/tests/test_principal_component.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 56b4032630..da3a3697eb 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -199,6 +199,6 @@ def test_detect_injected_correlation(): # test_auto_equal_cross_correlograms() # test_detect_injected_correlation() - test = ComputeCorrelogramsTest() + test = TestComputeCorrelograms() test.setUpClass() test.test_extension() diff --git a/src/spikeinterface/postprocessing/tests/test_isi.py b/src/spikeinterface/postprocessing/tests/test_isi.py index 3eff96ebfb..e4ce38cea8 100644 --- a/src/spikeinterface/postprocessing/tests/test_isi.py +++ b/src/spikeinterface/postprocessing/tests/test_isi.py @@ -47,7 +47,7 @@ def _test_ISI(sorting, window_ms: float, bin_ms: float, methods: List[str]): if __name__ == "__main__": - test = ComputeISIHistogramsTest() + test = TestComputeISIHistograms() test.setUpClass() test.test_extension() test.test_compute_ISI() diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 91b23bab9a..d1ee1589db 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -133,7 +133,7 @@ def test_project_new(self): if __name__ == "__main__": - test = PrincipalComponentsExtensionTest() + test = TestPrincipalComponentsExtension() test.setUpClass() test.test_extension() test.test_mode_concatenated() From c43fd1f0f5e4e01cb92a92c4f9bb8bbb44a9a11f Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 18:43:21 +0100 Subject: [PATCH 056/136] test_spike_locations.py --- .../postprocessing/tests/test_spike_locations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index d48ff3d84b..09250c9813 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -5,7 +5,7 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -class SpikeLocationsExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): +class TestSpikeLocationsExtension(AnalyzerExtensionCommonTestSuite): extension_class = ComputeSpikeLocations extension_function_params_list = [ dict( @@ -21,6 +21,6 @@ class SpikeLocationsExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.Tes if __name__ == "__main__": - test = SpikeLocationsExtensionTest() + test = TestSpikeLocationsExtension() test.setUpClass() test.test_extension() From 7c587cf30300fc355548991a3e05ae6e3ce47521 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 18:44:07 +0100 Subject: [PATCH 057/136] test_template_metrics.py --- .../postprocessing/tests/test_template_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 360f0f379f..fda8d19da5 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -5,7 +5,7 @@ from spikeinterface.postprocessing import ComputeTemplateMetrics -class TemplateMetricsTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): +class TestTemplateMetrics(AnalyzerExtensionCommonTestSuite): extension_class = ComputeTemplateMetrics extension_function_params_list = [ dict(), @@ -15,6 +15,6 @@ class TemplateMetricsTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): if __name__ == "__main__": - test = TemplateMetricsTest() + test = TestTemplateMetrics() test.setUpClass() test.test_extension() From 4d2b52a438bbdbbad395772354064b6b337b2cc4 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 18:44:44 +0100 Subject: [PATCH 058/136] test_template_similarity.py --- .../postprocessing/tests/test_template_similarity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 534c909592..7e25db14f7 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -9,7 +9,7 @@ from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity -class SimilarityExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): +class TestSimilarityExtension(AnalyzerExtensionCommonTestSuite): extension_class = ComputeTemplateSimilarity extension_function_params_list = [ dict(method="cosine_similarity"), From 9c4fc69e6b419cd08cd9c07d61d2059d678641f7 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 18:45:43 +0100 Subject: [PATCH 059/136] test_unit_localization.py --- .../postprocessing/tests/test_unit_localization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_unit_localization.py b/src/spikeinterface/postprocessing/tests/test_unit_localization.py index b23adf5868..fd5df3105c 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_localization.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_localization.py @@ -3,7 +3,7 @@ from spikeinterface.postprocessing import ComputeUnitLocations -class UnitLocationsExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): +class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite): extension_class = ComputeUnitLocations extension_function_params_list = [ dict(method="center_of_mass", radius_um=100), @@ -15,7 +15,7 @@ class UnitLocationsExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.Test if __name__ == "__main__": - test = UnitLocationsExtensionTest() + test = TestUnitLocationsExtension() test.setUpClass() test.test_extension() # test.tearDown() From e0466a22ebc8d2dd2a827c65db962c198a36fee5 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 18:49:59 +0100 Subject: [PATCH 060/136] Remove now-broken main blocks. --- .../postprocessing/tests/test_amplitude_scalings.py | 7 ------- .../postprocessing/tests/test_correlograms.py | 12 ------------ src/spikeinterface/postprocessing/tests/test_isi.py | 7 ------- .../postprocessing/tests/test_principal_component.py | 11 ----------- .../postprocessing/tests/test_spike_amplitudes.py | 10 ---------- .../postprocessing/tests/test_spike_locations.py | 6 ------ .../postprocessing/tests/test_template_metrics.py | 6 ------ .../postprocessing/tests/test_template_similarity.py | 8 -------- .../postprocessing/tests/test_unit_localization.py | 7 ------- 9 files changed, 74 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py index f5ef0db956..ebc35da348 100644 --- a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py @@ -33,10 +33,3 @@ def test_scaling_values(self): # fig, ax = plt.subplots() # ax.hist(ext.data["amplitude_scalings"]) # plt.show() - - -if __name__ == "__main__": - test = TestAmplitudeScalingsExtension() - test.setUpClass() - test.test_extension() - test.test_scaling_values() diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index da3a3697eb..c7c8e00722 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -190,15 +190,3 @@ def test_detect_injected_correlation(): # ax.set_title(method) # ax.legend() # plt.show() - - -if __name__ == "__main__": - # test_make_bins() - # test_equal_results_correlograms() - # test_flat_cross_correlogram() - # test_auto_equal_cross_correlograms() - # test_detect_injected_correlation() - - test = TestComputeCorrelograms() - test.setUpClass() - test.test_extension() diff --git a/src/spikeinterface/postprocessing/tests/test_isi.py b/src/spikeinterface/postprocessing/tests/test_isi.py index e4ce38cea8..3110cd7c93 100644 --- a/src/spikeinterface/postprocessing/tests/test_isi.py +++ b/src/spikeinterface/postprocessing/tests/test_isi.py @@ -44,10 +44,3 @@ def _test_ISI(sorting, window_ms: float, bin_ms: float, methods: List[str]): else: assert np.all(ISI == ref_ISI), f"Failed with method={method}" assert np.allclose(bins, ref_bins, atol=1e-10), f"Failed with method={method}" - - -if __name__ == "__main__": - test = TestComputeISIHistograms() - test.setUpClass() - test.test_extension() - test.test_compute_ISI() diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index d1ee1589db..5d616575c9 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -112,7 +112,6 @@ def test_compute_for_all_spikes(self): assert np.array_equal(all_pc1, all_pc2) def test_project_new(self): - from sklearn.decomposition import IncrementalPCA sorting_analyzer = self._prepare_sorting_analyzer(format="memory", sparse=False) @@ -131,16 +130,6 @@ def test_project_new(self): assert new_proj.shape[1] == n_components assert new_proj.shape[2] == ext_pca.data["pca_projection"].shape[2] - -if __name__ == "__main__": - test = TestPrincipalComponentsExtension() - test.setUpClass() - test.test_extension() - test.test_mode_concatenated() - test.test_get_projections() - test.test_compute_for_all_spikes() - test.test_project_new() - # ext = test.sorting_analyzers["sparseTrue_memory"].get_extension("principal_components") # pca = ext.data["pca_projection"] # import matplotlib.pyplot as plt diff --git a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py index 08f1ed31db..3b288c540c 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py @@ -10,13 +10,3 @@ class TestComputeSpikeAmplitudes(AnalyzerExtensionCommonTestSuite): extension_function_params_list = [ dict(), ] - - -if __name__ == "__main__": - test = TestComputeSpikeAmplitudes() - test.setUpClass() - test.test_extension() - - # for k, sorting_analyzer in test.sorting_analyzers.items(): - # print(sorting_analyzer) - # print(sorting_analyzer.get_extension("spike_amplitudes").data["amplitudes"].shape) diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index 09250c9813..223d74046f 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -18,9 +18,3 @@ class TestSpikeLocationsExtension(AnalyzerExtensionCommonTestSuite): dict(method="monopolar_triangulation"), # , chunk_size=10000, n_jobs=1 dict(method="grid_convolution"), # , chunk_size=10000, n_jobs=1 ] - - -if __name__ == "__main__": - test = TestSpikeLocationsExtension() - test.setUpClass() - test.test_extension() diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index fda8d19da5..96b1635b27 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -12,9 +12,3 @@ class TestTemplateMetrics(AnalyzerExtensionCommonTestSuite): dict(upsampling_factor=2), dict(include_multi_channel_metrics=True), ] - - -if __name__ == "__main__": - test = TestTemplateMetrics() - test.setUpClass() - test.test_extension() diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 7e25db14f7..96a9b5f3ee 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -34,11 +34,3 @@ def test_check_equal_template_with_distribution_overlap(): continue waveforms1 = wf_ext.get_waveforms_one_unit(unit_id1) check_equal_template_with_distribution_overlap(waveforms0, waveforms1) - - -if __name__ == "__main__": - # test = SimilarityExtensionTest() - # test.setUpClass() - # test.test_extension() - - test_check_equal_template_with_distribution_overlap() diff --git a/src/spikeinterface/postprocessing/tests/test_unit_localization.py b/src/spikeinterface/postprocessing/tests/test_unit_localization.py index fd5df3105c..1546a22056 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_localization.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_localization.py @@ -12,10 +12,3 @@ class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite): dict(method="monopolar_triangulation", radius_um=150), dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"), ] - - -if __name__ == "__main__": - test = TestUnitLocationsExtension() - test.setUpClass() - test.test_extension() - # test.tearDown() From 83abe6ccafd17b2e97ec3decb20e2ee538f13d13 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 19:12:13 +0100 Subject: [PATCH 061/136] Make common extension tests and test_amplitude_scalings.py use parameterized params. --- .../tests/common_extension_tests.py | 40 +++++++------------ .../tests/test_amplitude_scalings.py | 15 +++---- 2 files changed, 21 insertions(+), 34 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index d8575cbeb2..759a15e772 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -70,9 +70,6 @@ class AnalyzerExtensionCommonTestSuite: This also test the select_units() ability. """ - extension_class = None - extension_function_params_list = None - @pytest.fixture(autouse=True, scope="class") def setUpClass(self): """ @@ -87,52 +84,45 @@ class instance is used for each. In this case, we have to set are available to all subclass instances. """ self.__class__.recording, self.__class__.sorting = get_dataset() - # sparsity is computed once for all cases to save processing time and force a small radius + self.__class__.sparsity = estimate_sparsity( self.__class__.recording, self.__class__.sorting, method="radius", radius_um=20 ) - @property - def extension_name(self): - return self.extension_class.extension_name - - def _prepare_sorting_analyzer(self, format, sparse): - # prepare a SortingAnalyzer object with depencies already computed + def _prepare_sorting_analyzer(self, format, sparse, extension_class): + """prepare a SortingAnalyzer object with depencies already computed""" sparsity_ = self.sparsity if sparse else None sorting_analyzer = get_sorting_analyzer( - self.recording, self.sorting, format=format, sparsity=sparsity_, name=self.extension_class.extension_name + self.recording, self.sorting, format=format, sparsity=sparsity_, name=extension_class.extension_name ) sorting_analyzer.compute("random_spikes", max_spikes_per_unit=50, seed=2205) - for dependency_name in self.extension_class.depend_on: + for dependency_name in extension_class.depend_on: if "|" in dependency_name: dependency_name = dependency_name.split("|")[0] sorting_analyzer.compute(dependency_name) return sorting_analyzer - def _check_one(self, sorting_analyzer): - if self.extension_class.need_job_kwargs: + def _check_one(self, sorting_analyzer, extension_class, params): + """""" + if extension_class.need_job_kwargs: job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) else: job_kwargs = dict() - # TODO: a downside of this approach is each parameterisation does - # not get it's own test, but all falls under the same test. - for params in self.extension_function_params_list: - print(" params", params) - ext = sorting_analyzer.compute(self.extension_name, **params, **job_kwargs) - assert len(ext.data) > 0 - main_data = ext.get_data() + ext = sorting_analyzer.compute(extension_class.extension_name, **params, **job_kwargs) + assert len(ext.data) > 0 + main_data = ext.get_data() - ext = sorting_analyzer.get_extension(self.extension_name) + ext = sorting_analyzer.get_extension(extension_class.extension_name) assert ext is not None some_unit_ids = sorting_analyzer.unit_ids[::2] sliced = sorting_analyzer.select_units(some_unit_ids, format="memory") assert np.array_equal(sliced.unit_ids, sorting_analyzer.unit_ids[::2]) - def test_extension(self): + def run_extension_tests(self, extension_class, params): for sparse in (True, False): for format in ("memory", "binary_folder", "zarr"): print("sparse", sparse, format) - sorting_analyzer = self._prepare_sorting_analyzer(format, sparse) - self._check_one(sorting_analyzer) + sorting_analyzer = self._prepare_sorting_analyzer(format, sparse, extension_class) + self._check_one(sorting_analyzer, extension_class, params) diff --git a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py index ebc35da348..2fec970534 100644 --- a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py @@ -1,6 +1,5 @@ -import unittest import numpy as np - +import pytest from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite @@ -8,14 +7,13 @@ class TestAmplitudeScalingsExtension(AnalyzerExtensionCommonTestSuite): - extension_class = ComputeAmplitudeScalings - extension_function_params_list = [ - dict(handle_collisions=True), - dict(handle_collisions=False), - ] + + @pytest.mark.parametrize("params", [dict(handle_collisions=True), dict(handle_collisions=False)]) + def test_extension(self, params): + self.run_extension_tests(ComputeAmplitudeScalings, params) def test_scaling_values(self): - sorting_analyzer = self._prepare_sorting_analyzer("memory", True) + sorting_analyzer = self._prepare_sorting_analyzer("memory", True, ComputeAmplitudeScalings) sorting_analyzer.compute("amplitude_scalings", handle_collisions=False) spikes = sorting_analyzer.sorting.to_spike_vector() @@ -26,7 +24,6 @@ def test_scaling_values(self): mask = spikes["unit_index"] == unit_index scalings = ext.data["amplitude_scalings"][mask] median_scaling = np.median(scalings) - # print(unit_index, median_scaling) np.testing.assert_array_equal(np.round(median_scaling), 1) # import matplotlib.pyplot as plt From 53d3680da8981f85a6d8631e0c0d5edaf325c810 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 19:23:03 +0100 Subject: [PATCH 062/136] move test_correlograms to parameterised method. --- .../postprocessing/tests/test_correlograms.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index c7c8e00722..e9bdec827f 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -14,16 +14,21 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeCorrelograms from spikeinterface.postprocessing.correlograms import compute_correlograms_on_sorting, _make_bins +import pytest class TestComputeCorrelograms(AnalyzerExtensionCommonTestSuite): - extension_class = ComputeCorrelograms - extension_function_params_list = [ - dict(method="numpy"), - dict(method="auto"), - ] - if HAVE_NUMBA: - extension_function_params_list.append(dict(method="numba")) + + @pytest.mark.parametrize( + "params", + [ + dict(method="numpy"), + dict(method="auto"), + pytest.param(dict(method="numba"), marks=pytest.mark.skipif("not HAVE_NUMBA")), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeCorrelograms, params) def test_make_bins(): From bcf87d4690eee2b86e4e88e063ad912cc7655223 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 19:32:11 +0100 Subject: [PATCH 063/136] test_isi.py to parameterised method. --- .../postprocessing/tests/test_isi.py | 47 +++++++++++-------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_isi.py b/src/spikeinterface/postprocessing/tests/test_isi.py index 3110cd7c93..801e5621c3 100644 --- a/src/spikeinterface/postprocessing/tests/test_isi.py +++ b/src/spikeinterface/postprocessing/tests/test_isi.py @@ -6,7 +6,7 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import compute_isi_histograms, ComputeISIHistograms from spikeinterface.postprocessing.isi import _compute_isi_histograms - +import pytest try: import numba @@ -17,30 +17,37 @@ class TestComputeISIHistograms(AnalyzerExtensionCommonTestSuite): - extension_class = ComputeISIHistograms - extension_function_params_list = [ - dict(method="numpy"), - dict(method="auto"), - ] - if HAVE_NUMBA: - extension_function_params_list.append(dict(method="numba")) + + @pytest.mark.parametrize( + "params", + [ + dict(method="numpy"), + dict(method="auto"), + pytest.param(dict(method="numba"), marks=pytest.mark.skipif("not HAVE_NUMBA")), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeISIHistograms, params) def test_compute_ISI(self): + """ + Requires as list because everything tested against Numpy. + But numpy is not tested against anything. + """ methods = ["numpy", "auto"] if HAVE_NUMBA: methods.append("numba") - _test_ISI(self.sorting, window_ms=60.0, bin_ms=1.0, methods=methods) - _test_ISI(self.sorting, window_ms=43.57, bin_ms=1.6421, methods=methods) - + self._test_ISI(self.sorting, window_ms=60.0, bin_ms=1.0, methods=methods) + self._test_ISI(self.sorting, window_ms=43.57, bin_ms=1.6421, methods=methods) -def _test_ISI(sorting, window_ms: float, bin_ms: float, methods: List[str]): - for method in methods: - ISI, bins = _compute_isi_histograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) + def _test_ISI(self, sorting, window_ms: float, bin_ms: float, methods: List[str]): + for method in methods: + ISI, bins = _compute_isi_histograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) - if method == "numpy": - ref_ISI = ISI - ref_bins = bins - else: - assert np.all(ISI == ref_ISI), f"Failed with method={method}" - assert np.allclose(bins, ref_bins, atol=1e-10), f"Failed with method={method}" + if method == "numpy": + ref_ISI = ISI + ref_bins = bins + else: + assert np.all(ISI == ref_ISI), f"Failed with method={method}" + assert np.allclose(bins, ref_bins, atol=1e-10), f"Failed with method={method}" From 0b3432c260317e20856719cacd34b149b4d94ea5 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 20:01:50 +0100 Subject: [PATCH 064/136] test_principal_component.py to parameterized method. --- .../tests/test_principal_component.py | 171 ++++++++++-------- 1 file changed, 91 insertions(+), 80 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 5d616575c9..8e46a6b672 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -12,17 +12,23 @@ class TestPrincipalComponentsExtension(AnalyzerExtensionCommonTestSuite): - extension_class = ComputePrincipalComponents - extension_function_params_list = [ - dict(mode="by_channel_local"), - dict(mode="by_channel_global"), - # mode concatenated cannot be tested here because it do not work with sparse=True - ] + + @pytest.mark.parametrize( + "params", + [ + dict(mode="by_channel_local"), + dict(mode="by_channel_global"), + # mode concatenated cannot be tested here because it do not work with sparse=True + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputePrincipalComponents, params=params) def test_mode_concatenated(self): # this is tested outside "extension_function_params_list" because it do not support sparsity! - - sorting_analyzer = self._prepare_sorting_analyzer(format="memory", sparse=False) + sorting_analyzer = self._prepare_sorting_analyzer( + format="memory", sparse=False, extension_class=ComputePrincipalComponents + ) n_components = 3 sorting_analyzer.compute("principal_components", mode="concatenated", n_components=n_components) @@ -33,93 +39,98 @@ def test_mode_concatenated(self): assert pca.ndim == 2 assert pca.shape[1] == n_components - def test_get_projections(self): - - for sparse in (False, True): - - sorting_analyzer = self._prepare_sorting_analyzer(format="memory", sparse=sparse) - num_chans = sorting_analyzer.get_num_channels() - n_components = 2 - - sorting_analyzer.compute("principal_components", mode="by_channel_global", n_components=n_components) - ext = sorting_analyzer.get_extension("principal_components") - - for unit_id in sorting_analyzer.unit_ids: - if not sparse: - one_proj = ext.get_projections_one_unit(unit_id, sparse=False) - assert one_proj.shape[1] == n_components - assert one_proj.shape[2] == num_chans - else: - one_proj = ext.get_projections_one_unit(unit_id, sparse=False) - assert one_proj.shape[1] == n_components - assert one_proj.shape[2] == num_chans - - one_proj, chan_inds = ext.get_projections_one_unit(unit_id, sparse=True) - assert one_proj.shape[1] == n_components - assert one_proj.shape[2] < num_chans - assert one_proj.shape[2] == chan_inds.size - - some_unit_ids = sorting_analyzer.unit_ids[::2] - some_channel_ids = sorting_analyzer.channel_ids[::2] + @pytest.mark.parametrize("sparse", [True, False]) + def test_get_projections(self, sparse): - random_spikes_indices = sorting_analyzer.get_extension("random_spikes").get_data() + sorting_analyzer = self._prepare_sorting_analyzer( + format="memory", sparse=sparse, extension_class=ComputePrincipalComponents + ) + num_chans = sorting_analyzer.get_num_channels() + n_components = 2 - # this should be all spikes all channels - some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=None) - assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] == random_spikes_indices.size - assert some_projections.shape[1] == n_components - assert some_projections.shape[2] == num_chans - - # this should be some spikes all channels - some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=some_unit_ids) - assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] < random_spikes_indices.size - assert some_projections.shape[1] == n_components - assert some_projections.shape[2] == num_chans - assert 1 not in spike_unit_index - - # this should be some spikes some channels - some_projections, spike_unit_index = ext.get_some_projections( - channel_ids=some_channel_ids, unit_ids=some_unit_ids - ) - assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] < random_spikes_indices.size - assert some_projections.shape[1] == n_components - assert some_projections.shape[2] == some_channel_ids.size - assert 1 not in spike_unit_index - - def test_compute_for_all_spikes(self): - - for sparse in (True, False): - sorting_analyzer = self._prepare_sorting_analyzer(format="memory", sparse=sparse) + sorting_analyzer.compute("principal_components", mode="by_channel_global", n_components=n_components) + ext = sorting_analyzer.get_extension("principal_components") - num_spikes = sorting_analyzer.sorting.to_spike_vector().size + for unit_id in sorting_analyzer.unit_ids: + if not sparse: + one_proj = ext.get_projections_one_unit(unit_id, sparse=False) + assert one_proj.shape[1] == n_components + assert one_proj.shape[2] == num_chans + else: + one_proj = ext.get_projections_one_unit(unit_id, sparse=False) + assert one_proj.shape[1] == n_components + assert one_proj.shape[2] == num_chans + + one_proj, chan_inds = ext.get_projections_one_unit(unit_id, sparse=True) + assert one_proj.shape[1] == n_components + assert one_proj.shape[2] < num_chans + assert one_proj.shape[2] == chan_inds.size + + some_unit_ids = sorting_analyzer.unit_ids[::2] + some_channel_ids = sorting_analyzer.channel_ids[::2] + + random_spikes_indices = sorting_analyzer.get_extension("random_spikes").get_data() + + # this should be all spikes all channels + some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=None) + assert some_projections.shape[0] == spike_unit_index.shape[0] + assert spike_unit_index.shape[0] == random_spikes_indices.size + assert some_projections.shape[1] == n_components + assert some_projections.shape[2] == num_chans + + # this should be some spikes all channels + some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=some_unit_ids) + assert some_projections.shape[0] == spike_unit_index.shape[0] + assert spike_unit_index.shape[0] < random_spikes_indices.size + assert some_projections.shape[1] == n_components + assert some_projections.shape[2] == num_chans + assert 1 not in spike_unit_index + + # this should be some spikes some channels + some_projections, spike_unit_index = ext.get_some_projections( + channel_ids=some_channel_ids, unit_ids=some_unit_ids + ) + assert some_projections.shape[0] == spike_unit_index.shape[0] + assert spike_unit_index.shape[0] < random_spikes_indices.size + assert some_projections.shape[1] == n_components + assert some_projections.shape[2] == some_channel_ids.size + assert 1 not in spike_unit_index + + @pytest.mark.parametrize("sparse", [True, False]) + def test_compute_for_all_spikes(self, sparse): + + sorting_analyzer = self._prepare_sorting_analyzer( + format="memory", sparse=sparse, extension_class=ComputePrincipalComponents + ) + + num_spikes = sorting_analyzer.sorting.to_spike_vector().size - n_components = 3 - sorting_analyzer.compute("principal_components", mode="by_channel_local", n_components=n_components) - ext = sorting_analyzer.get_extension("principal_components") + n_components = 3 + sorting_analyzer.compute("principal_components", mode="by_channel_local", n_components=n_components) + ext = sorting_analyzer.get_extension("principal_components") - pc_file1 = cache_folder / "all_pc1.npy" - ext.run_for_all_spikes(pc_file1, chunk_size=10000, n_jobs=1) - all_pc1 = np.load(pc_file1) - assert all_pc1.shape[0] == num_spikes + pc_file1 = cache_folder / "all_pc1.npy" + ext.run_for_all_spikes(pc_file1, chunk_size=10000, n_jobs=1) + all_pc1 = np.load(pc_file1) + assert all_pc1.shape[0] == num_spikes - pc_file2 = cache_folder / "all_pc2.npy" - ext.run_for_all_spikes(pc_file2, chunk_size=10000, n_jobs=2) - all_pc2 = np.load(pc_file2) + pc_file2 = cache_folder / "all_pc2.npy" + ext.run_for_all_spikes(pc_file2, chunk_size=10000, n_jobs=2) + all_pc2 = np.load(pc_file2) - assert np.array_equal(all_pc1, all_pc2) + assert np.array_equal(all_pc1, all_pc2) def test_project_new(self): - sorting_analyzer = self._prepare_sorting_analyzer(format="memory", sparse=False) + sorting_analyzer = self._prepare_sorting_analyzer( + format="memory", sparse=False, extension_class=ComputePrincipalComponents + ) waveforms = sorting_analyzer.get_extension("waveforms").data["waveforms"] n_components = 3 sorting_analyzer.compute("principal_components", mode="by_channel_local", n_components=n_components) - ext_pca = sorting_analyzer.get_extension(self.extension_name) + ext_pca = sorting_analyzer.get_extension(ComputePrincipalComponents.extension_name) num_spike = 100 new_spikes = sorting_analyzer.sorting.to_spike_vector()[:num_spike] From 5d3a66520bbd9bc69cdb919eca714c8aee7a47ef Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 20:02:19 +0100 Subject: [PATCH 065/136] test_spike_amplitudes.py to parametrized methods. --- .../postprocessing/tests/test_spike_amplitudes.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py index 3b288c540c..3f29b923cd 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py @@ -6,7 +6,6 @@ class TestComputeSpikeAmplitudes(AnalyzerExtensionCommonTestSuite): - extension_class = ComputeSpikeAmplitudes - extension_function_params_list = [ - dict(), - ] + + def test_extension(self): + self.run_extension_tests(ComputeSpikeAmplitudes, params=dict()) From c9fa259b379732c54fe4835808c1a2a95a75898a Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 20:02:42 +0100 Subject: [PATCH 066/136] test_spike_locations.py to parametrized method. --- .../tests/test_spike_locations.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index 223d74046f..382b3baf7c 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -3,18 +3,20 @@ from spikeinterface.postprocessing import ComputeSpikeLocations from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite +import pytest class TestSpikeLocationsExtension(AnalyzerExtensionCommonTestSuite): - extension_class = ComputeSpikeLocations - extension_function_params_list = [ - dict( - method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=True) - ), # chunk_size=10000, n_jobs=1, - dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=False)), - dict( - method="center_of_mass", - ), - dict(method="monopolar_triangulation"), # , chunk_size=10000, n_jobs=1 - dict(method="grid_convolution"), # , chunk_size=10000, n_jobs=1 - ] + + @pytest.mark.parametrize( + "params", + [ + dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=True)), + dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=False)), + dict(method="center_of_mass"), + dict(method="monopolar_triangulation"), + dict(method="grid_convolution"), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeSpikeLocations, params) From c88f470e3ff9239a5716f912e6be02d9b695630b Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 20:05:01 +0100 Subject: [PATCH 067/136] test_template_metrics.py a parametrized method. --- .../tests/test_template_metrics.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 96b1635b27..f5cf03a5e3 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -3,12 +3,18 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeTemplateMetrics +import pytest class TestTemplateMetrics(AnalyzerExtensionCommonTestSuite): - extension_class = ComputeTemplateMetrics - extension_function_params_list = [ - dict(), - dict(upsampling_factor=2), - dict(include_multi_channel_metrics=True), - ] + + @pytest.mark.parametrize( + "params", + [ + dict(), + dict(upsampling_factor=2), + dict(include_multi_channel_metrics=True), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeTemplateMetrics, params) From 2ca41e9c34039a92275ed252cc970aa76408fa90 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 20:16:34 +0100 Subject: [PATCH 068/136] test_template_similarity.py to parametrized method. --- .../tests/test_template_similarity.py | 34 ++++++++----------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 96a9b5f3ee..dbced65237 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -10,27 +10,23 @@ class TestSimilarityExtension(AnalyzerExtensionCommonTestSuite): - extension_class = ComputeTemplateSimilarity - extension_function_params_list = [ - dict(method="cosine_similarity"), - ] + def test_extension(self): + self.run_extension_tests(ComputeTemplateSimilarity, params=dict(method="cosine_similarity")) -def test_check_equal_template_with_distribution_overlap(): + def test_check_equal_template_with_distribution_overlap(self): - recording, sorting = get_dataset() + sorting_analyzer = self._prepare_sorting_analyzer("memory", None, ComputeTemplateSimilarity) + sorting_analyzer.compute("random_spikes") + sorting_analyzer.compute("waveforms") + sorting_analyzer.compute("templates") - sorting_analyzer = get_sorting_analyzer(recording, sorting, sparsity=None) - sorting_analyzer.compute("random_spikes") - sorting_analyzer.compute("waveforms") - sorting_analyzer.compute("templates") + wf_ext = sorting_analyzer.get_extension("waveforms") - wf_ext = sorting_analyzer.get_extension("waveforms") - - for unit_id0 in sorting_analyzer.unit_ids: - waveforms0 = wf_ext.get_waveforms_one_unit(unit_id0) - for unit_id1 in sorting_analyzer.unit_ids: - if unit_id0 == unit_id1: - continue - waveforms1 = wf_ext.get_waveforms_one_unit(unit_id1) - check_equal_template_with_distribution_overlap(waveforms0, waveforms1) + for unit_id0 in sorting_analyzer.unit_ids: + waveforms0 = wf_ext.get_waveforms_one_unit(unit_id0) + for unit_id1 in sorting_analyzer.unit_ids: + if unit_id0 == unit_id1: + continue + waveforms1 = wf_ext.get_waveforms_one_unit(unit_id1) + check_equal_template_with_distribution_overlap(waveforms0, waveforms1) From fbe04043e70c41b03ded5028b8cfd6259fe4f573 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 20:18:11 +0100 Subject: [PATCH 069/136] test_unit_localization.py to parameterized method. --- .../tests/test_unit_localization.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_unit_localization.py b/src/spikeinterface/postprocessing/tests/test_unit_localization.py index 1546a22056..6fd589e2f4 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_localization.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_localization.py @@ -1,14 +1,20 @@ import unittest from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeUnitLocations +import pytest class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite): - extension_class = ComputeUnitLocations - extension_function_params_list = [ - dict(method="center_of_mass", radius_um=100), - dict(method="grid_convolution", radius_um=50), - dict(method="grid_convolution", radius_um=150, weight_method={"mode": "gaussian_2d"}), - dict(method="monopolar_triangulation", radius_um=150), - dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"), - ] + + @pytest.mark.parametrize( + "params", + [ + dict(method="center_of_mass", radius_um=100), + dict(method="grid_convolution", radius_um=50), + dict(method="grid_convolution", radius_um=150, weight_method={"mode": "gaussian_2d"}), + dict(method="monopolar_triangulation", radius_um=150), + dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeUnitLocations, params=params) From 3ae8479e9d32bcc089343942d9d5d8fa44107c85 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 20:24:09 +0100 Subject: [PATCH 070/136] Tidy up imports. --- .../postprocessing/tests/test_correlograms.py | 4 ---- src/spikeinterface/postprocessing/tests/test_isi.py | 3 +-- .../postprocessing/tests/test_principal_component.py | 9 +-------- .../postprocessing/tests/test_spike_amplitudes.py | 3 --- .../postprocessing/tests/test_spike_locations.py | 3 --- .../postprocessing/tests/test_template_metrics.py | 3 --- .../postprocessing/tests/test_template_similarity.py | 4 ---- .../postprocessing/tests/test_unit_localization.py | 1 - 8 files changed, 2 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index e9bdec827f..49d083cedf 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -1,6 +1,4 @@ -import unittest import numpy as np -from typing import List try: import numba @@ -38,13 +36,11 @@ def test_make_bins(): bin_ms = 1.6421 bins, window_size, bin_size = _make_bins(sorting, window_ms, bin_ms) assert bins.size == np.floor(window_ms / bin_ms) + 1 - # print(bins, window_size, bin_size) window_ms = 60.0 bin_ms = 2.0 bins, window_size, bin_size = _make_bins(sorting, window_ms, bin_ms) assert bins.size == np.floor(window_ms / bin_ms) + 1 - # print(bins, window_size, bin_size) def _test_correlograms(sorting, window_ms, bin_ms, methods): diff --git a/src/spikeinterface/postprocessing/tests/test_isi.py b/src/spikeinterface/postprocessing/tests/test_isi.py index 801e5621c3..3c7a1e1463 100644 --- a/src/spikeinterface/postprocessing/tests/test_isi.py +++ b/src/spikeinterface/postprocessing/tests/test_isi.py @@ -1,10 +1,9 @@ -import unittest import numpy as np from typing import List from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -from spikeinterface.postprocessing import compute_isi_histograms, ComputeISIHistograms +from spikeinterface.postprocessing import ComputeISIHistograms from spikeinterface.postprocessing.isi import _compute_isi_histograms import pytest diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 8e46a6b672..1d150d73da 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -1,16 +1,9 @@ -import unittest import pytest -from pathlib import Path - import numpy as np - -from spikeinterface.postprocessing import ComputePrincipalComponents, compute_principal_components +from spikeinterface.postprocessing import ComputePrincipalComponents from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite, cache_folder -DEBUG = False - - class TestPrincipalComponentsExtension(AnalyzerExtensionCommonTestSuite): @pytest.mark.parametrize( diff --git a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py index 3f29b923cd..a68483a1b2 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py @@ -1,6 +1,3 @@ -import unittest -import numpy as np - from spikeinterface.postprocessing import ComputeSpikeAmplitudes from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index 382b3baf7c..46a39d23ea 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -1,6 +1,3 @@ -import unittest -import numpy as np - from spikeinterface.postprocessing import ComputeSpikeLocations from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite import pytest diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index f5cf03a5e3..694aa083cc 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -1,6 +1,3 @@ -import unittest - - from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeTemplateMetrics import pytest diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index dbced65237..a0f57bf3c5 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -1,9 +1,5 @@ -import unittest - from spikeinterface.postprocessing.tests.common_extension_tests import ( AnalyzerExtensionCommonTestSuite, - get_sorting_analyzer, - get_dataset, ) from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity diff --git a/src/spikeinterface/postprocessing/tests/test_unit_localization.py b/src/spikeinterface/postprocessing/tests/test_unit_localization.py index 6fd589e2f4..c40a917a2b 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_localization.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_localization.py @@ -1,4 +1,3 @@ -import unittest from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeUnitLocations import pytest From 43a76f831e3808dec87dfd339154ec0f29f0fde5 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 21:54:35 +0100 Subject: [PATCH 071/136] Remove missed __main__. --- src/spikeinterface/postprocessing/tests/test_align_sorting.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_align_sorting.py b/src/spikeinterface/postprocessing/tests/test_align_sorting.py index e5c70ae4b2..fbb54035bb 100644 --- a/src/spikeinterface/postprocessing/tests/test_align_sorting.py +++ b/src/spikeinterface/postprocessing/tests/test_align_sorting.py @@ -40,7 +40,3 @@ def test_align_sorting(): st = sorting.get_unit_spike_train(unit_id) st_clean = sorting_aligned.get_unit_spike_train(unit_id) assert np.array_equal(st, st_clean) - - -if __name__ == "__main__": - test_align_sorting() From 6bbc6274783f8b11ece6c2d342f0a0cc2651785f Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 22:15:07 +0100 Subject: [PATCH 072/136] Extend docstrings for common_extension_tests.py --- .../tests/common_extension_tests.py | 48 +++++++++++++++---- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 759a15e772..6b445aa746 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -61,13 +61,27 @@ def get_sorting_analyzer(recording, sorting, format="memory", sparsity=None, nam class AnalyzerExtensionCommonTestSuite: """ - Common tests with class approach to compute extension on several cases (3 format x 2 sparsity) - - This is done a a list of differents parameters (extension_function_params_list). - - This automatically precompute extension dependencies with default params before running computation. - - This also test the select_units() ability. + Common tests with class approach to compute extension on several cases, + format ("memory", "binary_folder", "zarr") and sparsity (True, False). + Extensions refer to the extension classes that handle the postprocessing, + for example extracting principal components or amplitude scalings. + + This base class provides a fixture which sets a recording + and sorting object onto itself, which are set up once each time + the base class is subclassed in a test environment. The recording + and sorting object are used in the creation of the `sorting_analyzer` + object used to run postprocessing routines. + + When subclassed, a test function that parametrises arguments + that are passed to the `sorting_analyzer.compute()` can be setup. + This must call `run_extension_tests()` which sets up a `sorting_analyzer` + with the relevant format and sparsity. This also automatically precomputes + extension dependencies with default params, Then, `check_one()` is called + which runs the compute function with the passed params and tests that: + + 1) the returned extractor object has data on it + 2) check `sorting_analyzer.get_extension()` does not return None + 3) the correct units are sliced with the `select_units()` function. """ @pytest.fixture(autouse=True, scope="class") @@ -90,20 +104,31 @@ class instance is used for each. In this case, we have to set ) def _prepare_sorting_analyzer(self, format, sparse, extension_class): - """prepare a SortingAnalyzer object with depencies already computed""" + """ + Prepare a SortingAnalyzer object with dependencies already computed + according to format (e.g. "memory", "binary_folder", "zarr") + and sparsity (e.g. True, False). + """ sparsity_ = self.sparsity if sparse else None + sorting_analyzer = get_sorting_analyzer( self.recording, self.sorting, format=format, sparsity=sparsity_, name=extension_class.extension_name ) sorting_analyzer.compute("random_spikes", max_spikes_per_unit=50, seed=2205) + for dependency_name in extension_class.depend_on: if "|" in dependency_name: dependency_name = dependency_name.split("|")[0] sorting_analyzer.compute(dependency_name) + return sorting_analyzer def _check_one(self, sorting_analyzer, extension_class, params): - """""" + """ + Take a prepared sorting analyzer object, compute the extension of interest + with the passed parameters, and check the output is not empty, the extension + exists and `select_units()` method works. + """ if extension_class.need_job_kwargs: job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) else: @@ -121,6 +146,11 @@ def _check_one(self, sorting_analyzer, extension_class, params): assert np.array_equal(sliced.unit_ids, sorting_analyzer.unit_ids[::2]) def run_extension_tests(self, extension_class, params): + """ + Convenience function to perform all checks on the extension + of interest with the passed parameters. Will perform tests + for sparsity and format. + """ for sparse in (True, False): for format in ("memory", "binary_folder", "zarr"): print("sparse", sparse, format) From 8462a0782af255f9ec24c5b5c0689edb9e7282b7 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 22:15:36 +0100 Subject: [PATCH 073/136] Extend docstrings and refactor 'test_align_sorting()'. --- .../tests/test_align_sorting.py | 43 +++++++++++++------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_align_sorting.py b/src/spikeinterface/postprocessing/tests/test_align_sorting.py index fbb54035bb..3a7befe019 100644 --- a/src/spikeinterface/postprocessing/tests/test_align_sorting.py +++ b/src/spikeinterface/postprocessing/tests/test_align_sorting.py @@ -1,5 +1,3 @@ -import pytest -import shutil from pathlib import Path import pytest @@ -18,8 +16,16 @@ def test_align_sorting(): + """ + `align_sorting()` shifts, in time, the spikes belonging to a unit. + For each unit, an offset is provided and the spike peak index is shifted. + + This test creates a sorting object, then creates an 'unaligned' sorting + object in which the peaks for some of the units are shifted. Next, the `align_sorting()` + function is unused to unshift them, and the original sorting spike train + peak times compared with the corrected sorting train. + """ sorting = generate_sorting(durations=[10.0], seed=0) - print(sorting) unit_ids = sorting.unit_ids @@ -27,16 +33,27 @@ def test_align_sorting(): unit_peak_shifts[unit_ids[-1]] = 5 unit_peak_shifts[unit_ids[-2]] = -5 - # sorting to dict - d = {unit_id: sorting.get_unit_spike_train(unit_id) + unit_peak_shifts[unit_id] for unit_id in sorting.unit_ids} - sorting_unaligned = NumpySorting.from_unit_dict(d, sampling_frequency=sorting.get_sampling_frequency()) - print(sorting_unaligned) + shifted_unit_dict = { + unit_id: sorting.get_unit_spike_train(unit_id) + unit_peak_shifts[unit_id] for unit_id in sorting.unit_ids + } + sorting_unaligned = NumpySorting.from_unit_dict( + shifted_unit_dict, sampling_frequency=sorting.get_sampling_frequency() + ) sorting_aligned = align_sorting(sorting_unaligned, unit_peak_shifts) - print(sorting_aligned) - for start_frame, end_frame in [(None, None), (10000, 50000)]: - for unit_id in unit_ids[-2:]: - st = sorting.get_unit_spike_train(unit_id) - st_clean = sorting_aligned.get_unit_spike_train(unit_id) - assert np.array_equal(st, st_clean) + for unit_id in unit_ids: + spiketrain_orig = sorting.get_unit_spike_train(unit_id) + spiketrain_aligned = sorting_aligned.get_unit_spike_train(unit_id) + spiketrain_unaligned = sorting_unaligned.get_unit_spike_train(unit_id) + + # check the shift induced in the test has changed the + # spiketrain as expected. + if unit_peak_shifts[unit_id] == 0: + assert np.array_equal(spiketrain_orig, spiketrain_unaligned) + else: + assert not np.array_equal(spiketrain_orig, spiketrain_unaligned) + + # Perform the key test, that after correction the spiketrain + # matches the original spiketrain for all units (shifted and unshifted). + assert np.array_equal(spiketrain_orig, spiketrain_aligned) From 9a9cc62326e073827247ff264a010e138a19236e Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 22:42:01 +0100 Subject: [PATCH 074/136] Add docstrings to test_amplitude_scalings.py and test_isi.py --- .../postprocessing/tests/test_amplitude_scalings.py | 12 +++++++++++- src/spikeinterface/postprocessing/tests/test_isi.py | 6 ++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py index 2fec970534..6ea6b436bf 100644 --- a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py @@ -13,7 +13,17 @@ def test_extension(self, params): self.run_extension_tests(ComputeAmplitudeScalings, params) def test_scaling_values(self): - sorting_analyzer = self._prepare_sorting_analyzer("memory", True, ComputeAmplitudeScalings) + """ + Amplitude finds the scaling factor for each waveform + to best match its unit template. In this test, amplitude scalings + are calculated from the `sorting_analyzer`. In the test environment, + injected waveforms are not scaled from the template and so + should only differ by Gaussian noise. Therefore the median + scaling should be close to 1. + """ + sorting_analyzer = self._prepare_sorting_analyzer( + "memory", sparse=True, extension_class=ComputeAmplitudeScalings + ) sorting_analyzer.compute("amplitude_scalings", handle_collisions=False) spikes = sorting_analyzer.sorting.to_spike_vector() diff --git a/src/spikeinterface/postprocessing/tests/test_isi.py b/src/spikeinterface/postprocessing/tests/test_isi.py index 3c7a1e1463..0f9ecb3d7d 100644 --- a/src/spikeinterface/postprocessing/tests/test_isi.py +++ b/src/spikeinterface/postprocessing/tests/test_isi.py @@ -30,8 +30,10 @@ def test_extension(self, params): def test_compute_ISI(self): """ - Requires as list because everything tested against Numpy. - But numpy is not tested against anything. + This test checks the creation of ISI histograms matches across + "numpy", "auto" and "numba" methods. Does not parameterize as requires + as list because everything tested against Numpy. The Numpy result is not + explicitly tested. """ methods = ["numpy", "auto"] if HAVE_NUMBA: From 063a6e633450d8a740658c87a99971070d10dfc2 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 5 Jun 2024 22:47:56 +0100 Subject: [PATCH 075/136] Add assert and docstring to 'test_template_similarity.py' --- .../postprocessing/tests/test_template_similarity.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index a0f57bf3c5..a4de2a3a90 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -11,7 +11,12 @@ def test_extension(self): self.run_extension_tests(ComputeTemplateSimilarity, params=dict(method="cosine_similarity")) def test_check_equal_template_with_distribution_overlap(self): - + """ + Create a sorting object, extract its waveforms. Compare waveforms + from all pairs of units (excluding a unit against itself) + and check `check_equal_template_with_distribution_overlap()` + correctly determines they are different. + """ sorting_analyzer = self._prepare_sorting_analyzer("memory", None, ComputeTemplateSimilarity) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("waveforms") @@ -25,4 +30,5 @@ def test_check_equal_template_with_distribution_overlap(self): if unit_id0 == unit_id1: continue waveforms1 = wf_ext.get_waveforms_one_unit(unit_id1) - check_equal_template_with_distribution_overlap(waveforms0, waveforms1) + + assert not check_equal_template_with_distribution_overlap(waveforms0, waveforms1) From 9d296b49a6bf5576750eb8718c363c80bf786c16 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 6 Jun 2024 00:52:08 +0100 Subject: [PATCH 076/136] Add docstring and some small improvement to test_principal_component.py --- .../tests/test_principal_component.py | 47 ++++++++++++++++--- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 1d150d73da..79166e5400 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -18,7 +18,12 @@ def test_extension(self, params): self.run_extension_tests(ComputePrincipalComponents, params=params) def test_mode_concatenated(self): - # this is tested outside "extension_function_params_list" because it do not support sparsity! + """ + Replicate the "extension_function_params_list" test outside of + AnalyzerExtensionCommonTestSuite because it does not support sparsity. + + Also, add two additional checks on the dimension and n components of the output. + """ sorting_analyzer = self._prepare_sorting_analyzer( format="memory", sparse=False, extension_class=ComputePrincipalComponents ) @@ -34,7 +39,13 @@ def test_mode_concatenated(self): @pytest.mark.parametrize("sparse", [True, False]) def test_get_projections(self, sparse): - + """ + Test the shape of output projection score matrices are + correct when adjusting sparsity and using the + `get_some_projections()` function. We expect them + to hold, for each spike and each channel, the loading + for each of the specified number of components. + """ sorting_analyzer = self._prepare_sorting_analyzer( format="memory", sparse=sparse, extension_class=ComputePrincipalComponents ) @@ -44,6 +55,8 @@ def test_get_projections(self, sparse): sorting_analyzer.compute("principal_components", mode="by_channel_global", n_components=n_components) ext = sorting_analyzer.get_extension("principal_components") + # First, check the created projections have the expected number + # of components and the expected number of channels based on sparsity. for unit_id in sorting_analyzer.unit_ids: if not sparse: one_proj = ext.get_projections_one_unit(unit_id, sparse=False) @@ -56,13 +69,19 @@ def test_get_projections(self, sparse): one_proj, chan_inds = ext.get_projections_one_unit(unit_id, sparse=True) assert one_proj.shape[1] == n_components - assert one_proj.shape[2] < num_chans + num_channels_for_unit = sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id].size + assert one_proj.shape[2] == num_channels_for_unit assert one_proj.shape[2] == chan_inds.size + # Next, check that the `get_some_projections()` function returns + # projections with the expected shapes when selecting subjsets + # of channel and unit IDs. some_unit_ids = sorting_analyzer.unit_ids[::2] some_channel_ids = sorting_analyzer.channel_ids[::2] random_spikes_indices = sorting_analyzer.get_extension("random_spikes").get_data() + all_num_spikes = sorting_analyzer.sorting.get_total_num_spikes() + unit_ids_num_spikes = np.sum(all_num_spikes[unit_id] for unit_id in some_unit_ids) # this should be all spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=None) @@ -74,7 +93,7 @@ def test_get_projections(self, sparse): # this should be some spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=some_unit_ids) assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] < random_spikes_indices.size + assert spike_unit_index.shape[0] == unit_ids_num_spikes assert some_projections.shape[1] == n_components assert some_projections.shape[2] == num_chans assert 1 not in spike_unit_index @@ -84,14 +103,19 @@ def test_get_projections(self, sparse): channel_ids=some_channel_ids, unit_ids=some_unit_ids ) assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] < random_spikes_indices.size + assert spike_unit_index.shape[0] == unit_ids_num_spikes assert some_projections.shape[1] == n_components assert some_projections.shape[2] == some_channel_ids.size assert 1 not in spike_unit_index @pytest.mark.parametrize("sparse", [True, False]) def test_compute_for_all_spikes(self, sparse): - + """ + Compute the principal component scores, checking the shape + matches the number of spikes as expected. This is re-run + with n_jobs=2 and output projection score matrices + checked against n_jobs=1. + """ sorting_analyzer = self._prepare_sorting_analyzer( format="memory", sparse=sparse, extension_class=ComputePrincipalComponents ) @@ -114,7 +138,16 @@ def test_compute_for_all_spikes(self, sparse): assert np.array_equal(all_pc1, all_pc2) def test_project_new(self): - + """ + `project_new` projects new (unseen) waveforms onto the PCA components. + First compute principal components from existing waveforms. Then, + generate a new 'spikes' vector that includes sample_index, unit_index + and segment_index alongside some waveforms (the spike vector is required + to generate some corresponding unit IDs for the generated waveforms following + the API of principal_components.py). + + Then, check that the new projection scores matrix is the expected shape. + """ sorting_analyzer = self._prepare_sorting_analyzer( format="memory", sparse=False, extension_class=ComputePrincipalComponents ) From 75ce4d4783440aaaced7c5ced785cd9b9f50fe66 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 6 Jun 2024 01:14:51 +0100 Subject: [PATCH 077/136] Add docstring to test_correlograms.py --- .../postprocessing/tests/test_correlograms.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 49d083cedf..f3d7617512 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -30,6 +30,10 @@ def test_extension(self, params): def test_make_bins(): + """ + Check the `_make_bins()` function that generates time bins (lags) for + the correllogram creates the expected number of bins. + """ sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) window_ms = 43.57 @@ -79,6 +83,10 @@ def test_equal_results_correlograms(): def test_flat_cross_correlogram(): + """ + Check that the correlogram (num_units x num_units x num_bins) does not + vary too much across time bins (lags), for entries representing two different units. + """ sorting = generate_sorting(num_units=2, sampling_frequency=10000.0, durations=[100000.0], seed=0) methods = ["numpy"] @@ -150,6 +158,11 @@ def test_auto_equal_cross_correlograms(): def test_detect_injected_correlation(): + """ + Inject 1.44 ms of correlation every 13 spikes and compute + cross-correlation. Check that the time bin lag with the peak + correlation lag is 1.44 ms (within tolerance of a sampling period). + """ methods = ["numpy"] if HAVE_NUMBA: methods.append("numba") From 74169b0b7b257460c1e5af81d59d699648b732a2 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 6 Jun 2024 01:17:50 +0100 Subject: [PATCH 078/136] Remove unused / commented debugging code. --- .../tests/test_amplitude_scalings.py | 5 ---- .../postprocessing/tests/test_correlograms.py | 27 ------------------- .../tests/test_principal_component.py | 7 ----- 3 files changed, 39 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py index 6ea6b436bf..0868f5238e 100644 --- a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py @@ -35,8 +35,3 @@ def test_scaling_values(self): scalings = ext.data["amplitude_scalings"][mask] median_scaling = np.median(scalings) np.testing.assert_array_equal(np.round(median_scaling), 1) - - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # ax.hist(ext.data["amplitude_scalings"]) - # plt.show() diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index f3d7617512..56eac6ef9a 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -93,9 +93,6 @@ def test_flat_cross_correlogram(): if HAVE_NUMBA: methods.append("numba") - # ~ import matplotlib.pyplot as plt - # ~ fig, ax = plt.subplots() - for method in methods: correlograms, bins = compute_correlograms_on_sorting(sorting, window_ms=50.0, bin_ms=1.0, method=method) cc = correlograms[0, 1, :].copy() @@ -103,11 +100,6 @@ def test_flat_cross_correlogram(): assert np.all(cc > (m * 0.90)) assert np.all(cc < (m * 1.10)) - # ~ ax.plot(bins[:-1], cc, label=method) - # ~ ax.legend() - # ~ ax.set_ylim(0, np.max(correlograms) * 1.1) - # ~ plt.show() - def test_auto_equal_cross_correlograms(): """ @@ -146,16 +138,6 @@ def test_auto_equal_cross_correlograms(): else: assert np.array_equal(cc_corrected, ac) - # ~ import matplotlib.pyplot as plt - # ~ fig, ax = plt.subplots() - # ~ ax.plot(bins[:-1], cc, marker='*', color='red', label='cross-corr') - # ~ ax.plot(bins[:-1], cc_corrected, marker='*', color='orange', label='cross-corr corrected') - # ~ ax.plot(bins[:-1], ac, marker='*', color='green', label='auto-corr') - # ~ ax.set_title(method) - # ~ ax.legend() - # ~ ax.set_ylim(0, np.max(correlograms) * 1.1) - # ~ plt.show() - def test_detect_injected_correlation(): """ @@ -195,12 +177,3 @@ def test_detect_injected_correlation(): sampling_period_ms = 1000.0 / sampling_frequency assert abs(peak_location_01_ms) - injected_delta_ms < sampling_period_ms assert abs(peak_location_02_ms) - injected_delta_ms < sampling_period_ms - - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # half_bin_ms = np.mean(np.diff(bins)) / 2. - # ax.plot(bins[:-1]+half_bin_ms, cc_01, marker='*', color='red', label='cross-corr 0>1') - # ax.plot(bins[:-1]+half_bin_ms, cc_10, marker='*', color='orange', label='cross-corr 1>0') - # ax.set_title(method) - # ax.legend() - # plt.show() diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 79166e5400..ebfc781cd9 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -166,10 +166,3 @@ def test_project_new(self): assert new_proj.shape[0] == num_spike assert new_proj.shape[1] == n_components assert new_proj.shape[2] == ext_pca.data["pca_projection"].shape[2] - - # ext = test.sorting_analyzers["sparseTrue_memory"].get_extension("principal_components") - # pca = ext.data["pca_projection"] - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # ax.scatter(pca[:, 0, 0], pca[:, 0, 1]) - # plt.show() From e1c890a8914513b65e2357253419ba563a134de5 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Thu, 6 Jun 2024 12:07:36 +0100 Subject: [PATCH 079/136] Update src/spikeinterface/postprocessing/tests/common_extension_tests.py Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- .../postprocessing/tests/common_extension_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 6b445aa746..6cca483e29 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -136,7 +136,7 @@ def _check_one(self, sorting_analyzer, extension_class, params): ext = sorting_analyzer.compute(extension_class.extension_name, **params, **job_kwargs) assert len(ext.data) > 0 - main_data = ext.get_data() + assert len(main_data) > 0 ext = sorting_analyzer.get_extension(extension_class.extension_name) assert ext is not None From 6277f9f8229d707227313ef8c7d65c001cf14e43 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Thu, 6 Jun 2024 18:09:44 +0100 Subject: [PATCH 080/136] Reformat pytest param exclusion for test_correlograms Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- src/spikeinterface/postprocessing/tests/test_correlograms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 56eac6ef9a..d0b9a1e28b 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -22,7 +22,7 @@ class TestComputeCorrelograms(AnalyzerExtensionCommonTestSuite): [ dict(method="numpy"), dict(method="auto"), - pytest.param(dict(method="numba"), marks=pytest.mark.skipif("not HAVE_NUMBA")), + pytest.param(dict(method="numba"), marks=pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available")), ], ) def test_extension(self, params): From e599cf66b2d12f77d3cbca64a072d7b66abfa9d9 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 6 Jun 2024 18:10:53 +0100 Subject: [PATCH 081/136] Reformat pytest param exclusion for test_isi. --- src/spikeinterface/postprocessing/tests/test_isi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/test_isi.py b/src/spikeinterface/postprocessing/tests/test_isi.py index 0f9ecb3d7d..444e837cb4 100644 --- a/src/spikeinterface/postprocessing/tests/test_isi.py +++ b/src/spikeinterface/postprocessing/tests/test_isi.py @@ -22,7 +22,7 @@ class TestComputeISIHistograms(AnalyzerExtensionCommonTestSuite): [ dict(method="numpy"), dict(method="auto"), - pytest.param(dict(method="numba"), marks=pytest.mark.skipif("not HAVE_NUMBA")), + pytest.param(dict(method="numba"), marks=pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available")), ], ) def test_extension(self, params): From e8a006e08b1adc0041340cf707be945bf633f1e1 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 6 Jun 2024 18:21:26 +0100 Subject: [PATCH 082/136] Remove commented code in '_test_correlograms()' --- .../postprocessing/tests/test_correlograms.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index d0b9a1e28b..eef4af10fc 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -51,22 +51,8 @@ def _test_correlograms(sorting, window_ms, bin_ms, methods): for method in methods: correlograms, bins = compute_correlograms_on_sorting(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) if method == "numpy": - ref_correlograms = correlograms ref_bins = bins else: - # ~ import matplotlib.pyplot as plt - # ~ for i in range(ref_correlograms.shape[1]): - # ~ for j in range(ref_correlograms.shape[1]): - # ~ fig, ax = plt.subplots() - # ~ ax.plot(bins[:-1], ref_correlograms[i, j, :], color='green', label='numpy') - # ~ ax.plot(bins[:-1], correlograms[i, j, :], color='red', label=method) - # ~ ax.legend() - # ~ ax.set_title(f'{i} {j}') - # ~ plt.show() - - # numba and numyp do not have exactly the same output - # assert np.all(correlograms == ref_correlograms), f"Failed with method={method}" - assert np.allclose(bins, ref_bins, atol=1e-10), f"Failed with method={method}" From 4e617bf6c5b607d154ef5aa6853eb16a51b38686 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 6 Jun 2024 22:39:39 +0200 Subject: [PATCH 083/136] Moving unit_locations --- .../{test_unit_localization.py => test_unit_locations.py} | 0 .../{unit_localization.py => unit_locations.py} | 0 .../benchmark/tests/test_benchmark_peak_localization.py | 6 +++--- 3 files changed, 3 insertions(+), 3 deletions(-) rename src/spikeinterface/postprocessing/tests/{test_unit_localization.py => test_unit_locations.py} (100%) rename src/spikeinterface/postprocessing/{unit_localization.py => unit_locations.py} (100%) diff --git a/src/spikeinterface/postprocessing/tests/test_unit_localization.py b/src/spikeinterface/postprocessing/tests/test_unit_locations.py similarity index 100% rename from src/spikeinterface/postprocessing/tests/test_unit_localization.py rename to src/spikeinterface/postprocessing/tests/test_unit_locations.py diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_locations.py similarity index 100% rename from src/spikeinterface/postprocessing/unit_localization.py rename to src/spikeinterface/postprocessing/unit_locations.py diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py index b6f89dcd36..23060c4ddb 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py @@ -56,14 +56,14 @@ def test_benchmark_peak_localization(create_cache_folder): @pytest.mark.skip() -def test_benchmark_unit_localization(create_cache_folder): +def test_benchmark_unit_locations(create_cache_folder): cache_folder = create_cache_folder job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") recording, gt_sorting = make_dataset() # create study - study_folder = cache_folder / "study_unit_localization" + study_folder = cache_folder / "study_unit_locations" datasets = {"toy": (recording, gt_sorting)} cases = {} for method in ["center_of_mass", "grid_convolution", "monopolar_triangulation"]: @@ -100,4 +100,4 @@ def test_benchmark_unit_localization(create_cache_folder): if __name__ == "__main__": # test_benchmark_peak_localization() - test_benchmark_unit_localization() + test_benchmark_unit_locations() From 8a24b80024b16bf03177caabf9cea1417bcc3dec Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 6 Jun 2024 22:48:29 +0200 Subject: [PATCH 084/136] Messing with git --- src/spikeinterface/exporters/report.py | 2 +- src/spikeinterface/postprocessing/__init__.py | 2 +- .../benchmark/benchmark_peak_localization.py | 2 +- src/spikeinterface/sortingcomponents/peak_detection.py | 2 +- src/spikeinterface/sortingcomponents/peak_localization.py | 8 ++++---- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index e12bb9b588..3a4be9213a 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -111,7 +111,7 @@ def export_report( # global figures fig = plt.figure(figsize=(20, 10)) w = sw.plot_unit_locations(sorting_analyzer, figure=fig, unit_colors=unit_colors) - fig.savefig(output_folder / f"unit_localization.{format}") + fig.savefig(output_folder / f"unit_locations.{format}") if not show_figures: plt.close(fig) diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 528f2d3761..ae071a55e0 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -37,7 +37,7 @@ from .spike_locations import compute_spike_locations, ComputeSpikeLocations -from .unit_localization import ( +from .unit_locations import ( compute_unit_locations, ComputeUnitLocations, compute_center_of_mass, diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py index 5c4085af7c..3eda5db3b6 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py @@ -1,6 +1,6 @@ from __future__ import annotations -from spikeinterface.postprocessing.unit_localization import ( +from spikeinterface.postprocessing.unit_locations import ( compute_center_of_mass, compute_monopolar_triangulation, compute_grid_convolution, diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index d23f0fec74..11218a688f 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -23,7 +23,7 @@ base_peak_dtype, ) -from spikeinterface.postprocessing.unit_localization import get_convolution_weights +from spikeinterface.postprocessing.unit_locations import get_convolution_weights from ..core import get_chunk_with_margin from .tools import make_multi_method_doc diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index b06f6fac3e..fcae485af9 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -21,7 +21,7 @@ from spikeinterface.core import get_channel_distances -from ..postprocessing.unit_localization import ( +from ..postprocessing.unit_locations import ( dtype_localize_by_method, possible_localization_methods, solve_monopolar_triangulation, @@ -163,7 +163,7 @@ class LocalizeCenterOfMass(LocalizeBase): Notes ----- - See spikeinterface.postprocessing.unit_localization. + See spikeinterface.postprocessing.unit_locations. """ need_waveforms = True @@ -225,7 +225,7 @@ class LocalizeMonopolarTriangulation(PipelineNode): Notes ----- This method is from Julien Boussard, Erdem Varol and Charlie Windolf - See spikeinterface.postprocessing.unit_localization. + See spikeinterface.postprocessing.unit_locations. """ need_waveforms = False @@ -316,7 +316,7 @@ class LocalizeGridConvolution(PipelineNode): Notes ----- - See spikeinterface.postprocessing.unit_localization. + See spikeinterface.postprocessing.unit_locations. """ need_waveforms = True From 9019a51369cbcdc4d0684271b6b525715e1a9986 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 7 Jun 2024 09:33:02 +0200 Subject: [PATCH 085/136] Motion object in benchmark motionestimation --- .../benchmark/benchmark_motion_estimation.py | 152 +++++++++--------- .../benchmark/benchmark_tools.py | 6 + .../sortingcomponents/motion_utils.py | 8 + 3 files changed, 92 insertions(+), 74 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 7428629c4a..96b277de6e 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -15,6 +15,8 @@ from spikeinterface.sortingcomponents.peak_localization import localize_peaks from spikeinterface.widgets import plot_probe_map +from spikeinterface.sortingcomponents.motion_utils import Motion + # import MEArec as mr # TODO : plot_peaks @@ -28,8 +30,8 @@ def get_gt_motion_from_unit_displacement( unit_displacements, displacement_sampling_frequency, unit_locations, - temporal_bins, - spatial_bins, + temporal_bins_s, + spatial_bins_um, direction_dim=1, ): import scipy.interpolate @@ -37,21 +39,29 @@ def get_gt_motion_from_unit_displacement( unit_displacements = unit_displacements[:, :, direction_dim] times = np.arange(unit_displacements.shape[0]) / displacement_sampling_frequency f = scipy.interpolate.interp1d(times, unit_displacements, axis=0) - unit_displacements = f(temporal_bins) + unit_displacements = f(temporal_bins_s) # spatial interpolataion of units discplacement - if spatial_bins.shape[0] == 1: + if spatial_bins_um.shape[0] == 1: # rigid - gt_motion = np.mean(unit_displacements, axis=1)[:, None] + gt_displacement = np.mean(unit_displacements, axis=1)[:, None] else: # non rigid - gt_motion = np.zeros((temporal_bins.size, spatial_bins.size)) - for t in range(temporal_bins.shape[0]): + gt_displacement = np.zeros((temporal_bins_s.size, spatial_bins_um.size)) + for t in range(temporal_bins_s.shape[0]): f = scipy.interpolate.interp1d( unit_locations[:, direction_dim], unit_displacements[t, :], fill_value="extrapolate" ) - gt_motion[t, :] = f(spatial_bins) - + gt_displacement[t, :] = f(spatial_bins_um) + + gt_motion = Motion( + gt_displacement, + temporal_bins_s, + spatial_bins_um, + direction="xyz"[direction_dim], + interpolation_method="linear" + ) + return gt_motion @@ -92,7 +102,7 @@ def run(self, **job_kwargs): t2 = time.perf_counter() peak_locations = localize_peaks(self.recording, selected_peaks, **p["localize_kwargs"], **job_kwargs) t3 = time.perf_counter() - motion, temporal_bins, spatial_bins = estimate_motion( + motion = estimate_motion( self.recording, selected_peaks, peak_locations, **p["estimate_motion_kwargs"] ) t4 = time.perf_counter() @@ -106,43 +116,37 @@ def run(self, **job_kwargs): self.result["step_run_times"] = step_run_times self.result["raw_motion"] = motion - self.result["temporal_bins"] = temporal_bins - self.result["spatial_bins"] = spatial_bins def compute_result(self, **result_params): raw_motion = self.result["raw_motion"] - temporal_bins = self.result["temporal_bins"] - spatial_bins = self.result["spatial_bins"] gt_motion = get_gt_motion_from_unit_displacement( self.unit_displacements, self.displacement_sampling_frequency, self.unit_locations, - temporal_bins, - spatial_bins, + raw_motion.temporal_bins_s[0], + raw_motion.spatial_bins_um, direction_dim=self.direction_dim, ) # align globally gt_motion and motion to avoid offsets motion = raw_motion.copy() - motion += np.median(gt_motion - motion) + motion.displacement += np.median(gt_motion.displacement - motion.displacement) self.result["gt_motion"] = gt_motion self.result["motion"] = motion _run_key_saved = [ - ("raw_motion", "npy"), - ("temporal_bins", "npy"), - ("spatial_bins", "npy"), + ("raw_motion", "Motion"), ("step_run_times", "pickle"), ] _result_key_saved = [ ( "gt_motion", - "npy", + "Motion", ), ( "motion", - "npy", + "Motion", ), ] @@ -189,20 +193,20 @@ def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_p # dirft ax = ax1 = fig.add_subplot(gs[2:7]) ax1.sharey(ax0) - temporal_bins = bench.result["temporal_bins"] - spatial_bins = bench.result["spatial_bins"] + # temporal_bins_s = bench.result["temporal_bins_s"] + # spatial_bins_um = bench.result["spatial_bins_um"] gt_motion = bench.result["gt_motion"] motion = bench.result["motion"] # for i in range(self.gt_unit_positions.shape[1]): - # ax.plot(temporal_bins, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") + # ax.plot(temporal_bins_s, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") for i in range(gt_motion.shape[1]): - depth = spatial_bins[i] + depth = motion.spatial_bins_um[i] if gt_drift: - ax.plot(temporal_bins, gt_motion[:, i] + depth, color="green", lw=4) + ax.plot(motion.temporal_bins_s[0], gt_motion.displacement[0][:, i] + depth, color="green", lw=4) if tested_drift: - ax.plot(temporal_bins, motion[:, i] + depth, color="cyan", lw=2) + ax.plot(motion.temporal_bins_s[0], motion.displacement[0][:, i] + depth, color="cyan", lw=2) ax.set_xlabel("time (s)") _simpleaxis(ax) @@ -241,14 +245,14 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): gt_motion = bench.result["gt_motion"] motion = bench.result["motion"] - temporal_bins = bench.result["temporal_bins"] - spatial_bins = bench.result["spatial_bins"] + # temporal_bins_s = bench.result["temporal_bins_s"] + # spatial_bins_um = bench.result["spatial_bins_um"] fig = plt.figure(figsize=figsize) gs = fig.add_gridspec(2, 2) - errors = gt_motion - motion + errors = gt_motion.displacement[0] - motion.displacement[0] channel_positions = bench.recording.get_channel_locations() probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() @@ -259,7 +263,7 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): aspect="auto", interpolation="nearest", origin="lower", - extent=(temporal_bins[0], temporal_bins[-1], spatial_bins[0], spatial_bins[-1]), + extent=(motion.temporal_bins_s[0], motion.temporal_bins_s[-1], motion.spatial_bins_um[0], motion.spatial_bins_um[-1]), ) plt.colorbar(im, ax=ax, label="error") ax.set_ylabel("depth (um)") @@ -270,7 +274,7 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): ax = fig.add_subplot(gs[1, 0]) mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) - ax.plot(temporal_bins, mean_error) + ax.plot(motion.temporal_bins_s, mean_error) ax.set_xlabel("time (s)") ax.set_ylabel("error") _simpleaxis(ax) @@ -279,7 +283,7 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): ax = fig.add_subplot(gs[1, 1]) depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) - ax.plot(spatial_bins, depth_error) + ax.plot(motion.spatial_bins_um, depth_error) ax.axvline(probe_y_min, color="k", ls="--", alpha=0.5) ax.axvline(probe_y_max, color="k", ls="--", alpha=0.5) ax.set_xlabel("depth (um)") @@ -305,17 +309,17 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) gt_motion = bench.result["gt_motion"] motion = bench.result["motion"] - temporal_bins = bench.result["temporal_bins"] - spatial_bins = bench.result["spatial_bins"] + # temporal_bins_s = bench.result["temporal_bins_s"] + # spatial_bins_um = bench.result["spatial_bins_um"] # c = colors[count] if colors is not None else None c = colors[key] - errors = gt_motion - motion + errors = gt_motion.displacement[0] - motion.displacement[0] mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) - axes[0].plot(temporal_bins, mean_error, lw=1, label=label, color=c) + axes[0].plot(motion.temporal_bins_s, mean_error, lw=1, label=label, color=c) parts = axes[1].violinplot(mean_error, [count], showmeans=True) if c is not None: for pc in parts["bodies"]: @@ -325,7 +329,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) if k != "bodies": # for line in parts[k]: parts[k].set_color(c) - axes[2].plot(spatial_bins, depth_error, label=label, color=c) + axes[2].plot(motion.spatial_bins_um, depth_error, label=label, color=c) ax0 = ax = axes[0] ax.set_xlabel("Time [s]") @@ -362,8 +366,8 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # "peaks", # "selected_peaks", # "motion", -# "temporal_bins", -# "spatial_bins", +# "temporal_bins_s", +# "spatial_bins_um", # "peak_locations", # "gt_motion", # ) @@ -439,7 +443,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # self.recording, self.selected_peaks, **self.localize_kwargs, **self.job_kwargs # ) # t3 = time.perf_counter() -# self.motion, self.temporal_bins, self.spatial_bins = estimate_motion( +# self.motion, self.temporal_bins_s, self.spatial_bins_um = estimate_motion( # self.recording, self.selected_peaks, self.peak_locations, **self.estimate_motion_kwargs # ) @@ -464,7 +468,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # def run_estimate_motion(self): # # usefull to re run only the motion estimate with peak localization # t3 = time.perf_counter() -# self.motion, self.temporal_bins, self.spatial_bins = estimate_motion( +# self.motion, self.temporal_bins_s, self.spatial_bins_um = estimate_motion( # self.recording, self.selected_peaks, self.peak_locations, **self.estimate_motion_kwargs # ) # t4 = time.perf_counter() @@ -480,7 +484,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # self.save_to_folder() # def compute_gt_motion(self): -# self.gt_unit_positions, _ = mr.extract_units_drift_vector(self.mearec_filename, time_vector=self.temporal_bins) +# self.gt_unit_positions, _ = mr.extract_units_drift_vector(self.mearec_filename, time_vector=self.temporal_bins_s) # template_locations = np.array(mr.load_recordings(self.mearec_filename).template_locations) # assert len(template_locations.shape) == 3 @@ -490,18 +494,18 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # unit_motions = self.gt_unit_positions - unit_mid_positions # # unit_positions = np.mean(self.gt_unit_positions, axis=0) -# if self.spatial_bins is None: +# if self.spatial_bins_um is None: # self.gt_motion = np.mean(unit_motions, axis=1)[:, None] # channel_positions = self.recording.get_channel_locations() # probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() # center = (probe_y_min + probe_y_max) // 2 -# self.spatial_bins = np.array([center]) +# self.spatial_bins_um = np.array([center]) # else: # # time, units # self.gt_motion = np.zeros_like(self.motion) # for t in range(self.gt_unit_positions.shape[0]): # f = scipy.interpolate.interp1d(unit_mid_positions, unit_motions[t, :], fill_value="extrapolate") -# self.gt_motion[t, :] = f(self.spatial_bins) +# self.gt_motion[t, :] = f(self.spatial_bins_um) # def plot_true_drift(self, scaling_probe=1.5, figsize=(15, 10), axes=None): # if axes is None: @@ -535,11 +539,11 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # ax = axes[1] # for i in range(self.gt_unit_positions.shape[1]): -# ax.plot(self.temporal_bins, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") +# ax.plot(self.temporal_bins_s, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") # for i in range(self.gt_motion.shape[1]): -# depth = self.spatial_bins[i] -# ax.plot(self.temporal_bins, self.gt_motion[:, i] + depth, color="green", lw=4) +# depth = self.spatial_bins_um[i] +# ax.plot(self.temporal_bins_s, self.gt_motion[:, i] + depth, color="green", lw=4) # # ax.set_ylim(ymin, ymax) # ax.set_xlabel("time (s)") @@ -618,15 +622,15 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # ax.axhline(probe_y_max, color="k", ls="--", alpha=0.5) # if show_drift: -# if self.spatial_bins is None: +# if self.spatial_bins_um is None: # center = (probe_y_min + probe_y_max) // 2 -# ax.plot(self.temporal_bins, self.gt_motion[:, 0] + center, color="green", lw=1.5) -# ax.plot(self.temporal_bins, self.motion[:, 0] + center, color="orange", lw=1.5) +# ax.plot(self.temporal_bins_s, self.gt_motion[:, 0] + center, color="green", lw=1.5) +# ax.plot(self.temporal_bins_s, self.motion[:, 0] + center, color="orange", lw=1.5) # else: # for i in range(self.gt_motion.shape[1]): -# depth = self.spatial_bins[i] -# ax.plot(self.temporal_bins, self.gt_motion[:, i] + depth, color="green", lw=1.5) -# ax.plot(self.temporal_bins, self.motion[:, i] + depth, color="orange", lw=1.5) +# depth = self.spatial_bins_um[i] +# ax.plot(self.temporal_bins_s, self.gt_motion[:, i] + depth, color="green", lw=1.5) +# ax.plot(self.temporal_bins_s, self.motion[:, i] + depth, color="orange", lw=1.5) # if show_histogram: # ax2 = fig.add_subplot(gs[3]) @@ -672,8 +676,8 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # self.peak_locations, # self.recording, # self.motion, -# self.temporal_bins, -# self.spatial_bins, +# self.temporal_bins_s, +# self.spatial_bins_um, # direction="y", # ) # if axes is None: @@ -735,18 +739,18 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # colors = plt.colormaps["jet"].resampled(n) # for i in range(0, n, step): # ax = axs[0] -# ax.plot(self.temporal_bins, self.gt_motion[:, i], lw=1.5, ls="--", color=colors(i)) +# ax.plot(self.temporal_bins_s, self.gt_motion[:, i], lw=1.5, ls="--", color=colors(i)) # ax.plot( -# self.temporal_bins, +# self.temporal_bins_s, # self.motion[:, i], # lw=1.5, # ls="-", # color=colors(i), -# label=f"{self.spatial_bins[i]:0.1f}", +# label=f"{self.spatial_bins_um[i]:0.1f}", # ) # ax = axs[1] -# ax.plot(self.temporal_bins, self.motion[:, i] - self.gt_motion[:, i], lw=1.5, ls="-", color=colors(i)) +# ax.plot(self.temporal_bins_s, self.motion[:, i] - self.gt_motion[:, i], lw=1.5, ls="-", color=colors(i)) # ax = axs[0] # ax.set_title(self.title) @@ -775,7 +779,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # aspect="auto", # interpolation="nearest", # origin="lower", -# extent=(self.temporal_bins[0], self.temporal_bins[-1], self.spatial_bins[0], self.spatial_bins[-1]), +# extent=(self.temporal_bins_s[0], self.temporal_bins_s[-1], self.spatial_bins_um[0], self.spatial_bins_um[-1]), # ) # plt.colorbar(im, ax=ax, label="error") # ax.set_ylabel("depth (um)") @@ -786,7 +790,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # ax = fig.add_subplot(gs[1, 0]) # mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) -# ax.plot(self.temporal_bins, mean_error) +# ax.plot(self.temporal_bins_s, mean_error) # ax.set_xlabel("time (s)") # ax.set_ylabel("error") # _simpleaxis(ax) @@ -795,7 +799,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # ax = fig.add_subplot(gs[1, 1]) # depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) -# ax.plot(self.spatial_bins, depth_error) +# ax.plot(self.spatial_bins_um, depth_error) # ax.axvline(probe_y_min, color="k", ls="--", alpha=0.5) # ax.axvline(probe_y_max, color="k", ls="--", alpha=0.5) # ax.set_xlabel("depth (um)") @@ -817,7 +821,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) # depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) -# axes[0].plot(benchmark.temporal_bins, mean_error, lw=1, label=benchmark.title, color=c) +# axes[0].plot(benchmark.temporal_bins_s, mean_error, lw=1, label=benchmark.title, color=c) # parts = axes[1].violinplot(mean_error, [count], showmeans=True) # if c is not None: # for pc in parts["bodies"]: @@ -827,7 +831,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # if k != "bodies": # # for line in parts[k]: # parts[k].set_color(c) -# axes[2].plot(benchmark.spatial_bins, depth_error, label=benchmark.title, color=c) +# axes[2].plot(benchmark.spatial_bins_um, depth_error, label=benchmark.title, color=c) # ax0 = ax = axes[0] # ax.set_xlabel("Time [s]") @@ -876,10 +880,10 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # interpolation="nearest", # origin="lower", # extent=( -# benchmark.temporal_bins[0], -# benchmark.temporal_bins[-1], -# benchmark.spatial_bins[0], -# benchmark.spatial_bins[-1], +# benchmark.temporal_bins_s[0], +# benchmark.temporal_bins_s[-1], +# benchmark.spatial_bins_um[0], +# benchmark.spatial_bins_um[-1], # ), # ) # fig.colorbar(im, ax=ax, label="error") @@ -897,11 +901,11 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # def plot_motions_several_benchmarks(benchmarks): # fig, ax = plt.subplots(figsize=(15, 5)) -# ax.plot(list(benchmarks)[0].temporal_bins, list(benchmarks)[0].gt_motion[:, 0], lw=2, c="k", label="real motion") +# ax.plot(list(benchmarks)[0].temporal_bins_s, list(benchmarks)[0].gt_motion[:, 0], lw=2, c="k", label="real motion") # for count, benchmark in enumerate(benchmarks): -# ax.plot(benchmark.temporal_bins, benchmark.motion.mean(1), lw=1, c=f"C{count}", label=benchmark.title) +# ax.plot(benchmark.temporal_bins_s, benchmark.motion.mean(1), lw=1, c=f"C{count}", label=benchmark.title) # ax.fill_between( -# benchmark.temporal_bins, +# benchmark.temporal_bins_s, # benchmark.motion.mean(1) - benchmark.motion.std(1), # benchmark.motion.mean(1) + benchmark.motion.std(1), # color=f"C{count}", diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index b2cf56eb9c..e9f128993d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -406,6 +406,8 @@ def _save_keys(self, saved_keys, folder): pickle.dump(self.result[k], f) elif format == "sorting": self.result[k].save(folder=folder / k, format="numpy_folder", overwrite=True) + elif format == "Motion": + self.result[k].save(folder=folder / k) elif format == "zarr_templates": self.result[k].to_zarr(folder / k) elif format == "sorting_analyzer": @@ -440,6 +442,10 @@ def load_folder(cls, folder): from spikeinterface.core import load_extractor result[k] = load_extractor(folder / k) + elif format == "Motion": + from spikeinterface.sortingcomponents.motion_utils import Motion + + result[k] = Motion.load(folder / k) elif format == "zarr_templates": from spikeinterface.core.template import Templates diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py index 1edf484aa4..93c4a0741f 100644 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -236,3 +236,11 @@ def __eq__(self, other): return False return True + + def copy(self): + return Motion( + self.displacement.copy(), + self.temporal_bins_s.copy(), + self.spatial_bins_um.copy(), + interpolation_method=self.interpolation_method + ) From 4a1282f7e0dd090e8ef400414f58be9c0dbddbe2 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 7 Jun 2024 16:12:00 +0200 Subject: [PATCH 086/136] wip --- src/spikeinterface/widgets/amplitudes.py | 6 +-- src/spikeinterface/widgets/unit_waveforms.py | 43 +++++++++++-------- .../widgets/utils_ipywidgets.py | 2 +- .../widgets/utils_matplotlib.py | 8 +--- src/spikeinterface/widgets/widget_list.py | 3 ++ 5 files changed, 34 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index efbf6f3f32..9d0222cdee 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -189,7 +189,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.next_data_plot = data_plot.copy() cm = 1 / 2.54 - we = data_plot["sorting_analyzer"] + analyzer = data_plot["sorting_analyzer"] width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -202,8 +202,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - self.unit_selector = UnitSelector(we.unit_ids) - self.unit_selector.value = list(we.unit_ids)[:1] + self.unit_selector = UnitSelector(analyzer.unit_ids) + self.unit_selector.value = list(analyzer.unit_ids)[:1] self.checkbox_histograms = W.Checkbox( value=data_plot["plot_histograms"], diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index add8c820b8..b046e55fbf 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -252,7 +252,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: if dp.same_axis: backend_kwargs["num_axes"] = 1 - backend_kwargs["ncols"] = None + backend_kwargs["ncols"] = 1 else: backend_kwargs["num_axes"] = len(dp.unit_ids) backend_kwargs["ncols"] = min(dp.ncols, len(dp.unit_ids)) @@ -487,11 +487,10 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_plot(None) - - self.unit_selector.observe(self._update_plot, names="value", type="change") - self.scaler.observe(self._update_plot, names="value", type="change") - self.widen_narrow.observe(self._update_plot, names="value", type="change") for w in ( + self.unit_selector, + self.scaler, + self.widen_narrow, self.same_axis_button, self.plot_templates_button, self.template_shading_button, @@ -592,30 +591,38 @@ def _update_plot(self, change): ax.axis("off") # update probe plot - self.ax_probe.plot( + self._plot_probe( + self.ax_probe, + channel_locations, + unit_ids, + ) + fig_probe = self.ax_probe.get_figure() + + self.fig_wf.canvas.draw() + self.fig_wf.canvas.flush_events() + fig_probe.canvas.draw() + fig_probe.canvas.flush_events() + + def _plot_probe(self, ax, channel_locations, unit_ids): + # update probe plot + ax.plot( channel_locations[:, 0], channel_locations[:, 1], ls="", marker="o", color="gray", markersize=2, alpha=0.5 ) - self.ax_probe.axis("off") - self.ax_probe.axis("equal") + ax.axis("off") + ax.axis("equal") # TODO this could be done with probeinterface plotting plotting tools!! for unit in unit_ids: - channel_inds = data_plot["sparsity"].unit_id_to_channel_indices[unit] - self.ax_probe.plot( + channel_inds = self.data_plot["sparsity"].unit_id_to_channel_indices[unit] + ax.plot( channel_locations[channel_inds, 0], channel_locations[channel_inds, 1], ls="", marker="o", markersize=3, - color=self.next_data_plot["unit_colors"][unit], + color=self.data_plot["unit_colors"][unit], ) - self.ax_probe.set_xlim(np.min(channel_locations[:, 0]) - 10, np.max(channel_locations[:, 0]) + 10) - fig_probe = self.ax_probe.get_figure() - - self.fig_wf.canvas.draw() - self.fig_wf.canvas.flush_events() - fig_probe.canvas.draw() - fig_probe.canvas.flush_events() + ax.set_xlim(np.min(channel_locations[:, 0]) - 10, np.max(channel_locations[:, 0]) + 10) def get_waveforms_scales(templates, channel_locations, nbefore, x_offset_units=False, widen_narrow_scale=1.0): diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 12985d366f..75738209a1 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -401,7 +401,7 @@ def __init__(self, unit_ids, **kwargs): options=self.unit_ids, value=self.unit_ids, disabled=False, - layout=W.Layout(height="100%", width="80%", align="center"), + layout=W.Layout(height="100%", width="4cm", align="center"), ) super(W.VBox, self).__init__(children=[label, self.selector], **kwargs) diff --git a/src/spikeinterface/widgets/utils_matplotlib.py b/src/spikeinterface/widgets/utils_matplotlib.py index 825245750f..ceb7605d25 100644 --- a/src/spikeinterface/widgets/utils_matplotlib.py +++ b/src/spikeinterface/widgets/utils_matplotlib.py @@ -1,6 +1,5 @@ from __future__ import annotations -import matplotlib import matplotlib.pyplot as plt import numpy as np @@ -12,13 +11,10 @@ def make_mpl_figure(figure=None, ax=None, axes=None, ncols=None, num_axes=None, if figure is not None: assert ax is None and axes is None, "figure/ax/axes : only one of then can be not None" if num_axes is None: - if "ipympl" not in matplotlib.get_backend(): - ax = figure.add_subplot(111) - else: - ax = figure.add_subplot(111) + ax = figure.add_subplot(111) axes = np.array([[ax]]) else: - assert ncols is not None + assert ncols is not None, "ncols must be provided when num_axes is provided" axes = [] nrows = int(np.ceil(num_axes / ncols)) axes = np.full((nrows, ncols), fill_value=None, dtype=object) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index b3c1820276..b65fe97a3c 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -13,6 +13,7 @@ from .motion import MotionWidget from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget from .peak_activity import PeakActivityMapWidget +from .potential_merges import PotentialMergesWidget from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget from .rasters import RasterWidget @@ -48,6 +49,7 @@ MultiCompAgreementBySorterWidget, MultiCompGraphWidget, PeakActivityMapWidget, + PotentialMergesWidget, ProbeMapWidget, QualityMetricsWidget, RasterWidget, @@ -119,6 +121,7 @@ plot_multicomparison_agreement_by_sorter = MultiCompAgreementBySorterWidget plot_multicomparison_graph = MultiCompGraphWidget plot_peak_activity = PeakActivityMapWidget +plot_potential_merges = PotentialMergesWidget plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget plot_rasters = RasterWidget From 968923c47919358f8090052268f8e87a01c06395 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 7 Jun 2024 16:56:59 +0200 Subject: [PATCH 087/136] wip2 --- .../widgets/potential_merges.py | 232 ++++++++++++++++++ 1 file changed, 232 insertions(+) create mode 100644 src/spikeinterface/widgets/potential_merges.py diff --git a/src/spikeinterface/widgets/potential_merges.py b/src/spikeinterface/widgets/potential_merges.py new file mode 100644 index 0000000000..da967d2f87 --- /dev/null +++ b/src/spikeinterface/widgets/potential_merges.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import numpy as np +from warnings import warn + +from .base import BaseWidget, default_backend_kwargs + +from .amplitudes import AmplitudesWidget +from .crosscorrelograms import CrossCorrelogramsWidget +from .unit_templates import UnitTemplatesWidget + +from .utils import get_some_colors + +from ..core.sortinganalyzer import SortingAnalyzer + + +class PotentialMergesWidget(BaseWidget): + """ + Plots potential merges + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The input waveform extractor + potential_merges : list of lists or tuples + List of potential merges (see `spikeinterface.curation.get_potential_auto_merges`) + segment_index : int + The segment index to display + max_spike_samples : int or None, default: None + The maximum number of spikes to display per unit + """ + + def __init__( + self, + sorting_analyzer: SortingAnalyzer, + potential_merges: list, + unit_colors: list = None, + segment_index: int = 0, + max_spikes_per_unit: int = 100, + backend=None, + **backend_kwargs, + ): + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + + self.check_extensions(sorting_analyzer, ["templates", "spike_amplitudes", "correlograms"]) + + unique_merge_units = np.unique([u for merge in potential_merges for u in merge]) + if unit_colors is None: + unit_colors = get_some_colors(sorting_analyzer.unit_ids) + + plot_data = dict( + sorting_analyzer=sorting_analyzer, + potential_merges=potential_merges, + unit_colors=unit_colors, + segment_index=segment_index, + max_spikes_per_unit=max_spikes_per_unit, + unique_merge_units=unique_merge_units, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + + # import ipywidgets.widgets as widgets + import ipywidgets.widgets as W + from IPython.display import display + from .utils_ipywidgets import check_ipywidget_backend, ScaleWidget, WidenNarrowWidget + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + analyzer = data_plot["sorting_analyzer"] + + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] * 3 + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = W.Output() + with output: + self.figure = plt.figure( + figsize=((ratios[1] * width_cm) * cm, height_cm * cm), + constrained_layout=True, + ) + plt.show() + # find max number of merges: + self.gs = None + self.axes_amplitudes = None + self.ax_templates = None + self.ax_probe = None + self.axes_cc = None + + # Instantiate sub-widgets + self.w_amplitudes = AmplitudesWidget( + analyzer, + unit_colors=data_plot["unit_colors"], + unit_ids=data_plot["unique_merge_units"], + plot_histograms=True, + plot_legend=False, + immediate_plot=False, + ) + self.w_templates = UnitTemplatesWidget( + analyzer, + unit_ids=data_plot["unique_merge_units"], + unit_colors=data_plot["unit_colors"], + plot_legend=False, + immediate_plot=False, + ) + self.w_crosscorrelograms = CrossCorrelogramsWidget( + analyzer, + unit_ids=data_plot["unique_merge_units"], + min_similarity_for_correlograms=0, + unit_colors=data_plot["unit_colors"], + immediate_plot=False, + ) + + self.unit_selector = W.Dropdown( + options=data_plot["potential_merges"], value=data_plot["potential_merges"][0], layout=W.Layout(width="3cm") + ) + self.previous_num_merges = len(data_plot["potential_merges"][0]) + self.scaler = ScaleWidget(value=1.0) + self.widen_narrow = WidenNarrowWidget(value=1.0) + + left_sidebar = W.VBox([self.unit_selector, self.scaler, self.widen_narrow], layout=W.Layout(width="3cm")) + + self.widget = W.AppLayout( + center=self.figure.canvas, + left_sidebar=left_sidebar, + pane_widths=ratios + [0], + ) + + # a first update + self._update_plot(None) + + self.unit_selector.observe(self._update_plot, names="value", type="change") + self.scaler.observe(self._update_plot, names="value", type="change") + self.widen_narrow.observe(self._update_plot, names="value", type="change") + + if backend_kwargs["display"]: + display(self.widget) + + def _update_gs(self, merge_units, ncols, right_axes): + import matplotlib.gridspec as gridspec + + # we create a vertical grid with 1 row between the 3 first plots + n_units = len(merge_units) + unit_len_in_gs = ncols // n_units + nrows = ncols * 3 + 2 + print("Unit len in gs", unit_len_in_gs) + + if self.gs is not None and self.previous_num_merges == len(merge_units): + self.ax_templates.clear() + self.ax_probe.clear() + for ax in self.axes_amplitudes: + ax.clear() + for ax in self.axes_cc.flatten(): + ax.clear() + else: + self.figure.clear() + if self.axes_cc is not None: + for ax in self.axes_cc.flatten(): + ax.remove() + + self.gs = gridspec.GridSpec(nrows, ncols, figure=self.figure) + self.ax_templates = self.figure.add_subplot(self.gs[:ncols, :right_axes]) + self.ax_probe = self.figure.add_subplot(self.gs[:ncols, right_axes:]) + row_offset = ncols + 1 + ax_amplitudes_ts = self.figure.add_subplot(self.gs[row_offset : row_offset + ncols, :right_axes]) + ax_amplitudes_hist = self.figure.add_subplot(self.gs[row_offset : row_offset + ncols, right_axes:]) + self.axes_amplitudes = [ax_amplitudes_ts, ax_amplitudes_hist] + row_offset += ncols + 1 + self.axes_cc = [] + for i in range(0, n_units): + for j in range(0, n_units): + self.axes_cc.append( + self.figure.add_subplot( + self.gs[ + row_offset + (unit_len_in_gs) * i : row_offset + (unit_len_in_gs) * (i + 1), + j * unit_len_in_gs : (j + 1) * unit_len_in_gs, + ] + ) + ) + self.axes_cc = np.array(self.axes_cc).reshape((n_units, n_units)) + self.previous_num_merges = len(merge_units) + + def _update_plot(self, change=None): + from math import lcm + + merge_units = self.unit_selector.value + channel_locations = self.data_plot["sorting_analyzer"].get_channel_locations() + + if len(np.unique([len(m) for m in self.data_plot["potential_merges"]])) == 1: + ncols = 2 * len(merge_units) + else: + ncols = lcm(*[len(m) for m in self.data_plot["potential_merges"]]) + right_axes = int(ncols * 2 / 3) + print(ncols, right_axes) + self._update_gs(merge_units, ncols, right_axes) + + # unroll the merges + plot_unit_ids = [] + for m in merge_units: + plot_unit_ids.append(m) + + backend_kwargs_mpl = default_backend_kwargs["matplotlib"].copy() + backend_kwargs_mpl.pop("axes") + backend_kwargs_mpl.pop("ax") + + amplitude_data_plot = self.w_amplitudes.data_plot.copy() + amplitude_data_plot["unit_ids"] = plot_unit_ids + self.w_amplitudes.plot_matplotlib(amplitude_data_plot, ax=None, axes=self.axes_amplitudes, **backend_kwargs_mpl) + + unit_template_data_plot = self.w_templates.data_plot.copy() + unit_template_data_plot["unit_ids"] = plot_unit_ids + unit_template_data_plot["same_axis"] = True + unit_template_data_plot["set_title"] = False + unit_template_data_plot["scale"] = self.scaler.value + unit_template_data_plot["widen_narrow_scale"] = self.widen_narrow.value + self.w_templates.plot_matplotlib(unit_template_data_plot, ax=self.ax_templates, axes=None, **backend_kwargs_mpl) + self.ax_templates.axis("off") + self.w_templates._plot_probe(self.ax_probe, channel_locations, plot_unit_ids) + crosscorrelograms_data_plot = self.w_crosscorrelograms.data_plot.copy() + crosscorrelograms_data_plot["unit_ids"] = plot_unit_ids + self.w_crosscorrelograms.plot_matplotlib( + crosscorrelograms_data_plot, axes=self.axes_cc, ax=None, **backend_kwargs_mpl + ) + self.figure.canvas.draw() + self.figure.canvas.flush_events() From a00e37cb184551bc085e2e141116e1364623fbb0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 7 Jun 2024 19:53:11 +0200 Subject: [PATCH 088/136] First prototype --- .../widgets/potential_merges.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/widgets/potential_merges.py b/src/spikeinterface/widgets/potential_merges.py index da967d2f87..52035f919f 100644 --- a/src/spikeinterface/widgets/potential_merges.py +++ b/src/spikeinterface/widgets/potential_merges.py @@ -60,6 +60,7 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_ipywidgets(self, data_plot, **backend_kwargs): + from math import lcm import matplotlib.pyplot as plt # import ipywidgets.widgets as widgets @@ -118,9 +119,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): immediate_plot=False, ) - self.unit_selector = W.Dropdown( - options=data_plot["potential_merges"], value=data_plot["potential_merges"][0], layout=W.Layout(width="3cm") - ) + options = ["-".join([str(u) for u in m]) for m in data_plot["potential_merges"]] + value = options[0] + self.unit_selector = W.Dropdown(options=options, value=value, layout=W.Layout(width="3cm")) self.previous_num_merges = len(data_plot["potential_merges"][0]) self.scaler = ScaleWidget(value=1.0) self.widen_narrow = WidenNarrowWidget(value=1.0) @@ -133,6 +134,14 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): pane_widths=ratios + [0], ) + if len(np.unique([len(m) for m in self.data_plot["potential_merges"]])) == 1: + ncols = 2 * len(self.data_plot) + else: + ncols = lcm(*[len(m) for m in self.data_plot["potential_merges"]]) + right_axes = int(ncols * 2 / 3) + self.ncols = ncols + self.right_axes = right_axes + # a first update self._update_plot(None) @@ -143,14 +152,15 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): if backend_kwargs["display"]: display(self.widget) - def _update_gs(self, merge_units, ncols, right_axes): + def _update_gs(self, merge_units): import matplotlib.gridspec as gridspec # we create a vertical grid with 1 row between the 3 first plots n_units = len(merge_units) - unit_len_in_gs = ncols // n_units + ncols = self.ncols + right_axes = self.right_axes + unit_len_in_gs = self.ncols // n_units nrows = ncols * 3 + 2 - print("Unit len in gs", unit_len_in_gs) if self.gs is not None and self.previous_num_merges == len(merge_units): self.ax_templates.clear() @@ -161,10 +171,6 @@ def _update_gs(self, merge_units, ncols, right_axes): ax.clear() else: self.figure.clear() - if self.axes_cc is not None: - for ax in self.axes_cc.flatten(): - ax.remove() - self.gs = gridspec.GridSpec(nrows, ncols, figure=self.figure) self.ax_templates = self.figure.add_subplot(self.gs[:ncols, :right_axes]) self.ax_probe = self.figure.add_subplot(self.gs[:ncols, right_axes:]) @@ -188,23 +194,17 @@ def _update_gs(self, merge_units, ncols, right_axes): self.previous_num_merges = len(merge_units) def _update_plot(self, change=None): - from math import lcm merge_units = self.unit_selector.value channel_locations = self.data_plot["sorting_analyzer"].get_channel_locations() - - if len(np.unique([len(m) for m in self.data_plot["potential_merges"]])) == 1: - ncols = 2 * len(merge_units) - else: - ncols = lcm(*[len(m) for m in self.data_plot["potential_merges"]]) - right_axes = int(ncols * 2 / 3) - print(ncols, right_axes) - self._update_gs(merge_units, ncols, right_axes) + unit_ids = self.data_plot["sorting_analyzer"].unit_ids # unroll the merges + unit_ids_str = [str(u) for u in unit_ids] plot_unit_ids = [] - for m in merge_units: - plot_unit_ids.append(m) + for m in merge_units.split("-"): + plot_unit_ids.append(unit_ids[unit_ids_str.index(m)]) + self._update_gs(plot_unit_ids) backend_kwargs_mpl = default_backend_kwargs["matplotlib"].copy() backend_kwargs_mpl.pop("axes") From f1a33d05cc5d74bb139374af45e7d1508868c040 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 7 Jun 2024 22:44:02 +0200 Subject: [PATCH 089/136] Update src/spikeinterface/widgets/potential_merges.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/widgets/potential_merges.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/potential_merges.py b/src/spikeinterface/widgets/potential_merges.py index 52035f919f..a9d2911526 100644 --- a/src/spikeinterface/widgets/potential_merges.py +++ b/src/spikeinterface/widgets/potential_merges.py @@ -21,7 +21,7 @@ class PotentialMergesWidget(BaseWidget): Parameters ---------- sorting_analyzer : SortingAnalyzer - The input waveform extractor + The input sorting analyzer potential_merges : list of lists or tuples List of potential merges (see `spikeinterface.curation.get_potential_auto_merges`) segment_index : int From 177ada59361e9fc407469f0bb43fc4dd4052f6cb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 10 Jun 2024 15:01:24 +0200 Subject: [PATCH 090/136] Final adjustments --- src/spikeinterface/widgets/amplitudes.py | 3 ++- src/spikeinterface/widgets/potential_merges.py | 17 +++++++++++------ src/spikeinterface/widgets/utils_ipywidgets.py | 2 +- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 9d0222cdee..d8073b9806 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -1,5 +1,6 @@ from __future__ import annotations +from networkx import layout import numpy as np from warnings import warn @@ -215,7 +216,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector, self.checkbox_histograms, ], - layout=W.Layout(align_items="center", width="4cm", height="100%"), + layout=W.Layout(align_items="center", width="100%", height="100%"), ) self.widget = W.AppLayout( diff --git a/src/spikeinterface/widgets/potential_merges.py b/src/spikeinterface/widgets/potential_merges.py index a9d2911526..ab34f830c0 100644 --- a/src/spikeinterface/widgets/potential_merges.py +++ b/src/spikeinterface/widgets/potential_merges.py @@ -78,7 +78,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] * 3 - ratios = [0.15, 0.85] + ratios = [0.2, 0.8] with plt.ioff(): output = W.Output() @@ -121,12 +121,16 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): options = ["-".join([str(u) for u in m]) for m in data_plot["potential_merges"]] value = options[0] - self.unit_selector = W.Dropdown(options=options, value=value, layout=W.Layout(width="3cm")) + self.unit_selector_label = W.Label(value="Potential merges:") + self.unit_selector = W.Dropdown(options=options, value=value, layout=W.Layout(width="80%")) self.previous_num_merges = len(data_plot["potential_merges"][0]) - self.scaler = ScaleWidget(value=1.0) - self.widen_narrow = WidenNarrowWidget(value=1.0) + self.scaler = ScaleWidget(value=1.0, layout=W.Layout(width="80%")) + self.widen_narrow = WidenNarrowWidget(value=1.0, layout=W.Layout(width="80%")) - left_sidebar = W.VBox([self.unit_selector, self.scaler, self.widen_narrow], layout=W.Layout(width="3cm")) + left_sidebar = W.VBox( + [self.unit_selector_label, self.unit_selector, self.scaler, self.widen_narrow], + layout=W.Layout(width="100%"), + ) self.widget = W.AppLayout( center=self.figure.canvas, @@ -135,7 +139,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) if len(np.unique([len(m) for m in self.data_plot["potential_merges"]])) == 1: - ncols = 2 * len(self.data_plot) + # in this case we multiply the number of columns by 3 to have 2/3 of the space for the templates + ncols = 3 * len(self.data_plot["potential_merges"]) else: ncols = lcm(*[len(m) for m in self.data_plot["potential_merges"]]) right_axes = int(ncols * 2 / 3) diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 75738209a1..e31f0e0444 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -401,7 +401,7 @@ def __init__(self, unit_ids, **kwargs): options=self.unit_ids, value=self.unit_ids, disabled=False, - layout=W.Layout(height="100%", width="4cm", align="center"), + layout=W.Layout(height="100%", width="3cm", align="center"), ) super(W.VBox, self).__init__(children=[label, self.selector], **kwargs) From 4fdb1fbb328b8843c62a30ce88d1750c77d14efd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 10 Jun 2024 15:25:23 +0200 Subject: [PATCH 091/136] Fix unwanted import --- src/spikeinterface/widgets/amplitudes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index d8073b9806..ac73c57249 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -1,6 +1,5 @@ from __future__ import annotations -from networkx import layout import numpy as np from warnings import warn From d5bf38619a2c1e72c97c1a6ed5e1f1a418a83b1b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 10 Jun 2024 16:44:04 +0200 Subject: [PATCH 092/136] Fix template and shading retrieval in plot_potential_merges --- src/spikeinterface/widgets/potential_merges.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/widgets/potential_merges.py b/src/spikeinterface/widgets/potential_merges.py index ab34f830c0..2cdb7b8682 100644 --- a/src/spikeinterface/widgets/potential_merges.py +++ b/src/spikeinterface/widgets/potential_merges.py @@ -140,7 +140,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): if len(np.unique([len(m) for m in self.data_plot["potential_merges"]])) == 1: # in this case we multiply the number of columns by 3 to have 2/3 of the space for the templates - ncols = 3 * len(self.data_plot["potential_merges"]) + ncols = 3 * len(self.data_plot["potential_merges"][0]) else: ncols = lcm(*[len(m) for m in self.data_plot["potential_merges"]]) right_axes = int(ncols * 2 / 3) @@ -201,8 +201,9 @@ def _update_gs(self, merge_units): def _update_plot(self, change=None): merge_units = self.unit_selector.value - channel_locations = self.data_plot["sorting_analyzer"].get_channel_locations() - unit_ids = self.data_plot["sorting_analyzer"].unit_ids + sorting_analyzer = self.data_plot["sorting_analyzer"] + channel_locations = sorting_analyzer.get_channel_locations() + unit_ids = sorting_analyzer.unit_ids # unroll the merges unit_ids_str = [str(u) for u in unit_ids] @@ -225,6 +226,12 @@ def _update_plot(self, change=None): unit_template_data_plot["set_title"] = False unit_template_data_plot["scale"] = self.scaler.value unit_template_data_plot["widen_narrow_scale"] = self.widen_narrow.value + # update templates and shading + templates_ext = sorting_analyzer.get_extension("templates") + unit_template_data_plot["templates"] = templates_ext.get_templates(unit_ids=plot_unit_ids, operator="average") + unit_template_data_plot["templates_shading"] = self.w_templates._get_template_shadings( + plot_unit_ids, self.w_templates.data_plot["templates_percentile_shading"] + ) self.w_templates.plot_matplotlib(unit_template_data_plot, ax=self.ax_templates, axes=None, **backend_kwargs_mpl) self.ax_templates.axis("off") self.w_templates._plot_probe(self.ax_probe, channel_locations, plot_unit_ids) From 6f0eadcc2a68f8338498493dbcd905ba87e9d335 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Mon, 10 Jun 2024 17:32:00 +0200 Subject: [PATCH 093/136] Add `peak_to_peak` mode to SNR --- src/spikeinterface/qualitymetrics/misc_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index f1082386cc..1e7d4d0444 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -194,7 +194,7 @@ def compute_snrs( A SortingAnalyzer object. peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the template to compute best channels. - peak_mode : "extremum" | "at_index", default: "extremum" + peak_mode : "extremum" | "at_index", "peak_to_peak", default: "extremum" How to compute the amplitude. Extremum takes the maxima/minima At_index takes the value at t=sorting_analyzer.nbefore. @@ -210,7 +210,7 @@ def compute_snrs( noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() assert peak_sign in ("neg", "pos", "both") - assert peak_mode in ("extremum", "at_index") + assert peak_mode in ("extremum", "at_index", "peak_to_peak") if unit_ids is None: unit_ids = sorting_analyzer.unit_ids From 140b248110fb06ca7beaa2e357b032e465e2bcf3 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 11 Jun 2024 16:11:36 +0100 Subject: [PATCH 094/136] add AstypeRecording round default value --- src/spikeinterface/preprocessing/astype.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/astype.py b/src/spikeinterface/preprocessing/astype.py index 4b0d5f9e55..ce8dbc3ca7 100644 --- a/src/spikeinterface/preprocessing/astype.py +++ b/src/spikeinterface/preprocessing/astype.py @@ -20,7 +20,7 @@ class AstypeRecording(BasePreprocessor): dtype of the output recording. recording : Recording The recording extractor to be converted. - round : Bool + round : Bool | None, default: None If True, will round the values to the nearest integer. If None, will round in the case of float to integer conversion. From f89ea90cf40d2b0485ed138b66b9efe653eac399 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 11 Jun 2024 16:19:44 +0100 Subject: [PATCH 095/136] Update to fix PR02 rule --- src/spikeinterface/preprocessing/filter.py | 2 -- src/spikeinterface/preprocessing/resample.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 84ac542acc..ffad9a2029 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -222,7 +222,6 @@ class HighpassFilterRecording(FilterRecording): **filter_kwargs : dict Keyword arguments for `spikeinterface.preprocessing.FilterRecording` class. - {} Returns ------- filter_recording : HighpassFilterRecording @@ -255,7 +254,6 @@ class NotchFilterRecording(BasePreprocessor): margin_ms : float, default: 5.0 Margin in ms on border to avoid border effect - {} Returns ------- filter_recording : NotchFilterRecording diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index 54a602b7c0..ed77ec504d 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -28,7 +28,7 @@ class ResampleRecording(BasePreprocessor): The recording extractor to be re-referenced resample_rate : int The resampling frequency - margin : float, default: 100.0 + margin_ms : float, default: 100.0 Margin in ms for computations, will be used to decrease edge effects. dtype : dtype or None, default: None The dtype of the returned traces. If None, the dtype of the parent recording is used. From fcd6f8e274eab4648f91772d36000025d6b1f22e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 12 Jun 2024 14:08:50 +0200 Subject: [PATCH 096/136] Update src/spikeinterface/qualitymetrics/misc_metrics.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/qualitymetrics/misc_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 1e7d4d0444..cbb55aeb8b 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -194,7 +194,7 @@ def compute_snrs( A SortingAnalyzer object. peak_sign : "neg" | "pos" | "both", default: "neg" The sign of the template to compute best channels. - peak_mode : "extremum" | "at_index", "peak_to_peak", default: "extremum" + peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" How to compute the amplitude. Extremum takes the maxima/minima At_index takes the value at t=sorting_analyzer.nbefore. From 311a4175b01c9a0eb2ab21804a0261607af49790 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 12 Jun 2024 14:21:29 +0200 Subject: [PATCH 097/136] Remove un-used argument --- src/spikeinterface/curation/auto_merge.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 818b6a72b0..3ce12809dd 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -251,7 +251,7 @@ def get_potential_auto_merge( def compute_correlogram_diff( - sorting, correlograms_smoothed, bins, win_sizes, adaptative_window_threshold=0.5, pair_mask=None + sorting, correlograms_smoothed, bins, win_sizes, pair_mask=None ): """ Original author: Aurelien Wyngaard (lussac) @@ -267,8 +267,6 @@ def compute_correlogram_diff( Bins of the correlograms win_sized: TODO - adaptative_window_threshold : float - TODO pair_mask : None or boolean array A bool matrix of size (num_units, num_units) to select which pair to compute. From 80c8847e17418bbfbe80920be90fde02a9c4a51c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jun 2024 12:22:39 +0000 Subject: [PATCH 098/136] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 3ce12809dd..6629e61bfc 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -250,9 +250,7 @@ def get_potential_auto_merge( return potential_merges -def compute_correlogram_diff( - sorting, correlograms_smoothed, bins, win_sizes, pair_mask=None -): +def compute_correlogram_diff(sorting, correlograms_smoothed, bins, win_sizes, pair_mask=None): """ Original author: Aurelien Wyngaard (lussac) From d5e3c8c1be5a3e7f20d1996e9deba3cc842f3a2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 12 Jun 2024 14:46:21 +0200 Subject: [PATCH 099/136] Oops --- src/spikeinterface/curation/auto_merge.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 6629e61bfc..c652089a39 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -193,7 +193,6 @@ def get_potential_auto_merge( correlograms_smoothed, bins, win_sizes, - adaptative_window_threshold=adaptative_window_threshold, pair_mask=pair_mask, ) # print(correlogram_diff) From fa363303145d136fac464918d0279404cabc9f82 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 12 Jun 2024 15:08:38 +0200 Subject: [PATCH 100/136] various fixes --- .../benchmark/benchmark_motion_estimation.py | 13 +++++++------ .../benchmark/benchmark_motion_interpolation.py | 2 +- .../tests/test_benchmark_motion_estimation.py | 15 ++++++++------- .../tests/test_benchmark_motion_interpolation.py | 6 +++++- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 96b277de6e..2278bfbd3e 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -39,7 +39,7 @@ def get_gt_motion_from_unit_displacement( unit_displacements = unit_displacements[:, :, direction_dim] times = np.arange(unit_displacements.shape[0]) / displacement_sampling_frequency f = scipy.interpolate.interp1d(times, unit_displacements, axis=0) - unit_displacements = f(temporal_bins_s) + unit_displacements = f(temporal_bins_s.clip(times[0], times[-1])) # spatial interpolataion of units discplacement if spatial_bins_um.shape[0] == 1: @@ -131,7 +131,7 @@ def compute_result(self, **result_params): # align globally gt_motion and motion to avoid offsets motion = raw_motion.copy() - motion.displacement += np.median(gt_motion.displacement - motion.displacement) + motion.displacement[0] += np.median(gt_motion.displacement[0] - motion.displacement[0]) self.result["gt_motion"] = gt_motion self.result["motion"] = motion @@ -201,7 +201,7 @@ def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_p # for i in range(self.gt_unit_positions.shape[1]): # ax.plot(temporal_bins_s, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") - for i in range(gt_motion.shape[1]): + for i in range(gt_motion.displacement[0].shape[1]): depth = motion.spatial_bins_um[i] if gt_drift: ax.plot(motion.temporal_bins_s[0], gt_motion.displacement[0][:, i] + depth, color="green", lw=4) @@ -263,7 +263,8 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): aspect="auto", interpolation="nearest", origin="lower", - extent=(motion.temporal_bins_s[0], motion.temporal_bins_s[-1], motion.spatial_bins_um[0], motion.spatial_bins_um[-1]), + extent=(motion.temporal_bins_s[0][0], motion.temporal_bins_s[0][-1], + motion.spatial_bins_um[0], motion.spatial_bins_um[-1]), ) plt.colorbar(im, ax=ax, label="error") ax.set_ylabel("depth (um)") @@ -274,7 +275,7 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): ax = fig.add_subplot(gs[1, 0]) mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) - ax.plot(motion.temporal_bins_s, mean_error) + ax.plot(motion.temporal_bins_s[0], mean_error) ax.set_xlabel("time (s)") ax.set_ylabel("error") _simpleaxis(ax) @@ -319,7 +320,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) - axes[0].plot(motion.temporal_bins_s, mean_error, lw=1, label=label, color=c) + axes[0].plot(motion.temporal_bins_s[0], mean_error, lw=1, label=label, color=c) parts = axes[1].violinplot(mean_error, [count], showmeans=True) if c is not None: for pc in parts["bodies"]: diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index a515424648..5688d2eaf3 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -45,7 +45,7 @@ def run(self, **job_kwargs): elif self.params["recording_source"] == "corrected": correct_motion_kwargs = self.params["correct_motion_kwargs"] recording = InterpolateMotionRecording( - self.drifting_recording, self.motion, self.temporal_bins, self.spatial_bins, **correct_motion_kwargs + self.drifting_recording, self.motion, **correct_motion_kwargs ) else: raise ValueError("recording_source") diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py index dec0e612f8..f1aeeb54fb 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py @@ -52,13 +52,13 @@ def test_benchmark_motion_estimaton(): ) study_folder = cache_folder / "study_motion_estimation" - if study_folder.exists(): - shutil.rmtree(study_folder) - study = MotionEstimationStudy.create(study_folder, datasets, cases) + # if study_folder.exists(): + # shutil.rmtree(study_folder) + # study = MotionEstimationStudy.create(study_folder, datasets, cases) - # run and result - study.run(**job_kwargs) - study.compute_results() + # # run and result + # study.run(**job_kwargs) + # study.compute_results() # load study to check persistency study = MotionEstimationStudy(study_folder) @@ -66,10 +66,11 @@ def test_benchmark_motion_estimaton(): # plots study.plot_true_drift() + study.plot_drift() study.plot_errors() study.plot_summary_errors() - import matplotlib.pyplot as plt + import matplotlib.pyplot as plt plt.show() diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py index bf4522df94..f1b05dbb6d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py @@ -49,8 +49,11 @@ def test_benchmark_motion_interpolation(): spatial_bins, direction_dim=1, ) + # print(gt_motion) + + # import matplotlib.pyplot as plt # fig, ax = plt.subplots() - # ax.imshow(gt_motion.T) + # ax.imshow(gt_motion.displacement[0].T) # plt.show() cases = {} @@ -131,6 +134,7 @@ def test_benchmark_motion_interpolation(): study.plot_sorting_accuracy(mode="depth", mode_best_merge=False) study.plot_sorting_accuracy(mode="depth", mode_best_merge=True) + import matplotlib.pyplot as plt plt.show() From 919e16ad826a5df7e74d4efd3aaff1ac89dfa9ba Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 12 Jun 2024 18:40:05 +0200 Subject: [PATCH 101/136] wip Motion propagation in widgets --- .../sorters/internal/spyking_circus2.py | 4 +- .../tests/test_benchmark_motion_estimation.py | 12 +- .../tests/test_motion_utils.py | 25 +++ src/spikeinterface/widgets/motion.py | 159 +++++++++++++++--- .../widgets/tests/test_widgets.py | 79 +++++++-- src/spikeinterface/widgets/widget_list.py | 4 +- 6 files changed, 233 insertions(+), 50 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 05853b4c39..5d04495c7e 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -317,7 +317,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.preprocessing.motion import load_motion_info motion_info = load_motion_info(motion_folder) - merging_params["maximum_distance_um"] = max(50, 2 * np.abs(motion_info["motion"]).max()) + motion = motion_info["motion"] + max_motion = max(np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement))) + merging_params["maximum_distance_um"] = max(50, 2 * max_motion) # peak_sign = params['detection'].get('peak_sign', 'neg') # best_amplitudes = get_template_extremum_amplitude(templates, peak_sign=peak_sign) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py index f1aeeb54fb..e0f151eafe 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py @@ -52,13 +52,13 @@ def test_benchmark_motion_estimaton(): ) study_folder = cache_folder / "study_motion_estimation" - # if study_folder.exists(): - # shutil.rmtree(study_folder) - # study = MotionEstimationStudy.create(study_folder, datasets, cases) + if study_folder.exists(): + shutil.rmtree(study_folder) + study = MotionEstimationStudy.create(study_folder, datasets, cases) - # # run and result - # study.run(**job_kwargs) - # study.compute_results() + # run and result + study.run(**job_kwargs) + study.compute_results() # load study to check persistency study = MotionEstimationStudy(study_folder) diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py index 84dda89d0d..1542c8531a 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py @@ -5,6 +5,7 @@ import numpy as np import pytest from spikeinterface.sortingcomponents.motion_utils import Motion +from spikeinterface.generation import make_one_displacement_vector if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "sortingcomponents" @@ -12,6 +13,30 @@ cache_folder = Path("cache_folder") / "sortingcomponents" +def make_fake_motion(): + displacement_sampling_frequency = 5. + spatial_bins_um = np.array([100.0, 200.0, 300., 400.]) + + displacement_vector = make_one_displacement_vector( + drift_mode="zigzag", + duration=50.0, + amplitude_factor=1.0, + displacement_sampling_frequency=displacement_sampling_frequency, + period_s=25., + ) + temporal_bins_s = np.arange(displacement_vector.size) / displacement_sampling_frequency + displacement = np.zeros((temporal_bins_s.size, spatial_bins_um.size)) + + n = spatial_bins_um.size + for i in range(n): + displacement[:, i] = displacement_vector * ((i +1 ) / n) + + motion = Motion(displacement, temporal_bins_s, spatial_bins_um, direction="y") + + return motion + + + def test_Motion(): temporal_bins_s = np.arange(0.0, 10.0, 1.0) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 9d64c89e46..d83be77eb3 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -4,15 +4,110 @@ from .base import BaseWidget, to_attr - class MotionWidget(BaseWidget): """ - Plot unit depths + Plot the Motion object + + Parameters + ---------- + motion: Motion + The motion object + segment_index: None | int + If Motion is multi segment, the must be not None + mode: "auto" | "line" | "map" + How to plot map or lines. + "auto" make it automatic if the number of depth is too high. + """ + def __init__( + self, + motion, + segment_index=None, + mode="line", + motion_lim=None, + backend=None, + **backend_kwargs, + + ): + if isinstance(motion, dict): + raise ValueError("The API has changed, plot_motion() used Motion object now, maybe you want plot_motion_info(motion_info)") + + if segment_index is None: + if len(motion.displacement) == 1: + segment_index = 0 + else: + raise ValueError("plot motion : teh Motion object is multi segment you must provide segmentindex=XX") + + plot_data = dict( + motion=motion, + segment_index=segment_index, + mode=mode, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from matplotlib.colors import Normalize + + dp = to_attr(data_plot) + + motion = data_plot["motion"] + segment_index = data_plot["segment_index"] + + assert backend_kwargs["axes"] is None + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + + displacement = motion.displacement[dp.segment_index] + temporal_bins_s = motion.temporal_bins_s[dp.segment_index] + depth = motion.spatial_bins_um + + if dp.motion_lim is None: + motion_lim = np.max(np.abs(displacement)) * 1.05 + else: + motion_lim = dp.motion_lim + + + ax = self.ax + fig = self.figure + if dp.mode == "line": + ax.plot(temporal_bins_s, displacement, alpha=0.2, color="black") + ax.plot(temporal_bins_s, np.mean(displacement, axis=1), color="C0") + ax.set_xlabel("Times [s]") + ax.set_ylabel("motion [um]") + elif dp.mode == "map": + im = ax.imshow( + displacement.T, + interpolation="nearest", + aspect="auto", + origin="lower", + extent=(temporal_bins_s[0], temporal_bins_s[-1], depth[0], depth[-1]), + cmap="PiYG" + ) + im.set_clim(-motion_lim, motion_lim) + + cbar = fig.colorbar(im) + cbar.ax.set_ylabel("motion [um]") + ax.set_xlabel("Times [s]") + ax.set_ylabel("Depth [um]") + + +class MotionInfoWidget(BaseWidget): + """ + Plot motion information from the motion_info dict returned by correct_motion(). + This plot: + * the motion iself + * the peak depth vs time before correction + * the peak depth vs time after correction Parameters ---------- motion_info: dict The motion info return by correct_motion() or load back with load_motion_info() + segment_index: + recording : RecordingExtractor, default: None The recording extractor object (only used to get "real" times) sampling_frequency : float, default: None @@ -36,6 +131,7 @@ class MotionWidget(BaseWidget): def __init__( self, motion_info, + segment_index=None, recording=None, depth_lim=None, motion_lim=None, @@ -47,6 +143,14 @@ def __init__( backend=None, **backend_kwargs, ): + + motion = motion_info["motion"] + if segment_index is None: + if len(motion.displacement) == 1: + segment_index = 0 + else: + raise ValueError("plot motion : teh Motion object is multi segment you must provide segmentindex=XX") + times = recording.get_times() if recording is not None else None plot_data = dict( @@ -59,6 +163,8 @@ def __init__( amplitude_cmap=amplitude_cmap, amplitude_clim=amplitude_clim, amplitude_alpha=amplitude_alpha, + segment_index=segment_index, + recording=recording, **motion_info, ) @@ -80,7 +186,20 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): fig = self.figure fig.clear() - is_rigid = dp.motion.shape[1] == 1 + is_rigid = dp.motion.spatial_bins_um.shape[0] == 1 + + motion = dp.motion + + + displacement = motion.displacement[dp.segment_index] + temporal_bins_s = motion.temporal_bins_s[dp.segment_index] + spatial_bins_um = motion.spatial_bins_um + + if dp.motion_lim is None: + motion_lim = np.max(np.abs(displacement)) * 1.05 + else: + motion_lim = dp.motion_lim + gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.3) ax0 = fig.add_subplot(gs[0, 0]) @@ -91,31 +210,23 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1.sharex(ax0) ax1.sharey(ax0) - if dp.motion_lim is None: - motion_lim = np.max(np.abs(dp.motion)) * 1.05 - else: - motion_lim = dp.motion_lim - if dp.times is None: - temporal_bins_plot = dp.temporal_bins + # temporal_bins_plot = dp.temporal_bins x = dp.peaks["sample_index"] / dp.sampling_frequency else: # use real times and adjust temporal bins with t_start - temporal_bins_plot = dp.temporal_bins + dp.times[0] + # temporal_bins_plot = dp.temporal_bins + dp.times[0] x = dp.times[dp.peaks["sample_index"]] corrected_location = correct_motion_on_peaks( dp.peaks, dp.peak_locations, - dp.sampling_frequency, - dp.motion, - dp.temporal_bins, - dp.spatial_bins, - direction="y", + dp.recording, + dp.motion ) - y = dp.peak_locations["y"] - y2 = corrected_location["y"] + y = dp.peak_locations[motion.direction] + y2 = corrected_location[motion.direction] if dp.scatter_decimate is not None: x = x[:: dp.scatter_decimate] y = y[:: dp.scatter_decimate] @@ -156,8 +267,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1.set_ylabel("Depth [um]") ax1.set_title("Corrected peak depth") - ax2.plot(temporal_bins_plot, dp.motion, alpha=0.2, color="black") - ax2.plot(temporal_bins_plot, np.mean(dp.motion, axis=1), color="C0") + ax2.plot(temporal_bins_s, displacement, alpha=0.2, color="black") + ax2.plot(temporal_bins_s, np.mean(displacement, axis=1), color="C0") ax2.set_ylim(-motion_lim, motion_lim) ax2.set_ylabel("Motion [um]") ax2.set_title("Motion vectors") @@ -165,14 +276,14 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if not is_rigid: im = ax3.imshow( - dp.motion.T, + displacement.T, aspect="auto", origin="lower", extent=( - temporal_bins_plot[0], - temporal_bins_plot[-1], - dp.spatial_bins[0], - dp.spatial_bins[-1], + temporal_bins_s[0], + temporal_bins_s[-1], + spatial_bins_um[0], + spatial_bins_um[-1], ), ) im.set_clim(-motion_lim, motion_lim) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 156d1d92e2..8360842572 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -3,6 +3,8 @@ import os from pathlib import Path +import numpy as np + if __name__ != "__main__": try: import matplotlib @@ -76,25 +78,25 @@ def setUpClass(cls): ) job_kwargs = dict(n_jobs=-1) - # create dense - cls.sorting_analyzer_dense = create_sorting_analyzer(cls.sorting, cls.recording, format="memory", sparse=False) - cls.sorting_analyzer_dense.compute("random_spikes") - cls.sorting_analyzer_dense.compute(extensions_to_compute, **job_kwargs) + # # create dense + # cls.sorting_analyzer_dense = create_sorting_analyzer(cls.sorting, cls.recording, format="memory", sparse=False) + # cls.sorting_analyzer_dense.compute("random_spikes") + # cls.sorting_analyzer_dense.compute(extensions_to_compute, **job_kwargs) - sw.set_default_plotter_backend("matplotlib") + # sw.set_default_plotter_backend("matplotlib") - # make sparse waveforms - cls.sparsity_radius = compute_sparsity(cls.sorting_analyzer_dense, method="radius", radius_um=50) - cls.sparsity_strict = compute_sparsity(cls.sorting_analyzer_dense, method="radius", radius_um=20) - cls.sparsity_large = compute_sparsity(cls.sorting_analyzer_dense, method="radius", radius_um=80) - cls.sparsity_best = compute_sparsity(cls.sorting_analyzer_dense, method="best_channels", num_channels=5) + # # make sparse waveforms + # cls.sparsity_radius = compute_sparsity(cls.sorting_analyzer_dense, method="radius", radius_um=50) + # cls.sparsity_strict = compute_sparsity(cls.sorting_analyzer_dense, method="radius", radius_um=20) + # cls.sparsity_large = compute_sparsity(cls.sorting_analyzer_dense, method="radius", radius_um=80) + # cls.sparsity_best = compute_sparsity(cls.sorting_analyzer_dense, method="best_channels", num_channels=5) - # create sparse - cls.sorting_analyzer_sparse = create_sorting_analyzer( - cls.sorting, cls.recording, format="memory", sparsity=cls.sparsity_radius - ) - cls.sorting_analyzer_sparse.compute("random_spikes") - cls.sorting_analyzer_sparse.compute(extensions_to_compute, **job_kwargs) + # # create sparse + # cls.sorting_analyzer_sparse = create_sorting_analyzer( + # cls.sorting, cls.recording, format="memory", sparsity=cls.sparsity_radius + # ) + # cls.sorting_analyzer_sparse.compute("random_spikes") + # cls.sorting_analyzer_sparse.compute(extensions_to_compute, **job_kwargs) cls.skip_backends = ["ipywidgets", "ephyviewer", "spikeinterface_gui"] # cls.skip_backends = ["ipywidgets", "ephyviewer", "sortingview"] @@ -111,7 +113,7 @@ def setUpClass(cls): "spikeinterface_gui": {}, } - cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) + # cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) from spikeinterface.sortingcomponents.peak_detection import detect_peaks @@ -583,6 +585,45 @@ def test_plot_multicomparison(self): _, axes = plt.subplots(len(mcmp.object_list), 1) sw.plot_multicomparison_agreement_by_sorter(mcmp, axes=axes) + + def test_plot_motion(self): + from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion + motion = make_fake_motion() + + possible_backends = list(sw.MotionWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_motion(motion, backend=backend, mode='line') + sw.plot_motion(motion, backend=backend, mode='map') + + def test_plot_motion_info(self): + from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion + + + motion = make_fake_motion() + rng = np.random.default_rng(seed=2205) + peak_locations = np.zeros(self.peaks.size, dtype=[("x", "float64"), ("y", "float64")]) + peak_locations['y'] = rng.uniform(motion.spatial_bins_um[0], + motion.spatial_bins_um[-1], + size=self.peaks.size) + + motion_info = dict( + motion=motion, + parameters=dict(sampling_frequency=30000.), + run_times=dict(), + peaks=self.peaks, + peak_locations=peak_locations, + ) + + + possible_backends = list(sw.MotionWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_motion_info(motion_info, recording=self.recording, backend=backend) + + + + if __name__ == "__main__": @@ -598,7 +639,7 @@ def test_plot_multicomparison(self): # mytest.test_plot_traces() # mytest.test_plot_spikes_on_traces() # mytest.test_plot_unit_waveforms() - mytest.test_plot_spikes_on_traces() + # mytest.test_plot_spikes_on_traces() # mytest.test_plot_unit_depths() # mytest.test_plot_autocorrelograms() # mytest.test_plot_crosscorrelograms() @@ -618,6 +659,8 @@ def test_plot_multicomparison(self): # mytest.test_plot_peak_activity() # mytest.test_plot_multicomparison() # mytest.test_plot_sorting_summary() + # mytest.test_plot_motion() + mytest.test_plot_motion_info() plt.show() # TestWidgets.tearDownClass() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index b3c1820276..19ce40ca2b 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -10,7 +10,7 @@ from .autocorrelograms import AutoCorrelogramsWidget from .crosscorrelograms import CrossCorrelogramsWidget from .isi_distribution import ISIDistributionWidget -from .motion import MotionWidget +from .motion import MotionWidget, MotionInfoWidget from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget from .peak_activity import PeakActivityMapWidget from .probe_map import ProbeMapWidget @@ -44,6 +44,7 @@ CrossCorrelogramsWidget, ISIDistributionWidget, MotionWidget, + MotionInfoWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget, MultiCompGraphWidget, @@ -115,6 +116,7 @@ plot_crosscorrelograms = CrossCorrelogramsWidget plot_isi_distribution = ISIDistributionWidget plot_motion = MotionWidget +plot_motion_info = MotionInfoWidget plot_multicomparison_agreement = MultiCompGlobalAgreementWidget plot_multicomparison_agreement_by_sorter = MultiCompAgreementBySorterWidget plot_multicomparison_graph = MultiCompGraphWidget From a891045b24b0d55df2f70a8598554533669bcaeb Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 12 Jun 2024 11:20:22 -0600 Subject: [PATCH 102/136] remove upper bound in scipy --- pyproject.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index dadb677056..ea798c31ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ extractors = [ "pyedflib>=0.1.30", "sonpy;python_version<'3.10'", "lxml", # lxml for neuroscope - "scipy<1.13", + "scipy", "ONE-api>=2.7.0", # alf sorter and streaming IBL "ibllib>=2.32.5", # streaming IBL "pymatreader>=0.0.32", # For cell explorer matlab files @@ -75,8 +75,6 @@ extractors = [ streaming_extractors = [ "ONE-api>=2.7.0", # alf sorter and streaming IBL "ibllib>=2.32.5", # streaming IBL - "scipy<1.13", # ibl has a dependency on scipy but it does not have an upper bound - # Remove this once https://github.com/int-brain-lab/ibllib/issues/753 # Following dependencies are for streaming with nwb files "pynwb>=2.6.0", "fsspec", From 58cfcc481141bede3c97676f916d2e3dd4b06389 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 12 Jun 2024 11:23:06 -0600 Subject: [PATCH 103/136] update ibllib --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ea798c31ad..ef7f4bebf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ extractors = [ streaming_extractors = [ "ONE-api>=2.7.0", # alf sorter and streaming IBL - "ibllib>=2.32.5", # streaming IBL + "ibllib>=2.36.0", # streaming IBL # Following dependencies are for streaming with nwb files "pynwb>=2.6.0", "fsspec", From 8111789f0848f4b79ab52a8b36d06231b2f62286 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Proville?= Date: Thu, 13 Jun 2024 16:15:33 +0200 Subject: [PATCH 104/136] Apply suggestions from code review: typography et al Co-authored-by: Alessio Buccino --- src/spikeinterface/curation/curation_format.py | 6 ------ src/spikeinterface/curation/tests/test_curation_format.py | 4 ++-- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 82921a56b5..d6eded4345 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -14,12 +14,6 @@ def validate_curation_dict(curation_dict): ---------- curation_dict : dict - - Returns - ------- - Nothing. - - """ # format diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 1bee5b524c..6d132fbe97 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -106,13 +106,13 @@ def test_curation_format_validation(): # Raised because duplicated merged units validate_curation_dict(duplicate_merge) with pytest.raises(ValueError): - # Raised because Some units belong to multiple merge groups" + # Raised because some units belong to merged and removed unit groups validate_curation_dict(merged_and_removed) with pytest.raises(ValueError): # Some merged units are not in the unit list validate_curation_dict(unknown_merged_unit) with pytest.raises(ValueError): - # Raise beecause Some removed units are not in the unit list + # Raise because some removed units are not in the unit list validate_curation_dict(unknown_removed_unit) From 609cd34732362f4de92672c0574385d6626bdd31 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 13 Jun 2024 09:46:14 -0600 Subject: [PATCH 105/136] update the other ibl --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ef7f4bebf0..51528fcc8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ extractors = [ "lxml", # lxml for neuroscope "scipy", "ONE-api>=2.7.0", # alf sorter and streaming IBL - "ibllib>=2.32.5", # streaming IBL + "ibllib>=2.36.0", # streaming IBL "pymatreader>=0.0.32", # For cell explorer matlab files "zugbruecke>=0.2; sys_platform!='win32'", # For plexon2 ] From bdc7fe45c3937e4f5fcc9f6f015100040c85a595 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 14 Jun 2024 10:49:50 +0200 Subject: [PATCH 106/136] Merci Pierre --- src/spikeinterface/widgets/potential_merges.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/spikeinterface/widgets/potential_merges.py b/src/spikeinterface/widgets/potential_merges.py index 2cdb7b8682..c1c7d86522 100644 --- a/src/spikeinterface/widgets/potential_merges.py +++ b/src/spikeinterface/widgets/potential_merges.py @@ -237,6 +237,10 @@ def _update_plot(self, change=None): self.w_templates._plot_probe(self.ax_probe, channel_locations, plot_unit_ids) crosscorrelograms_data_plot = self.w_crosscorrelograms.data_plot.copy() crosscorrelograms_data_plot["unit_ids"] = plot_unit_ids + merge_unit_indices = np.flatnonzero(np.isin(self.unique_merge_units, plot_unit_ids)) + updated_correlograms = crosscorrelograms_data_plot["correlograms"] + updated_correlograms = updated_correlograms[merge_unit_indices][:, merge_unit_indices] + crosscorrelograms_data_plot["correlograms"] = updated_correlograms self.w_crosscorrelograms.plot_matplotlib( crosscorrelograms_data_plot, axes=self.axes_cc, ax=None, **backend_kwargs_mpl ) From d8a3826fe3c5f74f7db335be01966ed0f55cb2a7 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 17 Jun 2024 17:24:45 +0100 Subject: [PATCH 107/136] Add Zach improvements --- src/spikeinterface/preprocessing/astype.py | 6 +++--- .../preprocessing/detect_bad_channels.py | 14 +++++++------- src/spikeinterface/preprocessing/resample.py | 3 --- .../preprocessing/silence_periods.py | 3 ++- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/preprocessing/astype.py b/src/spikeinterface/preprocessing/astype.py index ce8dbc3ca7..a05610ea2e 100644 --- a/src/spikeinterface/preprocessing/astype.py +++ b/src/spikeinterface/preprocessing/astype.py @@ -17,12 +17,12 @@ class AstypeRecording(BasePreprocessor): Parameters ---------- dtype : None | str | dtype, default: None - dtype of the output recording. + dtype of the output recording. If None, takes dtype from input `recording`. recording : Recording The recording extractor to be converted. round : Bool | None, default: None - If True, will round the values to the nearest integer. - If None, will round in the case of float to integer conversion. + If True, will round the values to the nearest integer using `numpy.round`. + If None and dtype is an integer, will round floats to nearest integer. Returns ------- diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 218c9cb822..5d8f7107c7 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -58,28 +58,28 @@ def detect_bad_channels( std_mad_threshold : float, default: 5 The standard deviation/mad multiplier threshold psd_hf_threshold : float, default: 0.02 - Coeherence+psd. An absolute threshold (uV^2/Hz) used as a cutoff for noise channels. + For coherence+psd - an absolute threshold (uV^2/Hz) used as a cutoff for noise channels. Channels with average power at >80% Nyquist larger than this threshold will be labeled as noise dead_channel_threshold : float, default: -0.5 - Coeherence+psd. Threshold for channel coherence below which channels are labeled as dead + For coherence+psd - threshold for channel coherence below which channels are labeled as dead noisy_channel_threshold : float, default: 1 Threshold for channel coherence above which channels are labeled as noisy (together with psd condition) outside_channel_threshold : float, default: -0.75 - Coeherence+psd. Threshold for channel coherence above which channels at the edge of the recording are marked as outside + For coherence+psd - threshold for channel coherence above which channels at the edge of the recording are marked as outside of the brain outside_channels_location : "top" | "bottom" | "both", default: "top" - Coeherence+psd. Location of the outside channels. If "top", only the channels at the top of the probe can be + For coherence+psd - location of the outside channels. If "top", only the channels at the top of the probe can be marked as outside channels. If "bottom", only the channels at the bottom of the probe can be marked as outside channels. If "both", both the channels at the top and bottom of the probe can be marked as outside channels n_neighbors : int, default: 11 - Coeherence+psd. Number of channel neighbors to compute median filter (needs to be odd) + For coeherence+psd - number of channel neighbors to compute median filter (needs to be odd) nyquist_threshold : float, default: 0.8 - Coeherence+psd. Frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared + For coherence+psd - frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared with psd_hf_threshold direction : "x" | "y" | "z", default: "y" - Coeherence+psd. The depth dimension + For coherence+psd - the depth dimension highpass_filter_cutoff : float, default: 300 If the recording is not filtered, the cutoff frequency of the highpass filter chunk_duration_s : float, default: 0.5 diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index ed77ec504d..4843df5444 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -34,9 +34,6 @@ class ResampleRecording(BasePreprocessor): The dtype of the returned traces. If None, the dtype of the parent recording is used. skip_checks : bool, default: False If True, checks on sampling frequencies and cutoff filter frequencies are skipped - margin_ms : float, default: 100.0 - Margin in ms on border to avoid border effect - Returns ------- diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 88c7e2109c..74d370b3a9 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -26,7 +26,8 @@ class SilencedPeriodsRecording(BasePreprocessor): noise_levels : array Noise levels if already computed seed : int | None, default: None - Random seed for `get_noise_levels` + Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. + If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. mode : "zeros" | "noise, default: "zeros" Determines what periods are replaced by. Can be one of the following: From bcaafeaefcab0d0bf0c70273e69b8d6f79f33fc4 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 17 Jun 2024 12:31:30 -0600 Subject: [PATCH 108/136] fix most egregorious deprecated behavior and cap version --- pyproject.toml | 2 +- src/spikeinterface/core/core_tools.py | 2 +- .../core/tests/test_jsonification.py | 1 - .../tests/test_template_database.py | 2 +- .../tests/test_highpass_spatial_filter.py | 19 ++++++++------- .../tests/test_interpolate_bad_channels.py | 24 +++++++++++-------- .../sortingcomponents/peak_localization.py | 4 ++-- .../test_waveform_thresholder.py | 2 +- .../waveforms/waveform_thresholder.py | 2 +- 9 files changed, 31 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index dadb677056..9dbf0c0229 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ dependencies = [ - "numpy", + "numpy>=1.20, <2.0", # Minimal needed for np.ptp "threadpoolctl>=3.0.0", "tqdm", "zarr>=2.16,<2.18", diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 5232539422..681392d3f4 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -83,7 +83,7 @@ def default(self, obj): if isinstance(obj, np.generic): return obj.item() - if np.issctype(obj): # Cast numpy datatypes to their names + if isinstance(obj, np.dtype): return np.dtype(obj).name if isinstance(obj, np.ndarray): diff --git a/src/spikeinterface/core/tests/test_jsonification.py b/src/spikeinterface/core/tests/test_jsonification.py index f63cfb16d8..4417ea342f 100644 --- a/src/spikeinterface/core/tests/test_jsonification.py +++ b/src/spikeinterface/core/tests/test_jsonification.py @@ -122,7 +122,6 @@ def test_numpy_dtype_alises_encoding(): # People tend to use this a dtype instead of the proper classes json.dumps(np.int32, cls=SIJsonEncoder) json.dumps(np.float32, cls=SIJsonEncoder) - json.dumps(np.bool_, cls=SIJsonEncoder) # Note that np.bool was deperecated in numpy 1.20.0 def test_recording_encoding(numpy_generated_recording): diff --git a/src/spikeinterface/generation/tests/test_template_database.py b/src/spikeinterface/generation/tests/test_template_database.py index a71faf0683..9e2a013ad0 100644 --- a/src/spikeinterface/generation/tests/test_template_database.py +++ b/src/spikeinterface/generation/tests/test_template_database.py @@ -36,7 +36,7 @@ def test_fetch_templates_database_info(): def test_query_templates_from_database(): templates_info = fetch_templates_database_info() - templates_info = templates_info.iloc[::15] + templates_info = templates_info.iloc[[1, 3, 5]] num_selected = len(templates_info) templates = query_templates_from_database(templates_info) diff --git a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py index 5c843e7c0b..0dd75fd476 100644 --- a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py @@ -8,14 +8,7 @@ import spikeinterface.extractors as se from spikeinterface.core import generate_recording import spikeinterface.widgets as sw - -try: - import spikeglx - import neurodsp.voltage as voltage - - HAVE_IBL_NPIX = True -except ImportError: - HAVE_IBL_NPIX = False +import importlib.util ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) @@ -31,7 +24,10 @@ # ---------------------------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(not HAVE_IBL_NPIX or ON_GITHUB, reason="Only local. Requires ibl-neuropixel install") +@pytest.mark.skipif( + importlib.util.find_spec("neurodsp") is not None or importlib.util.find_spec("spikeglx") or ON_GITHUB, + reason="Only local. Requires ibl-neuropixel install", +) @pytest.mark.parametrize("lagc", [False, 1, 300]) def test_highpass_spatial_filter_real_data(lagc): """ @@ -56,6 +52,9 @@ def test_highpass_spatial_filter_real_data(lagc): use DEBUG = true to visualise. """ + import spikeglx + import neurodsp.voltage as voltage + options = dict(lagc=lagc, ntr_pad=25, ntr_tap=50, butter_kwargs=None) print(options) @@ -146,6 +145,8 @@ def get_ibl_si_data(): """ Set fixture to session to ensure origional data is not changed. """ + import spikeglx + local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0") ibl_recording = spikeglx.Reader( local_path / "Noise4Sam_g0_imec0" / "Noise4Sam_g0_t0.imec0.ap.bin", ignore_warnings=True diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index ad073e40aa..1189f04f7d 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -6,17 +6,10 @@ import spikeinterface.preprocessing as spre import spikeinterface.extractors as se from spikeinterface.core.generate import generate_recording +import importlib.util -try: - import spikeglx - import neurodsp.voltage as voltage - - HAVE_IBL_NPIX = True -except ImportError: - HAVE_IBL_NPIX = False ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) - DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -30,7 +23,10 @@ # ------------------------------------------------------------------------------- -@pytest.mark.skipif(not HAVE_IBL_NPIX or ON_GITHUB, reason="Only local. Requires ibl-neuropixel install") +@pytest.mark.skipif( + importlib.util.find_spec("neurodsp") is not None or importlib.util.find_spec("spikeglx") or ON_GITHUB, + reason="Only local. Requires ibl-neuropixel install", +) def test_compare_real_data_with_ibl(): """ Test SI implementation of bad channel interpolation against native IBL. @@ -43,6 +39,9 @@ def test_compare_real_data_with_ibl(): si_scaled_recordin.get_traces(0) is also close to 1e-2. """ # Download and load data + import spikeglx + import neurodsp.voltage as voltage + local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0") si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap") ibl_recording = spikeglx.Reader( @@ -80,7 +79,10 @@ def test_compare_real_data_with_ibl(): assert np.mean(is_close) > 0.999 -@pytest.mark.skipif(not HAVE_IBL_NPIX, reason="Requires ibl-neuropixel install") +@pytest.mark.skipif( + importlib.util.find_spec("neurodsp") is not None or importlib.util.find_spec("spikeglx") is not None, + reason="Requires ibl-neuropixel install", +) @pytest.mark.parametrize("num_channels", [32, 64]) @pytest.mark.parametrize("sigma_um", [1.25, 40]) @pytest.mark.parametrize("p", [0, -0.5, 1, 5]) @@ -90,6 +92,8 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan Perform an extended test across a range of function inputs to check IBL and SI interpolation results match. """ + import neurodsp.voltage as voltage + recording = generate_recording(num_channels=num_channels, durations=[1]) # distribute default probe locations across 4 shanks if set diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index b06f6fac3e..716eecf123 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -204,7 +204,7 @@ def compute(self, traces, peaks, waveforms): wf = waveforms[idx][:, :, chan_inds] if self.feature == "ptp": - wf_data = wf.ptp(axis=1) + wf_data = np.ptp(wf, axis=1) elif self.feature == "mean": wf_data = wf.mean(axis=1) elif self.feature == "energy": @@ -293,7 +293,7 @@ def compute(self, traces, peaks, waveforms): wf = waveforms[i, :][:, chan_inds] if self.feature == "ptp": - wf_data = wf.ptp(axis=0) + wf_data = np.ptp(wf, axis=0) elif self.feature == "energy": wf_data = np.linalg.norm(wf, axis=0) elif self.feature == "peak_voltage": diff --git a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py index 4f55030283..79a9603b8d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py +++ b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py @@ -37,7 +37,7 @@ def test_waveform_thresholder_ptp( recording, peaks, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs ) - data = tresholded_waveforms.ptp(axis=1) / noise_levels + data = np.ptp(tresholded_waveforms, axis=1) / noise_levels assert np.all(data[data != 0] > 3) diff --git a/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py b/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py index 76d72f3b08..b4c54be6ad 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py +++ b/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py @@ -78,7 +78,7 @@ def __init__( def compute(self, traces, peaks, waveforms): if self.feature == "ptp": - wf_data = waveforms.ptp(axis=1) / self.noise_levels + wf_data = np.ptp(waveforms, axis=1) / self.noise_levels elif self.feature == "mean": wf_data = waveforms.mean(axis=1) / self.noise_levels elif self.feature == "energy": From 61a7c2cae0743d4030ed8c0e4d6089cfd1768e81 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 17 Jun 2024 12:55:15 -0600 Subject: [PATCH 109/136] fix for python api class --- src/spikeinterface/core/core_tools.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 681392d3f4..a1a23aaade 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -83,9 +83,15 @@ def default(self, obj): if isinstance(obj, np.generic): return obj.item() + # # Standard numpy dtypes like np.dtype('int32") are transformed this way if isinstance(obj, np.dtype): return np.dtype(obj).name + # This will transform to a sring canonical representation of the dtype (e.g. np.int32 -> 'int32') + + if isinstance(obj, type) and issubclass(obj, np.generic): + return np.dtype(obj).name + if isinstance(obj, np.ndarray): return obj.tolist() From c82d085d5d116c61ef67ff6148588d8ab7e6eb04 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 17 Jun 2024 14:53:47 -0600 Subject: [PATCH 110/136] bump pickle --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9dbf0c0229..d6a627cb97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ dependencies = [ - "numpy>=1.20, <2.0", # Minimal needed for np.ptp + "numpy>=1.26, <2.0", # 1.20 np.ptp, 1.26 for avoiding pikcling errors when numpy >2.0 "threadpoolctl>=3.0.0", "tqdm", "zarr>=2.16,<2.18", From a014e5e714f33f10216206943de32830d3c20c52 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 17 Jun 2024 16:32:34 -0600 Subject: [PATCH 111/136] Apply suggestions from code review Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d6a627cb97..734fe04962 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ dependencies = [ - "numpy>=1.26, <2.0", # 1.20 np.ptp, 1.26 for avoiding pikcling errors when numpy >2.0 + "numpy>=1.26, <2.0", # 1.20 np.ptp, 1.26 for avoiding pickling errors when numpy >2.0 "threadpoolctl>=3.0.0", "tqdm", "zarr>=2.16,<2.18", From da99a89c873d216cfb15c760d07ca0b19d4f79dc Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 17 Jun 2024 16:33:05 -0600 Subject: [PATCH 112/136] Update src/spikeinterface/core/core_tools.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/core_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index a1a23aaade..0c6a6cbcd5 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -87,7 +87,7 @@ def default(self, obj): if isinstance(obj, np.dtype): return np.dtype(obj).name - # This will transform to a sring canonical representation of the dtype (e.g. np.int32 -> 'int32') + # This will transform to a string canonical representation of the dtype (e.g. np.int32 -> 'int32') if isinstance(obj, type) and issubclass(obj, np.generic): return np.dtype(obj).name From fb179016d87e4f29529b93f2e45eb080681bb9bb Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 17 Jun 2024 17:27:09 -0600 Subject: [PATCH 113/136] add time slice --- src/spikeinterface/core/baserecording.py | 24 ++++++++++++++++++ .../core/tests/test_baserecording.py | 25 +++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 68a7dd744b..184959512b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -679,6 +679,30 @@ def frame_slice(self, start_frame: int, end_frame: int) -> BaseRecording: sub_recording = FrameSliceRecording(self, start_frame=start_frame, end_frame=end_frame) return sub_recording + def time_slice(self, start_time: float, end_time: float) -> BaseRecording: + """ + Returns a new recording with sliced time. Note that this operation is not in place. + + Parameters + ---------- + start_time : float + The start time in seconds. + end_time : float + The end time in seconds. + + Returns + ------- + BaseRecording + The object with sliced time. + """ + + assert self.get_num_segments() == 1, "Time slicing is only supported for single segment recordings." + + start_frame = self.time_to_sample_index(start_time) + end_frame = self.time_to_sample_index(end_time) + + return self.frame_slice(start_frame=start_frame, end_frame=end_frame) + def _select_segments(self, segment_indices): from .segmentutils import SelectSegmentRecording diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index eb6cf7ac12..682881af8a 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -361,5 +361,30 @@ def test_select_channels(): assert np.array_equal(selected_channel_ids, ["a", "c"]) +def test_time_slice(): + # Case with sampling frequency + sampling_frequency = 10_000.0 + recording = generate_recording(durations=[1.0], num_channels=3, sampling_frequency=sampling_frequency) + + sliced_recording_times = recording.time_slice(start_time=0.1, end_time=0.8) + sliced_recording_frames = recording.frame_slice(start_frame=1000, end_frame=8000) + + assert np.allclose(sliced_recording_times.get_traces(), sliced_recording_frames.get_traces()) + + +def test_time_slice_with_time_vector(): + + # Case with time vector + sampling_frequency = 10_000.0 + recording = generate_recording(durations=[1.0], num_channels=3, sampling_frequency=sampling_frequency) + times = 1 + np.arange(0, 10_000) / sampling_frequency + recording.set_times(times=times, segment_index=0, with_warning=False) + + sliced_recording_times = recording.time_slice(start_time=1.1, end_time=1.8) + sliced_recording_frames = recording.frame_slice(start_frame=1000, end_frame=8000) + + assert np.allclose(sliced_recording_times.get_traces(), sliced_recording_frames.get_traces()) + + if __name__ == "__main__": test_BaseRecording() From da83ec9e3b545e5f4d5db26b3eab2a7c3037426d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jun 2024 11:30:37 +0200 Subject: [PATCH 114/136] Set DEV=True --- src/spikeinterface/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/__init__.py b/src/spikeinterface/__init__.py index 97fb95b623..306c12d516 100644 --- a/src/spikeinterface/__init__.py +++ b/src/spikeinterface/__init__.py @@ -30,5 +30,5 @@ # This flag must be set to False for release # This avoids using versioning that contains ".dev0" (and this is a better choice) # This is mainly useful when using run_sorter in a container and spikeinterface install -# DEV_MODE = True -DEV_MODE = False +DEV_MODE = True +# DEV_MODE = False From 14970e1b33a46e9308509eab4b3d6b21b18d957e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jun 2024 11:37:29 +0200 Subject: [PATCH 115/136] Comment linting --- src/spikeinterface/core/core_tools.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 0c6a6cbcd5..664eac169f 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -83,12 +83,11 @@ def default(self, obj): if isinstance(obj, np.generic): return obj.item() - # # Standard numpy dtypes like np.dtype('int32") are transformed this way + # Standard numpy dtypes like np.dtype('int32") are transformed this way if isinstance(obj, np.dtype): return np.dtype(obj).name # This will transform to a string canonical representation of the dtype (e.g. np.int32 -> 'int32') - if isinstance(obj, type) and issubclass(obj, np.generic): return np.dtype(obj).name From 2b35a0880326551ce1d4e179b0af4c8953d873ed Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jun 2024 13:15:14 +0200 Subject: [PATCH 116/136] Update plot_motion and correct_motion_on_peaks --- .../benchmark/benchmark_motion_estimation.py | 5 +- .../sortingcomponents/motion_interpolation.py | 26 ++++---- .../sortingcomponents/motion_utils.py | 26 ++------ .../tests/test_motion_interpolation.py | 2 +- src/spikeinterface/widgets/motion.py | 59 +++++++++++-------- 5 files changed, 55 insertions(+), 63 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 7428629c4a..b353b75817 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -670,11 +670,8 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # peak_locations_corrected = correct_motion_on_peaks( # self.selected_peaks, # self.peak_locations, -# self.recording, # self.motion, -# self.temporal_bins, -# self.spatial_bins, -# direction="y", +# self.recording, # ) # if axes is None: # if show_probe: diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 32c3e706cf..4b3c081f3f 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -7,25 +7,23 @@ from spikeinterface.preprocessing.filter import fix_dtype -def correct_motion_on_peaks( - peaks, - peak_locations, - rec, - motion, -): +def correct_motion_on_peaks(peaks, peak_locations, motion, recording=None, sampling_frequency=None): """ Given the output of estimate_motion(), apply inverse motion on peak locations. Parameters ---------- - peaks: np.array + peaks : np.array peaks vector - peak_locations: np.array + peak_locations : np.array peaks location vector - sampling_frequency: np.array - sampling_frequency of the recording - motion: Motion + motion : Motion The motion object. + recording : Recording | None, default: None + The recording object. If given, this is used to convert sample indices to times. + sampling_frequency : float | None + Sampling_frequency of the recording, required if recording is None. + Returns ------- @@ -33,7 +31,11 @@ def correct_motion_on_peaks( Motion-corrected peak locations """ corrected_peak_locations = peak_locations.copy() - times_s = rec.sample_index_to_time(peaks["sample_index"]) + assert recording is not None or sampling_frequency is not None, "recording or sampling_frequency must be provided" + if recording is not None: + times_s = recording.sample_index_to_time(peaks["sample_index"]) + else: + times_s = peaks["sample_index"] / sampling_frequency for segment_index in range(motion.num_segments): i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py index 1edf484aa4..9bccfae1e2 100644 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -5,24 +5,6 @@ import spikeinterface from spikeinterface.core.core_tools import check_json -# @charlie @sam -# here TODO list for motion object -# * simple test for Motion: DONE -# * save/load Motion DONE -# * make simple test for Motion object with save/load DONE -# * propagate to estimate_motion : DONE -# * handle multi segment in estimate_motion(): maybe in another PR -# * propagate to motion_interpolation.py: DONE -# * propagate to preprocessing/correct_motion(): DONE -# * generate drifting signals for test estimate_motion and interpolate_motion: SIMPLE ONE DONE? -# * uncomment assert in test_estimate_motion (aka debug torch vs numpy diff): DONE -# * delegate times to recording object in -# * estimate motion: DONE -# * correct_motion_on_peaks(): DONE -# * interpolate_motion_on_traces(): DONE -# propagate to benchmark estimate motion -# update plot_motion() dans widget - class Motion: """ @@ -30,19 +12,21 @@ class Motion: Parameters ---------- - displacement: numpy array 2d or list of + displacement : numpy array 2d or list of Motion estimate in um. List is the number of segment. For each semgent : * shape (temporal bins, spatial bins) * motion.shape[0] = temporal_bins.shape[0] * motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) - temporal_bins_s: numpy.array 1d or list of + temporal_bins_s : numpy.array 1d or list of temporal bins (bin center) - spatial_bins_um: numpy.array 1d + spatial_bins_um : numpy.array 1d Windows center. spatial_bins_um.shape[0] == displacement.shape[1] If rigid then spatial_bins_um.shape[0] == 1 + direction : str, default: 'y' + Direction of the motion. interpolation_method : str How to determine the displacement between bin centers? See the docs for scipy.interpolate.RegularGridInterpolator for options. diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py index c2be600586..3628d534a4 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py @@ -46,8 +46,8 @@ def test_correct_motion_on_peaks(): corrected_peak_locations = correct_motion_on_peaks( peaks, peak_locations, - rec, motion, + recording=rec, ) # print(corrected_peak_locations) assert np.any(corrected_peak_locations["y"] != 0) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 9d64c89e46..f4ed7fecf5 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -11,10 +11,12 @@ class MotionWidget(BaseWidget): Parameters ---------- - motion_info: dict + motion: dict The motion info return by correct_motion() or load back with load_motion_info() recording : RecordingExtractor, default: None The recording extractor object (only used to get "real" times) + segment_index : int, default: 0 + The segment index to display. sampling_frequency : float, default: None The sampling frequency (needed if recording is None) depth_lim : tuple or None, default: None @@ -36,6 +38,7 @@ class MotionWidget(BaseWidget): def __init__( self, motion_info, + segment_index=0, recording=None, depth_lim=None, motion_lim=None, @@ -47,11 +50,12 @@ def __init__( backend=None, **backend_kwargs, ): - times = recording.get_times() if recording is not None else None + times = recording.get_times(segment_index=segment_index) if recording is not None else None plot_data = dict( sampling_frequency=motion_info["parameters"]["sampling_frequency"], times=times, + segment_index=segment_index, depth_lim=depth_lim, motion_lim=motion_lim, color_amplitude=color_amplitude, @@ -59,6 +63,7 @@ def __init__( amplitude_cmap=amplitude_cmap, amplitude_clim=amplitude_clim, amplitude_alpha=amplitude_alpha, + recording=recording, **motion_info, ) @@ -73,16 +78,20 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) - assert backend_kwargs["axes"] is None - assert backend_kwargs["ax"] is None + assert backend_kwargs["axes"] is None, "axes argument is not allowed in MotionWidget" + assert backend_kwargs["ax"] is None, "ax argument is not allowed in MotionWidget" self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) fig = self.figure fig.clear() - is_rigid = dp.motion.shape[1] == 1 + motion_array = dp.motion.displacement[dp.segment_index] + temporal_bins_s = dp.motion.temporal_bins_s[dp.segment_index] + spatial_bins_um = dp.motion.spatial_bins_um - gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.3) + is_rigid = motion_array.shape[1] == 1 + + gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.5) ax0 = fig.add_subplot(gs[0, 0]) ax1 = fig.add_subplot(gs[0, 1]) ax2 = fig.add_subplot(gs[1, 0]) @@ -92,30 +101,29 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1.sharey(ax0) if dp.motion_lim is None: - motion_lim = np.max(np.abs(dp.motion)) * 1.05 + motion_lim = np.max(np.abs(motion_array)) * 1.05 else: motion_lim = dp.motion_lim if dp.times is None: - temporal_bins_plot = dp.temporal_bins + temporal_bins_plot = temporal_bins_s x = dp.peaks["sample_index"] / dp.sampling_frequency else: # use real times and adjust temporal bins with t_start - temporal_bins_plot = dp.temporal_bins + dp.times[0] + temporal_bins_plot = temporal_bins_s + dp.times[0] x = dp.times[dp.peaks["sample_index"]] corrected_location = correct_motion_on_peaks( dp.peaks, dp.peak_locations, - dp.sampling_frequency, dp.motion, - dp.temporal_bins, - dp.spatial_bins, - direction="y", + dp.recording, + dp.sampling_frequency, ) + dim = ["x", "y", "z"][dp.motion.dim] - y = dp.peak_locations["y"] - y2 = corrected_location["y"] + y = dp.peak_locations[dim] + y2 = corrected_location[dim] if dp.scatter_decimate is not None: x = x[:: dp.scatter_decimate] y = y[:: dp.scatter_decimate] @@ -149,37 +157,38 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax0.set_ylim(*dp.depth_lim) ax0.set_title("Peak depth") ax0.set_xlabel("Times [s]") - ax0.set_ylabel("Depth [um]") + ax0.set_ylabel("Depth [$\\mu$m]") ax1.scatter(x, y2, s=1, **color_kwargs) ax1.set_xlabel("Times [s]") - ax1.set_ylabel("Depth [um]") + ax1.set_ylabel("Depth [$\\mu$m]") ax1.set_title("Corrected peak depth") - ax2.plot(temporal_bins_plot, dp.motion, alpha=0.2, color="black") - ax2.plot(temporal_bins_plot, np.mean(dp.motion, axis=1), color="C0") + ax2.plot(temporal_bins_plot, motion_array, alpha=0.2, color="black") + ax2.plot(temporal_bins_plot, np.mean(motion_array, axis=1), color="C0") ax2.set_ylim(-motion_lim, motion_lim) - ax2.set_ylabel("Motion [um]") + ax2.set_ylabel("Motion [$\\mu$m]") + ax2.set_xlabel("Times [s]") ax2.set_title("Motion vectors") axes = [ax0, ax1, ax2] if not is_rigid: im = ax3.imshow( - dp.motion.T, + motion_array.T, aspect="auto", origin="lower", extent=( temporal_bins_plot[0], temporal_bins_plot[-1], - dp.spatial_bins[0], - dp.spatial_bins[-1], + spatial_bins_um[0], + spatial_bins_um[-1], ), ) im.set_clim(-motion_lim, motion_lim) cbar = fig.colorbar(im) - cbar.ax.set_xlabel("motion [um]") + cbar.ax.set_ylabel("Motion [$\\mu$m]") ax3.set_xlabel("Times [s]") - ax3.set_ylabel("Depth [um]") + ax3.set_ylabel("Depth [$\\mu$m]") ax3.set_title("Motion vectors") axes.append(ax3) self.axes = np.array(axes) From ecbe9a399d2b6432bef4d265620cc34861d88a2c Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 18 Jun 2024 08:43:57 -0400 Subject: [PATCH 117/136] fix is filtered check --- src/spikeinterface/sorters/basesorter.py | 2 +- src/spikeinterface/sorters/external/mountainsort5.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 2f87065d9f..8c52626703 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -183,7 +183,7 @@ def set_params_to_folder(cls, recording, output_folder, new_params, verbose): # custom check params params = cls._check_params(recording, output_folder, params) # common check : filter warning - if recording.is_filtered and cls._check_apply_filter_in_params(params) and verbose: + if recording.is_filtered() and cls._check_apply_filter_in_params(params) and verbose: print(f"Warning! The recording is already filtered, but {cls.sorter_name} filter is enabled") # dump parameters inside the folder with json diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index 6fa68de190..cf6933c9e6 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -120,7 +120,6 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort5 as ms5 - from mountainsort5.util import create_cached_recording recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) if recording is None: From eb31e08f3c03e082960f3d95d1bb3fcdc68c96bd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jun 2024 15:38:04 +0200 Subject: [PATCH 118/136] Add self._tmp_recording to store temp recording in analyzer --- src/spikeinterface/core/recording_tools.py | 8 +++--- src/spikeinterface/core/sortinganalyzer.py | 19 ++++++++----- .../core/tests/test_sortinganalyzer.py | 28 +++++++++++-------- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 81fb0b3eb8..2b193f4164 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -919,7 +919,7 @@ def get_rec_attributes(recording): return rec_attributes -def check_recording_attributes_match(recording1, recording2_attributes, skip_properties=True) -> bool: +def check_recording_attributes_match(recording1, recording2_attributes) -> bool: """ Check if two recordings have the same attributes @@ -937,9 +937,9 @@ def check_recording_attributes_match(recording1, recording2_attributes, skip_pro """ recording1_attributes = get_rec_attributes(recording1) recording2_attributes = deepcopy(recording2_attributes) - if skip_properties: - recording1_attributes.pop("properties") - recording2_attributes.pop("properties") + recording1_attributes.pop("properties") + recording2_attributes.pop("properties") + return ( np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"]) and recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"] diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 5862c247a1..706c1d5a73 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -203,6 +203,8 @@ def __init__( self.format = format self.sparsity = sparsity self.return_scaled = return_scaled + # this is used to store temporary recording + self._tmp_recording = None # extensions are not loaded at init self.extensions = dict() @@ -619,15 +621,13 @@ def set_temporary_recording(self, recording: BaseRecording): The recording object to set as temporary recording. """ # check that recording is compatible - assert check_recording_attributes_match( - recording, self.rec_attributes, skip_properties=True - ), "Recording attributes do not match." + assert check_recording_attributes_match(recording, self.rec_attributes), "Recording attributes do not match." assert np.array_equal( recording.get_channel_locations(), self.get_channel_locations() ), "Recording channel locations do not match." if self._recording is not None: - warnings.warn("SortingAnalyzer recording is already set. " "The current recording is temporarily replaced.") - self._recording = recording + warnings.warn("SortingAnalyzer recording is already set. The current recording is temporarily replaced.") + self._tmp_recording = recording def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> "SortingAnalyzer": """ @@ -635,7 +635,9 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> """ if self.has_recording(): - recording = self.recording + recording = self._recording + elif self.has_temporary_recording(): + recording = self._tmp_recording else: recording = None @@ -754,7 +756,7 @@ def is_read_only(self) -> bool: def recording(self) -> BaseRecording: if not self.has_recording(): raise ValueError("SortingAnalyzer could not load the recording") - return self._recording + return self._tmp_recording or self._recording @property def channel_ids(self) -> np.ndarray: @@ -771,6 +773,9 @@ def unit_ids(self) -> np.ndarray: def has_recording(self) -> bool: return self._recording is not None + def has_temporary_recording(self) -> bool: + return self._tmp_recording is not None + def is_sparse(self) -> bool: return self.sparsity is not None diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index e3003e693f..d780932146 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -19,7 +19,7 @@ import numpy as np -def _get_dataset(): +def get_dataset(): recording, sorting = generate_ground_truth_recording( durations=[30.0], sampling_frequency=16000.0, @@ -33,12 +33,12 @@ def _get_dataset(): @pytest.fixture(scope="module") -def get_dataset(): - return _get_dataset() +def dataset(): + return get_dataset() -def test_SortingAnalyzer_memory(tmp_path, get_dataset): - recording, sorting = get_dataset +def test_SortingAnalyzer_memory(tmp_path, dataset): + recording, sorting = dataset sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) @@ -57,8 +57,8 @@ def test_SortingAnalyzer_memory(tmp_path, get_dataset): assert not sorting_analyzer.return_scaled -def test_SortingAnalyzer_binary_folder(tmp_path, get_dataset): - recording, sorting = get_dataset +def test_SortingAnalyzer_binary_folder(tmp_path, dataset): + recording, sorting = dataset folder = tmp_path / "test_SortingAnalyzer_binary_folder" if folder.exists(): @@ -87,8 +87,8 @@ def test_SortingAnalyzer_binary_folder(tmp_path, get_dataset): _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) -def test_SortingAnalyzer_zarr(tmp_path, get_dataset): - recording, sorting = get_dataset +def test_SortingAnalyzer_zarr(tmp_path, dataset): + recording, sorting = dataset folder = tmp_path / "test_SortingAnalyzer_zarr.zarr" if folder.exists(): @@ -108,12 +108,18 @@ def test_SortingAnalyzer_zarr(tmp_path, get_dataset): ) -def test_SortingAnalyzer_tmp_recording(get_dataset): - recording, sorting = get_dataset +def test_SortingAnalyzer_tmp_recording(dataset): + recording, sorting = dataset recording_cached = recording.save(mode="memory") sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) sorting_analyzer.set_temporary_recording(recording_cached) + assert sorting_analyzer.has_temporary_recording() + # check that saving as uses the original recording + sorting_analyzer_saved = sorting_analyzer.save_as(format="memory") + assert sorting_analyzer_saved.has_recording() + assert not sorting_analyzer_saved.has_temporary_recording() + assert isinstance(sorting_analyzer_saved.recording, type(recording)) recording_sliced = recording.channel_slice(recording.channel_ids[:-1]) From 68e0a339c074769a5d36946b098b70b93d0cbbd8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jun 2024 15:48:15 +0200 Subject: [PATCH 119/136] _tmp_recording -> temporary_recording --- src/spikeinterface/core/sortinganalyzer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 706c1d5a73..c1080b3d4b 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -204,7 +204,7 @@ def __init__( self.sparsity = sparsity self.return_scaled = return_scaled # this is used to store temporary recording - self._tmp_recording = None + self._temporary_recording = None # extensions are not loaded at init self.extensions = dict() @@ -627,7 +627,7 @@ def set_temporary_recording(self, recording: BaseRecording): ), "Recording channel locations do not match." if self._recording is not None: warnings.warn("SortingAnalyzer recording is already set. The current recording is temporarily replaced.") - self._tmp_recording = recording + self._temporary_recording = recording def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> "SortingAnalyzer": """ @@ -637,7 +637,7 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> if self.has_recording(): recording = self._recording elif self.has_temporary_recording(): - recording = self._tmp_recording + recording = self._temporary_recording else: recording = None @@ -756,7 +756,7 @@ def is_read_only(self) -> bool: def recording(self) -> BaseRecording: if not self.has_recording(): raise ValueError("SortingAnalyzer could not load the recording") - return self._tmp_recording or self._recording + return self._temporary_recording or self._recording @property def channel_ids(self) -> np.ndarray: @@ -774,7 +774,7 @@ def has_recording(self) -> bool: return self._recording is not None def has_temporary_recording(self) -> bool: - return self._tmp_recording is not None + return self._temporary_recording is not None def is_sparse(self) -> bool: return self.sparsity is not None From 025b86f47d8ba2e277eb86b7421b3c16b1de1e3c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jun 2024 15:51:24 +0200 Subject: [PATCH 120/136] check_recording_attributes_match -> do_recording_attributes_match --- src/spikeinterface/core/recording_tools.py | 2 +- src/spikeinterface/core/sortinganalyzer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 2b193f4164..024382dea2 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -919,7 +919,7 @@ def get_rec_attributes(recording): return rec_attributes -def check_recording_attributes_match(recording1, recording2_attributes) -> bool: +def do_recording_attributes_match(recording1, recording2_attributes) -> bool: """ Check if two recordings have the same attributes diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index c1080b3d4b..bdb1a6c248 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -22,7 +22,7 @@ from .basesorting import BaseSorting from .base import load_extractor -from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, check_recording_attributes_match +from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match from .core_tools import check_json, retrieve_importing_provenance from .job_tools import split_job_kwargs from .numpyextractors import NumpySorting @@ -621,7 +621,7 @@ def set_temporary_recording(self, recording: BaseRecording): The recording object to set as temporary recording. """ # check that recording is compatible - assert check_recording_attributes_match(recording, self.rec_attributes), "Recording attributes do not match." + assert do_recording_attributes_match(recording, self.rec_attributes), "Recording attributes do not match." assert np.array_equal( recording.get_channel_locations(), self.get_channel_locations() ), "Recording channel locations do not match." From 88ec0717e3694314e4e50d895da8257a5db26a2e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jun 2024 16:25:22 +0200 Subject: [PATCH 121/136] Improve check in recording property --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index bdb1a6c248..46d02099d5 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -754,7 +754,7 @@ def is_read_only(self) -> bool: @property def recording(self) -> BaseRecording: - if not self.has_recording(): + if not self.has_recording() and not self.has_temporary_recording(): raise ValueError("SortingAnalyzer could not load the recording") return self._temporary_recording or self._recording From 6ed34caac56ad20b9a311b26a4b5f2537a269caa Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Tue, 18 Jun 2024 12:27:22 -0400 Subject: [PATCH 122/136] add miniconda latest. --- .github/workflows/installation-tips-test.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/installation-tips-test.yml b/.github/workflows/installation-tips-test.yml index cbe313b12e..e83399cf7c 100644 --- a/.github/workflows/installation-tips-test.yml +++ b/.github/workflows/installation-tips-test.yml @@ -28,8 +28,9 @@ jobs: with: python-version: '3.10' - name: Test Conda Environment Creation - uses: conda-incubator/setup-miniconda@v2.2.0 + uses: conda-incubator/setup-miniconda@v3 with: + miniconda-version: "latest" environment-file: ./installation_tips/full_spikeinterface_environment_${{ matrix.label }}.yml activate-environment: si_env - name: Check Installation Tips From 0b3065b58653392b33853f044a23c7936a66d3c5 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 18 Jun 2024 18:51:35 +0100 Subject: [PATCH 123/136] Add _common_filter_docs --- src/spikeinterface/preprocessing/filter.py | 37 ++++++++++------------ 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index ffad9a2029..8236acf848 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -8,14 +8,16 @@ from ..core import get_chunk_with_margin -_common_filter_docs = """**filter_kwargs : keyword arguments for parallel processing: - - * filter_order : order - The order of the filter - * filter_mode : "sos or "ba" - "sos" is bi quadratic and more stable than ab so thery are prefered. - * ftype : str - Filter type for iirdesign ("butter" / "cheby1" / ... all possible of scipy.signal.iirdesign) +_common_filter_docs = """**filter_kwargs : dict + Certain keyword arguments for `scipy.signal` filters: + filter_order : order + The order of the filter + filter_mode : "sos" | "ba", default: "sos" + Filter form of the filter coefficients: + - second-order sections ("sos") + - numerator/denominator : ("ba") + ftype : str, default: "butter" + Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". """ @@ -39,20 +41,13 @@ class FilterRecording(BasePreprocessor): Type of the filter margin_ms : float, default: 5.0 Margin in ms on border to avoid border effect - filter_mode : "sos" | "ba", default: "sos" - Filter form of the filter coefficients: - - second-order sections ("sos") - - numerator/denominator : ("ba") coeff : array | None, default: None Filter coefficients in the filter_mode form. dtype : dtype or None, default: None The dtype of the returned traces. If None, the dtype of the parent recording is used add_reflect_padding : Bool, default False If True, uses a left and right margin during calculation. - ftype : str | None, default: "butter" - The type of IIR filter to design, used in `scipy.signal.iirfilter`. - filter_order : int, default: 5 - The order of the filter, used in `scipy.signal.iirfilter`. + {} Returns ------- @@ -183,8 +178,7 @@ class BandpassFilterRecording(FilterRecording): Margin in ms on border to avoid border effect dtype : dtype or None The dtype of the returned traces. If None, the dtype of the parent recording is used - **filter_kwargs : dict - Keyword arguments for `spikeinterface.preprocessing.FilterRecording` class. + {} Returns ------- @@ -219,8 +213,7 @@ class HighpassFilterRecording(FilterRecording): Margin in ms on border to avoid border effect dtype : dtype or None The dtype of the returned traces. If None, the dtype of the parent recording is used - **filter_kwargs : dict - Keyword arguments for `spikeinterface.preprocessing.FilterRecording` class. + {} Returns ------- @@ -297,6 +290,10 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter") highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter") +filter.__doc__ = filter.__doc__.format(_common_filter_docs) +bandpass_filter.__doc__ = bandpass_filter.__doc__.format(_common_filter_docs) +highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs) + def fix_dtype(recording, dtype): if dtype is None: From 257950d5859521730ed1da746b6fd32b7b6335bb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Jun 2024 06:36:40 +0000 Subject: [PATCH 124/136] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 4 ++- .../benchmark/benchmark_motion_estimation.py | 20 ++++++------- .../benchmark_motion_interpolation.py | 4 +-- .../tests/test_benchmark_motion_estimation.py | 1 + .../test_benchmark_motion_interpolation.py | 1 + .../sortingcomponents/motion_utils.py | 4 +-- .../tests/test_motion_utils.py | 11 ++++--- src/spikeinterface/widgets/motion.py | 30 ++++++++----------- .../widgets/tests/test_widgets.py | 21 +++++-------- 9 files changed, 41 insertions(+), 55 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 1a064dcb31..b5df0f1059 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -314,7 +314,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): motion_info = load_motion_info(motion_folder) motion = motion_info["motion"] - max_motion = max(np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement))) + max_motion = max( + np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement)) + ) merging_params["maximum_distance_um"] = max(50, 2 * max_motion) # peak_sign = params['detection'].get('peak_sign', 'neg') diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 1408a5cb32..55ef21de9d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -55,13 +55,9 @@ def get_gt_motion_from_unit_displacement( gt_displacement[t, :] = f(spatial_bins_um) gt_motion = Motion( - gt_displacement, - temporal_bins_s, - spatial_bins_um, - direction="xyz"[direction_dim], - interpolation_method="linear" + gt_displacement, temporal_bins_s, spatial_bins_um, direction="xyz"[direction_dim], interpolation_method="linear" ) - + return gt_motion @@ -102,9 +98,7 @@ def run(self, **job_kwargs): t2 = time.perf_counter() peak_locations = localize_peaks(self.recording, selected_peaks, **p["localize_kwargs"], **job_kwargs) t3 = time.perf_counter() - motion = estimate_motion( - self.recording, selected_peaks, peak_locations, **p["estimate_motion_kwargs"] - ) + motion = estimate_motion(self.recording, selected_peaks, peak_locations, **p["estimate_motion_kwargs"]) t4 = time.perf_counter() step_run_times = dict( @@ -263,8 +257,12 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): aspect="auto", interpolation="nearest", origin="lower", - extent=(motion.temporal_bins_s[0][0], motion.temporal_bins_s[0][-1], - motion.spatial_bins_um[0], motion.spatial_bins_um[-1]), + extent=( + motion.temporal_bins_s[0][0], + motion.temporal_bins_s[0][-1], + motion.spatial_bins_um[0], + motion.spatial_bins_um[-1], + ), ) plt.colorbar(im, ax=ax, label="error") ax.set_ylabel("depth (um)") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index 5688d2eaf3..a6ff05fc55 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -44,9 +44,7 @@ def run(self, **job_kwargs): recording = self.drifting_recording elif self.params["recording_source"] == "corrected": correct_motion_kwargs = self.params["correct_motion_kwargs"] - recording = InterpolateMotionRecording( - self.drifting_recording, self.motion, **correct_motion_kwargs - ) + recording = InterpolateMotionRecording(self.drifting_recording, self.motion, **correct_motion_kwargs) else: raise ValueError("recording_source") diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py index 14a5fe9138..526cc2e92f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py @@ -70,6 +70,7 @@ def test_benchmark_motion_estimaton(create_cache_folder): study.plot_summary_errors() import matplotlib.pyplot as plt + plt.show() diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py index 07eb35b693..6d80d027f2 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py @@ -134,6 +134,7 @@ def test_benchmark_motion_interpolation(create_cache_folder): study.plot_sorting_accuracy(mode="depth", mode_best_merge=True) import matplotlib.pyplot as plt + plt.show() diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py index d4f0bb93b5..26d4b35b1a 100644 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -220,11 +220,11 @@ def __eq__(self, other): return False return True - + def copy(self): return Motion( self.displacement.copy(), self.temporal_bins_s.copy(), self.spatial_bins_um.copy(), - interpolation_method=self.interpolation_method + interpolation_method=self.interpolation_method, ) diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py index 1542c8531a..0b67be39c0 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py @@ -14,29 +14,28 @@ def make_fake_motion(): - displacement_sampling_frequency = 5. - spatial_bins_um = np.array([100.0, 200.0, 300., 400.]) + displacement_sampling_frequency = 5.0 + spatial_bins_um = np.array([100.0, 200.0, 300.0, 400.0]) displacement_vector = make_one_displacement_vector( drift_mode="zigzag", duration=50.0, amplitude_factor=1.0, displacement_sampling_frequency=displacement_sampling_frequency, - period_s=25., + period_s=25.0, ) temporal_bins_s = np.arange(displacement_vector.size) / displacement_sampling_frequency displacement = np.zeros((temporal_bins_s.size, spatial_bins_um.size)) - + n = spatial_bins_um.size for i in range(n): - displacement[:, i] = displacement_vector * ((i +1 ) / n) + displacement[:, i] = displacement_vector * ((i + 1) / n) motion = Motion(displacement, temporal_bins_s, spatial_bins_um, direction="y") return motion - def test_Motion(): temporal_bins_s = np.arange(0.0, 10.0, 1.0) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index b98e619bf7..dcb7b26f7e 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -4,6 +4,7 @@ from .base import BaseWidget, to_attr + class MotionWidget(BaseWidget): """ Plot the Motion object @@ -18,6 +19,7 @@ class MotionWidget(BaseWidget): How to plot map or lines. "auto" make it automatic if the number of depth is too high. """ + def __init__( self, motion, @@ -26,10 +28,11 @@ def __init__( motion_lim=None, backend=None, **backend_kwargs, - - ): + ): if isinstance(motion, dict): - raise ValueError("The API has changed, plot_motion() used Motion object now, maybe you want plot_motion_info(motion_info)") + raise ValueError( + "The API has changed, plot_motion() used Motion object now, maybe you want plot_motion_info(motion_info)" + ) if segment_index is None: if len(motion.displacement) == 1: @@ -43,7 +46,7 @@ def __init__( mode=mode, ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt @@ -59,7 +62,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - displacement = motion.displacement[dp.segment_index] temporal_bins_s = motion.temporal_bins_s[dp.segment_index] depth = motion.spatial_bins_um @@ -69,7 +71,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: motion_lim = dp.motion_lim - ax = self.ax fig = self.figure if dp.mode == "line": @@ -84,10 +85,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): aspect="auto", origin="lower", extent=(temporal_bins_s[0], temporal_bins_s[-1], depth[0], depth[-1]), - cmap="PiYG" + cmap="PiYG", ) im.set_clim(-motion_lim, motion_lim) - + cbar = fig.colorbar(im) cbar.ax.set_ylabel("motion [um]") ax.set_xlabel("Times [s]") @@ -106,7 +107,7 @@ class MotionInfoWidget(BaseWidget): ---------- motion_info : dict The motion info return by correct_motion() or load back with load_motion_info() - segment_index: + segment_index: recording : RecordingExtractor, default: None The recording extractor object (only used to get "real" times) @@ -145,7 +146,7 @@ def __init__( backend=None, **backend_kwargs, ): - + motion = motion_info["motion"] if segment_index is None: if len(motion.displacement) == 1: @@ -193,7 +194,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): motion = dp.motion - displacement = motion.displacement[dp.segment_index] temporal_bins_s = motion.temporal_bins_s[dp.segment_index] spatial_bins_um = motion.spatial_bins_um @@ -203,7 +203,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: motion_lim = dp.motion_lim - is_rigid = displacement.shape[1] == 1 gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.5) @@ -223,12 +222,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # temporal_bins_plot = dp.temporal_bins + dp.times[0] x = dp.times[dp.peaks["sample_index"]] - corrected_location = correct_motion_on_peaks( - dp.peaks, - dp.peak_locations, - dp.recording, - dp.motion - ) + corrected_location = correct_motion_on_peaks(dp.peaks, dp.peak_locations, dp.recording, dp.motion) dim = ["x", "y", "z"][dp.motion.dim] y = dp.peak_locations[motion.direction] diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 3e3e432817..0198e24626 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -579,47 +579,40 @@ def test_plot_multicomparison(self): _, axes = plt.subplots(len(mcmp.object_list), 1) sw.plot_multicomparison_agreement_by_sorter(mcmp, axes=axes) - + def test_plot_motion(self): from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion + motion = make_fake_motion() possible_backends = list(sw.MotionWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_motion(motion, backend=backend, mode='line') - sw.plot_motion(motion, backend=backend, mode='map') + sw.plot_motion(motion, backend=backend, mode="line") + sw.plot_motion(motion, backend=backend, mode="map") def test_plot_motion_info(self): from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion - motion = make_fake_motion() rng = np.random.default_rng(seed=2205) peak_locations = np.zeros(self.peaks.size, dtype=[("x", "float64"), ("y", "float64")]) - peak_locations['y'] = rng.uniform(motion.spatial_bins_um[0], - motion.spatial_bins_um[-1], - size=self.peaks.size) - + peak_locations["y"] = rng.uniform(motion.spatial_bins_um[0], motion.spatial_bins_um[-1], size=self.peaks.size) + motion_info = dict( motion=motion, - parameters=dict(sampling_frequency=30000.), + parameters=dict(sampling_frequency=30000.0), run_times=dict(), peaks=self.peaks, peak_locations=peak_locations, ) - possible_backends = list(sw.MotionWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_motion_info(motion_info, recording=self.recording, backend=backend) - - - - if __name__ == "__main__": # unittest.main() import matplotlib.pyplot as plt From 3eafcb35c876ccafd6cf9de62c87165105471c1c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jun 2024 09:29:57 +0200 Subject: [PATCH 125/136] fix --- .../sortingcomponents/motion_interpolation.py | 9 +++------ .../sortingcomponents/tests/test_motion_interpolation.py | 2 +- src/spikeinterface/widgets/motion.py | 7 +++---- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 4b3c081f3f..935b574565 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -7,7 +7,7 @@ from spikeinterface.preprocessing.filter import fix_dtype -def correct_motion_on_peaks(peaks, peak_locations, motion, recording=None, sampling_frequency=None): +def correct_motion_on_peaks(peaks, peak_locations, motion, recording): """ Given the output of estimate_motion(), apply inverse motion on peak locations. @@ -19,11 +19,8 @@ def correct_motion_on_peaks(peaks, peak_locations, motion, recording=None, sampl peaks location vector motion : Motion The motion object. - recording : Recording | None, default: None - The recording object. If given, this is used to convert sample indices to times. - sampling_frequency : float | None - Sampling_frequency of the recording, required if recording is None. - + recording : Recording + The recording object. This is used to convert sample indices to times. Returns ------- diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py index 2c6ff7ecdd..cb26560272 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py @@ -42,7 +42,7 @@ def test_correct_motion_on_peaks(): peaks, peak_locations, motion, - recording=rec, + rec, ) # print(corrected_peak_locations) assert np.any(corrected_peak_locations["y"] != 0) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index b98e619bf7..2f1f9d4adf 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -106,8 +106,8 @@ class MotionInfoWidget(BaseWidget): ---------- motion_info : dict The motion info return by correct_motion() or load back with load_motion_info() - segment_index: - + segment_index: int, default: None + The segment index to display. recording : RecordingExtractor, default: None The recording extractor object (only used to get "real" times) segment_index : int, default: 0 @@ -226,10 +226,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): corrected_location = correct_motion_on_peaks( dp.peaks, dp.peak_locations, + dp.motion, dp.recording, - dp.motion ) - dim = ["x", "y", "z"][dp.motion.dim] y = dp.peak_locations[motion.direction] y2 = corrected_location[motion.direction] From 2d72c96411fd892d0eb3cb030ef0522712a23786 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 19 Jun 2024 09:20:28 +0100 Subject: [PATCH 126/136] Change filter kwargs --- src/spikeinterface/preprocessing/filter.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 8236acf848..6a1733c57c 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -47,13 +47,19 @@ class FilterRecording(BasePreprocessor): The dtype of the returned traces. If None, the dtype of the parent recording is used add_reflect_padding : Bool, default False If True, uses a left and right margin during calculation. - {} + filter_order : order + The order of the filter for `scipy.signal.iirfilter` + filter_mode : "sos" | "ba", default: "sos" + Filter form of the filter coefficients for `scipy.signal.iirfilter`: + - second-order sections ("sos") + - numerator/denominator : ("ba") + ftype : str, default: "butter" + Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". Returns ------- filter_recording : FilterRecording The filtered recording extractor object - """ name = "filter" @@ -290,7 +296,6 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter") highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter") -filter.__doc__ = filter.__doc__.format(_common_filter_docs) bandpass_filter.__doc__ = bandpass_filter.__doc__.format(_common_filter_docs) highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs) From 225269dc1f36ade930bedb4b331f92db75e48d23 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jun 2024 10:28:05 +0200 Subject: [PATCH 127/136] more fix --- doc/modules/motion_correction.rst | 22 ++++++++----------- .../sortingcomponents/motion_interpolation.py | 12 +++++----- src/spikeinterface/widgets/motion.py | 11 +++++----- 3 files changed, 19 insertions(+), 26 deletions(-) diff --git a/doc/modules/motion_correction.rst b/doc/modules/motion_correction.rst index 8be2456caa..af81cb42d1 100644 --- a/doc/modules/motion_correction.rst +++ b/doc/modules/motion_correction.rst @@ -163,21 +163,19 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte max_distance_um=150.0, **job_kwargs) # Step 2: motion inference - motion, temporal_bins, spatial_bins = estimate_motion(recording=rec, - peaks=peaks, - peak_locations=peak_locations, - method="decentralized", - direction="y", - bin_duration_s=2.0, - bin_um=5.0, - win_step_um=50.0, - win_sigma_um=150.0) + motion = estimate_motion(recording=rec, + peaks=peaks, + peak_locations=peak_locations, + method="decentralized", + direction="y", + bin_duration_s=2.0, + bin_um=5.0, + win_step_um=50.0, + win_sigma_um=150.0) # Step 3: motion interpolation # this step is lazy rec_corrected = interpolate_motion(recording=rec, motion=motion, - temporal_bins=temporal_bins, - spatial_bins=spatial_bins, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=30.) @@ -220,8 +218,6 @@ different preprocessing chains: one for motion correction and one for spike sort rec_corrected2 = interpolate_motion( recording=rec2, motion=motion_info['motion'], - temporal_bins=motion_info['temporal_bins'], - spatial_bins=motion_info['spatial_bins'], **motion_info['parameters']['interpolate_motion_kwargs']) sorting = run_sorter(sorter_name="montainsort5", recording=rec_corrected2) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 935b574565..203eafbb6e 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -28,13 +28,10 @@ def correct_motion_on_peaks(peaks, peak_locations, motion, recording): Motion-corrected peak locations """ corrected_peak_locations = peak_locations.copy() - assert recording is not None or sampling_frequency is not None, "recording or sampling_frequency must be provided" - if recording is not None: - times_s = recording.sample_index_to_time(peaks["sample_index"]) - else: - times_s = peaks["sample_index"] / sampling_frequency + for segment_index in range(motion.num_segments): + times_s = recording.sample_index_to_time(peaks["sample_index"], segment_index=segment_index) i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) spike_times = times_s[i0:i1] @@ -368,7 +365,7 @@ def __init__( if interpolation_time_bin_centers_s is None: # in this case, interpolation_time_bin_size_s is set. s_end = parent_segment.get_num_samples() - t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end])) + t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end]), segment_index=segment_index) halfbin = interpolation_time_bin_size_s / 2.0 segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s) else: @@ -441,11 +438,12 @@ def get_traces(self, start_frame, end_frame, channel_indices): times, self.channel_locations, self.motion, + segment_index=self.segment_index, channel_inds=self.channel_inds, spatial_interpolation_method=self.spatial_interpolation_method, spatial_interpolation_kwargs=self.spatial_interpolation_kwargs, interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s, - segment_index=self.segment_index, + ) if channel_indices is not None: diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 2f1f9d4adf..7811415614 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -10,11 +10,11 @@ class MotionWidget(BaseWidget): Parameters ---------- - motion: Motion + motion : Motion The motion object - segment_index: None | int + segment_index : None | int If Motion is multi segment, the must be not None - mode: "auto" | "line" | "map" + mode : "auto" | "line" | "map" How to plot map or lines. "auto" make it automatic if the number of depth is too high. """ @@ -35,7 +35,7 @@ def __init__( if len(motion.displacement) == 1: segment_index = 0 else: - raise ValueError("plot motion : teh Motion object is multi segment you must provide segmentindex=XX") + raise ValueError("plot motion : the Motion object is multi segment you must provide segment_index=XX") plot_data = dict( motion=motion, @@ -106,7 +106,7 @@ class MotionInfoWidget(BaseWidget): ---------- motion_info : dict The motion info return by correct_motion() or load back with load_motion_info() - segment_index: int, default: None + segment_index : int, default: None The segment index to display. recording : RecordingExtractor, default: None The recording extractor object (only used to get "real" times) @@ -166,7 +166,6 @@ def __init__( amplitude_cmap=amplitude_cmap, amplitude_clim=amplitude_clim, amplitude_alpha=amplitude_alpha, - segment_index=segment_index, recording=recording, **motion_info, ) From eb21af56ed496a9d183aaeb4bf6d5e7a4609a282 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Jun 2024 08:30:44 +0000 Subject: [PATCH 128/136] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/motion_interpolation.py | 2 - src/spikeinterface/widgets/motion.py | 74 +++++++++---------- 2 files changed, 37 insertions(+), 39 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 203eafbb6e..9006ebdcf0 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -28,7 +28,6 @@ def correct_motion_on_peaks(peaks, peak_locations, motion, recording): Motion-corrected peak locations """ corrected_peak_locations = peak_locations.copy() - for segment_index in range(motion.num_segments): times_s = recording.sample_index_to_time(peaks["sample_index"], segment_index=segment_index) @@ -443,7 +442,6 @@ def get_traces(self, start_frame, end_frame, channel_indices): spatial_interpolation_method=self.spatial_interpolation_method, spatial_interpolation_kwargs=self.spatial_interpolation_kwargs, interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s, - ) if channel_indices is not None: diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 27831061ef..12b43ce7de 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -97,43 +97,43 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): class MotionInfoWidget(BaseWidget): """ - Plot motion information from the motion_info dict returned by correct_motion(). - This plot: - * the motion iself - * the peak depth vs time before correction - * the peak depth vs time after correction - - Parameters - ---------- - motion_info : dict - The motion info return by correct_motion() or load back with load_motion_info() -<<<<<<< HEAD - segment_index : int, default: None - The segment index to display. -======= - segment_index: - ->>>>>>> 257950d5859521730ed1da746b6fd32b7b6335bb - recording : RecordingExtractor, default: None - The recording extractor object (only used to get "real" times) - segment_index : int, default: 0 - The segment index to display. - sampling_frequency : float, default: None - The sampling frequency (needed if recording is None) - depth_lim : tuple or None, default: None - The min and max depth to display, if None (min and max of the recording) - motion_lim : tuple or None, default: None - The min and max motion to display, if None (min and max of the motion) - color_amplitude : bool, default: False - If True, the color of the scatter points is the amplitude of the peaks - scatter_decimate : int, default: None - If > 1, the scatter points are decimated - amplitude_cmap : str, default: "inferno" - The colormap to use for the amplitude - amplitude_clim : tuple or None, default: None - The min and max amplitude to display, if None (min and max of the amplitudes) - amplitude_alpha : float, default: 1 - The alpha of the scatter points + Plot motion information from the motion_info dict returned by correct_motion(). + This plot: + * the motion iself + * the peak depth vs time before correction + * the peak depth vs time after correction + + Parameters + ---------- + motion_info : dict + The motion info return by correct_motion() or load back with load_motion_info() + <<<<<<< HEAD + segment_index : int, default: None + The segment index to display. + ======= + segment_index: + + >>>>>>> 257950d5859521730ed1da746b6fd32b7b6335bb + recording : RecordingExtractor, default: None + The recording extractor object (only used to get "real" times) + segment_index : int, default: 0 + The segment index to display. + sampling_frequency : float, default: None + The sampling frequency (needed if recording is None) + depth_lim : tuple or None, default: None + The min and max depth to display, if None (min and max of the recording) + motion_lim : tuple or None, default: None + The min and max motion to display, if None (min and max of the motion) + color_amplitude : bool, default: False + If True, the color of the scatter points is the amplitude of the peaks + scatter_decimate : int, default: None + If > 1, the scatter points are decimated + amplitude_cmap : str, default: "inferno" + The colormap to use for the amplitude + amplitude_clim : tuple or None, default: None + The min and max amplitude to display, if None (min and max of the amplitudes) + amplitude_alpha : float, default: 1 + The alpha of the scatter points """ def __init__( From 8bd92e29e641afe5c2f8f6c122d15520c4d79519 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jun 2024 10:45:33 +0200 Subject: [PATCH 129/136] Ficx bug in plot_potential_merges --- src/spikeinterface/widgets/potential_merges.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/potential_merges.py b/src/spikeinterface/widgets/potential_merges.py index c1c7d86522..be882209b8 100644 --- a/src/spikeinterface/widgets/potential_merges.py +++ b/src/spikeinterface/widgets/potential_merges.py @@ -237,7 +237,7 @@ def _update_plot(self, change=None): self.w_templates._plot_probe(self.ax_probe, channel_locations, plot_unit_ids) crosscorrelograms_data_plot = self.w_crosscorrelograms.data_plot.copy() crosscorrelograms_data_plot["unit_ids"] = plot_unit_ids - merge_unit_indices = np.flatnonzero(np.isin(self.unique_merge_units, plot_unit_ids)) + merge_unit_indices = np.flatnonzero(np.isin(self.data_plot["unique_merge_units"], plot_unit_ids)) updated_correlograms = crosscorrelograms_data_plot["correlograms"] updated_correlograms = updated_correlograms[merge_unit_indices][:, merge_unit_indices] crosscorrelograms_data_plot["correlograms"] = updated_correlograms From aab214893cecd941b74af1af0df838217c047285 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jun 2024 10:47:19 +0200 Subject: [PATCH 130/136] oups --- .../widgets/tests/test_widgets.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 0198e24626..e841a1c93b 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -72,25 +72,25 @@ def setUpClass(cls): ) job_kwargs = dict(n_jobs=-1) - # # create dense - # cls.sorting_analyzer_dense = create_sorting_analyzer(cls.sorting, cls.recording, format="memory", sparse=False) - # cls.sorting_analyzer_dense.compute("random_spikes") - # cls.sorting_analyzer_dense.compute(extensions_to_compute, **job_kwargs) - - # sw.set_default_plotter_backend("matplotlib") - - # # make sparse waveforms - # cls.sparsity_radius = compute_sparsity(cls.sorting_analyzer_dense, method="radius", radius_um=50) - # cls.sparsity_strict = compute_sparsity(cls.sorting_analyzer_dense, method="radius", radius_um=20) - # cls.sparsity_large = compute_sparsity(cls.sorting_analyzer_dense, method="radius", radius_um=80) - # cls.sparsity_best = compute_sparsity(cls.sorting_analyzer_dense, method="best_channels", num_channels=5) - - # # create sparse - # cls.sorting_analyzer_sparse = create_sorting_analyzer( - # cls.sorting, cls.recording, format="memory", sparsity=cls.sparsity_radius - # ) - # cls.sorting_analyzer_sparse.compute("random_spikes") - # cls.sorting_analyzer_sparse.compute(extensions_to_compute, **job_kwargs) + # create dense + cls.sorting_analyzer_dense = create_sorting_analyzer(cls.sorting, cls.recording, format="memory", sparse=False) + cls.sorting_analyzer_dense.compute("random_spikes") + cls.sorting_analyzer_dense.compute(extensions_to_compute, **job_kwargs) + + sw.set_default_plotter_backend("matplotlib") + + # make sparse waveforms + cls.sparsity_radius = compute_sparsity(cls.sorting_analyzer_dense, method="radius", radius_um=50) + cls.sparsity_strict = compute_sparsity(cls.sorting_analyzer_dense, method="radius", radius_um=20) + cls.sparsity_large = compute_sparsity(cls.sorting_analyzer_dense, method="radius", radius_um=80) + cls.sparsity_best = compute_sparsity(cls.sorting_analyzer_dense, method="best_channels", num_channels=5) + + # create sparse + cls.sorting_analyzer_sparse = create_sorting_analyzer( + cls.sorting, cls.recording, format="memory", sparsity=cls.sparsity_radius + ) + cls.sorting_analyzer_sparse.compute("random_spikes") + cls.sorting_analyzer_sparse.compute(extensions_to_compute, **job_kwargs) cls.skip_backends = ["ipywidgets", "ephyviewer", "spikeinterface_gui"] # cls.skip_backends = ["ipywidgets", "ephyviewer", "sortingview"] @@ -107,7 +107,7 @@ def setUpClass(cls): "spikeinterface_gui": {}, } - # cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) + cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) from spikeinterface.sortingcomponents.peak_detection import detect_peaks From 82c7ee51115086eff01cc2ed63abefa4243c30c6 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jun 2024 11:13:43 +0200 Subject: [PATCH 131/136] oups --- src/spikeinterface/widgets/motion.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 27831061ef..5ef3c2d4af 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -44,6 +44,7 @@ def __init__( motion=motion, segment_index=segment_index, mode=mode, + motion_lim=motion_lim, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -107,13 +108,8 @@ class MotionInfoWidget(BaseWidget): ---------- motion_info : dict The motion info return by correct_motion() or load back with load_motion_info() -<<<<<<< HEAD segment_index : int, default: None The segment index to display. -======= - segment_index: - ->>>>>>> 257950d5859521730ed1da746b6fd32b7b6335bb recording : RecordingExtractor, default: None The recording extractor object (only used to get "real" times) segment_index : int, default: 0 From 1c6dcf403914b2cfabd3d6174a40b049b7231f97 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jun 2024 11:52:52 +0200 Subject: [PATCH 132/136] oups --- src/spikeinterface/sortingcomponents/motion_interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 9006ebdcf0..32bb7634e9 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -364,7 +364,7 @@ def __init__( if interpolation_time_bin_centers_s is None: # in this case, interpolation_time_bin_size_s is set. s_end = parent_segment.get_num_samples() - t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end]), segment_index=segment_index) + t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end])) halfbin = interpolation_time_bin_size_s / 2.0 segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s) else: From 2be582b155fa820de809590e5343a4bf2de695ef Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jun 2024 12:41:14 +0200 Subject: [PATCH 133/136] Clean up docs --- src/spikeinterface/curation/auto_merge.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index c652089a39..0797947f87 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -191,7 +191,6 @@ def get_potential_auto_merge( correlogram_diff = compute_correlogram_diff( sorting, correlograms_smoothed, - bins, win_sizes, pair_mask=pair_mask, ) @@ -249,28 +248,28 @@ def get_potential_auto_merge( return potential_merges -def compute_correlogram_diff(sorting, correlograms_smoothed, bins, win_sizes, pair_mask=None): +def compute_correlogram_diff(sorting, correlograms_smoothed, win_sizes, pair_mask=None): """ Original author: Aurelien Wyngaard (lussac) Parameters ---------- sorting : BaseSorting - The sorting object + The sorting object. correlograms_smoothed : array 3d The 3d array containing all cross and auto correlograms - (smoothed by a convolution with a gaussian curve) - bins : array - Bins of the correlograms - win_sized: - TODO + (smoothed by a convolution with a gaussian curve). + win_sizes : np.array[int] + Window size for each unit correlogram. pair_mask : None or boolean array A bool matrix of size (num_units, num_units) to select which pair to compute. Returns ------- - corr_diff + corr_diff : 2D array + The difference between the cross-correlogram and the auto-correlogram + for each pair of units. """ # bin_ms = bins[1] - bins[0] @@ -367,7 +366,7 @@ def get_unit_adaptive_window(auto_corr: np.ndarray, threshold: float): Returns ------- - unit_window (int): + unit_window : int Index at which the adaptive window has been calculated. """ import scipy.signal From 5ef22d6894552a8008b5cae9547b65904f99ed47 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jun 2024 12:41:50 +0200 Subject: [PATCH 134/136] Remove comment --- src/spikeinterface/curation/auto_merge.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 0797947f87..8ab4a07dd6 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -271,8 +271,6 @@ def compute_correlogram_diff(sorting, correlograms_smoothed, win_sizes, pair_mas The difference between the cross-correlogram and the auto-correlogram for each pair of units. """ - # bin_ms = bins[1] - bins[0] - unit_ids = sorting.unit_ids n = len(unit_ids) From e139e75d6d9a5bcbde9ce133db84a83ab7724a9e Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 19 Jun 2024 15:36:09 -0600 Subject: [PATCH 135/136] Cell explorer deprecations (#3046) * cell explorer deprecations --- .../cellexplorersortingextractor.py | 26 ++----------------- 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index 3436313b4d..9b77965c43 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -2,8 +2,7 @@ import numpy as np from pathlib import Path -import warnings -import datetime + from ..core import BaseSorting, BaseSortingSegment from ..core.core_tools import define_function_from_class @@ -36,36 +35,15 @@ class CellExplorerSortingExtractor(BaseSorting): def __init__( self, - file_path: str | Path | None = None, + file_path: str | Path, sampling_frequency: float | None = None, session_info_file_path: str | Path | None = None, - spikes_matfile_path: str | Path | None = None, ): try: from pymatreader import read_mat except ImportError: raise ImportError(self.installation_mesg) - assert ( - file_path is not None or spikes_matfile_path is not None - ), "Either file_path or spikes_matfile_path must be provided!" - - if spikes_matfile_path is not None: - # Raise an error if the warning period has expired - deprecation_issued = datetime.datetime(2023, 4, 1) - deprecation_deadline = deprecation_issued + datetime.timedelta(days=180) - if datetime.datetime.now() > deprecation_deadline: - raise ValueError("The spikes_matfile_path argument is no longer supported in. Use file_path instead.") - - # Otherwise, issue a DeprecationWarning - else: - warnings.warn( - "The spikes_matfile_path argument is deprecated and will be removed in six months. " - "Use file_path instead.", - DeprecationWarning, - ) - file_path = spikes_matfile_path if file_path is None else file_path - self.spikes_cellinfo_path = Path(file_path) self.session_path = self.spikes_cellinfo_path.parent self.session_id = self.spikes_cellinfo_path.stem.split(".")[0] From 328bba007391c47a5f08a1f0897119b5ea09bc05 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 20 Jun 2024 08:56:54 -0600 Subject: [PATCH 136/136] Fix intan kwargs (#3054) * add "ignore_integrity_checks" to intan kwargs --- src/spikeinterface/extractors/neoextractors/intan.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index 7b3816a04d..c37ff47807 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -53,6 +53,8 @@ def __init__( ) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()))) + if "ignore_integrity_checks" in neo_kwargs: + self._kwargs["ignore_integrity_checks"] = neo_kwargs["ignore_integrity_checks"] @classmethod def map_to_neo_kwargs(cls, file_path, ignore_integrity_checks: bool = False):