Skip to content

Commit

Permalink
Merge pull request SpikeInterface#2275 from h-mayorquin/add_rem_to_nw…
Browse files Browse the repository at this point in the history
…b_sorting

Add nwb sorting rem file support
  • Loading branch information
alejoe91 authored Dec 1, 2023
2 parents 217eec1 + 18de6a7 commit 682d727
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 16 deletions.
28 changes: 12 additions & 16 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ class NwbRecordingExtractor(BaseRecording):
samples_for_rate_estimation: int, default: 100000
The number of timestamp samples to use to estimate the rate.
Used if "rate" is not specified in the ElectricalSeries.
stream_mode: str or None, default: None
Specify the stream mode: "fsspec" or "ros3".
stream_mode : "fsspec" | "ros3" | "remfile" | None, default: None
The streaming mode to use. If None it assumes the file is on the local disk.
cache: bool, default: False
If True, the file is cached in the file passed to stream_cache_path
if False, the file is not cached.
Expand Down Expand Up @@ -412,12 +412,11 @@ def __init__(
else:
self.set_property(property_name, values)

if stream_mode not in ["fsspec", "ros3", "remfile"]:
if file_path is not None:
file_path = str(Path(file_path).absolute())
if stream_mode == "fsspec":
if stream_cache_path is not None:
stream_cache_path = str(Path(self.stream_cache_path).absolute())
if stream_mode is None and file_path is not None:
file_path = str(Path(file_path).resolve())

if stream_mode == "fsspec" and stream_cache_path is not None:
stream_cache_path = str(Path(self.stream_cache_path).absolute())

self.extra_requirements.extend(["pandas", "pynwb", "hdmf"])
self._electrical_series = electrical_series
Expand Down Expand Up @@ -493,8 +492,8 @@ class NwbSortingExtractor(BaseSorting):
samples_for_rate_estimation: int, default: 100000
The number of timestamp samples to use to estimate the rate.
Used if "rate" is not specified in the ElectricalSeries.
stream_mode: str or None, default: None
Specify the stream mode: "fsspec" or "ros3".
stream_mode : "fsspec" | "ros3" | "remfile" | None, default: None
The streaming mode to use. If None it assumes the file is on the local disk.
cache: bool, default: False
If True, the file is cached in the file passed to stream_cache_path
if False, the file is not cached.
Expand Down Expand Up @@ -591,12 +590,9 @@ def __init__(
for prop_name, values in properties.items():
self.set_property(prop_name, np.array(values))

if stream_mode not in ["fsspec", "ros3"]:
file_path = str(Path(file_path).absolute())
if stream_mode == "fsspec":
# only add stream_cache_path to kwargs if it was passed as an argument
if stream_cache_path is not None:
stream_cache_path = str(Path(self.stream_cache_path).absolute())
if stream_mode is None and file_path is not None:
file_path = str(Path(file_path).resolve())

self._kwargs = {
"file_path": file_path,
"electrical_series_name": self._electrical_series_name,
Expand Down
32 changes: 32 additions & 0 deletions src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,38 @@ def test_sorting_s3_nwb_fsspec(tmp_path, cache):
check_sortings_equal(reloaded_sorting, sorting)


@pytest.mark.streaming_extractors
def test_sorting_s3_nwb_remfile(tmp_path):
file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b"
# We provide the 'sampling_frequency' because the NWB file does not have the electrical series
sorting = NwbSortingExtractor(
file_path,
sampling_frequency=30000.0,
stream_mode="remfile",
)

num_seg = sorting.get_num_segments()
assert num_seg == 1
num_units = len(sorting.unit_ids)
assert num_units == 64

for segment_index in range(num_seg):
for unit in sorting.unit_ids:
spike_train = sorting.get_unit_spike_train(unit_id=unit, segment_index=segment_index)
assert len(spike_train) > 0
assert spike_train.dtype == "int64"
assert np.all(spike_train >= 0)

tmp_file = tmp_path / "test_remfile_sorting.pkl"
with open(tmp_file, "wb") as f:
pickle.dump(sorting, f)

with open(tmp_file, "rb") as f:
reloaded_sorting = pickle.load(f)

check_sortings_equal(reloaded_sorting, sorting)


if __name__ == "__main__":
test_recording_s3_nwb_ros3()
test_recording_s3_nwb_fsspec()
Expand Down

0 comments on commit 682d727

Please sign in to comment.