From 5e77fdafecedcf32b28b4a2e3d4d46a988c33e11 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 8 Dec 2023 09:15:32 +0100 Subject: [PATCH 1/8] fix testing --- .../extractors/nwbextractors.py | 100 +++++++++++------- .../extractors/tests/test_nwbextractors.py | 34 ++++-- 2 files changed, 85 insertions(+), 49 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index d5f5ac60a5..edc2be51fb 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path from typing import Union, List, Optional, Literal, Dict, BinaryIO +import warnings import numpy as np @@ -551,8 +552,6 @@ def __init__( file_path=file_path, stream_mode=stream_mode, cache=cache, stream_cache_path=stream_cache_path ) - units_ids = list(self._nwbfile.units.id[:]) - timestamps = None if sampling_frequency is None: # defines the electrical series from where the sorting came from @@ -571,39 +570,48 @@ def __init__( "Couldn't load sampling frequency. Please provide it with the " "'sampling_frequency' argument" ) + units_table = self._nwbfile.units + + name_to_column_data = {c.name: c for c in units_table.columns} + spike_times_data = name_to_column_data.pop("spike_times").data + spike_times_index_data = name_to_column_data.pop("spike_times_index").data + + units_ids = name_to_column_data.pop("unit_name", None) + if units_ids is None: + units_ids = units_table["id"].data + BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=units_ids) + sorting_segment = NwbSortingSegment( - nwbfile=self._nwbfile, sampling_frequency=sampling_frequency, timestamps=timestamps + spike_times_data=spike_times_data, + spike_times_index_data=spike_times_index_data, + sampling_frequency=sampling_frequency, ) self.add_sorting_segment(sorting_segment) - # Add properties: - properties = dict() - import warnings - - for column in list(self._nwbfile.units.colnames): - if column == "spike_times": - continue - - # Note that this has a different behavior than self._nwbfile.units[column].data - property_values = self._nwbfile.units[column][:] + # Add properties + properties_to_add = [name for name in name_to_column_data if "index" not in name] + for property_name in properties_to_add: + data = name_to_column_data.pop(property_name).data + data_index = name_to_column_data.get(f"{property_name}_index", None) + not_ragged_array = data_index is None + if not_ragged_array: + values = data[:] + else: + data_index = data_index.data + index_spacing = np.diff(data_index, prepend=0) + all_index_spacing_are_the_same = np.unique(index_spacing).size == 1 + if all_index_spacing_are_the_same: + start_indices = [0] + list(data_index[:-1]) + end_indices = list(data_index) + values = [data[start_index:end_index] for start_index, end_index in zip(start_indices, end_indices)] + self.set_property(property_name, np.asarray(values)) - # Making this explicit because I am not sure this is the best test - is_raggged_array = isinstance(property_values, list) - if is_raggged_array: - all_values_have_equal_shape = np.all([p.shape == property_values[0].shape for p in property_values]) - if all_values_have_equal_shape: - properties[column] = property_values else: - warnings.warn(f"Skipping {column} because of unequal shapes across units") - - continue # To next property - - # The rest of the properties are added as they come - properties[column] = property_values + warnings.warn(f"Skipping {property_name} because of unequal shapes across units") + continue - for prop_name, values in properties.items(): - self.set_property(prop_name, np.array(values)) + self.set_property(property_name, np.asarray(values)) if stream_mode is None and file_path is not None: file_path = str(Path(file_path).resolve()) @@ -620,11 +628,11 @@ def __init__( class NwbSortingSegment(BaseSortingSegment): - def __init__(self, nwbfile, sampling_frequency, timestamps): + def __init__(self, spike_times_data, spike_times_index_data, sampling_frequency): BaseSortingSegment.__init__(self) - self._nwbfile = nwbfile + self.spike_times_data = spike_times_data + self.spike_times_index_data = spike_times_index_data self._sampling_frequency = sampling_frequency - self._timestamps = timestamps def get_unit_spike_train( self, @@ -632,18 +640,30 @@ def get_unit_spike_train( start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, ) -> np.ndarray: - # must be implemented in subclass - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = np.inf - spike_times = self._nwbfile.units["spike_times"][list(self._nwbfile.units.id[:]).index(unit_id)][:] + # Extract the spike times for the unit + unit_index = self.parent_extractor.id_to_index(unit_id) + if unit_index == 0: + start_index = 0 + else: + start_index = self.spike_times_index_data[unit_index - 1] + end_index = self.spike_times_index_data[unit_index] + spike_times = self.spike_times_data[start_index:end_index] + + # Transform spike times to frames and subset + frames = np.round(spike_times * self._sampling_frequency) - if self._timestamps is not None: - frames = np.searchsorted(spike_times, self.timestamps) + start_index = 0 + if start_frame is not None: + start_index = np.searchsorted(frames, start_frame, side="left") else: - frames = np.round(spike_times * self._sampling_frequency) - return frames[(frames >= start_frame) & (frames < end_frame)].astype("int64", copy=False) + start_index = 0 + + if end_frame is not None: + end_index = np.searchsorted(frames, end_frame, side="left") + else: + end_index = frames.size + + return frames[start_index:end_index].astype("int64", copy=False) read_nwb_recording = define_function_from_class(source_class=NwbRecordingExtractor, name="read_nwb_recording") diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors.py b/src/spikeinterface/extractors/tests/test_nwbextractors.py index 400cb9311e..996ad5715a 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors.py @@ -98,8 +98,15 @@ def nwbfile_with_ecephys_content(): rng = np.random.default_rng(0) data = rng.random(size=(num_frames, len(electrode_indices))) rate = 30_000.0 + conversion = 5.0 + a_different_offset = offset + 1.0 electrical_series = ElectricalSeries( - name=electrical_series_name, data=data, electrodes=electrode_region, rate=rate, offset=offset + 1.0 + name=electrical_series_name, + data=data, + electrodes=electrode_region, + rate=rate, + offset=a_different_offset, + conversion=conversion, ) nwbfile.add_acquisition(electrical_series) @@ -144,7 +151,6 @@ def test_nwb_extractor_property_retrieval(path_to_nwbfile, nwbfile_with_ecephys_ electrical_series_name_list = ["ElectricalSeries1", "ElectricalSeries2"] for electrical_series_name in electrical_series_name_list: recording_extractor = NwbRecordingExtractor(path_to_nwbfile, electrical_series_name=electrical_series_name) - nwbfile = nwbfile_with_ecephys_content electrical_series = nwbfile.acquisition[electrical_series_name] electrical_series_electrode_indices = electrical_series.electrodes.data[:] @@ -160,7 +166,6 @@ def test_nwb_extractor_offset_from_electrodes_table(path_to_nwbfile, nwbfile_wit """Test that the offset is retrieved from the electrodes table if it is not present in the ElectricalSeries.""" electrical_series_name = "ElectricalSeries1" recording_extractor = NwbRecordingExtractor(path_to_nwbfile, electrical_series_name=electrical_series_name) - nwbfile = nwbfile_with_ecephys_content electrical_series = nwbfile.acquisition[electrical_series_name] electrical_series_electrode_indices = electrical_series.electrodes.data[:] @@ -176,7 +181,6 @@ def test_nwb_extractor_offset_from_series(path_to_nwbfile, nwbfile_with_ecephys_ """Test that the offset is retrieved from the ElectricalSeries if it is present.""" electrical_series_name = "ElectricalSeries2" recording_extractor = NwbRecordingExtractor(path_to_nwbfile, electrical_series_name=electrical_series_name) - nwbfile = nwbfile_with_ecephys_content electrical_series = nwbfile.acquisition[electrical_series_name] expected_offsets_uV = electrical_series.offset * 1e6 @@ -190,12 +194,14 @@ def test_sorting_extraction_of_ragged_arrays(tmp_path): # Add the spikes nwbfile.add_unit_column(name="unit_name", description="the name of the unit") + nwbfile.add_unit_column(name="a_property", description="a_cool_property") + spike_times1 = np.array([0.0, 1.0, 2.0]) - nwbfile.add_unit(spike_times=spike_times1, unit_name="a") + nwbfile.add_unit(spike_times=spike_times1, unit_name="a", a_property="a_property_value") spike_times2 = np.array([0.0, 1.0, 2.0, 3.0]) - nwbfile.add_unit(spike_times=spike_times2, unit_name="b") + nwbfile.add_unit(spike_times=spike_times2, unit_name="b", a_property="b_property_value") - ragged_array_bad = [[1, 2, 3], [1, 2, 3, 5]] + ragged_array_bad = [[1, 2, 3, 8, 10], [1, 2, 3, 5]] nwbfile.add_unit_column( name="ragged_array_bad", description="an evill array that wants to destroy your test", @@ -203,7 +209,7 @@ def test_sorting_extraction_of_ragged_arrays(tmp_path): index=True, ) - ragged_array_good = [[1, 2], [3, 4]] + ragged_array_good = [[1, 2, 3], [4, 5, 6]] nwbfile.add_unit_column( name="ragged_array_good", description="a good array that wants to help your test be nice to nice arrays", @@ -218,10 +224,20 @@ def test_sorting_extraction_of_ragged_arrays(tmp_path): sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=10.0) - # Check that the bad array was not added + units_ids = sorting_extractor.get_unit_ids() + + np.testing.assert_equal(units_ids, ["a", "b"]) + added_properties = sorting_extractor.get_property_keys() assert "ragged_array_bad" not in added_properties assert "ragged_array_good" in added_properties + assert "a_property" in added_properties + + spike_train1 = sorting_extractor.get_unit_spike_train(unit_id="a", return_times=True) + np.testing.assert_allclose(spike_train1, spike_times1) + + spike_train2 = sorting_extractor.get_unit_spike_train(unit_id="b", return_times=True) + np.testing.assert_allclose(spike_train2, spike_times2) if __name__ == "__main__": From da46881bb9960d78ca922ee487abe67a18e04ce9 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 13 Dec 2023 17:26:37 +0100 Subject: [PATCH 2/8] added starting time --- src/spikeinterface/extractors/nwbextractors.py | 16 +++++++++++----- .../extractors/tests/test_nwbextractors.py | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index edc2be51fb..59379a673a 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -536,6 +536,8 @@ def __init__( stream_mode: str | None = None, cache: bool = False, stream_cache_path: str | Path | None = None, + *, + t_start: float | None = None, ): try: from pynwb import NWBHDF5IO, NWBFile @@ -560,15 +562,17 @@ def __init__( # get rate if self.electrical_series.rate is not None: sampling_frequency = self.electrical_series.rate + self.t_start = self.electrical_series.starting_time else: if hasattr(self.electrical_series, "timestamps"): if self.electrical_series.timestamps is not None: timestamps = self.electrical_series.timestamps sampling_frequency = 1 / np.median(np.diff(timestamps[samples_for_rate_estimation])) - - assert sampling_frequency is not None, ( - "Couldn't load sampling frequency. Please provide it with the " "'sampling_frequency' argument" - ) + self.t_start = timestamps[0] + assert ( + sampling_frequency is not None + ), "Couldn't load sampling frequency. Please provide it with the 'sampling_frequency' argument" + assert t_start is not None, "Couldn't load t_start. Please provide it with the 't_start' argument" units_table = self._nwbfile.units @@ -586,6 +590,7 @@ def __init__( spike_times_data=spike_times_data, spike_times_index_data=spike_times_index_data, sampling_frequency=sampling_frequency, + t_start=t_start, ) self.add_sorting_segment(sorting_segment) @@ -628,11 +633,12 @@ def __init__( class NwbSortingSegment(BaseSortingSegment): - def __init__(self, spike_times_data, spike_times_index_data, sampling_frequency): + def __init__(self, spike_times_data, spike_times_index_data, sampling_frequency: float, t_start: float): BaseSortingSegment.__init__(self) self.spike_times_data = spike_times_data self.spike_times_index_data = spike_times_index_data self._sampling_frequency = sampling_frequency + self._t_start = t_start def get_unit_spike_train( self, diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors.py b/src/spikeinterface/extractors/tests/test_nwbextractors.py index 996ad5715a..ac6b0810c1 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors.py @@ -222,7 +222,7 @@ def test_sorting_extraction_of_ragged_arrays(tmp_path): with NWBHDF5IO(path=file_path, mode="w") as io: io.write(nwbfile) - sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=10.0) + sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=10.0, t_start=0) units_ids = sorting_extractor.get_unit_ids() From 345a9e7298aca9096618017df5e157235233bdec Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 14 Dec 2023 10:40:58 +0100 Subject: [PATCH 3/8] add tests --- .../extractors/nwbextractors.py | 7 +- .../extractors/tests/test_nwb_s3_extractor.py | 259 ------------------ .../extractors/tests/test_nwbextractors.py | 67 ++++- 3 files changed, 70 insertions(+), 263 deletions(-) delete mode 100644 src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 59379a673a..3c09bc7240 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -555,6 +555,7 @@ def __init__( ) timestamps = None + self.t_start = t_start if sampling_frequency is None: # defines the electrical series from where the sorting came from # important to know the sampling_frequency @@ -572,7 +573,9 @@ def __init__( assert ( sampling_frequency is not None ), "Couldn't load sampling frequency. Please provide it with the 'sampling_frequency' argument" - assert t_start is not None, "Couldn't load t_start. Please provide it with the 't_start' argument" + assert ( + self.t_start is not None + ), "Couldn't load a starting time for the sorting. Please provide it with the 't_start' argument" units_table = self._nwbfile.units @@ -590,7 +593,7 @@ def __init__( spike_times_data=spike_times_data, spike_times_index_data=spike_times_index_data, sampling_frequency=sampling_frequency, - t_start=t_start, + t_start=self.t_start, ) self.add_sorting_segment(sorting_segment) diff --git a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py deleted file mode 100644 index 9183c5b728..0000000000 --- a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py +++ /dev/null @@ -1,259 +0,0 @@ -from pathlib import Path -import pickle - -import pytest -import numpy as np -import h5py -from spikeinterface.core.testing import check_recordings_equal - -from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal -from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor - - -@pytest.mark.ros3_test -@pytest.mark.streaming_extractors -@pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") -def test_recording_s3_nwb_ros3(tmp_path): - file_path = ( - "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" - ) - rec = NwbRecordingExtractor(file_path, stream_mode="ros3") - - start_frame = 0 - end_frame = 300 - num_frames = end_frame - start_frame - - num_seg = rec.get_num_segments() - num_chans = rec.get_num_channels() - dtype = rec.get_dtype() - - for segment_index in range(num_seg): - num_samples = rec.get_num_samples(segment_index=segment_index) - - full_traces = rec.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame) - assert full_traces.shape == (num_frames, num_chans) - assert full_traces.dtype == dtype - - if rec.has_scaled(): - trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) - assert trace_scaled.dtype == "float32" - - tmp_file = tmp_path / "test_ros3_recording.pkl" - with open(tmp_file, "wb") as f: - pickle.dump(rec, f) - - with open(tmp_file, "rb") as f: - reloaded_recording = pickle.load(f) - - check_recordings_equal(rec, reloaded_recording) - - -@pytest.mark.streaming_extractors -@pytest.mark.parametrize("cache", [True, False]) # Test with and without cache -def test_recording_s3_nwb_fsspec(tmp_path, cache): - file_path = ( - "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" - ) - - # Instantiate NwbRecordingExtractor with the cache parameter - rec = NwbRecordingExtractor( - file_path, stream_mode="fsspec", cache=cache, stream_cache_path=tmp_path if cache else None - ) - - start_frame = 0 - end_frame = 300 - num_frames = end_frame - start_frame - - num_seg = rec.get_num_segments() - num_chans = rec.get_num_channels() - dtype = rec.get_dtype() - - for segment_index in range(num_seg): - num_samples = rec.get_num_samples(segment_index=segment_index) - - full_traces = rec.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame) - assert full_traces.shape == (num_frames, num_chans) - assert full_traces.dtype == dtype - - if rec.has_scaled(): - trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) - assert trace_scaled.dtype == "float32" - - tmp_file = tmp_path / "test_fsspec_recording.pkl" - with open(tmp_file, "wb") as f: - pickle.dump(rec, f) - - with open(tmp_file, "rb") as f: - reloaded_recording = pickle.load(f) - - check_recordings_equal(rec, reloaded_recording) - - -@pytest.mark.streaming_extractors -def test_recording_s3_nwb_remfile(): - file_path = ( - "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" - ) - rec = NwbRecordingExtractor(file_path, stream_mode="remfile") - - start_frame = 0 - end_frame = 300 - num_frames = end_frame - start_frame - - num_seg = rec.get_num_segments() - num_chans = rec.get_num_channels() - dtype = rec.get_dtype() - - for segment_index in range(num_seg): - num_samples = rec.get_num_samples(segment_index=segment_index) - - full_traces = rec.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame) - assert full_traces.shape == (num_frames, num_chans) - assert full_traces.dtype == dtype - - if rec.has_scaled(): - trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) - assert trace_scaled.dtype == "float32" - - -@pytest.mark.streaming_extractors -def test_recording_s3_nwb_remfile_file_like(tmp_path): - import remfile - - file_path = ( - "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" - ) - file = remfile.File(file_path) - rec = NwbRecordingExtractor(file=file) - - start_frame = 0 - end_frame = 300 - num_frames = end_frame - start_frame - - num_seg = rec.get_num_segments() - num_chans = rec.get_num_channels() - dtype = rec.get_dtype() - - for segment_index in range(num_seg): - num_samples = rec.get_num_samples(segment_index=segment_index) - - full_traces = rec.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame) - assert full_traces.shape == (num_frames, num_chans) - assert full_traces.dtype == dtype - - if rec.has_scaled(): - trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) - assert trace_scaled.dtype == "float32" - - # test pickling - with open(tmp_path / "rec.pkl", "wb") as f: - pickle.dump(rec, f) - with open(tmp_path / "rec.pkl", "rb") as f: - rec2 = pickle.load(f) - check_recordings_equal(rec, rec2) - - -@pytest.mark.ros3_test -@pytest.mark.streaming_extractors -@pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") -def test_sorting_s3_nwb_ros3(tmp_path): - file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" - # we provide the 'sampling_frequency' because the NWB file does not the electrical series - sort = NwbSortingExtractor(file_path, sampling_frequency=30000, stream_mode="ros3") - - start_frame = 0 - end_frame = 300 - num_frames = end_frame - start_frame - - num_seg = sort.get_num_segments() - num_units = len(sort.unit_ids) - - for segment_index in range(num_seg): - for unit in sort.unit_ids: - spike_train = sort.get_unit_spike_train(unit_id=unit, segment_index=segment_index) - assert len(spike_train) > 0 - assert spike_train.dtype == "int64" - assert np.all(spike_train >= 0) - - tmp_file = tmp_path / "test_ros3_sorting.pkl" - with open(tmp_file, "wb") as f: - pickle.dump(sort, f) - - with open(tmp_file, "rb") as f: - reloaded_sorting = pickle.load(f) - - check_sortings_equal(reloaded_sorting, sort) - - -@pytest.mark.streaming_extractors -@pytest.mark.parametrize("cache", [True, False]) # Test with and without cache -def test_sorting_s3_nwb_fsspec(tmp_path, cache): - file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" - # We provide the 'sampling_frequency' because the NWB file does not have the electrical series - sorting = NwbSortingExtractor( - file_path, - sampling_frequency=30000.0, - stream_mode="fsspec", - cache=cache, - stream_cache_path=tmp_path if cache else None, - ) - - num_seg = sorting.get_num_segments() - assert num_seg == 1 - num_units = len(sorting.unit_ids) - assert num_units == 64 - - for segment_index in range(num_seg): - for unit in sorting.unit_ids: - spike_train = sorting.get_unit_spike_train(unit_id=unit, segment_index=segment_index) - assert len(spike_train) > 0 - assert spike_train.dtype == "int64" - assert np.all(spike_train >= 0) - - tmp_file = tmp_path / "test_fsspec_sorting.pkl" - with open(tmp_file, "wb") as f: - pickle.dump(sorting, f) - - with open(tmp_file, "rb") as f: - reloaded_sorting = pickle.load(f) - - check_sortings_equal(reloaded_sorting, sorting) - - -@pytest.mark.streaming_extractors -def test_sorting_s3_nwb_remfile(tmp_path): - file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" - # We provide the 'sampling_frequency' because the NWB file does not have the electrical series - sorting = NwbSortingExtractor( - file_path, - sampling_frequency=30000.0, - stream_mode="remfile", - ) - - num_seg = sorting.get_num_segments() - assert num_seg == 1 - num_units = len(sorting.unit_ids) - assert num_units == 64 - - for segment_index in range(num_seg): - for unit in sorting.unit_ids: - spike_train = sorting.get_unit_spike_train(unit_id=unit, segment_index=segment_index) - assert len(spike_train) > 0 - assert spike_train.dtype == "int64" - assert np.all(spike_train >= 0) - - tmp_file = tmp_path / "test_remfile_sorting.pkl" - with open(tmp_file, "wb") as f: - pickle.dump(sorting, f) - - with open(tmp_file, "rb") as f: - reloaded_sorting = pickle.load(f) - - check_sortings_equal(reloaded_sorting, sorting) - - -if __name__ == "__main__": - test_recording_s3_nwb_ros3() - test_recording_s3_nwb_fsspec() - test_sorting_s3_nwb_ros3() - test_sorting_s3_nwb_fsspec() diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors.py b/src/spikeinterface/extractors/tests/test_nwbextractors.py index ac6b0810c1..d518ec8760 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors.py @@ -10,8 +10,7 @@ from pynwb.ecephys import ElectricalSeries from pynwb.testing.mock.file import mock_NWBFile from pynwb.testing.mock.device import mock_Device -from pynwb.testing.mock.ecephys import mock_ElectricalSeries, mock_ElectrodeGroup - +from pynwb.testing.mock.ecephys import mock_ElectricalSeries, mock_ElectrodeGroup, mock_electrodes from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor from spikeinterface.extractors.tests.common_tests import RecordingCommonTestSuite, SortingCommonTestSuite @@ -240,5 +239,69 @@ def test_sorting_extraction_of_ragged_arrays(tmp_path): np.testing.assert_allclose(spike_train2, spike_times2) +def test_sorting_extraction_start_time(tmp_path): + nwbfile = mock_NWBFile() + + # Add the spikes + spike_times1 = np.array([0.0, 1.0, 2.0]) + nwbfile.add_unit(spike_times=spike_times1) + spike_times2 = np.array([0.0, 1.0, 2.0, 3.0]) + nwbfile.add_unit(spike_times=spike_times2) + + file_path = tmp_path / "test.nwb" + # Write the nwbfile to a temporary file + with NWBHDF5IO(path=file_path, mode="w") as io: + io.write(nwbfile) + + t_start = 10 + sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=10.0, t_start=t_start) + + extracted_spike_times1 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=True) + expected_spike_times1 = spike_times1 + t_start + np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1) + + extracted_spike_times2 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=True) + expected_spike_times2 = spike_times2 + t_start + np.testing.assert_allclose(extracted_spike_times2, expected_spike_times2) + + +def test_sorting_extraction_start_time_from_series(tmp_path): + nwbfile = mock_NWBFile() + electrical_series_name = "ElectricalSeries" + t_start = 10.0 + + n_electrodes = 5 + electrodes = mock_electrodes(n_electrodes=n_electrodes, nwbfile=nwbfile) + electrical_series = ElectricalSeries( + name=electrical_series_name, + starting_time=t_start, + rate=1.0, + data=np.ones((10, 5)), + electrodes=electrodes, + ) + nwbfile.add_acquisition(electrical_series) + # Add the spikes + spike_times1 = np.array([0.0, 1.0, 2.0]) + t_start + nwbfile.add_unit(spike_times=spike_times1) + spike_times2 = np.array([0.0, 1.0, 2.0, 3.0]) + t_start + nwbfile.add_unit(spike_times=spike_times2) + + file_path = tmp_path / "test.nwb" + # Write the nwbfile to a temporary file + with NWBHDF5IO(path=file_path, mode="w") as io: + io.write(nwbfile) + + t_start = 10 + sorting_extractor = NwbSortingExtractor(file_path=file_path, electrical_series_name=electrical_series_name) + + extracted_spike_times1 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=True) + expected_spike_times1 = spike_times1 + t_start + np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1) + + extracted_spike_times2 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=True) + expected_spike_times2 = spike_times2 + t_start + np.testing.assert_allclose(extracted_spike_times2, expected_spike_times2) + + if __name__ == "__main__": test = NwbRecordingTest() From f93b04211cd751e27221aa23cf4d156ab5029f1a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 14 Dec 2023 12:33:47 +0100 Subject: [PATCH 4/8] recover tests --- .../extractors/nwbextractors.py | 21 +----- .../extractors/tests/test_nwbextractors.py | 68 +++++++++++++++++-- 2 files changed, 66 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index f51b6d0d93..73c76cf898 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import Union, List, Optional, Literal, Dict, BinaryIO import warnings -import warnings import numpy as np @@ -590,30 +589,13 @@ def __init__( units_table = self._nwbfile.units - name_to_column_data = {c.name: c for c in units_table.columns} - spike_times_data = name_to_column_data.pop("spike_times").data - spike_times_index_data = name_to_column_data.pop("spike_times_index").data - - units_ids = name_to_column_data.pop("unit_name", None) - if units_ids is None: - units_ids = units_table["id"].data - - units_table = self._nwbfile.units - - name_to_column_data = {c.name: c for c in units_table.columns} - spike_times_data = name_to_column_data.pop("spike_times").data - spike_times_index_data = name_to_column_data.pop("spike_times_index").data - - units_ids = name_to_column_data.pop("unit_name", None) - if units_ids is None: - units_ids = units_table["id"].data - BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=units_ids) sorting_segment = NwbSortingSegment( spike_times_data=spike_times_data, spike_times_index_data=spike_times_index_data, sampling_frequency=sampling_frequency, + t_start=self.t_start, ) self.add_sorting_segment(sorting_segment) @@ -656,6 +638,7 @@ def __init__( "cache": cache, "stream_mode": stream_mode, "stream_cache_path": stream_cache_path, + "t_start": self.t_start, } diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors.py b/src/spikeinterface/extractors/tests/test_nwbextractors.py index c363e1bec1..7fb1af5440 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors.py @@ -235,10 +235,6 @@ def test_sorting_property_extraction(tmp_path): np.testing.assert_equal(units_ids, ["a", "b"]) - units_ids = sorting_extractor.get_unit_ids() - - np.testing.assert_equal(units_ids, ["a", "b"]) - added_properties = sorting_extractor.get_property_keys() assert "non_uniform_ragged_array" not in added_properties assert "doubled_ragged_array" not in added_properties @@ -252,5 +248,69 @@ def test_sorting_property_extraction(tmp_path): np.testing.assert_allclose(spike_train2, spike_times2) +def test_sorting_extraction_start_time(tmp_path): + nwbfile = mock_NWBFile() + + # Add the spikes + spike_times1 = np.array([0.0, 1.0, 2.0]) + nwbfile.add_unit(spike_times=spike_times1) + spike_times2 = np.array([0.0, 1.0, 2.0, 3.0]) + nwbfile.add_unit(spike_times=spike_times2) + + file_path = tmp_path / "test.nwb" + # Write the nwbfile to a temporary file + with NWBHDF5IO(path=file_path, mode="w") as io: + io.write(nwbfile) + + t_start = 10 + sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=10.0, t_start=t_start) + + extracted_spike_times1 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=True) + expected_spike_times1 = spike_times1 + t_start + np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1) + + extracted_spike_times2 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=True) + expected_spike_times2 = spike_times2 + t_start + np.testing.assert_allclose(extracted_spike_times2, expected_spike_times2) + + +def test_sorting_extraction_start_time_from_series(tmp_path): + nwbfile = mock_NWBFile() + electrical_series_name = "ElectricalSeries" + t_start = 10.0 + + n_electrodes = 5 + electrodes = mock_electrodes(n_electrodes=n_electrodes, nwbfile=nwbfile) + electrical_series = ElectricalSeries( + name=electrical_series_name, + starting_time=t_start, + rate=1.0, + data=np.ones((10, 5)), + electrodes=electrodes, + ) + nwbfile.add_acquisition(electrical_series) + # Add the spikes + spike_times1 = np.array([0.0, 1.0, 2.0]) + t_start + nwbfile.add_unit(spike_times=spike_times1) + spike_times2 = np.array([0.0, 1.0, 2.0, 3.0]) + t_start + nwbfile.add_unit(spike_times=spike_times2) + + file_path = tmp_path / "test.nwb" + # Write the nwbfile to a temporary file + with NWBHDF5IO(path=file_path, mode="w") as io: + io.write(nwbfile) + + t_start = 10 + sorting_extractor = NwbSortingExtractor(file_path=file_path, electrical_series_name=electrical_series_name) + + extracted_spike_times1 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=True) + expected_spike_times1 = spike_times1 + t_start + np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1) + + extracted_spike_times2 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=True) + expected_spike_times2 = spike_times2 + t_start + np.testing.assert_allclose(extracted_spike_times2, expected_spike_times2) + + if __name__ == "__main__": test = NwbRecordingTest() From c312d8ac97849fbb3c2c3d0ac58b77c5df345af7 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 14 Dec 2023 13:23:51 +0100 Subject: [PATCH 5/8] alessio test --- .../extractors/nwbextractors.py | 2 +- .../extractors/tests/test_nwbextractors.py | 83 ++++++++++++------- 2 files changed, 53 insertions(+), 32 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index e67dbc9762..d711265cc8 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -972,7 +972,7 @@ def get_unit_spike_train( spike_times = self.spike_times_data[start_index:end_index] # Transform spike times to frames and subset - frames = np.round(spike_times * self._sampling_frequency) + frames = np.round((spike_times - self._t_start) * self._sampling_frequency) start_index = 0 if start_frame is not None: diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors.py b/src/spikeinterface/extractors/tests/test_nwbextractors.py index 14ae7b0e03..eb08c6f893 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors.py @@ -233,10 +233,10 @@ def test_sorting_extraction_of_ragged_arrays(tmp_path): nwbfile.add_unit_column(name="unit_name", description="the name of the unit") nwbfile.add_unit_column(name="a_property", description="a_cool_property") - spike_times1 = np.array([0.0, 1.0, 2.0]) - nwbfile.add_unit(spike_times=spike_times1, unit_name="a", a_property="a_property_value") - spike_times2 = np.array([0.0, 1.0, 2.0, 3.0]) - nwbfile.add_unit(spike_times=spike_times2, unit_name="b", a_property="b_property_value") + spike_times_a = np.array([0.0, 1.0, 2.0]) + nwbfile.add_unit(spike_times=spike_times_a, unit_name="a", a_property="a_property_value") + spike_times_b = np.array([0.0, 1.0, 2.0, 3.0]) + nwbfile.add_unit(spike_times=spike_times_b, unit_name="b", a_property="b_property_value") non_uniform_ragged_array = [[1, 2, 3, 8, 10], [1, 2, 3, 5]] nwbfile.add_unit_column( @@ -279,75 +279,96 @@ def test_sorting_extraction_of_ragged_arrays(tmp_path): assert "uniform_ragged_array" in added_properties assert "a_property" in added_properties - spike_train1 = sorting_extractor.get_unit_spike_train(unit_id="a", return_times=True) - np.testing.assert_allclose(spike_train1, spike_times1) + extracted_spike_times_a = sorting_extractor.get_unit_spike_train(unit_id="a", return_times=True) + np.testing.assert_allclose(extracted_spike_times_a, spike_times_a) - spike_train2 = sorting_extractor.get_unit_spike_train(unit_id="b", return_times=True) - np.testing.assert_allclose(spike_train2, spike_times2) + extracted_spike_times_b = sorting_extractor.get_unit_spike_train(unit_id="b", return_times=True) + np.testing.assert_allclose(extracted_spike_times_b, spike_times_b) def test_sorting_extraction_start_time(tmp_path): nwbfile = mock_NWBFile() # Add the spikes - spike_times1 = np.array([0.0, 1.0, 2.0]) + + t_start = 10 + sampling_frequency = 1.0 + spike_times0 = np.array([0.0, 1.0, 2.0]) + t_start + nwbfile.add_unit(spike_times=spike_times0) + spike_times1 = np.array([0.0, 1.0, 2.0, 3.0]) + t_start nwbfile.add_unit(spike_times=spike_times1) - spike_times2 = np.array([0.0, 1.0, 2.0, 3.0]) - nwbfile.add_unit(spike_times=spike_times2) file_path = tmp_path / "test.nwb" # Write the nwbfile to a temporary file with NWBHDF5IO(path=file_path, mode="w") as io: io.write(nwbfile) - t_start = 10 - sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=10.0, t_start=t_start) + sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=sampling_frequency, t_start=t_start) + + # Test frames + extracted_frames0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=False) + expected_frames = ((spike_times0 - t_start) * sampling_frequency).astype("int64") + np.testing.assert_allclose(extracted_frames0, expected_frames) + + extracted_frames1 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=False) + expected_frames = ((spike_times1 - t_start) * sampling_frequency).astype("int64") + np.testing.assert_allclose(extracted_frames1, expected_frames) - extracted_spike_times1 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=True) - expected_spike_times1 = spike_times1 + t_start - np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1) + # # Test times + # extracted_spike_times0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=True) + # expected_spike_times0 = spike_times0 + t_start + # np.testing.assert_allclose(extracted_spike_times0, expected_spike_times0) - extracted_spike_times2 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=True) - expected_spike_times2 = spike_times2 + t_start - np.testing.assert_allclose(extracted_spike_times2, expected_spike_times2) + # extracted_spike_times1 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=True) + # expected_spike_times1 = spike_times1 + t_start + # np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1) def test_sorting_extraction_start_time_from_series(tmp_path): nwbfile = mock_NWBFile() electrical_series_name = "ElectricalSeries" t_start = 10.0 - + sampling_frequency = 1.0 n_electrodes = 5 electrodes = mock_electrodes(n_electrodes=n_electrodes, nwbfile=nwbfile) electrical_series = ElectricalSeries( name=electrical_series_name, starting_time=t_start, - rate=1.0, + rate=sampling_frequency, data=np.ones((10, 5)), electrodes=electrodes, ) nwbfile.add_acquisition(electrical_series) # Add the spikes - spike_times1 = np.array([0.0, 1.0, 2.0]) + t_start + spike_times0 = np.array([0.0, 1.0, 2.0]) + t_start + nwbfile.add_unit(spike_times=spike_times0) + spike_times1 = np.array([0.0, 1.0, 2.0, 3.0]) + t_start nwbfile.add_unit(spike_times=spike_times1) - spike_times2 = np.array([0.0, 1.0, 2.0, 3.0]) + t_start - nwbfile.add_unit(spike_times=spike_times2) file_path = tmp_path / "test.nwb" # Write the nwbfile to a temporary file with NWBHDF5IO(path=file_path, mode="w") as io: io.write(nwbfile) - t_start = 10 sorting_extractor = NwbSortingExtractor(file_path=file_path, electrical_series_name=electrical_series_name) - extracted_spike_times1 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=True) - expected_spike_times1 = spike_times1 + t_start - np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1) + # Test frames + extracted_frames0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=False) + expected_frames = ((spike_times0 - t_start) * sampling_frequency).astype("int64") + np.testing.assert_allclose(extracted_frames0, expected_frames) + + extracted_frames1 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=False) + expected_frames = ((spike_times1 - t_start) * sampling_frequency).astype("int64") + np.testing.assert_allclose(extracted_frames1, expected_frames) + + # # Test returned times + # extracted_spike_times0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=True) + # expected_spike_times0 = spike_times1 + t_start + # np.testing.assert_allclose(extracted_spike_times0, expected_spike_times0) - extracted_spike_times2 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=True) - expected_spike_times2 = spike_times2 + t_start - np.testing.assert_allclose(extracted_spike_times2, expected_spike_times2) + # extracted_spike_times1 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=True) + # expected_spike_times1 = spike_times1 + t_start + # np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1) if __name__ == "__main__": From 21e3b58c1149bbabacb1552a46e0981b2797e3eb Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 14 Dec 2023 16:09:04 +0100 Subject: [PATCH 6/8] fix the tests --- .../extractors/tests/test_nwbextractors.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors.py b/src/spikeinterface/extractors/tests/test_nwbextractors.py index eb08c6f893..5d5c8f51fd 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors.py @@ -314,14 +314,14 @@ def test_sorting_extraction_start_time(tmp_path): expected_frames = ((spike_times1 - t_start) * sampling_frequency).astype("int64") np.testing.assert_allclose(extracted_frames1, expected_frames) - # # Test times - # extracted_spike_times0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=True) - # expected_spike_times0 = spike_times0 + t_start - # np.testing.assert_allclose(extracted_spike_times0, expected_spike_times0) + # Test times + extracted_spike_times0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=True) + expected_spike_times0 = spike_times0 + np.testing.assert_allclose(extracted_spike_times0, expected_spike_times0) - # extracted_spike_times1 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=True) - # expected_spike_times1 = spike_times1 + t_start - # np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1) + extracted_spike_times1 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=True) + expected_spike_times1 = spike_times1 + np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1) def test_sorting_extraction_start_time_from_series(tmp_path): @@ -361,14 +361,14 @@ def test_sorting_extraction_start_time_from_series(tmp_path): expected_frames = ((spike_times1 - t_start) * sampling_frequency).astype("int64") np.testing.assert_allclose(extracted_frames1, expected_frames) - # # Test returned times - # extracted_spike_times0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=True) - # expected_spike_times0 = spike_times1 + t_start - # np.testing.assert_allclose(extracted_spike_times0, expected_spike_times0) + # Test returned times + extracted_spike_times0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=True) + expected_spike_times0 = spike_times0 + np.testing.assert_allclose(extracted_spike_times0, expected_spike_times0) - # extracted_spike_times1 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=True) - # expected_spike_times1 = spike_times1 + t_start - # np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1) + extracted_spike_times1 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=True) + expected_spike_times1 = spike_times1 + np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1) if __name__ == "__main__": From cbdf4001f516b62aa52db134e1aa5b9f4900525f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 14 Dec 2023 16:20:49 +0100 Subject: [PATCH 7/8] ad docstring --- src/spikeinterface/extractors/nwbextractors.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index d711265cc8..96f4081796 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -819,6 +819,19 @@ class NwbSortingExtractor(BaseSorting): if False, the file is not cached. stream_cache_path: str or Path or None, default: None Local path for caching. If None it uses the system temporary directory. + t_start: float or None, default: None + This is the time at which the corresponding ElectricalSeries start. NWB stores its spikes as times + and the `t_start` is used to convert the times to seconds. Concrently, the returned frames are computed as: + + `frames = (times - t_start) * sampling_frequency`. + + As SpikeInterface always considers the first frame to be at the beginning of the recording independently + of the `t_start`. + + When a `t_start` is not provided it will be inferred from the corresponding ElectricalSeries with name equal + to `electrical_series_name`. The `t_start` then will be either the `ElectricalSeries.starting_time` or the + first timestamp in the `ElectricalSeries.timestamps`. + Returns ------- @@ -838,10 +851,10 @@ def __init__( sampling_frequency: float | None = None, samples_for_rate_estimation: int = 1000, stream_mode: str | None = None, - cache: bool = False, stream_cache_path: str | Path | None = None, *, t_start: float | None = None, + cache: bool = False, ): try: from pynwb import NWBHDF5IO, NWBFile From 34aaf917e640ff5ab1f0e665f36ab1398d532eeb Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 15 Dec 2023 11:41:05 +0100 Subject: [PATCH 8/8] Alessio feedback, eliminate duplication of units table definition --- src/spikeinterface/extractors/nwbextractors.py | 2 -- src/spikeinterface/extractors/tests/test_nwbextractors.py | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 96f4081796..6d3b49cb3f 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -904,8 +904,6 @@ def __init__( if units_ids is None: units_ids = units_table["id"].data - units_table = self._nwbfile.units - BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=units_ids) sorting_segment = NwbSortingSegment( diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors.py b/src/spikeinterface/extractors/tests/test_nwbextractors.py index 5d5c8f51fd..4a076fac7f 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors.py @@ -292,7 +292,7 @@ def test_sorting_extraction_start_time(tmp_path): # Add the spikes t_start = 10 - sampling_frequency = 1.0 + sampling_frequency = 100.0 spike_times0 = np.array([0.0, 1.0, 2.0]) + t_start nwbfile.add_unit(spike_times=spike_times0) spike_times1 = np.array([0.0, 1.0, 2.0, 3.0]) + t_start @@ -328,7 +328,7 @@ def test_sorting_extraction_start_time_from_series(tmp_path): nwbfile = mock_NWBFile() electrical_series_name = "ElectricalSeries" t_start = 10.0 - sampling_frequency = 1.0 + sampling_frequency = 100.0 n_electrodes = 5 electrodes = mock_electrodes(n_electrodes=n_electrodes, nwbfile=nwbfile) electrical_series = ElectricalSeries(