From 4b1cd08adfd1749db753a915badd568f127ed873 Mon Sep 17 00:00:00 2001 From: George Dang <53052793+gtdang@users.noreply.github.com> Date: Tue, 19 Mar 2024 09:11:55 -0400 Subject: [PATCH] refactor: black refactor of gui scripts --- hnn_core/gui/__init__.py | 2 +- hnn_core/gui/_viz_manager.py | 527 ++++++++------ hnn_core/gui/gui.py | 1328 +++++++++++++++++++++------------- 3 files changed, 1112 insertions(+), 745 deletions(-) diff --git a/hnn_core/gui/__init__.py b/hnn_core/gui/__init__.py index 3397d699b..3142afc2e 100644 --- a/hnn_core/gui/__init__.py +++ b/hnn_core/gui/__init__.py @@ -1 +1 @@ -from .gui import HNNGUI, launch \ No newline at end of file +from .gui import HNNGUI, launch diff --git a/hnn_core/gui/_viz_manager.py b/hnn_core/gui/_viz_manager.py index 4f9437454..c4e242197 100644 --- a/hnn_core/gui/_viz_manager.py +++ b/hnn_core/gui/_viz_manager.py @@ -10,34 +10,45 @@ import matplotlib.pyplot as plt import numpy as np from IPython.display import display -from ipywidgets import (Box, Button, Dropdown, FloatText, HBox, Label, Layout, - Output, Tab, VBox, link) - -from hnn_core.dipole import average_dipoles, _rmse +from ipywidgets import ( + Box, + Button, + Dropdown, + FloatText, + HBox, + Label, + Layout, + Output, + Tab, + VBox, + link, +) + +from hnn_core.dipole import _rmse, average_dipoles from hnn_core.gui._logging import logger from hnn_core.viz import plot_dipole -_fig_placeholder = 'Run simulation to add figures here.' +_fig_placeholder = "Run simulation to add figures here." _plot_types = [ - 'current dipole', - 'layer2 dipole', - 'layer5 dipole', - 'input histogram', - 'spikes', - 'PSD', - 'spectrogram', - 'network', + "current dipole", + "layer2 dipole", + "layer5 dipole", + "input histogram", + "spikes", + "PSD", + "spectrogram", + "network", ] _no_overlay_plot_types = [ - 'network', - 'spectrogram', - 'spikes', - 'input histogram', + "network", + "spectrogram", + "spikes", + "input histogram", ] -_ext_data_disabled_plot_types = ['spikes', 'input histogram', 'network'] +_ext_data_disabled_plot_types = ["spikes", "input histogram", "network"] _spectrogram_color_maps = [ "viridis", @@ -49,15 +60,15 @@ fig_templates = { "2row x 1col (1:3)": { - "kwargs": "gridspec_kw={\"height_ratios\":[1,3]}", + "kwargs": 'gridspec_kw={"height_ratios":[1,3]}', "mosaic": "00\n11", }, "2row x 1col (1:1)": { - "kwargs": "gridspec_kw={\"height_ratios\":[1,1]}", + "kwargs": 'gridspec_kw={"height_ratios":[1,1]}', "mosaic": "00\n11", }, "1row x 2col (1:1)": { - "kwargs": "gridspec_kw={\"height_ratios\":[1,1]}", + "kwargs": 'gridspec_kw={"height_ratios":[1,1]}', "mosaic": "01\n01", }, "single figure": { @@ -65,15 +76,14 @@ "mosaic": "00\n00", }, "2row x 2col (1:1)": { - "kwargs": "gridspec_kw={\"height_ratios\":[1,1]}", + "kwargs": 'gridspec_kw={"height_ratios":[1,1]}', "mosaic": "01\n23", }, } -def check_sim_plot_types( - new_sim_name, plot_type_selection, target_selection, data): - if data["simulations"][new_sim_name.new]['net'] is None: +def check_sim_plot_types(new_sim_name, plot_type_selection, target_selection, data): + if data["simulations"][new_sim_name.new]["net"] is None: plot_type_selection.options = [ pt for pt in _plot_types if pt not in _ext_data_disabled_plot_types ] @@ -82,18 +92,17 @@ def check_sim_plot_types( # deal with target data all_possible_targets = list(data["simulations"].keys()) all_possible_targets.remove(new_sim_name.new) - target_selection.options = all_possible_targets + ['None'] - target_selection.value = 'None' + target_selection.options = all_possible_targets + ["None"] + target_selection.value = "None" def target_comparison_change(new_target_name, simulation_selection, data): - """Triggered when the target data is turned on or changed. - """ + """Triggered when the target data is turned on or changed.""" pass def plot_type_coupled_change(new_plot_type, target_data_selection): - if new_plot_type != 'current dipole': + if new_plot_type != "current dipole": target_data_selection.disabled = True else: target_data_selection.disabled = False @@ -113,6 +122,7 @@ def unlink_relink(attribute): widgets """ + def _unlink_relink(f): @wraps(f) def wrapper(self, *args, **kwargs): @@ -127,7 +137,9 @@ def wrapper(self, *args, **kwargs): link_attribute.link() return result + return wrapper + return _unlink_relink @@ -156,14 +168,15 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config): A dict that specifies the preprocessing and style of plots. """ # Make sure that visualization does not change the original data - dpls_copied = copy.deepcopy(single_simulation['dpls']) - net_copied = copy.deepcopy(single_simulation['net']) + dpls_copied = copy.deepcopy(single_simulation["dpls"]) + net_copied = copy.deepcopy(single_simulation["net"]) for dpl in dpls_copied: - if plot_config['dipole_smooth'] > 0: - dpl.smooth(plot_config['dipole_smooth']).scale( - plot_config['dipole_scaling']) + if plot_config["dipole_smooth"] > 0: + dpl.smooth(plot_config["dipole_smooth"]).scale( + plot_config["dipole_scaling"] + ) else: - dpl.scale(plot_config['dipole_scaling']) + dpl.scale(plot_config["dipole_scaling"]) if net_copied is None: assert plot_type not in _ext_data_disabled_plot_types @@ -172,35 +185,38 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config): # x and y axis are hidden after plotting some functions. ax.get_yaxis().set_visible(True) ax.get_xaxis().set_visible(True) - if plot_type == 'spikes': + if plot_type == "spikes": if net_copied.cell_response: net_copied.cell_response.plot_spikes_raster(ax=ax, show=False) - elif plot_type == 'input histogram': + elif plot_type == "input histogram": if net_copied.cell_response: net_copied.cell_response.plot_spikes_hist(ax=ax, show=False) - elif plot_type == 'PSD': + elif plot_type == "PSD": if len(dpls_copied) > 0: color = ax._get_lines.get_next_color() - dpls_copied[0].plot_psd(fmin=0, fmax=50, ax=ax, color=color, - label=sim_name, show=False) + dpls_copied[0].plot_psd( + fmin=0, fmax=50, ax=ax, color=color, label=sim_name, show=False + ) - elif plot_type == 'spectrogram': + elif plot_type == "spectrogram": if len(dpls_copied) > 0: min_f = 10.0 - max_f = plot_config['max_spectral_frequency'] + max_f = plot_config["max_spectral_frequency"] step_f = 1.0 freqs = np.arange(min_f, max_f, step_f) - n_cycles = freqs / 8. + n_cycles = freqs / 8.0 dpls_copied[0].plot_tfr_morlet( freqs, n_cycles=n_cycles, - colormap=plot_config['spectrogram_cm'], - ax=ax, colorbar_inside=True, - show=False) + colormap=plot_config["spectrogram_cm"], + ax=ax, + colorbar_inside=True, + show=False, + ) - elif 'dipole' in plot_type: + elif "dipole" in plot_type: if len(dpls_copied) > 0: if len(dpls_copied) > 1: label = f"{sim_name}: average" @@ -208,46 +224,50 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config): label = sim_name color = ax._get_lines.get_next_color() - if plot_type == 'current dipole': - plot_dipole(dpls_copied, - ax=ax, - label=label, - color=color, - average=True, - show=False) + if plot_type == "current dipole": + plot_dipole( + dpls_copied, + ax=ax, + label=label, + color=color, + average=True, + show=False, + ) else: layer_namemap = { "layer2": "L2", "layer5": "L5", } - plot_dipole(dpls_copied, - ax=ax, - label=label, - color=color, - layer=layer_namemap[plot_type.split(" ")[0]], - average=True, - show=False) + plot_dipole( + dpls_copied, + ax=ax, + label=label, + color=color, + layer=layer_namemap[plot_type.split(" ")[0]], + average=True, + show=False, + ) else: print("No dipole data") - elif plot_type == 'network': + elif plot_type == "network": if net_copied: with plt.ioff(): _fig = plt.figure() - _ax = _fig.add_subplot(111, projection='3d') + _ax = _fig.add_subplot(111, projection="3d") net_copied.plot_cells(ax=_ax, show=False) io_buf = io.BytesIO() - _fig.savefig(io_buf, format='raw') + _fig.savefig(io_buf, format="raw") io_buf.seek(0) - img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), - dtype=np.uint8), - newshape=(int(_fig.bbox.bounds[3]), - int(_fig.bbox.bounds[2]), -1)) + img_arr = np.reshape( + np.frombuffer(io_buf.getvalue(), dtype=np.uint8), + newshape=(int(_fig.bbox.bounds[3]), int(_fig.bbox.bounds[2]), -1), + ) io_buf.close() _ = ax.imshow(img_arr) # set up alignment - if plot_type not in ['network', 'PSD']: + if plot_type not in ["network", "PSD"]: margin_x = 0 max_x = max([dpl.times[-1] for dpl in dpls_copied]) ax.set_xlim(left=-margin_x, right=max_x + margin_x) @@ -256,11 +276,11 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config): def _static_rerender(widgets, fig, fig_idx): - logger.debug('_static_re_render is called') - figs_tabs = widgets['figs_tabs'] + logger.debug("_static_re_render is called") + figs_tabs = widgets["figs_tabs"] titles = figs_tabs.titles fig_tab_idx = titles.index(_idx2figname(fig_idx)) - fig_output = widgets['figs_tabs'].children[fig_tab_idx] + fig_output = widgets["figs_tabs"].children[fig_tab_idx] fig_output.clear_output() with fig_output: fig.tight_layout() @@ -273,11 +293,22 @@ def _dynamic_rerender(fig): fig.tight_layout() -def _plot_on_axes(b, widgets_simulation, widgets_plot_type, - target_simulations, - spectrogram_colormap_selection, dipole_smooth, - max_spectral_frequency, dipole_scaling, widgets, data, - fig_idx, fig, ax, existing_plots): +def _plot_on_axes( + b, + widgets_simulation, + widgets_plot_type, + target_simulations, + spectrogram_colormap_selection, + dipole_smooth, + max_spectral_frequency, + dipole_scaling, + widgets, + data, + fig_idx, + fig, + ax, + existing_plots, +): """Plotting different types of data on the given axes. Now this function is also responsible for comparing multiple simulations, @@ -323,34 +354,39 @@ def _plot_on_axes(b, widgets_simulation, widgets_plot_type, # freeze plot type widgets_plot_type.disabled = True - single_simulation = data['simulations'][sim_name] + single_simulation = data["simulations"][sim_name] plot_config = { "max_spectral_frequency": max_spectral_frequency.value, "dipole_scaling": dipole_scaling.value, "dipole_smooth": dipole_smooth.value, - "spectrogram_cm": spectrogram_colormap_selection.value + "spectrogram_cm": spectrogram_colormap_selection.value, } - dpls_processed = _update_ax(fig, ax, single_simulation, sim_name, - plot_type, plot_config) + dpls_processed = _update_ax( + fig, ax, single_simulation, sim_name, plot_type, plot_config + ) # If target_simulations is not None and we are plotting a dipole, # we need to plot the target dipole as well. - if target_simulations.value in data['simulations'].keys( - ) and plot_type == 'current dipole': + if ( + target_simulations.value in data["simulations"].keys() + and plot_type == "current dipole" + ): target_sim_name = target_simulations.value - target_sim = data['simulations'][target_sim_name] + target_sim = data["simulations"][target_sim_name] # plot the target dipole. # disable scaling for the target dipole. - plot_config['dipole_scaling'] = 1. + plot_config["dipole_scaling"] = 1.0 # plot the target dipole. target_dpl_processed = _update_ax( - fig, ax, target_sim, target_sim_name, plot_type, - plot_config)[0] # we assume there is only one dipole. + fig, ax, target_sim, target_sim_name, plot_type, plot_config + )[ + 0 + ] # we assume there is only one dipole. # calculate the RMSE between the two dipoles. t0 = 0.0 @@ -361,45 +397,58 @@ def _plot_on_axes(b, widgets_simulation, widgets_plot_type, dpl = dpls_processed rmse = _rmse(dpl, target_dpl_processed, t0, tstop) # Show the RMSE between the two dipoles. - ax.annotate(f'RMSE({sim_name}, {target_sim_name}): {rmse:.4f}', - xy=(0.95, 0.05), - xycoords='axes fraction', - horizontalalignment='right', - verticalalignment='bottom', - fontsize=12) - - existing_plots.children = (*existing_plots.children, - Label(f"{sim_name}: {plot_type}")) - if data['use_ipympl'] is False: + ax.annotate( + f"RMSE({sim_name}, {target_sim_name}): {rmse:.4f}", + xy=(0.95, 0.05), + xycoords="axes fraction", + horizontalalignment="right", + verticalalignment="bottom", + fontsize=12, + ) + + existing_plots.children = ( + *existing_plots.children, + Label(f"{sim_name}: {plot_type}"), + ) + if data["use_ipympl"] is False: _static_rerender(widgets, fig, fig_idx) else: _dynamic_rerender(fig) -def _clear_axis(b, widgets, data, fig_idx, fig, ax, widgets_plot_type, - existing_plots, add_plot_button): +def _clear_axis( + b, + widgets, + data, + fig_idx, + fig, + ax, + widgets_plot_type, + existing_plots, + add_plot_button, +): ax.clear() # remove attached colorbar if exists - if hasattr(fig, f'_cbar-ax-{id(ax)}'): - getattr(fig, f'_cbar-ax-{id(ax)}').ax.remove() - delattr(fig, f'_cbar-ax-{id(ax)}') + if hasattr(fig, f"_cbar-ax-{id(ax)}"): + getattr(fig, f"_cbar-ax-{id(ax)}").ax.remove() + delattr(fig, f"_cbar-ax-{id(ax)}") - ax.set_facecolor('w') - ax.set_aspect('auto') + ax.set_facecolor("w") + ax.set_aspect("auto") widgets_plot_type.disabled = False add_plot_button.disabled = False existing_plots.children = () - if data['use_ipympl'] is False: + if data["use_ipympl"] is False: _static_rerender(widgets, fig, fig_idx) else: _dynamic_rerender(fig) def _get_ax_control(widgets, data, fig_idx, fig, ax): - analysis_style = {'description_width': '200px'} + analysis_style = {"description_width": "200px"} layout = Layout(width="98%") - simulation_names = tuple(data['simulations'].keys()) + simulation_names = tuple(data["simulations"].keys()) sim_name_default = simulation_names[-1] if len(simulation_names) == 0: simulation_names = [ @@ -409,13 +458,13 @@ def _get_ax_control(widgets, data, fig_idx, fig, ax): simulation_selection = Dropdown( options=simulation_names, value=sim_name_default, - description='Simulation Data:', + description="Simulation Data:", disabled=False, layout=layout, style=analysis_style, ) - if data['simulations'][sim_name_default]['net'] is None: + if data["simulations"][sim_name_default]["net"] is None: valid_plot_types = [ pt for pt in _plot_types if pt not in _ext_data_disabled_plot_types ] @@ -425,65 +474,70 @@ def _get_ax_control(widgets, data, fig_idx, fig, ax): plot_type_selection = Dropdown( options=valid_plot_types, value=valid_plot_types[0], - description='Type:', + description="Type:", disabled=False, layout=layout, style=analysis_style, ) target_data_selection = Dropdown( - options=simulation_names[:-1] + ('None',), - value='None', - description='Data to Compare:', + options=simulation_names[:-1] + ("None",), + value="None", + description="Data to Compare:", disabled=False, layout=layout, style=analysis_style, ) spectrogram_colormap_selection = Dropdown( - description='Spectrogram Colormap:', + description="Spectrogram Colormap:", options=[(cm, cm) for cm in _spectrogram_color_maps], value=_spectrogram_color_maps[0], layout=layout, style=analysis_style, ) - dipole_smooth = FloatText(value=30, - description='Dipole Smooth Window (ms):', - disabled=False, - layout=layout, - style=analysis_style) - dipole_scaling = FloatText(value=3000, - description='Dipole Scaling:', - disabled=False, - layout=layout, - style=analysis_style) + dipole_smooth = FloatText( + value=30, + description="Dipole Smooth Window (ms):", + disabled=False, + layout=layout, + style=analysis_style, + ) + dipole_scaling = FloatText( + value=3000, + description="Dipole Scaling:", + disabled=False, + layout=layout, + style=analysis_style, + ) max_spectral_frequency = FloatText( value=100, - description='Max Spectral Frequency (Hz):', + description="Max Spectral Frequency (Hz):", disabled=False, layout=layout, - style=analysis_style) + style=analysis_style, + ) existing_plots = VBox([]) - plot_button = Button(description='Add plot') - clear_button = Button(description='Clear axis') + plot_button = Button(description="Add plot") + clear_button = Button(description="Clear axis") def _on_sim_data_change(new_sim_name): return check_sim_plot_types( - new_sim_name, plot_type_selection, target_data_selection, data) + new_sim_name, plot_type_selection, target_data_selection, data + ) def _on_target_comparison_change(new_target_name): - return target_comparison_change(new_target_name, simulation_selection, - data) + return target_comparison_change(new_target_name, simulation_selection, data) def _on_plot_type_change(new_plot_type): return plot_type_coupled_change(new_plot_type, target_data_selection) - simulation_selection.observe(_on_sim_data_change, 'value') - target_data_selection.observe(_on_target_comparison_change, 'value') - plot_type_selection.observe(_on_plot_type_change, 'value') + simulation_selection.observe(_on_sim_data_change, "value") + target_data_selection.observe(_on_target_comparison_change, "value") + plot_type_selection.observe(_on_plot_type_change, "value") clear_button.on_click( partial( @@ -496,7 +550,8 @@ def _on_plot_type_change(new_plot_type): widgets_plot_type=plot_type_selection, existing_plots=existing_plots, add_plot_button=plot_button, - )) + ) + ) plot_button.on_click( partial( @@ -514,22 +569,32 @@ def _on_plot_type_change(new_plot_type): fig=fig, ax=ax, existing_plots=existing_plots, - )) + ) + ) - vbox = VBox([ - simulation_selection, plot_type_selection, target_data_selection, - dipole_smooth, dipole_scaling, max_spectral_frequency, - spectrogram_colormap_selection, - HBox( - [plot_button, clear_button], - layout=Layout(justify_content='space-between'), - ), existing_plots], layout=Layout(width="98%")) + vbox = VBox( + [ + simulation_selection, + plot_type_selection, + target_data_selection, + dipole_smooth, + dipole_scaling, + max_spectral_frequency, + spectrogram_colormap_selection, + HBox( + [plot_button, clear_button], + layout=Layout(justify_content="space-between"), + ), + existing_plots, + ], + layout=Layout(width="98%"), + ) return vbox def _close_figure(b, widgets, data, fig_idx): - fig_related_widgets = [widgets['figs_tabs'], widgets['axes_config_tabs']] + fig_related_widgets = [widgets["figs_tabs"], widgets["axes_config_tabs"]] for w_idx, tab in enumerate(fig_related_widgets): # Get tab object's list of children and their titles tab_children = list(tab.children) @@ -547,27 +612,27 @@ def _close_figure(b, widgets, data, fig_idx): # If the figure tab group... if w_idx == 0: # Close figure and delete the data - plt.close(data['figs'][fig_idx]) - data['figs'].pop(fig_idx) + plt.close(data["figs"][fig_idx]) + data["figs"].pop(fig_idx) # Redisplay the remaining children n_tabs = len(tab.children) for idx in range(n_tabs): _fig_idx = _figname2idx(tab.get_title(idx)) - assert _fig_idx in data['figs'].keys() + assert _fig_idx in data["figs"].keys() tab.children[idx].clear_output() with tab.children[idx]: - display(data['figs'][_fig_idx].canvas) + display(data["figs"][_fig_idx].canvas) # If all children have been deleted display the placeholder if n_tabs == 0: - widgets['figs_output'].clear_output() - with widgets['figs_output']: + widgets["figs_output"].clear_output() + with widgets["figs_output"]: display(Label(_fig_placeholder)) def _add_axes_controls(widgets, data, fig, axd): - fig_idx = data['fig_idx']['idx'] + fig_idx = data["fig_idx"]["idx"] controls = Tab() children = [ @@ -576,63 +641,66 @@ def _add_axes_controls(widgets, data, fig, axd): ] controls.children = children for i in range(len(children)): - controls.set_title(i, f'ax{i}') + controls.set_title(i, f"ax{i}") - close_fig_button = Button(description=f'Close {_idx2figname(fig_idx)}', - button_style='danger', - icon='close', - layout=Layout(width="98%")) + close_fig_button = Button( + description=f"Close {_idx2figname(fig_idx)}", + button_style="danger", + icon="close", + layout=Layout(width="98%"), + ) close_fig_button.on_click( - partial(_close_figure, widgets=widgets, data=data, fig_idx=fig_idx)) + partial(_close_figure, widgets=widgets, data=data, fig_idx=fig_idx) + ) - n_tabs = len(widgets['axes_config_tabs'].children) - widgets['axes_config_tabs'].children = widgets[ - 'axes_config_tabs'].children + (VBox([close_fig_button, controls]), ) - widgets['axes_config_tabs'].set_title(n_tabs, _idx2figname(fig_idx)) + n_tabs = len(widgets["axes_config_tabs"].children) + widgets["axes_config_tabs"].children = widgets["axes_config_tabs"].children + ( + VBox([close_fig_button, controls]), + ) + widgets["axes_config_tabs"].set_title(n_tabs, _idx2figname(fig_idx)) def _add_figure(b, widgets, data, scale=0.95, dpi=96): - template_name = widgets['templates_dropdown'].value - fig_idx = data['fig_idx']['idx'] - viz_output_layout = data['visualization_output'] + template_name = widgets["templates_dropdown"].value + fig_idx = data["fig_idx"]["idx"] + viz_output_layout = data["visualization_output"] fig_outputs = Output() - n_tabs = len(widgets['figs_tabs'].children) + n_tabs = len(widgets["figs_tabs"].children) if n_tabs == 0: - widgets['figs_output'].clear_output() - with widgets['figs_output']: - display(widgets['figs_tabs']) + widgets["figs_output"].clear_output() + with widgets["figs_output"]: + display(widgets["figs_tabs"]) - widgets['figs_tabs'].children = ( - [s for s in widgets['figs_tabs'].children] + [fig_outputs] - ) - widgets['figs_tabs'].set_title(n_tabs, _idx2figname(fig_idx)) + widgets["figs_tabs"].children = [s for s in widgets["figs_tabs"].children] + [ + fig_outputs + ] + widgets["figs_tabs"].set_title(n_tabs, _idx2figname(fig_idx)) with fig_outputs: - figsize = (scale * ((int(viz_output_layout.width[:-2]) - 10) / dpi), - scale * ((int(viz_output_layout.height[:-2]) - 10) / dpi)) - mosaic = fig_templates[template_name]['mosaic'] + figsize = ( + scale * ((int(viz_output_layout.width[:-2]) - 10) / dpi), + scale * ((int(viz_output_layout.height[:-2]) - 10) / dpi), + ) + mosaic = fig_templates[template_name]["mosaic"] kwargs = eval(f"dict({fig_templates[template_name]['kwargs']})") plt.ioff() - fig, axd = plt.subplot_mosaic(mosaic, - figsize=figsize, - dpi=dpi, - **kwargs) + fig, axd = plt.subplot_mosaic(mosaic, figsize=figsize, dpi=dpi, **kwargs) plt.ion() fig.tight_layout() fig.canvas.header_visible = False fig.canvas.footer_visible = False - if data['use_ipympl'] is False: + if data["use_ipympl"] is False: plt.show() else: display(fig.canvas) _add_axes_controls(widgets, data, fig=fig, axd=axd) - data['figs'][fig_idx] = fig - widgets['figs_tabs'].selected_index = n_tabs - data['fig_idx']['idx'] += 1 + data["figs"][fig_idx] = fig + widgets["figs_tabs"].selected_index = n_tabs + data["fig_idx"]["idx"] += 1 class _VizManager: @@ -656,7 +724,7 @@ class _VizManager: def __init__(self, gui_data, viz_layout): plt.close("all") self.viz_layout = viz_layout - self.use_ipympl = 'ipympl' in matplotlib.get_backend() + self.use_ipympl = "ipympl" in matplotlib.get_backend() self.axes_config_output = Output() self.figs_output = Output() @@ -667,22 +735,24 @@ def __init__(self, gui_data, viz_layout): self.axes_config_tabs.selected_index = None self.figs_tabs.selected_index = None self.figs_config_tab_link = link( - (self.axes_config_tabs, 'selected_index'), - (self.figs_tabs, 'selected_index'), + (self.axes_config_tabs, "selected_index"), + (self.figs_tabs, "selected_index"), ) template_names = list(fig_templates.keys()) self.templates_dropdown = Dropdown( - description='Layout template:', + description="Layout template:", options=template_names, value=template_names[0], - style={'description_width': 'initial'}, - layout=Layout(width="98%")) + style={"description_width": "initial"}, + layout=Layout(width="98%"), + ) self.make_fig_button = Button( - description='Make figure', + description="Make figure", button_style="primary", - style={'button_color': self.viz_layout['theme_color']}, - layout=self.viz_layout['btn']) + style={"button_color": self.viz_layout["theme_color"]}, + layout=self.viz_layout["btn"], + ) self.make_fig_button.on_click(self.add_figure) # data @@ -696,7 +766,7 @@ def widgets(self): "figs_output": self.figs_output, "axes_config_tabs": self.axes_config_tabs, "figs_tabs": self.figs_tabs, - "templates_dropdown": self.templates_dropdown + "templates_dropdown": self.templates_dropdown, } @property @@ -706,13 +776,13 @@ def data(self): "use_ipympl": self.use_ipympl, "simulations": self.gui_data["simulation_data"], "fig_idx": self.fig_idx, - "visualization_output": self.viz_layout['visualization_output'], - "figs": self.figs + "visualization_output": self.viz_layout["visualization_output"], + "figs": self.figs, } def reset_fig_config_tabs(self, template_name=None): """Reset the figure config tabs with most recent simulation data.""" - simulation_names = tuple(self.data['simulations'].keys()) + simulation_names = tuple(self.data["simulations"].keys()) for tab in self.axes_config_tabs.children: controls = tab.children[1] for ax_control in controls.children: @@ -731,34 +801,34 @@ def compose(self): display(Label(_fig_placeholder)) fig_output_container = VBox( - [self.figs_output], layout=self.viz_layout['visualization_window']) - - config_panel = VBox([ - Box( - [ - self.templates_dropdown, - self.make_fig_button, - ], - layout=Layout( - display='flex', - flex_flow='column', - align_items='stretch', + [self.figs_output], layout=self.viz_layout["visualization_window"] + ) + + config_panel = VBox( + [ + Box( + [ + self.templates_dropdown, + self.make_fig_button, + ], + layout=Layout( + display="flex", + flex_flow="column", + align_items="stretch", + ), ), - ), - Label("Figure config:"), - self.axes_config_output, - ]) + Label("Figure config:"), + self.axes_config_output, + ] + ) return config_panel, fig_output_container - @unlink_relink(attribute='figs_config_tab_link') + @unlink_relink(attribute="figs_config_tab_link") def add_figure(self, b=None): - """Add a figure and corresponding config tabs to the dashboard. - """ - _add_figure(None, - self.widgets, - self.data, - scale=0.97, - dpi=self.viz_layout['dpi']) + """Add a figure and corresponding config tabs to the dashboard.""" + _add_figure( + None, self.widgets, self.data, scale=0.97, dpi=self.viz_layout["dpi"] + ) def _simulate_add_fig(self): self.make_fig_button.click() @@ -777,8 +847,15 @@ def _simulate_delete_figure(self, fig_name): close_button = self.axes_config_tabs.children[tab_idx].children[0] close_button.click() - def _simulate_edit_figure(self, fig_name, ax_name, simulation_name, - plot_type, preprocessing_config, operation): + def _simulate_edit_figure( + self, + fig_name, + ax_name, + simulation_name, + plot_type, + preprocessing_config, + operation, + ): """Manipulate a certain figure. Parameters @@ -799,7 +876,7 @@ def _simulate_edit_figure(self, fig_name, ax_name, simulation_name, `"plot"` if you want to plot and `"clear"` if you want to remove previously plotted visualizations. """ - assert simulation_name in self.data['simulations'].keys() + assert simulation_name in self.data["simulations"].keys() assert plot_type in _plot_types assert operation in ("plot", "clear") diff --git a/hnn_core/gui/gui.py b/hnn_core/gui/gui.py index 70622d1d8..f0bc3307f 100644 --- a/hnn_core/gui/gui.py +++ b/hnn_core/gui/gui.py @@ -10,21 +10,46 @@ import urllib.parse import urllib.request from collections import defaultdict -from pathlib import Path from datetime import datetime +from pathlib import Path + from IPython.display import IFrame, display -from ipywidgets import (HTML, Accordion, AppLayout, BoundedFloatText, - BoundedIntText, Button, Dropdown, FileUpload, VBox, - HBox, IntText, Layout, Output, RadioButtons, Tab, Text) +from ipywidgets import ( + HTML, + Accordion, + AppLayout, + BoundedFloatText, + BoundedIntText, + Button, + Dropdown, + FileUpload, + HBox, + IntText, + Layout, + Output, + RadioButtons, + Tab, + Text, + VBox, +) from ipywidgets.embed import embed_minimal_html + import hnn_core -from hnn_core import (JoblibBackend, MPIBackend, jones_2009_model, read_params, - simulate_dipole) +from hnn_core import ( + JoblibBackend, + MPIBackend, + jones_2009_model, + read_params, + simulate_dipole, +) from hnn_core.gui._logging import logger -from hnn_core.gui._viz_manager import _VizManager, _idx2figname +from hnn_core.gui._viz_manager import _idx2figname, _VizManager from hnn_core.network import pick_connection -from hnn_core.params import (_extract_drive_specs_from_hnn_params, _read_json, - _read_legacy_params) +from hnn_core.params import ( + _extract_drive_specs_from_hnn_params, + _read_json, + _read_legacy_params, +) class _OutputWidgetHandler(logging.Handler): @@ -35,11 +60,11 @@ def __init__(self, output_widget, *args, **kwargs): def emit(self, record): formatted_record = self.format(record) new_output = { - 'name': 'stdout', - 'output_type': 'stream', - 'text': formatted_record + '\n' + "name": "stdout", + "output_type": "stream", + "text": formatted_record + "\n", } - self.out.outputs = (new_output, ) + self.out.outputs + self.out.outputs = (new_output,) + self.out.outputs class HNNGUI: @@ -120,18 +145,20 @@ class HNNGUI: in the network. """ - def __init__(self, theme_color="#8A2BE2", - total_height=800, - total_width=1300, - header_height=50, - button_height=30, - operation_box_height=60, - drive_widget_width=200, - left_sidebar_width=576, - log_window_height=150, - status_height=30, - dpi=96, - ): + def __init__( + self, + theme_color="#8A2BE2", + total_height=800, + total_width=1300, + header_height=50, + button_height=30, + operation_box_height=60, + drive_widget_width=200, + left_sidebar_width=576, + log_window_height=150, + status_height=30, + dpi=96, + ): # set up styling. self.total_height = total_height self.total_width = total_width @@ -139,41 +166,50 @@ def __init__(self, theme_color="#8A2BE2", viz_win_width = self.total_width - left_sidebar_width main_content_height = self.total_height - status_height - config_box_height = main_content_height - (log_window_height + - operation_box_height) + config_box_height = main_content_height - ( + log_window_height + operation_box_height + ) self.layout = { "dpi": dpi, "header_height": f"{header_height}px", "theme_color": theme_color, - "btn": Layout(height=f"{button_height}px", width='auto'), - "btn_full_w": Layout(height=f"{button_height}px", width='100%'), - "del_fig_btn": Layout(height=f"{button_height}px", width='auto'), - "log_out": Layout(border='1px solid gray', - height=f"{log_window_height-10}px", - overflow='auto'), - "viz_config": Layout(width='99%'), + "btn": Layout(height=f"{button_height}px", width="auto"), + "btn_full_w": Layout(height=f"{button_height}px", width="100%"), + "del_fig_btn": Layout(height=f"{button_height}px", width="auto"), + "log_out": Layout( + border="1px solid gray", + height=f"{log_window_height-10}px", + overflow="auto", + ), + "viz_config": Layout(width="99%"), "visualization_window": Layout( width=f"{viz_win_width-10}px", height=f"{main_content_height-10}px", - border='1px solid gray', - overflow='scroll'), + border="1px solid gray", + overflow="scroll", + ), "visualization_output": Layout( width=f"{viz_win_width-50}px", height=f"{main_content_height-100}px", - border='1px solid gray', - overflow='scroll'), - "left_sidebar": Layout(width=f"{left_sidebar_width}px", - height=f"{main_content_height}px"), - "left_tab": Layout(width=f"{left_sidebar_width}px", - height=f"{config_box_height}px"), - "operation_box": Layout(width=f"{left_sidebar_width}px", - height=f"{operation_box_height}px", - flex_wrap="wrap", - ), - "config_box": Layout(width=f"{left_sidebar_width}px", - height=f"{config_box_height-100}px"), + border="1px solid gray", + overflow="scroll", + ), + "left_sidebar": Layout( + width=f"{left_sidebar_width}px", height=f"{main_content_height}px" + ), + "left_tab": Layout( + width=f"{left_sidebar_width}px", height=f"{config_box_height}px" + ), + "operation_box": Layout( + width=f"{left_sidebar_width}px", + height=f"{operation_box_height}px", + flex_wrap="wrap", + ), + "config_box": Layout( + width=f"{left_sidebar_width}px", height=f"{config_box_height-100}px" + ), "drive_widget": Layout(width="auto"), - "drive_textbox": Layout(width='270px', height='auto'), + "drive_textbox": Layout(width="270px", height="auto"), # simulation status related "simulation_status_height": f"{status_height}px", "simulation_status_common": "background:gray;padding-left:10px", @@ -183,17 +219,13 @@ def __init__(self, theme_color="#8A2BE2", } self._simulation_status_contents = { - "not_running": - f"""
Not running
""", - "running": - f"""
Running...
""", - "finished": - f"""
Simulation finished
""", - "failed": - f"""
Simulation failed
""", } @@ -205,68 +237,102 @@ def __init__(self, theme_color="#8A2BE2", # Simulation parameters self.widget_tstop = BoundedFloatText( - value=170, description='tstop (ms):', min=0, max=1e6, step=1, - disabled=False) + value=170, description="tstop (ms):", min=0, max=1e6, step=1, disabled=False + ) self.widget_dt = BoundedFloatText( - value=0.025, description='dt (ms):', min=0, max=10, step=0.01, - disabled=False) - self.widget_ntrials = IntText(value=1, description='Trials:', - disabled=False) - self.widget_simulation_name = Text(value='default', - placeholder='ID of your simulation', - description='Name:', - disabled=False) - self.widget_backend_selection = Dropdown(options=[('Joblib', 'Joblib'), - ('MPI', 'MPI')], - value='Joblib', - description='Backend:') - self.widget_mpi_cmd = Text(value='mpiexec', - placeholder='Fill if applies', - description='MPI cmd:', disabled=False) - self.widget_n_jobs = BoundedIntText(value=1, min=1, - max=multiprocessing.cpu_count(), - description='Cores:', - disabled=False) + value=0.025, + description="dt (ms):", + min=0, + max=10, + step=0.01, + disabled=False, + ) + self.widget_ntrials = IntText(value=1, description="Trials:", disabled=False) + self.widget_simulation_name = Text( + value="default", + placeholder="ID of your simulation", + description="Name:", + disabled=False, + ) + self.widget_backend_selection = Dropdown( + options=[("Joblib", "Joblib"), ("MPI", "MPI")], + value="Joblib", + description="Backend:", + ) + self.widget_mpi_cmd = Text( + value="mpiexec", + placeholder="Fill if applies", + description="MPI cmd:", + disabled=False, + ) + self.widget_n_jobs = BoundedIntText( + value=1, + min=1, + max=multiprocessing.cpu_count(), + description="Cores:", + disabled=False, + ) self.load_data_button = FileUpload( - accept='.txt', multiple=False, - style={'button_color': self.layout['theme_color']}, - description='Load data', - button_style='success') + accept=".txt", + multiple=False, + style={"button_color": self.layout["theme_color"]}, + description="Load data", + button_style="success", + ) # Drive selection self.widget_drive_type_selection = RadioButtons( - options=['Evoked', 'Poisson', 'Rhythmic'], - value='Evoked', - description='Drive:', + options=["Evoked", "Poisson", "Rhythmic"], + value="Evoked", + description="Drive:", disabled=False, - layout=self.layout['drive_widget']) + layout=self.layout["drive_widget"], + ) self.widget_location_selection = RadioButtons( - options=['proximal', 'distal'], value='proximal', - description='Location', disabled=False, - layout=self.layout['drive_widget']) + options=["proximal", "distal"], + value="proximal", + description="Location", + disabled=False, + layout=self.layout["drive_widget"], + ) self.add_drive_button = create_expanded_button( - 'Add drive', 'primary', layout=self.layout['btn'], - button_color=self.layout['theme_color']) + "Add drive", + "primary", + layout=self.layout["btn"], + button_color=self.layout["theme_color"], + ) # Dashboard level buttons self.run_button = create_expanded_button( - 'Run', 'success', layout=self.layout['btn'], - button_color=self.layout['theme_color']) + "Run", + "success", + layout=self.layout["btn"], + button_color=self.layout["theme_color"], + ) self.load_connectivity_button = FileUpload( - accept='.json,.param', multiple=False, - style={'button_color': self.layout['theme_color']}, - description='Load local network connectivity', - layout=self.layout['btn_full_w'], button_style='success') + accept=".json,.param", + multiple=False, + style={"button_color": self.layout["theme_color"]}, + description="Load local network connectivity", + layout=self.layout["btn_full_w"], + button_style="success", + ) self.load_drives_button = FileUpload( - accept='.json,.param', multiple=False, - style={'button_color': self.layout['theme_color']}, - description='Load external drives', layout=self.layout['btn'], - button_style='success') + accept=".json,.param", + multiple=False, + style={"button_color": self.layout["theme_color"]}, + description="Load external drives", + layout=self.layout["btn"], + button_style="success", + ) self.delete_drive_button = create_expanded_button( - 'Delete drives', 'success', layout=self.layout['btn'], - button_color=self.layout['theme_color']) + "Delete drives", + "success", + layout=self.layout["btn"], + button_color=self.layout["theme_color"], + ) # Plotting window @@ -288,7 +354,8 @@ def __init__(self, theme_color="#8A2BE2", def add_logging_window_logger(self): handler = _OutputWidgetHandler(self._log_out) handler.setFormatter( - logging.Formatter('%(asctime)s - [%(levelname)s] %(message)s')) + logging.Formatter("%(asctime)s - [%(levelname)s] %(message)s") + ) logger.addHandler(handler) def _init_ui_components(self): @@ -311,23 +378,27 @@ def _init_ui_components(self): # static parts # Running status self._simulation_status_bar = HTML( - value=self._simulation_status_contents['not_running']) + value=self._simulation_status_contents["not_running"] + ) - self._log_window = HBox([self._log_out], layout=self.layout['log_out']) + self._log_window = HBox([self._log_out], layout=self.layout["log_out"]) self._operation_buttons = HBox( [self.run_button, self.load_data_button], - layout=self.layout['operation_box']) + layout=self.layout["operation_box"], + ) # title - self._header = HTML(value=f"""
- HUMAN NEOCORTICAL NEUROSOLVER
""") + HUMAN NEOCORTICAL NEUROSOLVER""" + ) @property def analysis_config(self): """Provides everything viz window needs except for the data.""" return { - "viz_style": self.layout['visualization_output'], + "viz_style": self.layout["visualization_output"], # widgets "plot_outputs": self.plot_outputs_dict, "plot_dropdowns": self.plot_dropdown_types_dict, @@ -346,23 +417,30 @@ def load_parameters(params_fname=None): if not params_fname: # by default load default.json hnn_core_root = Path(hnn_core.__file__).parent - params_fname = hnn_core_root / 'param/default.json' + params_fname = hnn_core_root / "param/default.json" return read_params(params_fname) def _link_callbacks(self): """Link callbacks to UI components.""" + def _handle_backend_change(backend_type): - return handle_backend_change(backend_type.new, - self._backend_config_out, - self.widget_mpi_cmd, - self.widget_n_jobs) + return handle_backend_change( + backend_type.new, + self._backend_config_out, + self.widget_mpi_cmd, + self.widget_n_jobs, + ) def _add_drive_button_clicked(b): - return add_drive_widget(self.widget_drive_type_selection.value, - self.drive_boxes, self.drive_widgets, - self._drives_out, self.widget_tstop, - self.widget_location_selection.value, - layout=self.layout['drive_textbox']) + return add_drive_widget( + self.widget_drive_type_selection.value, + self.drive_boxes, + self.drive_widgets, + self._drives_out, + self.widget_tstop, + self.widget_location_selection.value, + layout=self.layout["drive_textbox"], + ) def _delete_drives_clicked(b): self._drives_out.clear_output() @@ -374,42 +452,68 @@ def _delete_drives_clicked(b): def _on_upload_connectivity(change): return on_upload_params_change( - change, self.params, self.widget_tstop, self.widget_dt, - self._log_out, self.drive_boxes, self.drive_widgets, - self._drives_out, self._connectivity_out, - self.connectivity_widgets, self.layout['drive_textbox'], - "connectivity") + change, + self.params, + self.widget_tstop, + self.widget_dt, + self._log_out, + self.drive_boxes, + self.drive_widgets, + self._drives_out, + self._connectivity_out, + self.connectivity_widgets, + self.layout["drive_textbox"], + "connectivity", + ) def _on_upload_drives(change): return on_upload_params_change( - change, self.params, self.widget_tstop, self.widget_dt, - self._log_out, self.drive_boxes, self.drive_widgets, - self._drives_out, self._connectivity_out, - self.connectivity_widgets, self.layout['drive_textbox'], - "drives") + change, + self.params, + self.widget_tstop, + self.widget_dt, + self._log_out, + self.drive_boxes, + self.drive_widgets, + self._drives_out, + self._connectivity_out, + self.connectivity_widgets, + self.layout["drive_textbox"], + "drives", + ) def _on_upload_data(change): - return on_upload_data_change(change, self.data, self.viz_manager, - self._log_out) + return on_upload_data_change( + change, self.data, self.viz_manager, self._log_out + ) def _run_button_clicked(b): return run_button_clicked( - self.widget_simulation_name, self._log_out, self.drive_widgets, - self.data, self.widget_dt, self.widget_tstop, - self.widget_ntrials, self.widget_backend_selection, - self.widget_mpi_cmd, self.widget_n_jobs, self.params, - self._simulation_status_bar, self._simulation_status_contents, - self.connectivity_widgets, self.viz_manager) - - self.widget_backend_selection.observe(_handle_backend_change, 'value') + self.widget_simulation_name, + self._log_out, + self.drive_widgets, + self.data, + self.widget_dt, + self.widget_tstop, + self.widget_ntrials, + self.widget_backend_selection, + self.widget_mpi_cmd, + self.widget_n_jobs, + self.params, + self._simulation_status_bar, + self._simulation_status_contents, + self.connectivity_widgets, + self.viz_manager, + ) + + self.widget_backend_selection.observe(_handle_backend_change, "value") self.add_drive_button.on_click(_add_drive_button_clicked) self.delete_drive_button.on_click(_delete_drives_clicked) - self.load_connectivity_button.observe(_on_upload_connectivity, - names='value') - self.load_drives_button.observe(_on_upload_drives, names='value') + self.load_connectivity_button.observe(_on_upload_connectivity, names="value") + self.load_drives_button.observe(_on_upload_drives, names="value") self.run_button.on_click(_run_button_clicked) - self.load_data_button.observe(_on_upload_data, names='value') + self.load_data_button.observe(_on_upload_data, names="value") def compose(self, return_layout=True): """Compose widgets. @@ -420,70 +524,108 @@ def compose(self, return_layout=True): If the method returns the layout object which can be rendered by IPython.display.display() method. """ - simulation_box = VBox([ - VBox([ - self.widget_simulation_name, self.widget_tstop, self.widget_dt, - self.widget_ntrials, self.widget_backend_selection, - self._backend_config_out]), - ], layout=self.layout['config_box']) - - connectivity_box = VBox([ - HBox([self.load_connectivity_button, ]), - self._connectivity_out, - ]) + simulation_box = VBox( + [ + VBox( + [ + self.widget_simulation_name, + self.widget_tstop, + self.widget_dt, + self.widget_ntrials, + self.widget_backend_selection, + self._backend_config_out, + ] + ), + ], + layout=self.layout["config_box"], + ) + + connectivity_box = VBox( + [ + HBox( + [ + self.load_connectivity_button, + ] + ), + self._connectivity_out, + ] + ) # accordions to group local-connectivity by cell type - connectivity_boxes = [ - VBox(slider) for slider in self.connectivity_widgets] + connectivity_boxes = [VBox(slider) for slider in self.connectivity_widgets] connectivity_names = ( - 'Layer 2/3 Pyramidal', 'Layer 5 Pyramidal', 'Layer 2 Basket', - 'Layer 5 Basket') + "Layer 2/3 Pyramidal", + "Layer 5 Pyramidal", + "Layer 2 Basket", + "Layer 5 Basket", + ) cell_connectivity = Accordion(children=connectivity_boxes) cell_connectivity.titles = [s for s in connectivity_names] - drive_selections = VBox([ - self.add_drive_button, self.widget_drive_type_selection, - self.widget_location_selection], - layout=Layout(flex="1")) + drive_selections = VBox( + [ + self.add_drive_button, + self.widget_drive_type_selection, + self.widget_location_selection, + ], + layout=Layout(flex="1"), + ) - drives_options = VBox([ - HBox([ - VBox([self.load_drives_button, self.delete_drive_button], - layout=Layout(flex="1")), - drive_selections, - ]), self._drives_out - ]) + drives_options = VBox( + [ + HBox( + [ + VBox( + [self.load_drives_button, self.delete_drive_button], + layout=Layout(flex="1"), + ), + drive_selections, + ] + ), + self._drives_out, + ] + ) config_panel, figs_output = self.viz_manager.compose() # Tabs for left pane left_tab = Tab() left_tab.children = [ - simulation_box, connectivity_box, drives_options, + simulation_box, + connectivity_box, + drives_options, config_panel, ] - titles = ('Simulation', 'Network connectivity', 'External drives', - 'Visualization') + titles = ( + "Simulation", + "Network connectivity", + "External drives", + "Visualization", + ) for idx, title in enumerate(titles): left_tab.set_title(idx, title) self.app_layout = AppLayout( header=self._header, - left_sidebar=VBox([ - VBox([left_tab], layout=self.layout['left_tab']), - self._operation_buttons, - self._log_window, - ], layout=self.layout['left_sidebar']), + left_sidebar=VBox( + [ + VBox([left_tab], layout=self.layout["left_tab"]), + self._operation_buttons, + self._log_window, + ], + layout=self.layout["left_sidebar"], + ), right_sidebar=figs_output, footer=self._simulation_status_bar, pane_widths=[ - self.layout['left_sidebar'].width, '0px', - self.layout['visualization_window'].width + self.layout["left_sidebar"].width, + "0px", + self.layout["visualization_window"].width, ], pane_heights=[ - self.layout['header_height'], - self.layout['visualization_window'].height, - self.layout['simulation_status_height'] + self.layout["header_height"], + self.layout["visualization_window"].height, + self.layout["simulation_status_height"], ], ) @@ -492,11 +634,17 @@ def compose(self, return_layout=True): # self.simulation_data[self.widget_simulation_name.value] # initialize drive and connectivity ipywidgets - load_drive_and_connectivity(self.params, self._log_out, - self._drives_out, self.drive_widgets, - self.drive_boxes, self._connectivity_out, - self.connectivity_widgets, - self.widget_tstop, self.layout) + load_drive_and_connectivity( + self.params, + self._log_out, + self._drives_out, + self.drive_widgets, + self.drive_boxes, + self._connectivity_out, + self.connectivity_widgets, + self.widget_tstop, + self.layout, + ) if not return_layout: return @@ -525,13 +673,13 @@ def capture(self, width=None, height=None, extra_margin=100, render=True): snapshot : An iframe snapshot object that can be rendered in notebooks. """ file = io.StringIO() - embed_minimal_html(file, views=[self.app_layout], title='') + embed_minimal_html(file, views=[self.app_layout], title="") if not width: width = self.total_width + extra_margin if not height: height = self.total_height + extra_margin - content = urllib.parse.quote(file.getvalue().encode('utf8')) + content = urllib.parse.quote(file.getvalue().encode("utf8")) data_url = f"data:text/html,{content}" screenshot = IFrame(data_url, width=width, height=height) if render: @@ -605,15 +753,15 @@ def run_notebook_cells(self): # below are a series of methods that are used to manipulate the GUI def _simulate_upload_data(self, file_url): uploaded_value = _prepare_upload_file_from_url(file_url) - self.load_data_button.set_trait('value', uploaded_value) + self.load_data_button.set_trait("value", uploaded_value) def _simulate_upload_connectivity(self, file_url): uploaded_value = _prepare_upload_file_from_url(file_url) - self.load_connectivity_button.set_trait('value', uploaded_value) + self.load_connectivity_button.set_trait("value", uploaded_value) def _simulate_upload_drives(self, file_url): uploaded_value = _prepare_upload_file_from_url(file_url) - self.load_drives_button.set_trait('value', uploaded_value) + self.load_drives_button.set_trait("value", uploaded_value) def _simulate_left_tab_click(self, tab_title): # Get left tab group object @@ -625,7 +773,9 @@ def _simulate_left_tab_click(self, tab_title): else: raise ValueError("Tab title does not exist.") - def _simulate_make_figure(self,): + def _simulate_make_figure( + self, + ): self._simulate_left_tab_click("Visualization") self.viz_manager.make_fig_button.click() @@ -654,45 +804,63 @@ def _prepare_upload_file_from_url(file_url): for line in data: content += line - return [{ - 'name': params_name, - 'type': 'application/json', - 'size': len(content), - 'content': content, - 'last_modified': datetime.now() - }] + return [ + { + "name": params_name, + "type": "application/json", + "size": len(content), + "content": content, + "last_modified": datetime.now(), + } + ] -def create_expanded_button(description, button_style, layout, disabled=False, - button_color="#8A2BE2"): - return Button(description=description, button_style=button_style, - layout=layout, style={'button_color': button_color}, - disabled=disabled) +def create_expanded_button( + description, button_style, layout, disabled=False, button_color="#8A2BE2" +): + return Button( + description=description, + button_style=button_style, + layout=layout, + style={"button_color": button_color}, + disabled=disabled, + ) def _get_connectivity_widgets(conn_data): """Create connectivity box widgets from specified weight and probability""" - style = {'description_width': '150px'} + style = {"description_width": "150px"} style = {} sliders = list() for receptor_name in conn_data.keys(): w_text_input = BoundedFloatText( - value=conn_data[receptor_name]['weight'], disabled=False, - continuous_update=False, min=0, max=1e6, step=0.01, - description="weight", style=style) + value=conn_data[receptor_name]["weight"], + disabled=False, + continuous_update=False, + min=0, + max=1e6, + step=0.01, + description="weight", + style=style, + ) - conn_widget = VBox([ - HTML(value=f"""

- Receptor: {conn_data[receptor_name]['receptor']}

"""), - w_text_input, HTML(value="
") - ]) + conn_widget = VBox( + [ + HTML( + value=f"""

+ Receptor: {conn_data[receptor_name]['receptor']}

""" + ), + w_text_input, + HTML(value="
"), + ] + ) conn_widget._belongsto = { - "receptor": conn_data[receptor_name]['receptor'], - "location": conn_data[receptor_name]['location'], - "src_gids": conn_data[receptor_name]['src_gids'], - "target_gids": conn_data[receptor_name]['target_gids'], + "receptor": conn_data[receptor_name]["receptor"], + "location": conn_data[receptor_name]["location"], + "src_gids": conn_data[receptor_name]["src_gids"], + "target_gids": conn_data[receptor_name]["target_gids"], } sliders.append(conn_widget) @@ -701,23 +869,23 @@ def _get_connectivity_widgets(conn_data): def _get_cell_specific_widgets(layout, style, location, data=None): default_data = { - 'weights_ampa': { - 'L5_pyramidal': 0., - 'L2_pyramidal': 0., - 'L5_basket': 0., - 'L2_basket': 0. + "weights_ampa": { + "L5_pyramidal": 0.0, + "L2_pyramidal": 0.0, + "L5_basket": 0.0, + "L2_basket": 0.0, }, - 'weights_nmda': { - 'L5_pyramidal': 0., - 'L2_pyramidal': 0., - 'L5_basket': 0., - 'L2_basket': 0. + "weights_nmda": { + "L5_pyramidal": 0.0, + "L2_pyramidal": 0.0, + "L5_basket": 0.0, + "L2_basket": 0.0, }, - 'delays': { - 'L5_pyramidal': 0.1, - 'L2_pyramidal': 0.1, - 'L5_basket': 0.1, - 'L2_basket': 0.1 + "delays": { + "L5_pyramidal": 0.1, + "L2_pyramidal": 0.1, + "L5_basket": 0.1, + "L2_basket": 0.1, }, } if isinstance(data, dict): @@ -726,159 +894,224 @@ def _get_cell_specific_widgets(layout, style, location, data=None): default_data[k].update(data[k]) kwargs = dict(layout=layout, style=style) - cell_types = ['L5_pyramidal', 'L2_pyramidal', 'L5_basket', 'L2_basket'] + cell_types = ["L5_pyramidal", "L2_pyramidal", "L5_basket", "L2_basket"] if location == "distal": - cell_types.remove('L5_basket') + cell_types.remove("L5_basket") weights_ampa, weights_nmda, delays = dict(), dict(), dict() for cell_type in cell_types: - weights_ampa[f'{cell_type}'] = BoundedFloatText( - value=default_data['weights_ampa'][cell_type], - description=f'{cell_type}:', min=0, max=1e6, step=0.01, **kwargs) - weights_nmda[f'{cell_type}'] = BoundedFloatText( - value=default_data['weights_nmda'][cell_type], - description=f'{cell_type}:', min=0, max=1e6, step=0.01, **kwargs) - delays[f'{cell_type}'] = BoundedFloatText( - value=default_data['delays'][cell_type], - description=f'{cell_type}:', min=0, max=1e6, step=0.1, **kwargs) + weights_ampa[f"{cell_type}"] = BoundedFloatText( + value=default_data["weights_ampa"][cell_type], + description=f"{cell_type}:", + min=0, + max=1e6, + step=0.01, + **kwargs, + ) + weights_nmda[f"{cell_type}"] = BoundedFloatText( + value=default_data["weights_nmda"][cell_type], + description=f"{cell_type}:", + min=0, + max=1e6, + step=0.01, + **kwargs, + ) + delays[f"{cell_type}"] = BoundedFloatText( + value=default_data["delays"][cell_type], + description=f"{cell_type}:", + min=0, + max=1e6, + step=0.1, + **kwargs, + ) widgets_dict = { - 'weights_ampa': weights_ampa, - 'weights_nmda': weights_nmda, - 'delays': delays + "weights_ampa": weights_ampa, + "weights_nmda": weights_nmda, + "delays": delays, } - widgets_list = ([HTML(value="AMPA weights")] + - list(weights_ampa.values()) + - [HTML(value="NMDA weights")] + - list(weights_nmda.values()) + - [HTML(value="Synaptic delays")] + - list(delays.values())) + widgets_list = ( + [HTML(value="AMPA weights")] + + list(weights_ampa.values()) + + [HTML(value="NMDA weights")] + + list(weights_nmda.values()) + + [HTML(value="Synaptic delays")] + + list(delays.values()) + ) return widgets_list, widgets_dict -def _get_rhythmic_widget(name, tstop_widget, layout, style, location, - data=None, default_weights_ampa=None, - default_weights_nmda=None, default_delays=None): +def _get_rhythmic_widget( + name, + tstop_widget, + layout, + style, + location, + data=None, + default_weights_ampa=None, + default_weights_nmda=None, + default_delays=None, +): default_data = { - 'tstart': 0., - 'tstart_std': 0., - 'tstop': 0., - 'burst_rate': 7.5, - 'burst_std': 0, - 'repeats': 1, - 'seedcore': 14, + "tstart": 0.0, + "tstart_std": 0.0, + "tstop": 0.0, + "burst_rate": 7.5, + "burst_std": 0, + "repeats": 1, + "seedcore": 14, } if isinstance(data, dict): default_data.update(data) kwargs = dict(layout=layout, style=style) tstart = BoundedFloatText( - value=default_data['tstart'], description='Start time (ms)', - min=0, max=1e6, **kwargs) + value=default_data["tstart"], + description="Start time (ms)", + min=0, + max=1e6, + **kwargs, + ) tstart_std = BoundedFloatText( - value=default_data['tstart_std'], description='Start time dev (ms)', - min=0, max=1e6, **kwargs) + value=default_data["tstart_std"], + description="Start time dev (ms)", + min=0, + max=1e6, + **kwargs, + ) tstop = BoundedFloatText( - value=default_data['tstop'], - description='Stop time (ms)', + value=default_data["tstop"], + description="Stop time (ms)", max=tstop_widget.value, **kwargs, ) burst_rate = BoundedFloatText( - value=default_data['burst_rate'], description='Burst rate (Hz)', - min=0, max=1e6, **kwargs) + value=default_data["burst_rate"], + description="Burst rate (Hz)", + min=0, + max=1e6, + **kwargs, + ) burst_std = BoundedFloatText( - value=default_data['burst_std'], description='Burst std dev (Hz)', - min=0, max=1e6, **kwargs) + value=default_data["burst_std"], + description="Burst std dev (Hz)", + min=0, + max=1e6, + **kwargs, + ) repeats = BoundedIntText( - value=default_data['repeats'], description='Repeats', min=0, - max=int(1e6), **kwargs) - seedcore = IntText(value=default_data['seedcore'], - description='Seed', - **kwargs) + value=default_data["repeats"], + description="Repeats", + min=0, + max=int(1e6), + **kwargs, + ) + seedcore = IntText(value=default_data["seedcore"], description="Seed", **kwargs) widgets_list, widgets_dict = _get_cell_specific_widgets( layout, style, location, data={ - 'weights_ampa': default_weights_ampa, - 'weights_nmda': default_weights_nmda, - 'delays': default_delays, + "weights_ampa": default_weights_ampa, + "weights_nmda": default_weights_nmda, + "delays": default_delays, }, ) drive_box = VBox( - [tstart, tstart_std, tstop, burst_rate, burst_std, repeats, seedcore] + - widgets_list) - drive = dict(type='Rhythmic', - name=name, - tstart=tstart, - tstart_std=tstart_std, - burst_rate=burst_rate, - burst_std=burst_std, - repeats=repeats, - seedcore=seedcore, - location=location, - tstop=tstop) + [tstart, tstart_std, tstop, burst_rate, burst_std, repeats, seedcore] + + widgets_list + ) + drive = dict( + type="Rhythmic", + name=name, + tstart=tstart, + tstart_std=tstart_std, + burst_rate=burst_rate, + burst_std=burst_std, + repeats=repeats, + seedcore=seedcore, + location=location, + tstop=tstop, + ) drive.update(widgets_dict) return drive, drive_box -def _get_poisson_widget(name, tstop_widget, layout, style, location, data=None, - default_weights_ampa=None, default_weights_nmda=None, - default_delays=None): +def _get_poisson_widget( + name, + tstop_widget, + layout, + style, + location, + data=None, + default_weights_ampa=None, + default_weights_nmda=None, + default_delays=None, +): default_data = { - 'tstart': 0.0, - 'tstop': 0.0, - 'seedcore': 14, - 'rate_constant': { - 'L5_pyramidal': 8.5, - 'L2_pyramidal': 8.5, - 'L5_basket': 8.5, - 'L2_basket': 8.5, - } + "tstart": 0.0, + "tstop": 0.0, + "seedcore": 14, + "rate_constant": { + "L5_pyramidal": 8.5, + "L2_pyramidal": 8.5, + "L5_basket": 8.5, + "L2_basket": 8.5, + }, } if isinstance(data, dict): default_data.update(data) tstart = BoundedFloatText( - value=default_data['tstart'], description='Start time (ms)', - min=0, max=1e6, layout=layout, style=style) + value=default_data["tstart"], + description="Start time (ms)", + min=0, + max=1e6, + layout=layout, + style=style, + ) tstop = BoundedFloatText( - value=default_data['tstop'], + value=default_data["tstop"], max=tstop_widget.value, - description='Stop time (ms)', + description="Stop time (ms)", layout=layout, style=style, ) - seedcore = IntText(value=default_data['seedcore'], - description='Seed', - layout=layout, - style=style) + seedcore = IntText( + value=default_data["seedcore"], description="Seed", layout=layout, style=style + ) - cell_types = ['L5_pyramidal', 'L2_pyramidal', 'L5_basket', 'L2_basket'] + cell_types = ["L5_pyramidal", "L2_pyramidal", "L5_basket", "L2_basket"] rate_constant = dict() for cell_type in cell_types: - rate_constant[f'{cell_type}'] = BoundedFloatText( - value=default_data['rate_constant'][cell_type], - description=f'{cell_type}:', min=0, max=1e6, step=0.01, - layout=layout, style=style) + rate_constant[f"{cell_type}"] = BoundedFloatText( + value=default_data["rate_constant"][cell_type], + description=f"{cell_type}:", + min=0, + max=1e6, + step=0.01, + layout=layout, + style=style, + ) widgets_list, widgets_dict = _get_cell_specific_widgets( layout, style, location, data={ - 'weights_ampa': default_weights_ampa, - 'weights_nmda': default_weights_nmda, - 'delays': default_delays, + "weights_ampa": default_weights_ampa, + "weights_nmda": default_weights_nmda, + "delays": default_delays, }, ) - widgets_dict.update({'rate_constant': rate_constant}) - widgets_list.extend([HTML(value="Rate constants")] + - list(widgets_dict['rate_constant'].values())) + widgets_dict.update({"rate_constant": rate_constant}) + widgets_list.extend( + [HTML(value="Rate constants")] + + list(widgets_dict["rate_constant"].values()) + ) drive_box = VBox([tstart, tstop, seedcore] + widgets_list) drive = dict( - type='Poisson', + type="Poisson", name=name, tstart=tstart, tstop=tstop, @@ -890,66 +1123,92 @@ def _get_poisson_widget(name, tstop_widget, layout, style, location, data=None, return drive, drive_box -def _get_evoked_widget(name, layout, style, location, data=None, - default_weights_ampa=None, default_weights_nmda=None, - default_delays=None): +def _get_evoked_widget( + name, + layout, + style, + location, + data=None, + default_weights_ampa=None, + default_weights_nmda=None, + default_delays=None, +): default_data = { - 'mu': 0, - 'sigma': 1, - 'numspikes': 1, - 'seedcore': 14, + "mu": 0, + "sigma": 1, + "numspikes": 1, + "seedcore": 14, } if isinstance(data, dict): default_data.update(data) kwargs = dict(layout=layout, style=style) mu = BoundedFloatText( - value=default_data['mu'], description='Mean time:', min=0, max=1e6, - step=0.01, **kwargs) + value=default_data["mu"], + description="Mean time:", + min=0, + max=1e6, + step=0.01, + **kwargs, + ) sigma = BoundedFloatText( - value=default_data['sigma'], description='Std dev time:', min=0, - max=1e6, step=0.01, **kwargs) - numspikes = IntText(value=default_data['numspikes'], - description='No. Spikes:', - **kwargs) - seedcore = IntText(value=default_data['seedcore'], - description='Seed: ', - **kwargs) + value=default_data["sigma"], + description="Std dev time:", + min=0, + max=1e6, + step=0.01, + **kwargs, + ) + numspikes = IntText( + value=default_data["numspikes"], description="No. Spikes:", **kwargs + ) + seedcore = IntText(value=default_data["seedcore"], description="Seed: ", **kwargs) widgets_list, widgets_dict = _get_cell_specific_widgets( layout, style, location, data={ - 'weights_ampa': default_weights_ampa, - 'weights_nmda': default_weights_nmda, - 'delays': default_delays, + "weights_ampa": default_weights_ampa, + "weights_nmda": default_weights_nmda, + "delays": default_delays, }, ) drive_box = VBox([mu, sigma, numspikes, seedcore] + widgets_list) - drive = dict(type='Evoked', - name=name, - mu=mu, - sigma=sigma, - numspikes=numspikes, - seedcore=seedcore, - location=location, - sync_within_trial=False) + drive = dict( + type="Evoked", + name=name, + mu=mu, + sigma=sigma, + numspikes=numspikes, + seedcore=seedcore, + location=location, + sync_within_trial=False, + ) drive.update(widgets_dict) return drive, drive_box -def add_drive_widget(drive_type, drive_boxes, drive_widgets, drives_out, - tstop_widget, location, layout, - prespecified_drive_name=None, - prespecified_drive_data=None, - prespecified_weights_ampa=None, - prespecified_weights_nmda=None, - prespecified_delays=None, render=True, - expand_last_drive=True, event_seed=14): +def add_drive_widget( + drive_type, + drive_boxes, + drive_widgets, + drives_out, + tstop_widget, + location, + layout, + prespecified_drive_name=None, + prespecified_drive_data=None, + prespecified_weights_ampa=None, + prespecified_weights_nmda=None, + prespecified_delays=None, + render=True, + expand_last_drive=True, + event_seed=14, +): """Add a widget for a new drive.""" - style = {'description_width': '150px'} + style = {"description_width": "150px"} drives_out.clear_output() if not prespecified_drive_data: prespecified_drive_data = {} @@ -960,7 +1219,7 @@ def add_drive_widget(drive_type, drive_boxes, drive_widgets, drives_out, name = drive_type + str(len(drive_boxes)) else: name = prespecified_drive_name - if drive_type in ('Rhythmic', 'Bursty'): + if drive_type in ("Rhythmic", "Bursty"): drive, drive_box = _get_rhythmic_widget( name, tstop_widget, @@ -972,7 +1231,7 @@ def add_drive_widget(drive_type, drive_boxes, drive_widgets, drives_out, default_weights_nmda=prespecified_weights_nmda, default_delays=prespecified_delays, ) - elif drive_type == 'Poisson': + elif drive_type == "Poisson": drive, drive_box = _get_poisson_widget( name, tstop_widget, @@ -984,7 +1243,7 @@ def add_drive_widget(drive_type, drive_boxes, drive_widgets, drives_out, default_weights_nmda=prespecified_weights_nmda, default_delays=prespecified_delays, ) - elif drive_type in ('Evoked', 'Gaussian'): + elif drive_type in ("Evoked", "Gaussian"): drive, drive_box = _get_evoked_widget( name, layout, @@ -996,33 +1255,28 @@ def add_drive_widget(drive_type, drive_boxes, drive_widgets, drives_out, default_delays=prespecified_delays, ) - if drive_type in [ - 'Evoked', 'Poisson', 'Rhythmic', 'Bursty', 'Gaussian' - ]: + if drive_type in ["Evoked", "Poisson", "Rhythmic", "Bursty", "Gaussian"]: drive_boxes.append(drive_box) drive_widgets.append(drive) if render: accordion = Accordion( children=drive_boxes, - selected_index=len(drive_boxes) - - 1 if expand_last_drive else None, + selected_index=len(drive_boxes) - 1 if expand_last_drive else None, ) for idx, drive in enumerate(drive_widgets): - accordion.set_title(idx, - f"{drive['name']} ({drive['location']})") + accordion.set_title(idx, f"{drive['name']} ({drive['location']})") display(accordion) -def add_connectivity_tab(params, connectivity_out, - connectivity_textfields): +def add_connectivity_tab(params, connectivity_out, connectivity_textfields): """Add all possible connectivity boxes to connectivity tab.""" net = jones_2009_model(params) cell_types = [ct for ct in net.cell_types.keys()] - receptors = ('ampa', 'nmda', 'gabaa', 'gabab') - locations = ('proximal', 'distal', 'soma') + receptors = ("ampa", "nmda", "gabaa", "gabab") + locations = ("proximal", "distal", "soma") # clear existing connectivity connectivity_out.clear_output() @@ -1036,18 +1290,18 @@ def add_connectivity_tab(params, connectivity_out, # the connectivity list should be built on this level receptor_related_conn = {} for receptor in receptors: - conn_indices = pick_connection(net=net, - src_gids=src_gids, - target_gids=target_gids, - loc=location, - receptor=receptor) + conn_indices = pick_connection( + net=net, + src_gids=src_gids, + target_gids=target_gids, + loc=location, + receptor=receptor, + ) if len(conn_indices) > 0: assert len(conn_indices) == 1 conn_idx = conn_indices[0] - current_w = net.connectivity[ - conn_idx]['nc_dict']['A_weight'] - current_p = net.connectivity[ - conn_idx]['probability'] + current_w = net.connectivity[conn_idx]["nc_dict"]["A_weight"] + current_p = net.connectivity[conn_idx]["probability"] # valid connection receptor_related_conn[receptor] = { "weight": current_w, @@ -1059,10 +1313,10 @@ def add_connectivity_tab(params, connectivity_out, "target_gids": target_gids, } if len(receptor_related_conn) > 0: - connectivity_names.append( - f"{src_gids}→{target_gids} ({location})") + connectivity_names.append(f"{src_gids}→{target_gids} ({location})") connectivity_textfields.append( - _get_connectivity_widgets(receptor_related_conn)) + _get_connectivity_widgets(receptor_related_conn) + ) connectivity_boxes = [VBox(slider) for slider in connectivity_textfields] cell_connectivity = Accordion(children=connectivity_boxes) @@ -1075,11 +1329,11 @@ def add_connectivity_tab(params, connectivity_out, return net -def add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop, - layout): +def add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop, layout): net = jones_2009_model(params) drive_specs = _extract_drive_specs_from_hnn_params( - params, list(net.cell_types.keys()), legacy_mode=net._legacy_mode) + params, list(net.cell_types.keys()), legacy_mode=net._legacy_mode + ) # clear before adding drives drives_out.clear_output() @@ -1093,111 +1347,136 @@ def add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop, should_render = idx == (len(drive_names) - 1) add_drive_widget( - specs['type'].capitalize(), + specs["type"].capitalize(), drive_boxes, drive_widgets, drives_out, tstop, - specs['location'], + specs["location"], layout=layout, prespecified_drive_name=drive_name, - prespecified_drive_data=specs['dynamics'], - prespecified_weights_ampa=specs['weights_ampa'], - prespecified_weights_nmda=specs['weights_nmda'], - prespecified_delays=specs['synaptic_delays'], + prespecified_drive_data=specs["dynamics"], + prespecified_weights_ampa=specs["weights_ampa"], + prespecified_weights_nmda=specs["weights_nmda"], + prespecified_delays=specs["synaptic_delays"], render=should_render, expand_last_drive=False, - event_seed=specs['event_seed'], + event_seed=specs["event_seed"], ) -def load_drive_and_connectivity(params, log_out, drives_out, - drive_widgets, drive_boxes, connectivity_out, - connectivity_textfields, tstop, layout): +def load_drive_and_connectivity( + params, + log_out, + drives_out, + drive_widgets, + drive_boxes, + connectivity_out, + connectivity_textfields, + tstop, + layout, +): """Add drive and connectivity ipywidgets from params.""" log_out.clear_output() with log_out: # Add connectivity add_connectivity_tab(params, connectivity_out, connectivity_textfields) # Add drives - add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop, - layout) + add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop, layout) def on_upload_data_change(change, data, viz_manager, log_out): - if len(change['owner'].value) == 0: + if len(change["owner"].value) == 0: logger.info("Empty change") return - data_dict = change['new'][0] + data_dict = change["new"][0] - data_fname = data_dict['name'].rstrip('.txt') - if data_fname in data['simulation_data'].keys(): + data_fname = data_dict["name"].rstrip(".txt") + if data_fname in data["simulation_data"].keys(): logger.error(f"Found existing data: {data_fname}.") return - ext_content = data_dict['content'] + ext_content = data_dict["content"] ext_content = codecs.decode(ext_content, encoding="utf-8") with log_out: - data['simulation_data'][data_fname] = {'net': None, 'dpls': [ - hnn_core.read_dipole(io.StringIO(ext_content)) - ]} - logger.info(f'External data {data_fname} loaded.') - viz_manager.reset_fig_config_tabs(template_name='single figure') + data["simulation_data"][data_fname] = { + "net": None, + "dpls": [hnn_core.read_dipole(io.StringIO(ext_content))], + } + logger.info(f"External data {data_fname} loaded.") + viz_manager.reset_fig_config_tabs(template_name="single figure") viz_manager.add_figure() - fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1) + fig_name = _idx2figname(viz_manager.data["fig_idx"]["idx"] - 1) ax_plots = [("ax0", "current dipole")] for ax_name, plot_type in ax_plots: viz_manager._simulate_edit_figure( - fig_name, ax_name, data_fname, plot_type, {}, "plot") + fig_name, ax_name, data_fname, plot_type, {}, "plot" + ) -def on_upload_params_change(change, params, tstop, dt, log_out, drive_boxes, - drive_widgets, drives_out, connectivity_out, - connectivity_textfields, layout, load_type): - if len(change['owner'].value) == 0: +def on_upload_params_change( + change, + params, + tstop, + dt, + log_out, + drive_boxes, + drive_widgets, + drives_out, + connectivity_out, + connectivity_textfields, + layout, + load_type, +): + if len(change["owner"].value) == 0: logger.info("Empty change") return logger.info("Loading connectivity...") - param_dict = change['new'][0] - params_fname = param_dict['name'] - param_data = param_dict['content'] + param_dict = change["new"][0] + params_fname = param_dict["name"] + param_data = param_dict["content"] param_data = codecs.decode(param_data, encoding="utf-8") ext = Path(params_fname).suffix - read_func = {'.json': _read_json, '.param': _read_legacy_params} + read_func = {".json": _read_json, ".param": _read_legacy_params} params_network = read_func[ext](param_data) # update simulation settings and params log_out.clear_output() with log_out: - if 'tstop' in params_network.keys(): - tstop.value = params_network['tstop'] - if 'dt' in params_network.keys(): - dt.value = params_network['dt'] + if "tstop" in params_network.keys(): + tstop.value = params_network["tstop"] + if "dt" in params_network.keys(): + dt.value = params_network["dt"] params.update(params_network) # init network, add drives & connectivity - if load_type == 'connectivity': + if load_type == "connectivity": add_connectivity_tab(params, connectivity_out, connectivity_textfields) - elif load_type == 'drives': - add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop, - layout) + elif load_type == "drives": + add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop, layout) else: raise ValueError # Resets file counter to 0 - change['owner'].set_trait('value', ([])) - - -def _init_network_from_widgets(params, dt, tstop, single_simulation_data, - drive_widgets, connectivity_textfields, - add_drive=True): + change["owner"].set_trait("value", ([])) + + +def _init_network_from_widgets( + params, + dt, + tstop, + single_simulation_data, + drive_widgets, + connectivity_textfields, + add_drive=True, +): """Construct network and add drives.""" print("init network") - params['dt'] = dt.value - params['tstop'] = tstop.value - single_simulation_data['net'] = jones_2009_model( + params["dt"] = dt.value + params["tstop"] = tstop.value + single_simulation_data["net"] = jones_2009_model( params, add_drives_from_params=False, ) @@ -1205,136 +1484,146 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data, for connectivity_slider in connectivity_textfields: for vbox in connectivity_slider: conn_indices = pick_connection( - net=single_simulation_data['net'], - src_gids=vbox._belongsto['src_gids'], - target_gids=vbox._belongsto['target_gids'], - loc=vbox._belongsto['location'], - receptor=vbox._belongsto['receptor']) + net=single_simulation_data["net"], + src_gids=vbox._belongsto["src_gids"], + target_gids=vbox._belongsto["target_gids"], + loc=vbox._belongsto["location"], + receptor=vbox._belongsto["receptor"], + ) if len(conn_indices) > 0: assert len(conn_indices) == 1 conn_idx = conn_indices[0] - single_simulation_data['net'].connectivity[conn_idx][ - 'nc_dict']['A_weight'] = vbox.children[1].value - single_simulation_data['net'].connectivity[conn_idx][ - 'probability'] = vbox.children[2].value + single_simulation_data["net"].connectivity[conn_idx]["nc_dict"][ + "A_weight" + ] = vbox.children[1].value + single_simulation_data["net"].connectivity[conn_idx]["probability"] = ( + vbox.children[2].value + ) if add_drive is False: return # add drives to network for drive in drive_widgets: - weights_ampa = { - k: v.value - for k, v in drive['weights_ampa'].items() - } - weights_nmda = { - k: v.value - for k, v in drive['weights_nmda'].items() - } - synaptic_delays = {k: v.value for k, v in drive['delays'].items()} - print( - f"drive type is {drive['type']}, location={drive['location']}") - if drive['type'] == 'Poisson': + weights_ampa = {k: v.value for k, v in drive["weights_ampa"].items()} + weights_nmda = {k: v.value for k, v in drive["weights_nmda"].items()} + synaptic_delays = {k: v.value for k, v in drive["delays"].items()} + print(f"drive type is {drive['type']}, location={drive['location']}") + if drive["type"] == "Poisson": rate_constant = { - k: v.value - for k, v in drive['rate_constant'].items() if v.value > 0 - } - weights_ampa = { - k: v - for k, v in weights_ampa.items() if k in rate_constant - } - weights_nmda = { - k: v - for k, v in weights_nmda.items() if k in rate_constant + k: v.value for k, v in drive["rate_constant"].items() if v.value > 0 } - single_simulation_data['net'].add_poisson_drive( - name=drive['name'], - tstart=drive['tstart'].value, - tstop=drive['tstop'].value, + weights_ampa = {k: v for k, v in weights_ampa.items() if k in rate_constant} + weights_nmda = {k: v for k, v in weights_nmda.items() if k in rate_constant} + single_simulation_data["net"].add_poisson_drive( + name=drive["name"], + tstart=drive["tstart"].value, + tstop=drive["tstop"].value, rate_constant=rate_constant, - location=drive['location'], + location=drive["location"], weights_ampa=weights_ampa, weights_nmda=weights_nmda, synaptic_delays=synaptic_delays, space_constant=100.0, - event_seed=drive['seedcore'].value) - elif drive['type'] in ('Evoked', 'Gaussian'): - single_simulation_data['net'].add_evoked_drive( - name=drive['name'], - mu=drive['mu'].value, - sigma=drive['sigma'].value, - numspikes=drive['numspikes'].value, - location=drive['location'], + event_seed=drive["seedcore"].value, + ) + elif drive["type"] in ("Evoked", "Gaussian"): + single_simulation_data["net"].add_evoked_drive( + name=drive["name"], + mu=drive["mu"].value, + sigma=drive["sigma"].value, + numspikes=drive["numspikes"].value, + location=drive["location"], weights_ampa=weights_ampa, weights_nmda=weights_nmda, synaptic_delays=synaptic_delays, space_constant=3.0, - event_seed=drive['seedcore'].value) - elif drive['type'] in ('Rhythmic', 'Bursty'): - single_simulation_data['net'].add_bursty_drive( - name=drive['name'], - tstart=drive['tstart'].value, - tstart_std=drive['tstart_std'].value, - burst_rate=drive['burst_rate'].value, - burst_std=drive['burst_std'].value, - location=drive['location'], - tstop=drive['tstop'].value, + event_seed=drive["seedcore"].value, + ) + elif drive["type"] in ("Rhythmic", "Bursty"): + single_simulation_data["net"].add_bursty_drive( + name=drive["name"], + tstart=drive["tstart"].value, + tstart_std=drive["tstart_std"].value, + burst_rate=drive["burst_rate"].value, + burst_std=drive["burst_std"].value, + location=drive["location"], + tstop=drive["tstop"].value, weights_ampa=weights_ampa, weights_nmda=weights_nmda, synaptic_delays=synaptic_delays, - event_seed=drive['seedcore'].value) + event_seed=drive["seedcore"].value, + ) -def run_button_clicked(widget_simulation_name, log_out, drive_widgets, - all_data, dt, tstop, ntrials, backend_selection, - mpi_cmd, n_jobs, params, simulation_status_bar, - simulation_status_contents, connectivity_textfields, - viz_manager): +def run_button_clicked( + widget_simulation_name, + log_out, + drive_widgets, + all_data, + dt, + tstop, + ntrials, + backend_selection, + mpi_cmd, + n_jobs, + params, + simulation_status_bar, + simulation_status_contents, + connectivity_textfields, + viz_manager, +): """Run the simulation and plot outputs.""" log_out.clear_output() simulation_data = all_data["simulation_data"] with log_out: # clear empty trash simulations for _name in tuple(simulation_data.keys()): - if len(simulation_data[_name]['dpls']) == 0: + if len(simulation_data[_name]["dpls"]) == 0: del simulation_data[_name] _sim_name = widget_simulation_name.value - if simulation_data[_sim_name]['net'] is not None: + if simulation_data[_sim_name]["net"] is not None: print("Simulation with the same name exists!") - simulation_status_bar.value = simulation_status_contents[ - 'failed'] + simulation_status_bar.value = simulation_status_contents["failed"] return - _init_network_from_widgets(params, dt, tstop, - simulation_data[_sim_name], drive_widgets, - connectivity_textfields) + _init_network_from_widgets( + params, + dt, + tstop, + simulation_data[_sim_name], + drive_widgets, + connectivity_textfields, + ) print("start simulation") if backend_selection.value == "MPI": backend = MPIBackend( - n_procs=multiprocessing.cpu_count() - 1, mpi_cmd=mpi_cmd.value) + n_procs=multiprocessing.cpu_count() - 1, mpi_cmd=mpi_cmd.value + ) else: backend = JoblibBackend(n_jobs=n_jobs.value) print(f"Using Joblib with {n_jobs.value} core(s).") with backend: - simulation_status_bar.value = simulation_status_contents['running'] - simulation_data[_sim_name]['dpls'] = simulate_dipole( - simulation_data[_sim_name]['net'], + simulation_status_bar.value = simulation_status_contents["running"] + simulation_data[_sim_name]["dpls"] = simulate_dipole( + simulation_data[_sim_name]["net"], tstop=tstop.value, dt=dt.value, - n_trials=ntrials.value) + n_trials=ntrials.value, + ) - simulation_status_bar.value = simulation_status_contents[ - 'finished'] + simulation_status_bar.value = simulation_status_contents["finished"] viz_manager.reset_fig_config_tabs() viz_manager.add_figure() - fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1) + fig_name = _idx2figname(viz_manager.data["fig_idx"]["idx"] - 1) ax_plots = [("ax0", "input histogram"), ("ax1", "current dipole")] for ax_name, plot_type in ax_plots: - viz_manager._simulate_edit_figure(fig_name, ax_name, _sim_name, - plot_type, {}, "plot") + viz_manager._simulate_edit_figure( + fig_name, ax_name, _sim_name, plot_type, {}, "plot" + ) def handle_backend_change(backend_type, backend_config, mpi_cmd, n_jobs): @@ -1353,5 +1642,6 @@ def launch(): You can pass voila commandline parameters as usual. """ from voila.app import main - notebook_path = Path(__file__).parent / 'hnn_widget.ipynb' + + notebook_path = Path(__file__).parent / "hnn_widget.ipynb" main([str(notebook_path.resolve()), *sys.argv[1:]])