Skip to content

Commit

Permalink
Refactor raster plot and remove GUI docs from conf.py
Browse files Browse the repository at this point in the history
  • Loading branch information
samadpls committed May 30, 2024
1 parent aa773ec commit a331a23
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 16 deletions.
1 change: 0 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@
'navbar_sidebarrel': False,
'navbar_links': [
("Examples", "auto_examples/index"),
("GUI", "gui/index"),
("API", "api"),
("Glossary", "glossary"),
("What's new", "whats_new"),
Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ Changelog
- Added :class:`~hnn_core/viz/NetworkPlotter` to visualize and animate network simulations,
by `Nick Tolley`_ in :gh:`649`.

- Updated :func:`~hnn_core/viz/plot_spikes_raster` logic to always include all
cell types in raster plots, by `Abdul Samad Siddiqui`_ in :gh:`754`.

Bug
~~~
- Fix inconsistent connection mapping from drive gids to cell gids, by
Expand Down
3 changes: 2 additions & 1 deletion hnn_core/tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def test_dipole_visualization(setup_net):
with pytest.raises(TypeError, match="trial_idx must be an instance of"):
net.cell_response.plot_spikes_raster(trial_idx='blah', show=False)
net.cell_response.plot_spikes_raster(trial_idx=0, show=False)
net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False)
fig = net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False)
assert len(fig.axes[0].collections) > 0, "No data plotted in raster plot"

with pytest.raises(TypeError, match="trial_idx must be an instance of"):
net.cell_response.plot_spikes_hist(trial_idx='blah')
Expand Down
26 changes: 12 additions & 14 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,17 +514,12 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
_validate_type(trial_idx, list, 'trial_idx', 'int, list of int')

# Extract desired trials
if len(cell_response._spike_times[0]) > 0:
spike_times = np.concatenate(
np.array(cell_response._spike_times, dtype=object)[trial_idx])
spike_types = np.concatenate(
np.array(cell_response._spike_types, dtype=object)[trial_idx])
spike_gids = np.concatenate(
np.array(cell_response._spike_gids, dtype=object)[trial_idx])
else:
spike_times = np.array([])
spike_types = np.array([])
spike_gids = np.array([])
spike_times = np.concatenate(
np.array(cell_response._spike_times, dtype=object)[trial_idx])
spike_types = np.concatenate(
np.array(cell_response._spike_types, dtype=object)[trial_idx])
spike_gids = np.concatenate(
np.array(cell_response._spike_gids, dtype=object)[trial_idx])

cell_types = ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal']
cell_type_colors = {'L5_pyramidal': 'r', 'L5_basket': 'b',
Expand All @@ -533,22 +528,25 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
if ax is None:
_, ax = plt.subplots(1, 1, constrained_layout=True)

ypos = 0
events = []
for cell_type in cell_types:
cell_type_gids = np.unique(spike_gids[spike_types == cell_type])
cell_type_times, cell_type_ypos = [], []
for gid in cell_type_gids:
gid_time = spike_times[spike_gids == gid]
cell_type_times.append(gid_time)
cell_type_ypos.append(ypos)
ypos = ypos - 1
cell_type_ypos.append(-gid)

if cell_type_times:
events.append(
ax.eventplot(cell_type_times, lineoffsets=cell_type_ypos,
color=cell_type_colors[cell_type],
label=cell_type, linelengths=5))
else:
events.append(
ax.eventplot([-1], lineoffsets=[-1],
color=cell_type_colors[cell_type],
label=cell_type, linelengths=5))

ax.legend(handles=[e[0] for e in events], loc=1)
ax.set_facecolor('k')
Expand Down

0 comments on commit a331a23

Please sign in to comment.