Skip to content

Commit

Permalink
[MRG] ENH: add new property to access spike times by cell type (jones…
Browse files Browse the repository at this point in the history
…compneurolab#916)

* ENH: add new property to access spike times by cell type

* flake8

* DOC: update whats new
  • Loading branch information
jasmainak authored Oct 25, 2024
1 parent b616356 commit d74080a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 0 deletions.
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ Bug

API
~~~
- Add :func:`~hnn_core.CellResponse.spike_times_by_type` to get cell spiking times
organized by cell type, by `Mainak Jas`_ in :gh:`916`.

.. _0.4:

Expand Down
20 changes: 20 additions & 0 deletions hnn_core/cell_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,26 @@ def __eq__(self, other):
def spike_times(self):
return self._spike_times

@property
def cell_types(self):
"""Get unique cell types."""
spike_types_data = np.concatenate(np.array(self.spike_types,
dtype=object))
return np.unique(spike_types_data).tolist()

@property
def spike_times_by_type(self):
"""Get a dictionary of spike times by cell type"""
spike_times = dict()
for cell_type in self.cell_types:
spike_times[cell_type] = list()
for trial_spike_times, trial_spike_types in zip(self.spike_times,
self.spike_types):
mask = np.isin(trial_spike_types, cell_type)
cell_spike_times = np.array(trial_spike_times)[mask].tolist()
spike_times[cell_type].append(cell_spike_times)
return spike_times

@property
def spike_gids(self):
return self._spike_gids
Expand Down
5 changes: 5 additions & 0 deletions hnn_core/tests/test_cell_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def test_cell_response(tmp_path):
spike_gids=spike_gids,
spike_types=spike_types,
times=sim_times)

assert set(cell_response.cell_types) == set(gid_ranges.keys())
assert cell_response.spike_times_by_type['L2_basket'] == [[7.89], []]
assert cell_response.spike_times_by_type['L5_pyramidal'] == [[], [4.2812]]

kwargs_hist = dict(alpha=0.25)
fig = cell_response.plot_spikes_hist(show=False, **kwargs_hist)
assert all(patch.get_alpha() == kwargs_hist['alpha']
Expand Down

0 comments on commit d74080a

Please sign in to comment.