From d7c773b35bef87ff3b1a2bd5423dbf0ccdddd509 Mon Sep 17 00:00:00 2001 From: dylansdaniels-berkeley Date: Wed, 30 Oct 2024 20:08:44 -0400 Subject: [PATCH] new widget to adjust default smoothing value in gui --- hnn_core/gui/_viz_manager.py | 24 +++++++++++++++--------- hnn_core/gui/gui.py | 17 ++++++++++++----- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/hnn_core/gui/_viz_manager.py b/hnn_core/gui/_viz_manager.py index 2ec1e1c06..060b816c0 100644 --- a/hnn_core/gui/_viz_manager.py +++ b/hnn_core/gui/_viz_manager.py @@ -504,7 +504,7 @@ def _clear_axis(b, widgets, data, fig_idx, fig, ax, widgets_plot_type, _dynamic_rerender(fig) -def _get_ax_control(widgets, data, fig_idx, fig, ax): +def _get_ax_control(widgets, data, default_smoothing, fig_idx, fig, ax): analysis_style = {'description_width': '200px'} layout = Layout(width="98%") simulation_names = tuple(data['simulations'].keys()) @@ -565,7 +565,7 @@ def _get_ax_control(widgets, data, fig_idx, fig, ax): style=analysis_style, ) simulation_dipole_smooth = FloatText( - value=0, + value=default_smoothing, description='Dipole Smooth Window (ms):', disabled=False, layout=layout, @@ -716,12 +716,13 @@ def _close_figure(b, widgets, data, fig_idx): display(Label(_fig_placeholder)) -def _add_axes_controls(widgets, data, fig, axd): +def _add_axes_controls(widgets, data, default_smoothing, fig, axd): fig_idx = data['fig_idx']['idx'] controls = Tab() children = [ - _get_ax_control(widgets, data, fig_idx=fig_idx, fig=fig, ax=ax) + _get_ax_control(widgets, data, default_smoothing, fig_idx=fig_idx, + fig=fig, ax=ax) for ax_key, ax in axd.items() ] controls.children = children @@ -741,7 +742,8 @@ def _add_axes_controls(widgets, data, fig, axd): widgets['axes_config_tabs'].set_title(n_tabs, _idx2figname(fig_idx)) -def _add_figure(b, widgets, data, template_type, scale=0.95, dpi=96): +def _add_figure(b, widgets, data, default_smoothing, + template_type, scale=0.95, dpi=96): fig_idx = data['fig_idx']['idx'] viz_output_layout = data['visualization_output'] fig_outputs = Output() @@ -773,7 +775,7 @@ def _add_figure(b, widgets, data, template_type, scale=0.95, dpi=96): else: display(fig.canvas) - _add_axes_controls(widgets, data, fig=fig, axd=axd) + _add_axes_controls(widgets, data, default_smoothing, fig=fig, axd=axd) data['figs'][fig_idx] = fig widgets['figs_tabs'].selected_index = n_tabs @@ -824,9 +826,10 @@ class _VizManager: A dict of external simulation data object """ - def __init__(self, gui_data, viz_layout): + def __init__(self, gui_data, viz_layout, default_smoothing): plt.close("all") self.viz_layout = viz_layout + self.default_smoothing = default_smoothing self.use_ipympl = 'ipympl' in matplotlib.get_backend() self.axes_config_output = Output() @@ -857,7 +860,9 @@ def __init__(self, gui_data, viz_layout): button_style="primary", style={'button_color': self.viz_layout['theme_color']}, layout=self.viz_layout['btn']) - self.make_fig_button.on_click(self.add_figure) + self.make_fig_button.on_click( + lambda b: self.add_figure(self.default_smoothing) + ) self.datasets_dropdown = Dropdown( description='Dataset:', @@ -959,7 +964,7 @@ def _layout_template_change(self, template_type): self.datasets_dropdown.layout.visibility = "hidden" @unlink_relink(attribute='figs_config_tab_link') - def add_figure(self, b=None): + def add_figure(self, default_smoothing, b=None): """Add a figure and corresponding config tabs to the dashboard. """ if len(self.data["simulations"]) == 0: @@ -984,6 +989,7 @@ def add_figure(self, b=None): _add_figure(None, self.widgets, self.data, + default_smoothing, template_type, scale=0.97, dpi=self.viz_layout['dpi']) diff --git a/hnn_core/gui/gui.py b/hnn_core/gui/gui.py index 3b5c66350..dd4bbf565 100644 --- a/hnn_core/gui/gui.py +++ b/hnn_core/gui/gui.py @@ -324,6 +324,9 @@ def __init__(self, theme_color="#802989", self.simulation_data = defaultdict(lambda: dict(net=None, dpls=list())) # Simulation parameters + self.widget_default_smoothing = BoundedFloatText( + value=30.0, description='Smoothing:', + min=0.0, max=100.0, step=1.0, disabled=False) self.widget_tstop = BoundedFloatText( value=170, description='tstop (ms):', min=0, max=1e6, step=1, disabled=False) @@ -476,7 +479,8 @@ def _init_ui_components(self): self._log_out = Output() - self.viz_manager = _VizManager(self.data, self.layout) + self.viz_manager = _VizManager(self.data, self.layout, + self.widget_default_smoothing.value) # detailed configuration of backends self._backend_config_out = Output() @@ -565,6 +569,7 @@ def _run_button_clicked(b): return run_button_clicked( self.widget_simulation_name, self._log_out, self.drive_widgets, self.data, self.widget_dt, self.widget_tstop, + self.widget_default_smoothing, self.widget_ntrials, self.widget_backend_selection, self.widget_mpi_cmd, self.widget_n_jobs, self.params, self._simulation_status_bar, self._simulation_status_contents, @@ -669,8 +674,8 @@ def compose(self, return_layout=True): simulation_box = VBox([ VBox([ self.widget_simulation_name, self.widget_tstop, self.widget_dt, - self.widget_ntrials, self.widget_backend_selection, - self._backend_config_out]), + self.widget_ntrials, self.widget_default_smoothing, + self.widget_backend_selection, self._backend_config_out]), ], layout=self.layout['config_box']) connectivity_configuration = Tab() @@ -1910,7 +1915,8 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data, def run_button_clicked(widget_simulation_name, log_out, drive_widgets, - all_data, dt, tstop, ntrials, backend_selection, + all_data, dt, tstop, widget_default_smoothing, + ntrials, backend_selection, mpi_cmd, n_jobs, params, simulation_status_bar, simulation_status_contents, connectivity_textfields, viz_manager, simulations_list_widget, @@ -1960,7 +1966,8 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets, simulations_list_widget.value = sim_names[0] viz_manager.reset_fig_config_tabs() - viz_manager.add_figure() + default_smoothing = widget_default_smoothing.value + viz_manager.add_figure(default_smoothing=default_smoothing) fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1) ax_plots = [("ax0", "input histogram"), ("ax1", "current dipole")] for ax_name, plot_type in ax_plots: