Skip to content

Commit

Permalink
Merge pull request #2333 from h-mayorquin/add_t_start_to_nwb_sorting_…
Browse files Browse the repository at this point in the history
…extractor

Add t start to nwb sorting extractor
  • Loading branch information
alejoe91 authored Dec 15, 2023
2 parents b88ee2d + 30f8557 commit 4f46965
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 21 deletions.
41 changes: 32 additions & 9 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,19 @@ class NwbSortingExtractor(BaseSorting):
if False, the file is not cached.
stream_cache_path: str or Path or None, default: None
Local path for caching. If None it uses the system temporary directory.
t_start: float or None, default: None
This is the time at which the corresponding ElectricalSeries start. NWB stores its spikes as times
and the `t_start` is used to convert the times to seconds. Concrently, the returned frames are computed as:
`frames = (times - t_start) * sampling_frequency`.
As SpikeInterface always considers the first frame to be at the beginning of the recording independently
of the `t_start`.
When a `t_start` is not provided it will be inferred from the corresponding ElectricalSeries with name equal
to `electrical_series_name`. The `t_start` then will be either the `ElectricalSeries.starting_time` or the
first timestamp in the `ElectricalSeries.timestamps`.
Returns
-------
Expand All @@ -838,8 +851,10 @@ def __init__(
sampling_frequency: float | None = None,
samples_for_rate_estimation: int = 1000,
stream_mode: str | None = None,
cache: bool = False,
stream_cache_path: str | Path | None = None,
*,
t_start: float | None = None,
cache: bool = False,
):
try:
from pynwb import NWBHDF5IO, NWBFile
Expand All @@ -857,22 +872,27 @@ def __init__(
)

timestamps = None
self.t_start = t_start
if sampling_frequency is None:
# defines the electrical series from where the sorting came from
# important to know the sampling_frequency
self.electrical_series = retrieve_electrical_series(self._nwbfile, self._electrical_series_name)
# get rate
if self.electrical_series.rate is not None:
sampling_frequency = self.electrical_series.rate
self.t_start = self.electrical_series.starting_time
else:
if hasattr(self.electrical_series, "timestamps"):
if self.electrical_series.timestamps is not None:
timestamps = self.electrical_series.timestamps
sampling_frequency = 1 / np.median(np.diff(timestamps[samples_for_rate_estimation]))

assert sampling_frequency is not None, (
"Couldn't load sampling frequency. Please provide it with the " "'sampling_frequency' argument"
)
self.t_start = timestamps[0]
assert (
sampling_frequency is not None
), "Couldn't load sampling frequency. Please provide it with the 'sampling_frequency' argument"
assert (
self.t_start is not None
), "Couldn't load a starting time for the sorting. Please provide it with the 't_start' argument"

units_table = self._nwbfile.units

Expand All @@ -890,6 +910,7 @@ def __init__(
spike_times_data=spike_times_data,
spike_times_index_data=spike_times_index_data,
sampling_frequency=sampling_frequency,
t_start=self.t_start,
)
self.add_sorting_segment(sorting_segment)

Expand Down Expand Up @@ -932,15 +953,19 @@ def __init__(
"cache": cache,
"stream_mode": stream_mode,
"stream_cache_path": stream_cache_path,
"t_start": self.t_start,
}


class NwbSortingSegment(BaseSortingSegment):
def __init__(self, spike_times_data, spike_times_index_data, sampling_frequency):
def __init__(self, spike_times_data, spike_times_index_data, sampling_frequency: float, t_start: float):
BaseSortingSegment.__init__(self)
self.spike_times_data = spike_times_data
self.spike_times_index_data = spike_times_index_data
self.spike_times_data = spike_times_data
self.spike_times_index_data = spike_times_index_data
self._sampling_frequency = sampling_frequency
self._t_start = t_start

