Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Spikes raster plot colors #895

Merged
merged 15 commits into from
Nov 22, 2024

Conversation

gtdang
Copy link
Collaborator

@gtdang gtdang commented Sep 20, 2024

Changes

  1. Changed the spike raster colors to have a white background and use the current color cycle's first 4 colors.
  2. Added argument to change the cell colors by either dictionary or list of colors.
  3. Added argument to add custom cell types if the Network has different cell types than the default.
  4. Changed the line lengths to 1 to prevent overlap

With new colors
Screenshot 2024-09-20 at 4 12 30 PM

With new shorter lines.
Screenshot 2024-10-23 at 10 06 41 AM

Question:

  • Should the API allow users to be able to specify colors? Yes
    • With this new implementation the user could technically change the plot colors outside of the hnn-core API if they changed the Matplotlib color cycle and default background with the matplotlib API.

closes #888

@gtdang gtdang changed the title Spikes raster plot colors [WIP] Spikes raster plot colors Sep 20, 2024
@gtdang
Copy link
Collaborator Author

gtdang commented Sep 25, 2024

Hi @ntolley. I was looking into writing a test for making sure the colors are working expected for this plot. I was looking into the existing tests for the plotter in test_viz.py. It is shown below with the relevant tests for the plotter on lines 208-212. When I plot the spike event plot for the network specified I noticed that the plot was empty. Is this what we would expect for the network specified? The network only has 2 rhythmic drives added (lines 136-146).

