From ebecb22554a22a91849cde2516c8f058bd45e2a8 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 24 Nov 2023 11:26:13 +0100 Subject: [PATCH] Fix `NwbSortingExtractor` reading of ragged arrays (#2255) * add failing test * fix test * addeed test for properties --- .../extractors/nwbextractors.py | 40 +++++++++++------- .../extractors/tests/test_nwbextractors.py | 42 +++++++++++++++++++ 2 files changed, 66 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 010b22975c..55aa4f6943 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -508,28 +508,36 @@ def __init__( "Couldn't load sampling frequency. Please provide it with the " "'sampling_frequency' argument" ) - # get all units ids + BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=units_ids) + sorting_segment = NwbSortingSegment( + nwbfile=self._nwbfile, sampling_frequency=sampling_frequency, timestamps=timestamps + ) + self.add_sorting_segment(sorting_segment) - # store units properties and spike features to dictionaries + # Add properties: properties = dict() + import warnings for column in list(self._nwbfile.units.colnames): if column == "spike_times": continue - # if it is unit_property + + # Note that this has a different behavior than self._nwbfile.units[column].data property_values = self._nwbfile.units[column][:] - # only load columns with same shape for all units - if np.all(p.shape == property_values[0].shape for p in property_values): - properties[column] = property_values - else: - print(f"Skipping {column} because of unequal shapes across units") + # Making this explicit because I am not sure this is the best test + is_raggged_array = isinstance(property_values, list) + if is_raggged_array: + all_values_have_equal_shape = np.all([p.shape == property_values[0].shape for p in property_values]) + if all_values_have_equal_shape: + properties[column] = property_values + else: + warnings.warn(f"Skipping {column} because of unequal shapes across units") - BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=units_ids) - sorting_segment = NwbSortingSegment( - nwbfile=self._nwbfile, sampling_frequency=sampling_frequency, timestamps=timestamps - ) - self.add_sorting_segment(sorting_segment) + continue # To next property + + # The rest of the properties are added as they come + properties[column] = property_values for prop_name, values in properties.items(): self.set_property(prop_name, np.array(values)) @@ -569,10 +577,10 @@ def get_unit_spike_train( spike_times = self._nwbfile.units["spike_times"][list(self._nwbfile.units.id[:]).index(unit_id)][:] if self._timestamps is not None: - frames = np.searchsorted(spike_times, self.timestamps).astype("int64") + frames = np.searchsorted(spike_times, self.timestamps) else: - frames = np.round(spike_times * self._sampling_frequency).astype("int64") - return frames[(frames >= start_frame) & (frames < end_frame)] + frames = np.round(spike_times * self._sampling_frequency) + return frames[(frames >= start_frame) & (frames < end_frame)].astype("int64", copy=False) read_nwb_recording = define_function_from_class(source_class=NwbRecordingExtractor, name="read_nwb_recording") diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors.py b/src/spikeinterface/extractors/tests/test_nwbextractors.py index dad80111a7..400cb9311e 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors.py @@ -29,6 +29,9 @@ class NwbSortingTest(SortingCommonTestSuite, unittest.TestCase): entities = [] +from pynwb.testing.mock.ecephys import mock_ElectrodeGroup + + @pytest.fixture(scope="module") def nwbfile_with_ecephys_content(): nwbfile = mock_NWBFile() @@ -182,5 +185,44 @@ def test_nwb_extractor_offset_from_series(path_to_nwbfile, nwbfile_with_ecephys_ assert np.array_equal(extracted_offsets_uV, expected_offsets_uV) +def test_sorting_extraction_of_ragged_arrays(tmp_path): + nwbfile = mock_NWBFile() + + # Add the spikes + nwbfile.add_unit_column(name="unit_name", description="the name of the unit") + spike_times1 = np.array([0.0, 1.0, 2.0]) + nwbfile.add_unit(spike_times=spike_times1, unit_name="a") + spike_times2 = np.array([0.0, 1.0, 2.0, 3.0]) + nwbfile.add_unit(spike_times=spike_times2, unit_name="b") + + ragged_array_bad = [[1, 2, 3], [1, 2, 3, 5]] + nwbfile.add_unit_column( + name="ragged_array_bad", + description="an evill array that wants to destroy your test", + data=ragged_array_bad, + index=True, + ) + + ragged_array_good = [[1, 2], [3, 4]] + nwbfile.add_unit_column( + name="ragged_array_good", + description="a good array that wants to help your test be nice to nice arrays", + data=ragged_array_good, + index=True, + ) + + file_path = tmp_path / "test.nwb" + # Write the nwbfile to a temporary file + with NWBHDF5IO(path=file_path, mode="w") as io: + io.write(nwbfile) + + sorting_extractor = NwbSortingExtractor(file_path=file_path, sampling_frequency=10.0) + + # Check that the bad array was not added + added_properties = sorting_extractor.get_property_keys() + assert "ragged_array_bad" not in added_properties + assert "ragged_array_good" in added_properties + + if __name__ == "__main__": test = NwbRecordingTest()