-
Notifications
You must be signed in to change notification settings - Fork 3
/
analysis_hom.py
186 lines (154 loc) · 6.43 KB
/
analysis_hom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""
A python implementation of matlab/analysis.m
TB - 8/4/21
"""
import argparse
import os
import sys
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from fooof import FOOOF
from fooof.sim.gen import gen_aperiodic
from scipy.signal import decimate, welch
from scipy.signal.windows import hann as hanning
scale = 1
def raster(spikes_df,node_set,skip_ms=0,ax=None):
spikes_df = spikes_df[spikes_df['timestamps']>skip_ms]
for node in node_set:
cells = range(node['start'],node['end']+1) #+1 to be inclusive of last cell
cell_spikes = spikes_df[spikes_df['node_ids'].isin(cells)]
ax.scatter(cell_spikes['timestamps'],cell_spikes['node_ids'],
c='tab:'+node['color'],s=0.25, label=node['name'])
handles,labels = ax.get_legend_handles_labels()
ax.legend(reversed(handles), reversed(labels))
ax.grid(True)
ax.set_xlim(8500, 9000)
def raw_ecp(lfp):
pass
def ecp_psd(ecp,skip_n=0,downsample=10,nfft=1024,fs=1000,noverlap=0,ax=None,use_fooof=True):
#skip_n first few
data = ecp[skip_n:]
#downsample the data to fit ms (steps used 20=1/.05 step)
lfp_d = decimate(data,downsample)
#raw_ecp(lfp_d)
if not use_fooof:
win = hanning(nfft, True)
f,pxx = welch(lfp_d,fs,window=win,noverlap=noverlap,nfft=nfft)
#ax.set_xscale('log')
ax.set_yscale('log')
ax.plot(f, pxx*1000,linewidth=0.6)
ax.set_ylim([0,0.1])
theta = pxx[np.where((f>=4) & (f<=12))]*1000
gamma = pxx[np.where((f>=50) & (f<=60))]*1000
mean_theta = theta.mean()
peak_theta = theta.max()
mean_gamma = gamma.mean()
peak_gamma = gamma.max()
else:
f,pxx = welch(lfp_d,fs=1000,nfft=1024)
freqs,spectrum = np.array(f),np.array(pxx)
fm = FOOOF(aperiodic_mode='knee')
fm.fit(freqs, spectrum, [1,150])
ap_fit = fm._ap_fit
residual_spec = spectrum[0:150] #- 10**ap_fit
# Plot
#plt.plot([i for i in range(len(residual_spec))],residual_spec)
ax.plot(freqs[:len(residual_spec)], residual_spec)
ax.grid()
#ax.set_xlim(1, 25)
#ax.plot([i for i in range(4,13)],residual_spec[4:13]) # Only theta range
theta = residual_spec[4:13]
peak_theta = max(theta)
mean_theta = sum(theta)/len(theta)
gamma = residual_spec[50:61]
peak_gamma = max(gamma)
mean_gamma = sum(gamma)/len(gamma)
text = f"""
Mean theta (4Hz-8Hz) : {str(round(mean_theta,8))}
Mean gamma (50Hz-60Hz) : {str(round(mean_gamma,8))}
Peak theta (4Hz-8Hz) : {str(round(peak_theta,8))}
Peak gamma (50Hz-60Hz) : {str(round(peak_gamma,8))}
"""
return text
def spike_frequency_histogram(spikes_df,node_set,ms,skip_ms=0,ax=None,n_bins=10):
return_text = "Type : mean (std)\n"
for node in node_set:
cells = range(node['start'],node['end']+1) #+1 to be inclusive of last cell
cell_spikes = spikes_df[spikes_df['node_ids'].isin(cells)]
#skip the first few ms
cell_spikes = cell_spikes[cell_spikes['timestamps']>skip_ms]
spike_counts = cell_spikes.node_ids.value_counts()
total_seconds = (ms-skip_ms)/1000
spike_counts_per_second = spike_counts / total_seconds
spikes_mean = spike_counts_per_second.mean()
spikes_std = spike_counts_per_second.std()
label = "{} : {:.2f} ({:.2f})".format(node['name'],spikes_mean,spikes_std)
#print(label)
return_text = return_text + label + '\n'
c = "tab:" + node['color']
if ax:
ax.hist(spike_counts_per_second,n_bins,density=True,histtype='bar',label=label,color=c)
if ax:
ax.set_xscale('log')
ax.legend()
return return_text
def run(show_plots=False,save_plots=False,slack=True,tstop=15000.0,path="outputECP"):
dt = 0.1
steps_per_ms = 1/dt
skip_seconds = 5
skip_ms = skip_seconds*1000
skip_n = int(skip_ms * steps_per_ms)
end_ms = tstop
spikes_location = os.path.join(path,'spikes.h5')
print("loading " + spikes_location)
f = h5py.File(spikes_location)
spikes_df = pd.DataFrame({'node_ids':f['spikes']['BLA']['node_ids'],'timestamps':f['spikes']['BLA']['timestamps']})
print("done")
if show_plots or save_plots:
ecp_h5_location = os.path.join(path,'ecp.h5')
print("loading " + ecp_h5_location)
ecp_channel = 0
f = h5py.File(ecp_h5_location)
data_raw = np.array(f['ecp']['data'])
ecp = data_raw.T[ecp_channel] #flip verts and grab channel 0
print("done")
node_set = [
{"name":"PN","start":0*scale,"end":799*scale,"color":"blue"},
{"name":"PV","start":800*scale,"end":892*scale,"color":"red"},
{"name":"SOM","start":893*scale,"end":943*scale,"color":"green"},
{"name":"CR","start":944*scale,"end":999*scale,"color":"purple"}
]
if show_plots or save_plots:
print("plotting...")
fig, (ax1,ax2,ax3) = plt.subplots(1,3,figsize=(15,4.8))#6.4,4.8 default
fig.suptitle('Amygdala Theta Analysis')
output_text = ecp_psd(ecp, skip_n=skip_n, ax=ax2)
sfh_text = spike_frequency_histogram(spikes_df,node_set,end_ms,skip_ms=skip_ms,ax=ax3)
output_text = output_text + sfh_text
print(output_text)
raster(spikes_df,node_set,skip_ms=skip_ms,ax=ax1)
if save_plots:
f_name = 'analysis.png'
print("saving " + f_name)
plt.savefig(f_name, bbox_inches='tight')
if show_plots:
print("showing plots...")
fig.tight_layout()
plt.show()
if slack:
import upload_analysis_slack
import subprocess
output_text = output_text + '\n\n' + str(subprocess.check_output(['git', 'diff', 'components_homogenous/synaptic_models/']))
upload_analysis_slack.upload(output_text)
else:
spike_frequency_histogram(spikes_df,node_set,end_ms,skip_ms=skip_ms)
if __name__ == '__main__':
parser = argparse.ArgumentParser("analysis of results")
parser.add_argument("--show-plots",action="store_true")
parser.add_argument("--save-plots",action="store_true")
parser.add_argument("--tstop",type=float,default=15000.0)
parser.add_argument("--path",default="outputECP")
args = parser.parse_args()
run(show_plots = args.show_plots, save_plots = args.save_plots, tstop=args.tstop, path=args.path)