diff --git a/doc/conf.py b/doc/conf.py index 98758abf56..c53dc1046d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -115,7 +115,6 @@ 'navbar_sidebarrel': False, 'navbar_links': [ ("Examples", "auto_examples/index"), - ("GUI", "gui/index"), ("API", "api"), ("Glossary", "glossary"), ("What's new", "whats_new"), diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 62c0fb3e9c..1b3fcaa088 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -61,6 +61,9 @@ Changelog - Added :class:`~hnn_core/viz/NetworkPlotter` to visualize and animate network simulations, by `Nick Tolley`_ in :gh:`649`. +- Updated :func:`~hnn_core/viz/plot_spikes_raster` logic to always include all + cell types in raster plots, by `Abdul Samad Siddiqui`_ in :gh:`754`. + Bug ~~~ - Fix inconsistent connection mapping from drive gids to cell gids, by diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index bb4020d507..8546c9e6b8 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -198,7 +198,8 @@ def test_dipole_visualization(setup_net): with pytest.raises(TypeError, match="trial_idx must be an instance of"): net.cell_response.plot_spikes_raster(trial_idx='blah', show=False) net.cell_response.plot_spikes_raster(trial_idx=0, show=False) - net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False) + fig = net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False) + assert len(fig.axes[0].collections) > 0, "No data plotted in raster plot" with pytest.raises(TypeError, match="trial_idx must be an instance of"): net.cell_response.plot_spikes_hist(trial_idx='blah') diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 066cc74e1d..f15e9fd418 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -514,17 +514,12 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): _validate_type(trial_idx, list, 'trial_idx', 'int, list of int') # Extract desired trials - if len(cell_response._spike_times[0]) > 0: - spike_times = np.concatenate( - np.array(cell_response._spike_times, dtype=object)[trial_idx]) - spike_types = np.concatenate( - np.array(cell_response._spike_types, dtype=object)[trial_idx]) - spike_gids = np.concatenate( - np.array(cell_response._spike_gids, dtype=object)[trial_idx]) - else: - spike_times = np.array([]) - spike_types = np.array([]) - spike_gids = np.array([]) + spike_times = np.concatenate( + np.array(cell_response._spike_times, dtype=object)[trial_idx]) + spike_types = np.concatenate( + np.array(cell_response._spike_types, dtype=object)[trial_idx]) + spike_gids = np.concatenate( + np.array(cell_response._spike_gids, dtype=object)[trial_idx]) cell_types = ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal'] cell_type_colors = {'L5_pyramidal': 'r', 'L5_basket': 'b', @@ -533,7 +528,6 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): if ax is None: _, ax = plt.subplots(1, 1, constrained_layout=True) - ypos = 0 events = [] for cell_type in cell_types: cell_type_gids = np.unique(spike_gids[spike_types == cell_type]) @@ -541,14 +535,18 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): for gid in cell_type_gids: gid_time = spike_times[spike_gids == gid] cell_type_times.append(gid_time) - cell_type_ypos.append(ypos) - ypos = ypos - 1 + cell_type_ypos.append(-gid) if cell_type_times: events.append( ax.eventplot(cell_type_times, lineoffsets=cell_type_ypos, color=cell_type_colors[cell_type], label=cell_type, linelengths=5)) + else: + events.append( + ax.eventplot([-1], lineoffsets=[-1], + color=cell_type_colors[cell_type], + label=cell_type, linelengths=5)) ax.legend(handles=[e[0] for e in events], loc=1) ax.set_facecolor('k')