def get_unit_spike_train(
self,
Expand All @@ -958,13 +983,11 @@ def get_unit_spike_train(
spike_times = self.spike_times_data[start_index:end_index]

# Transform spike times to frames and subset
frames = np.round(spike_times * self._sampling_frequency)
frames = np.round((spike_times - self._t_start) * self._sampling_frequency)

start_index = 0
if start_frame is not None:
start_index = np.searchsorted(frames, start_frame, side="left")
else:
start_index = 0

if end_frame is not None:
end_index = np.searchsorted(frames, end_frame, side="left")
Expand Down
106 changes: 95 additions & 11 deletions src/spikeinterface/extractors/tests/test_nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from pynwb.ecephys import ElectricalSeries
from pynwb.testing.mock.file import mock_NWBFile
from pynwb.testing.mock.device import mock_Device
from pynwb.testing.mock.ecephys import mock_ElectricalSeries, mock_ElectrodeGroup

from pynwb.testing.mock.ecephys import mock_ElectricalSeries, mock_ElectrodeGroup, mock_electrodes
from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor

from spikeinterface.extractors.tests.common_tests import RecordingCommonTestSuite, SortingCommonTestSuite
Expand Down Expand Up @@ -234,10 +233,10 @@ def test_sorting_extraction_of_ragged_arrays(tmp_path):
nwbfile.add_unit_column(name="unit_name", description="the name of the unit")
nwbfile.add_unit_column(name="a_property", description="a_cool_property")

spike_times1 = np.array([0.0, 1.0, 2.0])
nwbfile.add_unit(spike_times=spike_times1, unit_name="a", a_property="a_property_value")
spike_times2 = np.array([0.0, 1.0, 2.0, 3.0])
nwbfile.add_unit(spike_times=spike_times2, unit_name="b", a_property="b_property_value")
spike_times_a = np.array([0.0, 1.0, 2.0])
nwbfile.add_unit(spike_times=spike_times_a, unit_name="a", a_property="a_property_value")
spike_times_b = np.array([0.0, 1.0, 2.0, 3.0])
nwbfile.add_unit(spike_times=spike_times_b, unit_name="b", a_property="b_property_value")

non_uniform_ragged_array = [[1, 2, 3, 8, 10], [1, 2, 3, 5]]
nwbfile.add_unit_column(
Expand Down Expand Up @@ -268,7 +267,7 @@ def test_sorting_extraction_of_ragged_arrays(tmp_path):
with NWBHDF5IO(path=file_path, mode="w") as io:
io.write(nwbfile)

sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=10.0)
sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=10.0, t_start=0)

units_ids = sorting_extractor.get_unit_ids()

Expand All @@ -280,11 +279,96 @@ def test_sorting_extraction_of_ragged_arrays(tmp_path):
assert "uniform_ragged_array" in added_properties
assert "a_property" in added_properties

spike_train1 = sorting_extractor.get_unit_spike_train(unit_id="a", return_times=True)
np.testing.assert_allclose(spike_train1, spike_times1)
extracted_spike_times_a = sorting_extractor.get_unit_spike_train(unit_id="a", return_times=True)
np.testing.assert_allclose(extracted_spike_times_a, spike_times_a)

extracted_spike_times_b = sorting_extractor.get_unit_spike_train(unit_id="b", return_times=True)
np.testing.assert_allclose(extracted_spike_times_b, spike_times_b)


def test_sorting_extraction_start_time(tmp_path):
nwbfile = mock_NWBFile()

# Add the spikes

t_start = 10
sampling_frequency = 100.0
spike_times0 = np.array([0.0, 1.0, 2.0]) + t_start
nwbfile.add_unit(spike_times=spike_times0)
spike_times1 = np.array([0.0, 1.0, 2.0, 3.0]) + t_start
nwbfile.add_unit(spike_times=spike_times1)

file_path = tmp_path / "test.nwb"
# Write the nwbfile to a temporary file
with NWBHDF5IO(path=file_path, mode="w") as io:
io.write(nwbfile)

sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=sampling_frequency, t_start=t_start)

