From 41cf888e2e7e437d1f4de2c1cb00ead5a5939e11 Mon Sep 17 00:00:00 2001 From: fmilisav <74116444+fmilisav@users.noreply.github.com> Date: Fri, 24 Nov 2023 14:39:38 -0500 Subject: [PATCH] SNN reservoir and plotting functions (#37) --- conn2res/plotting.py | 86 +++++++ conn2res/reservoir.py | 509 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 595 insertions(+) diff --git a/conn2res/plotting.py b/conn2res/plotting.py index 025c82d..0ad4ba7 100755 --- a/conn2res/plotting.py +++ b/conn2res/plotting.py @@ -734,3 +734,89 @@ def plot_phase_space( # reset rc defaults mpl.rcdefaults() + +def plot_spike_raster(tspike, x1, x2, title = "Spike Raster"): + """ + Plot a spike raster plot. + + Parameters + ---------- + tspike : (N_spikes, 2) numpy.ndarray + spike times (in s) + N_spikes: number of spikes + tspike[:, 0]: spike neuronal indices + tspike[:, 1]: spike times (in s) + x1 : float + start time (in s) + x2 : float + end time (in s) + title : str, optional + title of the plot, by default "Spike Raster" + """ + + nneurons = int(np.max(np.unique(tspike[:, 0])) + 1) + + plt.figure(figsize=(max((x2 - x1)/0.1*10, 10), max(.02*nneurons, 1))) + plt.title(title) + plt.xlabel("Time (ms)") + plt.ylabel("Neuron") + + spike_times = [] + for neuron in range(nneurons): + idx = np.where(tspike[:, 0] == neuron)[0] + spike_times.append(tspike[:, 1][idx]*1000) + + for neuron in range(nneurons): + spike_train = spike_times[neuron] + plt.scatter(spike_train, [neuron] * len(spike_train), marker='|', color='black') + + plt.xlim(x1*1000, x2*1000) + plt.ylim(-0.5, nneurons - 0.5) + plt.gca().invert_yaxis() + plt.grid(True, linestyle='--', alpha = 0.7) + plt.show() + +def plot_membrane_voltages(membrane_voltages, x1, x2, neuron_idx = None, + dt = 0.05, title="Membrane Voltages"): + """ + Plot the membrane voltages of the neurons. + + Parameters + ---------- + membrane_voltages : (nt, N) numpy.ndarray + membrane voltage tracings (mV) + nt: number of time steps + N: number of nodes + x1 : float + start timestep + x2 : float + end timestep + neuron_idx : numpy.ndarray, optional + indices of neurons to plot + Default: None + dt : float, optional + sampling rate (in s) + Default: 0.05 + title : str, optional + title of the plot, by default "Membrane Voltages" + """ + + if neuron_idx is None: + neuron_idx = np.arange(membrane_voltages.shape[1]) + nneurons = len(neuron_idx) + + time_arr = np.arange(x1, x2)*dt + + fig, axes = plt.subplots(nneurons, 1, figsize=(max(10, (x2 - x1)/1000*10), nneurons), sharex=True, sharey=False) + if nneurons == 1: + axes = [axes] + for idx, ax in enumerate(axes): + ax.plot(time_arr, membrane_voltages[x1:x2, neuron_idx[idx]], c = 'k') + ax.set_ylabel(f"Neuron {neuron_idx[idx]}") + ax.grid(True) + + fig.suptitle(title, y = 1.03) + fig.supylabel('Voltage (mV)', x = -0.03) + fig.supxlabel('Time (ms)', y = -0.03) + fig.tight_layout() + plt.show() \ No newline at end of file diff --git a/conn2res/reservoir.py b/conn2res/reservoir.py index 596c260..3226867 100755 --- a/conn2res/reservoir.py +++ b/conn2res/reservoir.py @@ -243,6 +243,515 @@ def step(x, thr=0.5, vmin=0, vmax=1): elif function == 'step': return step +class SpikingNeuralNetwork(Reservoir): + """ + Class that represents a Spiking Neural Network + Adapted from Kim et al., 2019 and Nicola & Clopath, 2017 + (https://github.com/rkim35/spikeRNN/blob/master/spiking/LIF_network_fnc.m) + ... + + Attributes + ---------- + w : (N, N) numpy.ndarray + reservoir connectivity matrix (source, target) + N: number of nodes in the network. If w is directed, then rows + (columns) should correspond to source (target) nodes. + _state : same as ext_input + reservoir activation states + n_nodes : int + dimension of the reservoir + inh : (N,) numpy.ndarray + boolean array indicating whether a node is + inhibitory (True) or excitatory (False) + N: number of nodes in the network + exc : (N,) numpy.ndarray + boolean array indicating whether a node is + excitatory (True) or inhibitory (False) + N: number of nodes in the network + som : (N,) numpy.ndarray + boolean array indicating whether a node is + somatostatin-expressing (True) or not (False) + N: number of nodes in the network + dt : float + sampling rate (in s) + T : float + trial duration (in s) + nt : int + number of time steps + td : float or numpy.ndarray + decay time constants of the synaptic filter model (in s) + REC : (nt, N) numpy.ndarray + membrane voltage tracings (mV) + nt: number of time steps + N: number of nodes in the network + Is : (N, nt) numpy.ndarray + external input current + nt: number of time steps + N: number of nodes in the network + IPSCs : (N, nt) numpy.ndarray + post synaptic currents over time + nt: number of time steps + N: number of nodes in the network + spk : (N, nt) numpy.ndarray + spike raster + nt: number of time steps + N: number of nodes in the network + rs : (N, nt/timescale) numpy.ndarray + filtered firing rates over time + nt: number of time steps + N: number of nodes in the network + timescale: number of internal time steps per external time steps + hs : (N, nt) numpy.ndarray + filtered firing rates over time (synaptic input accumulation) + nt: number of time steps + N: number of nodes in the network + tspike : (N_spikes, 2) numpy.ndarray + spike times (in s) + N_spikes: number of spikes + tspike[:, 0]: spike neuronal indices + tspike[:, 1]: spike times (in s) + inh_fr : (N_inh,) numpy.ndarray + average firing rates of inhibitory neurons + N_inh: number of inhibitory neurons + exc_fr : (N_exc,) numpy.ndarray + average firing rates of excitatory neurons + N_exc: number of excitatory neurons + all_fr : (N,) numpy.ndarray + average firing rates of all neurons + N: number of neurons in the network + + Methods + ------- + # TODO + + simulate + + """ + + def __init__(self, *args, inh = 0.2, som = 0., apply_Dale = True, **kwargs): + """ + Constructor class for Spiking Neural Networks + + Parameters + ---------- + w: (N, N) numpy.ndarray + Reservoir connectivity matrix (source, target) + N: number of nodes in the network. If w is directed, + then rows (columns) should correspond to source (target) nodes. + inh: float or (N,) numpy.ndarray, optional + If float, inh should be in the range [0, 1] and + indicates the proportion of inhibitory neurons in the network. + If numpy.ndarray of shape (N,), then inh is a + boolean array indicating whether a node is + inhibitory (True) or excitatory (False). + This parameter is used to apply Dale's principle, constraining + the connectivity matrix such that a neuron can only be either + excitatory or inhibitory, but not both. + Default: 0.2 + som: float or (N,) numpy.ndarray, optional + If float, som inh should be in the range [0, 1] and + indicates the proportion of somatostatin-expressing interneurons + in the network. + If numpy.ndarray of shape (N,), then som is a + boolean array indicating whether a node is + somatostatin-expressing (True) or not (False). + Importantly, somatostatin-expressing interneurons are a + subset of the inhibitory neurons. + This parameter is used to constrain a + common cortical microcircuit motif where somatostatin-expressing + inhibitory neurons do not receive inhibitory input. + Default: 0 + """ + + super().__init__(*args, **kwargs) + + if not(isinstance(inh, (float, np.ndarray))): + raise TypeError('inh must be float or numpy.ndarray') + + if not(isinstance(som, (float, np.ndarray))): + raise TypeError('som must be float or numpy.ndarray') + + if isinstance(inh, float) and isinstance(som, np.ndarray): + raise TypeError('inh must be numpy.ndarray if som is numpy.ndarray') + + if isinstance(inh, np.ndarray) and isinstance(som, np.ndarray): + if not np.all(som <= inh): + raise ValueError('som must be a subset of inh') + + if isinstance(inh, np.ndarray): + if not np.issubdtype(inh.dtype, np.bool_): + raise TypeError('inh must be boolean') + exc = ~inh + else: + if not (0 <= inh <= 1): + raise ValueError('inh and exc must be in the range [0, 1]') + inh = np.random.rand(self.n_nodes) < inh + exc = ~inh + + if isinstance(som, np.ndarray): + if not np.issubdtype(som.dtype, np.bool_): + raise TypeError('som must be boolean') + som_idx = np.where(som)[0] + else: + if not (0 <= som <= 1): + raise ValueError('inh and exc must be in the range [0, 1]') + if som > 0: + som_size = int(np.round(som * np.sum(inh))) + som = np.zeros(self.n_nodes, dtype=bool) + som_idx = np.random.choice(np.where(inh)[0], som_size, replace=False) + som[som_idx] = True + else: + som = np.zeros(self.n_nodes, dtype=bool) + + # apply Dale's principle if inhibitory neurons are specified + if np.any(inh): + w = np.abs(self.w) + + # mask matrix imposing Dale's principle + mask = np.eye(self.n_nodes, dtype=np.float32) + mask[np.where(inh)[0], np.where(inh)[0]] = -1 + + # mask matrix imposing wiring motif mediated by + # somatostatin-expressing interneurons + som_mask = np.ones((self.n_nodes, self.n_nodes), dtype=np.float32) + if np.any(som): + for i in som_idx: + som_mask[i, np.where(inh)[0]] = 0 + + self.w = np.multiply(np.matmul(w, mask), som_mask) + + self.inh = inh + self.exc = exc + self.som = som + + def simulate( + self, ext_input, w_in, + downsample = 1, taus = 35, + tau_min = 20, tau_max = 50, sig_param = None, + timescale = 100, dt = 0.05, tref = 2, tm = 10, + vreset = -65, vpeak = -40, tr = 2, + stim_mode = None, stim_dur = None, stim_units = None, stim_val = 0.5, + input_gain=None, ic=None, output_nodes=None, + return_states=True + ): + """ + Simulates the dynamics of a spiking neural network given + an external input signal 'ext_input', + an input connectivity matrix 'w_in', and + synaptic decay time constants 'taus' + + Parameters + ---------- + ext_input : (time, N_inputs) numpy.ndarray + External input signal + N_inputs: number of external input signals + w_in : (N_inputs, N) numpy.ndarray + Input connectivity matrix (source, target) + N_inputs: number of external input signals + N: number of nodes in the network + taus : float or (2,) array_like, optional + Parameter(s) that modify the decay time constants of the + synaptic filter model. + If float, then the same decay time constant is used + for all neurons. + If array_like, then: + taus[0]: minimum + taus[1]: maximum + downsample : int, optional + Downsamples external input signal by a factor of 'downsample'. + Default: 1 + taus : float or (N,) numpy.ndarray, optional + Decay time constants of the synaptic filter model (in ms). + If float, then the same decay time constant tau + is used for all neurons. + If numpy.ndarray of shape (N,), then: + taus[i]: decay time constant of neuron i + N: number of nodes in the network + Default: 35 + tau_min : float, optional + Minimum decay time constant of the synaptic filter model (in ms). + Default: 20 + Note: used in combination with tau_max and sig_param; + overrides taus + tau_max : float, optional + Maximum decay time constant of the synaptic filter model (in ms). + Default: 50 + Note: used in combination with tau_min and sig_param; + overrides taus + sig_param : float, (N,) numpy.ndarray, or string, optional + Parameter(s) of the sigmoid function that constrains + the decay time constants of the synaptic filter model. + If float, then the same parameter is used for all neurons + yielding a single decay time constant for all neurons. + If numpy.ndarray of shape (N,), then: + sig_param[i]: parameter of the sigmoid function of neuron i + and a different decay time constant is used for each neuron. + If 'normal', then N values are sampled from a normal distribution + with mean = 0 and standard deviation = 1. + N: number of nodes in the network. + Default: None + Note: used in combination with tau_min and tau_max; + overrides taus + dt : float, optional + Sampling rate (in ms). Default: 0.05 + timescale : float, optional + number of internal time steps per external time steps + Default: 100 + tref : float, optional + Refractory time constant (in ms). Default: 2 + tm : float, optional + Membrane time constant (in ms). Default: 10 + vreset : float, optional + Reset voltage (in mV). Default: -65 + vpeak : float, optional + Peak voltage (in mV). Default: -40 + tr : float, optional + Rise time constant (in ms). Default: 2 + stim_mode : {'exc', 'inh'}, optional + Indicates whether to apply artificial + depolarizing ('exc') or hyperpolarizing ('inh') + stimulation (modelling optogenetic stimulation). + Default: None + stim_dur : (2,) numpy.ndarray, optional + Time interval (in timesteps) during which + artificial stimulation or inhibition is applied. + stim_dur[0]: stimulus onset + stim_dur[1]: stimulus offset + Default: None + stim_units : (N,) numpy.ndarray, optional + Indices of neurons that will be stimulated or inhibited. + Default: None + stim_val : float, optional + Value of the artificial stimulation or inhibition (in mV). + Default: 0.5 + input_gain : float, optional + Constant gain that scales w_in. Default: None + ic : (N,) numpy.ndarray, optional + Initial voltage conditions + N: number of nodes in the network. + Default: None + output_nodes : array_like, optional + List of nodes for which reservoir states will be returned if + 'return_states' is True. Default: None + return_states : bool, optional + If True, simulated reservoir states are returned. + Default: True + + Returns + ------- + self._state : (time, N) numpy.ndarray + Activation states of the reservoir. + N: number of nodes in the network if output_nodes is None, else + number of output_nodes + """ + + # inhibitory and excitatory neuron indices + inh_ind = np.where(self.inh)[0] + exc_ind = np.where(self.exc)[0] + + # if ext_input is list or tuple convert to numpy.ndarray + if isinstance(ext_input, (list, tuple)): + sections = utils.get_sections(ext_input) + ext_input = utils.concat(ext_input) + convert_to_list = True + else: + convert_to_list = False + + # scale input connectivity matrix + if input_gain is not None: + w_in = input_gain * w_in + w_in = w_in.T + + # Downsample input stimulus + ext_input = ext_input.T + ext_input = ext_input[:, ::downsample] + ext_stim = np.dot(w_in, ext_input) + + # Set simulation parameters + # sampling rate (s) + dt = dt/1000 * downsample + # trial duration (s) + T = (ext_input.shape[1]) * dt * timescale + # number of time steps + nt = int(np.round(T / dt)) + # refractory time constant (s) + tref = tref/1000 + # membrane time constant (s) + tm = tm/1000 + # rise time constant (s) + tr = tr/1000 + + # Synaptic decay time constants (in sec) + # for the synaptic filter + # td: decay time constants + if sig_param is not None: + if sig_param == 'normal': + sig_param = np.random.randn(self.n_nodes) + td = (1 / (1 + np.exp(-sig_param)) * (tau_max - tau_min) + + tau_min) / 1000 + else: + td = taus/1000 + + # Initialize variables for LIF neurons simulation + # post synaptic current + IPSC = np.zeros(self.n_nodes) + # filtered firing rates (synaptic input accumulation) + h = np.zeros(self.n_nodes) + # filtered firing rates + r = np.zeros(self.n_nodes) + # filtered firing rates (rising phase) + hr = np.zeros(self.n_nodes) + # contribution of each neuron to IPSC + JD = np.zeros(self.n_nodes) + # number of spikes + ns = 0 + + # Initialize voltage + if ic is not None: + v = ic + else: + v = vreset + np.random.rand(self.n_nodes) * (30 - vreset) + + # Initialize storage arrays for recording results + # membrane voltage tracings (mV) + REC = np.zeros((nt, self.n_nodes)) + # external input current + Is = np.zeros((self.n_nodes, nt)) + # post synaptic currents over time + IPSCs = np.zeros((self.n_nodes, nt)) + # spike raster + spk = np.zeros((self.n_nodes, nt)) + # filtered firing rates over time + rs = np.zeros((self.n_nodes, nt)) + # filtered firing rates over time (synaptic input accumulation) + hs = np.zeros((self.n_nodes, nt)) + + tlast = np.zeros(self.n_nodes) # last spike time + + BIAS = vpeak # bias current + + # Start the simulation loop + for i in range(nt): + # Record IPSC over time + IPSCs[:, i] = IPSC + + # Calculate synaptic current + I = IPSC + BIAS + I = I + ext_stim[:, i // timescale] + Is[:, i] = ext_stim[:, i // timescale] + + # Compute voltage change according to LIF equation + dv = (dt * i > tlast + tref) * (-v + I) / tm + v = v + dt * dv + np.random.randn(self.n_nodes) / 10 + + # Apply artificial stimulation/inhibition + if stim_mode == 'exc': + if stim_dur is None: + raise ValueError('stim_dur not specified') + elif stim_dur[0] <= i < stim_dur[1]: + if stim_units is None: + raise ValueError('stim_units not specified') + elif np.random.rand() < 0.5: + v[stim_units] = v[stim_units] + stim_val + elif stim_mode == 'inh': + if stim_dur is None: + raise ValueError('stim_dur not specified') + elif stim_dur[0] <= i < stim_dur[1]: + if stim_units is None: + raise ValueError('stim_units not specified') + elif np.random.rand() < 0.5: + v[stim_units] = v[stim_units] - stim_val + + # Indices of neurons that have fired + index = np.where(v >= vpeak)[0] + + # Store spike times and compute weighted contributions to IPSC + if len(index) > 0: + JD = np.sum(self.w[:, index], axis=1) + curr_ts = np.column_stack((index, np.zeros(len(index)) + dt * i)) + if ns == 0: + tspike = curr_ts + else: + tspike = np.append(tspike, curr_ts, axis=0) + ns = ns + len(index) + + # Set refractory period + tlast = tlast + (dt * i - tlast) * (v >= vpeak) + + # Compute IPSC and filtered firing rates + # If the rise time is 0, then use the single synaptic filter, + # otherwise (i.e. if the rise time is positive) + # use the double-exponential filter + if tr == 0: + IPSC = IPSC * np.exp(-dt / td) + JD * (len(index) > 0) / td + r = r * np.exp(-dt / td) + (v >= vpeak) / td + rs[:, i] = r + else: + IPSC = IPSC * np.exp(-dt / td) + h * dt + h = h * np.exp(-dt / tr) + JD * (len(index) > 0) / (tr * td) + hs[:, i] = h + + r = r * np.exp(-dt / td) + hr * dt + hr = hr * np.exp(-dt / tr) + (v >= vpeak) / (tr * td) + rs[:, i] = r + + # Record spikes + spk[:, i] = v >= vpeak + + # Cap depolarization + v = v + (30 - v) * (v >= vpeak) + + # Record membrane voltage + REC[i, :] = v + + # Reset voltage after spike + v = v + (vreset - v) * (v >= vpeak) + + # Compute average firing rates for different populations + inh_fr = np.zeros(len(inh_ind)) + for i in range(len(inh_ind)): + inh_fr[i] = np.sum(spk[inh_ind[i], :] > 0) / T + + exc_fr = np.zeros(len(exc_ind)) + for i in range(len(exc_ind)): + exc_fr[i] = np.sum(spk[exc_ind[i], :] > 0) / T + + all_fr = np.zeros(self.n_nodes) + for i in range(self.n_nodes): + all_fr[i] = np.sum(spk[i, 10:] > 0) / T + + # Average over every 'timescale' time steps + rs = rs.reshape(rs.shape[0], int(rs.shape[1]/timescale), timescale).mean(axis = -1) + self._state = rs.T + + # Convert back to list or tuple + if convert_to_list: + self._state = utils.split(self._state, sections) + + self.dt = dt + self.T = T + self.nt = nt + self.td = td + self.REC = REC + self.Is = Is + self.IPSCs = IPSCs + self.spk = spk + self.rs = rs + self.hs = hs + self.tspike = tspike + self.inh_fr = inh_fr + self.exc_fr = exc_fr + self.all_fr = all_fr + + # Return the same type + if return_states: + if output_nodes is not None: + if convert_to_list: + return [state[:, output_nodes] for state in self._state] + else: + return self._state[:, output_nodes] + else: + return self._state class MemristiveReservoir: """