diff --git a/hnn_core/gui/gui.py b/hnn_core/gui/gui.py index b091a16f4..c3a27af08 100644 --- a/hnn_core/gui/gui.py +++ b/hnn_core/gui/gui.py @@ -422,8 +422,11 @@ def __init__(self, theme_color="#802989", # Connectivity list self.connectivity_widgets = list() - # Cell parameter list - self.cell_pameters_widgets = dict() + # Cell parameter dict + self.cell_parameters_widgets = dict() + + # Synaptic Gains dict + self.synaptic_gain_widgets = dict() self._init_ui_components() self.add_logging_window_logger() @@ -473,6 +476,7 @@ def _init_ui_components(self): self._drives_out = Output() # tab to add new drives self._connectivity_out = Output() # tab to tune connectivity. self._cell_params_out = Output() + self._syn_gain_out = Output() self._log_out = Output() @@ -569,7 +573,9 @@ def _run_button_clicked(b): self.widget_mpi_cmd, self.widget_n_jobs, self.params, self._simulation_status_bar, self._simulation_status_contents, self.connectivity_widgets, self.viz_manager, - self.simulation_list_widget, self.cell_pameters_widgets) + self.simulation_list_widget, self.cell_parameters_widgets, + self.synaptic_gain_widgets + ) def _simulation_list_change(value): # Simulation Data @@ -612,13 +618,13 @@ def _driver_type_change(value): def _cell_type_radio_change(value): _update_cell_params_vbox(self._cell_params_out, - self.cell_pameters_widgets, + self.cell_parameters_widgets, value.new, self.cell_layer_radio_buttons.value) def _cell_layer_radio_change(value): _update_cell_params_vbox(self._cell_params_out, - self.cell_pameters_widgets, + self.cell_parameters_widgets, self.cell_type_radio_buttons.value, value.new) @@ -673,7 +679,7 @@ def compose(self, return_layout=True): self._backend_config_out]), ], layout=self.layout['config_box']) - connectivity_configuration = Tab() + network_configuration = Tab() connectivity_box = VBox([ HBox([self.load_connectivity_button, ]), @@ -686,10 +692,14 @@ def compose(self, return_layout=True): self._cell_params_out ]) - connectivity_configuration.children = [connectivity_box, - cell_parameters] - connectivity_configuration.titles = ['Connectivity', - 'Cell parameters'] + syn_gain = VBox([self._syn_gain_out]) + + network_configuration.children = [connectivity_box, + cell_parameters, + syn_gain] + network_configuration.titles = ['Connectivity', + 'Cell parameters', + 'Synaptic gains'] drive_selections = VBox([ self.add_drive_button, self.widget_drive_type_selection, @@ -709,7 +719,7 @@ def compose(self, return_layout=True): # Tabs for left pane left_tab = Tab() left_tab.children = [ - simulation_box, connectivity_configuration, drives_options, + simulation_box, network_configuration, drives_options, config_panel, ] titles = ('Simulation', 'Network', 'External drives', @@ -902,9 +912,11 @@ def load_drive_and_connectivity(self): self._connectivity_out, self.connectivity_widgets, self._cell_params_out, - self.cell_pameters_widgets, + self.cell_parameters_widgets, self.cell_layer_radio_buttons, self.cell_type_radio_buttons, + self._syn_gain_out, + self.synaptic_gain_widgets, self.layout) # Add drives @@ -1034,9 +1046,11 @@ def on_upload_params_change(self, change, layout, load_type): if load_type == 'connectivity': add_connectivity_tab( params, self._connectivity_out, self.connectivity_widgets, - self._cell_params_out, self.cell_pameters_widgets, + self._cell_params_out, self.cell_parameters_widgets, self.cell_layer_radio_buttons, - self.cell_type_radio_buttons, layout) + self.cell_type_radio_buttons, + self._syn_gain_out, self.synaptic_gain_widgets, + layout) elif load_type == 'drives': self.add_drive_tab(params) else: @@ -1598,9 +1612,9 @@ def _build_drive_objects(drive_type, name, tstop_widget, layout, style, def add_connectivity_tab(params, connectivity_out, connectivity_textfields, - cell_params_out, cell_pameters_vboxes, + cell_params_out, cell_parameters_vboxes, cell_layer_radio_button, cell_type_radio_button, - layout): + syn_gain_out, syn_gain_textfields, layout): """Add all possible connectivity boxes to connectivity tab.""" net = dict_to_network(params) @@ -1609,9 +1623,13 @@ def add_connectivity_tab(params, connectivity_out, connectivity_textfields, connectivity_textfields) # build cell parameters tab - add_cell_parameters_tab(cell_params_out, cell_pameters_vboxes, + add_cell_parameters_tab(cell_params_out, cell_parameters_vboxes, cell_layer_radio_button, cell_type_radio_button, layout) + + # build synaptic gains tab + add_synaptic_gain_tab(net, syn_gain_out, syn_gain_textfields, layout) + return net @@ -1719,6 +1737,24 @@ def add_cell_parameters_tab(cell_params_out, cell_pameters_vboxes, cell_layer_radio_button.value) +def add_synaptic_gain_tab(net, syn_gain_out, syn_gain_textfields, layout): + """Creates widgets for global synaptic gains""" + gain_values = net.get_synaptic_gains() + gain_types = ('e_e', 'e_i', 'i_e', 'i_i') + for gain_type in gain_types: + gain_widget = BoundedFloatText( + value=gain_values[gain_type], + description=f'{gain_type}', + min=0, max=1e6, step=.1, + disabled=False, layout=layout) + syn_gain_textfields[gain_type] = gain_widget + + gain_vbox = VBox([widget for widget in syn_gain_textfields.values()]) + + with syn_gain_out: + display(gain_vbox) + + def get_cell_param_default_value(cell_type_key, param_dict): return param_dict[cell_type_key] @@ -1794,7 +1830,7 @@ def _drive_widget_to_dict(drive, name): def _init_network_from_widgets(params, dt, tstop, single_simulation_data, drive_widgets, connectivity_textfields, - cell_params_vboxes, + cell_params_vboxes, syn_gain_textfields, add_drive=True): """Construct network and add drives.""" print("init network") @@ -1819,7 +1855,6 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data, 'nc_dict']['A_weight'] = vbox_key.children[1].value # Update cell params - update_functions = { 'L2 Geometry': _update_L2_geometry_cell_params, 'L5 Geometry': _update_L5_geometry_cell_params, @@ -1842,6 +1877,11 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data, single_simulation_data['net'].cell_types[ cell_type]._compute_section_mechs() + # Update with synaptic gains + syn_gain_values = {key: widget.value + for key, widget in syn_gain_textfields.items()} + single_simulation_data['net'].set_synaptic_gains(**syn_gain_values) + if add_drive is False: return # add drives to network @@ -1914,7 +1954,7 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets, mpi_cmd, n_jobs, params, simulation_status_bar, simulation_status_contents, connectivity_textfields, viz_manager, simulations_list_widget, - cell_pameters_widgets): + cell_parameters_widgets, syn_gain_textfields): """Run the simulation and plot outputs.""" simulation_data = all_data["simulation_data"] with log_out: @@ -1933,7 +1973,8 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets, _init_network_from_widgets(params, dt, tstop, simulation_data[_sim_name], drive_widgets, connectivity_textfields, - cell_pameters_widgets) + cell_parameters_widgets, + syn_gain_textfields) print("start simulation") if backend_selection.value == "MPI": diff --git a/hnn_core/network.py b/hnn_core/network.py index c00de12b0..5ca3962a7 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -299,6 +299,40 @@ def pick_connection(net, src_gids=None, target_gids=None, return sorted(conn_set) +def _get_cell_index_by_synapse_type(net): + """Returns the indices of excitatory and inhibitory cells in the network. + + This function extracts the source GIDs (cell ID) of excitatory + and inhibitory cells based on their connection types. Excitatory cells are + identified by their synaptic connections using AMPA and NMDA receptors, + while inhibitory cells are identified by their connections using GABAA and + GABAB receptors. + + Parameters + ---------- + net : Instance of Network object + The Network object + + Returns + ------- + tuple: A tuple containing two lists: + - e_cells (list): The source GIDs of excitatory cells. + - i_cells (list): The source GIDs of inhibitory cells. + """ + + def list_src_gids(indices): + return np.concatenate([list(net.connectivity[conn_idx]['src_gids']) + for conn_idx in indices]).tolist() + + picks_e = pick_connection(net, receptor=['ampa', 'nmda']) + e_cells = list_src_gids(picks_e) + + picks_i = pick_connection(net, receptor=['gabaa', 'gabab']) + i_cells = list_src_gids(picks_i) + + return e_cells, i_cells + + class Network: """The Network class. @@ -1427,8 +1461,8 @@ def add_electrode_array(self, name, electrode_pos, *, conductivity=0.3, method=method, min_distance=min_distance)}) - def update_weights(self, e_e=None, e_i=None, - i_e=None, i_i=None, copy=False): + def set_synaptic_gains(self, e_e=None, e_i=None, + i_e=None, i_i=None, copy=False): """Update synaptic weights of the network. Parameters @@ -1466,13 +1500,7 @@ def update_weights(self, e_e=None, e_i=None, net = self.copy() if copy else self - e_conns = pick_connection(self, receptor=['ampa', 'nmda']) - e_cells = np.concatenate([list(net.connectivity[ - conn_idx]['src_gids']) for conn_idx in e_conns]).tolist() - - i_conns = pick_connection(self, receptor=['gabaa', 'gabab']) - i_cells = np.concatenate([list(net.connectivity[ - conn_idx]['src_gids']) for conn_idx in i_conns]).tolist() + e_cells, i_cells = _get_cell_index_by_synapse_type(net) conn_types = { 'e_e': (e_e, e_cells, e_cells), 'e_i': (e_i, e_cells, i_cells), @@ -1497,6 +1525,50 @@ def update_weights(self, e_e=None, e_i=None, if copy: return net + def get_synaptic_gains(self): + """Retrieve gain values for different connection types in the network. + + This function identifies excitatory and inhibitory cells in the network + and retrieves the gain value for each type of synaptic connection: + - excitatory to excitatory (e_e) + - excitatory to inhibitory (e_i) + - inhibitory to excitatory (i_e) + - inhibitory to inhibitory (i_i) + + The gain is assumed to be uniform within each connection type, and only + the first connection's gain value is used for each type. + + Returns + ------- + values : dict + A dictionary with the connection types ('e_e', 'e_i', 'i_e', + 'i_i') as keys and their corresponding gain values. + """ + values = {} + e_cells, i_cells = _get_cell_index_by_synapse_type(self) + + # Define the connection types and source/target cell indexes + conn_types = { + 'e_e': (e_cells, e_cells), + 'e_i': (e_cells, i_cells), + 'i_e': (i_cells, e_cells), + 'i_i': (i_cells, i_cells) + } + + # Retrieve the gain value for each connection type + for conn_type, (src_idxs, target_idxs) in conn_types.items(): + picks = pick_connection(self, + src_gids=src_idxs, + target_gids=target_idxs) + + if picks: + # Extract the gain from the first connection + values[conn_type] = ( + self.connectivity[picks[0]]['nc_dict']['gain'] + ) + + return values + def plot_cells(self, ax=None, show=True): """Plot the cells using Network.pos_dict. diff --git a/hnn_core/tests/test_gui.py b/hnn_core/tests/test_gui.py index 257caf70c..6a1f030b9 100644 --- a/hnn_core/tests/test_gui.py +++ b/hnn_core/tests/test_gui.py @@ -116,6 +116,8 @@ def test_gui_compose(): gui = HNNGUI() gui.compose() assert len(gui.connectivity_widgets) == 12 + assert len(gui.synaptic_gain_widgets) == 4 + assert len(gui.cell_parameters_widgets) == 6 assert len(gui.drive_widgets) == 3 plt.close('all') @@ -309,7 +311,8 @@ def test_gui_change_connectivity(): _single_simulation, gui.drive_widgets, gui.connectivity_widgets, - gui.cell_pameters_widgets, + gui.cell_parameters_widgets, + gui.synaptic_gain_widgets, add_drive=False) # test if the new value is reflected in the network @@ -348,7 +351,8 @@ def test_gui_init_network(setup_gui): _init_network_from_widgets(gui.params, gui.widget_dt, gui.widget_tstop, _single_simulation, gui.drive_widgets, gui.connectivity_widgets, - gui.cell_pameters_widgets) + gui.cell_parameters_widgets, + gui.synaptic_gain_widgets) plt.close('all') net_from_gui = _single_simulation['net'] @@ -939,7 +943,8 @@ def test_gui_add_tonic_input(): _init_network_from_widgets(gui.params, gui.widget_dt, gui.widget_tstop, _single_simulation, gui.drive_widgets, gui.connectivity_widgets, - gui.cell_pameters_widgets) + gui.cell_parameters_widgets, + gui.synaptic_gain_widgets) net = _single_simulation['net'] assert net.external_biases['tonic'] is not None @@ -966,7 +971,7 @@ def test_gui_cell_params_widgets(setup_gui): layers = gui.cell_layer_radio_buttons.options assert (len(layers) == 3) - keys = gui.cell_pameters_widgets.keys() + keys = gui.cell_parameters_widgets.keys() num_cell_params = 0 for pyramid_cell_type in pyramid_cell_types: cell_type = pyramid_cell_type.split('_')[0] @@ -1122,3 +1127,33 @@ def test_delete_single_drive(setup_gui): 'alpha_prox (proximal)', 'poisson (proximal)', 'tonic') + + +def test_adjust_synaptic_weights(setup_gui): + """Test adjusting synaptic weight widgets.""" + + gui = setup_gui + _single_simulation = {} + _single_simulation['net'] = dict_to_network(gui.params) + _init_network_from_widgets(gui.params, gui.widget_dt, gui.widget_tstop, + _single_simulation, gui.drive_widgets, + gui.connectivity_widgets, + gui.cell_parameters_widgets, + gui.synaptic_gain_widgets) + + gains_default = _single_simulation['net'].get_synaptic_gains() + assert gains_default == {'e_e': 1.0, 'e_i': 1.0, 'i_e': 1.0, 'i_i': 1.0} + + # Change the synaptic weight widgets + gui.synaptic_gain_widgets['e_e'].value = 0.5 + gui.synaptic_gain_widgets['e_i'].value = 0.5 + gui.synaptic_gain_widgets['i_i'].value = 1.1 + gui.synaptic_gain_widgets['i_e'].value = 1.1 + _init_network_from_widgets(gui.params, gui.widget_dt, gui.widget_tstop, + _single_simulation, gui.drive_widgets, + gui.connectivity_widgets, + gui.cell_parameters_widgets, + gui.synaptic_gain_widgets) + + gains_altered = _single_simulation['net'].get_synaptic_gains() + assert gains_altered == {'e_e': 0.5, 'e_i': 0.5, 'i_e': 1.1, 'i_i': 1.1} diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index 5b7d7a742..a0c170c39 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -895,8 +895,8 @@ def test_network_mesh(): _ = law_2021_model(mesh_shape=mesh_shape) -def test_synaptic_gains(): - """Test synaptic gains update""" +def test_set_synaptic_gains(): + """Test synaptic gains setter""" net = jones_2009_model() nb_base = NetworkBuilder(net) e_cell_names = ['L2_pyramidal', 'L5_pyramidal'] @@ -906,16 +906,16 @@ def test_synaptic_gains(): arg_names = ['e_e', 'e_i', 'i_e', 'i_i'] for arg in arg_names: with pytest.raises(TypeError, match='must be an instance of int or'): - net.update_weights(**{arg: 'abc'}) + net.set_synaptic_gains(**{arg: 'abc'}) with pytest.raises(ValueError, match='must be non-negative'): - net.update_weights(**{arg: -1}) + net.set_synaptic_gains(**{arg: -1}) with pytest.raises(TypeError, match='must be an instance of bool'): - net.update_weights(copy='True') + net.set_synaptic_gains(copy='True') # Single argument check with copy - net_updated = net.update_weights(e_e=2.0, copy=True) + net_updated = net.set_synaptic_gains(e_e=2.0, copy=True) for conn in net_updated.connectivity: if (conn['src_type'] in e_cell_names and conn['target_type'] in e_cell_names): @@ -927,7 +927,7 @@ def test_synaptic_gains(): assert conn['nc_dict']['gain'] == 1.0 # Single argument with inplace change - net.update_weights(i_e=0.5, copy=False) + net.set_synaptic_gains(i_e=0.5, copy=False) for conn in net.connectivity: if (conn['src_type'] in i_cell_names and conn['target_type'] in e_cell_names): @@ -936,7 +936,7 @@ def test_synaptic_gains(): assert conn['nc_dict']['gain'] == 1.0 # Two argument check - net.update_weights(i_e=0.5, i_i=0.25, copy=False) + net.set_synaptic_gains(i_e=0.5, i_i=0.25, copy=False) for conn in net.connectivity: if (conn['src_type'] in i_cell_names and conn['target_type'] in e_cell_names): @@ -963,6 +963,16 @@ def _get_weight(nb, conn_name, idx=0): _get_weight(nb_base, 'L2Pyr_L5Basket_ampa')) == 1 +def test_get_synaptic_gains(): + """Test synaptic gains getter.""" + net = jones_2009_model() + assert net.get_synaptic_gains() == {'e_e': 1.0, 'e_i': 1.0, + 'i_e': 1.0, 'i_i': 1.0} + new_gains = {'e_e': 0.5, 'e_i': 1.5, 'i_e': 0.75, 'i_i': 1.0} + net.set_synaptic_gains(**new_gains) + assert net.get_synaptic_gains() == new_gains + + class TestPickConnection: """Tests for the pick_connection function.""" @pytest.mark.parametrize("arg_name",