diff --git a/doc/whats_new.rst b/doc/whats_new.rst index a61a3e6abb..21c0571d67 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -58,6 +58,9 @@ Changelog - Added gui widgets to save simulation as csv and updated the file upload to support csv data, by `Camilo Diaz`_ in :gh:`753` +- Updated :func:`~hnn_core/viz/plot_spikes_raster` logic to always include all + cell types in raster plots, by `Abdul Samad Siddiqui`_ in :gh:`754`. + Bug ~~~ - Fix inconsistent connection mapping from drive gids to cell gids, by diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index b26754c5a2..b5c8df4f20 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -184,7 +184,8 @@ def test_dipole_visualization(): with pytest.raises(TypeError, match="trial_idx must be an instance of"): net.cell_response.plot_spikes_raster(trial_idx='blah', show=False) net.cell_response.plot_spikes_raster(trial_idx=0, show=False) - net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False) + fig = net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False) + assert len(fig.axes[0].collections) > 0, "No data plotted in raster plot" with pytest.raises(TypeError, match="trial_idx must be an instance of"): net.cell_response.plot_spikes_hist(trial_idx='blah') diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 16e3e8c548..f9da47e8ea 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -514,17 +514,13 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): _validate_type(trial_idx, list, 'trial_idx', 'int, list of int') # Extract desired trials - if len(cell_response._spike_times[0]) > 0: - spike_times = np.concatenate( - np.array(cell_response._spike_times, dtype=object)[trial_idx]) - spike_types = np.concatenate( - np.array(cell_response._spike_types, dtype=object)[trial_idx]) - spike_gids = np.concatenate( - np.array(cell_response._spike_gids, dtype=object)[trial_idx]) - else: - spike_times = np.array([]) - spike_types = np.array([]) - spike_gids = np.array([]) + spike_times = [] + spike_types = [] + spike_gids = [] + for trial in trial_idx: + spike_times.append(cell_response.spike_times[trial]) + spike_types.append(cell_response.spike_types[trial]) + spike_gids.append(cell_response.spike_gids[trial]) cell_types = ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal'] cell_type_colors = {'L5_pyramidal': 'r', 'L5_basket': 'b', @@ -539,10 +535,10 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): cell_type_gids = np.unique(spike_gids[spike_types == cell_type]) cell_type_times, cell_type_ypos = [], [] for gid in cell_type_gids: - gid_time = spike_times[spike_gids == gid] + gid_time = np.array(spike_times)[np.array(spike_gids) == gid] cell_type_times.append(gid_time) cell_type_ypos.append(ypos) - ypos = ypos - 1 + ypos -= 1 if cell_type_times: events.append(