Skip to content

Commit

Permalink
add reps arg to sim_dev_spiking(), minor edits to plotting, decrease …
Browse files Browse the repository at this point in the history
…net back down to 12x12, and tune
  • Loading branch information
rythorpe committed Jun 14, 2024
1 parent 3aee9a4 commit 6fac3c9
Show file tree
Hide file tree
Showing 5 changed files with 2,451 additions and 122 deletions.
24 changes: 12 additions & 12 deletions hnn_core/network_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,21 +356,21 @@ def L6_model(params=None, add_drives_from_params=False,
# net.cell_types[cell_type].synapses['gabab']['tau1'] = 45.0
# net.cell_types[cell_type].synapses['gabab']['tau2'] = 200.0

conn_weights = {"L2e_L2e_ampa": 0.00080, # 0.00070
conn_weights = {"L2e_L2e_ampa": 0.00055, # 0.00070
"L2e_L2e_nmda": 0.00020,
"L2i_L2e_gabaa": 0.010,
"L2i_L2e_gabab": 0.0010,
"L2i_L2e_gabab": 0.0001,
"L2e_L2i_ampa": 0.0060, # 0.00090
"L2i_L2i_gabaa": 0.005,
"L6i_cross_L2e_gabaa": 0.015,
"L6i_cross_L2e_gabaa": 0.020,
"L2e_L5e_ampa": 0.00010,
"L2i_L5e_gabaa": 0.00002,
"L5e_L5e_ampa": 0.00205, # 0.00077
"L5e_L5e_ampa": 0.00220, # 0.00077
"L5e_L5e_nmda": 0.00005,
"L5i_L5e_gabaa": 0.0025, # 0.018
"L5i_L5e_gabaa": 0.0035, # 0.018
"L5i_L5e_gabab": 0.0001, # changed from jones09
"L6i_cross_L5e_gabaa": 0.0030,
"L2e_L5i_ampa": 0.0005, # 0.00084
"L2e_L5i_ampa": 0.0010, # 0.00084
"L5e_L5i_ampa": 0.0040, # 0.00043
"L5i_L5i_gabaa": 0.005,
"L5e_L6e_ampa": 0.0001,
Expand All @@ -381,7 +381,7 @@ def L6_model(params=None, add_drives_from_params=False,
"L6e_L6i_ampa": 0.0060,
"L6i_L6i_gabaa": 0.005}
lamtha = 2.0
lamtha_L6_cross = 12.0
lamtha_L6_cross = 16.0
delay = 1.0
if rng is None:
rng = np.random.default_rng()
Expand Down Expand Up @@ -451,7 +451,7 @@ def L6_model(params=None, add_drives_from_params=False,
# loop over cell type connections that have more than one source group
######################################################################
for src_group in [1, 2]:
targ_group = src_group
# for now, target group and source group are the same!!!

# general connection probabilities
prob_e_e = 0.33
Expand Down Expand Up @@ -496,7 +496,7 @@ def L6_model(params=None, add_drives_from_params=False,
# layer5 Pyr -> layer6 Pyr
for loc in ['proximal', 'deep_basal']:
net.add_connection(src_gids='L5e',
target_gids=f'L6e_{targ_group}',
target_gids=f'L6e_{src_group}',
loc=loc,
receptor='ampa',
weight=conn_weights['L5e_L6e_ampa'],
Expand Down Expand Up @@ -538,19 +538,19 @@ def L6_model(params=None, add_drives_from_params=False,
prob_e_e_6 = prob_e_e
prob_i_e_6 = prob_i_e
prob_e_i_6 = prob_e_i + prob_offset_L6
prob_i_e_cross = 1.0
prob_i_e_cross = 11 / 12
else:
# between-group connection probabilities
prob_e_e = 0.00
prob_i_e = 0.75
prob_i_i = 0.25
prob_e_i = 0.00
lamtha_subpop = lamtha * 10
lamtha_subpop = lamtha * 8

prob_e_e_6 = prob_e_e
prob_i_e_6 = prob_i_e + prob_offset_L6
prob_e_i_6 = prob_e_i
prob_i_e_cross = 0.0
prob_i_e_cross = 1 / 12

# layer2 Pyr -> layer2 Pyr
for receptor in ['nmda', 'ampa']:
Expand Down
2,186 changes: 2,186 additions & 0 deletions rs_dd_project/main_sim.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion rs_dd_project/opt_baseline_drive_refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
target_sr_unconn = {cell: rate * 0.4 for cell, rate in
target_sr.items()}

net = L6_model(grid_shape=(16, 12), layer_6_fb=layer_6_fb, rng=rng)
net = L6_model(grid_shape=(12, 12), layer_6_fb=layer_6_fb, rng=rng)
net, dpls = sim_net_baseline(net.copy(), sim_time, burn_in_time,
poiss_params=poiss_params, clear_conn=clear_conn,
n_trials=n_trials, n_procs=n_procs, rng=rng,
Expand Down
270 changes: 196 additions & 74 deletions rs_dd_project/plot_intrinsic_net_dynamics.ipynb

Large diffs are not rendered by default.

91 changes: 56 additions & 35 deletions rs_dd_project/plot_simulate_rep_L6.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
# emp_dpl = read_dipole('S1_SupraT.txt')


def sim_dev_spiking(dev_magnitude=-1, n_trials=1, burn_in_time=300.0,
def sim_dev_spiking(dev_magnitude=-1, reps=4, n_trials=1, burn_in_time=300.0,
n_procs=10, record_vsec=False, rng=None):

# Hyperparameters of repetitive drive sequence
reps = 4
stim_interval = 100. # in ms; 10 Hz
rep_duration = 100. # 170 ms for human M/EEG

syn_depression = 0.05 # synaptic depression [0, 1]
# synaptic depression (fractional decrease between [0, 1])
syn_depression = 0.05

# see Constantinople and Bruno (2013) and de Kock et al. (2007) for
# experimental values re: evoked response peak timing
Expand All @@ -56,10 +56,10 @@ def sim_dev_spiking(dev_magnitude=-1, n_trials=1, burn_in_time=300.0,
# is maintained across standard (std) and deviant (dev) trials despite
# rounding to the nearest whole unit when drive cells are assigned via
# probabilities
grid_shape = (16, 12)
grid_shape = (12, 12)
n_1_delta = 6 # n_cells from group 1 constituting dev drive change
n_2_delta = 3 # n_cells from group 1 constituting dev drive change
dev_delta_prob = 1 / 3
dev_delta_prob = 1 / 4
n_agg_cells = grid_shape[0] * grid_shape[1]
# proportion of red to blue cells targetted by drive
prop_1_to_2 = n_1_delta / n_2_delta
Expand Down Expand Up @@ -95,12 +95,12 @@ def sim_dev_spiking(dev_magnitude=-1, n_trials=1, burn_in_time=300.0,
# undergo synaptic depletion

# prox drive weights and delays
weights_ampa_prox = {'L2/3i': 0.0010, 'L2/3e': 0.0020,
'L5i': 0.0018, 'L5e': 0.0015, 'L6e': 0.0030}
synaptic_delays_prox = {'L2/3i': 0.0, 'L2/3e': 1.0,
'L5i': 1., 'L5e': 2., 'L6e': 0.0}
weights_ampa_dist = {'L2/3i': 0.0000, 'L2/3e': 0.0021, 'L5e': 0.0014}
weights_nmda_dist = {'L2/3i': 0.0, 'L2/3e': 0.0, 'L5e': 0.0}
weights_ampa_prox = {'L2/3i': 0.0008, 'L2/3e': 0.0018,
'L5i': 0.0018, 'L5e': 0.0015, 'L6e': 0.0031}
synaptic_delays_prox = {'L2/3i': 2.0, 'L2/3e': 3.0,
'L5i': 3.0, 'L5e': 4.0, 'L6e': 0.0}
weights_ampa_dist = {'L2/3i': 0.0, 'L2/3e': 0.0007, 'L5e': 0.0007}
weights_nmda_dist = {'L2/3i': 0.0, 'L2/3e': 0.0001, 'L5e': 0.00005}
synaptic_delays_dist = {'L2/3i': 0.1, 'L2/3e': 0.1, 'L5e': 0.1}

# convert each dictionary to a more granular version with specific cell
Expand Down Expand Up @@ -137,11 +137,18 @@ def sim_dev_spiking(dev_magnitude=-1, n_trials=1, burn_in_time=300.0,
w_ampa_prox_depressed = {key: val * df for key, val in
weights_ampa_prox_group.items()}
# drive_strength = (prob_avg + prob_delta) * depression_factor
drive_strength = prob_avg * (1 + prob_delta)
drive_strengths.append(drive_strength)
drive_strength_default = prob_avg * (1 + prob_delta)
drive_strengths.append(drive_strength_default)

prob_prox = dict()
for layer_type in synaptic_delays_prox.keys():
# scale L6 delta to make it more extreme
# must by an integer number to allow a whole number change in
# the number of driven cells
drive_strength = drive_strength_default
if 'L6' in layer_type:
drive_strength = prob_avg * (1 + (2 * prob_delta))
print(f'increasing L6 delta to {2 * prob_delta} on rep {rep_idx}')
# group-type 1 (red) will be preferentially targetted
for group_type in cell_groups[layer_type]:
if '1' in group_type:
Expand All @@ -160,6 +167,7 @@ def sim_dev_spiking(dev_magnitude=-1, n_trials=1, burn_in_time=300.0,
# repetition
prob_dist = dict()
for layer_type in synaptic_delays_dist.keys():
drive_strength = drive_strength_default
for group_type in cell_groups[layer_type]:
if '1' in group_type:
prop = prop_1_to_2 * 2 / (prop_1_to_2 + 1)
Expand Down Expand Up @@ -187,7 +195,7 @@ def sim_dev_spiking(dev_magnitude=-1, n_trials=1, burn_in_time=300.0,
# dist drive
net.add_evoked_drive(
f'evdist_rep{rep_idx}', mu=drive_times[rep_idx]['dist'],
sigma=8.0, numspikes=1, weights_ampa=weights_ampa_dist_group,
sigma=5.0, numspikes=1, weights_ampa=weights_ampa_dist_group,
weights_nmda=weights_nmda_dist_group,
location='distal', synaptic_delays=synaptic_delays_dist_group,
space_constant=1e50, probability=prob_dist,
Expand Down Expand Up @@ -222,7 +230,7 @@ def plot_dev_spiking_v1(net, burn_in_time, rep_start_times, drive_times,
gridspec_kw=gridspec, constrained_layout=True)

# plot drive strength
arrow_height_max = 15
arrow_height_max = 33
head_length = arrow_height_max / 5
head_width = 12.0
stim_interval = np.unique(np.diff(rep_start_times))
Expand Down Expand Up @@ -405,7 +413,7 @@ def plot_dev_spiking_v2(net, burn_in_time, rep_start_times, drive_times,
gridspec_kw=gridspec, constrained_layout=True)

# plot drive strength
arrow_height_max = 15
arrow_height_max = 33
head_length = arrow_height_max / 5
head_width = 12.0
stim_interval = np.unique(np.diff(rep_start_times))
Expand All @@ -431,7 +439,7 @@ def plot_dev_spiking_v2(net, burn_in_time, rep_start_times, drive_times,
xmax=rep_time + stim_interval, colors='k',
linestyle=':')
axes[0].set_ylim([0, arrow_height_max])
axes[0].set_yticks([0, arrow_height_max])
axes[0].set_yticks([0, int(drive_strengths[0] * 100)])
axes[0].set_ylabel('external drive\n(% total\ndriven units)')

# vertical lines separating reps
Expand All @@ -454,6 +462,7 @@ def plot_dev_spiking_v2(net, burn_in_time, rep_start_times, drive_times,
cell_type_colors_hist = {'L2/3e': 'm', 'P': 'r', 'NP': 'b',
'L6e': 'm'}
spike_rates_all = dict()
mean_dev_peak_rates = dict()
for layer_idx, layer_spike_types in enumerate(spike_types_hist):

# first plot spike raster in background
Expand Down Expand Up @@ -494,26 +503,46 @@ def plot_dev_spiking_v2(net, burn_in_time, rep_start_times, drive_times,
trial_idx=trial_idx, show=False
)

# finally, plot a horizontal line at the peak agg. spike rate/rep
if 'P' not in spike_type:
sr_times = np.array(spike_rates['times'])
sr = np.array(spike_rates[spike_type])
for rep_time in rep_start_times:
rep_time_stop = rep_time + stim_interval
rep_mask = np.logical_and(sr_times >= rep_time,
sr_times < rep_time_stop)
peak = sr[rep_mask].max()
# finally, calculate peak spike rates and
# plot a horizontal line at the peak agg. spike rate/rep
sr_times = np.array(spike_rates['times'])
sr = np.array(spike_rates[spike_type])
for rep_time in rep_start_times:
rep_time_stop = rep_time + stim_interval
rep_mask = np.logical_and(sr_times >= rep_time,
sr_times < rep_time_stop)
peak = sr[rep_mask].max()
if 'P' not in spike_type:
axes[layer_idx + 1].hlines(
y=peak,
xmin=rep_time,
xmax=rep_time_stop,
colors=cell_type_colors_hist[spike_type],
linestyle=':'
)
# store peak spike rates on dev
if rep_time == rep_start_times[-1]:
if 'P' not in spike_type:
spike_type_name = spike_type
else:
spike_type_name = spike_type_groups[0]
mean_dev_peak_rates[spike_type_name] = peak

if spike_type != 'L2/3e' and spike_type != 'L6e':
spike_rates_all[spike_type_groups[0]] = spike_rates[spike_type]
spike_rates_all['times'] = spike_rates['times']

# round up upper y-axis tick to the nearest multiple of 5 for
# aesthetics
ylim_max = axes[layer_idx + 1].get_ylim()[1]
if layer_spike_types == 'L2/3e':
round_tick_to = 1 # try 5 if peaks are bigger
else:
round_tick_to = 5
ylim_max = (ylim_max // round_tick_to + 1) * round_tick_to
axes[layer_idx + 1].set_ylim([0, ylim_max])
axes[layer_idx + 1].set_yticks([0, ylim_max])

axes[1].set_ylabel('L2/3\nspikes/s')
handles, _ = axes[1].get_legend_handles_labels()
axes[1].legend(handles, ['agg. (eP+eNP)', 'eP', 'eNP'], ncol=3,
Expand All @@ -524,14 +553,6 @@ def plot_dev_spiking_v2(net, burn_in_time, rep_start_times, drive_times,
axes[2].get_legend().remove()
# fig.supylabel('mean single-unit spikes/s')

# make ylim consistent
ylim_max = max([axes[1].get_ylim()[1], axes[2].get_ylim()[1]])
# round up to the nearest multiple of 5 for aesthetics
ylim_max = (ylim_max // 5 + 1) * 5
axes[1].set_ylim([0, ylim_max])
axes[1].set_yticks([0, ylim_max])
axes[2].set_ylim([0, ylim_max])
axes[2].set_yticks([0, ylim_max])
axes[-1].set_xlim([burn_in_time - 100, tstop])
xticks = np.arange(burn_in_time - 100, tstop + 1, 100)
xticks_labels = (xticks - rep_start_times[0]).astype(int).astype(str)
Expand All @@ -540,7 +561,7 @@ def plot_dev_spiking_v2(net, burn_in_time, rep_start_times, drive_times,
axes[-1].set_xlabel('time (ms)')

if return_spike_rates is True:
return fig, spike_rates_all
return fig, spike_rates_all, mean_dev_peak_rates
else:
return fig

Expand Down

0 comments on commit 6fac3c9

Please sign in to comment.