Skip to content

Commit

Permalink
refactor: updated to use cell_types property of CellResponse and upda…
Browse files Browse the repository at this point in the history
…ted variable names
  • Loading branch information
gtdang committed Oct 25, 2024
1 parent 4ad591e commit 4d3a706
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,50 +539,48 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True,
if trial_idx is None:
trial_idx = list(range(n_trials))

# Get spike types
spike_types_data = np.concatenate(np.array(cell_response.spike_types,
dtype=object))
spike_types = np.unique(spike_types_data).tolist()
# Get spike types from cell response
unique_spike_types = cell_response.cell_types

# validate trial argument
if isinstance(trial_idx, int):
trial_idx = [trial_idx]
_validate_type(trial_idx, list, 'trial_idx', 'int, list of int')

# validate cell types
default_cell_types = ['L2_basket', 'L2_pyramidal',
'L5_basket', 'L5_pyramidal']
if cell_types:
_validate_type(cell_types, list, 'cell_types', 'list of str')
if not set(cell_types).issubset(set(spike_types)):
if not set(cell_types).issubset(set(unique_spike_types)):
raise ValueError("Invalid cell types provided. "
f"Must be of set {spike_types}. "
f"Must be of set {unique_spike_types}. "
f"Got {cell_types}")
default_cell_types = cell_types
else:
# Use default cell types
cell_types = ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal']

# Set default colors
default_colors = (plt.rcParams['axes.prop_cycle']
.by_key()['color'][:len(default_cell_types)])
.by_key()['color'][:len(cell_types)])
cell_colors = {cell: color
for cell, color in zip(default_cell_types, default_colors)}
for cell, color in zip(cell_types, default_colors)}

# validate colors argument
_validate_type(colors, (list, dict, None), 'color', 'list of str, or dict')
if colors:
if isinstance(colors, list):
if len(colors) != len(default_cell_types):
if len(colors) != len(cell_types):
raise ValueError(
f"Number of colors must be equal to number of "
f"cell types. {len(colors)} colors provided "
f"for {len(default_cell_types)} cell types.")
f"for {len(cell_types)} cell types.")
cell_colors = {cell: color
for cell, color in zip(default_cell_types, colors)}
for cell, color in zip(cell_types, colors)}

if isinstance(colors, dict):
# Check valid cell types
if not set(colors.keys()).issubset(set(spike_types)):
if not set(colors.keys()).issubset(set(unique_spike_types)):
raise ValueError("Invalid cell types provided. "
f"Must be of set {spike_types}. "
f"Must be of set {unique_spike_types}. "
f"Got {colors.keys()}")
cell_colors.update(colors)

Expand Down

0 comments on commit 4d3a706

Please sign in to comment.