Skip to content

Commit

Permalink
feat: Set default poisson tstop to the tstop widget's value by default.
Browse files Browse the repository at this point in the history
  • Loading branch information
gtdang committed Aug 7, 2024
1 parent a6f1461 commit adff1a2
Showing 1 changed file with 40 additions and 37 deletions.
77 changes: 40 additions & 37 deletions hnn_core/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@ def _get_poisson_widget(name, tstop_widget, layout, style, location, data=None,
default_delays=None, sync_evinput=False):
default_data = {
'tstart': 0.0,
'tstop': 0.0,
'tstop': tstop_widget.value,
'seedcore': 14,
'rate_constant': {
'L5_pyramidal': 40.,
Expand Down Expand Up @@ -1537,7 +1537,7 @@ def add_drive_tab(params, log_out, drives_out, drive_widgets, drive_boxes,
drive_widgets.pop()
drive_boxes.pop()

drive_names = sorted(drive_specs.keys())
drive_names = drive_specs.keys()

for idx, drive_name in enumerate(drive_names): # order matters
specs = drive_specs[drive_name]
Expand Down Expand Up @@ -1682,38 +1682,39 @@ def _drive_widget_to_dict(drive, name):
}


def _filter_poisson_inputs(ampa, nmda, rates, delays):
def _filter_by_keys(dictionary, keys):
return {
key: value
for key, value in dictionary.items()
if key in keys
}


def _filter_drive_inputs(weights_dicts, constants_dicts):
def _get_positive_weights(weight_dict):
return {
key: value
for key, value in weight_dict.items()
if value > 0
}

def _filter_by_keys(dictionary, keys):
return {
key: value
for key, value in dictionary.items()
if key in keys
}

# Filter the weights for positive values
weights_ampa, weights_nmda = [_get_positive_weights(weights)
for weights in [ampa, nmda]]
pos_weights_dicts = [_get_positive_weights(weights)
for weights in weights_dicts]

# Filter the rates and delays that match the weights
cell_types = set(list(weights_ampa.keys()) + list(weights_nmda.keys()))
rate_constant = _filter_by_keys(rates, cell_types)
synaptic_delays = _filter_by_keys(delays, cell_types)
cell_types = set()
for dictionary in pos_weights_dicts:
cell_types.update(dictionary.keys())

filtered_constants = [_filter_by_keys(dictionary, cell_types)
for dictionary in constants_dicts]

# Set weights to None if the dict is empty
weights_ampa, weights_nmda = [weights if weights else None
for weights in [weights_ampa, weights_nmda]]
weights_dicts_cleaned = [weights if weights else None
for weights in pos_weights_dicts]

return dict(weights_ampa=weights_ampa,
weights_nmda=weights_nmda,
rate_constant=rate_constant,
synaptic_delays=synaptic_delays)
return weights_dicts_cleaned, filtered_constants


def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
Expand Down Expand Up @@ -1785,25 +1786,27 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
weights_ampa = _drive_widget_to_dict(drive, 'weights_ampa')
weights_nmda = _drive_widget_to_dict(drive, 'weights_nmda')
synaptic_delays = _drive_widget_to_dict(drive, 'delays')

weights, constants = _filter_drive_inputs(
weights_dicts=[weights_ampa, weights_nmda],
constants_dicts=[synaptic_delays]
)

print(
f"drive type is {drive['type']}, location={drive['location']}")
if drive['type'] == 'Poisson':
rate_constant = _drive_widget_to_dict(drive, 'rate_constant')
poisson_params = _filter_poisson_inputs(
weights_ampa,
weights_nmda,
rate_constant,
synaptic_delays
)
filtered_rates = _filter_by_keys(rate_constant,
constants[0].keys())
single_simulation_data['net'].add_poisson_drive(
name=drive['name'],
tstart=drive['tstart'].value,
tstop=drive['tstop'].value,
rate_constant=poisson_params['rate_constant'],
rate_constant=filtered_rates,
location=drive['location'],
weights_ampa=poisson_params['weights_ampa'],
weights_nmda=poisson_params['weights_nmda'],
synaptic_delays=poisson_params['synaptic_delays'],
weights_ampa=weights[0],
weights_nmda=weights[1],
synaptic_delays=constants[0],
space_constant=100.0,
event_seed=drive['seedcore'].value,
**synch_inputs_kwargs)
Expand All @@ -1814,9 +1817,9 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
sigma=drive['sigma'].value,
numspikes=drive['numspikes'].value,
location=drive['location'],
weights_ampa=weights_ampa,
weights_nmda=weights_nmda,
synaptic_delays=synaptic_delays,
weights_ampa=weights[0],
weights_nmda=weights[1],
synaptic_delays=constants[0],
space_constant=3.0,
event_seed=drive['seedcore'].value,
**synch_inputs_kwargs)
Expand All @@ -1829,9 +1832,9 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
burst_rate=drive['burst_rate'].value,
burst_std=drive['burst_std'].value,
location=drive['location'],
weights_ampa=weights_ampa,
weights_nmda=weights_nmda,
synaptic_delays=synaptic_delays,
weights_ampa=weights[0],
weights_nmda=weights[1],
synaptic_delays=constants[0],
event_seed=drive['seedcore'].value,
**synch_inputs_kwargs)

Expand Down

0 comments on commit adff1a2

Please sign in to comment.