def test_dipole_visualization(setup_net):
"""Test dipole visualisations."""
net = setup_net
# Test plotting of simulations with no spiking
dpls = simulate_dipole(net, tstop=100., n_trials=1)
net.cell_response.plot_spikes_raster()
net.cell_response.plot_spikes_hist()
weights_ampa = {'L2_pyramidal': 5.4e-5, 'L5_pyramidal': 5.4e-5}
syn_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.}
net.add_bursty_drive(
'beta_prox', tstart=0., burst_rate=25, burst_std=5,
numspikes=1, spike_isi=0, n_drive_cells=11, location='proximal',
weights_ampa=weights_ampa, synaptic_delays=syn_delays,
event_seed=14)
net.add_bursty_drive(
'beta_dist', tstart=0., burst_rate=25, burst_std=5,
numspikes=1, spike_isi=0, n_drive_cells=11, location='distal',
weights_ampa=weights_ampa, synaptic_delays=syn_delays,
event_seed=14)
dpls = simulate_dipole(net, tstop=100., n_trials=2, record_vsec='all')
fig = dpls[0].plot() # plot the first dipole alone
axes = fig.get_axes()[0]
dpls[0].copy().smooth(window_len=10).plot(ax=axes) # add smoothed versions
dpls[0].copy().savgol_filter(h_freq=30).plot(ax=axes) # on top
# test decimation options
plot_dipole(dpls[0], decim=2, show=False)
for dec in [-1, [2, 2.]]:
with pytest.raises(ValueError,
match='each decimation factor must be a positive'):
plot_dipole(dpls[0], decim=dec, show=False)
# test plotting multiple dipoles as overlay
fig = plot_dipole(dpls, show=False)
# test plotting multiple dipoles with average
fig = plot_dipole(dpls, average=True, show=False)
plt.close('all')
# test plotting dipoles with multiple layers
fig, ax = plt.subplots()
fig = plot_dipole(dpls, show=False, ax=[ax], layer=['L2'])
fig = plot_dipole(dpls, show=False, layer=['L2', 'L5', 'agg'])
fig, axes = plt.subplots(nrows=3, ncols=1)
fig = plot_dipole(dpls, show=False, ax=axes, layer=['L2', 'L5', 'agg'])
fig, axes = plt.subplots(nrows=3, ncols=1)
fig = plot_dipole(dpls,
show=False,
ax=[axes[0], axes[1], axes[2]],
layer=['L2', 'L5', 'agg'])
plt.close('all')
with pytest.raises(AssertionError,
match="ax and layer should have the same size"):
fig, axes = plt.subplots(nrows=3, ncols=1)
fig = plot_dipole(dpls, show=False, ax=axes, layer=['L2', 'L5'])
# multiple TFRs get averaged
fig = plot_tfr_morlet(dpls, freqs=np.arange(23, 26, 1.), n_cycles=3,
show=False)
with pytest.raises(RuntimeError,
match="All dipoles must be scaled equally!"):
plot_dipole([dpls[0].copy().scale(10), dpls[1].copy().scale(20)])
with pytest.raises(RuntimeError,
match="All dipoles must be scaled equally!"):
plot_psd([dpls[0].copy().scale(10), dpls[1].copy().scale(20)])
with pytest.raises(RuntimeError,
match="All dipoles must be sampled equally!"):
dpl_sfreq = dpls[0].copy()
dpl_sfreq.sfreq /= 10
plot_psd([dpls[0], dpl_sfreq])
# pytest deprecation warning for tmin and tmax
with pytest.deprecated_call():
plot_dipole(dpls[0], show=False, tmin=10, tmax=100)
# test cell response plotting
with pytest.raises(TypeError, match="trial_idx must be an instance of"):
net.cell_response.plot_spikes_raster(trial_idx='blah', show=False)
net.cell_response.plot_spikes_raster(trial_idx=0, show=False)
fig = net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False)
assert len(fig.axes[0].collections) > 0, "No data plotted in raster plot"
with pytest.raises(TypeError, match="trial_idx must be an instance of"):
net.cell_response.plot_spikes_hist(trial_idx='blah')
net.cell_response.plot_spikes_hist(trial_idx=0, show=False)
net.cell_response.plot_spikes_hist(trial_idx=[0, 1], show=False)
net.cell_response.plot_spikes_hist(color='r')
net.cell_response.plot_spikes_hist(color=['C0', 'C1'])
net.cell_response.plot_spikes_hist(color={'beta_prox': 'r',
'beta_dist': 'g'})
net.cell_response.plot_spikes_hist(
spike_types={'group1': ['beta_prox', 'beta_dist']},
color={'group1': 'r'})
net.cell_response.plot_spikes_hist(
spike_types={'group1': ['beta']}, color={'group1': 'r'})
with pytest.raises(TypeError, match="color must be an instance of"):
net.cell_response.plot_spikes_hist(color=123)
with pytest.raises(ValueError):
net.cell_response.plot_spikes_hist(color='z')
with pytest.raises(ValueError):
net.cell_response.plot_spikes_hist(color={'beta_prox': 'z',
'beta_dist': 'g'})
with pytest.raises(TypeError, match="Dictionary values of color must"):
net.cell_response.plot_spikes_hist(color={'beta_prox': 123,
'beta_dist': 'g'})
with pytest.raises(ValueError, match="'beta_dist' must be"):
net.cell_response.plot_spikes_hist(color={'beta_prox': 'r'})
plt.close('all')

Here's a code snip to recreate:

from hnn_core import simulate_dipole, jones_2009_model, read_params
from pathlib import Path

hnn_core_root = Path.cwd().parents[0]
params_fname = Path(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)
net = jones_2009_model(params, mesh_shape=(3, 3))

weights_ampa = {'L2_pyramidal': 5.4e-5, 'L5_pyramidal': 5.4e-5}
syn_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.}

net.add_bursty_drive(
    'beta_prox', tstart=0., burst_rate=25, burst_std=5,
    numspikes=1, spike_isi=0, n_drive_cells=11, location='proximal',
    weights_ampa=weights_ampa, synaptic_delays=syn_delays,
    event_seed=14)

net.add_bursty_drive(
    'beta_dist', tstart=0., burst_rate=25, burst_std=5,
    numspikes=1, spike_isi=0, n_drive_cells=11, location='distal',
    weights_ampa=weights_ampa, synaptic_delays=syn_delays,
    event_seed=14)

dpls = simulate_dipole(net, tstop=100., n_trials=2, record_vsec='all')

fig1 = net.cell_response.plot_spikes_raster()
fig1.show()

@ntolley
Copy link
Contributor

ntolley commented Sep 26, 2024

@gtdang this test was written for the edge case with the drives are too weak to produce spiking activity in the network

There was an earlier bug where the plotting function would throw an error if not spikes occurred, so this test is to make sure that an empty plot is generated (the desired behavior in these simulations)

