diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 6a8b17bfb8..cd6d713a5a 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -58,6 +58,9 @@ Changelog - Added feature to read/write :class:`~hnn_core.Network` configurations to json, by `George Dang`_ and `Rajat Partani`_ in :gh:`757` +- Updated :func:`~hnn_core/viz/plot_spikes_raster` logic to always include all + cell types in raster plots, by `Abdul Samad Siddiqui`_ in :gh:`754`. + Bug ~~~ - Fix inconsistent connection mapping from drive gids to cell gids, by diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index b26754c5a2..b5c8df4f20 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -184,7 +184,8 @@ def test_dipole_visualization(): with pytest.raises(TypeError, match="trial_idx must be an instance of"): net.cell_response.plot_spikes_raster(trial_idx='blah', show=False) net.cell_response.plot_spikes_raster(trial_idx=0, show=False) - net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False) + fig = net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False) + assert len(fig.axes[0].collections) > 0, "No data plotted in raster plot" with pytest.raises(TypeError, match="trial_idx must be an instance of"): net.cell_response.plot_spikes_hist(trial_idx='blah') diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 16e3e8c548..ac4a83c04e 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -514,22 +514,19 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): _validate_type(trial_idx, list, 'trial_idx', 'int, list of int') # Extract desired trials - if len(cell_response._spike_times[0]) > 0: - spike_times = np.concatenate( - np.array(cell_response._spike_times, dtype=object)[trial_idx]) - spike_types = np.concatenate( - np.array(cell_response._spike_types, dtype=object)[trial_idx]) - spike_gids = np.concatenate( - np.array(cell_response._spike_gids, dtype=object)[trial_idx]) - else: - spike_times = np.array([]) - spike_types = np.array([]) - spike_gids = np.array([]) + spike_times = np.concatenate( + np.array(cell_response._spike_times, dtype=object)[trial_idx]) + spike_types = np.concatenate( + np.array(cell_response._spike_types, dtype=object)[trial_idx]) + spike_gids = np.concatenate( + np.array(cell_response._spike_gids, dtype=object)[trial_idx]) cell_types = ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal'] cell_type_colors = {'L5_pyramidal': 'r', 'L5_basket': 'b', 'L2_pyramidal': 'g', 'L2_basket': 'w'} + total_neurons = len(np.unique(np.concatenate(cell_response._spike_gids))) + if ax is None: _, ax = plt.subplots(1, 1, constrained_layout=True) @@ -542,19 +539,25 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): gid_time = spike_times[spike_gids == gid] cell_type_times.append(gid_time) cell_type_ypos.append(ypos) - ypos = ypos - 1 + ypos -= 1 if cell_type_times: events.append( ax.eventplot(cell_type_times, lineoffsets=cell_type_ypos, color=cell_type_colors[cell_type], label=cell_type, linelengths=5)) + else: + events.append( + ax.eventplot([-1], lineoffsets=[-1], + color=cell_type_colors[cell_type], + label=cell_type, linelengths=5)) ax.legend(handles=[e[0] for e in events], loc=1) ax.set_facecolor('k') ax.set_xlabel('Time (ms)') ax.get_yaxis().set_visible(False) ax.set_xlim(left=0) + ax.set_ylim(-total_neurons, 0) plt_show(show) return ax.get_figure()