diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 4bc4b51e3..8406a8427 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -539,10 +539,8 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True, if trial_idx is None: trial_idx = list(range(n_trials)) - # Get spike types - spike_types_data = np.concatenate(np.array(cell_response.spike_types, - dtype=object)) - spike_types = np.unique(spike_types_data).tolist() + # Get spike types from cell response + unique_spike_types = cell_response.cell_types # validate trial argument if isinstance(trial_idx, int): @@ -550,39 +548,39 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True, _validate_type(trial_idx, list, 'trial_idx', 'int, list of int') # validate cell types - default_cell_types = ['L2_basket', 'L2_pyramidal', - 'L5_basket', 'L5_pyramidal'] if cell_types: _validate_type(cell_types, list, 'cell_types', 'list of str') - if not set(cell_types).issubset(set(spike_types)): + if not set(cell_types).issubset(set(unique_spike_types)): raise ValueError("Invalid cell types provided. " - f"Must be of set {spike_types}. " + f"Must be of set {unique_spike_types}. " f"Got {cell_types}") - default_cell_types = cell_types + else: + # Use default cell types + cell_types = ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal'] # Set default colors default_colors = (plt.rcParams['axes.prop_cycle'] - .by_key()['color'][:len(default_cell_types)]) + .by_key()['color'][:len(cell_types)]) cell_colors = {cell: color - for cell, color in zip(default_cell_types, default_colors)} + for cell, color in zip(cell_types, default_colors)} # validate colors argument _validate_type(colors, (list, dict, None), 'color', 'list of str, or dict') if colors: if isinstance(colors, list): - if len(colors) != len(default_cell_types): + if len(colors) != len(cell_types): raise ValueError( f"Number of colors must be equal to number of " f"cell types. {len(colors)} colors provided " - f"for {len(default_cell_types)} cell types.") + f"for {len(cell_types)} cell types.") cell_colors = {cell: color - for cell, color in zip(default_cell_types, colors)} + for cell, color in zip(cell_types, colors)} if isinstance(colors, dict): # Check valid cell types - if not set(colors.keys()).issubset(set(spike_types)): + if not set(colors.keys()).issubset(set(unique_spike_types)): raise ValueError("Invalid cell types provided. " - f"Must be of set {spike_types}. " + f"Must be of set {unique_spike_types}. " f"Got {colors.keys()}") cell_colors.update(colors)