From d74080a0cf49866e013de0fdc3fb1d211fea3f71 Mon Sep 17 00:00:00 2001 From: Mainak Jas Date: Fri, 25 Oct 2024 12:24:46 -0400 Subject: [PATCH] [MRG] ENH: add new property to access spike times by cell type (#916) * ENH: add new property to access spike times by cell type * flake8 * DOC: update whats new --- doc/whats_new.rst | 2 ++ hnn_core/cell_response.py | 20 ++++++++++++++++++++ hnn_core/tests/test_cell_response.py | 5 +++++ 3 files changed, 27 insertions(+) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 338aa16e3..4cebe27d6 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -33,6 +33,8 @@ Bug API ~~~ +- Add :func:`~hnn_core.CellResponse.spike_times_by_type` to get cell spiking times + organized by cell type, by `Mainak Jas`_ in :gh:`916`. .. _0.4: diff --git a/hnn_core/cell_response.py b/hnn_core/cell_response.py index 248eca559..f6246dbd2 100644 --- a/hnn_core/cell_response.py +++ b/hnn_core/cell_response.py @@ -153,6 +153,26 @@ def __eq__(self, other): def spike_times(self): return self._spike_times + @property + def cell_types(self): + """Get unique cell types.""" + spike_types_data = np.concatenate(np.array(self.spike_types, + dtype=object)) + return np.unique(spike_types_data).tolist() + + @property + def spike_times_by_type(self): + """Get a dictionary of spike times by cell type""" + spike_times = dict() + for cell_type in self.cell_types: + spike_times[cell_type] = list() + for trial_spike_times, trial_spike_types in zip(self.spike_times, + self.spike_types): + mask = np.isin(trial_spike_types, cell_type) + cell_spike_times = np.array(trial_spike_times)[mask].tolist() + spike_times[cell_type].append(cell_spike_times) + return spike_times + @property def spike_gids(self): return self._spike_gids diff --git a/hnn_core/tests/test_cell_response.py b/hnn_core/tests/test_cell_response.py index 22a77f1af..32eb61ea4 100644 --- a/hnn_core/tests/test_cell_response.py +++ b/hnn_core/tests/test_cell_response.py @@ -24,6 +24,11 @@ def test_cell_response(tmp_path): spike_gids=spike_gids, spike_types=spike_types, times=sim_times) + + assert set(cell_response.cell_types) == set(gid_ranges.keys()) + assert cell_response.spike_times_by_type['L2_basket'] == [[7.89], []] + assert cell_response.spike_times_by_type['L5_pyramidal'] == [[], [4.2812]] + kwargs_hist = dict(alpha=0.25) fig = cell_response.plot_spikes_hist(show=False, **kwargs_hist) assert all(patch.get_alpha() == kwargs_hist['alpha']