diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index b4c07e77c9..2c7e75668f 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -929,7 +929,9 @@ def get_rec_attributes(recording): return rec_attributes -def do_recording_attributes_match(recording1, recording2_attributes) -> bool: +def do_recording_attributes_match( + recording1: "BaseRecording", recording2_attributes: bool, check_dtype: bool = True +) -> tuple[bool, str]: """ Check if two recordings have the same attributes @@ -939,22 +941,43 @@ def do_recording_attributes_match(recording1, recording2_attributes) -> bool: The first recording object recording2_attributes : dict The recording attributes to test against + check_dtype : bool, default: True + If True, check if the recordings have the same dtype Returns ------- bool True if the recordings have the same attributes + str + A string with the exception message with the attributes that do not match """ recording1_attributes = get_rec_attributes(recording1) recording2_attributes = deepcopy(recording2_attributes) recording1_attributes.pop("properties") recording2_attributes.pop("properties") - return ( - np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"]) - and recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"] - and recording1_attributes["num_channels"] == recording2_attributes["num_channels"] - and recording1_attributes["num_samples"] == recording2_attributes["num_samples"] - and recording1_attributes["is_filtered"] == recording2_attributes["is_filtered"] - and recording1_attributes["dtype"] == recording2_attributes["dtype"] - ) + attributes_match = True + non_matching_attrs = [] + + if not np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"]): + non_matching_attrs.append("channel_ids") + if not recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"]: + non_matching_attrs.append("sampling_frequency") + if not recording1_attributes["num_channels"] == recording2_attributes["num_channels"]: + non_matching_attrs.append("num_channels") + if not recording1_attributes["num_samples"] == recording2_attributes["num_samples"]: + non_matching_attrs.append("num_samples") + # dtype is optional + if "dtype" in recording1_attributes and "dtype" in recording2_attributes: + if check_dtype: + if not recording1_attributes["dtype"] == recording2_attributes["dtype"]: + non_matching_attrs.append("dtype") + + if len(non_matching_attrs) > 0: + attributes_match = False + exception_str = f"Recordings do not match in the following attributes: {non_matching_attrs}" + else: + attributes_match = True + exception_str = "" + + return attributes_match, exception_str diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index fbf0307498..dfdb44bcde 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -616,7 +616,7 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): return sorting_analyzer - def set_temporary_recording(self, recording: BaseRecording): + def set_temporary_recording(self, recording: BaseRecording, check_dtype: bool = True): """ Sets a temporary recording object. This function can be useful to temporarily set a "cached" recording object that is not saved in the SortingAnalyzer object to speed up @@ -628,12 +628,17 @@ def set_temporary_recording(self, recording: BaseRecording): ---------- recording : BaseRecording The recording object to set as temporary recording. + check_dtype : bool, default: True + If True, check that the dtype of the temporary recording is the same as the original recording. """ # check that recording is compatible - assert do_recording_attributes_match(recording, self.rec_attributes), "Recording attributes do not match." - assert np.array_equal( - recording.get_channel_locations(), self.get_channel_locations() - ), "Recording channel locations do not match." + attributes_match, exception_str = do_recording_attributes_match( + recording, self.rec_attributes, check_dtype=check_dtype + ) + if not attributes_match: + raise ValueError(exception_str) + if not np.array_equal(recording.get_channel_locations(), self.get_channel_locations()): + raise ValueError("Recording channel locations do not match.") if self._recording is not None: warnings.warn("SortingAnalyzer recording is already set. The current recording is temporarily replaced.") self._temporary_recording = recording @@ -1026,6 +1031,9 @@ def has_temporary_recording(self) -> bool: def is_sparse(self) -> bool: return self.sparsity is not None + def is_filtered(self) -> bool: + return self.rec_attributes["is_filtered"] + def get_sorting_provenance(self): """ Get the original sorting if possible otherwise return None diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index d83e4d76fc..23a1574f2a 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -17,6 +17,8 @@ get_channel_distances, get_noise_levels, order_channels_by_depth, + do_recording_attributes_match, + get_rec_attributes, ) @@ -300,6 +302,35 @@ def test_order_channels_by_depth(): assert np.array_equal(order_2d[::-1], order_2d_fliped) +def test_do_recording_attributes_match(): + recording = NoiseGeneratorRecording( + num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" + ) + rec_attributes = get_rec_attributes(recording) + do_match, _ = do_recording_attributes_match(recording, rec_attributes) + assert do_match + + rec_attributes = get_rec_attributes(recording) + rec_attributes["sampling_frequency"] = 1.0 + do_match, exc = do_recording_attributes_match(recording, rec_attributes) + assert not do_match + assert "sampling_frequency" in exc + + # check dtype options + rec_attributes = get_rec_attributes(recording) + rec_attributes["dtype"] = "int16" + do_match, exc = do_recording_attributes_match(recording, rec_attributes) + assert not do_match + assert "dtype" in exc + do_match, exc = do_recording_attributes_match(recording, rec_attributes, check_dtype=False) + assert do_match + + # check missing dtype + rec_attributes.pop("dtype") + do_match, exc = do_recording_attributes_match(recording, rec_attributes) + assert do_match + + if __name__ == "__main__": # Create a temporary folder using the standard library import tempfile diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index d89eb7fac0..689073d6bf 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -141,7 +141,7 @@ def test_SortingAnalyzer_tmp_recording(dataset): recording_sliced = recording.channel_slice(recording.channel_ids[:-1]) # wrong channels - with pytest.raises(AssertionError): + with pytest.raises(ValueError): sorting_analyzer.set_temporary_recording(recording_sliced) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 7b3c7daab0..06041da231 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -167,7 +167,7 @@ def export_to_phy( f.write(f"dtype = '{dtype_str}'\n") f.write(f"offset = 0\n") f.write(f"sample_rate = {fs}\n") - f.write(f"hp_filtered = {sorting_analyzer.recording.is_filtered()}") + f.write(f"hp_filtered = {sorting_analyzer.is_filtered()}") # export spike_times/spike_templates/spike_clusters # here spike_labels is a remapping to unit_index