# Test frames
extracted_frames0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=False)
expected_frames = ((spike_times0 - t_start) * sampling_frequency).astype("int64")
np.testing.assert_allclose(extracted_frames0, expected_frames)

extracted_frames1 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=False)
expected_frames = ((spike_times1 - t_start) * sampling_frequency).astype("int64")
np.testing.assert_allclose(extracted_frames1, expected_frames)

# Test times
extracted_spike_times0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=True)
expected_spike_times0 = spike_times0
np.testing.assert_allclose(extracted_spike_times0, expected_spike_times0)

extracted_spike_times1 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=True)
expected_spike_times1 = spike_times1
np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1)


def test_sorting_extraction_start_time_from_series(tmp_path):
nwbfile = mock_NWBFile()
electrical_series_name = "ElectricalSeries"
t_start = 10.0
sampling_frequency = 100.0
n_electrodes = 5
electrodes = mock_electrodes(n_electrodes=n_electrodes, nwbfile=nwbfile)
electrical_series = ElectricalSeries(
name=electrical_series_name,
starting_time=t_start,
rate=sampling_frequency,
data=np.ones((10, 5)),
electrodes=electrodes,
)
nwbfile.add_acquisition(electrical_series)
# Add the spikes
spike_times0 = np.array([0.0, 1.0, 2.0]) + t_start
nwbfile.add_unit(spike_times=spike_times0)
spike_times1 = np.array([0.0, 1.0, 2.0, 3.0]) + t_start
nwbfile.add_unit(spike_times=spike_times1)

file_path = tmp_path / "test.nwb"
# Write the nwbfile to a temporary file
with NWBHDF5IO(path=file_path, mode="w") as io:
io.write(nwbfile)

sorting_extractor = NwbSortingExtractor(file_path=file_path, electrical_series_name=electrical_series_name)

# Test frames
extracted_frames0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=False)
expected_frames = ((spike_times0 - t_start) * sampling_frequency).astype("int64")
np.testing.assert_allclose(extracted_frames0, expected_frames)

extracted_frames1 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=False)
expected_frames = ((spike_times1 - t_start) * sampling_frequency).astype("int64")
np.testing.assert_allclose(extracted_frames1, expected_frames)

# Test returned times
extracted_spike_times0 = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=True)
expected_spike_times0 = spike_times0
np.testing.assert_allclose(extracted_spike_times0, expected_spike_times0)

spike_train2 = sorting_extractor.get_unit_spike_train(unit_id="b", return_times=True)
np.testing.assert_allclose(spike_train2, spike_times2)
extracted_spike_times1 = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=True)
expected_spike_times1 = spike_times1
np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_recording_s3_nwb_remfile_file_like(tmp_path):
def test_sorting_s3_nwb_ros3(tmp_path):
file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b"
# we provide the 'sampling_frequency' because the NWB file does not the electrical series
sort = NwbSortingExtractor(file_path, sampling_frequency=30000, stream_mode="ros3")
sort = NwbSortingExtractor(file_path, sampling_frequency=30000, stream_mode="ros3", t_start=0)

start_frame = 0
end_frame = 300
Expand Down Expand Up @@ -196,6 +196,7 @@ def test_sorting_s3_nwb_fsspec(tmp_path, cache):
stream_mode="fsspec",
cache=cache,
stream_cache_path=tmp_path if cache else None,
t_start=0,
)

num_seg = sorting.get_num_segments()
Expand Down Expand Up @@ -228,6 +229,7 @@ def test_sorting_s3_nwb_remfile(tmp_path):
file_path,
sampling_frequency=30000.0,
stream_mode="remfile",
t_start=0,
)

num_seg = sorting.get_num_segments()
Expand Down

0 comments on commit 4f46965

Please sign in to comment.