Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] GUI synaptic gains implementation #918

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 62 additions & 21 deletions hnn_core/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,11 @@ def __init__(self, theme_color="#802989",
# Connectivity list
self.connectivity_widgets = list()

# Cell parameter list
self.cell_pameters_widgets = dict()
# Cell parameter dict
self.cell_parameters_widgets = dict()

# Synaptic Gains dict
self.synaptic_gain_widgets = dict()

self._init_ui_components()
self.add_logging_window_logger()
Expand Down Expand Up @@ -473,6 +476,7 @@ def _init_ui_components(self):
self._drives_out = Output() # tab to add new drives
self._connectivity_out = Output() # tab to tune connectivity.
self._cell_params_out = Output()
self._syn_gain_out = Output()

self._log_out = Output()

Expand Down Expand Up @@ -569,7 +573,9 @@ def _run_button_clicked(b):
self.widget_mpi_cmd, self.widget_n_jobs, self.params,
self._simulation_status_bar, self._simulation_status_contents,
self.connectivity_widgets, self.viz_manager,
self.simulation_list_widget, self.cell_pameters_widgets)
self.simulation_list_widget, self.cell_parameters_widgets,
self.synaptic_gain_widgets
)

def _simulation_list_change(value):
# Simulation Data
Expand Down Expand Up @@ -612,13 +618,13 @@ def _driver_type_change(value):

def _cell_type_radio_change(value):
_update_cell_params_vbox(self._cell_params_out,
self.cell_pameters_widgets,
self.cell_parameters_widgets,
value.new,
self.cell_layer_radio_buttons.value)

def _cell_layer_radio_change(value):
_update_cell_params_vbox(self._cell_params_out,
self.cell_pameters_widgets,
self.cell_parameters_widgets,
self.cell_type_radio_buttons.value,
value.new)

Expand Down Expand Up @@ -673,7 +679,7 @@ def compose(self, return_layout=True):
self._backend_config_out]),
], layout=self.layout['config_box'])

connectivity_configuration = Tab()
network_configuration = Tab()

connectivity_box = VBox([
HBox([self.load_connectivity_button, ]),
Expand All @@ -686,10 +692,14 @@ def compose(self, return_layout=True):
self._cell_params_out
])

connectivity_configuration.children = [connectivity_box,
cell_parameters]
connectivity_configuration.titles = ['Connectivity',
'Cell parameters']
syn_gain = VBox([self._syn_gain_out])

network_configuration.children = [connectivity_box,
cell_parameters,
syn_gain]
network_configuration.titles = ['Connectivity',
'Cell parameters',
'Synaptic gains']