@ntolley
Copy link
Contributor

ntolley commented Sep 26, 2024

If you want spiking just change the weights to a bigger number like 0.1 or 1.0!

@gtdang gtdang force-pushed the spike-plot-colors branch from 26c9fa8 to 9717ae7 Compare October 11, 2024 19:21
@gtdang gtdang marked this pull request as ready for review October 11, 2024 19:50
@gtdang gtdang changed the title [WIP] Spikes raster plot colors [MRG] Spikes raster plot colors Oct 11, 2024
@gtdang gtdang requested a review from asoplata October 15, 2024 19:16
hnn_core/viz.py Outdated
def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True):
def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True,
cell_types=['L2_basket', 'L2_pyramidal',
'L5_basket', 'L5_pyramidal'],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad idea to have a default list in a function. You will get funky effects in Python ... default should not be a mutable

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, we're going to refactor this call to be aligned with the plot_spikes_hist implementation so that it can also take a dict of color assignments. I don't think we need to expose the cell_types as an argument... though I wish there was a way to get it dynamically from the network instead of hard-coding the types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see #916 ... it allows you to dynamically extract the cell types

class TestCellResponsePlotters:

@pytest.fixture(scope='class')
def class_setup_net(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docstring

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gtdang this comment is not addressed yet, otherwise I can go ahead and merge. Still see some test functions without a one-line docstring

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, do you just want a doc string for these fixtures? Perhaps:
"""Creates a base network from the default json for tests within this class"""
and
"""Adds bursty drives with spikes to the base network for testing visualizations of spikes"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that's good!

events = []
for cell_type in cell_types:
cell_type_gids = np.unique(spike_gids[spike_types == cell_type])
cell_type_times, cell_type_ypos = [], []
color = next(color_iter)

for gid in cell_type_gids:
gid_time = spike_times[spike_gids == gid]
cell_type_times.append(gid_time)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while you are at it, I am wondering if this could be addressed as well. The following line:

cell_type_ypos.append(-gid)

causes cells that spike with neighboring gids to overlap. I have been staring at these raster plots recently and it's very hard to tell how many times the same cell spiked (important to understand the underlying dynamics). Adding a small offset between nearby cells should address that problem

Copy link
Collaborator Author

@gtdang gtdang Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a look into this. The overlap is due to both the y-position as you've identified and the line lengths defined during the plot function call on line 568. The y-position is the center of each line and the length is how much it extends from that center point (+-2.5 each way for a value of 5).

hnn-core/hnn_core/viz.py

Lines 559 to 568 in 27c6fc1

for gid in cell_type_gids:
gid_time = spike_times[spike_gids == gid]
cell_type_times.append(gid_time)
cell_type_ypos.append(-gid)
if cell_type_times:
events.append(
ax.eventplot(cell_type_times, lineoffsets=cell_type_ypos,
color=cell_type_colors[cell_type],
label=cell_type, linelengths=5))

A simple solution is to change the line length to 1. However the lines will look more like dots with this change.
Screenshot 2024-10-23 at 10 06 41 AM

Another solution would be to analyze the cell times and gids, and apply a larger y-offset if they are within an X and Y bounding box of one another.

Let me know what you think.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with the spikes looking like dots ... it's just a function of the number of cells in our network. Did a quick google image search of "spike raster plot" and the plots do look dotted when there are more neurons. I guess the y-offset = -gid is helpful since it allows you to identify the cell, so maybe best not to touch that. @ntolley any opinion here?

Copy link
Collaborator

@asoplata asoplata Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yesterday, both @ntolley and I had a brief look at the "dot" version and, IIRC, we both agreed it looks good. @ntolley LMK if I'm remembering wrong

@gtdang gtdang force-pushed the spike-plot-colors branch from be192f3 to 4d3a706 Compare October 25, 2024 18:15
@asoplata asoplata modified the milestones: 0.4, 0.5 Nov 15, 2024
Copy link
Collaborator

@asoplata asoplata left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@asoplata asoplata merged commit cbc86d2 into jonescompneurolab:master Nov 22, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

API: Spiking plot
4 participants