-
Notifications
You must be signed in to change notification settings - Fork 3
/
analysis_feng_hom.py
137 lines (104 loc) · 4.15 KB
/
analysis_feng_hom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
A python implementation of matlab/analysis.m
TB - 8/4/21
"""
from scipy.signal import hanning,welch,decimate
import h5py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import sys
scale = 1
def raster(spikes_df,node_set,skip_ms=0,ax=None):
spikes_df = spikes_df[spikes_df['timestamps']>skip_ms]
for node in node_set:
cells = range(node['start'],node['end']+1) #+1 to be inclusive of last cell
cell_spikes = spikes_df[spikes_df['node_ids'].isin(cells)]
ax.scatter(cell_spikes['timestamps'],cell_spikes['node_ids'],
c='tab:'+node['color'],s=0.25, label=node['name'])
handles,labels = ax.get_legend_handles_labels()
ax.legend(reversed(handles), reversed(labels))
ax.grid(True)
def raw_ecp(lfp):
pass
def ecp_psd(ecp,skip_n=0,downsample=20,nfft=1024,fs=1000,noverlap=0,ax=None):
#skip_n first few
data = ecp[skip_n:]
#downsample the data to fit ms (steps used 20=1/.05 step)
lfp_d = decimate(data,downsample)
raw_ecp(lfp_d)
win = hanning(nfft, True)
f,pxx = welch(lfp_d,fs,window=win,noverlap=noverlap,nfft=nfft)
ax.set_xscale('log')
ax.set_yscale('log')
ax.plot(f, pxx*1000,linewidth=0.6)
ax.set_ylim([0,0.1])
def spike_frequency_histogram(spikes_df,node_set,ms,skip_ms=0,ax=None,n_bins=10):
print("Type : mean (std)")
for node in node_set:
cells = range(node['start'],node['end']+1) #+1 to be inclusive of last cell
cell_spikes = spikes_df[spikes_df['node_ids'].isin(cells)]
#skip the first few ms
cell_spikes = cell_spikes[cell_spikes['timestamps']>skip_ms]
spike_counts = cell_spikes.node_ids.value_counts()
total_seconds = (ms-skip_ms)/1000
spike_counts_per_second = spike_counts / total_seconds
spikes_mean = spike_counts_per_second.mean()
spikes_std = spike_counts_per_second.std()
label = "{} : {:.2f} ({:.2f})".format(node['name'],spikes_mean,spikes_std)
print(label)
c = "tab:" + node['color']
if ax:
ax.hist(spike_counts_per_second,n_bins,density=True,histtype='bar',label=label,color=c)
if ax:
ax.set_xscale('log')
ax.legend()
def run(show_plots=False,save_plots=False):
dt = 0.05
steps_per_ms = 1/dt
skip_seconds = 5
skip_ms = skip_seconds*1000
skip_n = int(skip_ms * steps_per_ms)
end_ms = 15000
spikes_location = 'outputECP/spikes.h5'
print("loading " + spikes_location)
f = h5py.File(spikes_location)
spikes_df = pd.DataFrame({'node_ids':f['spikes']['BLA']['node_ids'],'timestamps':f['spikes']['BLA']['timestamps']})
print("done")
if show_plots or save_plots:
ecp_h5_location = 'outputECP/ecp.h5'
print("loading " + ecp_h5_location)
ecp_channel = 0
f = h5py.File(ecp_h5_location)
data_raw = np.array(f['ecp']['data'])
ecp = data_raw.T[ecp_channel] #flip verts and grab channel 0
print("done")
node_set = [
{"name":"PN","start":0*scale,"end":899*scale,"color":"blue"},
{"name":"PV","start":900*scale,"end":999*scale,"color":"red"},
]
if show_plots or save_plots:
print("plotting...")
fig, (ax1,ax2,ax3) = plt.subplots(1,3,figsize=(15,4.8))#6.4,4.8 default
fig.suptitle('Amygdala Theta Analysis')
ecp_psd(ecp, skip_n=skip_n, ax=ax2)
spike_frequency_histogram(spikes_df,node_set,end_ms,skip_ms=skip_ms,ax=ax3)
raster(spikes_df,node_set,skip_ms=skip_ms,ax=ax1)
if save_plots:
f_name = 'analysis.png'
print("saving " + f_name)
plt.savefig(f_name, bbox_inches='tight')
if show_plots:
print("showing plots...")
fig.tight_layout()
plt.show()
else:
spike_frequency_histogram(spikes_df,node_set,end_ms,skip_ms=skip_ms)
if __name__ == '__main__':
show_plots = False
save_plots = False
if '--show-plots' in sys.argv:
show_plots = True
if '--save-plots' in sys.argv:
save_plots = True
run(show_plots = show_plots, save_plots = save_plots)