Skip to content

Commit

Permalink
Fix NwbSortingExtractor reading of ragged arrays (#2255)
Browse files Browse the repository at this point in the history
* add failing test

* fix test

* addeed test for properties
  • Loading branch information
h-mayorquin authored Nov 24, 2023
1 parent fcbf422 commit ebecb22
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 16 deletions.
40 changes: 24 additions & 16 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,28 +508,36 @@ def __init__(
"Couldn't load sampling frequency. Please provide it with the " "'sampling_frequency' argument"
)

# get all units ids
BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=units_ids)
sorting_segment = NwbSortingSegment(
nwbfile=self._nwbfile, sampling_frequency=sampling_frequency, timestamps=timestamps
)
self.add_sorting_segment(sorting_segment)

# store units properties and spike features to dictionaries
# Add properties:
properties = dict()
import warnings

for column in list(self._nwbfile.units.colnames):
if column == "spike_times":
continue
# if it is unit_property

# Note that this has a different behavior than self._nwbfile.units[column].data
property_values = self._nwbfile.units[column][:]

# only load columns with same shape for all units
if np.all(p.shape == property_values[0].shape for p in property_values):
properties[column] = property_values
else:
print(f"Skipping {column} because of unequal shapes across units")
# Making this explicit because I am not sure this is the best test
is_raggged_array = isinstance(property_values, list)
if is_raggged_array:
all_values_have_equal_shape = np.all([p.shape == property_values[0].shape for p in property_values])
if all_values_have_equal_shape:
properties[column] = property_values
else:
warnings.warn(f"Skipping {column} because of unequal shapes across units")

BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=units_ids)
sorting_segment = NwbSortingSegment(
nwbfile=self._nwbfile, sampling_frequency=sampling_frequency, timestamps=timestamps
)
self.add_sorting_segment(sorting_segment)
continue # To next property

# The rest of the properties are added as they come
properties[column] = property_values

for prop_name, values in properties.items():
self.set_property(prop_name, np.array(values))
Expand Down Expand Up @@ -569,10 +577,10 @@ def get_unit_spike_train(
spike_times = self._nwbfile.units["spike_times"][list(self._nwbfile.units.id[:]).index(unit_id)][:]

if self._timestamps is not None:
frames = np.searchsorted(spike_times, self.timestamps).astype("int64")
frames = np.searchsorted(spike_times, self.timestamps)
else:
frames = np.round(spike_times * self._sampling_frequency).astype("int64")
return frames[(frames >= start_frame) & (frames < end_frame)]
frames = np.round(spike_times * self._sampling_frequency)
return frames[(frames >= start_frame) & (frames < end_frame)].astype("int64", copy=False)


read_nwb_recording = define_function_from_class(source_class=NwbRecordingExtractor, name="read_nwb_recording")
Expand Down
42 changes: 42 additions & 0 deletions src/spikeinterface/extractors/tests/test_nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class NwbSortingTest(SortingCommonTestSuite, unittest.TestCase):
entities = []


from pynwb.testing.mock.ecephys import mock_ElectrodeGroup


@pytest.fixture(scope="module")
def nwbfile_with_ecephys_content():
nwbfile = mock_NWBFile()
Expand Down Expand Up @@ -182,5 +185,44 @@ def test_nwb_extractor_offset_from_series(path_to_nwbfile, nwbfile_with_ecephys_
assert np.array_equal(extracted_offsets_uV, expected_offsets_uV)


def test_sorting_extraction_of_ragged_arrays(tmp_path):
nwbfile = mock_NWBFile()

# Add the spikes
nwbfile.add_unit_column(name="unit_name", description="the name of the unit")
spike_times1 = np.array([0.0, 1.0, 2.0])
nwbfile.add_unit(spike_times=spike_times1, unit_name="a")
spike_times2 = np.array([0.0, 1.0, 2.0, 3.0])
nwbfile.add_unit(spike_times=spike_times2, unit_name="b")

ragged_array_bad = [[1, 2, 3], [1, 2, 3, 5]]
nwbfile.add_unit_column(
name="ragged_array_bad",
description="an evill array that wants to destroy your test",
data=ragged_array_bad,
index=True,
)

ragged_array_good = [[1, 2], [3, 4]]
nwbfile.add_unit_column(
name="ragged_array_good",
description="a good array that wants to help your test be nice to nice arrays",
data=ragged_array_good,
index=True,
)

file_path = tmp_path / "test.nwb"
# Write the nwbfile to a temporary file
with NWBHDF5IO(path=file_path, mode="w") as io:
io.write(nwbfile)

sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=10.0)

# Check that the bad array was not added
added_properties = sorting_extractor.get_property_keys()
assert "ragged_array_bad" not in added_properties
assert "ragged_array_good" in added_properties


if __name__ == "__main__":
test = NwbRecordingTest()

0 comments on commit ebecb22

Please sign in to comment.