From 4b1cd08adfd1749db753a915badd568f127ed873 Mon Sep 17 00:00:00 2001
From: George Dang <53052793+gtdang@users.noreply.github.com>
Date: Tue, 19 Mar 2024 09:11:55 -0400
Subject: [PATCH] refactor: black refactor of gui scripts
---
hnn_core/gui/__init__.py | 2 +-
hnn_core/gui/_viz_manager.py | 527 ++++++++------
hnn_core/gui/gui.py | 1328 +++++++++++++++++++++-------------
3 files changed, 1112 insertions(+), 745 deletions(-)
diff --git a/hnn_core/gui/__init__.py b/hnn_core/gui/__init__.py
index 3397d699b..3142afc2e 100644
--- a/hnn_core/gui/__init__.py
+++ b/hnn_core/gui/__init__.py
@@ -1 +1 @@
-from .gui import HNNGUI, launch
\ No newline at end of file
+from .gui import HNNGUI, launch
diff --git a/hnn_core/gui/_viz_manager.py b/hnn_core/gui/_viz_manager.py
index 4f9437454..c4e242197 100644
--- a/hnn_core/gui/_viz_manager.py
+++ b/hnn_core/gui/_viz_manager.py
@@ -10,34 +10,45 @@
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display
-from ipywidgets import (Box, Button, Dropdown, FloatText, HBox, Label, Layout,
- Output, Tab, VBox, link)
-
-from hnn_core.dipole import average_dipoles, _rmse
+from ipywidgets import (
+ Box,
+ Button,
+ Dropdown,
+ FloatText,
+ HBox,
+ Label,
+ Layout,
+ Output,
+ Tab,
+ VBox,
+ link,
+)
+
+from hnn_core.dipole import _rmse, average_dipoles
from hnn_core.gui._logging import logger
from hnn_core.viz import plot_dipole
-_fig_placeholder = 'Run simulation to add figures here.'
+_fig_placeholder = "Run simulation to add figures here."
_plot_types = [
- 'current dipole',
- 'layer2 dipole',
- 'layer5 dipole',
- 'input histogram',
- 'spikes',
- 'PSD',
- 'spectrogram',
- 'network',
+ "current dipole",
+ "layer2 dipole",
+ "layer5 dipole",
+ "input histogram",
+ "spikes",
+ "PSD",
+ "spectrogram",
+ "network",
]
_no_overlay_plot_types = [
- 'network',
- 'spectrogram',
- 'spikes',
- 'input histogram',
+ "network",
+ "spectrogram",
+ "spikes",
+ "input histogram",
]
-_ext_data_disabled_plot_types = ['spikes', 'input histogram', 'network']
+_ext_data_disabled_plot_types = ["spikes", "input histogram", "network"]
_spectrogram_color_maps = [
"viridis",
@@ -49,15 +60,15 @@
fig_templates = {
"2row x 1col (1:3)": {
- "kwargs": "gridspec_kw={\"height_ratios\":[1,3]}",
+ "kwargs": 'gridspec_kw={"height_ratios":[1,3]}',
"mosaic": "00\n11",
},
"2row x 1col (1:1)": {
- "kwargs": "gridspec_kw={\"height_ratios\":[1,1]}",
+ "kwargs": 'gridspec_kw={"height_ratios":[1,1]}',
"mosaic": "00\n11",
},
"1row x 2col (1:1)": {
- "kwargs": "gridspec_kw={\"height_ratios\":[1,1]}",
+ "kwargs": 'gridspec_kw={"height_ratios":[1,1]}',
"mosaic": "01\n01",
},
"single figure": {
@@ -65,15 +76,14 @@
"mosaic": "00\n00",
},
"2row x 2col (1:1)": {
- "kwargs": "gridspec_kw={\"height_ratios\":[1,1]}",
+ "kwargs": 'gridspec_kw={"height_ratios":[1,1]}',
"mosaic": "01\n23",
},
}
-def check_sim_plot_types(
- new_sim_name, plot_type_selection, target_selection, data):
- if data["simulations"][new_sim_name.new]['net'] is None:
+def check_sim_plot_types(new_sim_name, plot_type_selection, target_selection, data):
+ if data["simulations"][new_sim_name.new]["net"] is None:
plot_type_selection.options = [
pt for pt in _plot_types if pt not in _ext_data_disabled_plot_types
]
@@ -82,18 +92,17 @@ def check_sim_plot_types(
# deal with target data
all_possible_targets = list(data["simulations"].keys())
all_possible_targets.remove(new_sim_name.new)
- target_selection.options = all_possible_targets + ['None']
- target_selection.value = 'None'
+ target_selection.options = all_possible_targets + ["None"]
+ target_selection.value = "None"
def target_comparison_change(new_target_name, simulation_selection, data):
- """Triggered when the target data is turned on or changed.
- """
+ """Triggered when the target data is turned on or changed."""
pass
def plot_type_coupled_change(new_plot_type, target_data_selection):
- if new_plot_type != 'current dipole':
+ if new_plot_type != "current dipole":
target_data_selection.disabled = True
else:
target_data_selection.disabled = False
@@ -113,6 +122,7 @@ def unlink_relink(attribute):
widgets
"""
+
def _unlink_relink(f):
@wraps(f)
def wrapper(self, *args, **kwargs):
@@ -127,7 +137,9 @@ def wrapper(self, *args, **kwargs):
link_attribute.link()
return result
+
return wrapper
+
return _unlink_relink
@@ -156,14 +168,15 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config):
A dict that specifies the preprocessing and style of plots.
"""
# Make sure that visualization does not change the original data
- dpls_copied = copy.deepcopy(single_simulation['dpls'])
- net_copied = copy.deepcopy(single_simulation['net'])
+ dpls_copied = copy.deepcopy(single_simulation["dpls"])
+ net_copied = copy.deepcopy(single_simulation["net"])
for dpl in dpls_copied:
- if plot_config['dipole_smooth'] > 0:
- dpl.smooth(plot_config['dipole_smooth']).scale(
- plot_config['dipole_scaling'])
+ if plot_config["dipole_smooth"] > 0:
+ dpl.smooth(plot_config["dipole_smooth"]).scale(
+ plot_config["dipole_scaling"]
+ )
else:
- dpl.scale(plot_config['dipole_scaling'])
+ dpl.scale(plot_config["dipole_scaling"])
if net_copied is None:
assert plot_type not in _ext_data_disabled_plot_types
@@ -172,35 +185,38 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config):
# x and y axis are hidden after plotting some functions.
ax.get_yaxis().set_visible(True)
ax.get_xaxis().set_visible(True)
- if plot_type == 'spikes':
+ if plot_type == "spikes":
if net_copied.cell_response:
net_copied.cell_response.plot_spikes_raster(ax=ax, show=False)
- elif plot_type == 'input histogram':
+ elif plot_type == "input histogram":
if net_copied.cell_response:
net_copied.cell_response.plot_spikes_hist(ax=ax, show=False)
- elif plot_type == 'PSD':
+ elif plot_type == "PSD":
if len(dpls_copied) > 0:
color = ax._get_lines.get_next_color()
- dpls_copied[0].plot_psd(fmin=0, fmax=50, ax=ax, color=color,
- label=sim_name, show=False)
+ dpls_copied[0].plot_psd(
+ fmin=0, fmax=50, ax=ax, color=color, label=sim_name, show=False
+ )
- elif plot_type == 'spectrogram':
+ elif plot_type == "spectrogram":
if len(dpls_copied) > 0:
min_f = 10.0
- max_f = plot_config['max_spectral_frequency']
+ max_f = plot_config["max_spectral_frequency"]
step_f = 1.0
freqs = np.arange(min_f, max_f, step_f)
- n_cycles = freqs / 8.
+ n_cycles = freqs / 8.0
dpls_copied[0].plot_tfr_morlet(
freqs,
n_cycles=n_cycles,
- colormap=plot_config['spectrogram_cm'],
- ax=ax, colorbar_inside=True,
- show=False)
+ colormap=plot_config["spectrogram_cm"],
+ ax=ax,
+ colorbar_inside=True,
+ show=False,
+ )
- elif 'dipole' in plot_type:
+ elif "dipole" in plot_type:
if len(dpls_copied) > 0:
if len(dpls_copied) > 1:
label = f"{sim_name}: average"
@@ -208,46 +224,50 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config):
label = sim_name
color = ax._get_lines.get_next_color()
- if plot_type == 'current dipole':
- plot_dipole(dpls_copied,
- ax=ax,
- label=label,
- color=color,
- average=True,
- show=False)
+ if plot_type == "current dipole":
+ plot_dipole(
+ dpls_copied,
+ ax=ax,
+ label=label,
+ color=color,
+ average=True,
+ show=False,
+ )
else:
layer_namemap = {
"layer2": "L2",
"layer5": "L5",
}
- plot_dipole(dpls_copied,
- ax=ax,
- label=label,
- color=color,
- layer=layer_namemap[plot_type.split(" ")[0]],
- average=True,
- show=False)
+ plot_dipole(
+ dpls_copied,
+ ax=ax,
+ label=label,
+ color=color,
+ layer=layer_namemap[plot_type.split(" ")[0]],
+ average=True,
+ show=False,
+ )
else:
print("No dipole data")
- elif plot_type == 'network':
+ elif plot_type == "network":
if net_copied:
with plt.ioff():
_fig = plt.figure()
- _ax = _fig.add_subplot(111, projection='3d')
+ _ax = _fig.add_subplot(111, projection="3d")
net_copied.plot_cells(ax=_ax, show=False)
io_buf = io.BytesIO()
- _fig.savefig(io_buf, format='raw')
+ _fig.savefig(io_buf, format="raw")
io_buf.seek(0)
- img_arr = np.reshape(np.frombuffer(io_buf.getvalue(),
- dtype=np.uint8),
- newshape=(int(_fig.bbox.bounds[3]),
- int(_fig.bbox.bounds[2]), -1))
+ img_arr = np.reshape(
+ np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
+ newshape=(int(_fig.bbox.bounds[3]), int(_fig.bbox.bounds[2]), -1),
+ )
io_buf.close()
_ = ax.imshow(img_arr)
# set up alignment
- if plot_type not in ['network', 'PSD']:
+ if plot_type not in ["network", "PSD"]:
margin_x = 0
max_x = max([dpl.times[-1] for dpl in dpls_copied])
ax.set_xlim(left=-margin_x, right=max_x + margin_x)
@@ -256,11 +276,11 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config):
def _static_rerender(widgets, fig, fig_idx):
- logger.debug('_static_re_render is called')
- figs_tabs = widgets['figs_tabs']
+ logger.debug("_static_re_render is called")
+ figs_tabs = widgets["figs_tabs"]
titles = figs_tabs.titles
fig_tab_idx = titles.index(_idx2figname(fig_idx))
- fig_output = widgets['figs_tabs'].children[fig_tab_idx]
+ fig_output = widgets["figs_tabs"].children[fig_tab_idx]
fig_output.clear_output()
with fig_output:
fig.tight_layout()
@@ -273,11 +293,22 @@ def _dynamic_rerender(fig):
fig.tight_layout()
-def _plot_on_axes(b, widgets_simulation, widgets_plot_type,
- target_simulations,
- spectrogram_colormap_selection, dipole_smooth,
- max_spectral_frequency, dipole_scaling, widgets, data,
- fig_idx, fig, ax, existing_plots):
+def _plot_on_axes(
+ b,
+ widgets_simulation,
+ widgets_plot_type,
+ target_simulations,
+ spectrogram_colormap_selection,
+ dipole_smooth,
+ max_spectral_frequency,
+ dipole_scaling,
+ widgets,
+ data,
+ fig_idx,
+ fig,
+ ax,
+ existing_plots,
+):
"""Plotting different types of data on the given axes.
Now this function is also responsible for comparing multiple simulations,
@@ -323,34 +354,39 @@ def _plot_on_axes(b, widgets_simulation, widgets_plot_type,
# freeze plot type
widgets_plot_type.disabled = True
- single_simulation = data['simulations'][sim_name]
+ single_simulation = data["simulations"][sim_name]
plot_config = {
"max_spectral_frequency": max_spectral_frequency.value,
"dipole_scaling": dipole_scaling.value,
"dipole_smooth": dipole_smooth.value,
- "spectrogram_cm": spectrogram_colormap_selection.value
+ "spectrogram_cm": spectrogram_colormap_selection.value,
}
- dpls_processed = _update_ax(fig, ax, single_simulation, sim_name,
- plot_type, plot_config)
+ dpls_processed = _update_ax(
+ fig, ax, single_simulation, sim_name, plot_type, plot_config
+ )
# If target_simulations is not None and we are plotting a dipole,
# we need to plot the target dipole as well.
- if target_simulations.value in data['simulations'].keys(
- ) and plot_type == 'current dipole':
+ if (
+ target_simulations.value in data["simulations"].keys()
+ and plot_type == "current dipole"
+ ):
target_sim_name = target_simulations.value
- target_sim = data['simulations'][target_sim_name]
+ target_sim = data["simulations"][target_sim_name]
# plot the target dipole.
# disable scaling for the target dipole.
- plot_config['dipole_scaling'] = 1.
+ plot_config["dipole_scaling"] = 1.0
# plot the target dipole.
target_dpl_processed = _update_ax(
- fig, ax, target_sim, target_sim_name, plot_type,
- plot_config)[0] # we assume there is only one dipole.
+ fig, ax, target_sim, target_sim_name, plot_type, plot_config
+ )[
+ 0
+ ] # we assume there is only one dipole.
# calculate the RMSE between the two dipoles.
t0 = 0.0
@@ -361,45 +397,58 @@ def _plot_on_axes(b, widgets_simulation, widgets_plot_type,
dpl = dpls_processed
rmse = _rmse(dpl, target_dpl_processed, t0, tstop)
# Show the RMSE between the two dipoles.
- ax.annotate(f'RMSE({sim_name}, {target_sim_name}): {rmse:.4f}',
- xy=(0.95, 0.05),
- xycoords='axes fraction',
- horizontalalignment='right',
- verticalalignment='bottom',
- fontsize=12)
-
- existing_plots.children = (*existing_plots.children,
- Label(f"{sim_name}: {plot_type}"))
- if data['use_ipympl'] is False:
+ ax.annotate(
+ f"RMSE({sim_name}, {target_sim_name}): {rmse:.4f}",
+ xy=(0.95, 0.05),
+ xycoords="axes fraction",
+ horizontalalignment="right",
+ verticalalignment="bottom",
+ fontsize=12,
+ )
+
+ existing_plots.children = (
+ *existing_plots.children,
+ Label(f"{sim_name}: {plot_type}"),
+ )
+ if data["use_ipympl"] is False:
_static_rerender(widgets, fig, fig_idx)
else:
_dynamic_rerender(fig)
-def _clear_axis(b, widgets, data, fig_idx, fig, ax, widgets_plot_type,
- existing_plots, add_plot_button):
+def _clear_axis(
+ b,
+ widgets,
+ data,
+ fig_idx,
+ fig,
+ ax,
+ widgets_plot_type,
+ existing_plots,
+ add_plot_button,
+):
ax.clear()
# remove attached colorbar if exists
- if hasattr(fig, f'_cbar-ax-{id(ax)}'):
- getattr(fig, f'_cbar-ax-{id(ax)}').ax.remove()
- delattr(fig, f'_cbar-ax-{id(ax)}')
+ if hasattr(fig, f"_cbar-ax-{id(ax)}"):
+ getattr(fig, f"_cbar-ax-{id(ax)}").ax.remove()
+ delattr(fig, f"_cbar-ax-{id(ax)}")
- ax.set_facecolor('w')
- ax.set_aspect('auto')
+ ax.set_facecolor("w")
+ ax.set_aspect("auto")
widgets_plot_type.disabled = False
add_plot_button.disabled = False
existing_plots.children = ()
- if data['use_ipympl'] is False:
+ if data["use_ipympl"] is False:
_static_rerender(widgets, fig, fig_idx)
else:
_dynamic_rerender(fig)
def _get_ax_control(widgets, data, fig_idx, fig, ax):
- analysis_style = {'description_width': '200px'}
+ analysis_style = {"description_width": "200px"}
layout = Layout(width="98%")
- simulation_names = tuple(data['simulations'].keys())
+ simulation_names = tuple(data["simulations"].keys())
sim_name_default = simulation_names[-1]
if len(simulation_names) == 0:
simulation_names = [
@@ -409,13 +458,13 @@ def _get_ax_control(widgets, data, fig_idx, fig, ax):
simulation_selection = Dropdown(
options=simulation_names,
value=sim_name_default,
- description='Simulation Data:',
+ description="Simulation Data:",
disabled=False,
layout=layout,
style=analysis_style,
)
- if data['simulations'][sim_name_default]['net'] is None:
+ if data["simulations"][sim_name_default]["net"] is None:
valid_plot_types = [
pt for pt in _plot_types if pt not in _ext_data_disabled_plot_types
]
@@ -425,65 +474,70 @@ def _get_ax_control(widgets, data, fig_idx, fig, ax):
plot_type_selection = Dropdown(
options=valid_plot_types,
value=valid_plot_types[0],
- description='Type:',
+ description="Type:",
disabled=False,
layout=layout,
style=analysis_style,
)
target_data_selection = Dropdown(
- options=simulation_names[:-1] + ('None',),
- value='None',
- description='Data to Compare:',
+ options=simulation_names[:-1] + ("None",),
+ value="None",
+ description="Data to Compare:",
disabled=False,
layout=layout,
style=analysis_style,
)
spectrogram_colormap_selection = Dropdown(
- description='Spectrogram Colormap:',
+ description="Spectrogram Colormap:",
options=[(cm, cm) for cm in _spectrogram_color_maps],
value=_spectrogram_color_maps[0],
layout=layout,
style=analysis_style,
)
- dipole_smooth = FloatText(value=30,
- description='Dipole Smooth Window (ms):',
- disabled=False,
- layout=layout,
- style=analysis_style)
- dipole_scaling = FloatText(value=3000,
- description='Dipole Scaling:',
- disabled=False,
- layout=layout,
- style=analysis_style)
+ dipole_smooth = FloatText(
+ value=30,
+ description="Dipole Smooth Window (ms):",
+ disabled=False,
+ layout=layout,
+ style=analysis_style,
+ )
+ dipole_scaling = FloatText(
+ value=3000,
+ description="Dipole Scaling:",
+ disabled=False,
+ layout=layout,
+ style=analysis_style,
+ )
max_spectral_frequency = FloatText(
value=100,
- description='Max Spectral Frequency (Hz):',
+ description="Max Spectral Frequency (Hz):",
disabled=False,
layout=layout,
- style=analysis_style)
+ style=analysis_style,
+ )
existing_plots = VBox([])
- plot_button = Button(description='Add plot')
- clear_button = Button(description='Clear axis')
+ plot_button = Button(description="Add plot")
+ clear_button = Button(description="Clear axis")
def _on_sim_data_change(new_sim_name):
return check_sim_plot_types(
- new_sim_name, plot_type_selection, target_data_selection, data)
+ new_sim_name, plot_type_selection, target_data_selection, data
+ )
def _on_target_comparison_change(new_target_name):
- return target_comparison_change(new_target_name, simulation_selection,
- data)
+ return target_comparison_change(new_target_name, simulation_selection, data)
def _on_plot_type_change(new_plot_type):
return plot_type_coupled_change(new_plot_type, target_data_selection)
- simulation_selection.observe(_on_sim_data_change, 'value')
- target_data_selection.observe(_on_target_comparison_change, 'value')
- plot_type_selection.observe(_on_plot_type_change, 'value')
+ simulation_selection.observe(_on_sim_data_change, "value")
+ target_data_selection.observe(_on_target_comparison_change, "value")
+ plot_type_selection.observe(_on_plot_type_change, "value")
clear_button.on_click(
partial(
@@ -496,7 +550,8 @@ def _on_plot_type_change(new_plot_type):
widgets_plot_type=plot_type_selection,
existing_plots=existing_plots,
add_plot_button=plot_button,
- ))
+ )
+ )
plot_button.on_click(
partial(
@@ -514,22 +569,32 @@ def _on_plot_type_change(new_plot_type):
fig=fig,
ax=ax,
existing_plots=existing_plots,
- ))
+ )
+ )
- vbox = VBox([
- simulation_selection, plot_type_selection, target_data_selection,
- dipole_smooth, dipole_scaling, max_spectral_frequency,
- spectrogram_colormap_selection,
- HBox(
- [plot_button, clear_button],
- layout=Layout(justify_content='space-between'),
- ), existing_plots], layout=Layout(width="98%"))
+ vbox = VBox(
+ [
+ simulation_selection,
+ plot_type_selection,
+ target_data_selection,
+ dipole_smooth,
+ dipole_scaling,
+ max_spectral_frequency,
+ spectrogram_colormap_selection,
+ HBox(
+ [plot_button, clear_button],
+ layout=Layout(justify_content="space-between"),
+ ),
+ existing_plots,
+ ],
+ layout=Layout(width="98%"),
+ )
return vbox
def _close_figure(b, widgets, data, fig_idx):
- fig_related_widgets = [widgets['figs_tabs'], widgets['axes_config_tabs']]
+ fig_related_widgets = [widgets["figs_tabs"], widgets["axes_config_tabs"]]
for w_idx, tab in enumerate(fig_related_widgets):
# Get tab object's list of children and their titles
tab_children = list(tab.children)
@@ -547,27 +612,27 @@ def _close_figure(b, widgets, data, fig_idx):
# If the figure tab group...
if w_idx == 0:
# Close figure and delete the data
- plt.close(data['figs'][fig_idx])
- data['figs'].pop(fig_idx)
+ plt.close(data["figs"][fig_idx])
+ data["figs"].pop(fig_idx)
# Redisplay the remaining children
n_tabs = len(tab.children)
for idx in range(n_tabs):
_fig_idx = _figname2idx(tab.get_title(idx))
- assert _fig_idx in data['figs'].keys()
+ assert _fig_idx in data["figs"].keys()
tab.children[idx].clear_output()
with tab.children[idx]:
- display(data['figs'][_fig_idx].canvas)
+ display(data["figs"][_fig_idx].canvas)
# If all children have been deleted display the placeholder
if n_tabs == 0:
- widgets['figs_output'].clear_output()
- with widgets['figs_output']:
+ widgets["figs_output"].clear_output()
+ with widgets["figs_output"]:
display(Label(_fig_placeholder))
def _add_axes_controls(widgets, data, fig, axd):
- fig_idx = data['fig_idx']['idx']
+ fig_idx = data["fig_idx"]["idx"]
controls = Tab()
children = [
@@ -576,63 +641,66 @@ def _add_axes_controls(widgets, data, fig, axd):
]
controls.children = children
for i in range(len(children)):
- controls.set_title(i, f'ax{i}')
+ controls.set_title(i, f"ax{i}")
- close_fig_button = Button(description=f'Close {_idx2figname(fig_idx)}',
- button_style='danger',
- icon='close',
- layout=Layout(width="98%"))
+ close_fig_button = Button(
+ description=f"Close {_idx2figname(fig_idx)}",
+ button_style="danger",
+ icon="close",
+ layout=Layout(width="98%"),
+ )
close_fig_button.on_click(
- partial(_close_figure, widgets=widgets, data=data, fig_idx=fig_idx))
+ partial(_close_figure, widgets=widgets, data=data, fig_idx=fig_idx)
+ )
- n_tabs = len(widgets['axes_config_tabs'].children)
- widgets['axes_config_tabs'].children = widgets[
- 'axes_config_tabs'].children + (VBox([close_fig_button, controls]), )
- widgets['axes_config_tabs'].set_title(n_tabs, _idx2figname(fig_idx))
+ n_tabs = len(widgets["axes_config_tabs"].children)
+ widgets["axes_config_tabs"].children = widgets["axes_config_tabs"].children + (
+ VBox([close_fig_button, controls]),
+ )
+ widgets["axes_config_tabs"].set_title(n_tabs, _idx2figname(fig_idx))
def _add_figure(b, widgets, data, scale=0.95, dpi=96):
- template_name = widgets['templates_dropdown'].value
- fig_idx = data['fig_idx']['idx']
- viz_output_layout = data['visualization_output']
+ template_name = widgets["templates_dropdown"].value
+ fig_idx = data["fig_idx"]["idx"]
+ viz_output_layout = data["visualization_output"]
fig_outputs = Output()
- n_tabs = len(widgets['figs_tabs'].children)
+ n_tabs = len(widgets["figs_tabs"].children)
if n_tabs == 0:
- widgets['figs_output'].clear_output()
- with widgets['figs_output']:
- display(widgets['figs_tabs'])
+ widgets["figs_output"].clear_output()
+ with widgets["figs_output"]:
+ display(widgets["figs_tabs"])
- widgets['figs_tabs'].children = (
- [s for s in widgets['figs_tabs'].children] + [fig_outputs]
- )
- widgets['figs_tabs'].set_title(n_tabs, _idx2figname(fig_idx))
+ widgets["figs_tabs"].children = [s for s in widgets["figs_tabs"].children] + [
+ fig_outputs
+ ]
+ widgets["figs_tabs"].set_title(n_tabs, _idx2figname(fig_idx))
with fig_outputs:
- figsize = (scale * ((int(viz_output_layout.width[:-2]) - 10) / dpi),
- scale * ((int(viz_output_layout.height[:-2]) - 10) / dpi))
- mosaic = fig_templates[template_name]['mosaic']
+ figsize = (
+ scale * ((int(viz_output_layout.width[:-2]) - 10) / dpi),
+ scale * ((int(viz_output_layout.height[:-2]) - 10) / dpi),
+ )
+ mosaic = fig_templates[template_name]["mosaic"]
kwargs = eval(f"dict({fig_templates[template_name]['kwargs']})")
plt.ioff()
- fig, axd = plt.subplot_mosaic(mosaic,
- figsize=figsize,
- dpi=dpi,
- **kwargs)
+ fig, axd = plt.subplot_mosaic(mosaic, figsize=figsize, dpi=dpi, **kwargs)
plt.ion()
fig.tight_layout()
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
- if data['use_ipympl'] is False:
+ if data["use_ipympl"] is False:
plt.show()
else:
display(fig.canvas)
_add_axes_controls(widgets, data, fig=fig, axd=axd)
- data['figs'][fig_idx] = fig
- widgets['figs_tabs'].selected_index = n_tabs
- data['fig_idx']['idx'] += 1
+ data["figs"][fig_idx] = fig
+ widgets["figs_tabs"].selected_index = n_tabs
+ data["fig_idx"]["idx"] += 1
class _VizManager:
@@ -656,7 +724,7 @@ class _VizManager:
def __init__(self, gui_data, viz_layout):
plt.close("all")
self.viz_layout = viz_layout
- self.use_ipympl = 'ipympl' in matplotlib.get_backend()
+ self.use_ipympl = "ipympl" in matplotlib.get_backend()
self.axes_config_output = Output()
self.figs_output = Output()
@@ -667,22 +735,24 @@ def __init__(self, gui_data, viz_layout):
self.axes_config_tabs.selected_index = None
self.figs_tabs.selected_index = None
self.figs_config_tab_link = link(
- (self.axes_config_tabs, 'selected_index'),
- (self.figs_tabs, 'selected_index'),
+ (self.axes_config_tabs, "selected_index"),
+ (self.figs_tabs, "selected_index"),
)
template_names = list(fig_templates.keys())
self.templates_dropdown = Dropdown(
- description='Layout template:',
+ description="Layout template:",
options=template_names,
value=template_names[0],
- style={'description_width': 'initial'},
- layout=Layout(width="98%"))
+ style={"description_width": "initial"},
+ layout=Layout(width="98%"),
+ )
self.make_fig_button = Button(
- description='Make figure',
+ description="Make figure",
button_style="primary",
- style={'button_color': self.viz_layout['theme_color']},
- layout=self.viz_layout['btn'])
+ style={"button_color": self.viz_layout["theme_color"]},
+ layout=self.viz_layout["btn"],
+ )
self.make_fig_button.on_click(self.add_figure)
# data
@@ -696,7 +766,7 @@ def widgets(self):
"figs_output": self.figs_output,
"axes_config_tabs": self.axes_config_tabs,
"figs_tabs": self.figs_tabs,
- "templates_dropdown": self.templates_dropdown
+ "templates_dropdown": self.templates_dropdown,
}
@property
@@ -706,13 +776,13 @@ def data(self):
"use_ipympl": self.use_ipympl,
"simulations": self.gui_data["simulation_data"],
"fig_idx": self.fig_idx,
- "visualization_output": self.viz_layout['visualization_output'],
- "figs": self.figs
+ "visualization_output": self.viz_layout["visualization_output"],
+ "figs": self.figs,
}
def reset_fig_config_tabs(self, template_name=None):
"""Reset the figure config tabs with most recent simulation data."""
- simulation_names = tuple(self.data['simulations'].keys())
+ simulation_names = tuple(self.data["simulations"].keys())
for tab in self.axes_config_tabs.children:
controls = tab.children[1]
for ax_control in controls.children:
@@ -731,34 +801,34 @@ def compose(self):
display(Label(_fig_placeholder))
fig_output_container = VBox(
- [self.figs_output], layout=self.viz_layout['visualization_window'])
-
- config_panel = VBox([
- Box(
- [
- self.templates_dropdown,
- self.make_fig_button,
- ],
- layout=Layout(
- display='flex',
- flex_flow='column',
- align_items='stretch',
+ [self.figs_output], layout=self.viz_layout["visualization_window"]
+ )
+
+ config_panel = VBox(
+ [
+ Box(
+ [
+ self.templates_dropdown,
+ self.make_fig_button,
+ ],
+ layout=Layout(
+ display="flex",
+ flex_flow="column",
+ align_items="stretch",
+ ),
),
- ),
- Label("Figure config:"),
- self.axes_config_output,
- ])
+ Label("Figure config:"),
+ self.axes_config_output,
+ ]
+ )
return config_panel, fig_output_container
- @unlink_relink(attribute='figs_config_tab_link')
+ @unlink_relink(attribute="figs_config_tab_link")
def add_figure(self, b=None):
- """Add a figure and corresponding config tabs to the dashboard.
- """
- _add_figure(None,
- self.widgets,
- self.data,
- scale=0.97,
- dpi=self.viz_layout['dpi'])
+ """Add a figure and corresponding config tabs to the dashboard."""
+ _add_figure(
+ None, self.widgets, self.data, scale=0.97, dpi=self.viz_layout["dpi"]
+ )
def _simulate_add_fig(self):
self.make_fig_button.click()
@@ -777,8 +847,15 @@ def _simulate_delete_figure(self, fig_name):
close_button = self.axes_config_tabs.children[tab_idx].children[0]
close_button.click()
- def _simulate_edit_figure(self, fig_name, ax_name, simulation_name,
- plot_type, preprocessing_config, operation):
+ def _simulate_edit_figure(
+ self,
+ fig_name,
+ ax_name,
+ simulation_name,
+ plot_type,
+ preprocessing_config,
+ operation,
+ ):
"""Manipulate a certain figure.
Parameters
@@ -799,7 +876,7 @@ def _simulate_edit_figure(self, fig_name, ax_name, simulation_name,
`"plot"` if you want to plot and `"clear"` if you want to
remove previously plotted visualizations.
"""
- assert simulation_name in self.data['simulations'].keys()
+ assert simulation_name in self.data["simulations"].keys()
assert plot_type in _plot_types
assert operation in ("plot", "clear")
diff --git a/hnn_core/gui/gui.py b/hnn_core/gui/gui.py
index 70622d1d8..f0bc3307f 100644
--- a/hnn_core/gui/gui.py
+++ b/hnn_core/gui/gui.py
@@ -10,21 +10,46 @@
import urllib.parse
import urllib.request
from collections import defaultdict
-from pathlib import Path
from datetime import datetime
+from pathlib import Path
+
from IPython.display import IFrame, display
-from ipywidgets import (HTML, Accordion, AppLayout, BoundedFloatText,
- BoundedIntText, Button, Dropdown, FileUpload, VBox,
- HBox, IntText, Layout, Output, RadioButtons, Tab, Text)
+from ipywidgets import (
+ HTML,
+ Accordion,
+ AppLayout,
+ BoundedFloatText,
+ BoundedIntText,
+ Button,
+ Dropdown,
+ FileUpload,
+ HBox,
+ IntText,
+ Layout,
+ Output,
+ RadioButtons,
+ Tab,
+ Text,
+ VBox,
+)
from ipywidgets.embed import embed_minimal_html
+
import hnn_core
-from hnn_core import (JoblibBackend, MPIBackend, jones_2009_model, read_params,
- simulate_dipole)
+from hnn_core import (
+ JoblibBackend,
+ MPIBackend,
+ jones_2009_model,
+ read_params,
+ simulate_dipole,
+)
from hnn_core.gui._logging import logger
-from hnn_core.gui._viz_manager import _VizManager, _idx2figname
+from hnn_core.gui._viz_manager import _idx2figname, _VizManager
from hnn_core.network import pick_connection
-from hnn_core.params import (_extract_drive_specs_from_hnn_params, _read_json,
- _read_legacy_params)
+from hnn_core.params import (
+ _extract_drive_specs_from_hnn_params,
+ _read_json,
+ _read_legacy_params,
+)
class _OutputWidgetHandler(logging.Handler):
@@ -35,11 +60,11 @@ def __init__(self, output_widget, *args, **kwargs):
def emit(self, record):
formatted_record = self.format(record)
new_output = {
- 'name': 'stdout',
- 'output_type': 'stream',
- 'text': formatted_record + '\n'
+ "name": "stdout",
+ "output_type": "stream",
+ "text": formatted_record + "\n",
}
- self.out.outputs = (new_output, ) + self.out.outputs
+ self.out.outputs = (new_output,) + self.out.outputs
class HNNGUI:
@@ -120,18 +145,20 @@ class HNNGUI:
in the network.
"""
- def __init__(self, theme_color="#8A2BE2",
- total_height=800,
- total_width=1300,
- header_height=50,
- button_height=30,
- operation_box_height=60,
- drive_widget_width=200,
- left_sidebar_width=576,
- log_window_height=150,
- status_height=30,
- dpi=96,
- ):
+ def __init__(
+ self,
+ theme_color="#8A2BE2",
+ total_height=800,
+ total_width=1300,
+ header_height=50,
+ button_height=30,
+ operation_box_height=60,
+ drive_widget_width=200,
+ left_sidebar_width=576,
+ log_window_height=150,
+ status_height=30,
+ dpi=96,
+ ):
# set up styling.
self.total_height = total_height
self.total_width = total_width
@@ -139,41 +166,50 @@ def __init__(self, theme_color="#8A2BE2",
viz_win_width = self.total_width - left_sidebar_width
main_content_height = self.total_height - status_height
- config_box_height = main_content_height - (log_window_height +
- operation_box_height)
+ config_box_height = main_content_height - (
+ log_window_height + operation_box_height
+ )
self.layout = {
"dpi": dpi,
"header_height": f"{header_height}px",
"theme_color": theme_color,
- "btn": Layout(height=f"{button_height}px", width='auto'),
- "btn_full_w": Layout(height=f"{button_height}px", width='100%'),
- "del_fig_btn": Layout(height=f"{button_height}px", width='auto'),
- "log_out": Layout(border='1px solid gray',
- height=f"{log_window_height-10}px",
- overflow='auto'),
- "viz_config": Layout(width='99%'),
+ "btn": Layout(height=f"{button_height}px", width="auto"),
+ "btn_full_w": Layout(height=f"{button_height}px", width="100%"),
+ "del_fig_btn": Layout(height=f"{button_height}px", width="auto"),
+ "log_out": Layout(
+ border="1px solid gray",
+ height=f"{log_window_height-10}px",
+ overflow="auto",
+ ),
+ "viz_config": Layout(width="99%"),
"visualization_window": Layout(
width=f"{viz_win_width-10}px",
height=f"{main_content_height-10}px",
- border='1px solid gray',
- overflow='scroll'),
+ border="1px solid gray",
+ overflow="scroll",
+ ),
"visualization_output": Layout(
width=f"{viz_win_width-50}px",
height=f"{main_content_height-100}px",
- border='1px solid gray',
- overflow='scroll'),
- "left_sidebar": Layout(width=f"{left_sidebar_width}px",
- height=f"{main_content_height}px"),
- "left_tab": Layout(width=f"{left_sidebar_width}px",
- height=f"{config_box_height}px"),
- "operation_box": Layout(width=f"{left_sidebar_width}px",
- height=f"{operation_box_height}px",
- flex_wrap="wrap",
- ),
- "config_box": Layout(width=f"{left_sidebar_width}px",
- height=f"{config_box_height-100}px"),
+ border="1px solid gray",
+ overflow="scroll",
+ ),
+ "left_sidebar": Layout(
+ width=f"{left_sidebar_width}px", height=f"{main_content_height}px"
+ ),
+ "left_tab": Layout(
+ width=f"{left_sidebar_width}px", height=f"{config_box_height}px"
+ ),
+ "operation_box": Layout(
+ width=f"{left_sidebar_width}px",
+ height=f"{operation_box_height}px",
+ flex_wrap="wrap",
+ ),
+ "config_box": Layout(
+ width=f"{left_sidebar_width}px", height=f"{config_box_height-100}px"
+ ),
"drive_widget": Layout(width="auto"),
- "drive_textbox": Layout(width='270px', height='auto'),
+ "drive_textbox": Layout(width="270px", height="auto"),
# simulation status related
"simulation_status_height": f"{status_height}px",
"simulation_status_common": "background:gray;padding-left:10px",
@@ -183,17 +219,13 @@ def __init__(self, theme_color="#8A2BE2",
}
self._simulation_status_contents = {
- "not_running":
- f"""
Not running
""",
- "running":
- f"""Running...
""",
- "finished":
- f"""Simulation finished
""",
- "failed":
- f"""Simulation failed
""",
}
@@ -205,68 +237,102 @@ def __init__(self, theme_color="#8A2BE2",
# Simulation parameters
self.widget_tstop = BoundedFloatText(
- value=170, description='tstop (ms):', min=0, max=1e6, step=1,
- disabled=False)
+ value=170, description="tstop (ms):", min=0, max=1e6, step=1, disabled=False
+ )
self.widget_dt = BoundedFloatText(
- value=0.025, description='dt (ms):', min=0, max=10, step=0.01,
- disabled=False)
- self.widget_ntrials = IntText(value=1, description='Trials:',
- disabled=False)
- self.widget_simulation_name = Text(value='default',
- placeholder='ID of your simulation',
- description='Name:',
- disabled=False)
- self.widget_backend_selection = Dropdown(options=[('Joblib', 'Joblib'),
- ('MPI', 'MPI')],
- value='Joblib',
- description='Backend:')
- self.widget_mpi_cmd = Text(value='mpiexec',
- placeholder='Fill if applies',
- description='MPI cmd:', disabled=False)
- self.widget_n_jobs = BoundedIntText(value=1, min=1,
- max=multiprocessing.cpu_count(),
- description='Cores:',
- disabled=False)
+ value=0.025,
+ description="dt (ms):",
+ min=0,
+ max=10,
+ step=0.01,
+ disabled=False,
+ )
+ self.widget_ntrials = IntText(value=1, description="Trials:", disabled=False)
+ self.widget_simulation_name = Text(
+ value="default",
+ placeholder="ID of your simulation",
+ description="Name:",
+ disabled=False,
+ )
+ self.widget_backend_selection = Dropdown(
+ options=[("Joblib", "Joblib"), ("MPI", "MPI")],
+ value="Joblib",
+ description="Backend:",
+ )
+ self.widget_mpi_cmd = Text(
+ value="mpiexec",
+ placeholder="Fill if applies",
+ description="MPI cmd:",
+ disabled=False,
+ )
+ self.widget_n_jobs = BoundedIntText(
+ value=1,
+ min=1,
+ max=multiprocessing.cpu_count(),
+ description="Cores:",
+ disabled=False,
+ )
self.load_data_button = FileUpload(
- accept='.txt', multiple=False,
- style={'button_color': self.layout['theme_color']},
- description='Load data',
- button_style='success')
+ accept=".txt",
+ multiple=False,
+ style={"button_color": self.layout["theme_color"]},
+ description="Load data",
+ button_style="success",
+ )
# Drive selection
self.widget_drive_type_selection = RadioButtons(
- options=['Evoked', 'Poisson', 'Rhythmic'],
- value='Evoked',
- description='Drive:',
+ options=["Evoked", "Poisson", "Rhythmic"],
+ value="Evoked",
+ description="Drive:",
disabled=False,
- layout=self.layout['drive_widget'])
+ layout=self.layout["drive_widget"],
+ )
self.widget_location_selection = RadioButtons(
- options=['proximal', 'distal'], value='proximal',
- description='Location', disabled=False,
- layout=self.layout['drive_widget'])
+ options=["proximal", "distal"],
+ value="proximal",
+ description="Location",
+ disabled=False,
+ layout=self.layout["drive_widget"],
+ )
self.add_drive_button = create_expanded_button(
- 'Add drive', 'primary', layout=self.layout['btn'],
- button_color=self.layout['theme_color'])
+ "Add drive",
+ "primary",
+ layout=self.layout["btn"],
+ button_color=self.layout["theme_color"],
+ )
# Dashboard level buttons
self.run_button = create_expanded_button(
- 'Run', 'success', layout=self.layout['btn'],
- button_color=self.layout['theme_color'])
+ "Run",
+ "success",
+ layout=self.layout["btn"],
+ button_color=self.layout["theme_color"],
+ )
self.load_connectivity_button = FileUpload(
- accept='.json,.param', multiple=False,
- style={'button_color': self.layout['theme_color']},
- description='Load local network connectivity',
- layout=self.layout['btn_full_w'], button_style='success')
+ accept=".json,.param",
+ multiple=False,
+ style={"button_color": self.layout["theme_color"]},
+ description="Load local network connectivity",
+ layout=self.layout["btn_full_w"],
+ button_style="success",
+ )
self.load_drives_button = FileUpload(
- accept='.json,.param', multiple=False,
- style={'button_color': self.layout['theme_color']},
- description='Load external drives', layout=self.layout['btn'],
- button_style='success')
+ accept=".json,.param",
+ multiple=False,
+ style={"button_color": self.layout["theme_color"]},
+ description="Load external drives",
+ layout=self.layout["btn"],
+ button_style="success",
+ )
self.delete_drive_button = create_expanded_button(
- 'Delete drives', 'success', layout=self.layout['btn'],
- button_color=self.layout['theme_color'])
+ "Delete drives",
+ "success",
+ layout=self.layout["btn"],
+ button_color=self.layout["theme_color"],
+ )
# Plotting window
@@ -288,7 +354,8 @@ def __init__(self, theme_color="#8A2BE2",
def add_logging_window_logger(self):
handler = _OutputWidgetHandler(self._log_out)
handler.setFormatter(
- logging.Formatter('%(asctime)s - [%(levelname)s] %(message)s'))
+ logging.Formatter("%(asctime)s - [%(levelname)s] %(message)s")
+ )
logger.addHandler(handler)
def _init_ui_components(self):
@@ -311,23 +378,27 @@ def _init_ui_components(self):
# static parts
# Running status
self._simulation_status_bar = HTML(
- value=self._simulation_status_contents['not_running'])
+ value=self._simulation_status_contents["not_running"]
+ )
- self._log_window = HBox([self._log_out], layout=self.layout['log_out'])
+ self._log_window = HBox([self._log_out], layout=self.layout["log_out"])
self._operation_buttons = HBox(
[self.run_button, self.load_data_button],
- layout=self.layout['operation_box'])
+ layout=self.layout["operation_box"],
+ )
# title
- self._header = HTML(value=f"""
- HUMAN NEOCORTICAL NEUROSOLVER
""")
+ HUMAN NEOCORTICAL NEUROSOLVER"""
+ )
@property
def analysis_config(self):
"""Provides everything viz window needs except for the data."""
return {
- "viz_style": self.layout['visualization_output'],
+ "viz_style": self.layout["visualization_output"],
# widgets
"plot_outputs": self.plot_outputs_dict,
"plot_dropdowns": self.plot_dropdown_types_dict,
@@ -346,23 +417,30 @@ def load_parameters(params_fname=None):
if not params_fname:
# by default load default.json
hnn_core_root = Path(hnn_core.__file__).parent
- params_fname = hnn_core_root / 'param/default.json'
+ params_fname = hnn_core_root / "param/default.json"
return read_params(params_fname)
def _link_callbacks(self):
"""Link callbacks to UI components."""
+
def _handle_backend_change(backend_type):
- return handle_backend_change(backend_type.new,
- self._backend_config_out,
- self.widget_mpi_cmd,
- self.widget_n_jobs)
+ return handle_backend_change(
+ backend_type.new,
+ self._backend_config_out,
+ self.widget_mpi_cmd,
+ self.widget_n_jobs,
+ )
def _add_drive_button_clicked(b):
- return add_drive_widget(self.widget_drive_type_selection.value,
- self.drive_boxes, self.drive_widgets,
- self._drives_out, self.widget_tstop,
- self.widget_location_selection.value,
- layout=self.layout['drive_textbox'])
+ return add_drive_widget(
+ self.widget_drive_type_selection.value,
+ self.drive_boxes,
+ self.drive_widgets,
+ self._drives_out,
+ self.widget_tstop,
+ self.widget_location_selection.value,
+ layout=self.layout["drive_textbox"],
+ )
def _delete_drives_clicked(b):
self._drives_out.clear_output()
@@ -374,42 +452,68 @@ def _delete_drives_clicked(b):
def _on_upload_connectivity(change):
return on_upload_params_change(
- change, self.params, self.widget_tstop, self.widget_dt,
- self._log_out, self.drive_boxes, self.drive_widgets,
- self._drives_out, self._connectivity_out,
- self.connectivity_widgets, self.layout['drive_textbox'],
- "connectivity")
+ change,
+ self.params,
+ self.widget_tstop,
+ self.widget_dt,
+ self._log_out,
+ self.drive_boxes,
+ self.drive_widgets,
+ self._drives_out,
+ self._connectivity_out,
+ self.connectivity_widgets,
+ self.layout["drive_textbox"],
+ "connectivity",
+ )
def _on_upload_drives(change):
return on_upload_params_change(
- change, self.params, self.widget_tstop, self.widget_dt,
- self._log_out, self.drive_boxes, self.drive_widgets,
- self._drives_out, self._connectivity_out,
- self.connectivity_widgets, self.layout['drive_textbox'],
- "drives")
+ change,
+ self.params,
+ self.widget_tstop,
+ self.widget_dt,
+ self._log_out,
+ self.drive_boxes,
+ self.drive_widgets,
+ self._drives_out,
+ self._connectivity_out,
+ self.connectivity_widgets,
+ self.layout["drive_textbox"],
+ "drives",
+ )
def _on_upload_data(change):
- return on_upload_data_change(change, self.data, self.viz_manager,
- self._log_out)
+ return on_upload_data_change(
+ change, self.data, self.viz_manager, self._log_out
+ )
def _run_button_clicked(b):
return run_button_clicked(
- self.widget_simulation_name, self._log_out, self.drive_widgets,
- self.data, self.widget_dt, self.widget_tstop,
- self.widget_ntrials, self.widget_backend_selection,
- self.widget_mpi_cmd, self.widget_n_jobs, self.params,
- self._simulation_status_bar, self._simulation_status_contents,
- self.connectivity_widgets, self.viz_manager)
-
- self.widget_backend_selection.observe(_handle_backend_change, 'value')
+ self.widget_simulation_name,
+ self._log_out,
+ self.drive_widgets,
+ self.data,
+ self.widget_dt,
+ self.widget_tstop,
+ self.widget_ntrials,
+ self.widget_backend_selection,
+ self.widget_mpi_cmd,
+ self.widget_n_jobs,
+ self.params,
+ self._simulation_status_bar,
+ self._simulation_status_contents,
+ self.connectivity_widgets,
+ self.viz_manager,
+ )
+
+ self.widget_backend_selection.observe(_handle_backend_change, "value")
self.add_drive_button.on_click(_add_drive_button_clicked)
self.delete_drive_button.on_click(_delete_drives_clicked)
- self.load_connectivity_button.observe(_on_upload_connectivity,
- names='value')
- self.load_drives_button.observe(_on_upload_drives, names='value')
+ self.load_connectivity_button.observe(_on_upload_connectivity, names="value")
+ self.load_drives_button.observe(_on_upload_drives, names="value")
self.run_button.on_click(_run_button_clicked)
- self.load_data_button.observe(_on_upload_data, names='value')
+ self.load_data_button.observe(_on_upload_data, names="value")
def compose(self, return_layout=True):
"""Compose widgets.
@@ -420,70 +524,108 @@ def compose(self, return_layout=True):
If the method returns the layout object which can be rendered by
IPython.display.display() method.
"""
- simulation_box = VBox([
- VBox([
- self.widget_simulation_name, self.widget_tstop, self.widget_dt,
- self.widget_ntrials, self.widget_backend_selection,
- self._backend_config_out]),
- ], layout=self.layout['config_box'])
-
- connectivity_box = VBox([
- HBox([self.load_connectivity_button, ]),
- self._connectivity_out,
- ])
+ simulation_box = VBox(
+ [
+ VBox(
+ [
+ self.widget_simulation_name,
+ self.widget_tstop,
+ self.widget_dt,
+ self.widget_ntrials,
+ self.widget_backend_selection,
+ self._backend_config_out,
+ ]
+ ),
+ ],
+ layout=self.layout["config_box"],
+ )
+
+ connectivity_box = VBox(
+ [
+ HBox(
+ [
+ self.load_connectivity_button,
+ ]
+ ),
+ self._connectivity_out,
+ ]
+ )
# accordions to group local-connectivity by cell type
- connectivity_boxes = [
- VBox(slider) for slider in self.connectivity_widgets]
+ connectivity_boxes = [VBox(slider) for slider in self.connectivity_widgets]
connectivity_names = (
- 'Layer 2/3 Pyramidal', 'Layer 5 Pyramidal', 'Layer 2 Basket',
- 'Layer 5 Basket')
+ "Layer 2/3 Pyramidal",
+ "Layer 5 Pyramidal",
+ "Layer 2 Basket",
+ "Layer 5 Basket",
+ )
cell_connectivity = Accordion(children=connectivity_boxes)
cell_connectivity.titles = [s for s in connectivity_names]
- drive_selections = VBox([
- self.add_drive_button, self.widget_drive_type_selection,
- self.widget_location_selection],
- layout=Layout(flex="1"))
+ drive_selections = VBox(
+ [
+ self.add_drive_button,
+ self.widget_drive_type_selection,
+ self.widget_location_selection,
+ ],
+ layout=Layout(flex="1"),
+ )
- drives_options = VBox([
- HBox([
- VBox([self.load_drives_button, self.delete_drive_button],
- layout=Layout(flex="1")),
- drive_selections,
- ]), self._drives_out
- ])
+ drives_options = VBox(
+ [
+ HBox(
+ [
+ VBox(
+ [self.load_drives_button, self.delete_drive_button],
+ layout=Layout(flex="1"),
+ ),
+ drive_selections,
+ ]
+ ),
+ self._drives_out,
+ ]
+ )
config_panel, figs_output = self.viz_manager.compose()
# Tabs for left pane
left_tab = Tab()
left_tab.children = [
- simulation_box, connectivity_box, drives_options,
+ simulation_box,
+ connectivity_box,
+ drives_options,
config_panel,
]
- titles = ('Simulation', 'Network connectivity', 'External drives',
- 'Visualization')
+ titles = (
+ "Simulation",
+ "Network connectivity",
+ "External drives",
+ "Visualization",
+ )
for idx, title in enumerate(titles):
left_tab.set_title(idx, title)
self.app_layout = AppLayout(
header=self._header,
- left_sidebar=VBox([
- VBox([left_tab], layout=self.layout['left_tab']),
- self._operation_buttons,
- self._log_window,
- ], layout=self.layout['left_sidebar']),
+ left_sidebar=VBox(
+ [
+ VBox([left_tab], layout=self.layout["left_tab"]),
+ self._operation_buttons,
+ self._log_window,
+ ],
+ layout=self.layout["left_sidebar"],
+ ),
right_sidebar=figs_output,
footer=self._simulation_status_bar,
pane_widths=[
- self.layout['left_sidebar'].width, '0px',
- self.layout['visualization_window'].width
+ self.layout["left_sidebar"].width,
+ "0px",
+ self.layout["visualization_window"].width,
],
pane_heights=[
- self.layout['header_height'],
- self.layout['visualization_window'].height,
- self.layout['simulation_status_height']
+ self.layout["header_height"],
+ self.layout["visualization_window"].height,
+ self.layout["simulation_status_height"],
],
)
@@ -492,11 +634,17 @@ def compose(self, return_layout=True):
# self.simulation_data[self.widget_simulation_name.value]
# initialize drive and connectivity ipywidgets
- load_drive_and_connectivity(self.params, self._log_out,
- self._drives_out, self.drive_widgets,
- self.drive_boxes, self._connectivity_out,
- self.connectivity_widgets,
- self.widget_tstop, self.layout)
+ load_drive_and_connectivity(
+ self.params,
+ self._log_out,
+ self._drives_out,
+ self.drive_widgets,
+ self.drive_boxes,
+ self._connectivity_out,
+ self.connectivity_widgets,
+ self.widget_tstop,
+ self.layout,
+ )
if not return_layout:
return
@@ -525,13 +673,13 @@ def capture(self, width=None, height=None, extra_margin=100, render=True):
snapshot : An iframe snapshot object that can be rendered in notebooks.
"""
file = io.StringIO()
- embed_minimal_html(file, views=[self.app_layout], title='')
+ embed_minimal_html(file, views=[self.app_layout], title="")
if not width:
width = self.total_width + extra_margin
if not height:
height = self.total_height + extra_margin
- content = urllib.parse.quote(file.getvalue().encode('utf8'))
+ content = urllib.parse.quote(file.getvalue().encode("utf8"))
data_url = f"data:text/html,{content}"
screenshot = IFrame(data_url, width=width, height=height)
if render:
@@ -605,15 +753,15 @@ def run_notebook_cells(self):
# below are a series of methods that are used to manipulate the GUI
def _simulate_upload_data(self, file_url):
uploaded_value = _prepare_upload_file_from_url(file_url)
- self.load_data_button.set_trait('value', uploaded_value)
+ self.load_data_button.set_trait("value", uploaded_value)
def _simulate_upload_connectivity(self, file_url):
uploaded_value = _prepare_upload_file_from_url(file_url)
- self.load_connectivity_button.set_trait('value', uploaded_value)
+ self.load_connectivity_button.set_trait("value", uploaded_value)
def _simulate_upload_drives(self, file_url):
uploaded_value = _prepare_upload_file_from_url(file_url)
- self.load_drives_button.set_trait('value', uploaded_value)
+ self.load_drives_button.set_trait("value", uploaded_value)
def _simulate_left_tab_click(self, tab_title):
# Get left tab group object
@@ -625,7 +773,9 @@ def _simulate_left_tab_click(self, tab_title):
else:
raise ValueError("Tab title does not exist.")
- def _simulate_make_figure(self,):
+ def _simulate_make_figure(
+ self,
+ ):
self._simulate_left_tab_click("Visualization")
self.viz_manager.make_fig_button.click()
@@ -654,45 +804,63 @@ def _prepare_upload_file_from_url(file_url):
for line in data:
content += line
- return [{
- 'name': params_name,
- 'type': 'application/json',
- 'size': len(content),
- 'content': content,
- 'last_modified': datetime.now()
- }]
+ return [
+ {
+ "name": params_name,
+ "type": "application/json",
+ "size": len(content),
+ "content": content,
+ "last_modified": datetime.now(),
+ }
+ ]
-def create_expanded_button(description, button_style, layout, disabled=False,
- button_color="#8A2BE2"):
- return Button(description=description, button_style=button_style,
- layout=layout, style={'button_color': button_color},
- disabled=disabled)
+def create_expanded_button(
+ description, button_style, layout, disabled=False, button_color="#8A2BE2"
+):
+ return Button(
+ description=description,
+ button_style=button_style,
+ layout=layout,
+ style={"button_color": button_color},
+ disabled=disabled,
+ )
def _get_connectivity_widgets(conn_data):
"""Create connectivity box widgets from specified weight and probability"""
- style = {'description_width': '150px'}
+ style = {"description_width": "150px"}
style = {}
sliders = list()
for receptor_name in conn_data.keys():
w_text_input = BoundedFloatText(
- value=conn_data[receptor_name]['weight'], disabled=False,
- continuous_update=False, min=0, max=1e6, step=0.01,
- description="weight", style=style)
+ value=conn_data[receptor_name]["weight"],
+ disabled=False,
+ continuous_update=False,
+ min=0,
+ max=1e6,
+ step=0.01,
+ description="weight",
+ style=style,
+ )
- conn_widget = VBox([
- HTML(value=f"""
- Receptor: {conn_data[receptor_name]['receptor']}
"""),
- w_text_input, HTML(value="
")
- ])
+ conn_widget = VBox(
+ [
+ HTML(
+ value=f"""
+ Receptor: {conn_data[receptor_name]['receptor']}
"""
+ ),
+ w_text_input,
+ HTML(value="
"),
+ ]
+ )
conn_widget._belongsto = {
- "receptor": conn_data[receptor_name]['receptor'],
- "location": conn_data[receptor_name]['location'],
- "src_gids": conn_data[receptor_name]['src_gids'],
- "target_gids": conn_data[receptor_name]['target_gids'],
+ "receptor": conn_data[receptor_name]["receptor"],
+ "location": conn_data[receptor_name]["location"],
+ "src_gids": conn_data[receptor_name]["src_gids"],
+ "target_gids": conn_data[receptor_name]["target_gids"],
}
sliders.append(conn_widget)
@@ -701,23 +869,23 @@ def _get_connectivity_widgets(conn_data):
def _get_cell_specific_widgets(layout, style, location, data=None):
default_data = {
- 'weights_ampa': {
- 'L5_pyramidal': 0.,
- 'L2_pyramidal': 0.,
- 'L5_basket': 0.,
- 'L2_basket': 0.
+ "weights_ampa": {
+ "L5_pyramidal": 0.0,
+ "L2_pyramidal": 0.0,
+ "L5_basket": 0.0,
+ "L2_basket": 0.0,
},
- 'weights_nmda': {
- 'L5_pyramidal': 0.,
- 'L2_pyramidal': 0.,
- 'L5_basket': 0.,
- 'L2_basket': 0.
+ "weights_nmda": {
+ "L5_pyramidal": 0.0,
+ "L2_pyramidal": 0.0,
+ "L5_basket": 0.0,
+ "L2_basket": 0.0,
},
- 'delays': {
- 'L5_pyramidal': 0.1,
- 'L2_pyramidal': 0.1,
- 'L5_basket': 0.1,
- 'L2_basket': 0.1
+ "delays": {
+ "L5_pyramidal": 0.1,
+ "L2_pyramidal": 0.1,
+ "L5_basket": 0.1,
+ "L2_basket": 0.1,
},
}
if isinstance(data, dict):
@@ -726,159 +894,224 @@ def _get_cell_specific_widgets(layout, style, location, data=None):
default_data[k].update(data[k])
kwargs = dict(layout=layout, style=style)
- cell_types = ['L5_pyramidal', 'L2_pyramidal', 'L5_basket', 'L2_basket']
+ cell_types = ["L5_pyramidal", "L2_pyramidal", "L5_basket", "L2_basket"]
if location == "distal":
- cell_types.remove('L5_basket')
+ cell_types.remove("L5_basket")
weights_ampa, weights_nmda, delays = dict(), dict(), dict()
for cell_type in cell_types:
- weights_ampa[f'{cell_type}'] = BoundedFloatText(
- value=default_data['weights_ampa'][cell_type],
- description=f'{cell_type}:', min=0, max=1e6, step=0.01, **kwargs)
- weights_nmda[f'{cell_type}'] = BoundedFloatText(
- value=default_data['weights_nmda'][cell_type],
- description=f'{cell_type}:', min=0, max=1e6, step=0.01, **kwargs)
- delays[f'{cell_type}'] = BoundedFloatText(
- value=default_data['delays'][cell_type],
- description=f'{cell_type}:', min=0, max=1e6, step=0.1, **kwargs)
+ weights_ampa[f"{cell_type}"] = BoundedFloatText(
+ value=default_data["weights_ampa"][cell_type],
+ description=f"{cell_type}:",
+ min=0,
+ max=1e6,
+ step=0.01,
+ **kwargs,
+ )
+ weights_nmda[f"{cell_type}"] = BoundedFloatText(
+ value=default_data["weights_nmda"][cell_type],
+ description=f"{cell_type}:",
+ min=0,
+ max=1e6,
+ step=0.01,
+ **kwargs,
+ )
+ delays[f"{cell_type}"] = BoundedFloatText(
+ value=default_data["delays"][cell_type],
+ description=f"{cell_type}:",
+ min=0,
+ max=1e6,
+ step=0.1,
+ **kwargs,
+ )
widgets_dict = {
- 'weights_ampa': weights_ampa,
- 'weights_nmda': weights_nmda,
- 'delays': delays
+ "weights_ampa": weights_ampa,
+ "weights_nmda": weights_nmda,
+ "delays": delays,
}
- widgets_list = ([HTML(value="AMPA weights")] +
- list(weights_ampa.values()) +
- [HTML(value="NMDA weights")] +
- list(weights_nmda.values()) +
- [HTML(value="Synaptic delays")] +
- list(delays.values()))
+ widgets_list = (
+ [HTML(value="AMPA weights")]
+ + list(weights_ampa.values())
+ + [HTML(value="NMDA weights")]
+ + list(weights_nmda.values())
+ + [HTML(value="Synaptic delays")]
+ + list(delays.values())
+ )
return widgets_list, widgets_dict
-def _get_rhythmic_widget(name, tstop_widget, layout, style, location,
- data=None, default_weights_ampa=None,
- default_weights_nmda=None, default_delays=None):
+def _get_rhythmic_widget(
+ name,
+ tstop_widget,
+ layout,
+ style,
+ location,
+ data=None,
+ default_weights_ampa=None,
+ default_weights_nmda=None,
+ default_delays=None,
+):
default_data = {
- 'tstart': 0.,
- 'tstart_std': 0.,
- 'tstop': 0.,
- 'burst_rate': 7.5,
- 'burst_std': 0,
- 'repeats': 1,
- 'seedcore': 14,
+ "tstart": 0.0,
+ "tstart_std": 0.0,
+ "tstop": 0.0,
+ "burst_rate": 7.5,
+ "burst_std": 0,
+ "repeats": 1,
+ "seedcore": 14,
}
if isinstance(data, dict):
default_data.update(data)
kwargs = dict(layout=layout, style=style)
tstart = BoundedFloatText(
- value=default_data['tstart'], description='Start time (ms)',
- min=0, max=1e6, **kwargs)
+ value=default_data["tstart"],
+ description="Start time (ms)",
+ min=0,
+ max=1e6,
+ **kwargs,
+ )
tstart_std = BoundedFloatText(
- value=default_data['tstart_std'], description='Start time dev (ms)',
- min=0, max=1e6, **kwargs)
+ value=default_data["tstart_std"],
+ description="Start time dev (ms)",
+ min=0,
+ max=1e6,
+ **kwargs,
+ )
tstop = BoundedFloatText(
- value=default_data['tstop'],
- description='Stop time (ms)',
+ value=default_data["tstop"],
+ description="Stop time (ms)",
max=tstop_widget.value,
**kwargs,
)
burst_rate = BoundedFloatText(
- value=default_data['burst_rate'], description='Burst rate (Hz)',
- min=0, max=1e6, **kwargs)
+ value=default_data["burst_rate"],
+ description="Burst rate (Hz)",
+ min=0,
+ max=1e6,
+ **kwargs,
+ )
burst_std = BoundedFloatText(
- value=default_data['burst_std'], description='Burst std dev (Hz)',
- min=0, max=1e6, **kwargs)
+ value=default_data["burst_std"],
+ description="Burst std dev (Hz)",
+ min=0,
+ max=1e6,
+ **kwargs,
+ )
repeats = BoundedIntText(
- value=default_data['repeats'], description='Repeats', min=0,
- max=int(1e6), **kwargs)
- seedcore = IntText(value=default_data['seedcore'],
- description='Seed',
- **kwargs)
+ value=default_data["repeats"],
+ description="Repeats",
+ min=0,
+ max=int(1e6),
+ **kwargs,
+ )
+ seedcore = IntText(value=default_data["seedcore"], description="Seed", **kwargs)
widgets_list, widgets_dict = _get_cell_specific_widgets(
layout,
style,
location,
data={
- 'weights_ampa': default_weights_ampa,
- 'weights_nmda': default_weights_nmda,
- 'delays': default_delays,
+ "weights_ampa": default_weights_ampa,
+ "weights_nmda": default_weights_nmda,
+ "delays": default_delays,
},
)
drive_box = VBox(
- [tstart, tstart_std, tstop, burst_rate, burst_std, repeats, seedcore] +
- widgets_list)
- drive = dict(type='Rhythmic',
- name=name,
- tstart=tstart,
- tstart_std=tstart_std,
- burst_rate=burst_rate,
- burst_std=burst_std,
- repeats=repeats,
- seedcore=seedcore,
- location=location,
- tstop=tstop)
+ [tstart, tstart_std, tstop, burst_rate, burst_std, repeats, seedcore]
+ + widgets_list
+ )
+ drive = dict(
+ type="Rhythmic",
+ name=name,
+ tstart=tstart,
+ tstart_std=tstart_std,
+ burst_rate=burst_rate,
+ burst_std=burst_std,
+ repeats=repeats,
+ seedcore=seedcore,
+ location=location,
+ tstop=tstop,
+ )
drive.update(widgets_dict)
return drive, drive_box
-def _get_poisson_widget(name, tstop_widget, layout, style, location, data=None,
- default_weights_ampa=None, default_weights_nmda=None,
- default_delays=None):
+def _get_poisson_widget(
+ name,
+ tstop_widget,
+ layout,
+ style,
+ location,
+ data=None,
+ default_weights_ampa=None,
+ default_weights_nmda=None,
+ default_delays=None,
+):
default_data = {
- 'tstart': 0.0,
- 'tstop': 0.0,
- 'seedcore': 14,
- 'rate_constant': {
- 'L5_pyramidal': 8.5,
- 'L2_pyramidal': 8.5,
- 'L5_basket': 8.5,
- 'L2_basket': 8.5,
- }
+ "tstart": 0.0,
+ "tstop": 0.0,
+ "seedcore": 14,
+ "rate_constant": {
+ "L5_pyramidal": 8.5,
+ "L2_pyramidal": 8.5,
+ "L5_basket": 8.5,
+ "L2_basket": 8.5,
+ },
}
if isinstance(data, dict):
default_data.update(data)
tstart = BoundedFloatText(
- value=default_data['tstart'], description='Start time (ms)',
- min=0, max=1e6, layout=layout, style=style)
+ value=default_data["tstart"],
+ description="Start time (ms)",
+ min=0,
+ max=1e6,
+ layout=layout,
+ style=style,
+ )
tstop = BoundedFloatText(
- value=default_data['tstop'],
+ value=default_data["tstop"],
max=tstop_widget.value,
- description='Stop time (ms)',
+ description="Stop time (ms)",
layout=layout,
style=style,
)
- seedcore = IntText(value=default_data['seedcore'],
- description='Seed',
- layout=layout,
- style=style)
+ seedcore = IntText(
+ value=default_data["seedcore"], description="Seed", layout=layout, style=style
+ )
- cell_types = ['L5_pyramidal', 'L2_pyramidal', 'L5_basket', 'L2_basket']
+ cell_types = ["L5_pyramidal", "L2_pyramidal", "L5_basket", "L2_basket"]
rate_constant = dict()
for cell_type in cell_types:
- rate_constant[f'{cell_type}'] = BoundedFloatText(
- value=default_data['rate_constant'][cell_type],
- description=f'{cell_type}:', min=0, max=1e6, step=0.01,
- layout=layout, style=style)
+ rate_constant[f"{cell_type}"] = BoundedFloatText(
+ value=default_data["rate_constant"][cell_type],
+ description=f"{cell_type}:",
+ min=0,
+ max=1e6,
+ step=0.01,
+ layout=layout,
+ style=style,
+ )
widgets_list, widgets_dict = _get_cell_specific_widgets(
layout,
style,
location,
data={
- 'weights_ampa': default_weights_ampa,
- 'weights_nmda': default_weights_nmda,
- 'delays': default_delays,
+ "weights_ampa": default_weights_ampa,
+ "weights_nmda": default_weights_nmda,
+ "delays": default_delays,
},
)
- widgets_dict.update({'rate_constant': rate_constant})
- widgets_list.extend([HTML(value="Rate constants")] +
- list(widgets_dict['rate_constant'].values()))
+ widgets_dict.update({"rate_constant": rate_constant})
+ widgets_list.extend(
+ [HTML(value="Rate constants")]
+ + list(widgets_dict["rate_constant"].values())
+ )
drive_box = VBox([tstart, tstop, seedcore] + widgets_list)
drive = dict(
- type='Poisson',
+ type="Poisson",
name=name,
tstart=tstart,
tstop=tstop,
@@ -890,66 +1123,92 @@ def _get_poisson_widget(name, tstop_widget, layout, style, location, data=None,
return drive, drive_box
-def _get_evoked_widget(name, layout, style, location, data=None,
- default_weights_ampa=None, default_weights_nmda=None,
- default_delays=None):
+def _get_evoked_widget(
+ name,
+ layout,
+ style,
+ location,
+ data=None,
+ default_weights_ampa=None,
+ default_weights_nmda=None,
+ default_delays=None,
+):
default_data = {
- 'mu': 0,
- 'sigma': 1,
- 'numspikes': 1,
- 'seedcore': 14,
+ "mu": 0,
+ "sigma": 1,
+ "numspikes": 1,
+ "seedcore": 14,
}
if isinstance(data, dict):
default_data.update(data)
kwargs = dict(layout=layout, style=style)
mu = BoundedFloatText(
- value=default_data['mu'], description='Mean time:', min=0, max=1e6,
- step=0.01, **kwargs)
+ value=default_data["mu"],
+ description="Mean time:",
+ min=0,
+ max=1e6,
+ step=0.01,
+ **kwargs,
+ )
sigma = BoundedFloatText(
- value=default_data['sigma'], description='Std dev time:', min=0,
- max=1e6, step=0.01, **kwargs)
- numspikes = IntText(value=default_data['numspikes'],
- description='No. Spikes:',
- **kwargs)
- seedcore = IntText(value=default_data['seedcore'],
- description='Seed: ',
- **kwargs)
+ value=default_data["sigma"],
+ description="Std dev time:",
+ min=0,
+ max=1e6,
+ step=0.01,
+ **kwargs,
+ )
+ numspikes = IntText(
+ value=default_data["numspikes"], description="No. Spikes:", **kwargs
+ )
+ seedcore = IntText(value=default_data["seedcore"], description="Seed: ", **kwargs)
widgets_list, widgets_dict = _get_cell_specific_widgets(
layout,
style,
location,
data={
- 'weights_ampa': default_weights_ampa,
- 'weights_nmda': default_weights_nmda,
- 'delays': default_delays,
+ "weights_ampa": default_weights_ampa,
+ "weights_nmda": default_weights_nmda,
+ "delays": default_delays,
},
)
drive_box = VBox([mu, sigma, numspikes, seedcore] + widgets_list)
- drive = dict(type='Evoked',
- name=name,
- mu=mu,
- sigma=sigma,
- numspikes=numspikes,
- seedcore=seedcore,
- location=location,
- sync_within_trial=False)
+ drive = dict(
+ type="Evoked",
+ name=name,
+ mu=mu,
+ sigma=sigma,
+ numspikes=numspikes,
+ seedcore=seedcore,
+ location=location,
+ sync_within_trial=False,
+ )
drive.update(widgets_dict)
return drive, drive_box
-def add_drive_widget(drive_type, drive_boxes, drive_widgets, drives_out,
- tstop_widget, location, layout,
- prespecified_drive_name=None,
- prespecified_drive_data=None,
- prespecified_weights_ampa=None,
- prespecified_weights_nmda=None,
- prespecified_delays=None, render=True,
- expand_last_drive=True, event_seed=14):
+def add_drive_widget(
+ drive_type,
+ drive_boxes,
+ drive_widgets,
+ drives_out,
+ tstop_widget,
+ location,
+ layout,
+ prespecified_drive_name=None,
+ prespecified_drive_data=None,
+ prespecified_weights_ampa=None,
+ prespecified_weights_nmda=None,
+ prespecified_delays=None,
+ render=True,
+ expand_last_drive=True,
+ event_seed=14,
+):
"""Add a widget for a new drive."""
- style = {'description_width': '150px'}
+ style = {"description_width": "150px"}
drives_out.clear_output()
if not prespecified_drive_data:
prespecified_drive_data = {}
@@ -960,7 +1219,7 @@ def add_drive_widget(drive_type, drive_boxes, drive_widgets, drives_out,
name = drive_type + str(len(drive_boxes))
else:
name = prespecified_drive_name
- if drive_type in ('Rhythmic', 'Bursty'):
+ if drive_type in ("Rhythmic", "Bursty"):
drive, drive_box = _get_rhythmic_widget(
name,
tstop_widget,
@@ -972,7 +1231,7 @@ def add_drive_widget(drive_type, drive_boxes, drive_widgets, drives_out,
default_weights_nmda=prespecified_weights_nmda,
default_delays=prespecified_delays,
)
- elif drive_type == 'Poisson':
+ elif drive_type == "Poisson":
drive, drive_box = _get_poisson_widget(
name,
tstop_widget,
@@ -984,7 +1243,7 @@ def add_drive_widget(drive_type, drive_boxes, drive_widgets, drives_out,
default_weights_nmda=prespecified_weights_nmda,
default_delays=prespecified_delays,
)
- elif drive_type in ('Evoked', 'Gaussian'):
+ elif drive_type in ("Evoked", "Gaussian"):
drive, drive_box = _get_evoked_widget(
name,
layout,
@@ -996,33 +1255,28 @@ def add_drive_widget(drive_type, drive_boxes, drive_widgets, drives_out,
default_delays=prespecified_delays,
)
- if drive_type in [
- 'Evoked', 'Poisson', 'Rhythmic', 'Bursty', 'Gaussian'
- ]:
+ if drive_type in ["Evoked", "Poisson", "Rhythmic", "Bursty", "Gaussian"]:
drive_boxes.append(drive_box)
drive_widgets.append(drive)
if render:
accordion = Accordion(
children=drive_boxes,
- selected_index=len(drive_boxes) -
- 1 if expand_last_drive else None,
+ selected_index=len(drive_boxes) - 1 if expand_last_drive else None,
)
for idx, drive in enumerate(drive_widgets):
- accordion.set_title(idx,
- f"{drive['name']} ({drive['location']})")
+ accordion.set_title(idx, f"{drive['name']} ({drive['location']})")
display(accordion)
-def add_connectivity_tab(params, connectivity_out,
- connectivity_textfields):
+def add_connectivity_tab(params, connectivity_out, connectivity_textfields):
"""Add all possible connectivity boxes to connectivity tab."""
net = jones_2009_model(params)
cell_types = [ct for ct in net.cell_types.keys()]
- receptors = ('ampa', 'nmda', 'gabaa', 'gabab')
- locations = ('proximal', 'distal', 'soma')
+ receptors = ("ampa", "nmda", "gabaa", "gabab")
+ locations = ("proximal", "distal", "soma")
# clear existing connectivity
connectivity_out.clear_output()
@@ -1036,18 +1290,18 @@ def add_connectivity_tab(params, connectivity_out,
# the connectivity list should be built on this level
receptor_related_conn = {}
for receptor in receptors:
- conn_indices = pick_connection(net=net,
- src_gids=src_gids,
- target_gids=target_gids,
- loc=location,
- receptor=receptor)
+ conn_indices = pick_connection(
+ net=net,
+ src_gids=src_gids,
+ target_gids=target_gids,
+ loc=location,
+ receptor=receptor,
+ )
if len(conn_indices) > 0:
assert len(conn_indices) == 1
conn_idx = conn_indices[0]
- current_w = net.connectivity[
- conn_idx]['nc_dict']['A_weight']
- current_p = net.connectivity[
- conn_idx]['probability']
+ current_w = net.connectivity[conn_idx]["nc_dict"]["A_weight"]
+ current_p = net.connectivity[conn_idx]["probability"]
# valid connection
receptor_related_conn[receptor] = {
"weight": current_w,
@@ -1059,10 +1313,10 @@ def add_connectivity_tab(params, connectivity_out,
"target_gids": target_gids,
}
if len(receptor_related_conn) > 0:
- connectivity_names.append(
- f"{src_gids}→{target_gids} ({location})")
+ connectivity_names.append(f"{src_gids}→{target_gids} ({location})")
connectivity_textfields.append(
- _get_connectivity_widgets(receptor_related_conn))
+ _get_connectivity_widgets(receptor_related_conn)
+ )
connectivity_boxes = [VBox(slider) for slider in connectivity_textfields]
cell_connectivity = Accordion(children=connectivity_boxes)
@@ -1075,11 +1329,11 @@ def add_connectivity_tab(params, connectivity_out,
return net
-def add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop,
- layout):
+def add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop, layout):
net = jones_2009_model(params)
drive_specs = _extract_drive_specs_from_hnn_params(
- params, list(net.cell_types.keys()), legacy_mode=net._legacy_mode)
+ params, list(net.cell_types.keys()), legacy_mode=net._legacy_mode
+ )
# clear before adding drives
drives_out.clear_output()
@@ -1093,111 +1347,136 @@ def add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop,
should_render = idx == (len(drive_names) - 1)
add_drive_widget(
- specs['type'].capitalize(),
+ specs["type"].capitalize(),
drive_boxes,
drive_widgets,
drives_out,
tstop,
- specs['location'],
+ specs["location"],
layout=layout,
prespecified_drive_name=drive_name,
- prespecified_drive_data=specs['dynamics'],
- prespecified_weights_ampa=specs['weights_ampa'],
- prespecified_weights_nmda=specs['weights_nmda'],
- prespecified_delays=specs['synaptic_delays'],
+ prespecified_drive_data=specs["dynamics"],
+ prespecified_weights_ampa=specs["weights_ampa"],
+ prespecified_weights_nmda=specs["weights_nmda"],
+ prespecified_delays=specs["synaptic_delays"],
render=should_render,
expand_last_drive=False,
- event_seed=specs['event_seed'],
+ event_seed=specs["event_seed"],
)
-def load_drive_and_connectivity(params, log_out, drives_out,
- drive_widgets, drive_boxes, connectivity_out,
- connectivity_textfields, tstop, layout):
+def load_drive_and_connectivity(
+ params,
+ log_out,
+ drives_out,
+ drive_widgets,
+ drive_boxes,
+ connectivity_out,
+ connectivity_textfields,
+ tstop,
+ layout,
+):
"""Add drive and connectivity ipywidgets from params."""
log_out.clear_output()
with log_out:
# Add connectivity
add_connectivity_tab(params, connectivity_out, connectivity_textfields)
# Add drives
- add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop,
- layout)
+ add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop, layout)
def on_upload_data_change(change, data, viz_manager, log_out):
- if len(change['owner'].value) == 0:
+ if len(change["owner"].value) == 0:
logger.info("Empty change")
return
- data_dict = change['new'][0]
+ data_dict = change["new"][0]
- data_fname = data_dict['name'].rstrip('.txt')
- if data_fname in data['simulation_data'].keys():
+ data_fname = data_dict["name"].rstrip(".txt")
+ if data_fname in data["simulation_data"].keys():
logger.error(f"Found existing data: {data_fname}.")
return
- ext_content = data_dict['content']
+ ext_content = data_dict["content"]
ext_content = codecs.decode(ext_content, encoding="utf-8")
with log_out:
- data['simulation_data'][data_fname] = {'net': None, 'dpls': [
- hnn_core.read_dipole(io.StringIO(ext_content))
- ]}
- logger.info(f'External data {data_fname} loaded.')
- viz_manager.reset_fig_config_tabs(template_name='single figure')
+ data["simulation_data"][data_fname] = {
+ "net": None,
+ "dpls": [hnn_core.read_dipole(io.StringIO(ext_content))],
+ }
+ logger.info(f"External data {data_fname} loaded.")
+ viz_manager.reset_fig_config_tabs(template_name="single figure")
viz_manager.add_figure()
- fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1)
+ fig_name = _idx2figname(viz_manager.data["fig_idx"]["idx"] - 1)
ax_plots = [("ax0", "current dipole")]
for ax_name, plot_type in ax_plots:
viz_manager._simulate_edit_figure(
- fig_name, ax_name, data_fname, plot_type, {}, "plot")
+ fig_name, ax_name, data_fname, plot_type, {}, "plot"
+ )
-def on_upload_params_change(change, params, tstop, dt, log_out, drive_boxes,
- drive_widgets, drives_out, connectivity_out,
- connectivity_textfields, layout, load_type):
- if len(change['owner'].value) == 0:
+def on_upload_params_change(
+ change,
+ params,
+ tstop,
+ dt,
+ log_out,
+ drive_boxes,
+ drive_widgets,
+ drives_out,
+ connectivity_out,
+ connectivity_textfields,
+ layout,
+ load_type,
+):
+ if len(change["owner"].value) == 0:
logger.info("Empty change")
return
logger.info("Loading connectivity...")
- param_dict = change['new'][0]
- params_fname = param_dict['name']
- param_data = param_dict['content']
+ param_dict = change["new"][0]
+ params_fname = param_dict["name"]
+ param_data = param_dict["content"]
param_data = codecs.decode(param_data, encoding="utf-8")
ext = Path(params_fname).suffix
- read_func = {'.json': _read_json, '.param': _read_legacy_params}
+ read_func = {".json": _read_json, ".param": _read_legacy_params}
params_network = read_func[ext](param_data)
# update simulation settings and params
log_out.clear_output()
with log_out:
- if 'tstop' in params_network.keys():
- tstop.value = params_network['tstop']
- if 'dt' in params_network.keys():
- dt.value = params_network['dt']
+ if "tstop" in params_network.keys():
+ tstop.value = params_network["tstop"]
+ if "dt" in params_network.keys():
+ dt.value = params_network["dt"]
params.update(params_network)
# init network, add drives & connectivity
- if load_type == 'connectivity':
+ if load_type == "connectivity":
add_connectivity_tab(params, connectivity_out, connectivity_textfields)
- elif load_type == 'drives':
- add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop,
- layout)
+ elif load_type == "drives":
+ add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop, layout)
else:
raise ValueError
# Resets file counter to 0
- change['owner'].set_trait('value', ([]))
-
-
-def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
- drive_widgets, connectivity_textfields,
- add_drive=True):
+ change["owner"].set_trait("value", ([]))
+
+
+def _init_network_from_widgets(
+ params,
+ dt,
+ tstop,
+ single_simulation_data,
+ drive_widgets,
+ connectivity_textfields,
+ add_drive=True,
+):
"""Construct network and add drives."""
print("init network")
- params['dt'] = dt.value
- params['tstop'] = tstop.value
- single_simulation_data['net'] = jones_2009_model(
+ params["dt"] = dt.value
+ params["tstop"] = tstop.value
+ single_simulation_data["net"] = jones_2009_model(
params,
add_drives_from_params=False,
)
@@ -1205,136 +1484,146 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
for connectivity_slider in connectivity_textfields:
for vbox in connectivity_slider:
conn_indices = pick_connection(
- net=single_simulation_data['net'],
- src_gids=vbox._belongsto['src_gids'],
- target_gids=vbox._belongsto['target_gids'],
- loc=vbox._belongsto['location'],
- receptor=vbox._belongsto['receptor'])
+ net=single_simulation_data["net"],
+ src_gids=vbox._belongsto["src_gids"],
+ target_gids=vbox._belongsto["target_gids"],
+ loc=vbox._belongsto["location"],
+ receptor=vbox._belongsto["receptor"],
+ )
if len(conn_indices) > 0:
assert len(conn_indices) == 1
conn_idx = conn_indices[0]
- single_simulation_data['net'].connectivity[conn_idx][
- 'nc_dict']['A_weight'] = vbox.children[1].value
- single_simulation_data['net'].connectivity[conn_idx][
- 'probability'] = vbox.children[2].value
+ single_simulation_data["net"].connectivity[conn_idx]["nc_dict"][
+ "A_weight"
+ ] = vbox.children[1].value
+ single_simulation_data["net"].connectivity[conn_idx]["probability"] = (
+ vbox.children[2].value
+ )
if add_drive is False:
return
# add drives to network
for drive in drive_widgets:
- weights_ampa = {
- k: v.value
- for k, v in drive['weights_ampa'].items()
- }
- weights_nmda = {
- k: v.value
- for k, v in drive['weights_nmda'].items()
- }
- synaptic_delays = {k: v.value for k, v in drive['delays'].items()}
- print(
- f"drive type is {drive['type']}, location={drive['location']}")
- if drive['type'] == 'Poisson':
+ weights_ampa = {k: v.value for k, v in drive["weights_ampa"].items()}
+ weights_nmda = {k: v.value for k, v in drive["weights_nmda"].items()}
+ synaptic_delays = {k: v.value for k, v in drive["delays"].items()}
+ print(f"drive type is {drive['type']}, location={drive['location']}")
+ if drive["type"] == "Poisson":
rate_constant = {
- k: v.value
- for k, v in drive['rate_constant'].items() if v.value > 0
- }
- weights_ampa = {
- k: v
- for k, v in weights_ampa.items() if k in rate_constant
- }
- weights_nmda = {
- k: v
- for k, v in weights_nmda.items() if k in rate_constant
+ k: v.value for k, v in drive["rate_constant"].items() if v.value > 0
}
- single_simulation_data['net'].add_poisson_drive(
- name=drive['name'],
- tstart=drive['tstart'].value,
- tstop=drive['tstop'].value,
+ weights_ampa = {k: v for k, v in weights_ampa.items() if k in rate_constant}
+ weights_nmda = {k: v for k, v in weights_nmda.items() if k in rate_constant}
+ single_simulation_data["net"].add_poisson_drive(
+ name=drive["name"],
+ tstart=drive["tstart"].value,
+ tstop=drive["tstop"].value,
rate_constant=rate_constant,
- location=drive['location'],
+ location=drive["location"],
weights_ampa=weights_ampa,
weights_nmda=weights_nmda,
synaptic_delays=synaptic_delays,
space_constant=100.0,
- event_seed=drive['seedcore'].value)
- elif drive['type'] in ('Evoked', 'Gaussian'):
- single_simulation_data['net'].add_evoked_drive(
- name=drive['name'],
- mu=drive['mu'].value,
- sigma=drive['sigma'].value,
- numspikes=drive['numspikes'].value,
- location=drive['location'],
+ event_seed=drive["seedcore"].value,
+ )
+ elif drive["type"] in ("Evoked", "Gaussian"):
+ single_simulation_data["net"].add_evoked_drive(
+ name=drive["name"],
+ mu=drive["mu"].value,
+ sigma=drive["sigma"].value,
+ numspikes=drive["numspikes"].value,
+ location=drive["location"],
weights_ampa=weights_ampa,
weights_nmda=weights_nmda,
synaptic_delays=synaptic_delays,
space_constant=3.0,
- event_seed=drive['seedcore'].value)
- elif drive['type'] in ('Rhythmic', 'Bursty'):
- single_simulation_data['net'].add_bursty_drive(
- name=drive['name'],
- tstart=drive['tstart'].value,
- tstart_std=drive['tstart_std'].value,
- burst_rate=drive['burst_rate'].value,
- burst_std=drive['burst_std'].value,
- location=drive['location'],
- tstop=drive['tstop'].value,
+ event_seed=drive["seedcore"].value,
+ )
+ elif drive["type"] in ("Rhythmic", "Bursty"):
+ single_simulation_data["net"].add_bursty_drive(
+ name=drive["name"],
+ tstart=drive["tstart"].value,
+ tstart_std=drive["tstart_std"].value,
+ burst_rate=drive["burst_rate"].value,
+ burst_std=drive["burst_std"].value,
+ location=drive["location"],
+ tstop=drive["tstop"].value,
weights_ampa=weights_ampa,
weights_nmda=weights_nmda,
synaptic_delays=synaptic_delays,
- event_seed=drive['seedcore'].value)
+ event_seed=drive["seedcore"].value,
+ )
-def run_button_clicked(widget_simulation_name, log_out, drive_widgets,
- all_data, dt, tstop, ntrials, backend_selection,
- mpi_cmd, n_jobs, params, simulation_status_bar,
- simulation_status_contents, connectivity_textfields,
- viz_manager):
+def run_button_clicked(
+ widget_simulation_name,
+ log_out,
+ drive_widgets,
+ all_data,
+ dt,
+ tstop,
+ ntrials,
+ backend_selection,
+ mpi_cmd,
+ n_jobs,
+ params,
+ simulation_status_bar,
+ simulation_status_contents,
+ connectivity_textfields,
+ viz_manager,
+):
"""Run the simulation and plot outputs."""
log_out.clear_output()
simulation_data = all_data["simulation_data"]
with log_out:
# clear empty trash simulations
for _name in tuple(simulation_data.keys()):
- if len(simulation_data[_name]['dpls']) == 0:
+ if len(simulation_data[_name]["dpls"]) == 0:
del simulation_data[_name]
_sim_name = widget_simulation_name.value
- if simulation_data[_sim_name]['net'] is not None:
+ if simulation_data[_sim_name]["net"] is not None:
print("Simulation with the same name exists!")
- simulation_status_bar.value = simulation_status_contents[
- 'failed']
+ simulation_status_bar.value = simulation_status_contents["failed"]
return
- _init_network_from_widgets(params, dt, tstop,
- simulation_data[_sim_name], drive_widgets,
- connectivity_textfields)
+ _init_network_from_widgets(
+ params,
+ dt,
+ tstop,
+ simulation_data[_sim_name],
+ drive_widgets,
+ connectivity_textfields,
+ )
print("start simulation")
if backend_selection.value == "MPI":
backend = MPIBackend(
- n_procs=multiprocessing.cpu_count() - 1, mpi_cmd=mpi_cmd.value)
+ n_procs=multiprocessing.cpu_count() - 1, mpi_cmd=mpi_cmd.value
+ )
else:
backend = JoblibBackend(n_jobs=n_jobs.value)
print(f"Using Joblib with {n_jobs.value} core(s).")
with backend:
- simulation_status_bar.value = simulation_status_contents['running']
- simulation_data[_sim_name]['dpls'] = simulate_dipole(
- simulation_data[_sim_name]['net'],
+ simulation_status_bar.value = simulation_status_contents["running"]
+ simulation_data[_sim_name]["dpls"] = simulate_dipole(
+ simulation_data[_sim_name]["net"],
tstop=tstop.value,
dt=dt.value,
- n_trials=ntrials.value)
+ n_trials=ntrials.value,
+ )
- simulation_status_bar.value = simulation_status_contents[
- 'finished']
+ simulation_status_bar.value = simulation_status_contents["finished"]
viz_manager.reset_fig_config_tabs()
viz_manager.add_figure()
- fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1)
+ fig_name = _idx2figname(viz_manager.data["fig_idx"]["idx"] - 1)
ax_plots = [("ax0", "input histogram"), ("ax1", "current dipole")]
for ax_name, plot_type in ax_plots:
- viz_manager._simulate_edit_figure(fig_name, ax_name, _sim_name,
- plot_type, {}, "plot")
+ viz_manager._simulate_edit_figure(
+ fig_name, ax_name, _sim_name, plot_type, {}, "plot"
+ )
def handle_backend_change(backend_type, backend_config, mpi_cmd, n_jobs):
@@ -1353,5 +1642,6 @@ def launch():
You can pass voila commandline parameters as usual.
"""
from voila.app import main
- notebook_path = Path(__file__).parent / 'hnn_widget.ipynb'
+
+ notebook_path = Path(__file__).parent / "hnn_widget.ipynb"
main([str(notebook_path.resolve()), *sys.argv[1:]])