drive_selections = VBox([
self.add_drive_button, self.widget_drive_type_selection,
Expand All @@ -709,7 +719,7 @@ def compose(self, return_layout=True):
# Tabs for left pane
left_tab = Tab()
left_tab.children = [
simulation_box, connectivity_configuration, drives_options,
simulation_box, network_configuration, drives_options,
config_panel,
]
titles = ('Simulation', 'Network', 'External drives',
Expand Down Expand Up @@ -902,9 +912,11 @@ def load_drive_and_connectivity(self):
self._connectivity_out,
self.connectivity_widgets,
self._cell_params_out,
self.cell_pameters_widgets,
self.cell_parameters_widgets,
self.cell_layer_radio_buttons,
self.cell_type_radio_buttons,
self._syn_gain_out,
self.synaptic_gain_widgets,
self.layout)

# Add drives
Expand Down Expand Up @@ -1034,9 +1046,11 @@ def on_upload_params_change(self, change, layout, load_type):
if load_type == 'connectivity':
add_connectivity_tab(
params, self._connectivity_out, self.connectivity_widgets,
self._cell_params_out, self.cell_pameters_widgets,
self._cell_params_out, self.cell_parameters_widgets,
self.cell_layer_radio_buttons,
self.cell_type_radio_buttons, layout)
self.cell_type_radio_buttons,
self._syn_gain_out, self.synaptic_gain_widgets,
layout)
elif load_type == 'drives':
self.add_drive_tab(params)
else:
Expand Down Expand Up @@ -1598,9 +1612,9 @@ def _build_drive_objects(drive_type, name, tstop_widget, layout, style,


def add_connectivity_tab(params, connectivity_out, connectivity_textfields,
cell_params_out, cell_pameters_vboxes,
cell_params_out, cell_parameters_vboxes,
cell_layer_radio_button, cell_type_radio_button,
layout):
syn_gain_out, syn_gain_textfields, layout):
"""Add all possible connectivity boxes to connectivity tab."""
net = dict_to_network(params)

Expand All @@ -1609,9 +1623,13 @@ def add_connectivity_tab(params, connectivity_out, connectivity_textfields,
connectivity_textfields)

# build cell parameters tab
add_cell_parameters_tab(cell_params_out, cell_pameters_vboxes,
add_cell_parameters_tab(cell_params_out, cell_parameters_vboxes,
cell_layer_radio_button, cell_type_radio_button,
layout)

# build synaptic gains tab
add_synaptic_gain_tab(net, syn_gain_out, syn_gain_textfields, layout)

return net


Expand Down Expand Up @@ -1719,6 +1737,24 @@ def add_cell_parameters_tab(cell_params_out, cell_pameters_vboxes,
cell_layer_radio_button.value)


def add_synaptic_gain_tab(net, syn_gain_out, syn_gain_textfields, layout):
jasmainak marked this conversation as resolved.
Show resolved Hide resolved
"""Creates widgets for global synaptic gains"""
gain_values = net.get_synaptic_gains()
gain_types = ('e_e', 'e_i', 'i_e', 'i_i')
for gain_type in gain_types:
gain_widget = BoundedFloatText(
value=gain_values[gain_type],
description=f'{gain_type}',
min=0, max=1e6, step=.1,
disabled=False, layout=layout)
syn_gain_textfields[gain_type] = gain_widget

gain_vbox = VBox([widget for widget in syn_gain_textfields.values()])

with syn_gain_out:
display(gain_vbox)


def get_cell_param_default_value(cell_type_key, param_dict):
return param_dict[cell_type_key]

Expand Down Expand Up @@ -1794,7 +1830,7 @@ def _drive_widget_to_dict(drive, name):

def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
drive_widgets, connectivity_textfields,
cell_params_vboxes,
cell_params_vboxes, syn_gain_textfields,
add_drive=True):
"""Construct network and add drives."""
print("init network")
Expand All @@ -1819,7 +1855,6 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
'nc_dict']['A_weight'] = vbox_key.children[1].value

# Update cell params

update_functions = {
'L2 Geometry': _update_L2_geometry_cell_params,
'L5 Geometry': _update_L5_geometry_cell_params,
Expand All @@ -1842,6 +1877,11 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
single_simulation_data['net'].cell_types[
cell_type]._compute_section_mechs()

# Update with synaptic gains
syn_gain_values = {key: widget.value
for key, widget in syn_gain_textfields.items()}
single_simulation_data['net'].set_synaptic_gains(**syn_gain_values)

if add_drive is False:
return
# add drives to network
Expand Down Expand Up @@ -1914,7 +1954,7 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets,
mpi_cmd, n_jobs, params, simulation_status_bar,
simulation_status_contents, connectivity_textfields,
viz_manager, simulations_list_widget,
cell_pameters_widgets):
cell_parameters_widgets, syn_gain_textfields):
"""Run the simulation and plot outputs."""
simulation_data = all_data["simulation_data"]
with log_out:
Expand All @@ -1933,7 +1973,8 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets,
_init_network_from_widgets(params, dt, tstop,
simulation_data[_sim_name], drive_widgets,
connectivity_textfields,
cell_pameters_widgets)
cell_parameters_widgets,
syn_gain_textfields)

