-
Notifications
You must be signed in to change notification settings - Fork 0
/
currents.py
81 lines (63 loc) · 3.28 KB
/
currents.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import h5py
import pandas as pd
from bmtk.utils.reports.compartment import CompartmentReport
from bmtk.analyzer.compartment import plot_traces
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
def sum_current(path):
print("Summing", path)
report = CompartmentReport(path)
array = report.data()
current_array = []
node_id = []
cells = report.node_ids()
for i in tqdm(range(len(report.node_ids()))):
if cells[i]>3050:
pass
else:
data = report.data(node_id=cells[i])
data = sum(sum(data))
current_array.append(data)
node_id.append(cells[i])
return current_array, node_id
def get_firing_rate(spike_path,node_id,total_seconds):
f = h5py.File(spike_path)
spikes_df = pd.DataFrame({'node_ids': f['spikes']['BLA']['node_ids'], 'timestamps': f['spikes']['BLA']['timestamps']})
cell_spikes = spikes_df[spikes_df['node_ids'].isin(node_id)]
print(node_id)
spike_counts = cell_spikes.node_ids.value_counts()
spike_counts_per_second = spike_counts / total_seconds
return spike_counts_per_second
#plot_traces(report_path="output_baseline/AMPA_NMDA_STP_PN2PN_i_NMDA.h5")
def process_currents():
tone2PN_AMPA_current, nodes = sum_current('output_currents_baseline_blocked_server/tone2PN_i_AMPA.h5')
tone2PN_NMDA_current, nodes = sum_current('output_currents_baseline_blocked_server/tone2PN_i_NMDA.h5')
PN2PN_i_AMPA_current, nodes = sum_current('output_currents_baseline_blocked_server/AMPA_NMDA_STP_PN2PN_i_AMPA.h5')
PN2PN_i_NMDA_current, nodes = sum_current('output_currents_baseline_blocked_server/AMPA_NMDA_STP_PN2PN_i_NMDA.h5')
bg2pyr_AMPA_current, nodes = sum_current('output_currents_baseline_blocked_server/bg2pyr_i_AMPA.h5')
bg2pyr_NMDA_current, nodes = sum_current('output_currents_baseline_blocked_server/bg2pyr_i_NMDA.h5')
spikes = get_firing_rate('output_currents_baseline_blocked_server/spikes.h5',node_id=nodes,total_seconds=15)
all_data = list(zip(nodes,tone2PN_AMPA_current, tone2PN_NMDA_current,PN2PN_i_AMPA_current,PN2PN_i_NMDA_current,bg2pyr_AMPA_current,bg2pyr_NMDA_current,spikes))
df = pd.DataFrame(all_data,columns=['node_id','tone2PN_AMPA_current', 'tone2PN_NMDA_current',
'PN2PN_i_AMPA_current','PN2PN_i_NMDA_current','bg2pyr_AMPA_current','bg2pyr_NMDA_current','spikes'])
df.to_csv("Currents_blocked.csv")
def read_in(path):
df = pd.read_csv(path)
return df
def current_plot(current, label,ax):
ax.hist(current)
ax.set_ylabel("cells")
ax.set_xlabel("Current")
ax.set_title(label)
process_currents()
df = read_in("Currents_blocked.csv")
fig, axs = plt.subplots(3,2, figsize=(12, 6),tight_layout=True,sharey=True,sharex=True)
current_plot(df['tone2PN_AMPA_current'],label = 'tone2PN_AMPA_current',ax=axs[0,0])
current_plot(df['tone2PN_NMDA_current'],label = 'tone2PN_NMDA_current',ax=axs[0,1])
current_plot(df['PN2PN_i_AMPA_current'],label = 'PN2PN_i_AMPA_current',ax=axs[1,0])
current_plot(df['PN2PN_i_NMDA_current'],label = 'PN2PN_i_NMDA_current',ax=axs[1,1])
current_plot(df['bg2pyr_AMPA_current'],label = 'bg2pyr_AMPA_current',ax=axs[2,0])
current_plot(df['bg2pyr_NMDA_current'],label = 'bg2pyr_NMDA_current',ax=axs[2,1])
plt.suptitle("blocked baseline")
plt.show()