-
Notifications
You must be signed in to change notification settings - Fork 3
/
build_input_vpsi_inh_spikes.py
202 lines (150 loc) · 7.43 KB
/
build_input_vpsi_inh_spikes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import scipy.stats as st
from functools import partial
import h5py
import numpy as np
import matplotlib.pyplot as plt
import scipy
def build_vpsi_input_jitter(t_sim=15000.0, n_cells=100, plot=False, output='vpsi_inh_spikes.h5', freq=8, jitter_amount=0.001):
simLength = 15 #in seconds
freq = 8 #in Hz
delay = 1/freq
def rand_jitter(arr,jitter_amount=0.001): #lower jitterAmount = less jitter
temp = jitter_amount * (max(arr) - min(arr))
if temp <= 0:
temp = 1
return arr + np.random.randn(len(arr)) * temp
total_timestamps = []
total_node_ids = []
for i in range(n_cells):
timestamp = np.arange(1,simLength,delay)
timestamp = rand_jitter(timestamp, jitter_amount)
total_timestamps.extend(timestamp)
for j in range(len(timestamp)):
total_node_ids.append(i)
total_timestamps= [x * 1000 for x in total_timestamps]
vpsi = h5py.File(output, 'w')
vpsi_spikes = vpsi.create_group("spikes")
vpsi_spikes_vp = vpsi_spikes.create_group("vpsi_inh")
spikes_node_ids = []
spikes_timestamps = []
nodes=np.array(total_node_ids).astype(int)
timestamps=np.array(total_timestamps).astype(float)
vpsi_spikes_vp.create_dataset("node_ids", data=nodes)
vpsi_spikes_vp.create_dataset("timestamps", data=timestamps)
vpsi.close()
if plot:
plt.scatter(total_timestamps,total_node_ids)
plt.xlim(6000,7000)
plt.show()
def build_vpsi_input(t_sim=15000.0, n_cells=100, plot=False, depth_of_mod=1, output='vpsi_inh_spikes.h5'):
#setting up mock population of presynaptic neurons
mean_fr = 5 # mean firing rate
std_fr = 2 # std firing rate
t_stop = t_sim/1000 #seconds
print(f"Building VPSI input for {n_cells} cells.")
# depth of modulation #defined in function header now
f = 8 # frequency of oscillation (Theta inhibition)
a, b = (0 - mean_fr) / std_fr, (100 - mean_fr) / std_fr #End points for the truncated normal distribution
print('a = ',a)
print('b = ',b)
d = partial(st.truncnorm.rvs, a=a, b=b, loc=mean_fr, scale=std_fr)
# Creating a function to sample from a simulated population of cells with
# Truncated Normal distribution
# mean firing rate = 10
# Std of firing rate = 2
# bounds printed above
def modulateSimSpikes(n_cells,f,depth_of_mod):
frs = d(size=n_cells) # Calling st.truncnorm.rvs to sample from simulated cells
# Sample size = n_cells = 1000
t = np.arange(0,t_stop,0.001)
# t is an array with values ranging from 0 to t_stop with increment 0.001
z = np.zeros((n_cells,t.shape[0]))
# Z is a n_cells by t.shape[0] ([1000][100]) matrix of 0's
P = 0
#Phase of sine wave
#Loop through each cell
for i in np.arange(0,n_cells):
offset = frs[i] #Set 'offest' to the firing rate of cell i
mod_trace = offset * (1 + depth_of_mod * np.sin((2 * np.pi * f * t ) + P)) #set the modulated firing rate values for cell i
# (2 * np.pi * f * t) : an array of size t
# np.sin( ) : takes the sine of each value of the above array
# depth_of_mod = 0 (no modulation) --> mod_trace = offset
# depth_of_mod = 1 (full modulation) --> mod_trace = offset + (offset * (np.sin((2 * np.pi * f * t ) + P)))
#The above is algebrically equivalent to the following
# offset + (offset * depth_of_mod * np.sin((2 * np.pi * f * t ) + P))
#And to:
# A = offset * (np.sin((2 * np.pi * f * t ) + P) + 1) #Setting the modulated term
# B = offset #Setting the constant term
# mod_trace = depth_of_mod*A + (1-depth_of_mod)*B #Adding their components
z[i,:] = mod_trace
# z[i,:] is the instantanous firing rate of cell i
# Set the ith row of z to an array of firing rates for each time step
return z
def samplePoissonRVS(z):
#---
# The Poisson distribution is a discrete probability distribution that expresses the probability of a
# given number of events occurring in a fixed interval of time or space if these events occur with a
# known constant mean rate and independently of the time since the last event.
#
# An unmodulated cell has a constant mean firing rate, but as modulation causes the firing rate to
# change over time (phase) we consider each small time interval (as determined by length of the timestep)
# where the firing rate is constant.
#
# For a given cell i and a given timestep t we model how many times cell i fires during timestep t
# using the poisson distribution with λ being set to the firing rate for cell i at time t (adjusted
# to be the firing rate per milisecond, i.e. per timstep).
#
# For a given cell this results in many timesteps that do not have spikes - but apoximatly 10 spikes every 1000 timesteps
simSpks = []
#Loop through each cell
for i in np.arange(0,n_cells):
r = z[i,:]
r[r<0] = 0 #Can't have negative firing rates.
numbPoints = scipy.stats.poisson(r/1000).rvs()
# numbPoints is an array of poisson random varibles with length of array = # of time steps
# Each random variable is a poisson disribution that models the number of spikes that occur for cell i during the interval of each timestep
# At each timestep t, cell i has a new freqency value given by z[i,t]
# The poisson random variable that models the number of spikes that occur in cell i at timestep t has the parameter λ = Z[i,t]/1000
simSpks.append(np.where(numbPoints>0)[0])
# If a spike occurs for cell i at timestep t, append timestep t to simSpks
# i.e. if the poisson random variable for cell i at timestep t takes a value of 1 then 't' will be appended to simSpks
return simSpks
z = modulateSimSpikes(n_cells,f,depth_of_mod)
ms_total = int(t_stop/0.001)
raster = samplePoissonRVS(z)
if plot:
plt.subplot(1, 2, 1)
plt.plot(z[0,:])
plt.plot(z[1,:])
plt.plot(z[2,:])
plt.plot(z[3,:])
plt.xlim(0,ms_total)
plt.xlabel('Time (ms)')
plt.ylabel('Firing Rate')
plt.subplot(1, 2, 2)
for i in np.arange(0,z.shape[0]):
plt.plot(raster[i],np.ones((raster[i].shape[0]))*i,'k.')
t = np.arange(0,t_stop,0.001)
plt.plot(100*depth_of_mod*np.sin(2 * np.pi * f * t)+500)
plt.xlabel('time(ms)')
plt.ylabel('node ID')
plt.show()
out = output #'vpsi_inh_spikes.h5'
vpsi = h5py.File(out, 'w')
vpsi_spikes = vpsi.create_group("spikes")
vpsi_spikes_vp = vpsi_spikes.create_group("vpsi_inh")
spikes_node_ids = []
spikes_timestamps = []
count = 0
for spikes in raster:
ids = [count for _ in range(len(spikes))]
spikes_node_ids = spikes_node_ids + ids
spikes_timestamps = spikes_timestamps + spikes.tolist()
count = count + 1
nodes=np.array(spikes_node_ids).astype(int)
timestamps=np.array(spikes_timestamps).astype(float)
vpsi_spikes_vp.create_dataset("node_ids", data=nodes)
vpsi_spikes_vp.create_dataset("timestamps", data=timestamps)
vpsi.close()
if __name__ == '__main__':
build_vpsi_input(plot=True)