From ca985072f1d87990b4e86b4c1fdc18c59f3c7869 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 18 Aug 2024 09:52:22 +0200 Subject: [PATCH 1/5] Add dtype in load_waveforms and analyzer.is_filtered() --- src/spikeinterface/core/sortinganalyzer.py | 3 +++ .../core/waveforms_extractor_backwards_compatibility.py | 4 ++++ src/spikeinterface/exporters/to_phy.py | 2 +- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index ac142405ab..e427236e15 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1011,6 +1011,9 @@ def has_temporary_recording(self) -> bool: def is_sparse(self) -> bool: return self.sparsity is not None + def is_filtered(self) -> bool: + return self.rec_attributes["filtered"] + def get_sorting_provenance(self): """ Get the original sorting if possible otherwise return None diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index da1f5a71f5..d9514c7fce 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -446,6 +446,10 @@ def _read_old_waveforms_extractor_binary(folder, sorting): else: rec_attributes["probegroup"] = None + if "dtype" not in rec_attributes: + warnings.warn("dtype not found in rec_attributes. Setting to float32") + rec_attributes["dtype"] = "float32" + # recording recording = None if (folder / "recording.json").exists(): diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 7b3c7daab0..06041da231 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -167,7 +167,7 @@ def export_to_phy( f.write(f"dtype = '{dtype_str}'\n") f.write(f"offset = 0\n") f.write(f"sample_rate = {fs}\n") - f.write(f"hp_filtered = {sorting_analyzer.recording.is_filtered()}") + f.write(f"hp_filtered = {sorting_analyzer.is_filtered()}") # export spike_times/spike_templates/spike_clusters # here spike_labels is a remapping to unit_index From df238228d2bbe2650b4e1cb8be2ce61dbe424578 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 18 Aug 2024 10:45:36 +0200 Subject: [PATCH 2/5] oups --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index e427236e15..7a9510e72f 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1012,7 +1012,7 @@ def is_sparse(self) -> bool: return self.sparsity is not None def is_filtered(self) -> bool: - return self.rec_attributes["filtered"] + return self.rec_attributes["is_filtered"] def get_sorting_provenance(self): """ From bc1c704dadb5c771b60ca11847d3bc2169fb4086 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 20 Aug 2024 12:35:09 +0200 Subject: [PATCH 3/5] Improve do_recording_attributes_match impelmentation, errors, and tests --- src/spikeinterface/core/recording_tools.py | 52 +++++++++++++++---- src/spikeinterface/core/sortinganalyzer.py | 19 +++++-- .../core/tests/test_recording_tools.py | 41 +++++++++++++++ .../core/tests/test_sortinganalyzer.py | 11 +++- ...forms_extractor_backwards_compatibility.py | 4 -- 5 files changed, 107 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index b4c07e77c9..cd2f563fba 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -1,6 +1,6 @@ from __future__ import annotations from copy import deepcopy -from typing import Literal +from typing import Literal, Tuple import warnings from pathlib import Path import os @@ -929,7 +929,9 @@ def get_rec_attributes(recording): return rec_attributes -def do_recording_attributes_match(recording1, recording2_attributes) -> bool: +def do_recording_attributes_match( + recording1: "BaseRecording", recording2_attributes: bool, check_is_filtered: bool = True, check_dtype: bool = True +) -> Tuple[bool, str]: """ Check if two recordings have the same attributes @@ -939,22 +941,52 @@ def do_recording_attributes_match(recording1, recording2_attributes) -> bool: The first recording object recording2_attributes : dict The recording attributes to test against + check_is_filtered : bool, default: True + If True, check if the recordings are filtered + check_dtype : bool, default: True + If True, check if the recordings have the same dtype Returns ------- bool True if the recordings have the same attributes + str + A string with the an exception message with attributes that do not match """ recording1_attributes = get_rec_attributes(recording1) recording2_attributes = deepcopy(recording2_attributes) recording1_attributes.pop("properties") recording2_attributes.pop("properties") - return ( - np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"]) - and recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"] - and recording1_attributes["num_channels"] == recording2_attributes["num_channels"] - and recording1_attributes["num_samples"] == recording2_attributes["num_samples"] - and recording1_attributes["is_filtered"] == recording2_attributes["is_filtered"] - and recording1_attributes["dtype"] == recording2_attributes["dtype"] - ) + attributes_match = True + non_matching_attrs = [] + + if not np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"]): + attributes_match = False + non_matching_attrs.append("channel_ids") + if not recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"]: + attributes_match = False + non_matching_attrs.append("sampling_frequency") + if not recording1_attributes["num_channels"] == recording2_attributes["num_channels"]: + attributes_match = False + non_matching_attrs.append("num_channels") + if not recording1_attributes["num_samples"] == recording2_attributes["num_samples"]: + attributes_match = False + non_matching_attrs.append("num_samples") + if check_is_filtered: + if not recording1_attributes["is_filtered"] == recording2_attributes["is_filtered"]: + attributes_match = False + non_matching_attrs.append("is_filtered") + # dtype is optional + if "dtype" in recording1_attributes and "dtype" in recording2_attributes: + if check_dtype: + if not recording1_attributes["dtype"] == recording2_attributes["dtype"]: + attributes_match = False + non_matching_attrs.append("dtype") + + if len(non_matching_attrs) > 0: + exception_str = f"Recordings do not match in the following attributes: {non_matching_attrs}" + else: + exception_str = "" + + return attributes_match, exception_str diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 7a9510e72f..d034dcb46a 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -608,7 +608,9 @@ def load_from_zarr(cls, folder, recording=None): return sorting_analyzer - def set_temporary_recording(self, recording: BaseRecording): + def set_temporary_recording( + self, recording: BaseRecording, check_is_filtered: bool = True, check_dtype: bool = True + ): """ Sets a temporary recording object. This function can be useful to temporarily set a "cached" recording object that is not saved in the SortingAnalyzer object to speed up @@ -620,12 +622,19 @@ def set_temporary_recording(self, recording: BaseRecording): ---------- recording : BaseRecording The recording object to set as temporary recording. + check_is_filtered : bool, default: True + If True, check that the temporary recording is filtered in the same way as the original recording. + check_dtype : bool, default: True + If True, check that the dtype of the temporary recording is the same as the original recording. """ # check that recording is compatible - assert do_recording_attributes_match(recording, self.rec_attributes), "Recording attributes do not match." - assert np.array_equal( - recording.get_channel_locations(), self.get_channel_locations() - ), "Recording channel locations do not match." + attributes_match, exception_str = do_recording_attributes_match( + recording, self.rec_attributes, check_is_filtered=check_is_filtered, check_dtype=check_dtype + ) + if not attributes_match: + raise ValueError(exception_str) + if not np.array_equal(recording.get_channel_locations(), self.get_channel_locations()): + raise ValueError("Recording channel locations do not match.") if self._recording is not None: warnings.warn("SortingAnalyzer recording is already set. The current recording is temporarily replaced.") self._temporary_recording = recording diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index d83e4d76fc..8a8fc3a358 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -17,6 +17,8 @@ get_channel_distances, get_noise_levels, order_channels_by_depth, + do_recording_attributes_match, + get_rec_attributes, ) @@ -300,6 +302,45 @@ def test_order_channels_by_depth(): assert np.array_equal(order_2d[::-1], order_2d_fliped) +def test_do_recording_attributes_match(): + recording = NoiseGeneratorRecording( + num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" + ) + rec_attributes = get_rec_attributes(recording) + do_match, _ = do_recording_attributes_match(recording, rec_attributes) + assert do_match + + rec_attributes = get_rec_attributes(recording) + rec_attributes["sampling_frequency"] = 1.0 + do_match, exc = do_recording_attributes_match(recording, rec_attributes) + assert not do_match + assert "sampling_frequency" in exc + + # check is_filtered options + rec_attributes = get_rec_attributes(recording) + rec_attributes["is_filtered"] = not rec_attributes["is_filtered"] + + do_match, exc = do_recording_attributes_match(recording, rec_attributes) + assert not do_match + assert "is_filtered" in exc + do_match, exc = do_recording_attributes_match(recording, rec_attributes, check_is_filtered=False) + assert do_match + + # check dtype options + rec_attributes = get_rec_attributes(recording) + rec_attributes["dtype"] = "int16" + do_match, exc = do_recording_attributes_match(recording, rec_attributes) + assert not do_match + assert "dtype" in exc + do_match, exc = do_recording_attributes_match(recording, rec_attributes, check_dtype=False) + assert do_match + + # check missing dtype + rec_attributes.pop("dtype") + do_match, exc = do_recording_attributes_match(recording, rec_attributes) + assert do_match + + if __name__ == "__main__": # Create a temporary folder using the standard library import tempfile diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index d89eb7fac0..9de725239d 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -141,9 +141,18 @@ def test_SortingAnalyzer_tmp_recording(dataset): recording_sliced = recording.channel_slice(recording.channel_ids[:-1]) # wrong channels - with pytest.raises(AssertionError): + with pytest.raises(ValueError): sorting_analyzer.set_temporary_recording(recording_sliced) + # test with different is_filtered + recording_filt = recording.clone() + recording_filt.annotate(is_filtered=False) + with pytest.raises(ValueError): + sorting_analyzer.set_temporary_recording(recording_filt) + + # thest with additional check_is_filtered + sorting_analyzer.set_temporary_recording(recording_filt, check_is_filtered=False) + def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index d9514c7fce..da1f5a71f5 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -446,10 +446,6 @@ def _read_old_waveforms_extractor_binary(folder, sorting): else: rec_attributes["probegroup"] = None - if "dtype" not in rec_attributes: - warnings.warn("dtype not found in rec_attributes. Setting to float32") - rec_attributes["dtype"] = "float32" - # recording recording = None if (folder / "recording.json").exists(): From d6f3ced15f938943f83741b6867de7e7916a3de8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 28 Aug 2024 13:09:18 +0200 Subject: [PATCH 4/5] Update src/spikeinterface/core/recording_tools.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/recording_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index cd2f563fba..5833f81ff8 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -951,7 +951,7 @@ def do_recording_attributes_match( bool True if the recordings have the same attributes str - A string with the an exception message with attributes that do not match + A string with the exception message with the attributes that do not match """ recording1_attributes = get_rec_attributes(recording1) recording2_attributes = deepcopy(recording2_attributes) From 945fc15980122c7d4cd3f9b5fc67c4b40816c948 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 28 Aug 2024 15:08:15 +0200 Subject: [PATCH 5/5] Suggestions from code review --- src/spikeinterface/core/recording_tools.py | 19 +++++-------------- src/spikeinterface/core/sortinganalyzer.py | 8 ++------ .../core/tests/test_recording_tools.py | 10 ---------- .../core/tests/test_sortinganalyzer.py | 9 --------- 4 files changed, 7 insertions(+), 39 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index cd2f563fba..7cbc236eda 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -1,6 +1,6 @@ from __future__ import annotations from copy import deepcopy -from typing import Literal, Tuple +from typing import Literal import warnings from pathlib import Path import os @@ -930,8 +930,8 @@ def get_rec_attributes(recording): def do_recording_attributes_match( - recording1: "BaseRecording", recording2_attributes: bool, check_is_filtered: bool = True, check_dtype: bool = True -) -> Tuple[bool, str]: + recording1: "BaseRecording", recording2_attributes: bool, check_dtype: bool = True +) -> tuple[bool, str]: """ Check if two recordings have the same attributes @@ -941,8 +941,6 @@ def do_recording_attributes_match( The first recording object recording2_attributes : dict The recording attributes to test against - check_is_filtered : bool, default: True - If True, check if the recordings are filtered check_dtype : bool, default: True If True, check if the recordings have the same dtype @@ -962,31 +960,24 @@ def do_recording_attributes_match( non_matching_attrs = [] if not np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"]): - attributes_match = False non_matching_attrs.append("channel_ids") if not recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"]: - attributes_match = False non_matching_attrs.append("sampling_frequency") if not recording1_attributes["num_channels"] == recording2_attributes["num_channels"]: - attributes_match = False non_matching_attrs.append("num_channels") if not recording1_attributes["num_samples"] == recording2_attributes["num_samples"]: - attributes_match = False non_matching_attrs.append("num_samples") - if check_is_filtered: - if not recording1_attributes["is_filtered"] == recording2_attributes["is_filtered"]: - attributes_match = False - non_matching_attrs.append("is_filtered") # dtype is optional if "dtype" in recording1_attributes and "dtype" in recording2_attributes: if check_dtype: if not recording1_attributes["dtype"] == recording2_attributes["dtype"]: - attributes_match = False non_matching_attrs.append("dtype") if len(non_matching_attrs) > 0: + attributes_match = False exception_str = f"Recordings do not match in the following attributes: {non_matching_attrs}" else: + attributes_match = True exception_str = "" return attributes_match, exception_str diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index d034dcb46a..7687017db6 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -608,9 +608,7 @@ def load_from_zarr(cls, folder, recording=None): return sorting_analyzer - def set_temporary_recording( - self, recording: BaseRecording, check_is_filtered: bool = True, check_dtype: bool = True - ): + def set_temporary_recording(self, recording: BaseRecording, check_dtype: bool = True): """ Sets a temporary recording object. This function can be useful to temporarily set a "cached" recording object that is not saved in the SortingAnalyzer object to speed up @@ -622,14 +620,12 @@ def set_temporary_recording( ---------- recording : BaseRecording The recording object to set as temporary recording. - check_is_filtered : bool, default: True - If True, check that the temporary recording is filtered in the same way as the original recording. check_dtype : bool, default: True If True, check that the dtype of the temporary recording is the same as the original recording. """ # check that recording is compatible attributes_match, exception_str = do_recording_attributes_match( - recording, self.rec_attributes, check_is_filtered=check_is_filtered, check_dtype=check_dtype + recording, self.rec_attributes, check_dtype=check_dtype ) if not attributes_match: raise ValueError(exception_str) diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 8a8fc3a358..23a1574f2a 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -316,16 +316,6 @@ def test_do_recording_attributes_match(): assert not do_match assert "sampling_frequency" in exc - # check is_filtered options - rec_attributes = get_rec_attributes(recording) - rec_attributes["is_filtered"] = not rec_attributes["is_filtered"] - - do_match, exc = do_recording_attributes_match(recording, rec_attributes) - assert not do_match - assert "is_filtered" in exc - do_match, exc = do_recording_attributes_match(recording, rec_attributes, check_is_filtered=False) - assert do_match - # check dtype options rec_attributes = get_rec_attributes(recording) rec_attributes["dtype"] = "int16" diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 9de725239d..689073d6bf 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -144,15 +144,6 @@ def test_SortingAnalyzer_tmp_recording(dataset): with pytest.raises(ValueError): sorting_analyzer.set_temporary_recording(recording_sliced) - # test with different is_filtered - recording_filt = recording.clone() - recording_filt.annotate(is_filtered=False) - with pytest.raises(ValueError): - sorting_analyzer.set_temporary_recording(recording_filt) - - # thest with additional check_is_filtered - sorting_analyzer.set_temporary_recording(recording_filt, check_is_filtered=False) - def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):