print("start simulation")
if backend_selection.value == "MPI":
Expand Down
90 changes: 81 additions & 9 deletions hnn_core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,40 @@ def pick_connection(net, src_gids=None, target_gids=None,
return sorted(conn_set)


def _get_cell_index_by_synapse_type(net):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to put this outside of the Network class as a standalone function if it requires an existing Network?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my general rule is to keep methods of a class minimal, so the structure of the class/object is clear. Just because it takes the object as input doesn't mean it should be a method. The method should have an explicit purpose. For example:

ica = ICA()
ica.fit(X)
y = ica.predict(X)

once something becomes a method, it's very hard to disentangle it from the class

"""Returns the indices of excitatory and inhibitory cells in the network.

This function extracts the source GIDs (cell ID) of excitatory
and inhibitory cells based on their connection types. Excitatory cells are
identified by their synaptic connections using AMPA and NMDA receptors,
while inhibitory cells are identified by their connections using GABAA and
GABAB receptors.

Parameters
----------
net : Instance of Network object
The Network object

Returns
-------
tuple: A tuple containing two lists:
- e_cells (list): The source GIDs of excitatory cells.
- i_cells (list): The source GIDs of inhibitory cells.
"""

def list_src_gids(indices):
return np.concatenate([list(net.connectivity[conn_idx]['src_gids'])
for conn_idx in indices]).tolist()

picks_e = pick_connection(net, receptor=['ampa', 'nmda'])
e_cells = list_src_gids(picks_e)

picks_i = pick_connection(net, receptor=['gabaa', 'gabab'])
i_cells = list_src_gids(picks_i)

return e_cells, i_cells


class Network:
"""The Network class.

Expand Down Expand Up @@ -1427,8 +1461,8 @@ def add_electrode_array(self, name, electrode_pos, *, conductivity=0.3,
method=method,
min_distance=min_distance)})

def update_weights(self, e_e=None, e_i=None,
i_e=None, i_i=None, copy=False):
def set_synaptic_gains(self, e_e=None, e_i=None,
i_e=None, i_i=None, copy=False):
"""Update synaptic weights of the network.

Parameters
Expand Down Expand Up @@ -1466,13 +1500,7 @@ def update_weights(self, e_e=None, e_i=None,

net = self.copy() if copy else self

e_conns = pick_connection(self, receptor=['ampa', 'nmda'])
e_cells = np.concatenate([list(net.connectivity[
conn_idx]['src_gids']) for conn_idx in e_conns]).tolist()

i_conns = pick_connection(self, receptor=['gabaa', 'gabab'])
i_cells = np.concatenate([list(net.connectivity[
conn_idx]['src_gids']) for conn_idx in i_conns]).tolist()
e_cells, i_cells = _get_cell_index_by_synapse_type(net)
conn_types = {
'e_e': (e_e, e_cells, e_cells),
'e_i': (e_i, e_cells, i_cells),
Expand All @@ -1497,6 +1525,50 @@ def update_weights(self, e_e=None, e_i=None,
if copy:
return net

def get_synaptic_gains(self):
"""Retrieve gain values for different connection types in the network.

This function identifies excitatory and inhibitory cells in the network
and retrieves the gain value for each type of synaptic connection:
- excitatory to excitatory (e_e)
- excitatory to inhibitory (e_i)
- inhibitory to excitatory (i_e)
- inhibitory to inhibitory (i_i)

The gain is assumed to be uniform within each connection type, and only
the first connection's gain value is used for each type.

Returns
-------
values : dict
A dictionary with the connection types ('e_e', 'e_i', 'i_e',
'i_i') as keys and their corresponding gain values.
"""
values = {}
e_cells, i_cells = _get_cell_index_by_synapse_type(self)

# Define the connection types and source/target cell indexes
conn_types = {
'e_e': (e_cells, e_cells),
'e_i': (e_cells, i_cells),
'i_e': (i_cells, e_cells),
'i_i': (i_cells, i_cells)
}

# Retrieve the gain value for each connection type
for conn_type, (src_idxs, target_idxs) in conn_types.items():
picks = pick_connection(self,
src_gids=src_idxs,
target_gids=target_idxs)

if picks:
# Extract the gain from the first connection
values[conn_type] = (
self.connectivity[picks[0]]['nc_dict']['gain']
)

return values

def plot_cells(self, ax=None, show=True):
"""Plot the cells using Network.pos_dict.

Expand Down
Loading
Loading