Skip to content

Commit

Permalink
Add plot_spike_raster changes in Chanalog
Browse files Browse the repository at this point in the history
  • Loading branch information
samadpls committed May 13, 2024
1 parent 5e04b17 commit d804451
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ Changelog
- Added gui widgets to save simulation as csv and updated the file upload to support csv data,
by `Camilo Diaz`_ in :gh:`753`

- Updated :func:`~hnn_core/viz/plot_spikes_raster` logic to always include all
cell types in raster plots, by `Abdul Samad Siddiqui`_ in :gh:`754`.

Bug
~~~
- Fix inconsistent connection mapping from drive gids to cell gids, by
Expand Down
3 changes: 2 additions & 1 deletion hnn_core/tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def test_dipole_visualization():
with pytest.raises(TypeError, match="trial_idx must be an instance of"):
net.cell_response.plot_spikes_raster(trial_idx='blah', show=False)
net.cell_response.plot_spikes_raster(trial_idx=0, show=False)
net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False)
fig = net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False)
assert len(fig.axes[0].collections) > 0, "No data plotted in raster plot"

with pytest.raises(TypeError, match="trial_idx must be an instance of"):
net.cell_response.plot_spikes_hist(trial_idx='blah')
Expand Down
22 changes: 9 additions & 13 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,17 +514,13 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
_validate_type(trial_idx, list, 'trial_idx', 'int, list of int')

# Extract desired trials
if len(cell_response._spike_times[0]) > 0:
spike_times = np.concatenate(
np.array(cell_response._spike_times, dtype=object)[trial_idx])
spike_types = np.concatenate(
np.array(cell_response._spike_types, dtype=object)[trial_idx])
spike_gids = np.concatenate(
np.array(cell_response._spike_gids, dtype=object)[trial_idx])
else:
spike_times = np.array([])
spike_types = np.array([])
spike_gids = np.array([])
spike_times = []
spike_types = []
spike_gids = []
for trial in trial_idx:
spike_times.append(cell_response.spike_times[trial])
spike_types.append(cell_response.spike_types[trial])
spike_gids.append(cell_response.spike_gids[trial])

cell_types = ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal']
cell_type_colors = {'L5_pyramidal': 'r', 'L5_basket': 'b',
Expand All @@ -539,10 +535,10 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
cell_type_gids = np.unique(spike_gids[spike_types == cell_type])
cell_type_times, cell_type_ypos = [], []
for gid in cell_type_gids:
gid_time = spike_times[spike_gids == gid]
gid_time = np.array(spike_times)[np.array(spike_gids) == gid]
cell_type_times.append(gid_time)
cell_type_ypos.append(ypos)
ypos = ypos - 1
ypos -= 1

if cell_type_times:
events.append(
Expand Down

0 comments on commit d804451

Please sign in to comment.