Skip to content

Commit

Permalink
Refactor logic to concatenate spike gids for total neuron count
Browse files Browse the repository at this point in the history
  • Loading branch information
samadpls committed May 16, 2024
1 parent 5b57580 commit f064b35
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ Changelog
- Added feature to read/write :class:`~hnn_core.Network` configurations to
json, by `George Dang`_ and `Rajat Partani`_ in :gh:`757`

- Updated :func:`~hnn_core/viz/plot_spikes_raster` logic to always include all
cell types in raster plots, by `Abdul Samad Siddiqui`_ in :gh:`754`.

Bug
~~~
- Fix inconsistent connection mapping from drive gids to cell gids, by
Expand Down
3 changes: 2 additions & 1 deletion hnn_core/tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def test_dipole_visualization():
with pytest.raises(TypeError, match="trial_idx must be an instance of"):
net.cell_response.plot_spikes_raster(trial_idx='blah', show=False)
net.cell_response.plot_spikes_raster(trial_idx=0, show=False)
net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False)
fig = net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False)
assert len(fig.axes[0].collections) > 0, "No data plotted in raster plot"

with pytest.raises(TypeError, match="trial_idx must be an instance of"):
net.cell_response.plot_spikes_hist(trial_idx='blah')
Expand Down
27 changes: 15 additions & 12 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,22 +514,19 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
_validate_type(trial_idx, list, 'trial_idx', 'int, list of int')

# Extract desired trials
if len(cell_response._spike_times[0]) > 0:
spike_times = np.concatenate(
np.array(cell_response._spike_times, dtype=object)[trial_idx])
spike_types = np.concatenate(
np.array(cell_response._spike_types, dtype=object)[trial_idx])
spike_gids = np.concatenate(
np.array(cell_response._spike_gids, dtype=object)[trial_idx])
else:
spike_times = np.array([])
spike_types = np.array([])
spike_gids = np.array([])
spike_times = np.concatenate(
np.array(cell_response._spike_times, dtype=object)[trial_idx])
spike_types = np.concatenate(
np.array(cell_response._spike_types, dtype=object)[trial_idx])
spike_gids = np.concatenate(
np.array(cell_response._spike_gids, dtype=object)[trial_idx])

cell_types = ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal']
cell_type_colors = {'L5_pyramidal': 'r', 'L5_basket': 'b',
'L2_pyramidal': 'g', 'L2_basket': 'w'}

total_neurons = len(np.unique(np.concatenate(cell_response._spike_gids)))

if ax is None:
_, ax = plt.subplots(1, 1, constrained_layout=True)

Expand All @@ -542,19 +539,25 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
gid_time = spike_times[spike_gids == gid]
cell_type_times.append(gid_time)
cell_type_ypos.append(ypos)
ypos = ypos - 1
ypos -= 1

if cell_type_times:
events.append(
ax.eventplot(cell_type_times, lineoffsets=cell_type_ypos,
color=cell_type_colors[cell_type],
label=cell_type, linelengths=5))
else:
events.append(
ax.eventplot([-1], lineoffsets=[-1],
color=cell_type_colors[cell_type],
label=cell_type, linelengths=5))

ax.legend(handles=[e[0] for e in events], loc=1)
ax.set_facecolor('k')
ax.set_xlabel('Time (ms)')
ax.get_yaxis().set_visible(False)
ax.set_xlim(left=0)
ax.set_ylim(-total_neurons, 0)

plt_show(show)
return ax.get_figure()
Expand Down

0 comments on commit f064b35

Please sign in to comment.