Skip to content

Commit

Permalink
SNN reservoir and plotting functions (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
fmilisav authored Nov 24, 2023
1 parent c53e241 commit 41cf888
Show file tree
Hide file tree
Showing 2 changed files with 595 additions and 0 deletions.
86 changes: 86 additions & 0 deletions conn2res/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,3 +734,89 @@ def plot_phase_space(

# reset rc defaults
mpl.rcdefaults()

def plot_spike_raster(tspike, x1, x2, title = "Spike Raster"):
"""
Plot a spike raster plot.
Parameters
----------
tspike : (N_spikes, 2) numpy.ndarray
spike times (in s)
N_spikes: number of spikes
tspike[:, 0]: spike neuronal indices
tspike[:, 1]: spike times (in s)
x1 : float
start time (in s)
x2 : float
end time (in s)
title : str, optional
title of the plot, by default "Spike Raster"
"""

nneurons = int(np.max(np.unique(tspike[:, 0])) + 1)

plt.figure(figsize=(max((x2 - x1)/0.1*10, 10), max(.02*nneurons, 1)))
plt.title(title)
plt.xlabel("Time (ms)")
plt.ylabel("Neuron")

spike_times = []
for neuron in range(nneurons):
idx = np.where(tspike[:, 0] == neuron)[0]
spike_times.append(tspike[:, 1][idx]*1000)

for neuron in range(nneurons):
spike_train = spike_times[neuron]
plt.scatter(spike_train, [neuron] * len(spike_train), marker='|', color='black')

plt.xlim(x1*1000, x2*1000)
plt.ylim(-0.5, nneurons - 0.5)
plt.gca().invert_yaxis()
plt.grid(True, linestyle='--', alpha = 0.7)
plt.show()

def plot_membrane_voltages(membrane_voltages, x1, x2, neuron_idx = None,
dt = 0.05, title="Membrane Voltages"):
"""
Plot the membrane voltages of the neurons.
Parameters
----------
membrane_voltages : (nt, N) numpy.ndarray
membrane voltage tracings (mV)
nt: number of time steps
N: number of nodes
x1 : float
start timestep
x2 : float
end timestep
neuron_idx : numpy.ndarray, optional
indices of neurons to plot
Default: None
dt : float, optional
sampling rate (in s)
Default: 0.05
title : str, optional
title of the plot, by default "Membrane Voltages"
"""

if neuron_idx is None:
neuron_idx = np.arange(membrane_voltages.shape[1])
nneurons = len(neuron_idx)

time_arr = np.arange(x1, x2)*dt

fig, axes = plt.subplots(nneurons, 1, figsize=(max(10, (x2 - x1)/1000*10), nneurons), sharex=True, sharey=False)
if nneurons == 1:
axes = [axes]
for idx, ax in enumerate(axes):
ax.plot(time_arr, membrane_voltages[x1:x2, neuron_idx[idx]], c = 'k')
ax.set_ylabel(f"Neuron {neuron_idx[idx]}")
ax.grid(True)

fig.suptitle(title, y = 1.03)
fig.supylabel('Voltage (mV)', x = -0.03)
fig.supxlabel('Time (ms)', y = -0.03)
fig.tight_layout()
plt.show()
Loading

0 comments on commit 41cf888

Please sign in to comment.