Skip to content

Commit

Permalink
Merge pull request SpikeInterface#3312 from alejoe91/fix-waveforms-ba…
Browse files Browse the repository at this point in the history
…ckcompatibility

Add dtype in load_waveforms and analyzer.is_filtered()
  • Loading branch information
samuelgarcia authored Aug 29, 2024
2 parents 9238023 + 5461e43 commit 261d671
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 16 deletions.
41 changes: 32 additions & 9 deletions src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,9 @@ def get_rec_attributes(recording):
return rec_attributes


def do_recording_attributes_match(recording1, recording2_attributes) -> bool:
def do_recording_attributes_match(
recording1: "BaseRecording", recording2_attributes: bool, check_dtype: bool = True
) -> tuple[bool, str]:
"""
Check if two recordings have the same attributes
Expand All @@ -939,22 +941,43 @@ def do_recording_attributes_match(recording1, recording2_attributes) -> bool:
The first recording object
recording2_attributes : dict
The recording attributes to test against
check_dtype : bool, default: True
If True, check if the recordings have the same dtype
Returns
-------
bool
True if the recordings have the same attributes
str
A string with the exception message with the attributes that do not match
"""
recording1_attributes = get_rec_attributes(recording1)
recording2_attributes = deepcopy(recording2_attributes)
recording1_attributes.pop("properties")
recording2_attributes.pop("properties")

return (
np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"])
and recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"]
and recording1_attributes["num_channels"] == recording2_attributes["num_channels"]
and recording1_attributes["num_samples"] == recording2_attributes["num_samples"]
and recording1_attributes["is_filtered"] == recording2_attributes["is_filtered"]
and recording1_attributes["dtype"] == recording2_attributes["dtype"]
)
attributes_match = True
non_matching_attrs = []

if not np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"]):
non_matching_attrs.append("channel_ids")
if not recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"]:
non_matching_attrs.append("sampling_frequency")
if not recording1_attributes["num_channels"] == recording2_attributes["num_channels"]:
non_matching_attrs.append("num_channels")
if not recording1_attributes["num_samples"] == recording2_attributes["num_samples"]:
non_matching_attrs.append("num_samples")
# dtype is optional
if "dtype" in recording1_attributes and "dtype" in recording2_attributes:
if check_dtype:
if not recording1_attributes["dtype"] == recording2_attributes["dtype"]:
non_matching_attrs.append("dtype")

if len(non_matching_attrs) > 0:
attributes_match = False
exception_str = f"Recordings do not match in the following attributes: {non_matching_attrs}"
else:
attributes_match = True
exception_str = ""

return attributes_match, exception_str
18 changes: 13 additions & 5 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None):

return sorting_analyzer

def set_temporary_recording(self, recording: BaseRecording):
def set_temporary_recording(self, recording: BaseRecording, check_dtype: bool = True):
"""
Sets a temporary recording object. This function can be useful to temporarily set
a "cached" recording object that is not saved in the SortingAnalyzer object to speed up
Expand All @@ -628,12 +628,17 @@ def set_temporary_recording(self, recording: BaseRecording):
----------
recording : BaseRecording
The recording object to set as temporary recording.
check_dtype : bool, default: True
If True, check that the dtype of the temporary recording is the same as the original recording.
"""
# check that recording is compatible
assert do_recording_attributes_match(recording, self.rec_attributes), "Recording attributes do not match."
assert np.array_equal(
recording.get_channel_locations(), self.get_channel_locations()
), "Recording channel locations do not match."
attributes_match, exception_str = do_recording_attributes_match(
recording, self.rec_attributes, check_dtype=check_dtype
)
if not attributes_match:
raise ValueError(exception_str)
if not np.array_equal(recording.get_channel_locations(), self.get_channel_locations()):
raise ValueError("Recording channel locations do not match.")
if self._recording is not None:
warnings.warn("SortingAnalyzer recording is already set. The current recording is temporarily replaced.")
self._temporary_recording = recording
Expand Down Expand Up @@ -1026,6 +1031,9 @@ def has_temporary_recording(self) -> bool:
def is_sparse(self) -> bool:
return self.sparsity is not None

def is_filtered(self) -> bool:
return self.rec_attributes["is_filtered"]

def get_sorting_provenance(self):
"""
Get the original sorting if possible otherwise return None
Expand Down
31 changes: 31 additions & 0 deletions src/spikeinterface/core/tests/test_recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
get_channel_distances,
get_noise_levels,
order_channels_by_depth,
do_recording_attributes_match,
get_rec_attributes,
)


Expand Down Expand Up @@ -300,6 +302,35 @@ def test_order_channels_by_depth():
assert np.array_equal(order_2d[::-1], order_2d_fliped)


def test_do_recording_attributes_match():
recording = NoiseGeneratorRecording(
num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated"
)
rec_attributes = get_rec_attributes(recording)
do_match, _ = do_recording_attributes_match(recording, rec_attributes)
assert do_match

rec_attributes = get_rec_attributes(recording)
rec_attributes["sampling_frequency"] = 1.0
do_match, exc = do_recording_attributes_match(recording, rec_attributes)
assert not do_match
assert "sampling_frequency" in exc

# check dtype options
rec_attributes = get_rec_attributes(recording)
rec_attributes["dtype"] = "int16"
do_match, exc = do_recording_attributes_match(recording, rec_attributes)
assert not do_match
assert "dtype" in exc
do_match, exc = do_recording_attributes_match(recording, rec_attributes, check_dtype=False)
assert do_match

# check missing dtype
rec_attributes.pop("dtype")
do_match, exc = do_recording_attributes_match(recording, rec_attributes)
assert do_match


if __name__ == "__main__":
# Create a temporary folder using the standard library
import tempfile
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_SortingAnalyzer_tmp_recording(dataset):
recording_sliced = recording.channel_slice(recording.channel_ids[:-1])

# wrong channels
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
sorting_analyzer.set_temporary_recording(recording_sliced)


Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/exporters/to_phy.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def export_to_phy(
f.write(f"dtype = '{dtype_str}'\n")
f.write(f"offset = 0\n")
f.write(f"sample_rate = {fs}\n")
f.write(f"hp_filtered = {sorting_analyzer.recording.is_filtered()}")
f.write(f"hp_filtered = {sorting_analyzer.is_filtered()}")

# export spike_times/spike_templates/spike_clusters
# here spike_labels is a remapping to unit_index
Expand Down

0 comments on commit 261d671

Please sign in to comment.