diff --git a/.github/workflows/unix_unit_tests.yml b/.github/workflows/unix_unit_tests.yml index 1e6ff5330..e21068a84 100644 --- a/.github/workflows/unix_unit_tests.yml +++ b/.github/workflows/unix_unit_tests.yml @@ -54,10 +54,14 @@ shell: bash -el {0} run: | pip install --verbose '.[opt, parallel, test, gui]' - - name: Lint with flake8 + - name: Lint with ruff shell: bash -el {0} run: | - flake8 --count hnn_core + ruff check --no-fix + - name: Check formatting with ruff + shell: bash -el {0} + run: | + ruff format --check - name: Test with pytest shell: bash -el {0} run: | @@ -65,4 +69,4 @@ - name: Upload coverage to Codecov shell: bash -el {0} run: | - bash <(curl -s https://codecov.io/bash) -f ./coverage.xml \ No newline at end of file + bash <(curl -s https://codecov.io/bash) -f ./coverage.xml diff --git a/Makefile b/Makefile index 9970666ed..19fb07e93 100644 --- a/Makefile +++ b/Makefile @@ -7,24 +7,22 @@ all: modl -modl: - cd hnn_core/mod/ && nrnivmodl - clean : rm -rf hnn_core/mod/x86_64/* check-manifest: check-manifest -test: flake - pytest . +format: + ruff format + +lint: + ruff check -flake: - @if command -v flake8 > /dev/null; then \ - echo "Running flake8"; \ - flake8 hnn_core --count; \ - else \ - echo "flake8 not found, please install it!"; \ - exit 1; \ - fi; - @echo "flake8 passed" +modl: + cd hnn_core/mod/ && nrnivmodl + +test: + ruff format --check + ruff check --no-fix + pytest . diff --git a/dev_scripts/convert_params.py b/dev_scripts/convert_params.py index 9f52d0331..8827cbab0 100644 --- a/dev_scripts/convert_params.py +++ b/dev_scripts/convert_params.py @@ -31,7 +31,7 @@ def download_folder_contents(owner, repo, path): ------- Path to temporary directory or None """ - url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}" + url = f'https://api.github.com/repos/{owner}/{repo}/contents/{path}' try: response = requests.get(url) @@ -47,7 +47,7 @@ def download_folder_contents(owner, repo, path): file_name = os.path.join(temp_dir, item['name']) with open(file_name, 'wb') as f: f.write(requests.get(download_url).content) - print(f"Downloaded: {file_name}") + print(f'Downloaded: {file_name}') return temp_dir @@ -70,36 +70,38 @@ def convert_param_files_from_repo(owner, repo, repo_path, local_path): # Download param files temp_dir = download_folder_contents(owner, repo, repo_path) # Get list of json and param files - file_list = [Path(temp_dir, f) - for f in os.listdir(temp_dir) - if f.endswith('.param') or f.endswith('.json')] + file_list = [ + Path(temp_dir, f) + for f in os.listdir(temp_dir) + if f.endswith('.param') or f.endswith('.json') + ] # Assign output location and names output_dir = Path(local_path) if not os.path.exists(output_dir): os.makedirs(output_dir) - output_filenames = [Path(output_dir, f.name.split('.')[0]) - for f in file_list] + output_filenames = [Path(output_dir, f.name.split('.')[0]) for f in file_list] - [convert_to_json(file, outfile) - for (file, outfile) in zip(file_list, output_filenames)] + [ + convert_to_json(file, outfile) + for (file, outfile) in zip(file_list, output_filenames) + ] # Delete downloads shutil.rmtree(temp_dir) if __name__ == '__main__': - # hnn param files - convert_param_files_from_repo(owner='jonescompneurolab', - repo='hnn', - repo_path='param', - local_path=(root_path / - 'network_configuration'), - ) + convert_param_files_from_repo( + owner='jonescompneurolab', + repo='hnn', + repo_path='param', + local_path=(root_path / 'network_configuration'), + ) # hnn-core json files - convert_param_files_from_repo(owner='jonescompneurolab', - repo='hnn-core', - repo_path='hnn_core/param', - local_path=(root_path / - 'network_configuration'), - ) + convert_param_files_from_repo( + owner='jonescompneurolab', + repo='hnn-core', + repo_path='hnn_core/param', + local_path=(root_path / 'network_configuration'), + ) diff --git a/doc/conf.py b/doc/conf.py index ec461fffa..ad948dd52 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -14,8 +14,8 @@ # import os import sys + # sys.path.insert(0, os.path.abspath('.')) -import sphinx_gallery from sphinx_gallery.sorting import ExampleTitleSortKey, ExplicitOrder import sphinx_bootstrap_theme @@ -57,7 +57,7 @@ 'sphinx.ext.intersphinx', 'numpydoc', 'sphinx_copybutton', - 'gh_substitutions' # custom extension, see ./sphinxext/gh_substitutions.py + 'gh_substitutions', # custom extension, see ./sphinxext/gh_substitutions.py ] # generate autosummary even if no references @@ -68,7 +68,7 @@ default_role = 'autolink' # XXX silently allows bad syntax, someone should fix # Sphinx-Copybutton configuration -copybutton_prompt_text = r">>> |\.\.\. |\$ " +copybutton_prompt_text = r'>>> |\.\.\. |\$ ' copybutton_prompt_is_regexp = True # Add any paths that contain templates here, relative to this directory. @@ -113,13 +113,13 @@ html_theme_options = { 'navbar_sidebarrel': False, 'navbar_links': [ - ("Examples", "auto_examples/index"), - ("API", "api"), - ("Glossary", "glossary"), - ("What's new", "whats_new"), - ("GitHub", "https://github.com/jonescompneurolab/hnn-core", True) + ('Examples', 'auto_examples/index'), + ('API', 'api'), + ('Glossary', 'glossary'), + ("What's new", 'whats_new'), + ('GitHub', 'https://github.com/jonescompneurolab/hnn-core', True), ], - 'bootswatch_theme': "yeti" + 'bootswatch_theme': 'yeti', } # Add any paths that contain custom static files (such as style sheets) here, @@ -150,15 +150,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -168,8 +165,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'hnn-core.tex', 'hnn-core Documentation', - 'Mainak Jas', 'manual'), + (master_doc, 'hnn-core.tex', 'hnn-core Documentation', 'Mainak Jas', 'manual'), ] @@ -177,10 +173,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'hnn-core', 'hnn-core Documentation', - [author], 1) -] +man_pages = [(master_doc, 'hnn-core', 'hnn-core Documentation', [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -189,9 +182,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'hnn-core', 'hnn-core Documentation', - author, 'hnn-core', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + 'hnn-core', + 'hnn-core Documentation', + author, + 'hnn-core', + 'One line description of project.', + 'Miscellaneous', + ), ] intersphinx_mapping = { @@ -199,7 +198,7 @@ 'mne': ('https://mne.tools/dev', None), 'numpy': ('https://numpy.org/devdocs', None), 'scipy': ('https://scipy.github.io/devdocs', None), - 'matplotlib': ('https://matplotlib.org', None) + 'matplotlib': ('https://matplotlib.org', None), } intersphinx_timeout = 5 @@ -209,7 +208,7 @@ 'https://doi.org/10.1152/jn.00535.2009', 'https://doi.org/10.1152/jn.00122.2010', 'https://doi.org/10.1101/2021.04.16.440210', - 'https://groups.google.com/g/hnnsolver' + 'https://groups.google.com/g/hnnsolver', ] # Resolve binder filepath_prefix. From the docs: @@ -224,29 +223,28 @@ filepath_prefix = 'v{}'.format(version) sphinx_gallery_conf = { - 'first_notebook_cell': ("import pyvista as pv\n" - "from mne.viz import set_3d_backend\n" - "set_3d_backend('notebook')\n" - "pv.set_jupyter_backend('client')" - ), + 'first_notebook_cell': ( + 'import pyvista as pv\n' + 'from mne.viz import set_3d_backend\n' + "set_3d_backend('notebook')\n" + "pv.set_jupyter_backend('client')" + ), 'doc_module': 'hnn_core', # path to your examples scripts 'examples_dirs': '../examples', # path where to save gallery generated examples 'gallery_dirs': 'auto_examples', 'backreferences_dir': 'generated', - 'reference_url': { - 'hnn_core': None - }, + 'reference_url': {'hnn_core': None}, 'within_subsection_order': ExampleTitleSortKey, - 'subsection_order': ExplicitOrder(['../examples/workflows/', - '../examples/howto/']), - 'binder': {'org': 'jonescompneurolab', - 'repo': 'hnn-core', - 'branch': 'gh-pages', - 'binderhub_url': 'https://mybinder.org', - 'filepath_prefix': filepath_prefix, - 'notebooks_dir': 'notebooks', - 'dependencies': 'Dockerfile' - } + 'subsection_order': ExplicitOrder(['../examples/workflows/', '../examples/howto/']), + 'binder': { + 'org': 'jonescompneurolab', + 'repo': 'hnn-core', + 'branch': 'gh-pages', + 'binderhub_url': 'https://mybinder.org', + 'filepath_prefix': filepath_prefix, + 'notebooks_dir': 'notebooks', + 'dependencies': 'Dockerfile', + }, } diff --git a/doc/sphinxext/gh_substitutions.py b/doc/sphinxext/gh_substitutions.py index b22ec3432..4f014a02c 100644 --- a/doc/sphinxext/gh_substitutions.py +++ b/doc/sphinxext/gh_substitutions.py @@ -6,6 +6,7 @@ https://doughellmann.com/blog/2010/05/09/defining-custom-roles-in-sphinx/ """ + from docutils.nodes import reference from docutils.parsers.rst.roles import set_classes diff --git a/examples/howto/optimize_evoked.py b/examples/howto/optimize_evoked.py index 0f799517d..4b74060f7 100644 --- a/examples/howto/optimize_evoked.py +++ b/examples/howto/optimize_evoked.py @@ -20,8 +20,7 @@ # Let us import hnn_core import hnn_core -from hnn_core import (MPIBackend, jones_2009_model, simulate_dipole, - read_dipole) +from hnn_core import MPIBackend, jones_2009_model, simulate_dipole, read_dipole hnn_core_root = op.join(op.dirname(hnn_core.__file__)) @@ -37,8 +36,10 @@ from urllib.request import urlretrieve -data_url = ('https://raw.githubusercontent.com/jonescompneurolab/hnn/master/' - 'data/MEG_detection_data/S1_SupraT.txt') +data_url = ( + 'https://raw.githubusercontent.com/jonescompneurolab/hnn/master/' + 'data/MEG_detection_data/S1_SupraT.txt' +) urlretrieve(data_url, 'S1_SupraT.txt') exp_dpl = read_dipole('S1_SupraT.txt') @@ -52,50 +53,73 @@ net_init = jones_2009_model() # Proximal 1 -weights_ampa_p1 = {'L2_basket': 0.2913, 'L2_pyramidal': 0.9337, - 'L5_basket': 0.1951, 'L5_pyramidal': 0.3602} -weights_nmda_p1 = {'L2_basket': 0.9240, 'L2_pyramidal': 0.0845, - 'L5_basket': 0.5849, 'L5_pyramidal': 0.65105} -synaptic_delays_p = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_basket': 1., 'L5_pyramidal': 1.} -net_init.add_evoked_drive('evprox1', - mu=5.6813, - sigma=20.3969, - numspikes=1, - location='proximal', - weights_ampa=weights_ampa_p1, - weights_nmda=weights_nmda_p1, - synaptic_delays=synaptic_delays_p) +weights_ampa_p1 = { + 'L2_basket': 0.2913, + 'L2_pyramidal': 0.9337, + 'L5_basket': 0.1951, + 'L5_pyramidal': 0.3602, +} +weights_nmda_p1 = { + 'L2_basket': 0.9240, + 'L2_pyramidal': 0.0845, + 'L5_basket': 0.5849, + 'L5_pyramidal': 0.65105, +} +synaptic_delays_p = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.1, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.0, +} +net_init.add_evoked_drive( + 'evprox1', + mu=5.6813, + sigma=20.3969, + numspikes=1, + location='proximal', + weights_ampa=weights_ampa_p1, + weights_nmda=weights_nmda_p1, + synaptic_delays=synaptic_delays_p, +) # Distal -weights_ampa_d1 = {'L2_basket': 0.8037, 'L2_pyramidal': 0.5738, - 'L5_pyramidal': 0.3626} -weights_nmda_d1 = {'L2_basket': 0.2492, 'L2_pyramidal': 0.6183, - 'L5_pyramidal': 0.1121} -synaptic_delays_d1 = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_pyramidal': 0.1} -net_init.add_evoked_drive('evdist1', - mu=58.6539, - sigma=5.5810, - numspikes=1, - location='distal', - weights_ampa=weights_ampa_d1, - weights_nmda=weights_nmda_d1, - synaptic_delays=synaptic_delays_d1) +weights_ampa_d1 = {'L2_basket': 0.8037, 'L2_pyramidal': 0.5738, 'L5_pyramidal': 0.3626} +weights_nmda_d1 = {'L2_basket': 0.2492, 'L2_pyramidal': 0.6183, 'L5_pyramidal': 0.1121} +synaptic_delays_d1 = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, 'L5_pyramidal': 0.1} +net_init.add_evoked_drive( + 'evdist1', + mu=58.6539, + sigma=5.5810, + numspikes=1, + location='distal', + weights_ampa=weights_ampa_d1, + weights_nmda=weights_nmda_d1, + synaptic_delays=synaptic_delays_d1, +) # Proximal 2 -weights_ampa_p2 = {'L2_basket': 0.01, 'L2_pyramidal': 0.01, 'L5_basket': 0.01, - 'L5_pyramidal': 0.01} -weights_nmda_p2 = {'L2_basket': 0.01, 'L2_pyramidal': 0.01, 'L5_basket': 0.01, - 'L5_pyramidal': 0.01} -net_init.add_evoked_drive('evprox2', - mu=80, - sigma=1, - numspikes=1, - location='proximal', - weights_ampa=weights_ampa_p2, - weights_nmda=weights_nmda_p2, - synaptic_delays=synaptic_delays_p) +weights_ampa_p2 = { + 'L2_basket': 0.01, + 'L2_pyramidal': 0.01, + 'L5_basket': 0.01, + 'L5_pyramidal': 0.01, +} +weights_nmda_p2 = { + 'L2_basket': 0.01, + 'L2_pyramidal': 0.01, + 'L5_basket': 0.01, + 'L5_pyramidal': 0.01, +} +net_init.add_evoked_drive( + 'evprox2', + mu=80, + sigma=1, + numspikes=1, + location='proximal', + weights_ampa=weights_ampa_p2, + weights_nmda=weights_nmda_p2, + synaptic_delays=synaptic_delays_p, +) with MPIBackend(n_procs=n_procs, mpi_cmd='mpiexec'): init_dpl = simulate_dipole(net_init, tstop=tstop, n_trials=1)[0] @@ -112,52 +136,54 @@ def set_params(net, params): - # Proximal 1 - net.add_evoked_drive('evprox1', - mu=5.6813, - sigma=20.3969, - numspikes=1, - location='proximal', - weights_ampa=weights_ampa_p1, - weights_nmda=weights_nmda_p1, - synaptic_delays=synaptic_delays_p) + net.add_evoked_drive( + 'evprox1', + mu=5.6813, + sigma=20.3969, + numspikes=1, + location='proximal', + weights_ampa=weights_ampa_p1, + weights_nmda=weights_nmda_p1, + synaptic_delays=synaptic_delays_p, + ) # Distal - net.add_evoked_drive('evdist1', - mu=58.6539, - sigma=5.5810, - numspikes=1, - location='distal', - weights_ampa=weights_ampa_d1, - weights_nmda=weights_nmda_d1, - synaptic_delays=synaptic_delays_d1) + net.add_evoked_drive( + 'evdist1', + mu=58.6539, + sigma=5.5810, + numspikes=1, + location='distal', + weights_ampa=weights_ampa_d1, + weights_nmda=weights_nmda_d1, + synaptic_delays=synaptic_delays_d1, + ) # Proximal 2 - weights_ampa_p2 = {'L2_basket': - params['evprox2_ampa_L2_basket'], - 'L2_pyramidal': - params['evprox2_ampa_L2_pyramidal'], - 'L5_basket': - params['evprox2_ampa_L5_basket'], - 'L5_pyramidal': - params['evprox2_ampa_L5_pyramidal']} - weights_nmda_p2 = {'L2_basket': - params['evprox2_nmda_L2_basket'], - 'L2_pyramidal': - params['evprox2_nmda_L2_pyramidal'], - 'L5_basket': - params['evprox2_nmda_L5_basket'], - 'L5_pyramidal': - params['evprox2_nmda_L5_pyramidal']} - net.add_evoked_drive('evprox2', - mu=params['evprox2_mu'], - sigma=params['evprox2_sigma'], - numspikes=1, - location='proximal', - weights_ampa=weights_ampa_p2, - weights_nmda=weights_nmda_p2, - synaptic_delays=synaptic_delays_p) + weights_ampa_p2 = { + 'L2_basket': params['evprox2_ampa_L2_basket'], + 'L2_pyramidal': params['evprox2_ampa_L2_pyramidal'], + 'L5_basket': params['evprox2_ampa_L5_basket'], + 'L5_pyramidal': params['evprox2_ampa_L5_pyramidal'], + } + weights_nmda_p2 = { + 'L2_basket': params['evprox2_nmda_L2_basket'], + 'L2_pyramidal': params['evprox2_nmda_L2_pyramidal'], + 'L5_basket': params['evprox2_nmda_L5_basket'], + 'L5_pyramidal': params['evprox2_nmda_L5_pyramidal'], + } + net.add_evoked_drive( + 'evprox2', + mu=params['evprox2_mu'], + sigma=params['evprox2_sigma'], + numspikes=1, + location='proximal', + weights_ampa=weights_ampa_p2, + weights_nmda=weights_nmda_p2, + synaptic_delays=synaptic_delays_p, + ) + ############################################################################### # Then, we define the constraints. @@ -170,16 +196,20 @@ def set_params(net, params): # were chosen so as to keep the model in physiologically realistic regimes. -constraints = dict({'evprox2_ampa_L2_basket': (0.01, 1.), - 'evprox2_ampa_L2_pyramidal': (0.01, 1.), - 'evprox2_ampa_L5_basket': (0.01, 1.), - 'evprox2_ampa_L5_pyramidal': (0.01, 1.), - 'evprox2_nmda_L2_basket': (0.01, 1.), - 'evprox2_nmda_L2_pyramidal': (0.01, 1.), - 'evprox2_nmda_L5_basket': (0.01, 1.), - 'evprox2_nmda_L5_pyramidal': (0.01, 1.), - 'evprox2_mu': (100., 120.), - 'evprox2_sigma': (2., 30.)}) +constraints = dict( + { + 'evprox2_ampa_L2_basket': (0.01, 1.0), + 'evprox2_ampa_L2_pyramidal': (0.01, 1.0), + 'evprox2_ampa_L5_basket': (0.01, 1.0), + 'evprox2_ampa_L5_pyramidal': (0.01, 1.0), + 'evprox2_nmda_L2_basket': (0.01, 1.0), + 'evprox2_nmda_L2_pyramidal': (0.01, 1.0), + 'evprox2_nmda_L5_basket': (0.01, 1.0), + 'evprox2_nmda_L5_pyramidal': (0.01, 1.0), + 'evprox2_mu': (100.0, 120.0), + 'evprox2_sigma': (2.0, 30.0), + } +) ############################################################################### # Now we define and fit the optimizer. @@ -187,11 +217,11 @@ def set_params(net, params): from hnn_core.optimization import Optimizer net = jones_2009_model() -optim = Optimizer(net, tstop=tstop, constraints=constraints, - set_params=set_params) +optim = Optimizer(net, tstop=tstop, constraints=constraints, set_params=set_params) with MPIBackend(n_procs=n_procs, mpi_cmd='mpiexec'): - optim.fit(target=exp_dpl, scale_factor=scale_factor, - smooth_window_len=smooth_window_len) + optim.fit( + target=exp_dpl, scale_factor=scale_factor, smooth_window_len=smooth_window_len + ) ############################################################################### # Finally, we can plot the experimental data alongside the post-optimization diff --git a/examples/howto/optimize_rhythmic.py b/examples/howto/optimize_rhythmic.py index 8e93185f3..2727747cd 100644 --- a/examples/howto/optimize_rhythmic.py +++ b/examples/howto/optimize_rhythmic.py @@ -15,7 +15,7 @@ ############################################################################### # Let us import hnn_core -from hnn_core import (MPIBackend, jones_2009_model, simulate_dipole) +from hnn_core import MPIBackend, jones_2009_model, simulate_dipole # The number of cores may need modifying depending on your current machine. n_procs = 10 @@ -26,39 +26,48 @@ # object with no attached drives, and a dictionary of the parameters we wish to # optimize. -def set_params(net, params): +def set_params(net, params): # Proximal (alpha) - weights_ampa_p = {'L2_pyramidal': params['alpha_prox_weight'], - 'L5_pyramidal': 4.4e-5} - syn_delays_p = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.} - - net.add_bursty_drive('alpha_prox', - tstart=params['alpha_prox_tstart'], - burst_rate=params['alpha_prox_burst_rate'], - burst_std=params['alpha_prox_burst_std'], - numspikes=2, - spike_isi=10, - n_drive_cells=10, - location='proximal', - weights_ampa=weights_ampa_p, - synaptic_delays=syn_delays_p) + weights_ampa_p = { + 'L2_pyramidal': params['alpha_prox_weight'], + 'L5_pyramidal': 4.4e-5, + } + syn_delays_p = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.0} + + net.add_bursty_drive( + 'alpha_prox', + tstart=params['alpha_prox_tstart'], + burst_rate=params['alpha_prox_burst_rate'], + burst_std=params['alpha_prox_burst_std'], + numspikes=2, + spike_isi=10, + n_drive_cells=10, + location='proximal', + weights_ampa=weights_ampa_p, + synaptic_delays=syn_delays_p, + ) # Distal (beta) - weights_ampa_d = {'L2_pyramidal': params['alpha_dist_weight'], - 'L5_pyramidal': 4.4e-5} - syn_delays_d = {'L2_pyramidal': 5., 'L5_pyramidal': 5.} - - net.add_bursty_drive('alpha_dist', - tstart=params['alpha_dist_tstart'], - burst_rate=params['alpha_dist_burst_rate'], - burst_std=params['alpha_dist_burst_std'], - numspikes=2, - spike_isi=10, - n_drive_cells=10, - location='distal', - weights_ampa=weights_ampa_d, - synaptic_delays=syn_delays_d) + weights_ampa_d = { + 'L2_pyramidal': params['alpha_dist_weight'], + 'L5_pyramidal': 4.4e-5, + } + syn_delays_d = {'L2_pyramidal': 5.0, 'L5_pyramidal': 5.0} + + net.add_bursty_drive( + 'alpha_dist', + tstart=params['alpha_dist_tstart'], + burst_rate=params['alpha_dist_burst_rate'], + burst_std=params['alpha_dist_burst_std'], + numspikes=2, + spike_isi=10, + n_drive_cells=10, + location='distal', + weights_ampa=weights_ampa_d, + synaptic_delays=syn_delays_d, + ) + ############################################################################### # Then, we define the constraints. @@ -71,14 +80,18 @@ def set_params(net, params): # were chosen so as to keep the model in physiologically realistic regimes. constraints = dict() -constraints.update({'alpha_prox_weight': (4.4e-5, 6.4e-5), - 'alpha_prox_tstart': (45, 55), - 'alpha_prox_burst_rate': (1, 30), - 'alpha_prox_burst_std': (10, 30), - 'alpha_dist_weight': (4.4e-5, 6.4e-5), - 'alpha_dist_tstart': (45, 55), - 'alpha_dist_burst_rate': (1, 30), - 'alpha_dist_burst_std': (10, 30)}) +constraints.update( + { + 'alpha_prox_weight': (4.4e-5, 6.4e-5), + 'alpha_prox_tstart': (45, 55), + 'alpha_prox_burst_rate': (1, 30), + 'alpha_prox_burst_std': (10, 30), + 'alpha_dist_weight': (4.4e-5, 6.4e-5), + 'alpha_dist_tstart': (45, 55), + 'alpha_dist_burst_rate': (1, 30), + 'alpha_dist_burst_std': (10, 30), + } +) ############################################################################### # Now we define and fit the optimizer. @@ -89,14 +102,23 @@ def set_params(net, params): smooth_window_len = 20 net = jones_2009_model() -optim = Optimizer(net, tstop=tstop, constraints=constraints, - set_params=set_params, obj_fun='maximize_psd') +optim = Optimizer( + net, + tstop=tstop, + constraints=constraints, + set_params=set_params, + obj_fun='maximize_psd', +) # 8-15 Hz (alpha) and 15-30 Hz (beta) are the frequency bands whose # power we wish to maximize in a ratio of 1 to 2. with MPIBackend(n_procs=n_procs, mpi_cmd='mpiexec'): - optim.fit(f_bands=[(9, 11), (19, 21)], relative_bandpower=(1, 2), - scale_factor=scale_factor, smooth_window_len=smooth_window_len) + optim.fit( + f_bands=[(9, 11), (19, 21)], + relative_bandpower=(1, 2), + scale_factor=scale_factor, + smooth_window_len=smooth_window_len, + ) ############################################################################### # Finally, we can plot the optimized dipole, power spectral density (PSD), and diff --git a/examples/howto/plot_batch_simulate.py b/examples/howto/plot_batch_simulate.py index f8018e731..5e2660907 100644 --- a/examples/howto/plot_batch_simulate.py +++ b/examples/howto/plot_batch_simulate.py @@ -39,22 +39,31 @@ def set_params(param_values, net=None): net : instance of Network, optional If None, a new network is created using the specified model type. """ - weights_ampa = {'L2_basket': param_values['weight_basket'], - 'L2_pyramidal': param_values['weight_pyr'], - 'L5_basket': param_values['weight_basket'], - 'L5_pyramidal': param_values['weight_pyr']} - - synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_basket': 1., 'L5_pyramidal': 1.} + weights_ampa = { + 'L2_basket': param_values['weight_basket'], + 'L2_pyramidal': param_values['weight_pyr'], + 'L5_basket': param_values['weight_basket'], + 'L5_pyramidal': param_values['weight_pyr'], + } + + synaptic_delays = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.1, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.0, + } # Add an evoked drive to the network. - net.add_evoked_drive('evprox', - mu=40, - sigma=5, - numspikes=1, - location='proximal', - weights_ampa=weights_ampa, - synaptic_delays=synaptic_delays) + net.add_evoked_drive( + 'evprox', + mu=40, + sigma=5, + numspikes=1, + location='proximal', + weights_ampa=weights_ampa, + synaptic_delays=synaptic_delays, + ) + ############################################################################### # Define a parameter grid for the batch simulation. @@ -62,7 +71,7 @@ def set_params(param_values, net=None): param_grid = { 'weight_basket': np.logspace(-4 - 1, 10), - 'weight_pyr': np.logspace(-4, -1, 10) + 'weight_pyr': np.logspace(-4, -1, 10), } ############################################################################### @@ -92,6 +101,7 @@ def summary_func(results): summary_stats.append({'min_peak': min_peak, 'max_peak': max_peak}) return summary_stats + ############################################################################### # Run the batch simulation and collect the results. @@ -102,15 +112,14 @@ def summary_func(results): # Run the batch simulation and collect the results. net = jones_2009_model(mesh_shape=(3, 3)) -batch_simulation = BatchSimulate(net=net, - set_params=set_params, - summary_func=summary_func) -simulation_results = batch_simulation.run(param_grid, - n_jobs=n_jobs, - combinations=False, - backend='multiprocessing') +batch_simulation = BatchSimulate( + net=net, set_params=set_params, summary_func=summary_func +) +simulation_results = batch_simulation.run( + param_grid, n_jobs=n_jobs, combinations=False, backend='multiprocessing' +) # backend='dask' if installed -print("Simulation results:", simulation_results) +print('Simulation results:', simulation_results) ############################################################################### # This plot shows an overlay of all smoothed dipole waveforms from the # batch simulation. Each line represents a different set of parameters, @@ -137,8 +146,9 @@ def summary_func(results): # dipole activity changes as we vary the synaptic strength parameter. min_peaks, max_peaks, param_values = [], [], [] -for summary_list, data_list in zip(simulation_results['summary_statistics'], - simulation_results['simulated_data']): +for summary_list, data_list in zip( + simulation_results['summary_statistics'], simulation_results['simulated_data'] +): for summary, data in zip(summary_list, data_list): min_peaks.append(summary['min_peak']) max_peaks.append(summary['max_peak']) diff --git a/examples/howto/plot_connectivity.py b/examples/howto/plot_connectivity.py index b2fb97aac..b944cfa71 100644 --- a/examples/howto/plot_connectivity.py +++ b/examples/howto/plot_connectivity.py @@ -10,12 +10,10 @@ # sphinx_gallery_thumbnail_number = 2 -import os.path as op ############################################################################### # Let us import ``hnn_core``. -import hnn_core from hnn_core import jones_2009_model, simulate_dipole ############################################################################### @@ -42,8 +40,12 @@ print(len(net_erp.connectivity)) conn_indices = pick_connection( - net=net_erp, src_gids='L5_basket', target_gids='L5_pyramidal', - loc='soma', receptor='gabaa') + net=net_erp, + src_gids='L5_basket', + target_gids='L5_pyramidal', + loc='soma', + receptor='gabaa', +) conn_idx = conn_indices[0] print(net_erp.connectivity[conn_idx]) plot_connectivity_matrix(net_erp, conn_idx) @@ -57,9 +59,10 @@ # Data recorded during simulations are stored under # :class:`~hnn_core.Cell_Response`. Spiking activity can be visualized after # a simulation is using :meth:`~hnn_core.Cell_Response.plot_spikes_raster` -dpl_erp = simulate_dipole(net_erp, tstop=170., n_trials=1) +dpl_erp = simulate_dipole(net_erp, tstop=170.0, n_trials=1) net_erp.cell_response.plot_spikes_raster() + ############################################################################### # We can also define our own connections to test the effect of different # connectivity patterns. To start, ``net.clear_connectivity()`` can be used @@ -80,30 +83,46 @@ def get_network(probability=1.0): src = 'L5_pyramidal' conn_seed = 3 for target in ['L5_pyramidal', 'L2_basket']: - net.add_connection(src, target, location, receptor, - delay, weight, lamtha, probability=probability, - conn_seed=conn_seed) + net.add_connection( + src, + target, + location, + receptor, + delay, + weight, + lamtha, + probability=probability, + conn_seed=conn_seed, + ) # Basket cell connections location, receptor = 'soma', 'gabaa' weight, delay, lamtha = 1.0, 1.0, 70 src = 'L2_basket' for target in ['L5_pyramidal', 'L2_basket']: - net.add_connection(src, target, location, receptor, - delay, weight, lamtha, probability=probability, - conn_seed=conn_seed) + net.add_connection( + src, + target, + location, + receptor, + delay, + weight, + lamtha, + probability=probability, + conn_seed=conn_seed, + ) return net net_all = get_network() -dpl_all = simulate_dipole(net_all, tstop=170., n_trials=1) +dpl_all = simulate_dipole(net_all, tstop=170.0, n_trials=1) ############################################################################### # We can additionally use the ``probability`` argument to create a sparse # connectivity pattern instead of all-to-all. Let's try creating the same # network with a 10% chance of cells connecting to each other. net_sparse = get_network(probability=0.1) -dpl_sparse = simulate_dipole(net_sparse, tstop=170., n_trials=1) +dpl_sparse = simulate_dipole(net_sparse, tstop=170.0, n_trials=1) ############################################################################### # With the previous connection pattern there appears to be synchronous rhythmic @@ -117,8 +136,12 @@ def get_network(probability=1.0): ############################################################################### # We can plot the sparse connectivity pattern between cell populations. conn_indices = pick_connection( - net=net_sparse, src_gids='L2_basket', target_gids='L2_basket', - loc='soma', receptor='gabaa') + net=net_sparse, + src_gids='L2_basket', + target_gids='L2_basket', + loc='soma', + receptor='gabaa', +) conn_idx = conn_indices[0] plot_connectivity_matrix(net_sparse, conn_idx) @@ -136,16 +159,17 @@ def get_network(probability=1.0): # the aggregate current dipole. import matplotlib.pyplot as plt from hnn_core.viz import plot_dipole -fig, axes = plt.subplots(2, 1, sharex=True, figsize=(6, 6), - constrained_layout=True) + +fig, axes = plt.subplots(2, 1, sharex=True, figsize=(6, 6), constrained_layout=True) window_len = 30 # ms scaling_factor = 3000 -dpls = [dpl_erp[0].smooth(window_len).scale(scaling_factor), - dpl_all[0].smooth(window_len).scale(scaling_factor), - dpl_sparse[0].smooth(window_len).scale(scaling_factor)] +dpls = [ + dpl_erp[0].smooth(window_len).scale(scaling_factor), + dpl_all[0].smooth(window_len).scale(scaling_factor), + dpl_sparse[0].smooth(window_len).scale(scaling_factor), +] plot_dipole(dpls, ax=axes[0], layer='agg', show=False) axes[0].legend(['Default', 'Custom All', 'Custom Sparse']) -net_erp.cell_response.plot_spikes_hist( - ax=axes[1], spike_types=['evprox', 'evdist']) +net_erp.cell_response.plot_spikes_hist(ax=axes[1], spike_types=['evprox', 'evdist']) diff --git a/examples/howto/plot_firing_pattern.py b/examples/howto/plot_firing_pattern.py index d3f0eaec5..d32eac9b5 100644 --- a/examples/howto/plot_firing_pattern.py +++ b/examples/howto/plot_firing_pattern.py @@ -16,7 +16,6 @@ ############################################################################### # Let us import ``hnn_core``. -import hnn_core from hnn_core import read_spikes, jones_2009_model, simulate_dipole ############################################################################### @@ -33,46 +32,83 @@ # "evoked drive" defines inputs that are normally distributed with a certain # mean and standard deviation. -weights_ampa_d1 = {'L2_basket': 0.006562, 'L2_pyramidal': 7e-6, - 'L5_pyramidal': 0.142300} -weights_nmda_d1 = {'L2_basket': 0.019482, 'L2_pyramidal': 0.004317, - 'L5_pyramidal': 0.080074} -synaptic_delays_d1 = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_pyramidal': 0.1} +weights_ampa_d1 = { + 'L2_basket': 0.006562, + 'L2_pyramidal': 7e-6, + 'L5_pyramidal': 0.142300, +} +weights_nmda_d1 = { + 'L2_basket': 0.019482, + 'L2_pyramidal': 0.004317, + 'L5_pyramidal': 0.080074, +} +synaptic_delays_d1 = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, 'L5_pyramidal': 0.1} net.add_evoked_drive( - 'evdist1', mu=63.53, sigma=3.85, numspikes=1, weights_ampa=weights_ampa_d1, - weights_nmda=weights_nmda_d1, location='distal', - synaptic_delays=synaptic_delays_d1, event_seed=274) + 'evdist1', + mu=63.53, + sigma=3.85, + numspikes=1, + weights_ampa=weights_ampa_d1, + weights_nmda=weights_nmda_d1, + location='distal', + synaptic_delays=synaptic_delays_d1, + event_seed=274, +) ############################################################################### # The reason it is called an "evoked drive" is it can be used to simulate # waveforms resembling evoked responses. Here, we show how to do it with two # proximal drives which drive current up the dendrite and one distal drive # which drives current down the dendrite producing the negative deflection. -weights_ampa_p1 = {'L2_basket': 0.08831, 'L2_pyramidal': 0.01525, - 'L5_basket': 0.19934, 'L5_pyramidal': 0.00865} -synaptic_delays_prox = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_basket': 1., 'L5_pyramidal': 1.} +weights_ampa_p1 = { + 'L2_basket': 0.08831, + 'L2_pyramidal': 0.01525, + 'L5_basket': 0.19934, + 'L5_pyramidal': 0.00865, +} +synaptic_delays_prox = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.1, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.0, +} # all NMDA weights are zero; pass None explicitly net.add_evoked_drive( - 'evprox1', mu=26.61, sigma=2.47, numspikes=1, weights_ampa=weights_ampa_p1, - weights_nmda=None, location='proximal', - synaptic_delays=synaptic_delays_prox, event_seed=544) + 'evprox1', + mu=26.61, + sigma=2.47, + numspikes=1, + weights_ampa=weights_ampa_p1, + weights_nmda=None, + location='proximal', + synaptic_delays=synaptic_delays_prox, + event_seed=544, +) ############################################################################### # Now we add the second proximal evoked drive and simulate the network # dynamics with somatic voltage recordings enabled. Note: only AMPA weights # differ from first. -weights_ampa_p2 = {'L2_basket': 0.000003, 'L2_pyramidal': 1.438840, - 'L5_basket': 0.008958, 'L5_pyramidal': 0.684013} +weights_ampa_p2 = { + 'L2_basket': 0.000003, + 'L2_pyramidal': 1.438840, + 'L5_basket': 0.008958, + 'L5_pyramidal': 0.684013, +} # all NMDA weights are zero; omit weights_nmda (defaults to None) net.add_evoked_drive( - 'evprox2', mu=137.12, sigma=8.33, numspikes=1, - weights_ampa=weights_ampa_p2, location='proximal', - synaptic_delays=synaptic_delays_prox, event_seed=814) - -dpls = simulate_dipole(net, tstop=170., record_vsec='soma') + 'evprox2', + mu=137.12, + sigma=8.33, + numspikes=1, + weights_ampa=weights_ampa_p2, + location='proximal', + synaptic_delays=synaptic_delays_prox, + event_seed=814, +) + +dpls = simulate_dipole(net, tstop=170.0, record_vsec='soma') ############################################################################### # Here, we explain more details about the data structures and how they can @@ -115,12 +151,12 @@ ############################################################################### # We can additionally calculate the mean spike rates for each cell class by # specifying a time window with ``tstart`` and ``tstop``. -all_rates = cell_response.mean_rates(tstart=0, tstop=170, - gid_ranges=net.gid_ranges, - mean_type='all') -trial_rates = cell_response.mean_rates(tstart=0, tstop=170, - gid_ranges=net.gid_ranges, - mean_type='trial') +all_rates = cell_response.mean_rates( + tstart=0, tstop=170, gid_ranges=net.gid_ranges, mean_type='all' +) +trial_rates = cell_response.mean_rates( + tstart=0, tstop=170, gid_ranges=net.gid_ranges, mean_type='trial' +) print('Mean spike rates across trials:') print(all_rates) print('Mean spike rates for individual trials:') @@ -138,6 +174,6 @@ gid = gid_ranges['L5_pyramidal'][idx] axes[0].plot(net.cell_response.times, vsec[gid]['soma'], color='r') net.cell_response.plot_spikes_raster(ax=axes[1]) -net.cell_response.plot_spikes_hist(ax=axes[2], - spike_types=['L5_pyramidal', - 'L2_pyramidal']) +net.cell_response.plot_spikes_hist( + ax=axes[2], spike_types=['L5_pyramidal', 'L2_pyramidal'] +) diff --git a/examples/howto/plot_hnn_animation.py b/examples/howto/plot_hnn_animation.py index 858ee2e99..58bda7a42 100644 --- a/examples/howto/plot_hnn_animation.py +++ b/examples/howto/plot_hnn_animation.py @@ -8,14 +8,11 @@ # Author: Nick Tolley - ############################################################################### # First, we'll import the necessary modules for instantiating a network and # running a simulation that we would like to animate. -import os.path as op -import hnn_core -from hnn_core import jones_2009_model, simulate_dipole, read_params +from hnn_core import jones_2009_model, simulate_dipole from hnn_core.network_models import add_erp_drives_to_jones_model ############################################################################### diff --git a/examples/howto/plot_record_extracellular_potentials.py b/examples/howto/plot_record_extracellular_potentials.py index 3ae0df044..fa8a3cb05 100644 --- a/examples/howto/plot_record_extracellular_potentials.py +++ b/examples/howto/plot_record_extracellular_potentials.py @@ -36,7 +36,7 @@ net = jones_2009_model() add_erp_drives_to_jones_model(net) -net.set_cell_positions(inplane_distance=30.) +net.set_cell_positions(inplane_distance=30.0) ############################################################################### # Extracellular recordings require specifying the electrode positions. It can be @@ -82,8 +82,9 @@ trial_idx = 0 window_len = 10 # ms decimate = [5, 4] # from 40k to 8k to 2k -fig, axs = plt.subplots(4, 1, sharex=True, figsize=(6, 8), - gridspec_kw={'height_ratios': [1, 3, 3, 3]}) +fig, axs = plt.subplots( + 4, 1, sharex=True, figsize=(6, 8), gridspec_kw={'height_ratios': [1, 3, 3, 3]} +) # Then plot the aggregate dipole time series on its own axis dpl[trial_idx].smooth(window_len=window_len) @@ -91,7 +92,8 @@ # use the same smoothing window on the LFP traces to allow comparison to dipole net.rec_arrays['shank1'][trial_idx].smooth(window_len=window_len).plot_lfp( - ax=axs[1], decim=decimate, show=False) + ax=axs[1], decim=decimate, show=False +) axs[1].grid(True, which='major', axis='x') axs[1].set_xlabel('') @@ -99,7 +101,9 @@ net.cell_response.plot_spikes_raster(ax=axs[2], show=False) # Finally, add the CSD to the bottom subplot -net.rec_arrays['shank1'][trial_idx].smooth(window_len=window_len).plot_csd(ax=axs[3], show=False) +net.rec_arrays['shank1'][trial_idx].smooth(window_len=window_len).plot_csd( + ax=axs[3], show=False +) plt.tight_layout() plt.show() diff --git a/examples/howto/plot_simulate_mpi_backend.py b/examples/howto/plot_simulate_mpi_backend.py index bb7320e65..b64e7cbaf 100644 --- a/examples/howto/plot_simulate_mpi_backend.py +++ b/examples/howto/plot_simulate_mpi_backend.py @@ -18,9 +18,7 @@ ############################################################################### # Let us import hnn_core -import os.path as op -import hnn_core from hnn_core import simulate_dipole, jones_2009_model ############################################################################### @@ -35,9 +33,17 @@ weights_ampa = {'L2_pyramidal': 5.4e-5, 'L5_pyramidal': 5.4e-5} net.add_bursty_drive( - 'bursty', tstart=50., burst_rate=10, burst_std=20., numspikes=2, - spike_isi=10, n_drive_cells=10, location='distal', - weights_ampa=weights_ampa, event_seed=278) + 'bursty', + tstart=50.0, + burst_rate=10, + burst_std=20.0, + numspikes=2, + spike_isi=10, + n_drive_cells=10, + location='distal', + weights_ampa=weights_ampa, + event_seed=278, +) ############################################################################### # Finally, to simulate we use the @@ -48,7 +54,7 @@ from hnn_core import MPIBackend with MPIBackend(n_procs=2, mpi_cmd='mpiexec'): - dpls = simulate_dipole(net, tstop=310., n_trials=1) + dpls = simulate_dipole(net, tstop=310.0, n_trials=1) trial_idx = 0 dpls[trial_idx].plot() diff --git a/examples/workflows/plot_simulate_alpha.py b/examples/workflows/plot_simulate_alpha.py index bf41a8b60..59c7852ea 100644 --- a/examples/workflows/plot_simulate_alpha.py +++ b/examples/workflows/plot_simulate_alpha.py @@ -20,12 +20,10 @@ # Nick Tolley # Christopher Bailey -import os.path as op ############################################################################### # Let us import hnn_core -import hnn_core from hnn_core import simulate_dipole, jones_2009_model ############################################################################### @@ -34,22 +32,31 @@ # simulation. Each burst consists of a pair (2) of spikes, spaced 10 ms apart. # The occurrence of each burst is jittered by a random, normally distributed # amount (20 ms standard deviation). We repeat the burst train 10 times, each -# time with unique randomization. The drive is only connected to the +# time with unique randomization. The drive is only connected to the # :term:`proximal` (dendritic) AMPA synapses on L2/3 and L5 pyramidal neurons. net = jones_2009_model() location = 'proximal' burst_std = 20 weights_ampa_p = {'L2_pyramidal': 5.4e-5, 'L5_pyramidal': 5.4e-5} -syn_delays_p = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.} +syn_delays_p = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.0} net.add_bursty_drive( - 'alpha_prox', tstart=50., burst_rate=10, burst_std=burst_std, numspikes=2, - spike_isi=10, n_drive_cells=10, location=location, - weights_ampa=weights_ampa_p, synaptic_delays=syn_delays_p, event_seed=284) + 'alpha_prox', + tstart=50.0, + burst_rate=10, + burst_std=burst_std, + numspikes=2, + spike_isi=10, + n_drive_cells=10, + location=location, + weights_ampa=weights_ampa_p, + synaptic_delays=syn_delays_p, + event_seed=284, +) # simulate the dipole, but do not automatically scale or smooth the result -dpl = simulate_dipole(net, tstop=310., n_trials=1) +dpl = simulate_dipole(net, tstop=310.0, n_trials=1) trial_idx = 0 # single trial simulated, choose the first index # to emulate a larger patch of cortex, we can apply a simple scaling factor @@ -63,7 +70,7 @@ # included in our biophysical model. We can confirm that what we simulate is # indeed 10 Hz activity by plotting the power spectral density (PSD). import matplotlib.pyplot as plt -from hnn_core.viz import plot_dipole, plot_psd +from hnn_core.viz import plot_psd fig, axes = plt.subplots(2, 1, constrained_layout=True) tmin, tmax = 10, 300 # exclude the initial burn-in period from the plots @@ -78,7 +85,7 @@ dpl_smooth.plot(tmin=tmin, tmax=tmax, color='r', ax=axes[0], show=False) axes[0].set_xlim((1, 399)) -plot_psd(dpl[trial_idx], fmin=1., fmax=1e3, tmin=tmin, ax=axes[1], show=False) +plot_psd(dpl[trial_idx], fmin=1.0, fmax=1e3, tmin=tmin, ax=axes[1], show=False) axes[1].set_xscale('log') plt.tight_layout() ############################################################################### @@ -90,12 +97,21 @@ location = 'distal' burst_std = 15 weights_ampa_d = {'L2_pyramidal': 5.4e-5, 'L5_pyramidal': 5.4e-5} -syn_delays_d = {'L2_pyramidal': 5., 'L5_pyramidal': 5.} +syn_delays_d = {'L2_pyramidal': 5.0, 'L5_pyramidal': 5.0} net.add_bursty_drive( - 'alpha_dist', tstart=50., burst_rate=10, burst_std=burst_std, numspikes=2, - spike_isi=10, n_drive_cells=10, location=location, - weights_ampa=weights_ampa_d, synaptic_delays=syn_delays_d, event_seed=296) -dpl = simulate_dipole(net, tstop=310., n_trials=1) + 'alpha_dist', + tstart=50.0, + burst_rate=10, + burst_std=burst_std, + numspikes=2, + spike_isi=10, + n_drive_cells=10, + location=location, + weights_ampa=weights_ampa_d, + synaptic_delays=syn_delays_d, + event_seed=296, +) +dpl = simulate_dipole(net, tstop=310.0, n_trials=1) ############################################################################### # We can verify that beta frequency activity was produced by inspecting the PSD @@ -114,7 +130,7 @@ dpl[trial_idx].plot(tmin=tmin, tmax=tmax, ax=axes[1], color='b', show=False) smooth_dpl.plot(tmin=tmin, tmax=tmax, ax=axes[1], color='r', show=False) -dpl[trial_idx].plot_psd(fmin=0., fmax=40., tmin=tmin, ax=axes[2]) +dpl[trial_idx].plot_psd(fmin=0.0, fmax=40.0, tmin=tmin, ax=axes[2]) plt.tight_layout() ############################################################################### diff --git a/examples/workflows/plot_simulate_beta.py b/examples/workflows/plot_simulate_beta.py index 8452c50e7..a5d5dbea4 100644 --- a/examples/workflows/plot_simulate_beta.py +++ b/examples/workflows/plot_simulate_beta.py @@ -79,37 +79,74 @@ # above. def add_erp_drives(net, stimulus_start): # Distal evoked drive - weights_ampa_d1 = {'L2_basket': 0.0005, 'L2_pyramidal': 0.004, - 'L5_pyramidal': 0.0005} - weights_nmda_d1 = {'L2_basket': 0.0005, 'L2_pyramidal': 0.004, - 'L5_pyramidal': 0.0005} - syn_delays_d1 = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_pyramidal': 0.1} + weights_ampa_d1 = { + 'L2_basket': 0.0005, + 'L2_pyramidal': 0.004, + 'L5_pyramidal': 0.0005, + } + weights_nmda_d1 = { + 'L2_basket': 0.0005, + 'L2_pyramidal': 0.004, + 'L5_pyramidal': 0.0005, + } + syn_delays_d1 = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, 'L5_pyramidal': 0.1} net.add_evoked_drive( - 'evdist1', mu=70.0 + stimulus_start, sigma=0.0, numspikes=1, - weights_ampa=weights_ampa_d1, weights_nmda=weights_nmda_d1, - location='distal', synaptic_delays=syn_delays_d1, event_seed=274) + 'evdist1', + mu=70.0 + stimulus_start, + sigma=0.0, + numspikes=1, + weights_ampa=weights_ampa_d1, + weights_nmda=weights_nmda_d1, + location='distal', + synaptic_delays=syn_delays_d1, + event_seed=274, + ) # Two proximal drives - weights_ampa_p1 = {'L2_basket': 0.002, 'L2_pyramidal': 0.0011, - 'L5_basket': 0.001, 'L5_pyramidal': 0.001} - syn_delays_prox = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_basket': 1., 'L5_pyramidal': 1.} + weights_ampa_p1 = { + 'L2_basket': 0.002, + 'L2_pyramidal': 0.0011, + 'L5_basket': 0.001, + 'L5_pyramidal': 0.001, + } + syn_delays_prox = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.1, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.0, + } # all NMDA weights are zero; pass None explicitly net.add_evoked_drive( - 'evprox1', mu=25.0 + stimulus_start, sigma=0.0, numspikes=1, - weights_ampa=weights_ampa_p1, weights_nmda=None, - location='proximal', synaptic_delays=syn_delays_prox, event_seed=544) + 'evprox1', + mu=25.0 + stimulus_start, + sigma=0.0, + numspikes=1, + weights_ampa=weights_ampa_p1, + weights_nmda=None, + location='proximal', + synaptic_delays=syn_delays_prox, + event_seed=544, + ) # Second proximal evoked drive. NB: only AMPA weights differ from first - weights_ampa_p2 = {'L2_basket': 0.005, 'L2_pyramidal': 0.005, - 'L5_basket': 0.01, 'L5_pyramidal': 0.01} + weights_ampa_p2 = { + 'L2_basket': 0.005, + 'L2_pyramidal': 0.005, + 'L5_basket': 0.01, + 'L5_pyramidal': 0.01, + } # all NMDA weights are zero; omit weights_nmda (defaults to None) net.add_evoked_drive( - 'evprox2', mu=135.0 + stimulus_start, sigma=0.0, numspikes=1, - weights_ampa=weights_ampa_p2, location='proximal', - synaptic_delays=syn_delays_prox, event_seed=814) + 'evprox2', + mu=135.0 + stimulus_start, + sigma=0.0, + numspikes=1, + weights_ampa=weights_ampa_p2, + location='proximal', + synaptic_delays=syn_delays_prox, + event_seed=814, + ) return net @@ -121,27 +158,57 @@ def add_erp_drives(net, stimulus_start): # of the network, and ultimately suppressed sensory detection. def add_beta_drives(net, beta_start): # Distal Drive - weights_ampa_d1 = {'L2_basket': 0.00032, 'L2_pyramidal': 0.00008, - 'L5_pyramidal': 0.00004} - syn_delays_d1 = {'L2_basket': 0.5, 'L2_pyramidal': 0.5, - 'L5_pyramidal': 0.5} + weights_ampa_d1 = { + 'L2_basket': 0.00032, + 'L2_pyramidal': 0.00008, + 'L5_pyramidal': 0.00004, + } + syn_delays_d1 = {'L2_basket': 0.5, 'L2_pyramidal': 0.5, 'L5_pyramidal': 0.5} net.add_bursty_drive( - 'beta_dist', tstart=beta_start, tstart_std=0., tstop=beta_start + 50., - burst_rate=1., burst_std=10., numspikes=2, spike_isi=10, - n_drive_cells=10, location='distal', weights_ampa=weights_ampa_d1, - synaptic_delays=syn_delays_d1, event_seed=290) + 'beta_dist', + tstart=beta_start, + tstart_std=0.0, + tstop=beta_start + 50.0, + burst_rate=1.0, + burst_std=10.0, + numspikes=2, + spike_isi=10, + n_drive_cells=10, + location='distal', + weights_ampa=weights_ampa_d1, + synaptic_delays=syn_delays_d1, + event_seed=290, + ) # Proximal Drive - weights_ampa_p1 = {'L2_basket': 0.00004, 'L2_pyramidal': 0.00002, - 'L5_basket': 0.00002, 'L5_pyramidal': 0.00002} - syn_delays_p1 = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_basket': 1.0, 'L5_pyramidal': 1.0} + weights_ampa_p1 = { + 'L2_basket': 0.00004, + 'L2_pyramidal': 0.00002, + 'L5_basket': 0.00002, + 'L5_pyramidal': 0.00002, + } + syn_delays_p1 = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.1, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.0, + } net.add_bursty_drive( - 'beta_prox', tstart=beta_start, tstart_std=0., tstop=beta_start + 50., - burst_rate=1., burst_std=20., numspikes=2, spike_isi=10, - n_drive_cells=10, location='proximal', weights_ampa=weights_ampa_p1, - synaptic_delays=syn_delays_p1, event_seed=300) + 'beta_prox', + tstart=beta_start, + tstart_std=0.0, + tstop=beta_start + 50.0, + burst_rate=1.0, + burst_std=20.0, + numspikes=2, + spike_isi=10, + n_drive_cells=10, + location='proximal', + weights_ampa=weights_ampa_p1, + synaptic_delays=syn_delays_p1, + event_seed=300, + ) return net @@ -174,8 +241,8 @@ def add_beta_drives(net, beta_start): # is an asymmetric beta event with a long positive tail. import matplotlib.pyplot as plt import numpy as np -fig, axes = plt.subplots(4, 1, sharex=True, figsize=(7, 7), - constrained_layout=True) + +fig, axes = plt.subplots(4, 1, sharex=True, figsize=(7, 7), constrained_layout=True) net_beta.cell_response.plot_spikes_hist(ax=axes[0], show=False) axes[0].set_title('Beta Event Generation') plot_dipole(dpls_beta, ax=axes[1], layer='agg', tmin=1.0, color='b', show=False) @@ -183,7 +250,7 @@ def add_beta_drives(net, beta_start): axes[2].set_title('Spike Raster') # Create a fixed-step tiling of frequencies from 1 to 40 Hz in steps of 1 Hz -freqs = np.arange(10., 60., 1.) +freqs = np.arange(10.0, 60.0, 1.0) dpls_beta[0].plot_tfr_morlet(freqs, n_cycles=7, ax=axes[3]) ############################################################################### @@ -192,10 +259,8 @@ def add_beta_drives(net, beta_start): # hand to arrive at the cortex is roughly 25 ms, which means the first proximal # input to the cortical column occurs ~100 ms after the beta event. dpls_beta_erp[0].smooth(45) -fig, axes = plt.subplots(3, 1, sharex=True, figsize=(7, 7), - constrained_layout=True) -plot_dipole(dpls_beta_erp, ax=axes[0], layer='agg', tmin=1.0, color='r', - show=False) +fig, axes = plt.subplots(3, 1, sharex=True, figsize=(7, 7), constrained_layout=True) +plot_dipole(dpls_beta_erp, ax=axes[0], layer='agg', tmin=1.0, color='r', show=False) axes[0].set_title('Beta Event + ERP') net_beta_erp.cell_response.plot_spikes_hist(ax=axes[1], show=False) axes[1].set_title('Input Drives Histogram') @@ -209,10 +274,8 @@ def add_beta_drives(net, beta_start): # The sustained inhibition of the network ultimately depresses # the sensory response which is associated with a reduced ERP amplitude dpls_erp[0].smooth(45) -fig, axes = plt.subplots(3, 1, sharex=True, figsize=(7, 7), - constrained_layout=True) -plot_dipole(dpls_beta_erp, ax=axes[0], layer='agg', tmin=1.0, color='r', - show=False) +fig, axes = plt.subplots(3, 1, sharex=True, figsize=(7, 7), constrained_layout=True) +plot_dipole(dpls_beta_erp, ax=axes[0], layer='agg', tmin=1.0, color='r', show=False) plot_dipole(dpls_erp, ax=axes[0], layer='agg', tmin=1.0, color='b', show=False) axes[0].set_title('Beta ERP Comparison') axes[0].legend(['ERP + Beta', 'ERP']) diff --git a/examples/workflows/plot_simulate_evoked.py b/examples/workflows/plot_simulate_evoked.py index 802cfdda6..6cfc93b12 100644 --- a/examples/workflows/plot_simulate_evoked.py +++ b/examples/workflows/plot_simulate_evoked.py @@ -21,15 +21,10 @@ # sphinx_gallery_thumbnail_number = 3 -import os.path as op -import tempfile - -import matplotlib.pyplot as plt ############################################################################### # Let us import hnn_core -import hnn_core from hnn_core import simulate_dipole, jones_2009_model from hnn_core.viz import plot_dipole @@ -52,37 +47,74 @@ ############################################################################### # First, we add a distal evoked drive -weights_ampa_d1 = {'L2_basket': 0.006562, 'L2_pyramidal': .000007, - 'L5_pyramidal': 0.142300} -weights_nmda_d1 = {'L2_basket': 0.019482, 'L2_pyramidal': 0.004317, - 'L5_pyramidal': 0.080074} -synaptic_delays_d1 = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_pyramidal': 0.1} +weights_ampa_d1 = { + 'L2_basket': 0.006562, + 'L2_pyramidal': 0.000007, + 'L5_pyramidal': 0.142300, +} +weights_nmda_d1 = { + 'L2_basket': 0.019482, + 'L2_pyramidal': 0.004317, + 'L5_pyramidal': 0.080074, +} +synaptic_delays_d1 = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, 'L5_pyramidal': 0.1} net.add_evoked_drive( - 'evdist1', mu=63.53, sigma=3.85, numspikes=1, weights_ampa=weights_ampa_d1, - weights_nmda=weights_nmda_d1, location='distal', - synaptic_delays=synaptic_delays_d1, event_seed=274) + 'evdist1', + mu=63.53, + sigma=3.85, + numspikes=1, + weights_ampa=weights_ampa_d1, + weights_nmda=weights_nmda_d1, + location='distal', + synaptic_delays=synaptic_delays_d1, + event_seed=274, +) ############################################################################### # Then, we add two proximal drives -weights_ampa_p1 = {'L2_basket': 0.08831, 'L2_pyramidal': 0.01525, - 'L5_basket': 0.19934, 'L5_pyramidal': 0.00865} -synaptic_delays_prox = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_basket': 1., 'L5_pyramidal': 1.} +weights_ampa_p1 = { + 'L2_basket': 0.08831, + 'L2_pyramidal': 0.01525, + 'L5_basket': 0.19934, + 'L5_pyramidal': 0.00865, +} +synaptic_delays_prox = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.1, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.0, +} # all NMDA weights are zero; pass None explicitly net.add_evoked_drive( - 'evprox1', mu=26.61, sigma=2.47, numspikes=1, weights_ampa=weights_ampa_p1, - weights_nmda=None, location='proximal', - synaptic_delays=synaptic_delays_prox, event_seed=544) + 'evprox1', + mu=26.61, + sigma=2.47, + numspikes=1, + weights_ampa=weights_ampa_p1, + weights_nmda=None, + location='proximal', + synaptic_delays=synaptic_delays_prox, + event_seed=544, +) # Second proximal evoked drive. NB: only AMPA weights differ from first -weights_ampa_p2 = {'L2_basket': 0.000003, 'L2_pyramidal': 1.438840, - 'L5_basket': 0.008958, 'L5_pyramidal': 0.684013} +weights_ampa_p2 = { + 'L2_basket': 0.000003, + 'L2_pyramidal': 1.438840, + 'L5_basket': 0.008958, + 'L5_pyramidal': 0.684013, +} # all NMDA weights are zero; omit weights_nmda (defaults to None) net.add_evoked_drive( - 'evprox2', mu=137.12, sigma=8.33, numspikes=1, - weights_ampa=weights_ampa_p2, location='proximal', - synaptic_delays=synaptic_delays_prox, event_seed=814) + 'evprox2', + mu=137.12, + sigma=8.33, + numspikes=1, + weights_ampa=weights_ampa_p2, + location='proximal', + synaptic_delays=synaptic_delays_prox, + event_seed=814, +) ############################################################################### # Now let's simulate the dipole, running 2 trials with the @@ -93,7 +125,7 @@ from hnn_core import JoblibBackend with JoblibBackend(n_jobs=2): - dpls = simulate_dipole(net, tstop=170., n_trials=2) + dpls = simulate_dipole(net, tstop=170.0, n_trials=2) ############################################################################### # Rather than reading smoothing and scaling parameters from file, we recommend @@ -107,11 +139,10 @@ ############################################################################### # Plot the amplitudes of the simulated aggregate dipole moments over time import matplotlib.pyplot as plt -fig, axes = plt.subplots(2, 1, sharex=True, figsize=(6, 6), - constrained_layout=True) + +fig, axes = plt.subplots(2, 1, sharex=True, figsize=(6, 6), constrained_layout=True) plot_dipole(dpls, ax=axes[0], layer='agg', show=False) -net.cell_response.plot_spikes_hist(ax=axes[1], - spike_types=['evprox', 'evdist']) +net.cell_response.plot_spikes_hist(ax=axes[1], spike_types=['evprox', 'evdist']) ############################################################################### # If you want to analyze how the different cortical layers contribute to @@ -128,23 +159,49 @@ net_sync = jones_2009_model() -n_drive_cells=1 -cell_specific=False +n_drive_cells = 1 +cell_specific = False net_sync.add_evoked_drive( - 'evdist1', mu=63.53, sigma=3.85, numspikes=1, weights_ampa=weights_ampa_d1, - weights_nmda=weights_nmda_d1, location='distal', n_drive_cells=n_drive_cells, - cell_specific=cell_specific, synaptic_delays=synaptic_delays_d1, event_seed=274) + 'evdist1', + mu=63.53, + sigma=3.85, + numspikes=1, + weights_ampa=weights_ampa_d1, + weights_nmda=weights_nmda_d1, + location='distal', + n_drive_cells=n_drive_cells, + cell_specific=cell_specific, + synaptic_delays=synaptic_delays_d1, + event_seed=274, +) net_sync.add_evoked_drive( - 'evprox1', mu=26.61, sigma=2.47, numspikes=1, weights_ampa=weights_ampa_p1, - weights_nmda=None, location='proximal', n_drive_cells=n_drive_cells, - cell_specific=cell_specific, synaptic_delays=synaptic_delays_prox, event_seed=544) + 'evprox1', + mu=26.61, + sigma=2.47, + numspikes=1, + weights_ampa=weights_ampa_p1, + weights_nmda=None, + location='proximal', + n_drive_cells=n_drive_cells, + cell_specific=cell_specific, + synaptic_delays=synaptic_delays_prox, + event_seed=544, +) net_sync.add_evoked_drive( - 'evprox2', mu=137.12, sigma=8.33, numspikes=1, - weights_ampa=weights_ampa_p2, location='proximal', n_drive_cells=n_drive_cells, - cell_specific=cell_specific, synaptic_delays=synaptic_delays_prox, event_seed=814) + 'evprox2', + mu=137.12, + sigma=8.33, + numspikes=1, + weights_ampa=weights_ampa_p2, + location='proximal', + n_drive_cells=n_drive_cells, + cell_specific=cell_specific, + synaptic_delays=synaptic_delays_prox, + event_seed=814, +) ############################################################################### # You may interrogate current values defining the spike event time dynamics by @@ -153,7 +210,7 @@ ############################################################################### # Finally, let's simulate this network. Rather than modifying the dipole # object, this time we make a copy of it before smoothing and scaling. -dpls_sync = simulate_dipole(net_sync, tstop=170., n_trials=1) +dpls_sync = simulate_dipole(net_sync, tstop=170.0, n_trials=1) trial_idx = 0 dpls_sync[trial_idx].copy().smooth(window_len).scale(scaling_factor).plot() diff --git a/examples/workflows/plot_simulate_gamma.py b/examples/workflows/plot_simulate_gamma.py index 558e3f145..9b0209063 100644 --- a/examples/workflows/plot_simulate_gamma.py +++ b/examples/workflows/plot_simulate_gamma.py @@ -50,12 +50,16 @@ synaptic_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.0} rate_constant = {'L2_pyramidal': 140.0, 'L5_pyramidal': 40.0} net.add_poisson_drive( - 'poisson', rate_constant=rate_constant, weights_ampa=weights_ampa, - location='proximal', synaptic_delays=synaptic_delays, - event_seed=1349) + 'poisson', + rate_constant=rate_constant, + weights_ampa=weights_ampa, + location='proximal', + synaptic_delays=synaptic_delays, + event_seed=1349, +) ############################################################################### -dpls = simulate_dipole(net, tstop=250.) +dpls = simulate_dipole(net, tstop=250.0) scaling_factor = 30000 dpls = [dpl.scale(scaling_factor) for dpl in dpls] # scale in place @@ -79,13 +83,12 @@ import numpy as np import matplotlib.pyplot as plt -fig, axes = plt.subplots(2, 1, sharex=True, figsize=(6, 6), - constrained_layout=True) +fig, axes = plt.subplots(2, 1, sharex=True, figsize=(6, 6), constrained_layout=True) dpls[trial_idx].plot(tmin=tmin, ax=axes[0], show=False) # Create an fixed-step tiling of frequencies from 20 to 100 Hz in steps of 1 Hz -freqs = np.arange(20., 100., 1.) +freqs = np.arange(20.0, 100.0, 1.0) dpls[trial_idx].plot_tfr_morlet(freqs, n_cycles=7, tmin=tmin, ax=axes[1]) ############################################################################### @@ -94,8 +97,8 @@ # more regular, with less noise due to the fact that the tonic depolarization # dominates over the influence of the Poisson drive. By default, a tonic bias # is applied to the entire duration of the simulation. -net.add_tonic_bias(cell_type='L5_pyramidal', amplitude=6.) -dpls = simulate_dipole(net, tstop=250., n_trials=1) +net.add_tonic_bias(cell_type='L5_pyramidal', amplitude=6.0) +dpls = simulate_dipole(net, tstop=250.0, n_trials=1) dpls = [dpl.scale(scaling_factor) for dpl in dpls] # scale in place dpls[trial_idx].plot() @@ -114,7 +117,8 @@ # Although the simulated dipole signal demonstrates clear periodicity, its # frequency is lower compared with the "weak" PING simulation above. from hnn_core.viz import plot_psd -plot_psd(dpls[trial_idx], fmin=20., fmax=100., tmin=tmin) + +plot_psd(dpls[trial_idx], fmin=20.0, fmax=100.0, tmin=tmin) ############################################################################### # Finally, we demonstrate the mechanistic link between PING and the GABAA decay @@ -123,11 +127,10 @@ # refactory period between L5 pyramidal cell spikes and increase the PING # frequency from ~50 to ~65 Hz. net.cell_types['L5_pyramidal'].synapses['gabaa']['tau2'] = 2 -dpls = simulate_dipole(net, tstop=250., n_trials=1) +dpls = simulate_dipole(net, tstop=250.0, n_trials=1) dpls = [dpl.scale(scaling_factor) for dpl in dpls] # scale in place -fig, axes = plt.subplots(3, 1, sharex=True, figsize=(6, 6), - constrained_layout=True) +fig, axes = plt.subplots(3, 1, sharex=True, figsize=(6, 6), constrained_layout=True) dpls[trial_idx].plot(ax=axes[0], show=False) net.cell_response.plot_spikes_raster(ax=axes[1], show=False) dpls[trial_idx].plot_tfr_morlet(freqs, n_cycles=7, tmin=tmin, ax=axes[2]) diff --git a/examples/workflows/plot_simulate_somato.py b/examples/workflows/plot_simulate_somato.py index 14141bfaa..afdb45f64 100644 --- a/examples/workflows/plot_simulate_somato.py +++ b/examples/workflows/plot_simulate_somato.py @@ -38,10 +38,18 @@ data_path = somato.data_path() subject = '01' task = 'somato' -raw_fname = op.join(data_path, 'sub-{}'.format(subject), 'meg', - 'sub-{}_task-{}_meg.fif'.format(subject, task)) -fwd_fname = op.join(data_path, 'derivatives', 'sub-{}'.format(subject), - 'sub-{}_task-{}-fwd.fif'.format(subject, task)) +raw_fname = op.join( + data_path, + 'sub-{}'.format(subject), + 'meg', + 'sub-{}_task-{}_meg.fif'.format(subject, task), +) +fwd_fname = op.join( + data_path, + 'derivatives', + 'sub-{}'.format(subject), + 'sub-{}_task-{}-fwd.fif'.format(subject, task), +) subjects_dir = op.join(data_path, 'derivatives', 'freesurfer', 'subjects') ############################################################################### @@ -56,10 +64,18 @@ events = mne.find_events(raw, stim_channel='STI 014') # Define epochs within the time series -event_id, tmin, tmax = 1, -.2, .17 +event_id, tmin, tmax = 1, -0.2, 0.17 baseline = None -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=baseline, - reject=dict(grad=4000e-13, eog=350e-6), preload=True) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + baseline=baseline, + reject=dict(grad=4000e-13, eog=350e-6), + preload=True, +) # Compute the inverse operator fwd = mne.read_forward_solution(fwd_fname) @@ -75,21 +91,27 @@ # to note that the dipole currents simulated with HNN are assumed to be normal # to the cortical surface. Hence, using the option ``pick_ori='normal'`` is # appropriate. -snr = 3. -lambda2 = 1. / snr ** 2 +snr = 3.0 +lambda2 = 1.0 / snr**2 evoked = epochs.average() -stc = apply_inverse(evoked, inv, lambda2, method='MNE', - pick_ori="normal", return_residual=False, - verbose=True) +stc = apply_inverse( + evoked, + inv, + lambda2, + method='MNE', + pick_ori='normal', + return_residual=False, + verbose=True, +) ############################################################################### # To extract the primary response in primary somatosensory cortex (S1), we # create a label for the postcentral gyrus (S1) in source-space hemi = 'rh' label_tag = 'G_postcentral' -label_s1 = mne.read_labels_from_annot(subject, parc='aparc.a2009s', hemi=hemi, - regexp=label_tag, - subjects_dir=subjects_dir)[0] +label_s1 = mne.read_labels_from_annot( + subject, parc='aparc.a2009s', hemi=hemi, regexp=label_tag, subjects_dir=subjects_dir +)[0] ############################################################################### # Visualizing the distributed S1 activation in reference to the geometric @@ -103,7 +125,7 @@ # post-central gyrus label from which the dipole time course was extracted and # the second showing MNE activation at 0.040 sec that resemble the following # images. -''' +""" Brain = mne.viz.get_brain_class() brain_label = Brain(subject, hemi, 'white', subjects_dir=subjects_dir) brain_label.add_label(label_s1, color='green', alpha=0.9) @@ -111,7 +133,7 @@ brain = stc_label.plot(subjects_dir=subjects_dir, hemi=hemi, surface='white', view_layout='horizontal', initial_time=0.04, backend='pyvista') -''' +""" ############################################################################### # |mne_label_fig| @@ -127,8 +149,7 @@ # extracted waveform so that the deflection at ~0.040 sec is pointed downwards. # Thus, the ~0.040 sec deflection corresponds to current flow traveling from # superficial to deep layers of cortex. -flip_data = stc.extract_label_time_course(label_s1, inv['src'], - mode='pca_flip') +flip_data = stc.extract_label_time_course(label_s1, inv['src'], mode='pca_flip') dipole_tc = -flip_data[0] * 1e9 plt.figure() @@ -162,60 +183,110 @@ # in the network. # Early proximal drive -weights_ampa_p = {'L2_basket': 0.0036, 'L2_pyramidal': 0.0039, - 'L5_basket': 0.0019, 'L5_pyramidal': 0.0020} -weights_nmda_p = {'L2_basket': 0.0029, 'L2_pyramidal': 0.0005, - 'L5_basket': 0.0030, 'L5_pyramidal': 0.0019} -synaptic_delays_p = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_basket': 1.0, 'L5_pyramidal': 1.0} +weights_ampa_p = { + 'L2_basket': 0.0036, + 'L2_pyramidal': 0.0039, + 'L5_basket': 0.0019, + 'L5_pyramidal': 0.0020, +} +weights_nmda_p = { + 'L2_basket': 0.0029, + 'L2_pyramidal': 0.0005, + 'L5_basket': 0.0030, + 'L5_pyramidal': 0.0019, +} +synaptic_delays_p = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.1, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.0, +} net.add_evoked_drive( - 'evprox1', mu=21., sigma=4., numspikes=1, location='proximal', - n_drive_cells=1, cell_specific=False, weights_ampa=weights_ampa_p, - weights_nmda=weights_nmda_p, synaptic_delays=synaptic_delays_p, - event_seed=276) + 'evprox1', + mu=21.0, + sigma=4.0, + numspikes=1, + location='proximal', + n_drive_cells=1, + cell_specific=False, + weights_ampa=weights_ampa_p, + weights_nmda=weights_nmda_p, + synaptic_delays=synaptic_delays_p, + event_seed=276, +) # Late proximal drive -weights_ampa_p = {'L2_basket': 0.003, 'L2_pyramidal': 0.0039, - 'L5_basket': 0.004, 'L5_pyramidal': 0.0020} -weights_nmda_p = {'L2_basket': 0.001, 'L2_pyramidal': 0.0005, - 'L5_basket': 0.002, 'L5_pyramidal': 0.0020} -synaptic_delays_p = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_basket': 1.0, 'L5_pyramidal': 1.0} +weights_ampa_p = { + 'L2_basket': 0.003, + 'L2_pyramidal': 0.0039, + 'L5_basket': 0.004, + 'L5_pyramidal': 0.0020, +} +weights_nmda_p = { + 'L2_basket': 0.001, + 'L2_pyramidal': 0.0005, + 'L5_basket': 0.002, + 'L5_pyramidal': 0.0020, +} +synaptic_delays_p = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.1, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.0, +} net.add_evoked_drive( - 'evprox2', mu=134., sigma=4.5, numspikes=1, location='proximal', - n_drive_cells=1, cell_specific=False, weights_ampa=weights_ampa_p, - weights_nmda=weights_nmda_p, synaptic_delays=synaptic_delays_p, - event_seed=276) + 'evprox2', + mu=134.0, + sigma=4.5, + numspikes=1, + location='proximal', + n_drive_cells=1, + cell_specific=False, + weights_ampa=weights_ampa_p, + weights_nmda=weights_nmda_p, + synaptic_delays=synaptic_delays_p, + event_seed=276, +) # Early distal drive -weights_ampa_d = {'L2_basket': 0.0043, 'L2_pyramidal': 0.0032, - 'L5_pyramidal': 0.0009} -weights_nmda_d = {'L2_basket': 0.0029, 'L2_pyramidal': 0.0051, - 'L5_pyramidal': 0.0010} -synaptic_delays_d = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_pyramidal': 0.1} +weights_ampa_d = {'L2_basket': 0.0043, 'L2_pyramidal': 0.0032, 'L5_pyramidal': 0.0009} +weights_nmda_d = {'L2_basket': 0.0029, 'L2_pyramidal': 0.0051, 'L5_pyramidal': 0.0010} +synaptic_delays_d = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, 'L5_pyramidal': 0.1} net.add_evoked_drive( - 'evdist1', mu=32., sigma=2.5, numspikes=1, location='distal', - n_drive_cells=1, cell_specific=False, weights_ampa=weights_ampa_d, - weights_nmda=weights_nmda_d, synaptic_delays=synaptic_delays_d, - event_seed=277) + 'evdist1', + mu=32.0, + sigma=2.5, + numspikes=1, + location='distal', + n_drive_cells=1, + cell_specific=False, + weights_ampa=weights_ampa_d, + weights_nmda=weights_nmda_d, + synaptic_delays=synaptic_delays_d, + event_seed=277, +) # Late distal drive -weights_ampa_d = {'L2_basket': 0.0041, 'L2_pyramidal': 0.0019, - 'L5_pyramidal': 0.0018} -weights_nmda_d = {'L2_basket': 0.0032, 'L2_pyramidal': 0.0018, - 'L5_pyramidal': 0.0017} -synaptic_delays_d = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_pyramidal': 0.1} +weights_ampa_d = {'L2_basket': 0.0041, 'L2_pyramidal': 0.0019, 'L5_pyramidal': 0.0018} +weights_nmda_d = {'L2_basket': 0.0032, 'L2_pyramidal': 0.0018, 'L5_pyramidal': 0.0017} +synaptic_delays_d = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, 'L5_pyramidal': 0.1} net.add_evoked_drive( - 'evdist2', mu=84., sigma=4.5, numspikes=1, location='distal', - n_drive_cells=1, cell_specific=False, weights_ampa=weights_ampa_d, - weights_nmda=weights_nmda_d, synaptic_delays=synaptic_delays_d, - event_seed=275) + 'evdist2', + mu=84.0, + sigma=4.5, + numspikes=1, + location='distal', + n_drive_cells=1, + cell_specific=False, + weights_ampa=weights_ampa_d, + weights_nmda=weights_nmda_d, + synaptic_delays=synaptic_delays_d, + event_seed=275, +) ############################################################################### # Now we run the simulation over 2 trials so that we can plot the average @@ -224,7 +295,7 @@ n_trials = 2 # n_trials = 25 with JoblibBackend(n_jobs=2): - dpls = simulate_dipole(net, tstop=170., n_trials=n_trials) + dpls = simulate_dipole(net, tstop=170.0, n_trials=n_trials) ############################################################################### # Since the model is a reduced representation of the larger network @@ -243,11 +314,10 @@ ############################################################################### # Finally, we plot the driving spike histogram, empirical and simulated median # nerve evoked response waveforms, and output spike histogram. -fig, axes = plt.subplots(3, 1, sharex=True, figsize=(6, 6), - constrained_layout=True) -net.cell_response.plot_spikes_hist(ax=axes[0], - spike_types=['evprox', 'evdist'], - show=False) +fig, axes = plt.subplots(3, 1, sharex=True, figsize=(6, 6), constrained_layout=True) +net.cell_response.plot_spikes_hist( + ax=axes[0], spike_types=['evprox', 'evdist'], show=False +) axes[1].axhline(0, c='k', ls=':', label='_nolegend_') axes[1].plot(1e3 * stc.times, dipole_tc, 'r--') average_dipoles(dpls).plot(ax=axes[1], show=False) diff --git a/hnn_core/__init__.py b/hnn_core/__init__.py index eb2ad7afd..cc3c2978e 100644 --- a/hnn_core/__init__.py +++ b/hnn_core/__init__.py @@ -1,4 +1,10 @@ -from .dipole import simulate_dipole, read_dipole, average_dipoles, Dipole,_read_dipole_txt +from .dipole import ( + simulate_dipole, + read_dipole, + average_dipoles, + Dipole, + _read_dipole_txt, +) from .params import Params, read_params, convert_to_json from .network import Network, pick_connection from .network_models import jones_2009_model, law_2021_model, calcium_model diff --git a/hnn_core/batch_simulate.py b/hnn_core/batch_simulate.py index 4f153b40e..695bdfaf7 100644 --- a/hnn_core/batch_simulate.py +++ b/hnn_core/batch_simulate.py @@ -16,15 +16,29 @@ class BatchSimulate(object): - def __init__(self, set_params, net=jones_2009_model(), tstop=170, - dt=0.025, n_trials=1, record_vsec=False, - record_isec=False, postproc=False, save_outputs=False, - save_folder='./sim_results', batch_size=100, - overwrite=True, summary_func=None, - save_dpl=True, save_spiking=False, - save_lfp=False, save_voltages=False, - save_currents=False, save_calcium=False, - clear_cache=False): + def __init__( + self, + set_params, + net=jones_2009_model(), + tstop=170, + dt=0.025, + n_trials=1, + record_vsec=False, + record_isec=False, + postproc=False, + save_outputs=False, + save_folder='./sim_results', + batch_size=100, + overwrite=True, + summary_func=None, + save_dpl=True, + save_spiking=False, + save_lfp=False, + save_voltages=False, + save_currents=False, + save_calcium=False, + clear_cache=False, + ): """Initialize the BatchSimulate class. Parameters @@ -129,10 +143,10 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170, _validate_type(clear_cache, types=(bool,), item_name='clear_cache') if set_params is not None and not callable(set_params): - raise TypeError("set_params must be a callable function") + raise TypeError('set_params must be a callable function') if summary_func is not None and not callable(summary_func): - raise TypeError("summary_func must be a callable function") + raise TypeError('summary_func must be a callable function') self.net = net self.set_params = set_params @@ -155,9 +169,15 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170, self.save_calcium = save_calcium self.clear_cache = clear_cache - def run(self, param_grid, return_output=True, - combinations=True, n_jobs=1, backend='loky', - verbose=50): + def run( + self, + param_grid, + return_output=True, + combinations=True, + n_jobs=1, + backend='loky', + verbose=50, + ): """Run batch simulations. Parameters @@ -194,12 +214,12 @@ def run(self, param_grid, return_output=True, """ _validate_type(param_grid, types=(dict,), item_name='param_grid') _validate_type(n_jobs, types='int', item_name='n_jobs') - _check_option('backend', backend, ['loky', 'threading', - 'multiprocessing', 'dask']) + _check_option( + 'backend', backend, ['loky', 'threading', 'multiprocessing', 'dask'] + ) _validate_type(verbose, types='int', item_name='verbose') - param_combinations = self._generate_param_combinations( - param_grid, combinations) + param_combinations = self._generate_param_combinations(param_grid, combinations) total_sims = len(param_combinations) num_sims_per_batch = max(total_sims // self.batch_size, 1) batch_size = min(self.batch_size, total_sims) @@ -215,7 +235,8 @@ def run(self, param_grid, return_output=True, param_combinations[start_idx:end_idx], n_jobs=n_jobs, backend=backend, - verbose=verbose) + verbose=verbose, + ) if self.save_outputs: self._save(batch_results, start_idx, end_idx) @@ -234,13 +255,11 @@ def run(self, param_grid, return_output=True, if return_output: if self.clear_cache: - return {"summary_statistics": results} + return {'summary_statistics': results} else: - return {"summary_statistics": results, - "simulated_data": simulated_data} + return {'summary_statistics': results, 'simulated_data': simulated_data} - def simulate_batch(self, param_combinations, n_jobs=1, - backend='loky', verbose=50): + def simulate_batch(self, param_combinations, n_jobs=1, backend='loky', verbose=50): """Simulate a batch of parameter sets in parallel. Parameters @@ -264,17 +283,19 @@ def simulate_batch(self, param_combinations, n_jobs=1, - `dpl`: The simulated dipole. - `param_values`: The parameter values used for the simulation. """ - _validate_type(param_combinations, types=(list,), - item_name='param_combinations') + _validate_type( + param_combinations, types=(list,), item_name='param_combinations' + ) _validate_type(n_jobs, types='int', item_name='n_jobs') - _check_option('backend', backend, ['loky', 'threading', - 'multiprocessing', 'dask']) + _check_option( + 'backend', backend, ['loky', 'threading', 'multiprocessing', 'dask'] + ) _validate_type(verbose, types='int', item_name='verbose') with parallel_config(backend=backend): res = Parallel(n_jobs=n_jobs, verbose=verbose)( - delayed(self._run_single_sim)( - params) for params in param_combinations) + delayed(self._run_single_sim)(params) for params in param_combinations + ) return res def _run_single_sim(self, param_values): @@ -301,20 +322,22 @@ def _run_single_sim(self, param_values): results = {'net': net, 'param_values': param_values} if self.save_dpl: - dpl = simulate_dipole(net, - tstop=self.tstop, - dt=self.dt, - n_trials=self.n_trials, - record_vsec=self.record_vsec, - record_isec=self.record_isec, - postproc=self.postproc) + dpl = simulate_dipole( + net, + tstop=self.tstop, + dt=self.dt, + n_trials=self.n_trials, + record_vsec=self.record_vsec, + record_isec=self.record_isec, + postproc=self.postproc, + ) results['dpl'] = dpl if self.save_spiking: results['spiking'] = { 'spike_times': net.cell_response.spike_times, 'spike_types': net.cell_response.spike_types, - 'spike_gids': net.cell_response.spike_gids + 'spike_gids': net.cell_response.spike_gids, } if self.save_lfp: @@ -352,11 +375,13 @@ def _generate_param_combinations(self, param_grid, combinations=True): keys, values = zip(*param_grid.items()) if combinations: - param_combinations = [dict(zip(keys, combination)) - for combination in product(*values)] + param_combinations = [ + dict(zip(keys, combination)) for combination in product(*values) + ] else: - param_combinations = [dict(zip(keys, combination)) - for combination in zip(*values)] + param_combinations = [ + dict(zip(keys, combination)) for combination in zip(*values) + ] return param_combinations def _save(self, results, start_idx, end_idx): @@ -378,22 +403,25 @@ def _save(self, results, start_idx, end_idx): if not os.path.exists(self.save_folder): os.makedirs(self.save_folder) - save_data = { - 'param_values': [result['param_values'] for result in results] - } + save_data = {'param_values': [result['param_values'] for result in results]} - attributes_to_save = ['dpl', 'spiking', 'lfp', - 'voltages', 'currents', 'calcium'] + attributes_to_save = [ + 'dpl', + 'spiking', + 'lfp', + 'voltages', + 'currents', + 'calcium', + ] for attr in attributes_to_save: if getattr(self, f'save_{attr}') and attr in results[0]: save_data[attr] = [result[attr] for result in results] - file_name = os.path.join(self.save_folder, - f'sim_run_{start_idx}-{end_idx}.npz') + file_name = os.path.join(self.save_folder, f'sim_run_{start_idx}-{end_idx}.npz') if os.path.exists(file_name) and not self.overwrite: raise FileExistsError( - f"File {file_name} already exists and " - "overwrite is set to False.") + f'File {file_name} already exists and ' 'overwrite is set to False.' + ) np.savez(file_name, **save_data) @@ -429,8 +457,11 @@ def load_results(self, file_path, return_data=None): return_data.append('calcium') data = np.load(file_path, allow_pickle=True) - results = {key: data[key].tolist() for key in data.files - if key in return_data or key == 'param_values'} + results = { + key: data[key].tolist() + for key in data.files + if key in return_data or key == 'param_values' + } return results def load_all_results(self): diff --git a/hnn_core/cell.py b/hnn_core/cell.py index 7e533f087..1835dbb81 100644 --- a/hnn_core/cell.py +++ b/hnn_core/cell.py @@ -19,8 +19,9 @@ def _get_cos_theta(sections, sec_name_apical): """Get cos(theta) to compute dipole along the apical dendrite.""" - a = (np.array(sections[sec_name_apical].end_pts[1]) - - np.array(sections[sec_name_apical].end_pts[0])) + a = np.array(sections[sec_name_apical].end_pts[1]) - np.array( + sections[sec_name_apical].end_pts[0] + ) cos_thetas = dict() for sec_name, section in sections.items(): b = np.array(section.end_pts[1]) - np.array(section.end_pts[0]) @@ -55,8 +56,7 @@ def _calculate_gaussian(x_val, height, lamtha): return x_height -def _get_gaussian_connection(src_pos, target_pos, nc_dict, - inplane_distance=1.): +def _get_gaussian_connection(src_pos, target_pos, nc_dict, inplane_distance=1.0): """Calculate distance dependent connection properties. Parameters @@ -88,15 +88,13 @@ def _get_gaussian_connection(src_pos, target_pos, nc_dict, cell_dist = np.sqrt(x_dist**2 + y_dist**2) scaled_lamtha = nc_dict['lamtha'] * inplane_distance - weight = _calculate_gaussian( - cell_dist, nc_dict['A_weight'], scaled_lamtha) - delay = nc_dict['A_delay'] / _calculate_gaussian( - cell_dist, 1, scaled_lamtha) + weight = _calculate_gaussian(cell_dist, nc_dict['A_weight'], scaled_lamtha) + delay = nc_dict['A_delay'] / _calculate_gaussian(cell_dist, 1, scaled_lamtha) return weight, delay def node_to_str(node): - return node[0] + "," + str(node[1]) + return node[0] + ',' + str(node[1]) class _ArtificialCell: @@ -127,6 +125,7 @@ class _ArtificialCell: gid : int GID of the cell in a network (or None if not yet assigned) """ + def __init__(self, event_times, threshold, gid=None): # Convert event times into nrn vector self.nrn_eventvec = h.Vector() @@ -160,8 +159,8 @@ def gid(self, gid): def _get_nseg(L): nseg = 1 - if L > 100.: # 100 um - nseg = int(L / 50.) + if L > 100.0: # 100 um + nseg = int(L / 50.0) # make dend.nseg odd for all sections if not nseg % 2: nseg += 1 @@ -205,8 +204,8 @@ class Section: nseg : int Number of segments in the section """ - def __init__(self, L, diam, Ra, cm, end_pts=None): + def __init__(self, L, diam, Ra, cm, end_pts=None): self._L = L self._diam = diam self._Ra = Ra @@ -238,8 +237,7 @@ def __eq__(self, other): # Check end_pts for self_end_pt, other_end_pt in zip(self.end_pts, other.end_pts): - if np.testing.assert_almost_equal(self_end_pt, - other_end_pt, 5) is not None: + if np.testing.assert_almost_equal(self_end_pt, other_end_pt, 5) is not None: return False all_attrs = dir(self) @@ -378,14 +376,15 @@ class Cell: ) """ - def __init__(self, name, pos, sections, synapses, sect_loc, cell_tree, - gid=None): + def __init__(self, name, pos, sections, synapses, sect_loc, cell_tree, gid=None): self.name = name self.pos = pos for section in sections.values(): if not isinstance(section, Section): - raise ValueError(f'Items in section must be instances' - f' of Section. Got {type(section)}') + raise ValueError( + f'Items in section must be instances' + f' of Section. Got {type(section)}' + ) self.sections = sections self.synapses = synapses self.sect_loc = sect_loc @@ -419,12 +418,25 @@ def __eq__(self, other): all_attrs = dir(self) attrs_to_ignore = [x for x in all_attrs if x.startswith('_')] - attrs_to_ignore.extend(['build', 'copy', 'create_tonic_bias', - 'define_shape', 'distance_section', 'gid', - 'list_IClamp', 'modify_section', - 'parconnect_from_src', 'plot_morphology', - 'record', 'sections', 'setup_source_netcon', - 'syn_create', 'to_dict']) + attrs_to_ignore.extend( + [ + 'build', + 'copy', + 'create_tonic_bias', + 'define_shape', + 'distance_section', + 'gid', + 'list_IClamp', + 'modify_section', + 'parconnect_from_src', + 'plot_morphology', + 'record', + 'sections', + 'setup_source_netcon', + 'syn_create', + 'to_dict', + ] + ) attrs_to_check = [x for x in all_attrs if x not in attrs_to_ignore] # Check all other attributes @@ -508,8 +520,7 @@ def distance_section(self, target_sec_name, curr_node): # Python version of the Neuron distance function # https://nrn.readthedocs.io/en/latest/python/modelspec/programmatic/topology/geometry.html#distance # noqa if self.cell_tree is None: - raise TypeError("distance_section() " - "cannot work with cell_tree as None.") + raise TypeError('distance_section() ' 'cannot work with cell_tree as None.') if curr_node not in self.cell_tree: return np.nan @@ -530,9 +541,11 @@ def distance_section(self, target_sec_name, curr_node): # Recursion to find distance for node in self.cell_tree[curr_node]: - if (node[0] == curr_node[0]): - dist_temp = (self.distance_section(target_sec_name, node) + - self.sections[node[0]].L) + if node[0] == curr_node[0]: + dist_temp = ( + self.distance_section(target_sec_name, node) + + self.sections[node[0]].L + ) else: dist_temp = self.distance_section(target_sec_name, node) if np.isnan(dist) and np.isnan(dist_temp): @@ -576,18 +589,15 @@ def _compute_section_mechs(self): for attr, val in p_mech.items(): if hasattr(val, '__call__'): seg_xs, seg_vals = list(), list() - section_distance = self.distance_section(sec_name, - ('soma', 0)) - seg_centers = (np.linspace(0, 1, section.nseg * 2 + 1) - [1::2]) + section_distance = self.distance_section(sec_name, ('soma', 0)) + seg_centers = np.linspace(0, 1, section.nseg * 2 + 1)[1::2] for seg_x in seg_centers: # sec_end_dist is distance between 0 end of soma to # the 0 or 1 end of section (whichever is closer) sec_end_dist = section_distance - (section.L / 2) seg_xs.append(seg_x) - seg_vals.append(val(sec_end_dist + - (seg_x * section.L))) + seg_vals.append(val(sec_end_dist + (seg_x * section.L))) p_mech[attr] = [seg_xs, seg_vals] return self.sections @@ -597,8 +607,7 @@ def _create_synapses(self, sections, synapses): for receptor in sections[sec_name].syns: syn_key = f'{sec_name}_{receptor}' seg = self._nrn_sections[sec_name](0.5) - self._nrn_synapses[syn_key] = self.syn_create( - seg, **synapses[receptor]) + self._nrn_synapses[syn_key] = self.syn_create(seg, **synapses[receptor]) def _create_sections(self, sections, cell_tree): """Create soma and set geometry. @@ -662,9 +671,11 @@ def build(self, sec_name_apical=None): if sec_name_apical in self._nrn_sections: self._insert_dipole(sec_name_apical) elif sec_name_apical is not None: - raise ValueError(f'sec_name_apical must be an existing ' - f'section of the current cell or None. ' - f'Got {sec_name_apical}.') + raise ValueError( + f'sec_name_apical must be an existing ' + f'section of the current cell or None. ' + f'Got {sec_name_apical}.' + ) def copy(self): """Return copy of instance.""" @@ -749,7 +760,7 @@ def create_tonic_bias(self, amplitude, t0, tstop, loc=0.5): self.tonic_biases.append(stim) def record(self, record_vsec=False, record_isec=False, record_ca=False): - """ Record current and voltage from all sections + """Record current and voltage from all sections Parameters ---------- @@ -775,8 +786,7 @@ def record(self, record_vsec=False, record_isec=False, record_ca=False): if record_vsec: for sec_name in self.vsec: self.vsec[sec_name] = h.Vector() - self.vsec[sec_name].record( - self._nrn_sections[sec_name](0.5)._ref_v) + self.vsec[sec_name].record(self._nrn_sections[sec_name](0.5)._ref_v) if record_isec == 'soma': self.isec = dict.fromkeys(['soma']) @@ -785,14 +795,18 @@ def record(self, record_vsec=False, record_isec=False, record_ca=False): if record_isec: for sec_name in self.isec: - list_syn = [key for key in self._nrn_synapses.keys() - if key.startswith(f'{sec_name}_')] + list_syn = [ + key + for key in self._nrn_synapses.keys() + if key.startswith(f'{sec_name}_') + ] self.isec[sec_name] = dict.fromkeys(list_syn) for syn_name in self.isec[sec_name]: self.isec[sec_name][syn_name] = h.Vector() self.isec[sec_name][syn_name].record( - self._nrn_synapses[syn_name]._ref_i) + self._nrn_synapses[syn_name]._ref_i + ) # calcium concentration if record_ca == 'soma': @@ -804,8 +818,7 @@ def record(self, record_vsec=False, record_isec=False, record_ca=False): for sec_name in self.ca: if hasattr(self._nrn_sections[sec_name](0.5), '_ref_cai'): self.ca[sec_name] = h.Vector() - self.ca[sec_name].record( - self._nrn_sections[sec_name](0.5)._ref_cai) + self.ca[sec_name].record(self._nrn_sections[sec_name](0.5)._ref_cai) def syn_create(self, secloc, e, tau1, tau2): """Create an h.Exp2Syn synapse. @@ -827,8 +840,9 @@ def syn_create(self, secloc, e, tau1, tau2): A two state kinetic scheme synapse. """ if not isinstance(secloc, nrn.Segment): - raise TypeError(f'secloc must be instance of' - f'nrn.Segment. Got {type(secloc)}') + raise TypeError( + f'secloc must be instance of' f'nrn.Segment. Got {type(secloc)}' + ) syn = h.Exp2Syn(secloc) syn.e = e syn.tau1 = tau1 @@ -843,13 +857,13 @@ def setup_source_netcon(self, threshold): threshold : float The voltage threshold for action potential. """ - nc = h.NetCon(self._nrn_sections['soma'](0.5)._ref_v, None, - sec=self._nrn_sections['soma']) + nc = h.NetCon( + self._nrn_sections['soma'](0.5)._ref_v, None, sec=self._nrn_sections['soma'] + ) nc.threshold = threshold return nc - def parconnect_from_src(self, gid_presyn, nc_dict, postsyn, - inplane_distance): + def parconnect_from_src(self, gid_presyn, nc_dict, postsyn, inplane_distance): """Parallel receptor-centric connect FROM presyn TO this cell, based on GID. @@ -878,14 +892,21 @@ def parconnect_from_src(self, gid_presyn, nc_dict, postsyn, # set props here. nc.threshold = nc_dict['threshold'] nc.weight[0], nc.delay = _get_gaussian_connection( - nc_dict['pos_src'], self.pos, nc_dict, - inplane_distance=inplane_distance) + nc_dict['pos_src'], self.pos, nc_dict, inplane_distance=inplane_distance + ) return nc - def plot_morphology(self, ax=None, color=None, pos=(0, 0, 0), - xlim=(-250, 150), ylim=(-100, 100), zlim=(-100, 1200), - show=True): + def plot_morphology( + self, + ax=None, + color=None, + pos=(0, 0, 0), + xlim=(-250, 150), + ylim=(-100, 100), + zlim=(-100, 1200), + show=True, + ): """Plot the cell morphology. Parameters @@ -916,8 +937,16 @@ def plot_morphology(self, ax=None, color=None, pos=(0, 0, 0), axes : instance of Axes3D The matplotlib 3D axis handle. """ - return plot_cell_morphology(self, ax=ax, color=color, pos=pos, - xlim=xlim, ylim=ylim, zlim=zlim, show=show) + return plot_cell_morphology( + self, + ax=ax, + color=color, + pos=pos, + xlim=xlim, + ylim=ylim, + zlim=zlim, + show=show, + ) def _update_section_end_pts_L(self, node, dpt): if self.cell_tree is None: @@ -966,8 +995,7 @@ def define_shape(self, node): node_opp_end = 0 pts = self.sections[node[0]].end_pts x0, y0, z0 = pts[node[1]][0], pts[node[1]][1], pts[node[1]][2] - x1, y1, z1 = (pts[node_opp_end][0], pts[node_opp_end][1], - pts[node_opp_end][2]) + x1, y1, z1 = (pts[node_opp_end][0], pts[node_opp_end][1], pts[node_opp_end][2]) # Find the factor by which length is changed end_1 = np.array((x0, y0, z0)) @@ -1017,13 +1045,7 @@ def _update_end_pts(self): end_pts = self.sections[sec_name].end_pts updated_end_pts = list() for pt in end_pts: - updated_end_pts.append( - [ - pt[0] + dx, - pt[1] + dy, - pt[2] + dz - ] - ) + updated_end_pts.append([pt[0] + dx, pt[1] + dy, pt[2] + dz]) self.sections[sec_name]._end_pts = updated_end_pts # Check and update all end pts starting from root according to length diff --git a/hnn_core/cell_response.py b/hnn_core/cell_response.py index e81b20b1d..a9bf27327 100644 --- a/hnn_core/cell_response.py +++ b/hnn_core/cell_response.py @@ -80,8 +80,14 @@ class CellResponse(object): Write spiking activity to a collection of spike trial files. """ - def __init__(self, spike_times=None, spike_gids=None, spike_types=None, - times=None, cell_type_names=None): + def __init__( + self, + spike_times=None, + spike_gids=None, + spike_types=None, + times=None, + cell_type_names=None, + ): if spike_times is None: spike_times = list() if spike_gids is None: @@ -92,28 +98,28 @@ def __init__(self, spike_times=None, spike_gids=None, spike_types=None, times = list() if cell_type_names is None: - cell_type_names = ['L2_basket', 'L2_pyramidal', - 'L5_basket', 'L5_pyramidal'] + cell_type_names = ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal'] # Validate arguments arg_names = ['spike_times', 'spike_gids', 'spike_types'] for arg_idx, arg in enumerate([spike_times, spike_gids, spike_types]): # Validate outer list if not isinstance(arg, list): - raise TypeError('%s should be a list of lists' - % (arg_names[arg_idx],)) + raise TypeError('%s should be a list of lists' % (arg_names[arg_idx],)) # If arg is not an empty list, validate inner list for trial_list in arg: if not isinstance(trial_list, list): - raise TypeError('%s should be a list of lists' - % (arg_names[arg_idx],)) + raise TypeError( + '%s should be a list of lists' % (arg_names[arg_idx],) + ) # Set the length of 'spike_times' as a references and validate # uniform length if arg == spike_times: n_trials = len(spike_times) if len(arg) != n_trials: - raise ValueError('spike times, gids, and types should be ' - 'lists of the same length') + raise ValueError( + 'spike times, gids, and types should be ' 'lists of the same length' + ) self._spike_times = spike_times self._spike_gids = spike_gids self._spike_types = spike_types @@ -135,19 +141,21 @@ def __eq__(self, other): if not isinstance(other, CellResponse): return NotImplemented # Round each time element - times_self = [[round(time, 3) for time in trial] - for trial in self._spike_times] - times_other = [[round(time, 3) for time in trial] - for trial in other._spike_times] - return (times_self == times_other and - self._spike_gids == other._spike_gids and - self._spike_types == other._spike_types and - self._vsec == other._vsec and - self._isec == other._isec and - self._ca == other._ca and - self.vsec == other.vsec and - self.isec == other.isec and - self.ca == other.ca) + times_self = [[round(time, 3) for time in trial] for trial in self._spike_times] + times_other = [ + [round(time, 3) for time in trial] for trial in other._spike_times + ] + return ( + times_self == times_other + and self._spike_gids == other._spike_gids + and self._spike_types == other._spike_types + and self._vsec == other._vsec + and self._isec == other._isec + and self._ca == other._ca + and self.vsec == other.vsec + and self.isec == other.isec + and self.ca == other.ca + ) @property def spike_times(self): @@ -156,8 +164,7 @@ def spike_times(self): @property def cell_types(self): """Get unique cell types.""" - spike_types_data = np.concatenate(np.array(self.spike_types, - dtype=object)) + spike_types_data = np.concatenate(np.array(self.spike_types, dtype=object)) return np.unique(spike_types_data).tolist() @property @@ -166,8 +173,9 @@ def spike_times_by_type(self): spike_times = dict() for cell_type in self.cell_types: spike_times[cell_type] = list() - for trial_spike_times, trial_spike_types in zip(self.spike_times, - self.spike_types): + for trial_spike_times, trial_spike_types in zip( + self.spike_times, self.spike_types + ): mask = np.isin(trial_spike_types, cell_type) cell_spike_times = np.array(trial_spike_times)[mask].tolist() spike_times[cell_type].append(cell_spike_times) @@ -215,13 +223,15 @@ def update_types(self, gid_ranges): gid_set_1 = set(all_gid_ranges[item_idx_1]) gid_set_2 = set(all_gid_ranges[item_idx_2]) if not gid_set_1.isdisjoint(gid_set_2): - raise ValueError('gid_ranges should contain only disjoint ' - 'sets of gid values') + raise ValueError( + 'gid_ranges should contain only disjoint ' 'sets of gid values' + ) spike_types = list() for trial_idx in range(len(self._spike_times)): - spike_types_trial = np.empty_like(self._spike_times[trial_idx], - dtype=' apical_trunk etc. middle = section_name.replace('_', '') dend_prop[key] = params[f'{cell_type}_{middle}_{key}'] - sections[section_name] = Section(L=dend_prop['L'], - diam=dend_prop['diam'], - Ra=dend_prop['Ra'], - cm=dend_prop['cm']) + sections[section_name] = Section( + L=dend_prop['L'], + diam=dend_prop['diam'], + Ra=dend_prop['Ra'], + cm=dend_prop['cm'], + ) return sections @@ -48,22 +49,28 @@ def _get_pyr_soma(p_all, cell_type): L=p_all[f'{cell_type}_soma_L'], diam=p_all[f'{cell_type}_soma_diam'], cm=p_all[f'{cell_type}_soma_cm'], - Ra=p_all[f'{cell_type}_soma_Ra'] + Ra=p_all[f'{cell_type}_soma_Ra'], ) -def _cell_L2Pyr(override_params, pos=(0., 0., 0), gid=0.): +def _cell_L2Pyr(override_params, pos=(0.0, 0.0, 0), gid=0.0): """The geometry of the default sections in L2Pyr neuron.""" p_all = get_L2Pyr_params_default() if override_params is not None: assert isinstance(override_params, dict) p_all = compare_dictionaries(p_all, override_params) - section_names = ['apical_trunk', 'apical_1', 'apical_tuft', - 'apical_oblique', 'basal_1', 'basal_2', 'basal_3'] - - sections = _get_dends(p_all, cell_type='L2Pyr', - section_names=section_names) + section_names = [ + 'apical_trunk', + 'apical_1', + 'apical_tuft', + 'apical_oblique', + 'basal_1', + 'basal_2', + 'basal_3', + ] + + sections = _get_dends(p_all, cell_type='L2Pyr', section_names=section_names) sections['soma'] = _get_pyr_soma(p_all, 'L2Pyr') end_pts = { @@ -79,11 +86,9 @@ def _cell_L2Pyr(override_params, pos=(0., 0., 0), gid=0.): mechanisms = { 'km': ['gbar_km'], - 'hh2': ['gkbar_hh2', 'gnabar_hh2', - 'gl_hh2', 'el_hh2'] + 'hh2': ['gkbar_hh2', 'gnabar_hh2', 'gl_hh2', 'el_hh2'], } - p_mech = _get_mechanisms(p_all, 'L2Pyr', ['soma'] + section_names, - mechanisms) + p_mech = _get_mechanisms(p_all, 'L2Pyr', ['soma'] + section_names, mechanisms) for sec_name, section in sections.items(): section._end_pts = end_pts[sec_name] @@ -109,22 +114,27 @@ def _cell_L2Pyr(override_params, pos=(0., 0., 0), gid=0.): ('soma', 1): [('apical_trunk', 0)], ('apical_trunk', 1): [('apical_1', 0), ('apical_oblique', 0)], ('apical_1', 1): [('apical_tuft', 0)], - ('basal_1', 1): [('basal_2', 0), ('basal_3', 0)] + ('basal_1', 1): [('basal_2', 0), ('basal_3', 0)], } - sect_loc = {'proximal': ['apical_oblique', 'basal_2', 'basal_3'], - 'distal': ['apical_tuft']} + sect_loc = { + 'proximal': ['apical_oblique', 'basal_2', 'basal_3'], + 'distal': ['apical_tuft'], + } synapses = _get_pyr_syn_props(p_all, 'L2Pyr') - return Cell('L2Pyr', pos, - sections=sections, - synapses=synapses, - sect_loc=sect_loc, - cell_tree=cell_tree, - gid=gid) + return Cell( + 'L2Pyr', + pos, + sections=sections, + synapses=synapses, + sect_loc=sect_loc, + cell_tree=cell_tree, + gid=gid, + ) -def _cell_L5Pyr(override_params, pos=(0., 0., 0), gid=0.): +def _cell_L5Pyr(override_params, pos=(0.0, 0.0, 0), gid=0.0): """The geometry of the default sections in L5Pyr Neuron.""" p_all = get_L5Pyr_params_default() @@ -132,12 +142,18 @@ def _cell_L5Pyr(override_params, pos=(0., 0., 0), gid=0.): assert isinstance(override_params, dict) p_all = compare_dictionaries(p_all, override_params) - section_names = ['apical_trunk', 'apical_1', - 'apical_2', 'apical_tuft', - 'apical_oblique', 'basal_1', 'basal_2', 'basal_3'] - - sections = _get_dends(p_all, cell_type='L5Pyr', - section_names=section_names) + section_names = [ + 'apical_trunk', + 'apical_1', + 'apical_2', + 'apical_tuft', + 'apical_oblique', + 'basal_1', + 'basal_2', + 'basal_3', + ] + + sections = _get_dends(p_all, cell_type='L5Pyr', section_names=section_names) sections['soma'] = _get_pyr_soma(p_all, 'L5Pyr') end_pts = { @@ -149,22 +165,20 @@ def _cell_L5Pyr(override_params, pos=(0., 0., 0), gid=0.): 'apical_tuft': [[0, 0, 883], [0, 0, 1133]], 'basal_1': [[0, 0, 0], [0, 0, -50]], 'basal_2': [[0, 0, -50], [-106, 0, -156]], - 'basal_3': [[0, 0, -50], [106, 0, -156]] + 'basal_3': [[0, 0, -50], [106, 0, -156]], } # units = ['pS/um^2', 'S/cm^2', 'pS/um^2', '??', 'tau', '??'] mechanisms = { - 'hh2': ['gkbar_hh2', 'gnabar_hh2', - 'gl_hh2', 'el_hh2'], + 'hh2': ['gkbar_hh2', 'gnabar_hh2', 'gl_hh2', 'el_hh2'], 'ca': ['gbar_ca'], 'cad': ['taur_cad'], 'kca': ['gbar_kca'], 'km': ['gbar_km'], 'cat': ['gbar_cat'], - 'ar': ['gbar_ar'] + 'ar': ['gbar_ar'], } - p_mech = _get_mechanisms(p_all, 'L5Pyr', ['soma'] + section_names, - mechanisms) + p_mech = _get_mechanisms(p_all, 'L5Pyr', ['soma'] + section_names, mechanisms) for sec_name, section in sections.items(): section._end_pts = end_pts[sec_name] @@ -177,9 +191,9 @@ def _cell_L5Pyr(override_params, pos=(0., 0., 0), gid=0.): section.mechs = p_mech[sec_name] if sec_name != 'soma': - sections[sec_name].mechs['ar']['gbar_ar'] = \ - partial(_exp_g_at_dist, zero_val=1e-6, - exp_term=3e-3, offset=0.0) + sections[sec_name].mechs['ar']['gbar_ar'] = partial( + _exp_g_at_dist, zero_val=1e-6, exp_term=3e-3, offset=0.0 + ) cell_tree = { ('apical_trunk', 0): [('apical_trunk', 1)], @@ -196,30 +210,29 @@ def _cell_L5Pyr(override_params, pos=(0., 0., 0), gid=0.): ('apical_trunk', 1): [('apical_1', 0), ('apical_oblique', 0)], ('apical_1', 1): [('apical_2', 0)], ('apical_2', 1): [('apical_tuft', 0)], - ('basal_1', 1): [('basal_2', 0), ('basal_3', 0)] + ('basal_1', 1): [('basal_2', 0), ('basal_3', 0)], } - sect_loc = {'proximal': ['apical_oblique', 'basal_2', 'basal_3'], - 'distal': ['apical_tuft']} + sect_loc = { + 'proximal': ['apical_oblique', 'basal_2', 'basal_3'], + 'distal': ['apical_tuft'], + } synapses = _get_pyr_syn_props(p_all, 'L5Pyr') - return Cell('L5Pyr', pos, - sections=sections, - synapses=synapses, - sect_loc=sect_loc, - cell_tree=cell_tree, - gid=gid) + return Cell( + 'L5Pyr', + pos, + sections=sections, + synapses=synapses, + sect_loc=sect_loc, + cell_tree=cell_tree, + gid=gid, + ) def _get_basket_soma(cell_name): - end_pts = [[0, 0, 0], [0, 0, 39.]] - return Section( - L=39., - diam=20., - cm=0.85, - Ra=200., - end_pts=end_pts - ) + end_pts = [[0, 0, 0], [0, 0, 39.0]] + return Section(L=39.0, diam=20.0, cm=0.85, Ra=200.0, end_pts=end_pts) def _get_pyr_syn_props(p_all, cell_type): @@ -243,27 +256,15 @@ def _get_pyr_syn_props(p_all, cell_type): 'e': p_all['%s_gabab_e' % cell_type], 'tau1': p_all['%s_gabab_tau1' % cell_type], 'tau2': p_all['%s_gabab_tau2' % cell_type], - } + }, } def _get_basket_syn_props(): return { - 'ampa': { - 'e': 0, - 'tau1': 0.5, - 'tau2': 5. - }, - 'gabaa': { - 'e': -80, - 'tau1': 0.5, - 'tau2': 5. - }, - 'nmda': { - 'e': 0, - 'tau1': 1., - 'tau2': 20. - } + 'ampa': {'e': 0, 'tau1': 0.5, 'tau2': 5.0}, + 'gabaa': {'e': -80, 'tau1': 0.5, 'tau2': 5.0}, + 'nmda': {'e': 0, 'tau1': 1.0, 'tau2': 20.0}, } @@ -354,12 +355,15 @@ def basket(cell_name, pos=(0, 0, 0), gid=None): sections['soma'].mechs = {'hh2': dict()} cell_tree = None - return Cell(cell_name, pos, - sections=sections, - synapses=synapses, - sect_loc=sect_loc, - cell_tree=cell_tree, - gid=gid) + return Cell( + cell_name, + pos, + sections=sections, + synapses=synapses, + sect_loc=sect_loc, + cell_tree=cell_tree, + gid=gid, + ) def pyramidal(cell_name, pos=(0, 0, 0), override_params=None, gid=None): @@ -417,20 +421,24 @@ def pyramidal_ca(cell_name, pos, override_params=None, gid=None): override_params['L5Pyr_soma_gkbar_hh2'] = 0.06 override_params['L5Pyr_soma_gnabar_hh2'] = 0.32 - gbar_ca = partial( - _linear_g_at_dist, gsoma=10., gdend=40., xkink=1501) + gbar_ca = partial(_linear_g_at_dist, gsoma=10.0, gdend=40.0, xkink=1501) gbar_na = partial( - _linear_g_at_dist, gsoma=override_params['L5Pyr_soma_gnabar_hh2'], - gdend=28e-4, xkink=962) + _linear_g_at_dist, + gsoma=override_params['L5Pyr_soma_gnabar_hh2'], + gdend=28e-4, + xkink=962, + ) gbar_k = partial( - _exp_g_at_dist, zero_val=override_params['L5Pyr_soma_gkbar_hh2'], - exp_term=-0.006, offset=1e-4) + _exp_g_at_dist, + zero_val=override_params['L5Pyr_soma_gkbar_hh2'], + exp_term=-0.006, + offset=1e-4, + ) override_params['L5Pyr_dend_gbar_ca'] = gbar_ca override_params['L5Pyr_dend_gnabar_hh2'] = gbar_na override_params['L5Pyr_dend_gkbar_hh2'] = gbar_k - cell = pyramidal(cell_name, pos, override_params=override_params, - gid=gid) + cell = pyramidal(cell_name, pos, override_params=override_params, gid=gid) return cell diff --git a/hnn_core/check.py b/hnn_core/check.py index 7e0416fa8..e9e511a4d 100644 --- a/hnn_core/check.py +++ b/hnn_core/check.py @@ -8,8 +8,9 @@ def _check_gids(gids, gid_ranges, valid_cells, arg_name, same_type=True): """Format different gid specifications into list of gids""" - _validate_type(gids, (int, list, range, str, None), arg_name, - 'int list, range, str, or None') + _validate_type( + gids, (int, list, range, str, None), arg_name, 'int list, range, str, or None' + ) # Convert gids to list if gids is None: @@ -28,8 +29,7 @@ def _check_gids(gids, gid_ranges, valid_cells, arg_name, same_type=True): _validate_type(gid, int, arg_name) gid_type = _gid_to_type(gid, gid_ranges) if gid_type is None: - raise AssertionError( - f'{arg_name} {gid} not in net.gid_ranges') + raise AssertionError(f'{arg_name} {gid} not in net.gid_ranges') if same_type and gid_type != cell_type: raise AssertionError(f'All {arg_name} must be of the same type') diff --git a/hnn_core/dipole.py b/hnn_core/dipole.py index 408acfb31..6eb83b247 100644 --- a/hnn_core/dipole.py +++ b/hnn_core/dipole.py @@ -15,8 +15,16 @@ from .viz import plot_dipole, plot_psd, plot_tfr_morlet -def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False, - record_isec=False, record_ca=False, postproc=False): +def simulate_dipole( + net, + tstop, + dt=0.025, + n_trials=None, + record_vsec=False, + record_isec=False, + record_ca=False, + postproc=False, +): """Simulate a dipole given the experiment parameters. Parameters @@ -63,13 +71,15 @@ def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False, if n_trials is None: n_trials = net._params['N_trials'] if n_trials < 1: - raise ValueError("Invalid number of simulations: %d" % n_trials) + raise ValueError('Invalid number of simulations: %d' % n_trials) if not net.connectivity: - warnings.warn('No connections instantiated in network. Consider using ' - 'net = jones_2009_model() or net = law_2021_model() to ' - 'create a predefined network from published models.', - UserWarning) + warnings.warn( + 'No connections instantiated in network. Consider using ' + 'net = jones_2009_model() or net = law_2021_model() to ' + 'create a predefined network from published models.', + UserWarning, + ) # ADD DRIVE WARNINGS HERE if not net.external_drives and not net.external_biases: warnings.warn('No external drives or biases loaded', UserWarning) @@ -82,10 +92,10 @@ def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False, for cell_type, bias_cell_type in bias.items(): if bias_cell_type['tstop'] is None: bias_cell_type['tstop'] = tstop - if bias_cell_type['tstop'] < 0.: + if bias_cell_type['tstop'] < 0.0: raise ValueError('End time of tonic input cannot be negative') duration = bias_cell_type['tstop'] - bias_cell_type['t0'] - if duration < 0.: + if duration < 0.0: raise ValueError('Duration of tonic input cannot be negative') net._instantiate_drives(n_trials=n_trials, tstop=tstop) @@ -108,10 +118,12 @@ def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False, net._dt = dt if postproc: - warnings.warn('The postproc-argument is deprecated and will be removed' - ' in a future release of hnn-core. Please define ' - 'smoothing and scaling explicitly using Dipole methods.', - DeprecationWarning) + warnings.warn( + 'The postproc-argument is deprecated and will be removed' + ' in a future release of hnn-core. Please define ' + 'smoothing and scaling explicitly using Dipole methods.', + DeprecationWarning, + ) dpls = _BACKEND.simulate(net, tstop, dt, n_trials, postproc) return dpls @@ -132,14 +144,14 @@ def _read_dipole_txt(fname, extension='.txt'): """ if extension == '.csv': # read from a csv file ignoring the headers - dpl_data = np.genfromtxt(fname, delimiter=',', - skip_header=1, dtype=float) + dpl_data = np.genfromtxt(fname, delimiter=',', skip_header=1, dtype=float) else: dpl_data = np.loadtxt(fname, dtype=float) ncols = dpl_data.shape[1] if ncols not in (2, 4): raise ValueError( - f'Data are supposed to have 2 or 4 columns while we have {ncols}.') + f'Data are supposed to have 2 or 4 columns while we have {ncols}.' + ) dpl = Dipole(dpl_data[:, 0], dpl_data[:, 1:]) return dpl @@ -160,16 +172,18 @@ def _read_dipole_hdf5(fname): dpl_data = read_hdf5(fname) if 'object_type' not in dpl_data: - raise NameError('The given file is not compatible. ' - 'The file should contain information' - ' about object type to be read.') + raise NameError( + 'The given file is not compatible. ' + 'The file should contain information' + ' about object type to be read.' + ) if dpl_data['object_type'] != 'Dipole': - raise ValueError('The object should be of type Dipole. ' - 'The file contains object of ' - 'type %s' % (dpl_data['object_type'],)) - dpl = Dipole(times=dpl_data['times'], - data=dpl_data['data'], - nave=dpl_data['nave']) + raise ValueError( + 'The object should be of type Dipole. ' + 'The file contains object of ' + 'type %s' % (dpl_data['object_type'],) + ) + dpl = Dipole(times=dpl_data['times'], data=dpl_data['data'], nave=dpl_data['nave']) dpl.sfreq = dpl_data['sfreq'] dpl.scale_applied = dpl_data['scale_applied'] return dpl @@ -199,8 +213,10 @@ def read_dipole(fname): elif file_extension == '.hdf5': return _read_dipole_hdf5(fname) else: - raise NameError('File extension should be either txt or hdf5, but the ' - 'given extension is %s' % (file_extension,)) + raise NameError( + 'File extension should be either txt or hdf5, but the ' + 'given extension is %s' % (file_extension,) + ) def average_dipoles(dpls): @@ -224,19 +240,19 @@ def average_dipoles(dpls): raise RuntimeError('All dipoles must be scaled equally!') if not isinstance(dpl, Dipole): raise ValueError( - f"All elements in the list should be instances of " - f"Dipole. Got {type(dpl)}") + f'All elements in the list should be instances of ' + f'Dipole. Got {type(dpl)}' + ) if dpl.nave > 1: - raise ValueError("Dipole at index %d was already an average of %d" - " trials. Cannot reaverage" % - (dpl_idx, dpl.nave)) + raise ValueError( + 'Dipole at index %d was already an average of %d' + ' trials. Cannot reaverage' % (dpl_idx, dpl.nave) + ) avg_data = list() layers = dpl.data.keys() for layer in layers: - avg_data.append( - np.mean(np.array([dpl.data[layer] for dpl in dpls]), axis=0) - ) + avg_data.append(np.mean(np.array([dpl.data[layer] for dpl in dpls]), axis=0)) avg_data = np.c_[avg_data].T avg_dpl = Dipole(dpls[0].times, avg_data) # The averaged scale should equal all scals in the input dpl list. @@ -249,7 +265,7 @@ def average_dipoles(dpls): def _rmse(dpl, exp_dpl, tstart=0.0, tstop=0.0, weights=None): - """ Calculates RMSE between data in dpl and exp_dpl + """Calculates RMSE between data in dpl and exp_dpl Parameters ---------- dpl : instance of Dipole @@ -302,13 +318,13 @@ def _rmse(dpl, exp_dpl, tstart=0.0, tstop=0.0, weights=None): dpl1 = dpl.data['agg'][sim_start_index:sim_end_index] dpl2 = exp_dpl.data['agg'][exp_start_index:exp_end_index] - if (sim_length > exp_length): + if sim_length > exp_length: # downsample simulation timeseries to match exp data dpl1 = signal.resample(dpl1, exp_length) weights = signal.resample(weights, exp_length) indices = np.where(weights < 1e-4) weights[indices] = 0 - elif (sim_length < exp_length): + elif sim_length < exp_length: # downsample exp timeseries to match simulation data dpl2 = signal.resample(dpl2, sim_length) @@ -359,13 +375,12 @@ def __init__(self, times, data, nave=1): # noqa: D102 if data.ndim == 1: data = data[:, None] if data.shape[1] == 3: - self.data = {'agg': data[:, 0], 'L2': data[:, 1], - 'L5': data[:, 2]} + self.data = {'agg': data[:, 0], 'L2': data[:, 1], 'L5': data[:, 2]} elif data.shape[1] == 1: self.data = {'agg': data[:, 0]} self.nave = nave - self.sfreq = 1000. / (times[1] - times[0]) # NB assumes len > 1 + self.sfreq = 1000.0 / (times[1] - times[0]) # NB assumes len > 1 self.scale_applied = 1 # for visualisation def copy(self): @@ -438,8 +453,7 @@ def smooth(self, window_len): from .utils import smooth_waveform for key in self.data.keys(): - self.data[key] = smooth_waveform(self.data[key], window_len, - self.sfreq) + self.data[key] = smooth_waveform(self.data[key], window_len, self.sfreq) return self @@ -468,19 +482,25 @@ def savgol_filter(self, h_freq): A copy of the modified Dipole instance. """ from .utils import _savgol_filter + if h_freq < 0: raise ValueError('h_freq cannot be negative') elif h_freq > 0.5 * self.sfreq: - raise ValueError( - 'h_freq must be less than half the sample rate') + raise ValueError('h_freq must be less than half the sample rate') for key in self.data.keys(): - self.data[key] = _savgol_filter(self.data[key], - h_freq, - self.sfreq) + self.data[key] = _savgol_filter(self.data[key], h_freq, self.sfreq) return self - def plot(self, tmin=None, tmax=None, layer='agg', decim=None, ax=None, - color='k', show=True): + def plot( + self, + tmin=None, + tmax=None, + layer='agg', + decim=None, + ax=None, + color='k', + show=True, + ): """Simple layer-specific plot function. Parameters @@ -502,11 +522,29 @@ def plot(self, tmin=None, tmax=None, layer='agg', decim=None, ax=None, The matplotlib figure handle. """ - return plot_dipole(self, tmin=tmin, tmax=tmax, ax=ax, layer=layer, - decim=decim, color=color, show=show) + return plot_dipole( + self, + tmin=tmin, + tmax=tmax, + ax=ax, + layer=layer, + decim=decim, + color=color, + show=show, + ) - def plot_psd(self, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg', - color=None, label=None, ax=None, show=True): + def plot_psd( + self, + fmin=0, + fmax=None, + tmin=None, + tmax=None, + layer='agg', + color=None, + label=None, + ax=None, + show=True, + ): """Plot power spectral density (PSD) of dipole time course Applies `~scipy.signal.periodogram` from SciPy with @@ -547,14 +585,34 @@ def plot_psd(self, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg', fig : instance of matplotlib Figure The matplotlib figure handle. """ - return plot_psd(self, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, - layer=layer, color=color, label=label, ax=ax, - show=show) - - def plot_tfr_morlet(self, freqs, n_cycles=7., tmin=None, tmax=None, - layer='agg', decim=None, padding='zeros', ax=None, - colormap='inferno', colorbar=True, - colorbar_inside=False, show=True): + return plot_psd( + self, + fmin=fmin, + fmax=fmax, + tmin=tmin, + tmax=tmax, + layer=layer, + color=color, + label=label, + ax=ax, + show=show, + ) + + def plot_tfr_morlet( + self, + freqs, + n_cycles=7.0, + tmin=None, + tmax=None, + layer='agg', + decim=None, + padding='zeros', + ax=None, + colormap='inferno', + colorbar=True, + colorbar_inside=False, + show=True, + ): """Plot Morlet time-frequency representation of dipole time course NB: Calls `~mne.time_frequency.tfr_array_morlet`, so ``mne`` must be @@ -603,10 +661,20 @@ def plot_tfr_morlet(self, freqs, n_cycles=7., tmin=None, tmax=None, The matplotlib figure handle. """ return plot_tfr_morlet( - self, freqs, n_cycles=n_cycles, tmin=tmin, tmax=tmax, - layer=layer, decim=decim, padding=padding, ax=ax, - colormap=colormap, colorbar=colorbar, - colorbar_inside=colorbar_inside, show=show) + self, + freqs, + n_cycles=n_cycles, + tmin=tmin, + tmax=tmax, + layer=layer, + decim=decim, + padding=padding, + ax=ax, + colormap=colormap, + colorbar=colorbar, + colorbar_inside=colorbar_inside, + show=show, + ) def _baseline_renormalize(self, N_pyr_x, N_pyr_y): """Only baseline renormalize if the units are fAm. @@ -631,7 +699,7 @@ def _baseline_renormalize(self, N_pyr_x, N_pyr_y): dpl_offset = { # these values will be subtracted 'L2': N_pyr * 0.0443, - 'L5': N_pyr * -49.0502 + 'L5': N_pyr * -49.0502, # 'L5': N_pyr * -48.3642, # will be calculated next, this is a placeholder # 'agg': None, @@ -647,15 +715,17 @@ def _baseline_renormalize(self, N_pyr_x, N_pyr_y): m = 3.4770508e-3 b = -51.231085 # these values were fit over the range [750., 5000] - t1 = 750. + t1 = 750.0 m1 = 1.01e-4 b1 = -48.412078 # piecewise normalization - self.data['L5'][self.times <= 37.] -= dpl_offset['L5'] - self.data['L5'][(self.times > 37.) & (self.times < t1)] -= N_pyr * \ - (m * self.times[(self.times > 37.) & (self.times < t1)] + b) - self.data['L5'][self.times >= t1] -= N_pyr * \ - (m1 * self.times[self.times >= t1] + b1) + self.data['L5'][self.times <= 37.0] -= dpl_offset['L5'] + self.data['L5'][(self.times > 37.0) & (self.times < t1)] -= N_pyr * ( + m * self.times[(self.times > 37.0) & (self.times < t1)] + b + ) + self.data['L5'][self.times >= t1] -= N_pyr * ( + m1 * self.times[self.times >= t1] + b1 + ) # recalculate the aggregate dipole based on the baseline # normalized ones self.data['agg'] = self.data['L2'] + self.data['L5'] @@ -678,13 +748,18 @@ def _write_txt(self, fname): 4) L5 current dipole (scaled nAm) """ - warnings.warn('Writing dipole to txt file is deprecated ' - 'and will be removed in future versions. ' - 'Please use hdf5', DeprecationWarning, stacklevel=2) + warnings.warn( + 'Writing dipole to txt file is deprecated ' + 'and will be removed in future versions. ' + 'Please use hdf5', + DeprecationWarning, + stacklevel=2, + ) if self.nave > 1: - warnings.warn("Saving Dipole to file that is an average of %d" - " trials" % self.nave) + warnings.warn( + 'Saving Dipole to file that is an average of %d' ' trials' % self.nave + ) X = [self.times] fmt = ['%3.3f'] @@ -708,7 +783,7 @@ def _write_hdf5(self, fname): """ print(f'Writing file {fname}') dpl_data = dict() - dpl_data['object_type'] = "Dipole" + dpl_data['object_type'] = 'Dipole' dpl_data['times'] = self.times dpl_data['sfreq'] = self.sfreq dpl_data['nave'] = self.nave @@ -741,13 +816,17 @@ def write(self, fname, overwrite=True): fname = str(fname) if overwrite is False and os.path.exists(fname): - raise FileExistsError('File already exists at path %s. Rename ' - 'the file or set overwrite=True.' % (fname,)) + raise FileExistsError( + 'File already exists at path %s. Rename ' + 'the file or set overwrite=True.' % (fname,) + ) file_extension = os.path.splitext(fname)[-1] if file_extension == '.txt': self._write_txt(fname) elif file_extension == '.hdf5': self._write_hdf5(fname) else: - raise NameError('File extension should be either txt or hdf5, but ' - 'the given extension is %s.' % (file_extension,)) + raise NameError( + 'File extension should be either txt or hdf5, but ' + 'the given extension is %s.' % (file_extension,) + ) diff --git a/hnn_core/docs.py b/hnn_core/docs.py index 93d67c33f..da7f7da54 100644 --- a/hnn_core/docs.py +++ b/hnn_core/docs.py @@ -3,47 +3,35 @@ docdict = dict() # Define docdicts -docdict[ - "net" -] = """ +docdict['net'] = """ net : Instance of Network object The Network object. """ -docdict[ - "fname" -] = """ +docdict['fname'] = """ fname : str | Path object Full path to the output file (.hdf5). """ -docdict[ - "overwrite" -] = """ +docdict['overwrite'] = """ overwrite : Boolean True : Overwrite existing file. False : Throw error if file already exists. """ -docdict[ - "write_output" -] = """ +docdict['write_output'] = """ write_output : Boolean True : Save the Network simulation output. False : Do not save the Network simulation output. """ -docdict[ - "read_output" -] = """ +docdict['read_output'] = """ read_output : Boolean True : Read network with simulation results. False : Read network without simulation results. """ -docdict[ - "read_drives" -] = """ +docdict['read_drives'] = """ read_output : Boolean True : Read drives from configuration file. False : Do not read drives from the configuration file. diff --git a/hnn_core/drives.py b/hnn_core/drives.py index 4b48b7613..989adc9f7 100644 --- a/hnn_core/drives.py +++ b/hnn_core/drives.py @@ -6,12 +6,15 @@ import numpy as np -from .params import (_extract_bias_specs_from_hnn_params, - _extract_drive_specs_from_hnn_params) +from .params import ( + _extract_bias_specs_from_hnn_params, + _extract_drive_specs_from_hnn_params, +) -def _get_target_properties(weights_ampa, weights_nmda, synaptic_delays, - location, probability=1.0): +def _get_target_properties( + weights_ampa, weights_nmda, synaptic_delays, location, probability=1.0 +): """Retrieve drive properties associated with each target cell type Note that target cell types of a drive are inferred from the synaptic @@ -24,8 +27,10 @@ def _get_target_properties(weights_ampa, weights_nmda, synaptic_delays, if weights_nmda is None: weights_nmda = dict() - weights_by_type = {cell_type: dict() for cell_type in - (set(weights_ampa.keys()) | set(weights_nmda.keys()))} + weights_by_type = { + cell_type: dict() + for cell_type in (set(weights_ampa.keys()) | set(weights_nmda.keys())) + } for cell_type in weights_ampa: weights_by_type[cell_type].update({'ampa': weights_ampa[cell_type]}) for cell_type in weights_nmda: @@ -33,77 +38,87 @@ def _get_target_properties(weights_ampa, weights_nmda, synaptic_delays, target_populations = set(weights_by_type) if not target_populations: - raise ValueError('No target cell types have been given a synaptic ' - 'weight for this drive.') + raise ValueError( + 'No target cell types have been given a synaptic ' 'weight for this drive.' + ) # Distal drives should not target L5 basket cells according to the # canonical Jones model if location == 'distal' and 'L5_basket' in target_populations: - raise ValueError('Due to physiological/anatomical constraints, ' - 'a distal drive cannot target L5_basket cell types. ' - 'L5_basket cell types must remain undefined by ' - 'the user in all synaptic weights dictionaries ' - 'for this drive. ' - 'Therefore, please remove the L5_basket entries ' - 'from the corresponding dictionaries.') + raise ValueError( + 'Due to physiological/anatomical constraints, ' + 'a distal drive cannot target L5_basket cell types. ' + 'L5_basket cell types must remain undefined by ' + 'the user in all synaptic weights dictionaries ' + 'for this drive. ' + 'Therefore, please remove the L5_basket entries ' + 'from the corresponding dictionaries.' + ) if isinstance(synaptic_delays, float): - delays_by_type = {cell_type: synaptic_delays for cell_type in - target_populations} + delays_by_type = { + cell_type: synaptic_delays for cell_type in target_populations + } else: delays_by_type = synaptic_delays.copy() if set(delays_by_type.keys()) != target_populations: - raise ValueError('synaptic_delays is either a common float or needs ' - 'to be specified as a dict for each of the cell ' - 'types defined in weights_ampa and weights_nmda ' - f'({target_populations})') + raise ValueError( + 'synaptic_delays is either a common float or needs ' + 'to be specified as a dict for each of the cell ' + 'types defined in weights_ampa and weights_nmda ' + f'({target_populations})' + ) if isinstance(probability, float): - probability_by_type = {cell_type: probability for cell_type in - target_populations} + probability_by_type = { + cell_type: probability for cell_type in target_populations + } else: probability_by_type = probability.copy() if set(probability_by_type.keys()) != target_populations: - raise ValueError('probability is either a common float or needs ' - 'to be specified as a dict for each of the cell ' - 'types defined in weights_ampa and weights_nmda ' - f'({target_populations})') + raise ValueError( + 'probability is either a common float or needs ' + 'to be specified as a dict for each of the cell ' + 'types defined in weights_ampa and weights_nmda ' + f'({target_populations})' + ) - return (target_populations, weights_by_type, delays_by_type, - probability_by_type) + return (target_populations, weights_by_type, delays_by_type, probability_by_type) def _check_drive_parameter_values(drive_type, **kwargs): if 'tstop' in kwargs: if kwargs['tstop'] is not None: - if kwargs['tstop'] < 0.: - raise ValueError(f'End time of {drive_type} drive cannot be ' - 'negative') + if kwargs['tstop'] < 0.0: + raise ValueError( + f'End time of {drive_type} drive cannot be ' 'negative' + ) if 'tstart' in kwargs and kwargs['tstop'] < kwargs['tstart']: - raise ValueError(f'Duration of {drive_type} drive cannot be ' - 'negative') + raise ValueError( + f'Duration of {drive_type} drive cannot be ' 'negative' + ) if 'sigma' in kwargs: - if kwargs['sigma'] < 0.: + if kwargs['sigma'] < 0.0: raise ValueError('Standard deviation cannot be negative') if 'numspikes' in kwargs: if not kwargs['numspikes'] > 0: raise ValueError('Number of spikes must be greater than zero') if 'tstart' in kwargs: if kwargs['tstart'] < 0: - raise ValueError(f'Start time of {drive_type} drive cannot be ' - 'negative') + raise ValueError(f'Start time of {drive_type} drive cannot be ' 'negative') - if ('numspikes' in kwargs and 'spike_isi' in kwargs and - 'burst_rate' in kwargs): + if 'numspikes' in kwargs and 'spike_isi' in kwargs and 'burst_rate' in kwargs: n_spikes = kwargs['numspikes'] isi = kwargs['spike_isi'] - burst_period = 1000. / kwargs['burst_rate'] + burst_period = 1000.0 / kwargs['burst_rate'] burst_duration = (n_spikes - 1) * isi if burst_duration > burst_period: - raise ValueError(f'Burst duration ({burst_duration}s) cannot' - f' be greater than burst period ({burst_period}s)' - 'Consider increasing the spike ISI or burst rate') + raise ValueError( + f'Burst duration ({burst_duration}s) cannot' + f' be greater than burst period ({burst_period}s)' + 'Consider increasing the spike ISI or burst rate' + ) def _check_poisson_rates(rate_constant, target_populations, all_cell_types): @@ -111,63 +126,74 @@ def _check_poisson_rates(rate_constant, target_populations, all_cell_types): constants_provided = set(rate_constant.keys()) if not target_populations.issubset(constants_provided): raise ValueError( - f"Rate constants not provided for all target cell " - f"populations ({target_populations})") + f'Rate constants not provided for all target cell ' + f'populations ({target_populations})' + ) if not constants_provided.issubset(all_cell_types): offending_keys = constants_provided.difference(all_cell_types) raise ValueError( - f"Rate constant provided for unknown target cell " - f"population: {offending_keys}") + f'Rate constant provided for unknown target cell ' + f'population: {offending_keys}' + ) for key, val in rate_constant.items(): - if not val > 0.: - raise ValueError( - f"Rate constant must be positive ({key}, {val})") + if not val > 0.0: + raise ValueError(f'Rate constant must be positive ({key}, {val})') else: - if not rate_constant > 0.: - raise ValueError( - f"Rate constant must be positive, got {rate_constant}") + if not rate_constant > 0.0: + raise ValueError(f'Rate constant must be positive, got {rate_constant}') def _add_drives_from_params(net): drive_specs = _extract_drive_specs_from_hnn_params( - net._params, list(net.cell_types.keys()), net._legacy_mode) + net._params, list(net.cell_types.keys()), net._legacy_mode + ) bias_specs = _extract_bias_specs_from_hnn_params( - net._params, list(net.cell_types.keys())) + net._params, list(net.cell_types.keys()) + ) for drive_name in sorted(drive_specs.keys()): # order matters specs = drive_specs[drive_name] if specs['type'] == 'evoked': net.add_evoked_drive( - drive_name, mu=specs['dynamics']['mu'], + drive_name, + mu=specs['dynamics']['mu'], sigma=specs['dynamics']['sigma'], numspikes=specs['dynamics']['numspikes'], n_drive_cells=specs['dynamics']['n_drive_cells'], cell_specific=specs['cell_specific'], weights_ampa=specs['weights_ampa'], weights_nmda=specs['weights_nmda'], - location=specs['location'], event_seed=specs['event_seed'], + location=specs['location'], + event_seed=specs['event_seed'], synaptic_delays=specs['synaptic_delays'], - space_constant=specs['space_constant']) + space_constant=specs['space_constant'], + ) elif specs['type'] == 'poisson': net.add_poisson_drive( - drive_name, tstart=specs['dynamics']['tstart'], + drive_name, + tstart=specs['dynamics']['tstart'], tstop=specs['dynamics']['tstop'], rate_constant=specs['dynamics']['rate_constant'], weights_ampa=specs['weights_ampa'], weights_nmda=specs['weights_nmda'], - location=specs['location'], event_seed=specs['event_seed'], + location=specs['location'], + event_seed=specs['event_seed'], synaptic_delays=specs['synaptic_delays'], - space_constant=specs['space_constant']) + space_constant=specs['space_constant'], + ) elif specs['type'] == 'gaussian': net.add_evoked_drive( # 'gaussian' is just evoked - drive_name, mu=specs['dynamics']['mu'], + drive_name, + mu=specs['dynamics']['mu'], sigma=specs['dynamics']['sigma'], numspikes=specs['dynamics']['numspikes'], weights_ampa=specs['weights_ampa'], weights_nmda=specs['weights_nmda'], - location=specs['location'], event_seed=specs['event_seed'], + location=specs['location'], + event_seed=specs['event_seed'], synaptic_delays=specs['synaptic_delays'], - space_constant=specs['space_constant']) + space_constant=specs['space_constant'], + ) elif specs['type'] == 'bursty': net.add_bursty_drive( drive_name, @@ -185,21 +211,20 @@ def _add_drives_from_params(net): location=specs['location'], space_constant=specs['space_constant'], synaptic_delays=specs['synaptic_delays'], - event_seed=specs['event_seed']) + event_seed=specs['event_seed'], + ) # add tonic biases if present in params if bias_specs['tonic']: _cell_types_amplitudes = dict() for cellname in bias_specs['tonic']: - _cell_types_amplitudes[cellname] = ( - bias_specs['tonic'][cellname]['amplitude']) + _cell_types_amplitudes[cellname] = bias_specs['tonic'][cellname][ + 'amplitude' + ] _t0 = bias_specs['tonic'][cellname]['t0'] _tstop = bias_specs['tonic'][cellname]['tstop'] - net.add_tonic_bias( - amplitude=_cell_types_amplitudes, - t0=_t0, - tstop=_tstop) + net.add_tonic_bias(amplitude=_cell_types_amplitudes, t0=_t0, tstop=_tstop) # in HNN-GUI, seed is determined by "absolute GID" instead of the # gid offset with respect to the first cell of a population. @@ -230,9 +255,16 @@ def _get_prng(seed, gid): return np.random.RandomState(seed + gid), np.random.RandomState(seed) -def _drive_cell_event_times(drive_type, dynamics, tstop, target_type='any', - trial_idx=0, drive_cell_gid=0, event_seed=0, - trial_seed_offset=0): +def _drive_cell_event_times( + drive_type, + dynamics, + tstop, + target_type='any', + trial_idx=0, + drive_cell_gid=0, + event_seed=0, + trial_seed_offset=0, +): """Generate event times for one artificial drive cell based on dynamics. Parameters @@ -265,8 +297,9 @@ def _drive_cell_event_times(drive_type, dynamics, tstop, target_type='any', event_times : list The event times at which spikes occur. """ - prng, prng2 = _get_prng(seed=event_seed + trial_idx * trial_seed_offset, - gid=drive_cell_gid) + prng, prng2 = _get_prng( + seed=event_seed + trial_idx * trial_seed_offset, gid=drive_cell_gid + ) # check drive name validity, allowing substring matches valid_drives = ['evoked', 'poisson', 'gaussian', 'bursty'] @@ -290,13 +323,15 @@ def _drive_cell_event_times(drive_type, dynamics, tstop, target_type='any', t0=dynamics['tstart'], T=dynamics['tstop'], lamtha=rate_constant, - prng=prng) + prng=prng, + ) elif drive_type == 'evoked' or drive_type == 'gaussian': event_times = _create_gauss( mu=dynamics['mu'], sigma=dynamics['sigma'], numspikes=dynamics['numspikes'], - prng=prng) + prng=prng, + ) elif drive_type == 'bursty': event_times = _create_bursty_input( t0=dynamics['tstart'], @@ -307,13 +342,13 @@ def _drive_cell_event_times(drive_type, dynamics, tstop, target_type='any', events_per_cycle=dynamics['numspikes'], cycle_events_isi=dynamics['spike_isi'], prng=prng, - prng2=prng2) + prng2=prng2, + ) # brute force remove non-zero times. Might result in fewer vals # than desired # values MUST be sorted for VecStim()! - event_times = event_times[np.logical_and(event_times > 0, - event_times <= tstop)] + event_times = event_times[np.logical_and(event_times > 0, event_times <= tstop)] event_times.sort() event_times = event_times.tolist() @@ -341,18 +376,21 @@ def _create_extpois(*, t0, T, lamtha, prng): """ # see: http://www.cns.nyu.edu/~david/handouts/poisson.pdf if t0 < 0: - raise ValueError('The start time for Poisson inputs must be' - f'greater than 0. Got {t0}') + raise ValueError( + 'The start time for Poisson inputs must be' f'greater than 0. Got {t0}' + ) if T < t0: - raise ValueError('The end time for Poisson inputs must be' - f'greater than start time. Got ({t0}, {T})') - if lamtha <= 0.: + raise ValueError( + 'The end time for Poisson inputs must be' + f'greater than start time. Got ({t0}, {T})' + ) + if lamtha <= 0.0: raise ValueError(f'Rate must be > 0. Got {lamtha}') event_times = list() t_gen = t0 while t_gen < T: - t_gen += prng.exponential(1. / lamtha) * 1000. + t_gen += prng.exponential(1.0 / lamtha) * 1000.0 if t_gen < T: event_times.append(t_gen) @@ -381,9 +419,18 @@ def _create_gauss(*, mu, sigma, numspikes, prng): return prng.normal(mu, sigma, numspikes) -def _create_bursty_input(*, t0, t0_stdev, tstop, f_input, - events_jitter_std, events_per_cycle=2, - cycle_events_isi=10, prng, prng2): +def _create_bursty_input( + *, + t0, + t0_stdev, + tstop, + f_input, + events_jitter_std, + events_per_cycle=2, + cycle_events_isi=10, + prng, + prng2, +): """Creates the bursty ongoing external inputs. Used for, e.g., for rhythmic inputs in alpha/beta generation. @@ -421,12 +468,14 @@ def _create_bursty_input(*, t0, t0_stdev, tstop, f_input, if t0_stdev > 0.0: t0 = prng2.normal(t0, t0_stdev) - burst_period = 1000. / f_input + burst_period = 1000.0 / f_input burst_duration = (events_per_cycle - 1) * cycle_events_isi if burst_duration > burst_period: - raise ValueError(f'Burst duration ({burst_duration}s) cannot' - f' be greater than burst period ({burst_period}s)' - 'Consider increasing the spike ISI or burst rate') + raise ValueError( + f'Burst duration ({burst_duration}s) cannot' + f' be greater than burst period ({burst_period}s)' + 'Consider increasing the spike ISI or burst rate' + ) # array of mean stimulus times, starts at t0 isi_array = np.arange(t0, tstop, burst_period) @@ -434,7 +483,7 @@ def _create_bursty_input(*, t0, t0_stdev, tstop, f_input, t_array = prng.normal(isi_array, events_jitter_std) if events_per_cycle > 1: - cycle = (np.arange(events_per_cycle) - (events_per_cycle - 1) / 2) + cycle = np.arange(events_per_cycle) - (events_per_cycle - 1) / 2 t_array = np.ravel([t_array + cycle_events_isi * cyc for cyc in cycle]) return t_array diff --git a/hnn_core/externals/bayesopt.py b/hnn_core/externals/bayesopt.py index 318bd38a3..bb82755ae 100644 --- a/hnn_core/externals/bayesopt.py +++ b/hnn_core/externals/bayesopt.py @@ -38,15 +38,14 @@ def expected_improvement(gp, best_f, all_x): """ with warnings.catch_warnings(): - warnings.simplefilter("ignore") + warnings.simplefilter('ignore') # (n_samples, n_features) y, y_std = gp.predict(all_x, return_std=True) Z = (y - best_f) / (y_std + 1e-12) return (y - best_f) * st.norm.cdf(Z) + y_std * st.norm.pdf(Z) -def bayes_opt(func, x0, cons, acquisition, maxfun=200, - debug=False, random_state=None): +def bayes_opt(func, x0, cons, acquisition, maxfun=200, debug=False, random_state=None): """The actual bayesian optimization function. Parameters @@ -86,14 +85,15 @@ def bayes_opt(func, x0, cons, acquisition, maxfun=200, gp = gaussian_process.GaussianProcessRegressor(random_state=random_state) if debug: - print("iter", -1, "best_x", best_x, best_f) + print('iter', -1, 'best_x', best_x, best_f) for i in range(maxfun): - # draw samples from distribution - all_x = np.random.uniform(low=[idx[0] for idx in cons], - high=[idx[1] for idx in cons], - size=(10000, len(cons))) + all_x = np.random.uniform( + low=[idx[0] for idx in cons], + high=[idx[1] for idx in cons], + size=(10000, len(cons)), + ) gp.fit(np.array(X), np.array(y)) # (n_samples, n_features) @@ -111,7 +111,7 @@ def bayes_opt(func, x0, cons, acquisition, maxfun=200, best_x = new_x if debug: - print("iter", i, "best_x", best_x, best_f) + print('iter', i, 'best_x', best_x, best_f) return best_x, best_f @@ -119,11 +119,14 @@ def bayes_opt(func, x0, cons, acquisition, maxfun=200, if __name__ == '__main__': from scipy.optimize import rosen - opt_params, obj_vals = bayes_opt(rosen, - [0.5, 0.6], [(-1, 1), (-1, 1)], - expected_improvement, - maxfun=200, - random_state=1) + opt_params, obj_vals = bayes_opt( + rosen, + [0.5, 0.6], + [(-1, 1), (-1, 1)], + expected_improvement, + maxfun=200, + random_state=1, + ) x = np.linspace(-1, 1, 50) y = np.linspace(-1, 1, 50) diff --git a/hnn_core/externals/mne.py b/hnn_core/externals/mne.py index 864278431..92196d0f0 100644 --- a/hnn_core/externals/mne.py +++ b/hnn_core/externals/mne.py @@ -50,20 +50,178 @@ def next_fast_len(target): Copied from SciPy with minor modifications. """ from bisect import bisect_left - hams = (8, 9, 10, 12, 15, 16, 18, 20, 24, 25, 27, 30, 32, 36, 40, 45, 48, - 50, 54, 60, 64, 72, 75, 80, 81, 90, 96, 100, 108, 120, 125, 128, - 135, 144, 150, 160, 162, 180, 192, 200, 216, 225, 240, 243, 250, - 256, 270, 288, 300, 320, 324, 360, 375, 384, 400, 405, 432, 450, - 480, 486, 500, 512, 540, 576, 600, 625, 640, 648, 675, 720, 729, - 750, 768, 800, 810, 864, 900, 960, 972, 1000, 1024, 1080, 1125, - 1152, 1200, 1215, 1250, 1280, 1296, 1350, 1440, 1458, 1500, 1536, - 1600, 1620, 1728, 1800, 1875, 1920, 1944, 2000, 2025, 2048, 2160, - 2187, 2250, 2304, 2400, 2430, 2500, 2560, 2592, 2700, 2880, 2916, - 3000, 3072, 3125, 3200, 3240, 3375, 3456, 3600, 3645, 3750, 3840, - 3888, 4000, 4050, 4096, 4320, 4374, 4500, 4608, 4800, 4860, 5000, - 5120, 5184, 5400, 5625, 5760, 5832, 6000, 6075, 6144, 6250, 6400, - 6480, 6561, 6750, 6912, 7200, 7290, 7500, 7680, 7776, 8000, 8100, - 8192, 8640, 8748, 9000, 9216, 9375, 9600, 9720, 10000) + + hams = ( + 8, + 9, + 10, + 12, + 15, + 16, + 18, + 20, + 24, + 25, + 27, + 30, + 32, + 36, + 40, + 45, + 48, + 50, + 54, + 60, + 64, + 72, + 75, + 80, + 81, + 90, + 96, + 100, + 108, + 120, + 125, + 128, + 135, + 144, + 150, + 160, + 162, + 180, + 192, + 200, + 216, + 225, + 240, + 243, + 250, + 256, + 270, + 288, + 300, + 320, + 324, + 360, + 375, + 384, + 400, + 405, + 432, + 450, + 480, + 486, + 500, + 512, + 540, + 576, + 600, + 625, + 640, + 648, + 675, + 720, + 729, + 750, + 768, + 800, + 810, + 864, + 900, + 960, + 972, + 1000, + 1024, + 1080, + 1125, + 1152, + 1200, + 1215, + 1250, + 1280, + 1296, + 1350, + 1440, + 1458, + 1500, + 1536, + 1600, + 1620, + 1728, + 1800, + 1875, + 1920, + 1944, + 2000, + 2025, + 2048, + 2160, + 2187, + 2250, + 2304, + 2400, + 2430, + 2500, + 2560, + 2592, + 2700, + 2880, + 2916, + 3000, + 3072, + 3125, + 3200, + 3240, + 3375, + 3456, + 3600, + 3645, + 3750, + 3840, + 3888, + 4000, + 4050, + 4096, + 4320, + 4374, + 4500, + 4608, + 4800, + 4860, + 5000, + 5120, + 5184, + 5400, + 5625, + 5760, + 5832, + 6000, + 6075, + 6144, + 6250, + 6400, + 6480, + 6561, + 6750, + 6912, + 7200, + 7290, + 7500, + 7680, + 7776, + 8000, + 8100, + 8192, + 8640, + 8748, + 9000, + 9216, + 9375, + 9600, + 9720, + 10000, + ) if target <= 6: return target @@ -182,21 +340,34 @@ def _validate_type(item, types=None, item_name=None, type_name=None): The types to be checked against. If str, must be one of {'int', 'str', 'numeric', 'path-like'}. """ - if types == "int": + if types == 'int': _ensure_int(item, name=item_name) return # terminate prematurely if not isinstance(types, (list, tuple)): types = [types] - check_types = sum(((type(None),) if type_ is None else (type_,) - if not isinstance(type_, str) else _multi[type_] - for type_ in types), ()) + check_types = sum( + ( + (type(None),) + if type_ is None + else (type_,) + if not isinstance(type_, str) + else _multi[type_] + for type_ in types + ), + (), + ) if not isinstance(item, check_types): if type_name is None: - type_name = ['None' if cls_ is None else cls_.__name__ - if not isinstance(cls_, str) else cls_ - for cls_ in types] + type_name = [ + 'None' + if cls_ is None + else cls_.__name__ + if not isinstance(cls_, str) + else cls_ + for cls_ in types + ] if len(type_name) == 1: type_name = type_name[0] elif len(type_name) == 2: @@ -204,8 +375,14 @@ def _validate_type(item, types=None, item_name=None, type_name=None): else: type_name[-1] = 'or ' + type_name[-1] type_name = ', '.join(type_name) - raise TypeError('%s must be an instance of %s, got %s instead' - % (item_name, type_name, type(item),)) + raise TypeError( + '%s must be an instance of %s, got %s instead' + % ( + item_name, + type_name, + type(item), + ) + ) def _check_option(parameter, value, allowed_values, extra=''): @@ -237,8 +414,10 @@ def _check_option(parameter, value, allowed_values, extra=''): # Prepare a nice error message for the user extra = ' ' + extra if extra else extra - msg = ("Invalid value for the '{parameter}' parameter{extra}. " - '{options}, but got {value!r} instead.') + msg = ( + "Invalid value for the '{parameter}' parameter{extra}. " + '{options}, but got {value!r} instead.' + ) allowed_values = list(allowed_values) # e.g., if a dict was given if len(allowed_values) == 1: options = f'The only allowed value is {repr(allowed_values[0])}' @@ -246,8 +425,9 @@ def _check_option(parameter, value, allowed_values, extra=''): options = 'Allowed values are ' options += ', '.join([f'{repr(v)}' for v in allowed_values[:-1]]) options += f', and {repr(allowed_values[-1])}' - raise ValueError(msg.format(parameter=parameter, options=options, - value=value, extra=extra)) + raise ValueError( + msg.format(parameter=parameter, options=options, value=value, extra=extra) + ) #################################################### @@ -304,12 +484,10 @@ def morlet(sfreq, freqs, n_cycles=7.0, sigma=None, zero_mean=False): freqs = np.array(freqs) if np.any(freqs <= 0): - raise ValueError("all frequencies in 'freqs' must be " - "greater than 0.") + raise ValueError("all frequencies in 'freqs' must be " 'greater than 0.') if (n_cycles.size != 1) and (n_cycles.size != len(freqs)): - raise ValueError("n_cycles should be fixed or defined for " - "each frequency.") + raise ValueError('n_cycles should be fixed or defined for ' 'each frequency.') for k, f in enumerate(freqs): if len(n_cycles) != 1: this_n_cycles = n_cycles[k] @@ -322,12 +500,12 @@ def morlet(sfreq, freqs, n_cycles=7.0, sigma=None, zero_mean=False): sigma_t = this_n_cycles / (2.0 * np.pi * sigma) # this scaling factor is proportional to (Tallon-Baudry 98): # (sigma_t*sqrt(pi))^(-1/2); - t = np.arange(0., 5. * sigma_t, 1.0 / sfreq) + t = np.arange(0.0, 5.0 * sigma_t, 1.0 / sfreq) t = np.r_[-t[::-1], t[1:]] oscillation = np.exp(2.0 * 1j * np.pi * f * t) - gaussian_enveloppe = np.exp(-t ** 2 / (2.0 * sigma_t ** 2)) + gaussian_enveloppe = np.exp(-(t**2) / (2.0 * sigma_t**2)) if zero_mean: # to make it zero mean - real_offset = np.exp(- 2 * (np.pi * f * sigma_t) ** 2) + real_offset = np.exp(-2 * (np.pi * f * sigma_t) ** 2) oscillation -= real_offset W = oscillation * gaussian_enveloppe W /= np.sqrt(0.5) * np.linalg.norm(W.ravel()) @@ -335,7 +513,7 @@ def morlet(sfreq, freqs, n_cycles=7.0, sigma=None, zero_mean=False): return Ws -def _cwt_gen(X, Ws, *, fsize=0, mode="same", decim=1, use_fft=True): +def _cwt_gen(X, Ws, *, fsize=0, mode='same', decim=1, use_fft=True): """Compute cwt with fft based convolutions or temporal convolutions. Parameters ---------- @@ -385,7 +563,7 @@ def _cwt_gen(X, Ws, *, fsize=0, mode="same", decim=1, use_fft=True): # Loop across wavelets for ii, W in enumerate(Ws): if use_fft: - ret = ifft(fft_x * fft_Ws[ii])[:n_times + W.size - 1] + ret = ifft(fft_x * fft_Ws[ii])[: n_times + W.size - 1] else: ret = np.convolve(x, W, mode=mode) @@ -393,8 +571,7 @@ def _cwt_gen(X, Ws, *, fsize=0, mode="same", decim=1, use_fft=True): if mode == 'valid': sz = int(abs(W.size - n_times)) + 1 offset = (n_times - sz) // 2 - this_slice = slice(offset // decim.step, - (offset + sz) // decim.step) + this_slice = slice(offset // decim.step, (offset + sz) // decim.step) if use_fft: ret = _centered(ret, sz) tfr[ii, this_slice] = ret[decim] @@ -452,8 +629,7 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): for W in Ws: # No need to check here, it's done earlier (outside parallel part) nfft = _get_nfft(W, X, use_fft, check=False) - coefs = _cwt_gen( - X, W, fsize=nfft, mode=mode, decim=decim, use_fft=use_fft) + coefs = _cwt_gen(X, W, fsize=nfft, mode=mode, decim=decim, use_fft=use_fft) # Inter-trial phase locking is apparently computed per taper... if 'itc' in output: @@ -469,7 +645,7 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): elif output == 'avg_power_itc': tfr_abs = np.abs(tfr) plf += tfr / tfr_abs # phase - tfr = tfr_abs ** 2 # power + tfr = tfr_abs**2 # power elif output == 'itc': plf += tfr / np.abs(tfr) # phase continue # not need to stack anything else than plf @@ -495,10 +671,20 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim): return tfrs -def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', - n_cycles=7.0, zero_mean=None, time_bandwidth=None, - use_fft=True, decim=1, output='complex', n_jobs=1, - verbose=None): +def _compute_tfr( + epoch_data, + freqs, + sfreq=1.0, + method='morlet', + n_cycles=7.0, + zero_mean=None, + time_bandwidth=None, + use_fft=True, + decim=1, + output='complex', + n_jobs=1, + verbose=None, +): """Compute time-frequency transforms. Parameters ---------- @@ -555,19 +741,30 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', # Check data epoch_data = np.asarray(epoch_data) if epoch_data.ndim != 3: - raise ValueError('epoch_data must be of shape (n_epochs, n_chans, ' - 'n_times), got %s' % (epoch_data.shape,)) + raise ValueError( + 'epoch_data must be of shape (n_epochs, n_chans, ' + 'n_times), got %s' % (epoch_data.shape,) + ) # Check params - freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim = \ - _check_tfr_param(freqs, sfreq, method, zero_mean, n_cycles, - time_bandwidth, use_fft, decim, output) + freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim = _check_tfr_param( + freqs, + sfreq, + method, + zero_mean, + n_cycles, + time_bandwidth, + use_fft, + decim, + output, + ) decim = _check_decim(decim) - if (freqs > sfreq / 2.).any(): - raise ValueError('Cannot compute freq above Nyquist freq of the data ' - '(%0.1f Hz), got %0.1f Hz' - % (sfreq / 2., freqs.max())) + if (freqs > sfreq / 2.0).any(): + raise ValueError( + 'Cannot compute freq above Nyquist freq of the data ' + '(%0.1f Hz), got %0.1f Hz' % (sfreq / 2.0, freqs.max()) + ) # We decimate *after* decomposition, so we need to create our kernels # for the original sfreq @@ -577,8 +774,10 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', # Check wavelets if len(Ws[0][0]) > epoch_data.shape[2]: - raise ValueError('At least one of the wavelets is longer than the ' - 'signal. Use a longer signal or shorter wavelets.') + raise ValueError( + 'At least one of the wavelets is longer than the ' + 'signal. Use a longer signal or shorter wavelets.' + ) # Initialize output n_freqs = len(freqs) @@ -603,7 +802,8 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', # Parallelization is applied across channels. tfrs = parallel( my_cwt(channel, Ws, output, use_fft, 'same', decim) - for channel in epoch_data.transpose(1, 0, 2)) + for channel in epoch_data.transpose(1, 0, 2) + ) # FIXME: to avoid overheads we should use np.array_split() for channel_idx, tfr in enumerate(tfrs): @@ -615,9 +815,18 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet', return out -def tfr_array_morlet(epoch_data, sfreq, freqs, n_cycles=7.0, - zero_mean=False, use_fft=True, decim=1, output='complex', - n_jobs=1, verbose=None): +def tfr_array_morlet( + epoch_data, + sfreq, + freqs, + n_cycles=7.0, + zero_mean=False, + use_fft=True, + decim=1, + output='complex', + n_jobs=1, + verbose=None, +): """Compute Time-Frequency Representation (TFR) using Morlet wavelets. Same computation as `~mne.time_frequency.tfr_morlet`, but operates on :class:`NumPy arrays ` instead of `~mne.Epochs` objects. @@ -675,22 +884,34 @@ def tfr_array_morlet(epoch_data, sfreq, freqs, n_cycles=7.0, ----- .. versionadded:: 0.14.0 """ - return _compute_tfr(epoch_data=epoch_data, freqs=freqs, - sfreq=sfreq, method='morlet', n_cycles=n_cycles, - zero_mean=zero_mean, time_bandwidth=None, - use_fft=use_fft, decim=decim, output=output, - n_jobs=n_jobs, verbose=verbose) + return _compute_tfr( + epoch_data=epoch_data, + freqs=freqs, + sfreq=sfreq, + method='morlet', + n_cycles=n_cycles, + zero_mean=zero_mean, + time_bandwidth=None, + use_fft=use_fft, + decim=decim, + output=output, + n_jobs=n_jobs, + verbose=verbose, + ) # Low level convolution + def _get_nfft(wavelets, X, use_fft=True, check=True): n_times = X.shape[-1] max_size = max(w.size for w in wavelets) if max_size > n_times: - msg = (f'At least one of the wavelets ({max_size}) is longer than the ' - f'signal ({n_times}). Consider using a longer signal or ' - 'shorter wavelets.') + msg = ( + f'At least one of the wavelets ({max_size}) is longer than the ' + f'signal ({n_times}). Consider using a longer signal or ' + 'shorter wavelets.' + ) if check: if use_fft: warn(msg) # warn(msg, UserWarning) @@ -701,35 +922,42 @@ def _get_nfft(wavelets, X, use_fft=True, check=True): return nfft -def _check_tfr_param(freqs, sfreq, method, zero_mean, n_cycles, - time_bandwidth, use_fft, decim, output): +def _check_tfr_param( + freqs, sfreq, method, zero_mean, n_cycles, time_bandwidth, use_fft, decim, output +): """Aux. function to _compute_tfr to check the params validity.""" # Check freqs if not isinstance(freqs, (list, np.ndarray)): - raise ValueError('freqs must be an array-like, got %s ' - 'instead.' % type(freqs)) + raise ValueError( + 'freqs must be an array-like, got %s ' 'instead.' % type(freqs) + ) freqs = np.asarray(freqs, dtype=float) if freqs.ndim != 1: - raise ValueError('freqs must be of shape (n_freqs,), got %s ' - 'instead.' % np.array(freqs.shape)) + raise ValueError( + 'freqs must be of shape (n_freqs,), got %s ' + 'instead.' % np.array(freqs.shape) + ) # Check sfreq if not isinstance(sfreq, (float, int)): - raise ValueError('sfreq must be a float or an int, got %s ' - 'instead.' % type(sfreq)) + raise ValueError( + 'sfreq must be a float or an int, got %s ' 'instead.' % type(sfreq) + ) sfreq = float(sfreq) # Default zero_mean = True if multitaper else False zero_mean = method == 'multitaper' if zero_mean is None else zero_mean if not isinstance(zero_mean, bool): - raise ValueError('zero_mean should be of type bool, got %s. instead' - % type(zero_mean)) + raise ValueError( + 'zero_mean should be of type bool, got %s. instead' % type(zero_mean) + ) freqs = np.asarray(freqs) if (method == 'multitaper') and (output == 'phase'): raise NotImplementedError( 'This function is not optimized to compute the phase using the ' - 'multitaper method. Use np.angle of the complex output instead.') + 'multitaper method. Use np.angle of the complex output instead.' + ) # Check n_cycles if isinstance(n_cycles, (int, float)): @@ -737,34 +965,40 @@ def _check_tfr_param(freqs, sfreq, method, zero_mean, n_cycles, elif isinstance(n_cycles, (list, np.ndarray)): n_cycles = np.array(n_cycles) if len(n_cycles) != len(freqs): - raise ValueError('n_cycles must be a float or an array of length ' - '%i frequencies, got %i cycles instead.' % - (len(freqs), len(n_cycles))) + raise ValueError( + 'n_cycles must be a float or an array of length ' + '%i frequencies, got %i cycles instead.' % (len(freqs), len(n_cycles)) + ) else: - raise ValueError('n_cycles must be a float or an array, got %s ' - 'instead.' % type(n_cycles)) + raise ValueError( + 'n_cycles must be a float or an array, got %s ' 'instead.' % type(n_cycles) + ) # Check time_bandwidth if (method == 'morlet') and (time_bandwidth is not None): raise ValueError('time_bandwidth only applies to "multitaper" method.') elif method == 'multitaper': - time_bandwidth = (4.0 if time_bandwidth is None - else float(time_bandwidth)) + time_bandwidth = 4.0 if time_bandwidth is None else float(time_bandwidth) # Check use_fft if not isinstance(use_fft, bool): - raise ValueError('use_fft must be a boolean, got %s ' - 'instead.' % type(use_fft)) + raise ValueError( + 'use_fft must be a boolean, got %s ' 'instead.' % type(use_fft) + ) # Check decim if isinstance(decim, int): decim = slice(None, None, decim) if not isinstance(decim, slice): - raise ValueError('decim must be an integer or a slice, ' - 'got %s instead.' % type(decim)) + raise ValueError( + 'decim must be an integer or a slice, ' 'got %s instead.' % type(decim) + ) # Check output - _check_option('output', output, ['complex', 'power', 'phase', - 'avg_power_itc', 'avg_power', 'itc']) + _check_option( + 'output', + output, + ['complex', 'power', 'phase', 'avg_power_itc', 'avg_power', 'itc'], + ) _check_option('method', method, ['multitaper', 'morlet']) return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim @@ -801,7 +1035,7 @@ def fill_doc(f): try: indented = docdict_indented[icount] except KeyError: - indent = " " * icount + indent = ' ' * icount docdict_indented[icount] = indented = {} for name, dstr in docdict.items(): lines = dstr.splitlines() @@ -809,15 +1043,15 @@ def fill_doc(f): newlines = [lines[0]] for line in lines[1:]: newlines.append(indent + line) - indented[name] = "\n".join(newlines) + indented[name] = '\n'.join(newlines) except IndexError: indented[name] = dstr try: f.__doc__ = docstring % indented except (TypeError, ValueError, KeyError) as exp: funcname = f.__name__ - funcname = docstring.split("\n")[0] if funcname is None else funcname - raise RuntimeError("Error documenting %s:\n%s" % (funcname, str(exp))) + funcname = docstring.split('\n')[0] if funcname is None else funcname + raise RuntimeError('Error documenting %s:\n%s' % (funcname, str(exp))) return f @@ -857,7 +1091,7 @@ def copy_doc(source): def wrapper(func): if source.__doc__ is None or len(source.__doc__) == 0: - raise ValueError("Cannot copy docstring: docstring was empty.") + raise ValueError('Cannot copy docstring: docstring was empty.') doc = source.__doc__ if func.__doc__ is not None: doc += func.__doc__ diff --git a/hnn_core/extracellular.py b/hnn_core/extracellular.py index d76355df1..285fa99d0 100644 --- a/hnn_core/extracellular.py +++ b/hnn_core/extracellular.py @@ -51,11 +51,12 @@ def calculate_csd2d(lfp_data, delta=1): csd[electrode] = -(LFP[electrode - 1] - 2*LFP[electrode] + LFP[electrode + 1]) / spacing ** 2 """ - csd2d = -np.diff(lfp_data, n=2, axis=0) / delta ** 2 + csd2d = -np.diff(lfp_data, n=2, axis=0) / delta**2 bottom_border = csd2d[-1, :] * 2 - csd2d[-2, :] top_border = csd2d[0, :] * 2 - csd2d[1, :] - csd2d = np.concatenate((top_border[None, ...], csd2d, - bottom_border[None, ...]), axis=0) + csd2d = np.concatenate( + (top_border[None, ...], csd2d, bottom_border[None, ...]), axis=0 + ) return csd2d @@ -79,25 +80,30 @@ def _get_laminar_z_coords(electrode_positions): raise ValueError( 'Electrode array positions must contain more than 1 contact to be ' 'compatible with laminar profiling in a neocortical column. Got ' - f'{n_contacts} electrode contact positions.') + f'{n_contacts} electrode contact positions.' + ) displacements = np.diff(electrode_positions, axis=0) z_delta = np.abs(displacements[0, 2]) magnitudes = np.linalg.norm(displacements, axis=1) cross_prods = np.cross(displacements[:-1], displacements[1:]) - if not (np.allclose(magnitudes, magnitudes[0]) and # equally spaced - z_delta > 0 and # changes in z-direction - np.allclose(cross_prods, 0)): # colinear + if not ( + np.allclose(magnitudes, magnitudes[0]) # equally spaced + and z_delta > 0 # changes in z-direction + and np.allclose(cross_prods, 0) + ): # colinear raise ValueError( 'Electrode contacts are incompatible with laminar profiling ' 'in a neocortical column. Make sure the ' 'electrode positions are equispaced, colinear, and projecting ' - 'along the z-axis.') + 'along the z-axis.' + ) else: return np.array(electrode_positions)[:, 2], z_delta -def _transfer_resistance(section, electrode_pos, conductivity, method, - min_distance=0.5): +def _transfer_resistance( + section, electrode_pos, conductivity, method, min_distance=0.5 +): """Transfer resistance between section and electrode position. To arrive at the extracellular potential, the value returned by this @@ -152,16 +158,14 @@ def _transfer_resistance(section, electrode_pos, conductivity, method, line_lens = np.array([first_len] + list(line_lens[2:])) if method == 'psa': - # distance from segment midpoints to electrode - dis = norm(np.tile(electrode_pos, (section.nseg, 1)) - seg_ctr, - axis=1) + dis = norm(np.tile(electrode_pos, (section.nseg, 1)) - seg_ctr, axis=1) # To avoid very large values when electrode is placed close to a # segment junction, enforce minimal radial distance dis = np.maximum(dis, min_distance) - phi = 1. / dis + phi = 1.0 / dis elif method == 'lsa': # From: Appendix C (pp. 137) in Holt, G. R. A critical reexamination of @@ -191,21 +195,21 @@ def _transfer_resistance(section, electrode_pos, conductivity, method, # projection: H = a.cos(theta) = a.dot(b) / |a| H = np.dot(b, a) / norm_a # NB can be negative L = H + norm_a - R2 = np.dot(b, b) - H ** 2 # NB squares + R2 = np.dot(b, b) - H**2 # NB squares # To avoid very large values when electrode is placed (anywhere) on # the section axis, enforce minimal perpendicular distance - R2 = np.maximum(R2, min_distance ** 2) + R2 = np.maximum(R2, min_distance**2) if L < 0 and H < 0: # electrode is "behind" line segment - num = np.sqrt(H ** 2 + R2) - H # == norm(b) - H - denom = np.sqrt(L ** 2 + R2) - L + num = np.sqrt(H**2 + R2) - H # == norm(b) - H + denom = np.sqrt(L**2 + R2) - L elif L > 0 and H < 0: # electrode is "on top of" line segment - num = (np.sqrt(H ** 2 + R2) - H) * (L + np.sqrt(L ** 2 + R2)) + num = (np.sqrt(H**2 + R2) - H) * (L + np.sqrt(L**2 + R2)) denom = R2 else: # electrode is "ahead of" line segment - num = np.sqrt(L ** 2 + R2) + L - denom = np.sqrt(H ** 2 + R2) + H # == norm(b) + H + num = np.sqrt(L**2 + R2) + L + denom = np.sqrt(H**2 + R2) + H # == norm(b) + H phi[idx] = np.log(num / denom) / norm_a @@ -272,23 +276,32 @@ class ExtracellularArray: measured values of conductivity in rat cortex (note units there are mS/cm) """ - def __init__(self, positions, *, conductivity=0.3, method='psa', - min_distance=0.5, times=None, voltages=None): - + def __init__( + self, + positions, + *, + conductivity=0.3, + method='psa', + min_distance=0.5, + times=None, + voltages=None, + ): _validate_type(positions, (tuple, list), 'positions') if np.array(positions).shape == (3,): # a single coordinate given positions = [positions] for pos in positions: _validate_type(pos, (tuple, list), 'positions') if len(pos) != 3: - raise ValueError('positions should be provided as xyz ' - f'coordinate triplets, got: {positions}') + raise ValueError( + 'positions should be provided as xyz ' + f'coordinate triplets, got: {positions}' + ) _validate_type(conductivity, float, 'conductivity') - if not conductivity > 0.: + if not conductivity > 0.0: raise ValueError('conductivity must be a positive number') _validate_type(min_distance, float, 'min_distance') - if not min_distance > 0.: + if not min_distance > 0.0: raise ValueError('min_distance must be a positive number') if method is not None: # method allowed to be None for testing _validate_type(method, str, 'method') @@ -305,12 +318,16 @@ def __init__(self, positions, *, conductivity=0.3, method='psa', if voltages.size != 0: # voltages is not None n_trials, n_electrodes, n_times = voltages.shape if len(positions) != n_electrodes: - raise ValueError(f'number of voltage traces must match number' - f' of channels, got {n_electrodes} and ' - f'{len(positions)}') + raise ValueError( + f'number of voltage traces must match number' + f' of channels, got {n_electrodes} and ' + f'{len(positions)}' + ) if len(times) != n_times: - raise ValueError('length of times and voltages must match,' - f' got {len(times)} and {n_times} ') + raise ValueError( + 'length of times and voltages must match,' + f' got {len(times)} and {n_times} ' + ) self.positions = positions self.n_contacts = len(self.positions) @@ -330,21 +347,29 @@ def __getitem__(self, trial_no): elif isinstance(trial_no, (list, tuple)): return_data = [self._data[trial] for trial in trial_no] else: - raise TypeError(f'trial index must be int, slice or list-like,' - f' got: {trial_no} which is {type(trial_no)}') + raise TypeError( + f'trial index must be int, slice or list-like,' + f' got: {trial_no} which is {type(trial_no)}' + ) except IndexError: - raise IndexError(f'the data contain {len(self)} trials, the ' - f'indices provided are out of range: {trial_no}') - return ExtracellularArray(self.positions, - conductivity=self.conductivity, - method=self.method, - times=self.times, - voltages=return_data) + raise IndexError( + f'the data contain {len(self)} trials, the ' + f'indices provided are out of range: {trial_no}' + ) + return ExtracellularArray( + self.positions, + conductivity=self.conductivity, + method=self.method, + times=self.times, + voltages=return_data, + ) def __repr__(self): class_name = self.__class__.__name__ - msg = (f'{self.n_contacts} electrodes, ' - f'conductivity={self.conductivity}, method={self.method}') + msg = ( + f'{self.n_contacts} electrodes, ' + f'conductivity={self.conductivity}, method={self.method}' + ) if len(self._data) > 0: msg += f' | {len(self._data)} trials, {len(self.times)} times' else: @@ -360,9 +385,21 @@ def __eq__(self, other): all_attrs = dir(self) attrs_to_ignore = [x for x in all_attrs if x.startswith('_')] - attrs_to_ignore.extend(['conductivity', 'copy', 'n_contacts', - 'plot_csd', 'plot_lfp', 'sfreq', 'smooth', - 'voltages', 'to_dict', 'times', 'voltages']) + attrs_to_ignore.extend( + [ + 'conductivity', + 'copy', + 'n_contacts', + 'plot_csd', + 'plot_lfp', + 'sfreq', + 'smooth', + 'voltages', + 'to_dict', + 'times', + 'voltages', + ] + ) attrs_to_check = [x for x in all_attrs if x not in attrs_to_ignore] # Check all other attributes @@ -370,8 +407,10 @@ def __eq__(self, other): if getattr(self, attr) != getattr(other, attr): return False - if not ((self.times == other.times).all() and - (self.voltages == other.voltages).all()): + if not ( + (self.times == other.times).all() + and (self.voltages == other.voltages).all() + ): return False return True @@ -407,9 +446,10 @@ def sfreq(self): if np.abs(dT.max() - Tsamp) > 1e-3 or np.abs(dT.min() - Tsamp) > 1e-3: raise RuntimeError( 'Extracellular sampling times vary by more than 1 us. Check ' - 'times-attribute for errors.') + 'times-attribute for errors.' + ) - return 1000. / Tsamp # times are in in ms + return 1000.0 / Tsamp # times are in in ms def _reset(self): self._data = list() @@ -438,14 +478,25 @@ def smooth(self, window_len): for n_trial in range(len(self)): for n_contact in range(self.n_contacts): self._data[n_trial][n_contact] = smooth_waveform( - self._data[n_trial][n_contact], window_len, - self.sfreq) # XXX smooth_waveform returns ndarray + self._data[n_trial][n_contact], window_len, self.sfreq + ) # XXX smooth_waveform returns ndarray return self - def plot_lfp(self, *, trial_no=None, contact_no=None, tmin=None, tmax=None, - ax=None, decim=None, color='cividis', voltage_offset=50, - voltage_scalebar=200, show=True): + def plot_lfp( + self, + *, + trial_no=None, + contact_no=None, + tmin=None, + tmax=None, + ax=None, + decim=None, + color='cividis', + voltage_offset=50, + voltage_scalebar=200, + show=True, + ): """Plot laminar local field potential time series. One plot is created for each trial. Multiple trials can be overlaid @@ -490,12 +541,15 @@ def plot_lfp(self, *, trial_no=None, contact_no=None, tmin=None, tmax=None, if trial_no is None: plot_data = self.voltages elif isinstance(trial_no, (list, tuple, int, slice)): - plot_data = self.voltages[trial_no, ] + plot_data = self.voltages[trial_no,] else: raise ValueError(f'unknown trial number type, got {trial_no}') if isinstance(contact_no, (list, tuple, int, slice)): - plot_data = plot_data[:, contact_no, ] + plot_data = plot_data[ + :, + contact_no, + ] elif contact_no is not None: raise ValueError(f'unknown contact number type, got {contact_no}') @@ -503,16 +557,30 @@ def plot_lfp(self, *, trial_no=None, contact_no=None, tmin=None, tmax=None, for trial_data in plot_data: fig = plot_laminar_lfp( - self.times, trial_data, tmin=tmin, tmax=tmax, ax=ax, - decim=decim, color=color, + self.times, + trial_data, + tmin=tmin, + tmax=tmax, + ax=ax, + decim=decim, + color=color, voltage_offset=voltage_offset, voltage_scalebar=voltage_scalebar, contact_labels=contact_labels, - show=show) + show=show, + ) return fig - def plot_csd(self, vmin=None, vmax=None, interpolation='spline', - sink='b', colorbar=True, ax=None, show=True): + def plot_csd( + self, + vmin=None, + vmax=None, + interpolation='spline', + sink='b', + colorbar=True, + ax=None, + show=True, + ): """Plot laminar current source density (CSD) estimation Parameters @@ -542,17 +610,24 @@ def plot_csd(self, vmin=None, vmax=None, interpolation='spline', The matplotlib figure handle. """ from .viz import plot_laminar_csd + lfp = self.voltages[0] contact_labels, delta = _get_laminar_z_coords(self.positions) - csd_data = calculate_csd2d(lfp_data=lfp, - delta=delta) - - fig = plot_laminar_csd(self.times, csd_data, - contact_labels=contact_labels, ax=ax, - colorbar=colorbar, vmin=vmin, vmax=vmax, - interpolation=interpolation, sink=sink, - show=show) + csd_data = calculate_csd2d(lfp_data=lfp, delta=delta) + + fig = plot_laminar_csd( + self.times, + csd_data, + contact_labels=contact_labels, + ax=ax, + colorbar=colorbar, + vmin=vmin, + vmax=vmax, + interpolation=interpolation, + sink=sink, + show=show, + ) return fig @@ -584,6 +659,7 @@ class _ExtracellularArrayBuilder(object): The instance of :class:`hnn_core.extracellular.ExtracellularArray` to build in NEURON-Python """ + def __init__(self, array): self.array = array self.n_contacts = array.n_contacts @@ -615,11 +691,9 @@ def _build(self, cvode=None, include_celltypes='all'): """ secs_on_rank = h.allsec() # get all h.Sections known to this MPI rank _validate_type(include_celltypes, str) - _check_option('include_celltypes', include_celltypes, ['all', 'Pyr', - 'Basket']) + _check_option('include_celltypes', include_celltypes, ['all', 'Pyr', 'Basket']) if include_celltypes.lower() != 'all': - secs_on_rank = [s for s in secs_on_rank if - include_celltypes in s.name()] + secs_on_rank = [s for s in secs_on_rank if include_celltypes in s.name()] segment_counts = [sec.nseg for sec in secs_on_rank] n_total_segments = np.sum(segment_counts) @@ -633,12 +707,12 @@ def _build(self, cvode=None, include_celltypes='all'): for sec in secs_on_rank: for seg in sec: # section end points (0, 1) not included # set Nth pointer to the net membrane current at this segment - self._nrn_imem_ptrvec.pset( - ptr_idx, sec(seg.x)._ref_i_membrane_) + self._nrn_imem_ptrvec.pset(ptr_idx, sec(seg.x)._ref_i_membrane_) ptr_idx += 1 if ptr_idx != n_total_segments: - raise RuntimeError(f'Expected {n_total_segments} imem pointers, ' - f'got {ptr_idx}.') + raise RuntimeError( + f'Expected {n_total_segments} imem pointers, ' f'got {ptr_idx}.' + ) # transfer resistances for each segment (keep in Neuron Matrix object) self._nrn_r_transfer = h.Matrix(self.n_contacts, n_total_segments) @@ -648,16 +722,18 @@ def _build(self, cvode=None, include_celltypes='all'): transfer_resistance = list() for sec in secs_on_rank: this_xfer_r = _transfer_resistance( - sec, pos, conductivity=self.array.conductivity, + sec, + pos, + conductivity=self.array.conductivity, method=self.array.method, - min_distance=self.array.min_distance) + min_distance=self.array.min_distance, + ) transfer_resistance.extend(this_xfer_r) self._nrn_r_transfer.setrow(row, h.Vector(transfer_resistance)) else: # for testing, make a matrix of ones - self._nrn_r_transfer.setrow(row, - h.Vector(n_total_segments, 1.)) + self._nrn_r_transfer.setrow(row, h.Vector(n_total_segments, 1.0)) # record time for each array self._nrn_times = h.Vector().record(h._ref_t) @@ -666,7 +742,7 @@ def _build(self, cvode=None, include_celltypes='all'): # potential at electrode (_PC.allreduce called in _simulate_dipole) # NB voltages of all contacts are initialised to 0 mV, i.e., the # potential at time 0.0 ms is defined to be zero. - self._nrn_voltages = h.Vector(self.n_contacts, 0.) + self._nrn_voltages = h.Vector(self.n_contacts, 0.0) # NB we must make a copy of the function reference, and keep it for # later decoupling using extra_scatter_gather_remove @@ -692,8 +768,7 @@ def _gather_nrn_voltages(self): # Calculate potentials by multiplying the _nrn_imem_vec by the matrix # _nrn_r_transfer. This is equivalent to a row-by-row dot-product: # V_i(t) = SUM_j ( R_i,j x I_j (t) ) - self._nrn_voltages.append( - self._nrn_r_transfer.mulv(self._nrn_imem_vec)) + self._nrn_voltages.append(self._nrn_r_transfer.mulv(self._nrn_imem_vec)) # NB all values appended to the h.Vector _nrn_voltages at current time # step. The vector will have size (n_contacts x n_samples, 1), which # will be reshaped later to (n_contacts, n_samples). @@ -702,23 +777,23 @@ def _gather_nrn_voltages(self): def _nrn_n_samples(self): """Return the length (in samples) of the extracellular data.""" if self._nrn_voltages.size() % self.n_contacts != 0: - raise RuntimeError(f'Something went wrong: have {self.n_contacts}' - f', but {self._nrn_voltages.size()} samples') + raise RuntimeError( + f'Something went wrong: have {self.n_contacts}' + f', but {self._nrn_voltages.size()} samples' + ) return int(self._nrn_voltages.size() / self.n_contacts) def _get_nrn_voltages(self): """The extracellular data (n_contacts x n_samples).""" if len(self._nrn_voltages) > 0: - assert (self._nrn_voltages.size() == - self.n_contacts * self._nrn_n_samples) + assert self._nrn_voltages.size() == self.n_contacts * self._nrn_n_samples # first reshape to a Neuron Matrix object extmat = h.Matrix(self.n_contacts, self._nrn_n_samples) extmat.from_vector(self._nrn_voltages) # then unpack into 2D python list and return - return [extmat.getrow(ii).to_python() for - ii in range(extmat.nrow())] + return [extmat.getrow(ii).to_python() for ii in range(extmat.nrow())] else: raise RuntimeError('Simulation not yet run!') diff --git a/hnn_core/gui/__init__.py b/hnn_core/gui/__init__.py index 3397d699b..3142afc2e 100644 --- a/hnn_core/gui/__init__.py +++ b/hnn_core/gui/__init__.py @@ -1 +1 @@ -from .gui import HNNGUI, launch \ No newline at end of file +from .gui import HNNGUI, launch diff --git a/hnn_core/gui/_logging.py b/hnn_core/gui/_logging.py index 7dfdeb2f4..5487849bb 100644 --- a/hnn_core/gui/_logging.py +++ b/hnn_core/gui/_logging.py @@ -1,5 +1,5 @@ import logging -_logger_name = "hnn_gui" +_logger_name = 'hnn_gui' logger = logging.getLogger(_logger_name) logger.setLevel(logging.INFO) diff --git a/hnn_core/gui/_viz_manager.py b/hnn_core/gui/_viz_manager.py index 6b383eee4..7793df253 100644 --- a/hnn_core/gui/_viz_manager.py +++ b/hnn_core/gui/_viz_manager.py @@ -10,8 +10,20 @@ import matplotlib.pyplot as plt import numpy as np from IPython.display import display -from ipywidgets import (Box, Button, Dropdown, BoundedFloatText, FloatText, - HBox, Label, Layout, Output, Tab, VBox, link) +from ipywidgets import ( + Box, + Button, + Dropdown, + BoundedFloatText, + FloatText, + HBox, + Label, + Layout, + Output, + Tab, + VBox, + link, +) from hnn_core.dipole import average_dipoles, _rmse from hnn_core.gui._logging import logger @@ -40,112 +52,96 @@ _ext_data_disabled_plot_types = ['spikes', 'input histogram', 'network'] _spectrogram_color_maps = [ - "viridis", - "plasma", - "inferno", - "magma", - "cividis", + 'viridis', + 'plasma', + 'inferno', + 'magma', + 'cividis', ] fig_templates = { - "[Blank] 2row x 1col (1:3)": { - "kwargs": { - "gridspec_kw": {"height_ratios": [1, 3]} - }, - "mosaic": "00\n11", + '[Blank] 2row x 1col (1:3)': { + 'kwargs': {'gridspec_kw': {'height_ratios': [1, 3]}}, + 'mosaic': '00\n11', }, - "[Blank] 2row x 1col (1:1)": { - "kwargs": { - "gridspec_kw": {"height_ratios": [1, 1]} - }, - "mosaic": "00\n11", + '[Blank] 2row x 1col (1:1)': { + 'kwargs': {'gridspec_kw': {'height_ratios': [1, 1]}}, + 'mosaic': '00\n11', }, - "[Blank] 1row x 2col (1:1)": { - "kwargs": { - "gridspec_kw": {"height_ratios": [1, 1]} - }, - "mosaic": "01\n01", + '[Blank] 1row x 2col (1:1)': { + 'kwargs': {'gridspec_kw': {'height_ratios': [1, 1]}}, + 'mosaic': '01\n01', }, - "[Blank] single figure": { - "kwargs": { - "gridspec_kw": "" - }, - "mosaic": "00\n00", + '[Blank] single figure': { + 'kwargs': {'gridspec_kw': ''}, + 'mosaic': '00\n00', + }, + '[Blank] 2row x 2col (1:1)': { + 'kwargs': {'gridspec_kw': {'height_ratios': [1, 1]}}, + 'mosaic': '01\n23', }, - "[Blank] 2row x 2col (1:1)": { - "kwargs": { - "gridspec_kw": {"height_ratios": [1, 1]} - }, - "mosaic": "01\n23", - } } data_templates = { - "Drive-Dipole (2x1)": { - "kwargs": { - "gridspec_kw": {"height_ratios": [1, 3]} - }, - "mosaic": "00\n11", - "ax_plots": [("ax0", "input histogram"), ("ax1", "current dipole")] + 'Drive-Dipole (2x1)': { + 'kwargs': {'gridspec_kw': {'height_ratios': [1, 3]}}, + 'mosaic': '00\n11', + 'ax_plots': [('ax0', 'input histogram'), ('ax1', 'current dipole')], }, - "Dipole Layers (3x1)": { - "kwargs": { - "gridspec_kw": {"height_ratios": [1, 1, 1]} - }, - "mosaic": "0\n1\n2", - "ax_plots": [("ax0", "layer2 dipole"), ("ax1", "layer5 dipole"), - ("ax2", "current dipole")] + 'Dipole Layers (3x1)': { + 'kwargs': {'gridspec_kw': {'height_ratios': [1, 1, 1]}}, + 'mosaic': '0\n1\n2', + 'ax_plots': [ + ('ax0', 'layer2 dipole'), + ('ax1', 'layer5 dipole'), + ('ax2', 'current dipole'), + ], }, - "Drive-Spikes (2x1)": { - "kwargs": { - "gridspec_kw": {"height_ratios": [1, 3]} - }, - "mosaic": "00\n11", - "ax_plots": [("ax0", "input histogram"), ("ax1", "spikes")] + 'Drive-Spikes (2x1)': { + 'kwargs': {'gridspec_kw': {'height_ratios': [1, 3]}}, + 'mosaic': '00\n11', + 'ax_plots': [('ax0', 'input histogram'), ('ax1', 'spikes')], }, - "Dipole-Spectrogram (2x1)": { - "kwargs": { - "gridspec_kw": {"height_ratios": [1, 3]} - }, - "mosaic": "00\n11", - "ax_plots": [("ax0", "current dipole"), ("ax1", "spectrogram")] + 'Dipole-Spectrogram (2x1)': { + 'kwargs': {'gridspec_kw': {'height_ratios': [1, 3]}}, + 'mosaic': '00\n11', + 'ax_plots': [('ax0', 'current dipole'), ('ax1', 'spectrogram')], }, - "Dipole-Spikes (2x1)": { - "kwargs": { - "gridspec_kw": {"height_ratios": [1, 1]} - }, - "mosaic": "00\n11", - "ax_plots": [("ax0", "current dipole"), ("ax1", "spikes")] + 'Dipole-Spikes (2x1)': { + 'kwargs': {'gridspec_kw': {'height_ratios': [1, 1]}}, + 'mosaic': '00\n11', + 'ax_plots': [('ax0', 'current dipole'), ('ax1', 'spikes')], }, - "Drive-Dipole-Spectrogram (3x1)": { - "kwargs": { - "gridspec_kw": {"height_ratios": [1, 1, 2]} - }, - "mosaic": "0\n1\n2", - "ax_plots": [("ax0", "input histogram"), ("ax1", "current dipole"), - ("ax2", "spectrogram")] + 'Drive-Dipole-Spectrogram (3x1)': { + 'kwargs': {'gridspec_kw': {'height_ratios': [1, 1, 2]}}, + 'mosaic': '0\n1\n2', + 'ax_plots': [ + ('ax0', 'input histogram'), + ('ax1', 'current dipole'), + ('ax2', 'spectrogram'), + ], + }, + 'PSD Layers (3x1)': { + 'kwargs': {'gridspec_kw': {'height_ratios': [1, 1, 1]}}, + 'mosaic': '0\n1\n2', + 'ax_plots': [ + ('ax0', 'layer2 dipole'), + ('ax1', 'layer5 dipole'), + ('ax2', 'PSD'), + ], }, - "PSD Layers (3x1)": { - "kwargs": { - "gridspec_kw": {"height_ratios": [1, 1, 1]} - }, - "mosaic": "0\n1\n2", - "ax_plots": [("ax0", "layer2 dipole"), ("ax1", "layer5 dipole"), - ("ax2", "PSD")] - } } -def check_sim_plot_types( - new_sim_name, plot_type_selection, target_selection, data): - if not _is_simulation(data["simulations"][new_sim_name]): +def check_sim_plot_types(new_sim_name, plot_type_selection, target_selection, data): + if not _is_simulation(data['simulations'][new_sim_name]): plot_type_selection.options = [ pt for pt in _plot_types if pt not in _ext_data_disabled_plot_types ] else: plot_type_selection.options = _plot_types # deal with target data - all_possible_targets = list(data["simulations"].keys()) + all_possible_targets = list(data['simulations'].keys()) all_possible_targets.remove(new_sim_name) target_selection.options = all_possible_targets + ['None'] target_selection.value = 'None' @@ -157,8 +153,7 @@ def _check_template_type_is_data_dependant(template_name): def target_comparison_change(new_target_name, simulation_selection, data): - """Triggered when the target data is turned on or changed. - """ + """Triggered when the target data is turned on or changed.""" pass @@ -183,6 +178,7 @@ def unlink_relink(attribute): widgets """ + def _unlink_relink(f): @wraps(f) def wrapper(self, *args, **kwargs): @@ -197,16 +193,18 @@ def wrapper(self, *args, **kwargs): link_attribute.link() return result + return wrapper + return _unlink_relink def _idx2figname(idx): - return f"Figure {idx}" + return f'Figure {idx}' def _figname2idx(fname): - return int(fname.split(" ")[-1]) + return int(fname.split(' ')[-1]) def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config): @@ -231,7 +229,8 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config): for dpl in dpls_copied: if plot_config['dipole_smooth'] > 0: dpl.smooth(plot_config['dipole_smooth']).scale( - plot_config['dipole_scaling']) + plot_config['dipole_scaling'] + ) else: dpl.scale(plot_config['dipole_scaling']) @@ -283,7 +282,9 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config): distal_drives.append(name) net_copied.cell_response.plot_spikes_hist( - ax=ax, show=False, spike_types=all_drives, + ax=ax, + show=False, + spike_types=all_drives, invert_spike_types=distal_drives, color=drive_colors, ) @@ -293,8 +294,9 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config): min_f = plot_config['min_spectral_frequency'] max_f = plot_config['max_spectral_frequency'] color = ax._get_lines.get_next_color() - dpls_copied[0].plot_psd(fmin=min_f, fmax=max_f, ax=ax, color=color, - label=sim_name, show=False) + dpls_copied[0].plot_psd( + fmin=min_f, fmax=max_f, ax=ax, color=color, label=sim_name, show=False + ) elif plot_type == 'spectrogram': if len(dpls_copied) > 0: @@ -302,43 +304,49 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config): max_f = plot_config['max_spectral_frequency'] step_f = 1.0 freqs = np.arange(min_f, max_f, step_f) - n_cycles = freqs / 2. + n_cycles = freqs / 2.0 dpls_copied[0].plot_tfr_morlet( freqs, n_cycles=n_cycles, colormap=plot_config['spectrogram_cm'], - ax=ax, colorbar_inside=True, - show=False) + ax=ax, + colorbar_inside=True, + show=False, + ) elif 'dipole' in plot_type: if len(dpls_copied) > 0: if len(dpls_copied) > 1: - label = f"{sim_name}: average" + label = f'{sim_name}: average' else: label = sim_name color = ax._get_lines.get_next_color() if plot_type == 'current dipole': - plot_dipole(dpls_copied, - ax=ax, - label=label, - color=color, - average=True, - show=False) + plot_dipole( + dpls_copied, + ax=ax, + label=label, + color=color, + average=True, + show=False, + ) else: layer_namemap = { - "layer2": "L2", - "layer5": "L5", + 'layer2': 'L2', + 'layer5': 'L5', } - plot_dipole(dpls_copied, - ax=ax, - label=label, - color=color, - layer=layer_namemap[plot_type.split(" ")[0]], - average=True, - show=False) + plot_dipole( + dpls_copied, + ax=ax, + label=label, + color=color, + layer=layer_namemap[plot_type.split(' ')[0]], + average=True, + show=False, + ) else: - print("No dipole data") + print('No dipole data') elif plot_type == 'network': if net_copied: @@ -349,10 +357,10 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config): io_buf = io.BytesIO() _fig.savefig(io_buf, format='raw') io_buf.seek(0) - img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), - dtype=np.uint8), - newshape=(int(_fig.bbox.bounds[3]), - int(_fig.bbox.bounds[2]), -1)) + img_arr = np.reshape( + np.frombuffer(io_buf.getvalue(), dtype=np.uint8), + newshape=(int(_fig.bbox.bounds[3]), int(_fig.bbox.bounds[2]), -1), + ) io_buf.close() _ = ax.imshow(img_arr) @@ -395,12 +403,25 @@ def _avg_dipole_check(dpls): return dpl -def _plot_on_axes(b, simulations_widget, widgets_plot_type, - data_widget, - spectrogram_colormap_selection, - min_spectral_frequency, max_spectral_frequency, - dipole_smooth, dipole_scaling, data_smooth, data_scaling, - widgets, data, fig_idx, fig, ax, existing_plots): +def _plot_on_axes( + b, + simulations_widget, + widgets_plot_type, + data_widget, + spectrogram_colormap_selection, + min_spectral_frequency, + max_spectral_frequency, + dipole_smooth, + dipole_scaling, + data_smooth, + data_scaling, + widgets, + data, + fig_idx, + fig, + ax, + existing_plots, +): """Plotting different types of data on the given axes. Now this function is also responsible for comparing multiple simulations, @@ -450,35 +471,37 @@ def _plot_on_axes(b, simulations_widget, widgets_plot_type, single_simulation = data['simulations'][sim_name] simulation_plot_config = { - "dipole_scaling": dipole_scaling.value, - "dipole_smooth": dipole_smooth.value, - "min_spectral_frequency": min_spectral_frequency.value, - "max_spectral_frequency": max_spectral_frequency.value, - "spectrogram_cm": spectrogram_colormap_selection.value + 'dipole_scaling': dipole_scaling.value, + 'dipole_smooth': dipole_smooth.value, + 'min_spectral_frequency': min_spectral_frequency.value, + 'max_spectral_frequency': max_spectral_frequency.value, + 'spectrogram_cm': spectrogram_colormap_selection.value, } - dpls_processed = _update_ax(fig, ax, single_simulation, sim_name, - plot_type, simulation_plot_config) + dpls_processed = _update_ax( + fig, ax, single_simulation, sim_name, plot_type, simulation_plot_config + ) # If target_simulations is not None and we are plotting a dipole, # we need to plot the target dipole as well. - if data_widget.value in data['simulations'].keys( - ) and plot_type == 'current dipole': - + if ( + data_widget.value in data['simulations'].keys() + and plot_type == 'current dipole' + ): target_sim_name = data_widget.value target_sim = data['simulations'][target_sim_name] data_plot_config = { - "dipole_scaling": data_scaling.value, - "dipole_smooth": data_smooth.value, - "min_spectral_frequency": min_spectral_frequency.value, - "max_spectral_frequency": max_spectral_frequency.value, - "spectrogram_cm": spectrogram_colormap_selection.value + 'dipole_scaling': data_scaling.value, + 'dipole_smooth': data_smooth.value, + 'min_spectral_frequency': min_spectral_frequency.value, + 'max_spectral_frequency': max_spectral_frequency.value, + 'spectrogram_cm': spectrogram_colormap_selection.value, } # plot the target dipole. target_dpl_processed = _update_ax( - fig, ax, target_sim, target_sim_name, plot_type, - data_plot_config)[0] # we assume there is only one dipole. + fig, ax, target_sim, target_sim_name, plot_type, data_plot_config + )[0] # we assume there is only one dipole. # calculate the RMSE between the two dipoles. t0 = 0.0 @@ -491,46 +514,63 @@ def _plot_on_axes(b, simulations_widget, widgets_plot_type, annotation_text = f'RMSE({sim_name}, {target_sim_name}): {rmse:.4f}' # find subplot's annotation - annotation = next((child for child in ax.get_children() - if isinstance(child, plt.Annotation)), None) + annotation = next( + (child for child in ax.get_children() if isinstance(child, plt.Annotation)), + None, + ) # if the subplot already has an annotation, update its text. # Otherwise, create a new one. if annotation is not None: annotation.set_text(annotation_text) else: - ax.annotate(annotation_text, - xy=(0.95, 0.05), - xycoords='axes fraction', - horizontalalignment='right', - verticalalignment='bottom', - fontsize=12) - - rmse_logger_text = (f'RMSE {rmse:.4f} (' - f'{sim_name} smooth:{dipole_smooth.value} ' - f'scale:{dipole_scaling.value} \n' - f'{target_sim_name} smooth:{data_smooth.value} ' - f'scale:{data_scaling.value})') + ax.annotate( + annotation_text, + xy=(0.95, 0.05), + xycoords='axes fraction', + horizontalalignment='right', + verticalalignment='bottom', + fontsize=12, + ) + + rmse_logger_text = ( + f'RMSE {rmse:.4f} (' + f'{sim_name} smooth:{dipole_smooth.value} ' + f'scale:{dipole_scaling.value} \n' + f'{target_sim_name} smooth:{data_smooth.value} ' + f'scale:{data_scaling.value})' + ) logger.info(rmse_logger_text) - existing_plots.children = (*existing_plots.children, - Label(f"{sim_name}: {plot_type}")) + existing_plots.children = ( + *existing_plots.children, + Label(f'{sim_name}: {plot_type}'), + ) if data['use_ipympl'] is False: _static_rerender(widgets, fig, fig_idx) else: _dynamic_rerender(fig) -def _clear_axis(b, widgets, data, fig_idx, fig, ax, widgets_plot_type, - existing_plots, add_plot_button): +def _clear_axis( + b, + widgets, + data, + fig_idx, + fig, + ax, + widgets_plot_type, + existing_plots, + add_plot_button, +): ax.clear() # Remove "plot_spikes_hist"'s inverted second axes object, if exists, and # if the axis you are clearing is the spike histogram - if ax._label == "Spike histogram": + if ax._label == 'Spike histogram': for axis in fig.axes: - if axis._label == "Inverted spike histogram": + if axis._label == 'Inverted spike histogram': axis.remove() # remove attached colorbar if exists @@ -551,18 +591,20 @@ def _clear_axis(b, widgets, data, fig_idx, fig, ax, widgets_plot_type, def _get_ax_control(widgets, data, fig_idx, fig, ax): analysis_style = {'description_width': '200px'} - layout = Layout(width="98%") + layout = Layout(width='98%') simulation_names = tuple(data['simulations'].keys()) sim_index = 0 if not simulation_names: - simulation_names = ("None",) + simulation_names = ('None',) else: # Find the last simulation with a non-None 'net' sim_index = next( - (idx for idx, sim_name in - reversed(list(enumerate(simulation_names))) - if _is_simulation(data["simulations"][sim_name])), - 0 # Default value if no such simulation is found + ( + idx + for idx, sim_name in reversed(list(enumerate(simulation_names))) + if _is_simulation(data['simulations'][sim_name]) + ), + 0, # Default value if no such simulation is found ) simulation_selection = Dropdown( @@ -599,8 +641,8 @@ def _get_ax_control(widgets, data, fig_idx, fig, ax): # This will check the sim plot types dropdown available options # for the specific sim name in the simulation_selection dropdown options check_sim_plot_types( - simulation_names[sim_index], - plot_type_selection, target_data_selection, data) + simulation_names[sim_index], plot_type_selection, target_data_selection, data + ) spectrogram_colormap_selection = Dropdown( description='Spectrogram Colormap:', @@ -614,28 +656,32 @@ def _get_ax_control(widgets, data, fig_idx, fig, ax): description='Dipole Smooth Window (ms):', disabled=False, layout=layout, - style=analysis_style) + style=analysis_style, + ) simulation_dipole_scaling = FloatText( value=3000, description='Simulation Dipole Scaling:', disabled=False, layout=layout, - style=analysis_style) + style=analysis_style, + ) data_dipole_smooth = FloatText( value=0, description='Data Smooth Window (ms):', disabled=False, layout=layout, - style=analysis_style) + style=analysis_style, + ) data_dipole_scaling = FloatText( value=1, description='Data Dipole Scaling:', disabled=False, layout=layout, - style=analysis_style) + style=analysis_style, + ) min_spectral_frequency = BoundedFloatText( value=10, @@ -644,7 +690,8 @@ def _get_ax_control(widgets, data, fig_idx, fig, ax): description='Min Spectral Frequency (Hz):', disabled=False, layout=layout, - style=analysis_style) + style=analysis_style, + ) max_spectral_frequency = BoundedFloatText( value=100, @@ -653,7 +700,8 @@ def _get_ax_control(widgets, data, fig_idx, fig, ax): description='Max Spectral Frequency (Hz):', disabled=False, layout=layout, - style=analysis_style) + style=analysis_style, + ) existing_plots = VBox([]) @@ -662,15 +710,14 @@ def _get_ax_control(widgets, data, fig_idx, fig, ax): def _on_sim_data_change(new_sim_name): return check_sim_plot_types( - new_sim_name.new, plot_type_selection, target_data_selection, data) + new_sim_name.new, plot_type_selection, target_data_selection, data + ) def _on_target_comparison_change(new_target_name): - return target_comparison_change(new_target_name, simulation_selection, - data) + return target_comparison_change(new_target_name, simulation_selection, data) def _on_plot_type_change(new_plot_type): - return plot_type_coupled_change(new_plot_type.new, - target_data_selection) + return plot_type_coupled_change(new_plot_type.new, target_data_selection) simulation_selection.observe(_on_sim_data_change, 'value') target_data_selection.observe(_on_target_comparison_change, 'value') @@ -687,7 +734,8 @@ def _on_plot_type_change(new_plot_type): widgets_plot_type=plot_type_selection, existing_plots=existing_plots, add_plot_button=plot_button, - )) + ) + ) plot_button.on_click( partial( @@ -708,17 +756,29 @@ def _on_plot_type_change(new_plot_type): fig=fig, ax=ax, existing_plots=existing_plots, - )) - - vbox = VBox([ - plot_type_selection, simulation_selection, simulation_dipole_smooth, - simulation_dipole_scaling, target_data_selection, data_dipole_smooth, - data_dipole_scaling, min_spectral_frequency, max_spectral_frequency, - spectrogram_colormap_selection, - HBox( - [plot_button, clear_button], - layout=Layout(justify_content='space-between'), - ), existing_plots], layout=Layout(width="98%")) + ) + ) + + vbox = VBox( + [ + plot_type_selection, + simulation_selection, + simulation_dipole_smooth, + simulation_dipole_scaling, + target_data_selection, + data_dipole_smooth, + data_dipole_scaling, + min_spectral_frequency, + max_spectral_frequency, + spectrogram_colormap_selection, + HBox( + [plot_button, clear_button], + layout=Layout(justify_content='space-between'), + ), + existing_plots, + ], + layout=Layout(width='98%'), + ) return vbox @@ -732,7 +792,7 @@ def _close_figure(b, widgets, data, fig_idx): # Get the index based on the title tab_idx = titles.index(_idx2figname(fig_idx)) # Remove the child and title specified - print(f"Del fig_idx={fig_idx}, fig_idx={fig_idx}") + print(f'Del fig_idx={fig_idx}, fig_idx={fig_idx}') tab_children.pop(tab_idx) titles.pop(tab_idx) # Reset children and titles of the tab object @@ -773,16 +833,20 @@ def _add_axes_controls(widgets, data, fig, axd): for i in range(len(children)): controls.set_title(i, f'ax{i}') - close_fig_button = Button(description=f'Close {_idx2figname(fig_idx)}', - button_style='danger', - icon='close', - layout=Layout(width="98%")) + close_fig_button = Button( + description=f'Close {_idx2figname(fig_idx)}', + button_style='danger', + icon='close', + layout=Layout(width='98%'), + ) close_fig_button.on_click( - partial(_close_figure, widgets=widgets, data=data, fig_idx=fig_idx)) + partial(_close_figure, widgets=widgets, data=data, fig_idx=fig_idx) + ) n_tabs = len(widgets['axes_config_tabs'].children) - widgets['axes_config_tabs'].children = widgets[ - 'axes_config_tabs'].children + (VBox([close_fig_button, controls]), ) + widgets['axes_config_tabs'].children = widgets['axes_config_tabs'].children + ( + VBox([close_fig_button, controls]), + ) widgets['axes_config_tabs'].set_title(n_tabs, _idx2figname(fig_idx)) @@ -797,14 +861,16 @@ def _add_figure(b, widgets, data, template_type, scale=0.95, dpi=96): with widgets['figs_output']: display(widgets['figs_tabs']) - widgets['figs_tabs'].children = ( - [s for s in widgets['figs_tabs'].children] + [fig_outputs] - ) + widgets['figs_tabs'].children = [s for s in widgets['figs_tabs'].children] + [ + fig_outputs + ] widgets['figs_tabs'].set_title(n_tabs, _idx2figname(fig_idx)) with fig_outputs: - figsize = (scale * ((int(viz_output_layout.width[:-2]) - 10) / dpi), - scale * ((int(viz_output_layout.height[:-2]) - 10) / dpi)) + figsize = ( + scale * ((int(viz_output_layout.width[:-2]) - 10) / dpi), + scale * ((int(viz_output_layout.height[:-2]) - 10) / dpi), + ) mosaic = template_type['mosaic'] kwargs = template_type['kwargs'] with plt.ioff(): @@ -826,8 +892,7 @@ def _add_figure(b, widgets, data, template_type, scale=0.95, dpi=96): data['fig_idx']['idx'] += 1 -def _postprocess_template(template_name, fig, idx, - use_ipympl=True, widgets=None): +def _postprocess_template(template_name, fig, idx, use_ipympl=True, widgets=None): """Post-processes and re-renders plot templates with determined styles Templates are constructed on panel-by-panel basis. If adjustments need to @@ -870,7 +935,7 @@ class _VizManager: """ def __init__(self, gui_data, viz_layout): - plt.close("all") + plt.close('all') self.viz_layout = viz_layout self.use_ipympl = 'ipympl' in matplotlib.get_backend() @@ -894,14 +959,16 @@ def __init__(self, gui_data, viz_layout): options=template_names, value=template_names[0], style={'description_width': 'initial'}, - layout=Layout(width="98%")) + layout=Layout(width='98%'), + ) self.templates_dropdown.observe(self._layout_template_change, 'value') self.make_fig_button = Button( description='Make figure', - button_style="primary", + button_style='primary', style={'button_color': self.viz_layout['theme_color']}, - layout=self.viz_layout['btn']) + layout=self.viz_layout['btn'], + ) self.make_fig_button.on_click(self.add_figure) self.datasets_dropdown = Dropdown( @@ -909,32 +976,33 @@ def __init__(self, gui_data, viz_layout): options=[], value=None, style={'description_width': 'initial'}, - layout=Layout(width="98%")) + layout=Layout(width='98%'), + ) # data - self.fig_idx = {"idx": 1} + self.fig_idx = {'idx': 1} self.figs = {} self.gui_data = gui_data @property def widgets(self): return { - "figs_output": self.figs_output, - "axes_config_tabs": self.axes_config_tabs, - "figs_tabs": self.figs_tabs, - "templates_dropdown": self.templates_dropdown, - "dataset_dropdown": self.datasets_dropdown + 'figs_output': self.figs_output, + 'axes_config_tabs': self.axes_config_tabs, + 'figs_tabs': self.figs_tabs, + 'templates_dropdown': self.templates_dropdown, + 'dataset_dropdown': self.datasets_dropdown, } @property def data(self): """Provides easy access to visualization-related data.""" return { - "use_ipympl": self.use_ipympl, - "simulations": self.gui_data["simulation_data"], - "fig_idx": self.fig_idx, - "visualization_output": self.viz_layout['visualization_output'], - "figs": self.figs + 'use_ipympl': self.use_ipympl, + 'simulations': self.gui_data['simulation_data'], + 'fig_idx': self.fig_idx, + 'visualization_output': self.viz_layout['visualization_output'], + 'figs': self.figs, } def reset_fig_config_tabs(self, template_name=None): @@ -964,104 +1032,115 @@ def compose(self): display(Label(_fig_placeholder)) fig_output_container = VBox( - [self.figs_output], layout=self.viz_layout['visualization_window']) - - config_panel = VBox([ - Box( - [ - self.templates_dropdown, - self.datasets_dropdown, - self.make_fig_button, - ], - layout=Layout( - display='flex', - flex_flow='column', - align_items='stretch', + [self.figs_output], layout=self.viz_layout['visualization_window'] + ) + + config_panel = VBox( + [ + Box( + [ + self.templates_dropdown, + self.datasets_dropdown, + self.make_fig_button, + ], + layout=Layout( + display='flex', + flex_flow='column', + align_items='stretch', + ), ), - ), - Label("Figure config:"), - self.axes_config_output, - ]) + Label('Figure config:'), + self.axes_config_output, + ] + ) return config_panel, fig_output_container def _layout_template_change(self, template_type): # check if plot set type requires loaded sim-data if _check_template_type_is_data_dependant(template_type.new): # Add only simualated data - sim_names = [simulations for simulations, sim_name - in self.data["simulations"].items() - if sim_name['net'] is not None] + sim_names = [ + simulations + for simulations, sim_name in self.data['simulations'].items() + if sim_name['net'] is not None + ] if len(sim_names) == 0: - sim_names = [" "] + sim_names = [' '] self.datasets_dropdown.options = sim_names self.datasets_dropdown.value = sim_names[0] # show list of simulated to gui dropdown - self.datasets_dropdown.layout.visibility = "visible" + self.datasets_dropdown.layout.visibility = 'visible' else: # hide sim-data dropdown - self.datasets_dropdown.layout.visibility = "hidden" + self.datasets_dropdown.layout.visibility = 'hidden' @unlink_relink(attribute='figs_config_tab_link') def add_figure(self, b=None): - """Add a figure and corresponding config tabs to the dashboard. - """ - if len(self.data["simulations"]) == 0: - logger.error("No data has been loaded") + """Add a figure and corresponding config tabs to the dashboard.""" + if len(self.data['simulations']) == 0: + logger.error('No data has been loaded') return template_name = self.widgets['templates_dropdown'].value - is_data_template = (_check_template_type_is_data_dependant - (template_name)) + is_data_template = _check_template_type_is_data_dependant(template_name) if is_data_template: - sim_name = self.widgets["dataset_dropdown"].value - if sim_name not in self.data["simulations"]: - logger.error("No simulation data has been loaded") + sim_name = self.widgets['dataset_dropdown'].value + if sim_name not in self.data['simulations']: + logger.error('No simulation data has been loaded') return # Use data_templates dictionary if it's a data dependent layout - template_type = (data_templates[template_name] - if is_data_template - else fig_templates[template_name]) + template_type = ( + data_templates[template_name] + if is_data_template + else fig_templates[template_name] + ) # Add empty figure according to template arguments - _add_figure(None, - self.widgets, - self.data, - template_type, - scale=0.97, - dpi=self.viz_layout['dpi']) + _add_figure( + None, + self.widgets, + self.data, + template_type, + scale=0.97, + dpi=self.viz_layout['dpi'], + ) # Plot data if it is a data-dependent template if is_data_template: fig_name = _idx2figname(self.data['fig_idx']['idx'] - 1) # get figs per axis - ax_plots = data_templates[template_name]["ax_plots"] + ax_plots = data_templates[template_name]['ax_plots'] for ax_name, plot_type in ax_plots: # paint fig in axis - self._simulate_edit_figure(fig_name, ax_name, sim_name, - plot_type, {}, "plot") + self._simulate_edit_figure( + fig_name, ax_name, sim_name, plot_type, {}, 'plot' + ) # template post-processing fig_key = self.data['fig_idx']['idx'] - 1 - _postprocess_template(template_name, - fig=self.figs[fig_key], - idx=fig_key, - use_ipympl=self.use_ipympl, - widgets=self.widgets, - ) - - logger.info(f"Figure {template_name} for " - f"simulation {sim_name} " - "has been created" - ) + _postprocess_template( + template_name, + fig=self.figs[fig_key], + idx=fig_key, + use_ipympl=self.use_ipympl, + widgets=self.widgets, + ) + + logger.info( + f'Figure {template_name} for ' + f'simulation {sim_name} ' + 'has been created' + ) def _simulate_add_fig(self): self.make_fig_button.click() def _simulate_switch_fig_template(self, template_name): - assert (template_name in fig_templates.keys() or - data_templates.keys()), "No such template" + assert ( + template_name in fig_templates.keys() or data_templates.keys() + ), 'No such template' self.templates_dropdown.value = template_name def _simulate_delete_figure(self, fig_name): @@ -1074,8 +1153,15 @@ def _simulate_delete_figure(self, fig_name): close_button = self.axes_config_tabs.children[tab_idx].children[0] close_button.click() - def _simulate_edit_figure(self, fig_name, ax_name, simulation_name, - plot_type, preprocessing_config, operation): + def _simulate_edit_figure( + self, + fig_name, + ax_name, + simulation_name, + plot_type, + preprocessing_config, + operation, + ): """Manipulate a certain figure. Parameters @@ -1101,19 +1187,19 @@ def _simulate_edit_figure(self, fig_name, ax_name, simulation_name, """ assert simulation_name in self.data['simulations'].keys() assert plot_type in _plot_types - assert operation in ("plot", "clear") + assert operation in ('plot', 'clear') # Select the figure tab tab = self.axes_config_tabs titles = tab.titles - assert fig_name in titles, "No such figure" + assert fig_name in titles, 'No such figure' tab_idx = titles.index(fig_name) self.axes_config_tabs.selected_index = tab_idx # Select the figure panel/ax tab ax_control_tabs = self.axes_config_tabs.children[tab_idx].children[1] ax_titles = ax_control_tabs.titles - assert ax_name in ax_titles, "No such axis" + assert ax_name in ax_titles, 'No such axis' ax_idx = ax_titles.index(ax_name) ax_control_tabs.selected_index = ax_idx @@ -1127,14 +1213,14 @@ def _simulate_edit_figure(self, fig_name, ax_name, simulation_name, # Set the plot configurations config_name_idx = { - "dipole_smooth": 2, - "dipole_scaling": 3, - "data_to_compare": 4, - "data_smooth": 5, - "data_scaling": 6, - "min_spectral_frequency": 7, - "max_spectral_frequency": 8, - "spectrogram_colormap_selection": 9, + 'dipole_smooth': 2, + 'dipole_scaling': 3, + 'data_to_compare': 4, + 'data_smooth': 5, + 'data_scaling': 6, + 'min_spectral_frequency': 7, + 'max_spectral_frequency': 8, + 'spectrogram_colormap_selection': 9, } for conf_key, conf_val in preprocessing_config.items(): assert conf_key in config_name_idx.keys() @@ -1143,9 +1229,9 @@ def _simulate_edit_figure(self, fig_name, ax_name, simulation_name, conf_widget.value = conf_val buttons = ax_control_tabs.children[ax_idx].children[-2] - if operation == "plot": + if operation == 'plot': buttons.children[0].click() - elif operation == "clear": + elif operation == 'clear': buttons.children[1].click() diff --git a/hnn_core/gui/gui.py b/hnn_core/gui/gui.py index 3b5c66350..f21025eaa 100644 --- a/hnn_core/gui/gui.py +++ b/hnn_core/gui/gui.py @@ -20,10 +20,25 @@ from datetime import datetime from functools import partial from IPython.display import IFrame, display -from ipywidgets import (HTML, Accordion, AppLayout, BoundedFloatText, - BoundedIntText, Button, Dropdown, FileUpload, VBox, - HBox, IntText, Layout, Output, RadioButtons, Tab, Text, - Checkbox) +from ipywidgets import ( + HTML, + Accordion, + AppLayout, + BoundedFloatText, + BoundedIntText, + Button, + Dropdown, + FileUpload, + VBox, + HBox, + IntText, + Layout, + Output, + RadioButtons, + Tab, + Text, + Checkbox, +) from ipywidgets.embed import embed_minimal_html import hnn_core from hnn_core import JoblibBackend, MPIBackend, simulate_dipole @@ -31,18 +46,15 @@ from hnn_core.gui._viz_manager import _VizManager, _idx2figname from hnn_core.network import pick_connection from hnn_core.dipole import _read_dipole_txt -from hnn_core.params_default import (get_L2Pyr_params_default, - get_L5Pyr_params_default) +from hnn_core.params_default import get_L2Pyr_params_default, get_L5Pyr_params_default from hnn_core.hnn_io import dict_to_network, write_network_configuration from hnn_core.cells_default import _exp_g_at_dist hnn_core_root = Path(hnn_core.__file__).parent -default_network_configuration = (hnn_core_root / 'param' / - 'jones2009_base.json') +default_network_configuration = hnn_core_root / 'param' / 'jones2009_base.json' cell_parameters_dict = { - "Geometry L2": - [ + 'Geometry L2': [ ('Soma length', 'micron', 'soma_L'), ('Soma diameter', 'micron', 'soma_diam'), ('Soma capacitive density', 'F/cm2', 'soma_cm'), @@ -62,11 +74,9 @@ ('Basal Dendrite 2 length', 'micron', 'basal2_L'), ('Basal Dendrite 2 diameter', 'micron', 'basal2_diam'), ('Basal Dendrite 3 length', 'micron', 'basal3_L'), - ('Basal Dendrite 3 diameter', 'micron', 'basal3_diam') + ('Basal Dendrite 3 diameter', 'micron', 'basal3_diam'), ], - - "Geometry L5": - [ + 'Geometry L5': [ ('Soma length', 'micron', 'soma_L'), ('Soma diameter', 'micron', 'soma_diam'), ('Soma capacitive density', 'F/cm2', 'soma_cm'), @@ -88,10 +98,9 @@ ('Basal Dendrite 2 length', 'micron', 'basal2_L'), ('Basal Dendrite 2 diameter', 'micron', 'basal2_diam'), ('Basal Dendrite 3 length', 'micron', 'basal3_L'), - ('Basal Dendrite 3 diameter', 'micron', 'basal3_diam') + ('Basal Dendrite 3 diameter', 'micron', 'basal3_diam'), ], - "Synapses": - [ + 'Synapses': [ ('AMPA reversal', 'mV', 'ampa_e'), ('AMPA rise time', 'ms', 'ampa_tau1'), ('AMPA decay time', 'ms', 'ampa_tau2'), @@ -103,10 +112,9 @@ ('GABAA decay time', 'ms', 'gabaa_tau2'), ('GABAB reversal', 'mV', 'gabab_e'), ('GABAB rise time', 'ms', 'gabab_tau1'), - ('GABAB decay time', 'ms', 'gabab_tau2') + ('GABAB decay time', 'ms', 'gabab_tau2'), ], - "Biophysics L2": - [ + 'Biophysics L2': [ ('Soma Kv channel density', 'S/cm2', 'soma_gkbar_hh2'), ('Soma Na channel density', 'S/cm2', 'soma_gnabar_hh2'), ('Soma leak reversal', 'mV', 'soma_el_hh2'), @@ -116,10 +124,9 @@ ('Dendrite Na channel density', 'S/cm2', 'dend_gnabar_hh2'), ('Dendrite leak reversal', 'mV', 'dend_el_hh2'), ('Dendrite leak channel density', 'S/cm2', 'dend_gl_hh2'), - ('Dendrite Km channel density', 'pS/micron2', 'dend_gbar_km') + ('Dendrite Km channel density', 'pS/micron2', 'dend_gbar_km'), ], - "Biophysics L5": - [ + 'Biophysics L5': [ ('Soma Kv channel density', 'S/cm2', 'soma_gkbar_hh2'), ('Soma Na channel density', 'S/cm2', 'soma_gnabar_hh2'), ('Soma leak reversal', 'mV', 'soma_el_hh2'), @@ -139,8 +146,8 @@ ('Dendrite KCa channel density', 'pS/micron2', 'dend_gbar_kca'), ('Dendrite Km channel density', 'pS/micron2', 'dend_gbar_km'), ('Dendrite CaT channel density', 'S/cm2', 'dend_gbar_cat'), - ('Dendrite HCN channel density', 'S/cm2', 'dend_gbar_ar') - ] + ('Dendrite HCN channel density', 'S/cm2', 'dend_gbar_ar'), + ], } @@ -154,9 +161,9 @@ def emit(self, record): new_output = { 'name': 'stdout', 'output_type': 'stream', - 'text': formatted_record + '\n' + 'text': formatted_record + '\n', } - self.out.outputs = (new_output, ) + self.out.outputs + self.out.outputs = (new_output,) + self.out.outputs class HNNGUI: @@ -237,19 +244,21 @@ class HNNGUI: in the network. """ - def __init__(self, theme_color="#802989", - total_height=800, - total_width=1300, - header_height=50, - button_height=30, - operation_box_height=60, - drive_widget_width=200, - left_sidebar_width=576, - log_window_height=150, - status_height=30, - dpi=96, - network_configuration=default_network_configuration, - ): + def __init__( + self, + theme_color='#802989', + total_height=800, + total_width=1300, + header_height=50, + button_height=30, + operation_box_height=60, + drive_widget_width=200, + left_sidebar_width=576, + log_window_height=150, + status_height=30, + dpi=96, + network_configuration=default_network_configuration, + ): # set up styling. self.total_height = total_height self.total_width = total_width @@ -257,63 +266,68 @@ def __init__(self, theme_color="#802989", viz_win_width = self.total_width - left_sidebar_width main_content_height = self.total_height - status_height - config_box_height = main_content_height - (log_window_height + - operation_box_height) + config_box_height = main_content_height - ( + log_window_height + operation_box_height + ) self.layout = { - "dpi": dpi, - "header_height": f"{header_height}px", - "theme_color": theme_color, - "btn": Layout(height=f"{button_height}px", width='auto'), - "run_btn": Layout(height=f"{button_height}px", width='10%'), - "btn_full_w": Layout(height=f"{button_height}px", width='100%'), - "del_fig_btn": Layout(height=f"{button_height}px", width='auto'), - "log_out": Layout(border='1px solid gray', - height=f"{log_window_height - 10}px", - overflow='auto'), - "viz_config": Layout(width='99%'), - "simulations_list": Layout(width=f'{left_sidebar_width - 50}px'), - "visualization_window": Layout( - width=f"{viz_win_width - 10}px", - height=f"{main_content_height - 10}px", + 'dpi': dpi, + 'header_height': f'{header_height}px', + 'theme_color': theme_color, + 'btn': Layout(height=f'{button_height}px', width='auto'), + 'run_btn': Layout(height=f'{button_height}px', width='10%'), + 'btn_full_w': Layout(height=f'{button_height}px', width='100%'), + 'del_fig_btn': Layout(height=f'{button_height}px', width='auto'), + 'log_out': Layout( + border='1px solid gray', + height=f'{log_window_height - 10}px', + overflow='auto', + ), + 'viz_config': Layout(width='99%'), + 'simulations_list': Layout(width=f'{left_sidebar_width - 50}px'), + 'visualization_window': Layout( + width=f'{viz_win_width - 10}px', + height=f'{main_content_height - 10}px', border='1px solid gray', - overflow='scroll'), - "visualization_output": Layout( - width=f"{viz_win_width - 50}px", - height=f"{main_content_height - 100}px", + overflow='scroll', + ), + 'visualization_output': Layout( + width=f'{viz_win_width - 50}px', + height=f'{main_content_height - 100}px', border='1px solid gray', - overflow='scroll'), - "left_sidebar": Layout(width=f"{left_sidebar_width}px", - height=f"{main_content_height}px"), - "left_tab": Layout(width=f"{left_sidebar_width}px", - height=f"{config_box_height}px"), - "operation_box": Layout(width=f"{left_sidebar_width}px", - height=f"{operation_box_height}px", - flex_wrap="wrap", - ), - "config_box": Layout(width=f"{left_sidebar_width}px", - height=f"{config_box_height - 100}px"), - "drive_widget": Layout(width="auto"), - "drive_textbox": Layout(width='270px', height='auto'), + overflow='scroll', + ), + 'left_sidebar': Layout( + width=f'{left_sidebar_width}px', height=f'{main_content_height}px' + ), + 'left_tab': Layout( + width=f'{left_sidebar_width}px', height=f'{config_box_height}px' + ), + 'operation_box': Layout( + width=f'{left_sidebar_width}px', + height=f'{operation_box_height}px', + flex_wrap='wrap', + ), + 'config_box': Layout( + width=f'{left_sidebar_width}px', height=f'{config_box_height - 100}px' + ), + 'drive_widget': Layout(width='auto'), + 'drive_textbox': Layout(width='270px', height='auto'), # simulation status related - "simulation_status_height": f"{status_height}px", - "simulation_status_common": "background:gray;padding-left:10px", - "simulation_status_running": "background:orange;padding-left:10px", - "simulation_status_failed": "background:red;padding-left:10px", - "simulation_status_finished": "background:green;padding-left:10px", + 'simulation_status_height': f'{status_height}px', + 'simulation_status_common': 'background:gray;padding-left:10px', + 'simulation_status_running': 'background:orange;padding-left:10px', + 'simulation_status_failed': 'background:red;padding-left:10px', + 'simulation_status_finished': 'background:green;padding-left:10px', } self._simulation_status_contents = { - "not_running": - f"""
Not running
""", - "running": - f"""
Running...
""", - "finished": - f"""
Simulation finished
""", - "failed": - f"""
Simulation failed
""", } @@ -325,87 +339,123 @@ def __init__(self, theme_color="#802989", # Simulation parameters self.widget_tstop = BoundedFloatText( - value=170, description='tstop (ms):', min=0, max=1e6, step=1, - disabled=False) + value=170, description='tstop (ms):', min=0, max=1e6, step=1, disabled=False + ) self.widget_dt = BoundedFloatText( - value=0.025, description='dt (ms):', min=0, max=10, step=0.01, - disabled=False) - self.widget_ntrials = IntText(value=1, description='Trials:', - disabled=False) - self.widget_simulation_name = Text(value='default', - placeholder='ID of your simulation', - description='Name:', - disabled=False) - self.widget_backend_selection = Dropdown(options=[('Joblib', 'Joblib'), - ('MPI', 'MPI')], - value='Joblib', - description='Backend:') - self.widget_mpi_cmd = Text(value='mpiexec', - placeholder='Fill if applies', - description='MPI cmd:', disabled=False) - self.widget_n_jobs = BoundedIntText(value=1, min=1, - max=multiprocessing.cpu_count(), - description='Cores:', - disabled=False) + value=0.025, + description='dt (ms):', + min=0, + max=10, + step=0.01, + disabled=False, + ) + self.widget_ntrials = IntText(value=1, description='Trials:', disabled=False) + self.widget_simulation_name = Text( + value='default', + placeholder='ID of your simulation', + description='Name:', + disabled=False, + ) + self.widget_backend_selection = Dropdown( + options=[('Joblib', 'Joblib'), ('MPI', 'MPI')], + value='Joblib', + description='Backend:', + ) + self.widget_mpi_cmd = Text( + value='mpiexec', + placeholder='Fill if applies', + description='MPI cmd:', + disabled=False, + ) + self.widget_n_jobs = BoundedIntText( + value=1, + min=1, + max=multiprocessing.cpu_count(), + description='Cores:', + disabled=False, + ) self.load_data_button = FileUpload( - accept='.txt,.csv', multiple=False, + accept='.txt,.csv', + multiple=False, style={'button_color': self.layout['theme_color']}, layout=self.layout['btn'], description='Load data', - button_style='success') + button_style='success', + ) # Create save simulation widget wrapper self.save_simuation_button = self._init_html_download_button( - title='Save Simulation', mimetype='text/csv') + title='Save Simulation', mimetype='text/csv' + ) self.save_config_button = self._init_html_download_button( - title='Save Network', mimetype='application/json') + title='Save Network', mimetype='application/json' + ) - self.simulation_list_widget = Dropdown(options=[], - value=None, - description='', - layout={'width': '15%'}) + self.simulation_list_widget = Dropdown( + options=[], value=None, description='', layout={'width': '15%'} + ) # Drive selection self.widget_drive_type_selection = RadioButtons( options=['Evoked', 'Poisson', 'Rhythmic', 'Tonic'], value='Evoked', description='Drive:', disabled=False, - layout=self.layout['drive_widget']) + layout=self.layout['drive_widget'], + ) self.widget_location_selection = RadioButtons( - options=['proximal', 'distal'], value='proximal', - description='Location', disabled=False, - layout=self.layout['drive_widget']) + options=['proximal', 'distal'], + value='proximal', + description='Location', + disabled=False, + layout=self.layout['drive_widget'], + ) self.add_drive_button = create_expanded_button( - 'Add drive', 'primary', layout=self.layout['btn'], - button_color=self.layout['theme_color']) + 'Add drive', + 'primary', + layout=self.layout['btn'], + button_color=self.layout['theme_color'], + ) # Dashboard level buttons self.run_button = create_expanded_button( - 'Run', 'success', layout=self.layout['run_btn'], - button_color=self.layout['theme_color']) + 'Run', + 'success', + layout=self.layout['run_btn'], + button_color=self.layout['theme_color'], + ) self.load_connectivity_button = FileUpload( - accept='.json', multiple=False, + accept='.json', + multiple=False, style={'button_color': self.layout['theme_color']}, description='Load local network connectivity', - layout=self.layout['btn_full_w'], button_style='success') + layout=self.layout['btn_full_w'], + button_style='success', + ) self.load_drives_button = FileUpload( - accept='.json', multiple=False, + accept='.json', + multiple=False, style={'button_color': self.layout['theme_color']}, - description='Load external drives', layout=self.layout['btn'], - button_style='success') + description='Load external drives', + layout=self.layout['btn'], + button_style='success', + ) self.delete_drive_button = create_expanded_button( - 'Delete drives', 'success', layout=self.layout['btn'], - button_color=self.layout['theme_color']) + 'Delete drives', + 'success', + layout=self.layout['btn'], + button_color=self.layout['theme_color'], + ) self.cell_type_radio_buttons = RadioButtons( - options=['L2/3 Pyramidal', 'L5 Pyramidal'], - description='Cell type:') + options=['L2/3 Pyramidal', 'L5 Pyramidal'], description='Cell type:' + ) self.cell_layer_radio_buttons = RadioButtons( options=['Geometry', 'Synapses', 'Biophysics'], - description='Cell Properties:') + description='Cell Properties:', + ) # Plotting window @@ -430,37 +480,40 @@ def __init__(self, theme_color="#802989", def get_cell_parameters_dict(self): """Returns the number of elements in the - cell_parameters_dict dictionary. - This is for testing purposes """ + cell_parameters_dict dictionary. + This is for testing purposes""" return cell_parameters_dict def _init_html_download_button(self, title, mimetype): - b64 = base64.b64encode("".encode()) + b64 = base64.b64encode(''.encode()) payload = b64.decode() # Initialliting HTML code for download button - self.html_download_button = ''' + self.html_download_button = """ - ''' + """ # Create widget wrapper - return ( - HTML(self.html_download_button. - format(payload=payload, - filename={""}, - is_disabled="disabled", - btn_height=self.layout['run_btn'].height, - color_theme=self.layout['theme_color'], - title=title, - mimetype=mimetype))) + return HTML( + self.html_download_button.format( + payload=payload, + filename={''}, + is_disabled='disabled', + btn_height=self.layout['run_btn'].height, + color_theme=self.layout['theme_color'], + title=title, + mimetype=mimetype, + ) + ) def add_logging_window_logger(self): handler = _OutputWidgetHandler(self._log_out) handler.setFormatter( - logging.Formatter('%(asctime)s - [%(levelname)s] %(message)s')) + logging.Formatter('%(asctime)s - [%(levelname)s] %(message)s') + ) logger.addHandler(handler) def _init_ui_components(self): @@ -484,37 +537,44 @@ def _init_ui_components(self): # static parts # Running status self._simulation_status_bar = HTML( - value=self._simulation_status_contents['not_running']) + value=self._simulation_status_contents['not_running'] + ) self._log_window = HBox([self._log_out], layout=self.layout['log_out']) self._operation_buttons = HBox( - [self.run_button, self.load_data_button, - self.save_config_button, - self.save_simuation_button, - self.simulation_list_widget], - layout=self.layout['operation_box']) + [ + self.run_button, + self.load_data_button, + self.save_config_button, + self.save_simuation_button, + self.simulation_list_widget, + ], + layout=self.layout['operation_box'], + ) # title - self._header = HTML(value=f"""
- HUMAN NEOCORTICAL NEUROSOLVER
""") + HUMAN NEOCORTICAL NEUROSOLVER""" + ) @property def analysis_config(self): """Provides everything viz window needs except for the data.""" return { - "viz_style": self.layout['visualization_output'], + 'viz_style': self.layout['visualization_output'], # widgets - "plot_outputs": self.plot_outputs_dict, - "plot_dropdowns": self.plot_dropdown_types_dict, - "plot_sim_selections": self.plot_sim_selections_dict, - "current_sim_name": self.widget_simulation_name.value, + 'plot_outputs': self.plot_outputs_dict, + 'plot_dropdowns': self.plot_dropdown_types_dict, + 'plot_sim_selections': self.plot_sim_selections_dict, + 'current_sim_name': self.widget_simulation_name.value, } @property def data(self): """Provides easy access to simulation-related data.""" - return {"simulation_data": self.simulation_data} + return {'simulation_data': self.simulation_data} @staticmethod def load_parameters(params_fname): @@ -526,11 +586,14 @@ def load_parameters(params_fname): def _link_callbacks(self): """Link callbacks to UI components.""" + def _handle_backend_change(backend_type): - return handle_backend_change(backend_type.new, - self._backend_config_out, - self.widget_mpi_cmd, - self.widget_n_jobs) + return handle_backend_change( + backend_type.new, + self._backend_config_out, + self.widget_mpi_cmd, + self.widget_n_jobs, + ) def _add_drive_button_clicked(b): return self.add_drive_widget( @@ -548,95 +611,112 @@ def _delete_drives_clicked(b): def _on_upload_connectivity(change): new_params = self.on_upload_params_change( - change, self.layout['drive_textbox'], load_type="connectivity" + change, self.layout['drive_textbox'], load_type='connectivity' ) self.params = new_params def _on_upload_drives(change): _ = self.on_upload_params_change( - change, self.layout['drive_textbox'], load_type="drives" + change, self.layout['drive_textbox'], load_type='drives' ) def _on_upload_data(change): - return on_upload_data_change(change, self.data, self.viz_manager, - self._log_out) + return on_upload_data_change( + change, self.data, self.viz_manager, self._log_out + ) def _run_button_clicked(b): return run_button_clicked( - self.widget_simulation_name, self._log_out, self.drive_widgets, - self.data, self.widget_dt, self.widget_tstop, - self.widget_ntrials, self.widget_backend_selection, - self.widget_mpi_cmd, self.widget_n_jobs, self.params, - self._simulation_status_bar, self._simulation_status_contents, - self.connectivity_widgets, self.viz_manager, - self.simulation_list_widget, self.cell_pameters_widgets) + self.widget_simulation_name, + self._log_out, + self.drive_widgets, + self.data, + self.widget_dt, + self.widget_tstop, + self.widget_ntrials, + self.widget_backend_selection, + self.widget_mpi_cmd, + self.widget_n_jobs, + self.params, + self._simulation_status_bar, + self._simulation_status_contents, + self.connectivity_widgets, + self.viz_manager, + self.simulation_list_widget, + self.cell_pameters_widgets, + ) def _simulation_list_change(value): # Simulation Data - _simulation_data, file_extension = ( - _serialize_simulation(self._log_out, - self.data, - self.simulation_list_widget)) + _simulation_data, file_extension = _serialize_simulation( + self._log_out, self.data, self.simulation_list_widget + ) - result_file = f"{value.new}{file_extension}" - if file_extension == ".csv": + result_file = f'{value.new}{file_extension}' + if file_extension == '.csv': b64 = base64.b64encode(_simulation_data.encode()) else: b64 = base64.b64encode(_simulation_data) payload = b64.decode() - self.save_simuation_button.value = ( - self.html_download_button.format( - payload=payload, filename=result_file, - is_disabled="", btn_height=self.layout['run_btn'].height, - color_theme=self.layout['theme_color'], - title='Save Simulation', mimetype='text/csv')) + self.save_simuation_button.value = self.html_download_button.format( + payload=payload, + filename=result_file, + is_disabled='', + btn_height=self.layout['run_btn'].height, + color_theme=self.layout['theme_color'], + title='Save Simulation', + mimetype='text/csv', + ) # Network Configuration - network_config = _serialize_config(self._log_out, - self.data, - self.simulation_list_widget) + network_config = _serialize_config( + self._log_out, self.data, self.simulation_list_widget + ) b64_net = base64.b64encode(network_config.encode()) - self.save_config_button.value = ( - self.html_download_button.format( - payload=b64_net.decode(), - filename=f"{value.new}.json", - is_disabled="", - btn_height=self.layout['run_btn'].height, - color_theme=self.layout['theme_color'], - title='Save Network', mimetype='application/json')) + self.save_config_button.value = self.html_download_button.format( + payload=b64_net.decode(), + filename=f'{value.new}.json', + is_disabled='', + btn_height=self.layout['run_btn'].height, + color_theme=self.layout['theme_color'], + title='Save Network', + mimetype='application/json', + ) def _driver_type_change(value): self.widget_location_selection.disabled = ( - True if value.new == "Tonic" else False) + True if value.new == 'Tonic' else False + ) def _cell_type_radio_change(value): - _update_cell_params_vbox(self._cell_params_out, - self.cell_pameters_widgets, - value.new, - self.cell_layer_radio_buttons.value) + _update_cell_params_vbox( + self._cell_params_out, + self.cell_pameters_widgets, + value.new, + self.cell_layer_radio_buttons.value, + ) def _cell_layer_radio_change(value): - _update_cell_params_vbox(self._cell_params_out, - self.cell_pameters_widgets, - self.cell_type_radio_buttons.value, - value.new) + _update_cell_params_vbox( + self._cell_params_out, + self.cell_pameters_widgets, + self.cell_type_radio_buttons.value, + value.new, + ) self.widget_backend_selection.observe(_handle_backend_change, 'value') self.add_drive_button.on_click(_add_drive_button_clicked) self.delete_drive_button.on_click(_delete_drives_clicked) - self.load_connectivity_button.observe(_on_upload_connectivity, - names='value') + self.load_connectivity_button.observe(_on_upload_connectivity, names='value') self.load_drives_button.observe(_on_upload_drives, names='value') self.run_button.on_click(_run_button_clicked) self.load_data_button.observe(_on_upload_data, names='value') self.simulation_list_widget.observe(_simulation_list_change, 'value') self.widget_drive_type_selection.observe(_driver_type_change, 'value') - self.cell_type_radio_buttons.observe(_cell_type_radio_change, - 'value') - self.cell_layer_radio_buttons.observe(_cell_layer_radio_change, - 'value') + self.cell_type_radio_buttons.observe(_cell_type_radio_change, 'value') + self.cell_layer_radio_buttons.observe(_cell_layer_radio_change, 'value') def _delete_single_drive(self, b): index = self.drive_accordion.selected_index @@ -666,74 +746,104 @@ def compose(self, return_layout=True): If the method returns the layout object which can be rendered by IPython.display.display() method. """ - simulation_box = VBox([ - VBox([ - self.widget_simulation_name, self.widget_tstop, self.widget_dt, - self.widget_ntrials, self.widget_backend_selection, - self._backend_config_out]), - ], layout=self.layout['config_box']) + simulation_box = VBox( + [ + VBox( + [ + self.widget_simulation_name, + self.widget_tstop, + self.widget_dt, + self.widget_ntrials, + self.widget_backend_selection, + self._backend_config_out, + ] + ), + ], + layout=self.layout['config_box'], + ) connectivity_configuration = Tab() - connectivity_box = VBox([ - HBox([self.load_connectivity_button, ]), - self._connectivity_out, - ]) - - cell_parameters = VBox([ - HBox([self.cell_type_radio_buttons, - self.cell_layer_radio_buttons]), - self._cell_params_out - ]) - - connectivity_configuration.children = [connectivity_box, - cell_parameters] - connectivity_configuration.titles = ['Connectivity', - 'Cell parameters'] - - drive_selections = VBox([ - self.add_drive_button, self.widget_drive_type_selection, - self.widget_location_selection], - layout=Layout(flex="1")) - - drives_options = VBox([ - HBox([ - VBox([self.load_drives_button, self.delete_drive_button], - layout=Layout(flex="1")), - drive_selections, - ]), self._drives_out - ]) + connectivity_box = VBox( + [ + HBox( + [ + self.load_connectivity_button, + ] + ), + self._connectivity_out, + ] + ) + + cell_parameters = VBox( + [ + HBox([self.cell_type_radio_buttons, self.cell_layer_radio_buttons]), + self._cell_params_out, + ] + ) + + connectivity_configuration.children = [connectivity_box, cell_parameters] + connectivity_configuration.titles = ['Connectivity', 'Cell parameters'] + + drive_selections = VBox( + [ + self.add_drive_button, + self.widget_drive_type_selection, + self.widget_location_selection, + ], + layout=Layout(flex='1'), + ) + + drives_options = VBox( + [ + HBox( + [ + VBox( + [self.load_drives_button, self.delete_drive_button], + layout=Layout(flex='1'), + ), + drive_selections, + ] + ), + self._drives_out, + ] + ) config_panel, figs_output = self.viz_manager.compose() # Tabs for left pane left_tab = Tab() left_tab.children = [ - simulation_box, connectivity_configuration, drives_options, + simulation_box, + connectivity_configuration, + drives_options, config_panel, ] - titles = ('Simulation', 'Network', 'External drives', - 'Visualization') + titles = ('Simulation', 'Network', 'External drives', 'Visualization') for idx, title in enumerate(titles): left_tab.set_title(idx, title) self.app_layout = AppLayout( header=self._header, - left_sidebar=VBox([ - VBox([left_tab], layout=self.layout['left_tab']), - self._operation_buttons, - self._log_window, - ], layout=self.layout['left_sidebar']), + left_sidebar=VBox( + [ + VBox([left_tab], layout=self.layout['left_tab']), + self._operation_buttons, + self._log_window, + ], + layout=self.layout['left_sidebar'], + ), right_sidebar=figs_output, footer=self._simulation_status_bar, pane_widths=[ - self.layout['left_sidebar'].width, '0px', - self.layout['visualization_window'].width + self.layout['left_sidebar'].width, + '0px', + self.layout['visualization_window'].width, ], pane_heights=[ self.layout['header_height'], self.layout['visualization_window'].height, - self.layout['simulation_status_height'] + self.layout['simulation_status_height'], ], ) @@ -776,7 +886,7 @@ def capture(self, width=None, height=None, extra_margin=100, render=True): height = self.total_height + extra_margin content = urllib.parse.quote(file.getvalue().encode('utf8')) - data_url = f"data:text/html,{content}" + data_url = f'data:text/html,{content}' screenshot = IFrame(data_url, width=width, height=height) if render: display(screenshot) @@ -867,10 +977,12 @@ def _simulate_left_tab_click(self, tab_title): # Simulate the user clicking on the tab left_tab.selected_index = left_tab.titles.index(tab_title) else: - raise ValueError("Tab title does not exist.") + raise ValueError('Tab title does not exist.') - def _simulate_make_figure(self,): - self._simulate_left_tab_click("Visualization") + def _simulate_make_figure( + self, + ): + self._simulate_left_tab_click('Visualization') self.viz_manager.make_fig_button.click() def _simulate_viz_action(self, action_name, *args, **kwargs): @@ -886,8 +998,8 @@ def _simulate_viz_action(self, action_name, *args, **kwargs): kwargs: dict Optional keyword parameters passed to the called method. """ - self._simulate_left_tab_click("Visualization") - action = getattr(self.viz_manager, f"_simulate_{action_name}") + self._simulate_left_tab_click('Visualization') + action = getattr(self.viz_manager, f'_simulate_{action_name}') action(*args, **kwargs) def _simulate_delete_single_drive(self, idx=0): @@ -898,61 +1010,80 @@ def load_drive_and_connectivity(self): """Add drive and connectivity ipywidgets from params.""" with self._log_out: # Add connectivity - add_connectivity_tab(self.params, - self._connectivity_out, - self.connectivity_widgets, - self._cell_params_out, - self.cell_pameters_widgets, - self.cell_layer_radio_buttons, - self.cell_type_radio_buttons, - self.layout) + add_connectivity_tab( + self.params, + self._connectivity_out, + self.connectivity_widgets, + self._cell_params_out, + self.cell_pameters_widgets, + self.cell_layer_radio_buttons, + self.cell_type_radio_buttons, + self.layout, + ) # Add drives self.add_drive_tab(self.params) - def add_drive_widget(self, - drive_type, - location, - prespecified_drive_name=None, - prespecified_drive_data=None, - prespecified_weights_ampa=None, - prespecified_weights_nmda=None, - prespecified_delays=None, - prespecified_n_drive_cells=None, - prespecified_cell_specific=None, - render=True, - expand_last_drive=True, - event_seed=14, ): + def add_drive_widget( + self, + drive_type, + location, + prespecified_drive_name=None, + prespecified_drive_data=None, + prespecified_weights_ampa=None, + prespecified_weights_nmda=None, + prespecified_delays=None, + prespecified_n_drive_cells=None, + prespecified_cell_specific=None, + render=True, + expand_last_drive=True, + event_seed=14, + ): """Add a widget for a new drive.""" # Check only adds 1 tonic input widget - if (drive_type == "Tonic" and - not _is_valid_add_tonic_input(self.drive_widgets)): + if drive_type == 'Tonic' and not _is_valid_add_tonic_input(self.drive_widgets): return # Build drive widget objects - name = (drive_type + str(len(self.drive_boxes)) - if not prespecified_drive_name - else prespecified_drive_name) + name = ( + drive_type + str(len(self.drive_boxes)) + if not prespecified_drive_name + else prespecified_drive_name + ) style = {'description_width': '125px'} - prespecified_drive_data = ({} if not prespecified_drive_data - else prespecified_drive_data) - prespecified_drive_data.update({"seedcore": max(event_seed, 2)}) + prespecified_drive_data = ( + {} if not prespecified_drive_data else prespecified_drive_data + ) + prespecified_drive_data.update({'seedcore': max(event_seed, 2)}) drive, drive_box = _build_drive_objects( - drive_type, name, self.widget_tstop, - self.layout['drive_textbox'], style, location, - prespecified_drive_data, prespecified_weights_ampa, - prespecified_weights_nmda, prespecified_delays, - prespecified_n_drive_cells, prespecified_cell_specific + drive_type, + name, + self.widget_tstop, + self.layout['drive_textbox'], + style, + location, + prespecified_drive_data, + prespecified_weights_ampa, + prespecified_weights_nmda, + prespecified_delays, + prespecified_n_drive_cells, + prespecified_cell_specific, ) # Add delete button and assign its call-back function - delete_button = Button(description='Delete', button_style='danger', - icon='close', layout=self.layout['del_fig_btn']) + delete_button = Button( + description='Delete', + button_style='danger', + icon='close', + layout=self.layout['del_fig_btn'], + ) delete_button.on_click(self._delete_single_drive) - drive_box.children += (HTML(value="

"), # Adds blank space - delete_button) + drive_box.children += ( + HTML(value='

'), # Adds blank space + delete_button, + ) self.drive_boxes.append(drive_box) self.drive_widgets.append(drive) @@ -1007,19 +1138,20 @@ def add_drive_tab(self, params): ) should_render = idx == (len(drive_names) - 1) - self.add_drive_widget(drive_type=specs['type'].capitalize(), - location=specs['location'], - prespecified_drive_name=drive_name, - render=should_render, - expand_last_drive=False, - **kwargs) + self.add_drive_widget( + drive_type=specs['type'].capitalize(), + location=specs['location'], + prespecified_drive_name=drive_name, + render=should_render, + expand_last_drive=False, + **kwargs, + ) def on_upload_params_change(self, change, layout, load_type): - if len(change['owner'].value) == 0: return param_dict = change['new'][0] - file_contents = codecs.decode(param_dict['content'], encoding="utf-8") + file_contents = codecs.decode(param_dict['content'], encoding='utf-8') with self._log_out: params = json.loads(file_contents) @@ -1033,10 +1165,15 @@ def on_upload_params_change(self, change, layout, load_type): # init network, add drives & connectivity if load_type == 'connectivity': add_connectivity_tab( - params, self._connectivity_out, self.connectivity_widgets, - self._cell_params_out, self.cell_pameters_widgets, + params, + self._connectivity_out, + self.connectivity_widgets, + self._cell_params_out, + self.cell_pameters_widgets, self.cell_layer_radio_buttons, - self.cell_type_radio_buttons, layout) + self.cell_type_radio_buttons, + layout, + ) elif load_type == 'drives': self.add_drive_tab(params) else: @@ -1054,37 +1191,41 @@ def _prepare_upload_file_from_local(path): content = memoryview(file.read()) last_modified = datetime.fromtimestamp(path.stat().st_mtime) - upload_structure = [{ - 'name': path.name, - 'type': mimetypes.guess_type(path)[0], - 'size': path.stat().st_size, - 'content': content, - 'last_modified': last_modified - }] + upload_structure = [ + { + 'name': path.name, + 'type': mimetypes.guess_type(path)[0], + 'size': path.stat().st_size, + 'content': content, + 'last_modified': last_modified, + } + ] return upload_structure def _prepare_upload_file_from_url(file_url): - file_name = file_url.split("/")[-1] + file_name = file_url.split('/')[-1] data = urllib.request.urlopen(file_url) content = bytearray() for line in data: content.extend(line) - upload_structure = [{ - 'name': file_name, - 'type': mimetypes.guess_type(file_url)[0], - 'size': len(content), - 'content': memoryview(content), - 'last_modified': datetime.now() - }] + upload_structure = [ + { + 'name': file_name, + 'type': mimetypes.guess_type(file_url)[0], + 'size': len(content), + 'content': memoryview(content), + 'last_modified': datetime.now(), + } + ] return upload_structure def _prepare_upload_file(path): - """ Simulates output of the FileUpload widget for testing. + """Simulates output of the FileUpload widget for testing. Unit tests for the GUI simulate user upload of files. File source can either be local or from a URL. This function returns the data structure @@ -1100,7 +1241,7 @@ def _prepare_upload_file(path): def _update_nested_dict(original, new, skip_none=True): - """ Updates dictionary values from another dictionary + """Updates dictionary values from another dictionary Will update nested dictionaries in the structure. New items from the update dictionary are added and omitted items are retained from the @@ -1124,9 +1265,11 @@ def _update_nested_dict(original, new, skip_none=True): """ updated = original.copy() for key, value in new.items(): - if (isinstance(value, dict) and - key in updated and - isinstance(updated[key], dict)): + if ( + isinstance(value, dict) + and key in updated + and isinstance(updated[key], dict) + ): updated[key] = _update_nested_dict(updated[key], value, skip_none) elif (value is not None) or (not skip_none): updated[key] = value @@ -1136,11 +1279,16 @@ def _update_nested_dict(original, new, skip_none=True): return updated -def create_expanded_button(description, button_style, layout, disabled=False, - button_color="#8A2BE2"): - return Button(description=description, button_style=button_style, - layout=layout, style={'button_color': button_color}, - disabled=disabled) +def create_expanded_button( + description, button_style, layout, disabled=False, button_color='#8A2BE2' +): + return Button( + description=description, + button_style=button_style, + layout=layout, + style={'button_color': button_color}, + disabled=disabled, + ) def _get_connectivity_widgets(conn_data): @@ -1151,21 +1299,32 @@ def _get_connectivity_widgets(conn_data): sliders = list() for receptor_name in conn_data.keys(): w_text_input = BoundedFloatText( - value=conn_data[receptor_name]['weight'], disabled=False, - continuous_update=False, min=0, max=1e6, step=0.01, - description="weight", style=style) + value=conn_data[receptor_name]['weight'], + disabled=False, + continuous_update=False, + min=0, + max=1e6, + step=0.01, + description='weight', + style=style, + ) - conn_widget = VBox([ - HTML(value=f"""

- Receptor: {conn_data[receptor_name]['receptor']}

"""), - w_text_input, HTML(value="
") - ]) + conn_widget = VBox( + [ + HTML( + value=f"""

+ Receptor: {conn_data[receptor_name]['receptor']}

""" + ), + w_text_input, + HTML(value="
"), + ] + ) conn_widget._belongsto = { - "receptor": conn_data[receptor_name]['receptor'], - "location": conn_data[receptor_name]['location'], - "src_gids": conn_data[receptor_name]['src_gids'], - "target_gids": conn_data[receptor_name]['target_gids'], + 'receptor': conn_data[receptor_name]['receptor'], + 'location': conn_data[receptor_name]['location'], + 'src_gids': conn_data[receptor_name]['src_gids'], + 'target_gids': conn_data[receptor_name]['target_gids'], } sliders.append(conn_widget) @@ -1175,22 +1334,22 @@ def _get_connectivity_widgets(conn_data): def _get_drive_weight_widgets(layout, style, location, data=None): default_data = { 'weights_ampa': { - 'L5_pyramidal': 0., - 'L2_pyramidal': 0., - 'L5_basket': 0., - 'L2_basket': 0. + 'L5_pyramidal': 0.0, + 'L2_pyramidal': 0.0, + 'L5_basket': 0.0, + 'L2_basket': 0.0, }, 'weights_nmda': { - 'L5_pyramidal': 0., - 'L2_pyramidal': 0., - 'L5_basket': 0., - 'L2_basket': 0. + 'L5_pyramidal': 0.0, + 'L2_pyramidal': 0.0, + 'L5_basket': 0.0, + 'L2_basket': 0.0, }, 'delays': { 'L5_pyramidal': 0.1, 'L2_pyramidal': 0.1, 'L5_basket': 0.1, - 'L2_basket': 0.1 + 'L2_basket': 0.1, }, } if isinstance(data, dict): @@ -1198,32 +1357,49 @@ def _get_drive_weight_widgets(layout, style, location, data=None): kwargs = dict(layout=layout, style=style) cell_types = ['L5_pyramidal', 'L2_pyramidal', 'L5_basket', 'L2_basket'] - if location == "distal": + if location == 'distal': cell_types.remove('L5_basket') weights_ampa, weights_nmda, delays = dict(), dict(), dict() for cell_type in cell_types: weights_ampa[f'{cell_type}'] = BoundedFloatText( value=default_data['weights_ampa'][cell_type], - description=f'{cell_type}:', min=0, max=1e6, step=0.01, **kwargs) + description=f'{cell_type}:', + min=0, + max=1e6, + step=0.01, + **kwargs, + ) weights_nmda[f'{cell_type}'] = BoundedFloatText( value=default_data['weights_nmda'][cell_type], - description=f'{cell_type}:', min=0, max=1e6, step=0.01, **kwargs) + description=f'{cell_type}:', + min=0, + max=1e6, + step=0.01, + **kwargs, + ) delays[f'{cell_type}'] = BoundedFloatText( value=default_data['delays'][cell_type], - description=f'{cell_type}:', min=0, max=1e6, step=0.1, **kwargs) + description=f'{cell_type}:', + min=0, + max=1e6, + step=0.1, + **kwargs, + ) widgets_dict = { 'weights_ampa': weights_ampa, 'weights_nmda': weights_nmda, - 'delays': delays + 'delays': delays, } - widgets_list = ([HTML(value="AMPA weights")] + - list(weights_ampa.values()) + - [HTML(value="NMDA weights")] + - list(weights_nmda.values()) + - [HTML(value="Synaptic delays")] + - list(delays.values())) + widgets_list = ( + [HTML(value='AMPA weights')] + + list(weights_ampa.values()) + + [HTML(value='NMDA weights')] + + list(weights_nmda.values()) + + [HTML(value='Synaptic delays')] + + list(delays.values()) + ) return widgets_list, widgets_dict @@ -1234,14 +1410,22 @@ def _cell_spec_change(change, widget): widget.disabled = False -def _get_rhythmic_widget(name, tstop_widget, layout, style, location, - data={}, weights_ampa=None, - weights_nmda=None, delays=None, - n_drive_cells=None, cell_specific=None - ): +def _get_rhythmic_widget( + name, + tstop_widget, + layout, + style, + location, + data={}, + weights_ampa=None, + weights_nmda=None, + delays=None, + n_drive_cells=None, + cell_specific=None, +): default_data = { - 'tstart': 0., - 'tstart_std': 0., + 'tstart': 0.0, + 'tstart_std': 0.0, 'tstop': tstop_widget.value, 'burst_rate': 7.5, 'burst_std': 0, @@ -1250,17 +1434,24 @@ def _get_rhythmic_widget(name, tstop_widget, layout, style, location, 'cell_specific': False, 'seedcore': 14, } - data.update({'n_drive_cells': n_drive_cells, - 'cell_specific': cell_specific}) + data.update({'n_drive_cells': n_drive_cells, 'cell_specific': cell_specific}) default_data = _update_nested_dict(default_data, data) kwargs = dict(layout=layout, style=style) tstart = BoundedFloatText( - value=default_data['tstart'], description='Start time (ms)', - min=0, max=1e6, **kwargs) + value=default_data['tstart'], + description='Start time (ms)', + min=0, + max=1e6, + **kwargs, + ) tstart_std = BoundedFloatText( - value=default_data['tstart_std'], description='Start time dev (ms)', - min=0, max=1e6, **kwargs) + value=default_data['tstart_std'], + description='Start time dev (ms)', + min=0, + max=1e6, + **kwargs, + ) tstop = BoundedFloatText( value=default_data['tstop'], description='Stop time (ms)', @@ -1268,24 +1459,36 @@ def _get_rhythmic_widget(name, tstop_widget, layout, style, location, **kwargs, ) burst_rate = BoundedFloatText( - value=default_data['burst_rate'], description='Burst rate (Hz)', - min=0, max=1e6, **kwargs) + value=default_data['burst_rate'], + description='Burst rate (Hz)', + min=0, + max=1e6, + **kwargs, + ) burst_std = BoundedFloatText( - value=default_data['burst_std'], description='Burst std dev (Hz)', - min=0, max=1e6, **kwargs) + value=default_data['burst_std'], + description='Burst std dev (Hz)', + min=0, + max=1e6, + **kwargs, + ) numspikes = BoundedIntText( - value=default_data['numspikes'], description='No. Spikes:', min=0, - max=int(1e6), **kwargs) - n_drive_cells = IntText(value=default_data['n_drive_cells'], - description='No. Drive Cells:', - disabled=default_data['cell_specific'], - **kwargs) - cell_specific = Checkbox(value=default_data['cell_specific'], - description='Cell-Specific', - **kwargs) - seedcore = IntText(value=default_data['seedcore'], - description='Seed', - **kwargs) + value=default_data['numspikes'], + description='No. Spikes:', + min=0, + max=int(1e6), + **kwargs, + ) + n_drive_cells = IntText( + value=default_data['n_drive_cells'], + description='No. Drive Cells:', + disabled=default_data['cell_specific'], + **kwargs, + ) + cell_specific = Checkbox( + value=default_data['cell_specific'], description='Cell-Specific', **kwargs + ) + seedcore = IntText(value=default_data['seedcore'], description='Seed', **kwargs) widgets_list, widgets_dict = _get_drive_weight_widgets( layout, @@ -1299,35 +1502,56 @@ def _get_rhythmic_widget(name, tstop_widget, layout, style, location, ) # Disable n_drive_cells widget based on cell_specific checkbox - cell_specific.observe(partial(_cell_spec_change, widget=n_drive_cells), - names='value') - - drive_box = VBox([tstart, tstart_std, tstop, - burst_rate, burst_std, numspikes, - n_drive_cells, cell_specific, - seedcore] + widgets_list) - - drive = dict(type='Rhythmic', - name=name, - tstart=tstart, - tstart_std=tstart_std, - burst_rate=burst_rate, - burst_std=burst_std, - numspikes=numspikes, - seedcore=seedcore, - location=location, - tstop=tstop, - n_drive_cells=n_drive_cells, - is_cell_specific=cell_specific, - ) + cell_specific.observe( + partial(_cell_spec_change, widget=n_drive_cells), names='value' + ) + + drive_box = VBox( + [ + tstart, + tstart_std, + tstop, + burst_rate, + burst_std, + numspikes, + n_drive_cells, + cell_specific, + seedcore, + ] + + widgets_list + ) + + drive = dict( + type='Rhythmic', + name=name, + tstart=tstart, + tstart_std=tstart_std, + burst_rate=burst_rate, + burst_std=burst_std, + numspikes=numspikes, + seedcore=seedcore, + location=location, + tstop=tstop, + n_drive_cells=n_drive_cells, + is_cell_specific=cell_specific, + ) drive.update(widgets_dict) return drive, drive_box -def _get_poisson_widget(name, tstop_widget, layout, style, location, data={}, - weights_ampa=None, weights_nmda=None, - delays=None, n_drive_cells=None, - cell_specific=None): +def _get_poisson_widget( + name, + tstop_widget, + layout, + style, + location, + data={}, + weights_ampa=None, + weights_nmda=None, + delays=None, + n_drive_cells=None, + cell_specific=None, +): default_data = { 'tstart': 0.0, 'tstop': tstop_widget.value, @@ -1335,19 +1559,23 @@ def _get_poisson_widget(name, tstop_widget, layout, style, location, data={}, 'cell_specific': True, 'seedcore': 14, 'rate_constant': { - 'L2_pyramidal': 40., - 'L5_pyramidal': 40., - 'L2_basket': 40., - 'L5_basket': 40., - } + 'L2_pyramidal': 40.0, + 'L5_pyramidal': 40.0, + 'L2_basket': 40.0, + 'L5_basket': 40.0, + }, } - data.update({'n_drive_cells': n_drive_cells, - 'cell_specific': cell_specific}) + data.update({'n_drive_cells': n_drive_cells, 'cell_specific': cell_specific}) default_data = _update_nested_dict(default_data, data) tstart = BoundedFloatText( - value=default_data['tstart'], description='Start time (ms)', - min=0, max=1e6, layout=layout, style=style) + value=default_data['tstart'], + description='Start time (ms)', + min=0, + max=1e6, + layout=layout, + style=style, + ) tstop = BoundedFloatText( value=default_data['tstop'], max=tstop_widget.value, @@ -1355,29 +1583,35 @@ def _get_poisson_widget(name, tstop_widget, layout, style, location, data={}, layout=layout, style=style, ) - n_drive_cells = IntText(value=default_data['n_drive_cells'], - description='No. Drive Cells:', - disabled=default_data['cell_specific'], - layout=layout, - style=style - ) - cell_specific = Checkbox(value=default_data['cell_specific'], - description='Cell-Specific', - layout=layout, - style=style - ) - seedcore = IntText(value=default_data['seedcore'], - description='Seed', - layout=layout, - style=style) + n_drive_cells = IntText( + value=default_data['n_drive_cells'], + description='No. Drive Cells:', + disabled=default_data['cell_specific'], + layout=layout, + style=style, + ) + cell_specific = Checkbox( + value=default_data['cell_specific'], + description='Cell-Specific', + layout=layout, + style=style, + ) + seedcore = IntText( + value=default_data['seedcore'], description='Seed', layout=layout, style=style + ) cell_types = ['L5_pyramidal', 'L2_pyramidal', 'L5_basket', 'L2_basket'] rate_constant = dict() for cell_type in cell_types: rate_constant[f'{cell_type}'] = BoundedFloatText( value=default_data['rate_constant'][cell_type], - description=f'{cell_type}:', min=0, max=1e6, step=0.01, - layout=layout, style=style) + description=f'{cell_type}:', + min=0, + max=1e6, + step=0.01, + layout=layout, + style=style, + ) widgets_list, widgets_dict = _get_drive_weight_widgets( layout, @@ -1390,15 +1624,19 @@ def _get_poisson_widget(name, tstop_widget, layout, style, location, data={}, }, ) widgets_dict.update({'rate_constant': rate_constant}) - widgets_list.extend([HTML(value="Rate constants")] + - list(widgets_dict['rate_constant'].values())) + widgets_list.extend( + [HTML(value='Rate constants')] + + list(widgets_dict['rate_constant'].values()) + ) # Disable n_drive_cells widget based on cell_specific checkbox - cell_specific.observe(partial(_cell_spec_change, widget=n_drive_cells), - names='value') + cell_specific.observe( + partial(_cell_spec_change, widget=n_drive_cells), names='value' + ) - drive_box = VBox([tstart, tstop, n_drive_cells, - cell_specific, seedcore] + widgets_list) + drive_box = VBox( + [tstart, tstop, n_drive_cells, cell_specific, seedcore] + widgets_list + ) drive = dict( type='Poisson', name=name, @@ -1414,9 +1652,18 @@ def _get_poisson_widget(name, tstop_widget, layout, style, location, data={}, return drive, drive_box -def _get_evoked_widget(name, layout, style, location, data={}, - weights_ampa=None, weights_nmda=None, - delays=None, n_drive_cells=None, cell_specific=None): +def _get_evoked_widget( + name, + layout, + style, + location, + data={}, + weights_ampa=None, + weights_nmda=None, + delays=None, + n_drive_cells=None, + cell_specific=None, +): default_data = { 'mu': 0, 'sigma': 1, @@ -1425,30 +1672,39 @@ def _get_evoked_widget(name, layout, style, location, data={}, 'cell_specific': True, 'seedcore': 14, } - data.update({'n_drive_cells': n_drive_cells, - 'cell_specific': cell_specific}) + data.update({'n_drive_cells': n_drive_cells, 'cell_specific': cell_specific}) default_data = _update_nested_dict(default_data, data) kwargs = dict(layout=layout, style=style) mu = BoundedFloatText( - value=default_data['mu'], description='Mean time:', min=0, max=1e6, - step=0.01, **kwargs) + value=default_data['mu'], + description='Mean time:', + min=0, + max=1e6, + step=0.01, + **kwargs, + ) sigma = BoundedFloatText( - value=default_data['sigma'], description='Std dev time:', min=0, - max=1e6, step=0.01, **kwargs) - numspikes = IntText(value=default_data['numspikes'], - description='No. Spikes:', - **kwargs) - n_drive_cells = IntText(value=default_data['n_drive_cells'], - description='No. Drive Cells:', - disabled=default_data['cell_specific'], - **kwargs) - cell_specific = Checkbox(value=default_data['cell_specific'], - description='Cell-Specific', - **kwargs) - seedcore = IntText(value=default_data['seedcore'], - description='Seed: ', - **kwargs) + value=default_data['sigma'], + description='Std dev time:', + min=0, + max=1e6, + step=0.01, + **kwargs, + ) + numspikes = IntText( + value=default_data['numspikes'], description='No. Spikes:', **kwargs + ) + n_drive_cells = IntText( + value=default_data['n_drive_cells'], + description='No. Drive Cells:', + disabled=default_data['cell_specific'], + **kwargs, + ) + cell_specific = Checkbox( + value=default_data['cell_specific'], description='Cell-Specific', **kwargs + ) + seedcore = IntText(value=default_data['seedcore'], description='Seed: ', **kwargs) widgets_list, widgets_dict = _get_drive_weight_widgets( layout, @@ -1462,33 +1718,40 @@ def _get_evoked_widget(name, layout, style, location, data={}, ) # Disable n_drive_cells widget based on cell_specific checkbox - cell_specific.observe(partial(_cell_spec_change, widget=n_drive_cells), - names='value') - - drive_box = VBox([mu, sigma, numspikes, n_drive_cells, - cell_specific, seedcore,] + - widgets_list) - drive = dict(type='Evoked', - name=name, - mu=mu, - sigma=sigma, - numspikes=numspikes, - seedcore=seedcore, - location=location, - sync_within_trial=False, - n_drive_cells=n_drive_cells, - is_cell_specific=cell_specific) + cell_specific.observe( + partial(_cell_spec_change, widget=n_drive_cells), names='value' + ) + + drive_box = VBox( + [ + mu, + sigma, + numspikes, + n_drive_cells, + cell_specific, + seedcore, + ] + + widgets_list + ) + drive = dict( + type='Evoked', + name=name, + mu=mu, + sigma=sigma, + numspikes=numspikes, + seedcore=seedcore, + location=location, + sync_within_trial=False, + n_drive_cells=n_drive_cells, + is_cell_specific=cell_specific, + ) drive.update(widgets_dict) return drive, drive_box def _get_tonic_widget(name, tstop_widget, layout, style, data=None): cell_types = ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal'] - default_values = { - 'amplitude': 0, - 't0': 0, - 'tstop': tstop_widget.value - } + default_values = {'amplitude': 0, 't0': 0, 'tstop': tstop_widget.value} t0 = default_values['t0'] tstop = default_values['tstop'] default_data = {cell_type: default_values for cell_type in cell_types} @@ -1501,8 +1764,8 @@ def _get_tonic_widget(name, tstop_widget, layout, style, data=None): for cell_type in cell_types: amplitude = default_data[cell_type]['amplitude'] amplitudes[cell_type] = BoundedFloatText( - value=amplitude, description=cell_type, - min=0, max=1e6, step=0.01, **kwargs) + value=amplitude, description=cell_type, min=0, max=1e6, step=0.01, **kwargs + ) # Reset the global t0 and stop with values from the 'data' keyword. # It should be same across all the cell-types. if amplitude > 0: @@ -1510,38 +1773,47 @@ def _get_tonic_widget(name, tstop_widget, layout, style, data=None): tstop = default_data[cell_type]['tstop'] start_times = BoundedFloatText( - value=t0, description="Start time", - min=0, max=1e6, step=1.0, **kwargs) + value=t0, description='Start time', min=0, max=1e6, step=1.0, **kwargs + ) stop_times = BoundedFloatText( - value=tstop, description="Stop time", - min=-1, max=1e6, step=1.0, **kwargs) + value=tstop, description='Stop time', min=-1, max=1e6, step=1.0, **kwargs + ) - widgets_dict = { - 'amplitude': amplitudes, - 't0': start_times, - 'tstop': stop_times - } - widgets_list = ([HTML(value="Times (ms):")] + - [start_times, stop_times] + - [HTML(value="Amplitude (nA):")] + - list(amplitudes.values())) + widgets_dict = {'amplitude': amplitudes, 't0': start_times, 'tstop': stop_times} + widgets_list = ( + [HTML(value='Times (ms):')] + + [start_times, stop_times] + + [HTML(value='Amplitude (nA):')] + + list(amplitudes.values()) + ) drive_box = VBox(widgets_list) - drive = dict(type='Tonic', - name=name, - amplitude=amplitudes, - t0=start_times, - tstop=stop_times,) + drive = dict( + type='Tonic', + name=name, + amplitude=amplitudes, + t0=start_times, + tstop=stop_times, + ) drive.update(widgets_dict) return drive, drive_box -def _build_drive_objects(drive_type, name, tstop_widget, layout, style, - location, drive_data, weights_ampa, - weights_nmda, delays, n_drive_cells, - cell_specific): - +def _build_drive_objects( + drive_type, + name, + tstop_widget, + layout, + style, + location, + drive_data, + weights_ampa, + weights_nmda, + delays, + n_drive_cells, + cell_specific, +): if drive_type in ('Rhythmic', 'Bursty'): drive, drive_box = _get_rhythmic_widget( name, @@ -1585,11 +1857,7 @@ def _build_drive_objects(drive_type, name, tstop_widget, layout, style, ) elif drive_type == 'Tonic': drive, drive_box = _get_tonic_widget( - name, - tstop_widget, - layout, - style, - data=drive_data + name, tstop_widget, layout, style, data=drive_data ) else: raise ValueError(f'Unknown drive type {drive_type}') @@ -1597,26 +1865,34 @@ def _build_drive_objects(drive_type, name, tstop_widget, layout, style, return drive, drive_box -def add_connectivity_tab(params, connectivity_out, connectivity_textfields, - cell_params_out, cell_pameters_vboxes, - cell_layer_radio_button, cell_type_radio_button, - layout): +def add_connectivity_tab( + params, + connectivity_out, + connectivity_textfields, + cell_params_out, + cell_pameters_vboxes, + cell_layer_radio_button, + cell_type_radio_button, + layout, +): """Add all possible connectivity boxes to connectivity tab.""" net = dict_to_network(params) # build network connectivity tab - add_network_connectivity_tab(net, connectivity_out, - connectivity_textfields) + add_network_connectivity_tab(net, connectivity_out, connectivity_textfields) # build cell parameters tab - add_cell_parameters_tab(cell_params_out, cell_pameters_vboxes, - cell_layer_radio_button, cell_type_radio_button, - layout) + add_cell_parameters_tab( + cell_params_out, + cell_pameters_vboxes, + cell_layer_radio_button, + cell_type_radio_button, + layout, + ) return net -def add_network_connectivity_tab(net, connectivity_out, - connectivity_textfields): +def add_network_connectivity_tab(net, connectivity_out, connectivity_textfields): cell_types = [ct for ct in net.cell_types.keys()] receptors = ('ampa', 'nmda', 'gabaa', 'gabab') locations = ('proximal', 'distal', 'soma') @@ -1633,33 +1909,33 @@ def add_network_connectivity_tab(net, connectivity_out, # the connectivity list should be built on this level receptor_related_conn = {} for receptor in receptors: - conn_indices = pick_connection(net=net, - src_gids=src_gids, - target_gids=target_gids, - loc=location, - receptor=receptor) + conn_indices = pick_connection( + net=net, + src_gids=src_gids, + target_gids=target_gids, + loc=location, + receptor=receptor, + ) if len(conn_indices) > 0: assert len(conn_indices) == 1 conn_idx = conn_indices[0] - current_w = net.connectivity[ - conn_idx]['nc_dict']['A_weight'] - current_p = net.connectivity[ - conn_idx]['probability'] + current_w = net.connectivity[conn_idx]['nc_dict']['A_weight'] + current_p = net.connectivity[conn_idx]['probability'] # valid connection receptor_related_conn[receptor] = { - "weight": current_w, - "probability": current_p, + 'weight': current_w, + 'probability': current_p, # info used to identify connection - "receptor": receptor, - "location": location, - "src_gids": src_gids, - "target_gids": target_gids, + 'receptor': receptor, + 'location': location, + 'src_gids': src_gids, + 'target_gids': target_gids, } if len(receptor_related_conn) > 0: - connectivity_names.append( - f"{src_gids}→{target_gids} ({location})") + connectivity_names.append(f'{src_gids}→{target_gids} ({location})') connectivity_textfields.append( - _get_connectivity_widgets(receptor_related_conn)) + _get_connectivity_widgets(receptor_related_conn) + ) connectivity_boxes = [VBox(slider) for slider in connectivity_textfields] cell_connectivity = Accordion(children=connectivity_boxes) @@ -1672,38 +1948,48 @@ def add_network_connectivity_tab(net, connectivity_out, return net -def add_cell_parameters_tab(cell_params_out, cell_pameters_vboxes, - cell_layer_radio_button, cell_type_radio_button, - layout): +def add_cell_parameters_tab( + cell_params_out, + cell_pameters_vboxes, + cell_layer_radio_button, + cell_type_radio_button, + layout, +): L2_default_values = get_L2Pyr_params_default() L5_default_values = get_L5Pyr_params_default() - cell_types = [("L2", L2_default_values), ("L5", L5_default_values)] + cell_types = [('L2', L2_default_values), ('L5', L5_default_values)] style = {'description_width': '255px'} kwargs = dict(layout=layout, style=style) for cell_type in cell_types: layer_parameters = list() for layer in cell_parameters_dict.keys(): - if ('Biophysic' in layer or 'Geometry' in layer) and \ - cell_type[0] not in layer: + if ('Biophysic' in layer or 'Geometry' in layer) and cell_type[ + 0 + ] not in layer: continue for parameter in cell_parameters_dict[layer]: - param_name, param_units, params_key = (parameter[0], - parameter[1], - parameter[2]) + param_name, param_units, params_key = ( + parameter[0], + parameter[1], + parameter[2], + ) default_value = get_cell_param_default_value( - f'{cell_type[0]}Pyr_{params_key}', cell_type[1]) - description = f"{param_name} ({param_units})" + f'{cell_type[0]}Pyr_{params_key}', cell_type[1] + ) + description = f'{param_name} ({param_units})' min_value = -1000.0 if param_units not in 'ms' else 0 - text_field = BoundedFloatText(value=default_value, - min=min_value, - max=1000.0, - step=0.1, - description=description, - disabled=False, - **kwargs) - text_field.layout.width = "350px" + text_field = BoundedFloatText( + value=default_value, + min=min_value, + max=1000.0, + step=0.1, + description=description, + disabled=False, + **kwargs, + ) + text_field.layout.width = '350px' layer_parameters.append(text_field) cell_pameters_key = f'{cell_type[0]} Pyramidal_{layer}' cell_pameters_vboxes[cell_pameters_key] = VBox(layer_parameters) @@ -1713,10 +1999,12 @@ def add_cell_parameters_tab(cell_params_out, cell_pameters_vboxes, cell_params_out.clear_output() # Add cell parameters - _update_cell_params_vbox(cell_params_out, - cell_pameters_vboxes, - cell_type_radio_button.value, - cell_layer_radio_button.value) + _update_cell_params_vbox( + cell_params_out, + cell_pameters_vboxes, + cell_type_radio_button.value, + cell_layer_radio_button.value, + ) def get_cell_param_default_value(cell_type_key, param_dict): @@ -1731,39 +2019,39 @@ def on_upload_data_change(change, data, viz_manager, log_out): data_dict = change['new'][0] dict_name = data_dict['name'].rsplit('.', 1) data_fname = dict_name[0] - file_extension = f".{dict_name[1]}" + file_extension = f'.{dict_name[1]}' # If data was already loaded return if data_fname in data['simulation_data'].keys(): with log_out: - logger.error(f"Found existing data: {data_fname}.") + logger.error(f'Found existing data: {data_fname}.') return # Read the file ext_content = data_dict['content'] - ext_content = codecs.decode(ext_content, encoding="utf-8") - with (log_out): + ext_content = codecs.decode(ext_content, encoding='utf-8') + with log_out: # Write loaded data to data object data['simulation_data'][data_fname] = { - 'net': None, 'dpls': [_read_dipole_txt(io.StringIO(ext_content), - file_extension - ) - ]} + 'net': None, + 'dpls': [_read_dipole_txt(io.StringIO(ext_content), file_extension)], + } logger.info(f'External data {data_fname} loaded.') # Create a dipole plot - _template_name = "[Blank] single figure" + _template_name = '[Blank] single figure' viz_manager.reset_fig_config_tabs(template_name=_template_name) viz_manager.add_figure() fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1) process_configs = {'dipole_smooth': 0, 'dipole_scaling': 1} - viz_manager._simulate_edit_figure(fig_name, - ax_name='ax0', - simulation_name=data_fname, - plot_type="current dipole", - preprocessing_config=process_configs, - operation='plot' - ) + viz_manager._simulate_edit_figure( + fig_name, + ax_name='ax0', + simulation_name=data_fname, + plot_type='current dipole', + preprocessing_config=process_configs, + operation='plot', + ) # Reset the load file widget change['owner'].value = [] @@ -1786,22 +2074,24 @@ def _drive_widget_to_dict(drive, name): ------- """ - return { - k: v.value - for k, v in drive[name].items() - } - - -def _init_network_from_widgets(params, dt, tstop, single_simulation_data, - drive_widgets, connectivity_textfields, - cell_params_vboxes, - add_drive=True): + return {k: v.value for k, v in drive[name].items()} + + +def _init_network_from_widgets( + params, + dt, + tstop, + single_simulation_data, + drive_widgets, + connectivity_textfields, + cell_params_vboxes, + add_drive=True, +): """Construct network and add drives.""" - print("init network") - single_simulation_data['net'] = dict_to_network(params, - read_drives=False, - read_external_biases=False - ) + print('init network') + single_simulation_data['net'] = dict_to_network( + params, read_drives=False, read_external_biases=False + ) # adjust connectivity according to the connectivity_tab for connectivity_slider in connectivity_textfields: for vbox_key in connectivity_slider: @@ -1810,13 +2100,15 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data, src_gids=vbox_key._belongsto['src_gids'], target_gids=vbox_key._belongsto['target_gids'], loc=vbox_key._belongsto['location'], - receptor=vbox_key._belongsto['receptor']) + receptor=vbox_key._belongsto['receptor'], + ) if len(conn_indices) > 0: assert len(conn_indices) == 1 conn_idx = conn_indices[0] - single_simulation_data['net'].connectivity[conn_idx][ - 'nc_dict']['A_weight'] = vbox_key.children[1].value + single_simulation_data['net'].connectivity[conn_idx]['nc_dict'][ + 'A_weight' + ] = vbox_key.children[1].value # Update cell params @@ -1825,7 +2117,7 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data, 'L5 Geometry': _update_L5_geometry_cell_params, 'Synapses': _update_synapse_cell_params, 'L2 Pyramidal_Biophysics': _update_L2_biophysics_cell_params, - 'L5 Pyramidal_Biophysics': _update_L5_biophysics_cell_params + 'L5 Pyramidal_Biophysics': _update_L5_biophysics_cell_params, } # Update cell params @@ -1833,14 +2125,14 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data, for key, update_function in update_functions.items(): if key in vbox_key: cell_type = vbox_key.split()[0] - update_function(single_simulation_data['net'], cell_type, - cell_param_list.children) + update_function( + single_simulation_data['net'], cell_type, cell_param_list.children + ) break # update needed only once per vbox_key for cell_type in single_simulation_data['net'].cell_types.keys(): single_simulation_data['net'].cell_types[cell_type]._update_end_pts() - single_simulation_data['net'].cell_types[ - cell_type]._compute_section_mechs() + single_simulation_data['net'].cell_types[cell_type]._compute_section_mechs() if add_drive is False: return @@ -1850,20 +2142,23 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data, weights_amplitudes = _drive_widget_to_dict(drive, 'amplitude') single_simulation_data['net'].add_tonic_bias( amplitude=weights_amplitudes, - t0=drive["t0"].value, - tstop=drive["tstop"].value) + t0=drive['t0'].value, + tstop=drive['tstop'].value, + ) else: sync_inputs_kwargs = dict( - n_drive_cells=('n_cells' if drive['is_cell_specific'].value - else drive['n_drive_cells'].value), + n_drive_cells=( + 'n_cells' + if drive['is_cell_specific'].value + else drive['n_drive_cells'].value + ), cell_specific=drive['is_cell_specific'].value, ) weights_ampa = _drive_widget_to_dict(drive, 'weights_ampa') weights_nmda = _drive_widget_to_dict(drive, 'weights_nmda') synaptic_delays = _drive_widget_to_dict(drive, 'delays') - print( - f"drive type is {drive['type']}, location={drive['location']}") + print(f"drive type is {drive['type']}, location={drive['location']}") if drive['type'] == 'Poisson': rate_constant = _drive_widget_to_dict(drive, 'rate_constant') @@ -1878,7 +2173,8 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data, synaptic_delays=synaptic_delays, space_constant=100.0, event_seed=drive['seedcore'].value, - **sync_inputs_kwargs) + **sync_inputs_kwargs, + ) elif drive['type'] in ('Evoked', 'Gaussian'): single_simulation_data['net'].add_evoked_drive( name=drive['name'], @@ -1891,7 +2187,8 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data, synaptic_delays=synaptic_delays, space_constant=3.0, event_seed=drive['seedcore'].value, - **sync_inputs_kwargs) + **sync_inputs_kwargs, + ) elif drive['type'] in ('Rhythmic', 'Bursty'): single_simulation_data['net'].add_bursty_drive( name=drive['name'], @@ -1906,17 +2203,31 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data, weights_nmda=weights_nmda, synaptic_delays=synaptic_delays, event_seed=drive['seedcore'].value, - **sync_inputs_kwargs) + **sync_inputs_kwargs, + ) -def run_button_clicked(widget_simulation_name, log_out, drive_widgets, - all_data, dt, tstop, ntrials, backend_selection, - mpi_cmd, n_jobs, params, simulation_status_bar, - simulation_status_contents, connectivity_textfields, - viz_manager, simulations_list_widget, - cell_pameters_widgets): +def run_button_clicked( + widget_simulation_name, + log_out, + drive_widgets, + all_data, + dt, + tstop, + ntrials, + backend_selection, + mpi_cmd, + n_jobs, + params, + simulation_status_bar, + simulation_status_contents, + connectivity_textfields, + viz_manager, + simulations_list_widget, + cell_pameters_widgets, +): """Run the simulation and plot outputs.""" - simulation_data = all_data["simulation_data"] + simulation_data = all_data['simulation_data'] with log_out: # clear empty trash simulations for _name in tuple(simulation_data.keys()): @@ -1925,36 +2236,44 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets, _sim_name = widget_simulation_name.value if simulation_data[_sim_name]['net'] is not None: - print("Simulation with the same name exists!") - simulation_status_bar.value = simulation_status_contents[ - 'failed'] + print('Simulation with the same name exists!') + simulation_status_bar.value = simulation_status_contents['failed'] return - _init_network_from_widgets(params, dt, tstop, - simulation_data[_sim_name], drive_widgets, - connectivity_textfields, - cell_pameters_widgets) + _init_network_from_widgets( + params, + dt, + tstop, + simulation_data[_sim_name], + drive_widgets, + connectivity_textfields, + cell_pameters_widgets, + ) - print("start simulation") - if backend_selection.value == "MPI": + print('start simulation') + if backend_selection.value == 'MPI': backend = MPIBackend( - n_procs=multiprocessing.cpu_count() - 1, mpi_cmd=mpi_cmd.value) + n_procs=multiprocessing.cpu_count() - 1, mpi_cmd=mpi_cmd.value + ) else: backend = JoblibBackend(n_jobs=n_jobs.value) - print(f"Using Joblib with {n_jobs.value} core(s).") + print(f'Using Joblib with {n_jobs.value} core(s).') with backend: simulation_status_bar.value = simulation_status_contents['running'] simulation_data[_sim_name]['dpls'] = simulate_dipole( simulation_data[_sim_name]['net'], tstop=tstop.value, dt=dt.value, - n_trials=ntrials.value) + n_trials=ntrials.value, + ) - simulation_status_bar.value = simulation_status_contents[ - 'finished'] + simulation_status_bar.value = simulation_status_contents['finished'] - sim_names = [sim_name for sim_name in simulation_data - if simulation_data[sim_name]['net'] is not None] + sim_names = [ + sim_name + for sim_name in simulation_data + if simulation_data[sim_name]['net'] is not None + ] simulations_list_widget.options = sim_names simulations_list_widget.value = sim_names[0] @@ -1962,20 +2281,22 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets, viz_manager.reset_fig_config_tabs() viz_manager.add_figure() fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1) - ax_plots = [("ax0", "input histogram"), ("ax1", "current dipole")] + ax_plots = [('ax0', 'input histogram'), ('ax1', 'current dipole')] for ax_name, plot_type in ax_plots: - viz_manager._simulate_edit_figure(fig_name, ax_name, _sim_name, - plot_type, {}, "plot") + viz_manager._simulate_edit_figure( + fig_name, ax_name, _sim_name, plot_type, {}, 'plot' + ) -def _update_cell_params_vbox(cell_type_out, cell_parameters_list, - cell_type, cell_layer): - cell_parameters_key = f"{cell_type}_{cell_layer}" +def _update_cell_params_vbox( + cell_type_out, cell_parameters_list, cell_type, cell_layer +): + cell_parameters_key = f'{cell_type}_{cell_layer}' if cell_layer in ['Biophysics', 'Geometry']: cell_parameters_key += f" {cell_type.split(' ')[0]}" # Needed for the button to display L2/3, but the underlying data to use L2 - cell_parameters_key = cell_parameters_key.replace("L2/3", "L2") + cell_parameters_key = cell_parameters_key.replace('L2/3', 'L2') if cell_parameters_key in cell_parameters_list: cell_type_out.clear_output() @@ -1998,12 +2319,9 @@ def _update_L2_geometry_cell_params(net, cell_param_key, param_list): dendrite_cm = cell_params[4].value dendrite_Ra = cell_params[5].value - dendrite_sections = [name for name in sections.keys() - if name != 'soma' - ] + dendrite_sections = [name for name in sections.keys() if name != 'soma'] - param_indices = [ - (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19)] + param_indices = [(6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19)] # Dendrite for section, indices in zip(dendrite_sections, param_indices): @@ -2028,13 +2346,18 @@ def _update_L5_geometry_cell_params(net, cell_param_key, param_list): dendrite_cm = cell_params[4].value dendrite_Ra = cell_params[5].value - dendrite_sections = [name for name in sections.keys() - if name != 'soma' - ] + dendrite_sections = [name for name in sections.keys() if name != 'soma'] param_indices = [ - (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), - (16, 17), (18, 19), (20, 21)] + (6, 7), + (8, 9), + (10, 11), + (12, 13), + (14, 15), + (16, 17), + (18, 19), + (20, 21), + ] # Dentrite for section, indices in zip(dendrite_sections, param_indices): @@ -2050,8 +2373,7 @@ def _update_synapse_cell_params(net, cell_param_key, param_list): network_synapses = net.cell_types[cell_type].synapses synapse_sections = ['ampa', 'nmda', 'gabaa', 'gabab'] - param_indices = [ - (0, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)] + param_indices = [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)] # Update Dendrite for section, indices in zip(synapse_sections, param_indices): @@ -2061,7 +2383,6 @@ def _update_synapse_cell_params(net, cell_param_key, param_list): def _update_L2_biophysics_cell_params(net, cell_param_key, param_list): - cell_type = f'{cell_param_key.split("_")[0]}_pyramidal' sections = net.cell_types[cell_type].sections # Soma @@ -2070,9 +2391,9 @@ def _update_L2_biophysics_cell_params(net, cell_param_key, param_list): 'gkbar_hh2': param_list[0].value, 'gnabar_hh2': param_list[1].value, 'el_hh2': param_list[2].value, - 'gl_hh2': param_list[3].value}, - 'km': { - 'gbar_km': param_list[4].value} + 'gl_hh2': param_list[3].value, + }, + 'km': {'gbar_km': param_list[4].value}, } sections['soma'].mechs.update(mechs_params) @@ -2082,9 +2403,9 @@ def _update_L2_biophysics_cell_params(net, cell_param_key, param_list): 'gkbar_hh2': param_list[5].value, 'gnabar_hh2': param_list[6].value, 'el_hh2': param_list[7].value, - 'gl_hh2': param_list[8].value} - mechs_params['km'] = { - 'gbar_km': param_list[9].value} + 'gl_hh2': param_list[8].value, + } + mechs_params['km'] = {'gbar_km': param_list[9].value} update_common_dendrite_sections(sections, mechs_params) @@ -2094,37 +2415,18 @@ def _update_L5_biophysics_cell_params(net, cell_param_key, param_list): sections = net.cell_types[cell_type].sections # Soma mechs_params = { - 'hh2': - { + 'hh2': { 'gkbar_hh2': param_list[0].value, 'gnabar_hh2': param_list[1].value, 'el_hh2': param_list[2].value, - 'gl_hh2': param_list[3].value - }, - 'ca': - { - 'gbar_ca': param_list[4].value - }, - 'cad': - { - 'taur_cad': param_list[5].value - }, - 'kca': - { - 'gbar_kca': param_list[6].value + 'gl_hh2': param_list[3].value, }, - 'km': - { - 'gbar_km': param_list[7].value - }, - 'cat': - { - 'gbar_cat': param_list[8].value - }, - 'ar': - { - 'gbar_ar': param_list[9].value - } + 'ca': {'gbar_ca': param_list[4].value}, + 'cad': {'taur_cad': param_list[5].value}, + 'kca': {'gbar_kca': param_list[6].value}, + 'km': {'gbar_km': param_list[7].value}, + 'cat': {'gbar_cat': param_list[8].value}, + 'ar': {'gbar_ar': param_list[9].value}, } sections['soma'].mechs.update(mechs_params) @@ -2134,24 +2436,25 @@ def _update_L5_biophysics_cell_params(net, cell_param_key, param_list): 'gkbar_hh2': param_list[10].value, 'gnabar_hh2': param_list[11].value, 'el_hh2': param_list[12].value, - 'gl_hh2': param_list[13].value} + 'gl_hh2': param_list[13].value, + } mechs_params['ca'] = {'gbar_ca': param_list[14].value} mechs_params['cad'] = {'taur_cad': param_list[15].value} mechs_params['kca'] = {'gbar_kca': param_list[16].value} mechs_params['km'] = {'gbar_km': param_list[17].value} mechs_params['cat'] = {'gbar_cat': param_list[18].value} - mechs_params['ar'] = {'gbar_ar': partial( - _exp_g_at_dist, zero_val=param_list[19].value, - exp_term=3e-3, offset=0.0)} + mechs_params['ar'] = { + 'gbar_ar': partial( + _exp_g_at_dist, zero_val=param_list[19].value, exp_term=3e-3, offset=0.0 + ) + } update_common_dendrite_sections(sections, mechs_params) def update_common_dendrite_sections(sections, mechs_params): - dendrite_sections = [ - name for name in sections.keys() if name != 'soma' - ] + dendrite_sections = [name for name in sections.keys() if name != 'soma'] for section in dendrite_sections: sections[section].mechs.update(deepcopy(mechs_params)) @@ -2167,11 +2470,11 @@ def _serialize_simulation(log_out, sim_data, simulation_list_widget): def serialize_simulation(simulations_data, simulation_name): """Serializes simulation data to CSV. - Creates a single CSV file or a ZIP file containing multiple CSVs, - depending on the number of trials in the simulation. + Creates a single CSV file or a ZIP file containing multiple CSVs, + depending on the number of trials in the simulation. """ - simulation_data = simulations_data["simulation_data"] + simulation_data = simulations_data['simulation_data'] csv_trials_output = [] # CSV file headers headers = 'times,agg,L2,L5' @@ -2179,25 +2482,26 @@ def serialize_simulation(simulations_data, simulation_name): for dpl_trial in simulation_data[simulation_name]['dpls']: # Combine all data columns at once - signals_matrix = np.column_stack(( - dpl_trial.times, - dpl_trial.data['agg'], - dpl_trial.data['L2'], - dpl_trial.data['L5'] - )) + signals_matrix = np.column_stack( + ( + dpl_trial.times, + dpl_trial.data['agg'], + dpl_trial.data['L2'], + dpl_trial.data['L5'], + ) + ) # Using StringIO to collect CSV data with io.StringIO() as output: - np.savetxt(output, signals_matrix, delimiter=',', - header=headers, fmt=fmt) + np.savetxt(output, signals_matrix, delimiter=',', header=headers, fmt=fmt) csv_trials_output.append(output.getvalue()) if len(csv_trials_output) == 1: # Return a single csv file - return csv_trials_output[0], ".csv" + return csv_trials_output[0], '.csv' else: # Create zip file - return _create_zip(csv_trials_output, simulation_name), ".zip" + return _create_zip(csv_trials_output, simulation_name), '.zip' def _serialize_config(log_out, sim_data, simulation_list_widget): @@ -2212,7 +2516,7 @@ def serialize_config(simulations_data, simulation_name): """Serializes Network configuration data to json.""" # Get network from data dictionary - net = simulations_data["simulation_data"][simulation_name]['net'] + net = simulations_data['simulation_data'][simulation_name]['net'] # Write to buffer with io.StringIO() as output: @@ -2234,9 +2538,9 @@ def handle_backend_change(backend_type, backend_config, mpi_cmd, n_jobs): """Switch backends between MPI and Joblib.""" backend_config.clear_output() with backend_config: - if backend_type == "MPI": + if backend_type == 'MPI': display(mpi_cmd) - elif backend_type == "Joblib": + elif backend_type == 'Joblib': display(n_jobs) @@ -2253,5 +2557,6 @@ def launch(): You can pass voila commandline parameters as usual. """ from voila.app import main + notebook_path = Path(__file__).parent / 'hnn_widget.ipynb' main([str(notebook_path.resolve()), *sys.argv[1:]]) diff --git a/hnn_core/hnn_io.py b/hnn_core/hnn_io.py index 667cf82ab..173fead27 100644 --- a/hnn_core/hnn_io.py +++ b/hnn_core/hnn_io.py @@ -59,8 +59,7 @@ def _conn_to_dict(conn): 'src_type': conn['src_type'], 'src_gids': list(conn['src_gids']), 'num_srcs': conn['num_srcs'], - 'gid_pairs': {str(key): val - for key, val in conn['gid_pairs'].items()}, + 'gid_pairs': {str(key): val for key, val in conn['gid_pairs'].items()}, 'loc': conn['loc'], 'receptor': conn['receptor'], 'nc_dict': conn['nc_dict'], @@ -101,11 +100,13 @@ def _read_cell_types(cell_types_data): sections_data = cell_data['sections'] for section_name in sections_data: section_data = sections_data[section_name] - sections[section_name] = Section(L=section_data['L'], - diam=section_data['diam'], - cm=section_data['cm'], - Ra=section_data['Ra'], - end_pts=section_data['end_pts']) + sections[section_name] = Section( + L=section_data['L'], + diam=section_data['diam'], + cm=section_data['cm'], + Ra=section_data['Ra'], + end_pts=section_data['end_pts'], + ) # Set section attributes sections[section_name].syns = section_data['syns'] sections[section_name].mechs = section_data['mechs'] @@ -120,13 +121,15 @@ def _read_cell_types(cell_types_data): value.append(_str_to_node(child)) cell_tree[key] = value - cell_types[cell_name] = Cell(name=cell_data['name'], - pos=tuple(cell_data['pos']), - sections=sections, - synapses=cell_data['synapses'], - cell_tree=cell_tree, - sect_loc=cell_data['sect_loc'], - gid=cell_data['gid']) + cell_types[cell_name] = Cell( + name=cell_data['name'], + pos=tuple(cell_data['pos']), + sections=sections, + synapses=cell_data['synapses'], + cell_tree=cell_tree, + sect_loc=cell_data['sect_loc'], + gid=cell_data['gid'], + ) # Setting cell attributes cell_types[cell_name].dipole_pp = cell_data['dipole_pp'] cell_types[cell_name].vsec = cell_data['vsec'] @@ -140,9 +143,11 @@ def _read_cell_response(cell_response_data, read_output): """Returns CellResponse from json encoded data""" if (not cell_response_data) or (not read_output): return None - cell_response = CellResponse(spike_times=cell_response_data['spike_times'], - spike_gids=cell_response_data['spike_gids'], - spike_types=cell_response_data['spike_types']) + cell_response = CellResponse( + spike_times=cell_response_data['spike_times'], + spike_gids=cell_response_data['spike_gids'], + spike_types=cell_response_data['spike_types'], + ) cell_response._times = cell_response_data['times'] cell_response._vsec = list() @@ -173,53 +178,57 @@ def _read_external_drive(net, drive_data, read_output): if (drive_data['type'] == 'evoked') or (drive_data['type'] == 'gaussian'): # Skipped n_drive_cells here - net.add_evoked_drive(name=drive_data['name'], - mu=drive_data['dynamics']['mu'], - sigma=drive_data['dynamics']['sigma'], - numspikes=drive_data['dynamics']['numspikes'], - location=drive_data['location'], - n_drive_cells=_set_from_cell_specific(drive_data), - cell_specific=drive_data['cell_specific'], - weights_ampa=drive_data['weights_ampa'], - weights_nmda=drive_data['weights_nmda'], - synaptic_delays=drive_data['synaptic_delays'], - probability=drive_data["probability"], - event_seed=drive_data['event_seed'], - conn_seed=drive_data['conn_seed']) + net.add_evoked_drive( + name=drive_data['name'], + mu=drive_data['dynamics']['mu'], + sigma=drive_data['dynamics']['sigma'], + numspikes=drive_data['dynamics']['numspikes'], + location=drive_data['location'], + n_drive_cells=_set_from_cell_specific(drive_data), + cell_specific=drive_data['cell_specific'], + weights_ampa=drive_data['weights_ampa'], + weights_nmda=drive_data['weights_nmda'], + synaptic_delays=drive_data['synaptic_delays'], + probability=drive_data['probability'], + event_seed=drive_data['event_seed'], + conn_seed=drive_data['conn_seed'], + ) elif drive_data['type'] == 'poisson': - net.add_poisson_drive(name=drive_data['name'], - tstart=drive_data['dynamics']['tstart'], - tstop=drive_data['dynamics']['tstop'], - rate_constant=(drive_data['dynamics'] - ['rate_constant']), - location=drive_data['location'], - n_drive_cells=( - _set_from_cell_specific(drive_data)), - cell_specific=drive_data['cell_specific'], - weights_ampa=drive_data['weights_ampa'], - weights_nmda=drive_data['weights_nmda'], - synaptic_delays=drive_data['synaptic_delays'], - probability=drive_data["probability"], - event_seed=drive_data['event_seed'], - conn_seed=drive_data['conn_seed']) + net.add_poisson_drive( + name=drive_data['name'], + tstart=drive_data['dynamics']['tstart'], + tstop=drive_data['dynamics']['tstop'], + rate_constant=(drive_data['dynamics']['rate_constant']), + location=drive_data['location'], + n_drive_cells=(_set_from_cell_specific(drive_data)), + cell_specific=drive_data['cell_specific'], + weights_ampa=drive_data['weights_ampa'], + weights_nmda=drive_data['weights_nmda'], + synaptic_delays=drive_data['synaptic_delays'], + probability=drive_data['probability'], + event_seed=drive_data['event_seed'], + conn_seed=drive_data['conn_seed'], + ) elif drive_data['type'] == 'bursty': - net.add_bursty_drive(name=drive_data['name'], - tstart=drive_data['dynamics']['tstart'], - tstart_std=drive_data['dynamics']['tstart_std'], - tstop=drive_data['dynamics']['tstop'], - burst_rate=drive_data['dynamics']['burst_rate'], - burst_std=drive_data['dynamics']['burst_std'], - numspikes=drive_data['dynamics']['numspikes'], - spike_isi=drive_data['dynamics']['spike_isi'], - location=drive_data['location'], - n_drive_cells=_set_from_cell_specific(drive_data), - cell_specific=drive_data['cell_specific'], - weights_ampa=drive_data['weights_ampa'], - weights_nmda=drive_data['weights_nmda'], - synaptic_delays=drive_data['synaptic_delays'], - probability=drive_data["probability"], - event_seed=drive_data['event_seed'], - conn_seed=drive_data['conn_seed']) + net.add_bursty_drive( + name=drive_data['name'], + tstart=drive_data['dynamics']['tstart'], + tstart_std=drive_data['dynamics']['tstart_std'], + tstop=drive_data['dynamics']['tstop'], + burst_rate=drive_data['dynamics']['burst_rate'], + burst_std=drive_data['dynamics']['burst_std'], + numspikes=drive_data['dynamics']['numspikes'], + spike_isi=drive_data['dynamics']['spike_isi'], + location=drive_data['location'], + n_drive_cells=_set_from_cell_specific(drive_data), + cell_specific=drive_data['cell_specific'], + weights_ampa=drive_data['weights_ampa'], + weights_nmda=drive_data['weights_nmda'], + synaptic_delays=drive_data['synaptic_delays'], + probability=drive_data['probability'], + event_seed=drive_data['event_seed'], + conn_seed=drive_data['conn_seed'], + ) net.external_drives[drive_data['name']]['events'] = drive_data['events'] if not read_output: @@ -233,32 +242,34 @@ def _read_connectivity(net, conns_data): for conn_data in conns_data: src_gids = [int(s) for s in conn_data['gid_pairs'].keys()] - target_gids_nested = [target_gid for target_gid - in conn_data['gid_pairs'].values()] + target_gids_nested = [ + target_gid for target_gid in conn_data['gid_pairs'].values() + ] conn_data['allow_autapses'] = bool(conn_data['allow_autapses']) - net.add_connection(src_gids=src_gids, - target_gids=target_gids_nested, - loc=conn_data['loc'], - receptor=conn_data['receptor'], - weight=conn_data['nc_dict']['A_weight'], - delay=conn_data['nc_dict']['A_delay'], - lamtha=conn_data['nc_dict']['lamtha'], - allow_autapses=conn_data['allow_autapses'], - probability=conn_data['probability']) + net.add_connection( + src_gids=src_gids, + target_gids=target_gids_nested, + loc=conn_data['loc'], + receptor=conn_data['receptor'], + weight=conn_data['nc_dict']['A_weight'], + delay=conn_data['nc_dict']['A_delay'], + lamtha=conn_data['nc_dict']['lamtha'], + allow_autapses=conn_data['allow_autapses'], + probability=conn_data['probability'], + ) def _read_rec_arrays(net, rec_arrays_data, read_output): """Adds rec arrays to Network from json data.""" for key in rec_arrays_data: rec_array = rec_arrays_data[key] - net.add_electrode_array(name=key, - electrode_pos=[ - tuple(pos) for - pos in rec_array['positions'] - ], - conductivity=rec_array['conductivity'], - method=rec_array['method'], - min_distance=rec_array['min_distance']) + net.add_electrode_array( + name=key, + electrode_pos=[tuple(pos) for pos in rec_array['positions']], + conductivity=rec_array['conductivity'], + method=rec_array['method'], + min_distance=rec_array['min_distance'], + ) net.rec_arrays[key]._times = rec_array['times'] net.rec_arrays[key]._data = rec_array['voltages'] if not read_output: @@ -296,23 +307,25 @@ def network_to_dict(net, write_output=False): 'N_pyr_x': net._N_pyr_x, 'N_pyr_y': net._N_pyr_y, 'celsius': net._params['celsius'], - 'cell_types': {name: template.to_dict() - for name, template in net.cell_types.items() - }, - 'gid_ranges': {cell: {'start': c_range.start, 'stop': c_range.stop} - for cell, c_range in net.gid_ranges.items() - }, + 'cell_types': { + name: template.to_dict() for name, template in net.cell_types.items() + }, + 'gid_ranges': { + cell: {'start': c_range.start, 'stop': c_range.stop} + for cell, c_range in net.gid_ranges.items() + }, 'pos_dict': {cell: pos for cell, pos in net.pos_dict.items()}, 'cell_response': _cell_response_to_dict(net, write_output), - 'external_drives': {drive: _external_drive_to_dict(params, - write_output) - for drive, params in net.external_drives.items() - }, + 'external_drives': { + drive: _external_drive_to_dict(params, write_output) + for drive, params in net.external_drives.items() + }, 'external_biases': net.external_biases, 'connectivity': [_conn_to_dict(conn) for conn in net.connectivity], - 'rec_arrays': {ra_name: _rec_array_to_dict(ex_array, write_output) - for ra_name, ex_array in net.rec_arrays.items() - }, + 'rec_arrays': { + ra_name: _rec_array_to_dict(ex_array, write_output) + for ra_name, ex_array in net.rec_arrays.items() + }, 'threshold': net.threshold, 'delay': net.delay, } @@ -346,9 +359,10 @@ def write_network_configuration(net, output, overwrite=True): if isinstance(output, (str, Path)): if overwrite is False and os.path.exists(output): - raise FileExistsError('File already exists at path %s. Rename ' - 'the file or set overwrite=True.' % (output,) - ) + raise FileExistsError( + 'File already exists at path %s. Rename ' + 'the file or set overwrite=True.' % (output,) + ) # Saving file with open(output, 'w', encoding='utf-8') as f: json.dump(net_data_converted, f, ensure_ascii=False, indent=4) @@ -378,10 +392,11 @@ def _order_drives(gid_ranges, external_drives): Ordered dict with drives by ascending gid ranges """ ordered_drives = OrderedDict() - min_gid_to_drive = {min(gid_range): name - for (name, gid_range) in gid_ranges.items() - if name in external_drives.keys() - } + min_gid_to_drive = { + min(gid_range): name + for (name, gid_range) in gid_ranges.items() + if name in external_drives.keys() + } min_gid_sorted = sorted(list(min_gid_to_drive.keys())) for min_gid in min_gid_sorted: drive_name = min_gid_to_drive[min_gid] @@ -390,9 +405,7 @@ def _order_drives(gid_ranges, external_drives): return ordered_drives -def dict_to_network(net_data, - read_drives=True, - read_external_biases=True): +def dict_to_network(net_data, read_drives=True, read_external_biases=True): """Converts a dict of network configurations to a Network Parameters @@ -411,6 +424,7 @@ def dict_to_network(net_data, # Importing Network. # Cannot do this globally due to circular import. from .network import Network + params = dict() params['celsius'] = net_data['celsius'] params['threshold'] = net_data['threshold'] @@ -418,10 +432,7 @@ def dict_to_network(net_data, mesh_shape = (net_data['N_pyr_x'], net_data['N_pyr_y']) # Instantiating network - net = Network(params, - mesh_shape=mesh_shape, - legacy_mode=net_data['legacy_mode'] - ) + net = Network(params, mesh_shape=mesh_shape, legacy_mode=net_data['legacy_mode']) # Setting attributes # Set cell types @@ -436,14 +447,13 @@ def dict_to_network(net_data, # Set pos_dict net.pos_dict = _read_pos_dict(net_data['pos_dict']) # Set cell_response - net.cell_response = _read_cell_response(net_data['cell_response'], - read_output=False) + net.cell_response = _read_cell_response( + net_data['cell_response'], read_output=False + ) # Set external drives - external_drive_data = _order_drives(net.gid_ranges, - net_data['external_drives']) + external_drive_data = _order_drives(net.gid_ranges, net_data['external_drives']) for key in external_drive_data.keys(): - _read_external_drive(net, external_drive_data[key], - read_output=False) + _read_external_drive(net, external_drive_data[key], read_output=False) # Set external biases if read_external_biases: net.external_biases = net_data['external_biases'] @@ -462,9 +472,7 @@ def dict_to_network(net_data, return net -def read_network_configuration(fname, - read_drives=True, - read_external_biases=True): +def read_network_configuration(fname, read_drives=True, read_external_biases=True): """Read network from a json configuration file. Parameters @@ -485,9 +493,11 @@ def read_network_configuration(fname, net_data = json.load(file) if net_data.get('object_type') != 'Network': - raise ValueError('The json should encode a Network object. ' - 'The file contains object of ' - 'type %s' % (net_data.get('object_type'))) + raise ValueError( + 'The json should encode a Network object. ' + 'The file contains object of ' + 'type %s' % (net_data.get('object_type')) + ) net = dict_to_network(net_data, read_drives, read_external_biases) diff --git a/hnn_core/mpi_child.py b/hnn_core/mpi_child.py index a26ac5eae..f90dc4e39 100644 --- a/hnn_core/mpi_child.py +++ b/hnn_core/mpi_child.py @@ -26,15 +26,15 @@ def _str_to_net(input_str): data_str = _extract_data(input_str, 'net') if len(data_str) > 0: # get the size, but start the search after data - net_size = _extract_data_length(input_str[len(data_str):], - 'net') + net_size = _extract_data_length(input_str[len(data_str) :], 'net') # check the size if len(data_str) != net_size: - raise ValueError("Got incorrect network size: %d bytes " % - len(data_str) + "expected length: %d" % net_size) + raise ValueError( + 'Got incorrect network size: %d bytes ' % len(data_str) + + 'expected length: %d' % net_size + ) # unpickle the net - net = pickle.loads(base64.b64decode(data_str.encode(), - validate=True)) + net = pickle.loads(base64.b64decode(data_str.encode(), validate=True)) return net @@ -52,6 +52,7 @@ class MPISimulation(object): rank : int The rank for each processor part of the MPI communicator """ + def __init__(self, skip_mpi_import=False): self.skip_mpi_import = skip_mpi_import if skip_mpi_import: @@ -69,6 +70,7 @@ def __exit__(self, type, value, traceback): # skip Finalize() if we didn't import MPI on __init__ if hasattr(self, 'comm'): from mpi4py import MPI + MPI.Finalize() def _read_net(self): @@ -144,6 +146,7 @@ def run(self, net, tstop, dt, n_trials): """This file is called on command-line from nrniv""" import traceback + rc = 0 try: diff --git a/hnn_core/network.py b/hnn_core/network.py index c00de12b0..fdd07d79e 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -60,6 +60,7 @@ def _create_cell_coords(n_pyr_x, n_pyr_y, zdiff, inplane_distance): Common positions are all located at origin. Sort of a hack because of redundancy. """ + def _calc_pyramidal_coord(xxrange, yyrange, zdiff): list_coords = [pos for pos in it.product(xxrange, yyrange, [zdiff])] return list_coords @@ -71,13 +72,15 @@ def _calc_basket_coord(n_pyr_x, n_pyr_y, zdiff, inplane_distance, weight): yeven = np.arange(0, n_pyr_y, 2) * inplane_distance yodd = np.arange(1, n_pyr_y, 2) * inplane_distance # create general list of x,y coords and sort it - coords = [pos for pos in it.product( - xzero, yeven)] + [pos for pos in it.product(xone, yodd)] + coords = [pos for pos in it.product(xzero, yeven)] + [ + pos for pos in it.product(xone, yodd) + ] coords_sorted = sorted(coords, key=lambda pos: pos[1]) # append the z value for position - list_coords = [(pos_xy[0], pos_xy[1], weight * zdiff) - for pos_xy in coords_sorted] + list_coords = [ + (pos_xy[0], pos_xy[1], weight * zdiff) for pos_xy in coords_sorted + ] return list_coords @@ -97,12 +100,12 @@ def _calc_origin(xxrange, yyrange, zdiff): pos_dict = { 'L5_pyramidal': _calc_pyramidal_coord(xxrange, yyrange, zdiff=0), 'L2_pyramidal': _calc_pyramidal_coord(xxrange, yyrange, zdiff=zdiff), - 'L5_basket': _calc_basket_coord(n_pyr_x, n_pyr_y, zdiff, - inplane_distance, weight=0.2 - ), - 'L2_basket': _calc_basket_coord(n_pyr_x, n_pyr_y, zdiff, - inplane_distance, weight=0.8 - ), + 'L5_basket': _calc_basket_coord( + n_pyr_x, n_pyr_y, zdiff, inplane_distance, weight=0.2 + ), + 'L2_basket': _calc_basket_coord( + n_pyr_x, n_pyr_y, zdiff, inplane_distance, weight=0.8 + ), 'origin': _calc_origin(xxrange, yyrange, zdiff), } @@ -156,14 +159,14 @@ def _connection_probability(conn, probability, conn_seed=None): raise ValueError('probability must be in the range (0,1)') # Flatten connections into a list of targets. all_connections = np.concatenate( - [target_src_pair for - target_src_pair in conn['gid_pairs'].values()]) - n_connections = np.round( - len(all_connections) * probability).astype(int) + [target_src_pair for target_src_pair in conn['gid_pairs'].values()] + ) + n_connections = np.round(len(all_connections) * probability).astype(int) # Select a random subset of connections to retain. new_connections = rng.choice( - range(len(all_connections)), n_connections, replace=False) + range(len(all_connections)), n_connections, replace=False + ) remove_srcs = list() connection_idx = 0 for src_gid, target_src_pair in conn['gid_pairs'].items(): @@ -183,8 +186,7 @@ def _connection_probability(conn, probability, conn_seed=None): conn['gid_pairs'].pop(src_gid) -def pick_connection(net, src_gids=None, target_gids=None, - loc=None, receptor=None): +def pick_connection(net, src_gids=None, target_gids=None, loc=None, receptor=None): """Returns indices of connections that match search parameters. Parameters @@ -233,15 +235,15 @@ def pick_connection(net, src_gids=None, target_gids=None, # Convert src and target gids to lists valid_srcs = list(net.gid_ranges.keys()) # includes drives as srcs valid_targets = list(net.cell_types.keys()) - src_gids_checked = _check_gids(src_gids, net.gid_ranges, - valid_srcs, 'src_gids', same_type=False) - target_gids_checked = _check_gids(target_gids, net.gid_ranges, - valid_targets, 'target_gids', - same_type=False) + src_gids_checked = _check_gids( + src_gids, net.gid_ranges, valid_srcs, 'src_gids', same_type=False + ) + target_gids_checked = _check_gids( + target_gids, net.gid_ranges, valid_targets, 'target_gids', same_type=False + ) _validate_type(loc, (str, list, None), 'loc', 'str, list, or None') - _validate_type(receptor, (str, list, None), 'receptor', - 'str, list, or None') + _validate_type(receptor, (str, list, None), 'receptor', 'str, list, or None') valid_loc = ['proximal', 'distal', 'soma'] valid_receptor = ['ampa', 'nmda', 'gabaa', 'gabab'] @@ -270,11 +272,12 @@ def pick_connection(net, src_gids=None, target_gids=None, # Look up conn indices that match search terms and add to set. conn_set = set() - search_pairs = [(src_gids_checked, src_dict), - (target_gids_checked, target_dict), - (loc_list, loc_dict), - (receptor_list, receptor_dict), - ] + search_pairs = [ + (src_gids_checked, src_dict), + (target_gids_checked, target_dict), + (loc_list, loc_dict), + (receptor_list, receptor_dict), + ] for search_terms, search_dict in search_pairs: if search_terms: inner_set = set() @@ -365,8 +368,13 @@ class Network: connectivity information contained in ``params`` will be ignored. """ - def __init__(self, params, add_drives_from_params=False, - legacy_mode=False, mesh_shape=(10, 10)): + def __init__( + self, + params, + add_drives_from_params=False, + legacy_mode=False, + mesh_shape=(10, 10), + ): # Save the parameters used to create the Network _validate_type(params, dict, 'params') self._params = params @@ -387,15 +395,17 @@ def __init__(self, params, add_drives_from_params=False, warnings.warn( 'Legacy mode is used solely to maintain compatibility with' '.param files of the old HNN GUI. This feature will be ' - 'deprecrated in future releases.', DeprecationWarning, - stacklevel=1) + 'deprecrated in future releases.', + DeprecationWarning, + stacklevel=1, + ) # Source dict of names, first real ones only! cell_types = { 'L2_basket': basket(cell_name=_short_name('L2_basket')), 'L2_pyramidal': pyramidal(cell_name=_short_name('L2_pyramidal')), 'L5_basket': basket(cell_name=_short_name('L5_basket')), - 'L5_pyramidal': pyramidal(cell_name=_short_name('L5_pyramidal')) + 'L5_pyramidal': pyramidal(cell_name=_short_name('L5_pyramidal')), } self.cell_response = None @@ -423,22 +433,26 @@ def __init__(self, params, add_drives_from_params=False, _validate_type(mesh_shape[1], int, 'mesh_shape[1]') if mesh_shape[0] < 1 or mesh_shape[1] < 1: - raise ValueError('mesh_shape must be a tuple of positive ' - f'integers, got: {mesh_shape}') + raise ValueError( + 'mesh_shape must be a tuple of positive ' f'integers, got: {mesh_shape}' + ) self._N_pyr_x = mesh_shape[0] self._N_pyr_y = mesh_shape[1] self._inplane_distance = 1.0 # XXX hard-coded default self._layer_separation = 1307.4 # XXX hard-coded default - self.set_cell_positions(inplane_distance=self._inplane_distance, - layer_separation=self._layer_separation) + self.set_cell_positions( + inplane_distance=self._inplane_distance, + layer_separation=self._layer_separation, + ) # populates self.gid_ranges for the 1st time: order matters for # NetworkBuilder! for cell_name in cell_types: - self._add_cell_type(cell_name, self.pos_dict[cell_name], - cell_template=cell_types[cell_name]) + self._add_cell_type( + cell_name, self.pos_dict[cell_name], cell_template=cell_types[cell_name] + ) if add_drives_from_params: _add_drives_from_params(self) @@ -448,11 +462,11 @@ def __init__(self, params, add_drives_from_params=False, def __repr__(self): class_name = self.__class__.__name__ - s = ("%d x %d Pyramidal cells (L2, L5)" - % (self._N_pyr_x, self._N_pyr_y)) - s += ("\n%d L2 basket cells\n%d L5 basket cells" - % (len(self.pos_dict['L2_basket']), - len(self.pos_dict['L5_basket']))) + s = '%d x %d Pyramidal cells (L2, L5)' % (self._N_pyr_x, self._N_pyr_y) + s += '\n%d L2 basket cells\n%d L5 basket cells' % ( + len(self.pos_dict['L2_basket']), + len(self.pos_dict['L5_basket']), + ) return '<%s | %s>' % (class_name, s) def __eq__(self, other): @@ -460,8 +474,9 @@ def __eq__(self, other): return NotImplemented # Check connectivity - if ((len(self.connectivity) != len(other.connectivity)) or - not (_compare_lists(self.connectivity, other.connectivity))): + if (len(self.connectivity) != len(other.connectivity)) or not ( + _compare_lists(self.connectivity, other.connectivity) + ): return False # Check all other attributes @@ -479,8 +494,7 @@ def __eq__(self, other): return True - def set_cell_positions(self, *, inplane_distance=None, - layer_separation=None): + def set_cell_positions(self, *, inplane_distance=None, layer_separation=None): """Set relative positions of cells arranged in a square grid Note that it is possible to change only a subset of the parameters @@ -499,20 +513,25 @@ def set_cell_positions(self, *, inplane_distance=None, if inplane_distance is None: inplane_distance = self._inplane_distance _validate_type(inplane_distance, (float, int), 'inplane_distance') - if not inplane_distance > 0.: - raise ValueError('In-plane distance must be positive, ' - f'got: {inplane_distance}') + if not inplane_distance > 0.0: + raise ValueError( + 'In-plane distance must be positive, ' f'got: {inplane_distance}' + ) if layer_separation is None: layer_separation = self._layer_separation _validate_type(layer_separation, (float, int), 'layer_separation') - if not layer_separation > 0.: - raise ValueError('Layer separation must be positive, ' - f'got: {layer_separation}') - - pos = _create_cell_coords(n_pyr_x=self._N_pyr_x, n_pyr_y=self._N_pyr_y, - zdiff=layer_separation, - inplane_distance=inplane_distance) + if not layer_separation > 0.0: + raise ValueError( + 'Layer separation must be positive, ' f'got: {layer_separation}' + ) + + pos = _create_cell_coords( + n_pyr_x=self._N_pyr_x, + n_pyr_y=self._N_pyr_y, + zdiff=layer_separation, + inplane_distance=inplane_distance, + ) # update positions of the real cells for key in pos.keys(): self.pos_dict[key] = pos[key] @@ -547,11 +566,24 @@ def copy(self): net_copy._reset_rec_arrays() return net_copy - def add_evoked_drive(self, name, *, mu, sigma, numspikes, location, - n_drive_cells='n_cells', cell_specific=True, - weights_ampa=None, weights_nmda=None, - space_constant=3., synaptic_delays=0.1, - probability=1.0, event_seed=2, conn_seed=3): + def add_evoked_drive( + self, + name, + *, + mu, + sigma, + numspikes, + location, + n_drive_cells='n_cells', + cell_specific=True, + weights_ampa=None, + weights_nmda=None, + space_constant=3.0, + synaptic_delays=0.1, + probability=1.0, + event_seed=2, + conn_seed=3, + ): """Add an 'evoked' external drive to the network Parameters @@ -632,8 +664,7 @@ def add_evoked_drive(self, name, *, mu, sigma, numspikes, location, probability < 1.0, the random subset of gids targeted is the same. """ if not self._legacy_mode: - _check_drive_parameter_values('evoked', sigma=sigma, - numspikes=numspikes) + _check_drive_parameter_values('evoked', sigma=sigma, numspikes=numspikes) drive = _NetworkDrive() drive['type'] = 'evoked' drive['location'] = location @@ -650,16 +681,37 @@ def add_evoked_drive(self, name, *, mu, sigma, numspikes, location, drive['synaptic_delays'] = synaptic_delays drive['probability'] = probability - self._attach_drive(name, drive, weights_ampa, weights_nmda, location, - space_constant, synaptic_delays, - n_drive_cells, cell_specific, probability) - - def add_poisson_drive(self, name, *, tstart=0, tstop=None, rate_constant, - location, n_drive_cells='n_cells', - cell_specific=True, weights_ampa=None, - weights_nmda=None, space_constant=100., - synaptic_delays=0.1, probability=1.0, event_seed=2, - conn_seed=3): + self._attach_drive( + name, + drive, + weights_ampa, + weights_nmda, + location, + space_constant, + synaptic_delays, + n_drive_cells, + cell_specific, + probability, + ) + + def add_poisson_drive( + self, + name, + *, + tstart=0, + tstop=None, + rate_constant, + location, + n_drive_cells='n_cells', + cell_specific=True, + weights_ampa=None, + weights_nmda=None, + space_constant=100.0, + synaptic_delays=0.1, + probability=1.0, + event_seed=2, + conn_seed=3, + ): """Add a Poisson-distributed external drive to the network Parameters @@ -730,25 +782,25 @@ def add_poisson_drive(self, name, *, tstart=0, tstop=None, rate_constant, Used to randomly remove connections when probability < 1.0. """ - _check_drive_parameter_values('Poisson', tstart=tstart, - tstop=tstop) - target_populations = _get_target_properties(weights_ampa, - weights_nmda, - synaptic_delays, - location)[0] - _check_poisson_rates(rate_constant, target_populations, - self.cell_types.keys()) + _check_drive_parameter_values('Poisson', tstart=tstart, tstop=tstop) + target_populations = _get_target_properties( + weights_ampa, weights_nmda, synaptic_delays, location + )[0] + _check_poisson_rates(rate_constant, target_populations, self.cell_types.keys()) if isinstance(rate_constant, dict): if not cell_specific: - raise ValueError(f"Drives specific to cell types are only " - f"possible with cell_specific=True and " - f"n_drive_cells='n_cells'. Got cell_specific" - f" cell_specific={cell_specific} and " - f"n_drive_cells={n_drive_cells}.") + raise ValueError( + f'Drives specific to cell types are only ' + f'possible with cell_specific=True and ' + f"n_drive_cells='n_cells'. Got cell_specific" + f' cell_specific={cell_specific} and ' + f'n_drive_cells={n_drive_cells}.' + ) elif isinstance(rate_constant, (float, int)): if cell_specific: - rate_constant = {cell_type: rate_constant for cell_type in - target_populations} + rate_constant = { + cell_type: rate_constant for cell_type in target_populations + } drive = _NetworkDrive() drive['type'] = 'poisson' @@ -756,8 +808,9 @@ def add_poisson_drive(self, name, *, tstart=0, tstop=None, rate_constant, drive['n_drive_cells'] = n_drive_cells drive['event_seed'] = event_seed drive['conn_seed'] = conn_seed - drive['dynamics'] = dict(tstart=tstart, tstop=tstop, - rate_constant=rate_constant) + drive['dynamics'] = dict( + tstart=tstart, tstop=tstop, rate_constant=rate_constant + ) drive['events'] = list() # Need to save this information drive['weights_ampa'] = weights_ampa @@ -765,16 +818,41 @@ def add_poisson_drive(self, name, *, tstart=0, tstop=None, rate_constant, drive['synaptic_delays'] = synaptic_delays drive['probability'] = probability - self._attach_drive(name, drive, weights_ampa, weights_nmda, location, - space_constant, synaptic_delays, - n_drive_cells, cell_specific, probability) - - def add_bursty_drive(self, name, *, tstart=0, tstart_std=0, tstop=None, - location, burst_rate, burst_std=0, numspikes=2, - spike_isi=10, n_drive_cells=1, cell_specific=False, - weights_ampa=None, weights_nmda=None, - synaptic_delays=0.1, space_constant=100., - probability=1.0, event_seed=2, conn_seed=3): + self._attach_drive( + name, + drive, + weights_ampa, + weights_nmda, + location, + space_constant, + synaptic_delays, + n_drive_cells, + cell_specific, + probability, + ) + + def add_bursty_drive( + self, + name, + *, + tstart=0, + tstart_std=0, + tstop=None, + location, + burst_rate, + burst_std=0, + numspikes=2, + spike_isi=10, + n_drive_cells=1, + cell_specific=False, + weights_ampa=None, + weights_nmda=None, + synaptic_delays=0.1, + space_constant=100.0, + probability=1.0, + event_seed=2, + conn_seed=3, + ): """Add a bursty (rhythmic) external drive to all cells of the network Parameters @@ -852,12 +930,20 @@ def add_bursty_drive(self, name, *, tstart=0, tstart_std=0, tstop=None, Used to randomly remove connections when probability < 1.0. """ if not self._legacy_mode: - _check_drive_parameter_values('bursty', tstart=tstart, tstop=tstop, - sigma=tstart_std, location=location) - _check_drive_parameter_values('bursty', sigma=burst_std, - numspikes=numspikes, - spike_isi=spike_isi, - burst_rate=burst_rate) + _check_drive_parameter_values( + 'bursty', + tstart=tstart, + tstop=tstop, + sigma=tstart_std, + location=location, + ) + _check_drive_parameter_values( + 'bursty', + sigma=burst_std, + numspikes=numspikes, + spike_isi=spike_isi, + burst_rate=burst_rate, + ) drive = _NetworkDrive() drive['type'] = 'bursty' @@ -865,10 +951,15 @@ def add_bursty_drive(self, name, *, tstart=0, tstart_std=0, tstop=None, drive['n_drive_cells'] = n_drive_cells drive['event_seed'] = event_seed drive['conn_seed'] = conn_seed - drive['dynamics'] = dict(tstart=tstart, - tstart_std=tstart_std, tstop=tstop, - burst_rate=burst_rate, burst_std=burst_std, - numspikes=numspikes, spike_isi=spike_isi) + drive['dynamics'] = dict( + tstart=tstart, + tstart_std=tstart_std, + tstop=tstop, + burst_rate=burst_rate, + burst_std=burst_std, + numspikes=numspikes, + spike_isi=spike_isi, + ) drive['events'] = list() # Need to save this information drive['weights_ampa'] = weights_ampa @@ -876,13 +967,32 @@ def add_bursty_drive(self, name, *, tstart=0, tstart_std=0, tstop=None, drive['synaptic_delays'] = synaptic_delays drive['probability'] = probability - self._attach_drive(name, drive, weights_ampa, weights_nmda, location, - space_constant, synaptic_delays, - n_drive_cells, cell_specific, probability) - - def _attach_drive(self, name, drive, weights_ampa, weights_nmda, location, - space_constant, synaptic_delays, n_drive_cells, - cell_specific, probability): + self._attach_drive( + name, + drive, + weights_ampa, + weights_nmda, + location, + space_constant, + synaptic_delays, + n_drive_cells, + cell_specific, + probability, + ) + + def _attach_drive( + self, + name, + drive, + weights_ampa, + weights_nmda, + location, + space_constant, + synaptic_delays, + n_drive_cells, + cell_specific, + probability, + ): """Attach a drive to network based on connectivity information Parameters @@ -940,39 +1050,53 @@ def _attach_drive(self, name, drive, weights_ampa, weights_nmda, location, self.pos_dict is updated, and self._update_gid_ranges() called """ if name in self.external_drives: - raise ValueError(f"Drive {name} already defined") + raise ValueError(f'Drive {name} already defined') - _validate_type( - probability, (float, dict), 'probability', 'float or dict') + _validate_type(probability, (float, dict), 'probability', 'float or dict') # allow passing weights as None, convert to dict here - (target_populations, weights_by_type, delays_by_type, - probability_by_type) = \ - _get_target_properties(weights_ampa, weights_nmda, synaptic_delays, - location, probability) + (target_populations, weights_by_type, delays_by_type, probability_by_type) = ( + _get_target_properties( + weights_ampa, weights_nmda, synaptic_delays, location, probability + ) + ) # weights passed must correspond to cells in the network if not target_populations.issubset(set(self.cell_types.keys())): - raise ValueError('Allowed drive target cell types are: ', - f'{self.cell_types.keys()}') + raise ValueError( + 'Allowed drive target cell types are: ', f'{self.cell_types.keys()}' + ) # enforce the same order as in self.cell_types - necessary for # consistent source gid assignment - target_populations = [cell_type for cell_type in self.cell_types.keys() - if cell_type in target_populations] + target_populations = [ + cell_type + for cell_type in self.cell_types.keys() + if cell_type in target_populations + ] # Ensure location exists for all target cells - cell_sections = [set(self.cell_types[cell_type].sections.keys()) for - cell_type in target_populations] - sect_locs = [set(self.cell_types[cell_type].sect_loc.keys()) for - cell_type in target_populations] + cell_sections = [ + set(self.cell_types[cell_type].sections.keys()) + for cell_type in target_populations + ] + sect_locs = [ + set(self.cell_types[cell_type].sect_loc.keys()) + for cell_type in target_populations + ] valid_cell_sections = set.intersection(*cell_sections) valid_sect_locs = set.intersection(*sect_locs) valid_loc = list(valid_cell_sections) + list(valid_sect_locs) - _check_option('location', location, valid_loc, - extra=(f" (the location '{location}' is not defined " - "for one of the targeted cells)")) + _check_option( + 'location', + location, + valid_loc, + extra=( + f" (the location '{location}' is not defined " + 'for one of the targeted cells)' + ), + ) if self._legacy_mode: # allows tests must match HNN GUI output by preserving original @@ -980,26 +1104,33 @@ def _attach_drive(self, name, drive, weights_ampa, weights_nmda, location, target_populations = list(self.cell_types.keys()) for target_type in target_populations: if target_type not in weights_by_type: - weights_by_type.update({target_type: {'ampa': 0.}}) + weights_by_type.update({target_type: {'ampa': 0.0}}) if target_type not in delays_by_type: delays_by_type.update({target_type: 0.1}) if target_type not in probability_by_type: probability_by_type.update({target_type: 1.0}) elif len(target_populations) == 0: - raise ValueError('No target populations have been specified for ' - 'this drive.') + raise ValueError( + 'No target populations have been specified for ' 'this drive.' + ) if cell_specific and n_drive_cells != 'n_cells': - raise ValueError(f"If cell_specific is True, n_drive_cells must" - f" equal 'n_cells'. Got {n_drive_cells}.") + raise ValueError( + f'If cell_specific is True, n_drive_cells must' + f" equal 'n_cells'. Got {n_drive_cells}." + ) elif not cell_specific: if not isinstance(n_drive_cells, int): - raise ValueError(f"If cell_specific is False, n_drive_cells " - f"must be of type int. Got " - f"{type(n_drive_cells)}.") + raise ValueError( + f'If cell_specific is False, n_drive_cells ' + f'must be of type int. Got ' + f'{type(n_drive_cells)}.' + ) if not n_drive_cells > 0: - raise ValueError('Number of drive cells must be greater than ' - f'0. Got {n_drive_cells}.') + raise ValueError( + 'Number of drive cells must be greater than ' + f'0. Got {n_drive_cells}.' + ) drive['name'] = name # for easier for-looping later drive['target_types'] = target_populations # for _connect_celltypes @@ -1029,40 +1160,53 @@ def _attach_drive(self, name, drive, weights_ampa, weights_nmda, location, delays = delays_by_type[target_cell_type] probability = probability_by_type[target_cell_type] if cell_specific: - target_gids_nested = [[target_gid] for - target_gid in target_gids] + target_gids_nested = [[target_gid] for target_gid in target_gids] src_idx_end = src_idx + len(target_gids) - src_gids = (list(self.gid_ranges[name]) - [src_idx:src_idx_end]) + src_gids = list(self.gid_ranges[name])[src_idx:src_idx_end] src_idx = src_idx_end for receptor_idx, receptor in enumerate( - weights_by_type[target_cell_type]): + weights_by_type[target_cell_type] + ): weights = weights_by_type[target_cell_type][receptor] self.add_connection( - src_gids=src_gids, target_gids=target_gids_nested, - loc=location, receptor=receptor, weight=weights, - delay=delays, lamtha=space_constant, + src_gids=src_gids, + target_gids=target_gids_nested, + loc=location, + receptor=receptor, + weight=weights, + delay=delays, + lamtha=space_constant, probability=probability, - conn_seed=drive['conn_seed'] + seed_increment) + conn_seed=drive['conn_seed'] + seed_increment, + ) # Ensure that AMPA/NMDA connections target the same gids if receptor_idx > 0: - self.connectivity[-1]['src_gids'] = \ - self.connectivity[-2]['src_gids'] + self.connectivity[-1]['src_gids'] = self.connectivity[-2][ + 'src_gids' + ] else: for receptor_idx, receptor in enumerate( - weights_by_type[target_cell_type]): + weights_by_type[target_cell_type] + ): weights = weights_by_type[target_cell_type][receptor] self.add_connection( - src_gids=name, target_gids=target_gids, loc=location, - receptor=receptor, weight=weights, delay=delays, - lamtha=space_constant, probability=probability, - conn_seed=drive['conn_seed'] + seed_increment) + src_gids=name, + target_gids=target_gids, + loc=location, + receptor=receptor, + weight=weights, + delay=delays, + lamtha=space_constant, + probability=probability, + conn_seed=drive['conn_seed'] + seed_increment, + ) # Ensure that AMPA/NMDA connections target the same gids # when probability < 1 if receptor_idx > 0: - self.connectivity[-1]['src_gids'] = \ - self.connectivity[-2]['src_gids'] + self.connectivity[-1]['src_gids'] = self.connectivity[-2][ + 'src_gids' + ] def _reset_drives(self): # reset every time called again, e.g., from dipole.py or in self.copy() @@ -1096,27 +1240,32 @@ def _instantiate_drives(self, tstop, n_trials=1): for drive in self.external_drives.values(): event_times = list() # new list for each trial and drive for drive_cell_gid in self.gid_ranges[drive['name']]: - drive_cell_gid_offset = (drive_cell_gid - - self.gid_ranges[drive['name']][0]) + drive_cell_gid_offset = ( + drive_cell_gid - self.gid_ranges[drive['name']][0] + ) trial_seed_offset = self._n_gids if drive['cell_specific']: # loop over drives (one for each target cell # population) and create event times - conn_idxs = pick_connection(self, - src_gids=drive_cell_gid) - target_types = set([self.connectivity[conn_idx] - ['target_type'] for conn_idx in - conn_idxs]) + conn_idxs = pick_connection(self, src_gids=drive_cell_gid) + target_types = set( + [ + self.connectivity[conn_idx]['target_type'] + for conn_idx in conn_idxs + ] + ) for target_type in target_types: - event_times.append(_drive_cell_event_times( - drive['type'], - drive['dynamics'], - target_type=target_type, - trial_idx=trial_idx, - drive_cell_gid=drive_cell_gid_offset, - event_seed=drive['event_seed'], - tstop=tstop, - trial_seed_offset=trial_seed_offset) + event_times.append( + _drive_cell_event_times( + drive['type'], + drive['dynamics'], + target_type=target_type, + trial_idx=trial_idx, + drive_cell_gid=drive_cell_gid_offset, + event_seed=drive['event_seed'], + tstop=tstop, + trial_seed_offset=trial_seed_offset, + ) ) else: src_event_times = _drive_cell_event_times( @@ -1127,11 +1276,11 @@ def _instantiate_drives(self, tstop, n_trials=1): trial_idx=trial_idx, drive_cell_gid=drive_cell_gid_offset, event_seed=drive['event_seed'], - trial_seed_offset=trial_seed_offset) + trial_seed_offset=trial_seed_offset, + ) event_times.append(src_event_times) # 'events': nested list (n_trials x n_drive_cells x n_events) - self.external_drives[ - drive['name']]['events'].append(event_times) + self.external_drives[drive['name']]['events'].append(event_times) def add_tonic_bias(self, *, cell_type=None, amplitude, t0=0, tstop=None): """Attaches parameters of tonic bias input for given cell types @@ -1160,28 +1309,41 @@ def add_tonic_bias(self, *, cell_type=None, amplitude, t0=0, tstop=None): # old functionality single cell type - amplitude if cell_type is not None: - warnings.warn('cell_type argument will be deprecated and ' - 'removed in future releases. Use amplitude as a ' - 'cell_type:str,amplitude:float dictionary.' - 'Read the function docustring for more information', - DeprecationWarning, - stacklevel=1) + warnings.warn( + 'cell_type argument will be deprecated and ' + 'removed in future releases. Use amplitude as a ' + 'cell_type:str,amplitude:float dictionary.' + 'Read the function docustring for more information', + DeprecationWarning, + stacklevel=1, + ) _validate_type(amplitude, (float, int), 'amplitude') - _add_cell_type_bias(network=self, cell_type=cell_type, - amplitude=float(amplitude), - t_0=t0, t_stop=tstop) + _add_cell_type_bias( + network=self, + cell_type=cell_type, + amplitude=float(amplitude), + t_0=t0, + t_stop=tstop, + ) else: _validate_type(amplitude, dict, 'amplitude') if len(amplitude) == 0: - warnings.warn('No bias have been defined, no action taken', - UserWarning, stacklevel=1) + warnings.warn( + 'No bias have been defined, no action taken', + UserWarning, + stacklevel=1, + ) return for _cell_type, _amplitude in amplitude.items(): - _add_cell_type_bias(network=self, cell_type=_cell_type, - amplitude=_amplitude, - t_0=t0, t_stop=tstop) + _add_cell_type_bias( + network=self, + cell_type=_cell_type, + amplitude=_amplitude, + t_0=t0, + t_stop=tstop, + ) def _add_cell_type(self, cell_name, pos, cell_template=None): """Add cell type by updating pos_dict and gid_ranges.""" @@ -1198,9 +1360,19 @@ def gid_to_type(self, gid): """Reverse lookup of gid to type.""" return _gid_to_type(gid, self.gid_ranges) - def add_connection(self, src_gids, target_gids, loc, receptor, - weight, delay, lamtha, allow_autapses=True, - probability=1.0, conn_seed=None): + def add_connection( + self, + src_gids, + target_gids, + loc, + receptor, + weight, + delay, + lamtha, + allow_autapses=True, + probability=1.0, + conn_seed=None, + ): """Appends connections to connectivity list Parameters @@ -1250,14 +1422,19 @@ def add_connection(self, src_gids, target_gids, loc, receptor, conn = _Connectivity() threshold = self.threshold - _validate_type(target_gids, (int, list, range, str), 'target_gids', - 'int list, range or str') + _validate_type( + target_gids, + (int, list, range, str), + 'target_gids', + 'int list, range or str', + ) _validate_type(allow_autapses, bool, 'target_gids', 'bool') valid_source_cells = list(self.gid_ranges.keys()) # Convert src_gids to list - src_gids = _check_gids(src_gids, self.gid_ranges, - valid_source_cells, 'src_gids') + src_gids = _check_gids( + src_gids, self.gid_ranges, valid_source_cells, 'src_gids' + ) # Convert target_gids to list of list, one element for each src_gid valid_target_cells = list(self.cell_types.keys()) @@ -1265,20 +1442,22 @@ def add_connection(self, src_gids, target_gids, loc, receptor, target_gids = [[target_gids] for _ in range(len(src_gids))] elif isinstance(target_gids, str): _check_option('target_gids', target_gids, valid_target_cells) - target_gids = [list(self.gid_ranges[_long_name(target_gids)]) - for _ in range(len(src_gids))] + target_gids = [ + list(self.gid_ranges[_long_name(target_gids)]) + for _ in range(len(src_gids)) + ] elif isinstance(target_gids, range): target_gids = [list(target_gids) for _ in range(len(src_gids))] - elif isinstance(target_gids, list) and all(isinstance(t_gid, int) - for t_gid in target_gids): + elif isinstance(target_gids, list) and all( + isinstance(t_gid, int) for t_gid in target_gids + ): target_gids = [target_gids for _ in range(len(src_gids))] # Validate each target list - src pairs. # set() used to avoid redundant checks. target_set = set() for target_src_pair in target_gids: - _validate_type(target_src_pair, list, 'target_gids[idx]', - 'list or range') + _validate_type(target_src_pair, list, 'target_gids[idx]', 'list or range') for target_gid in target_src_pair: target_set.add(target_gid) target_type = self.gid_to_type(target_gids[0][0]) @@ -1287,11 +1466,9 @@ def add_connection(self, src_gids, target_gids, loc, receptor, # Ensure gids in range of Network.gid_ranges gid_type = self.gid_to_type(target_gid) if gid_type is None: - raise AssertionError( - f'target_gid {target_gid}''not in net.gid_ranges') + raise AssertionError(f'target_gid {target_gid}' 'not in net.gid_ranges') elif gid_type != target_type: - raise AssertionError( - 'All target_gids must be of the same type') + raise AssertionError('All target_gids must be of the same type') conn['target_type'] = target_type conn['target_gids'] = target_set conn['num_targets'] = len(target_set) @@ -1319,12 +1496,14 @@ def add_connection(self, src_gids, target_gids, loc, receptor, target_sect_loc = self.cell_types[target_type].sect_loc target_sections = self.cell_types[target_type].sections - valid_loc = list( - target_sect_loc.keys()) + list(target_sections.keys()) - - _check_option('loc', loc, valid_loc, - extra=(f" (the loc '{loc}' is not defined " - f"for '{target_type}' cells)")) + valid_loc = list(target_sect_loc.keys()) + list(target_sections.keys()) + + _check_option( + 'loc', + loc, + valid_loc, + extra=(f" (the loc '{loc}' is not defined " f"for '{target_type}' cells)"), + ) conn['loc'] = loc # `loc` specifies a group of sections, all must contain the synapse @@ -1332,17 +1511,25 @@ def add_connection(self, src_gids, target_gids, loc, receptor, if loc in target_sect_loc: for sec_name in target_sect_loc[loc]: valid_receptor = target_sections[sec_name].syns - _check_option('receptor', receptor, valid_receptor, - extra=f" (the '{receptor}' receptor is not " - f"defined for the '{sec_name}' of" - f"'{target_type}' cells)") + _check_option( + 'receptor', + receptor, + valid_receptor, + extra=f" (the '{receptor}' receptor is not " + f"defined for the '{sec_name}' of" + f"'{target_type}' cells)", + ) # `loc` specifies an individual section else: valid_receptor = target_sections[loc].syns - _check_option('receptor', receptor, valid_receptor, - extra=f"(the '{receptor}' receptor is not " - f"defined for the '{loc}' of" - f"'{target_type}' cells)") + _check_option( + 'receptor', + receptor, + valid_receptor, + extra=f"(the '{receptor}' receptor is not " + f"defined for the '{loc}' of" + f"'{target_type}' cells)", + ) conn['receptor'] = receptor @@ -1367,8 +1554,7 @@ def add_connection(self, src_gids, target_gids, loc, receptor, self.connectivity.append(deepcopy(conn)) def clear_connectivity(self): - """Remove all connections defined in Network.connectivity - """ + """Remove all connections defined in Network.connectivity""" connectivity = list() for conn in self.connectivity: if conn['src_type'] in self.external_drives.keys(): @@ -1377,9 +1563,11 @@ def clear_connectivity(self): def clear_drives(self): """Remove all drives defined in Network.connectivity""" - self.connectivity = [conn for conn in self.connectivity if - conn['src_type'] not - in self.external_drives.keys()] + self.connectivity = [ + conn + for conn in self.connectivity + if conn['src_type'] not in self.external_drives.keys() + ] for cell_name in list(self.gid_ranges.keys()): if cell_name in self.external_drives: @@ -1389,8 +1577,9 @@ def clear_drives(self): self.external_drives = dict() - def add_electrode_array(self, name, electrode_pos, *, conductivity=0.3, - method='psa', min_distance=0.5): + def add_electrode_array( + self, name, electrode_pos, *, conductivity=0.3, method='psa', min_distance=0.5 + ): """Specify coordinates of electrode array for extracellular recording. Parameters @@ -1421,14 +1610,18 @@ def add_electrode_array(self, name, electrode_pos, *, conductivity=0.3, raise ValueError(f'{name} already exists, use another name!') # let ExtracellularArray perform all remaining argument checks - self.rec_arrays.update({ - name: ExtracellularArray(electrode_pos, - conductivity=conductivity, - method=method, - min_distance=min_distance)}) - - def update_weights(self, e_e=None, e_i=None, - i_e=None, i_i=None, copy=False): + self.rec_arrays.update( + { + name: ExtracellularArray( + electrode_pos, + conductivity=conductivity, + method=method, + min_distance=min_distance, + ) + } + ) + + def update_weights(self, e_e=None, e_i=None, i_e=None, i_i=None, copy=False): """Update synaptic weights of the network. Parameters @@ -1467,17 +1660,19 @@ def update_weights(self, e_e=None, e_i=None, net = self.copy() if copy else self e_conns = pick_connection(self, receptor=['ampa', 'nmda']) - e_cells = np.concatenate([list(net.connectivity[ - conn_idx]['src_gids']) for conn_idx in e_conns]).tolist() + e_cells = np.concatenate( + [list(net.connectivity[conn_idx]['src_gids']) for conn_idx in e_conns] + ).tolist() i_conns = pick_connection(self, receptor=['gabaa', 'gabab']) - i_cells = np.concatenate([list(net.connectivity[ - conn_idx]['src_gids']) for conn_idx in i_conns]).tolist() + i_cells = np.concatenate( + [list(net.connectivity[conn_idx]['src_gids']) for conn_idx in i_conns] + ).tolist() conn_types = { 'e_e': (e_e, e_cells, e_cells), 'e_i': (e_i, e_cells, i_cells), 'i_e': (i_e, i_cells, e_cells), - 'i_i': (i_i, i_cells, i_cells) + 'i_i': (i_i, i_cells, i_cells), } for conn_type, (gain, e_vals, i_vals) in conn_types.items(): @@ -1486,11 +1681,12 @@ def update_weights(self, e_e=None, e_i=None, _validate_type(gain, (int, float), conn_type, 'int or float') if gain < 0.0: - raise ValueError("Synaptic gains must be non-negative." - f"Got {gain} for '{conn_type}'.") + raise ValueError( + 'Synaptic gains must be non-negative.' + f"Got {gain} for '{conn_type}'." + ) - conn_indices = pick_connection(net, src_gids=e_vals, - target_gids=i_vals) + conn_indices = pick_connection(net, src_gids=e_vals, target_gids=i_vals) for conn_idx in conn_indices: net.connectivity[conn_idx]['nc_dict']['gain'] = gain @@ -1585,7 +1781,7 @@ def __repr__(self): entr += f"\nweight: {self['nc_dict']['A_weight']}; " entr += f"delay: {self['nc_dict']['A_delay']}; " entr += f"lamtha: {self['nc_dict']['lamtha']}" - entr += "\n " + entr += '\n ' return entr @@ -1636,34 +1832,36 @@ def __repr__(self): entr += f"\ntarget cell types: {self['target_types']}" entr += f"\nnumber of drive cells: {self['n_drive_cells']}" entr += f"\ncell-specific: {self['cell_specific']}" - entr += "\ndynamic parameters:" + entr += '\ndynamic parameters:' for key, val in self['dynamics'].items(): - entr += f"\n\t{key}: {val}" + entr += f'\n\t{key}: {val}' if len(self['events']) > 0: plurl = 's' if len(self['events']) > 1 else '' - entr += ("\nevent times instantiated for " - f"{len(self['events'])} trial{plurl}") + entr += ( + "\nevent times instantiated for " f"{len(self['events'])} trial{plurl}" + ) entr += '>' return entr -def _add_cell_type_bias(network: Network, amplitude: Union[float, dict], - cell_type=None, - t_0=0, t_stop=None): - +def _add_cell_type_bias( + network: Network, amplitude: Union[float, dict], cell_type=None, t_0=0, t_stop=None +): if network is None: - raise ValueError('The "network" parameter is required ' - 'but was not provided') + raise ValueError('The "network" parameter is required ' 'but was not provided') if amplitude is None: - raise ValueError('The "amplitude" parameter is required ' - 'but was not provided') + raise ValueError( + 'The "amplitude" parameter is required ' 'but was not provided' + ) if cell_type is not None: # Validate cell_type value if cell_type not in network.cell_types: - raise ValueError(f'cell_type must be one of ' - f'{list(network.cell_types.keys())}. ' - f'Got {cell_type}') + raise ValueError( + f'cell_type must be one of ' + f'{list(network.cell_types.keys())}. ' + f'Got {cell_type}' + ) if 'tonic' not in network.external_biases: network.external_biases['tonic'] = dict() @@ -1671,9 +1869,5 @@ def _add_cell_type_bias(network: Network, amplitude: Union[float, dict], if cell_type in network.external_biases['tonic']: raise ValueError(f'Tonic bias already defined for {cell_type}') - cell_type_bias = { - 'amplitude': amplitude, - 't0': t_0, - 'tstop': t_stop - } + cell_type_bias = {'amplitude': amplitude, 't0': t_0, 'tstop': t_stop} network.external_biases['tonic'][cell_type] = cell_type_bias diff --git a/hnn_core/network_builder.py b/hnn_core/network_builder.py index 230f20e06..a51df042b 100644 --- a/hnn_core/network_builder.py +++ b/hnn_core/network_builder.py @@ -13,6 +13,7 @@ # This is due to: https://github.com/neuronsimulator/nrn/pull/746 from neuron import __version__ + if int(__version__[0]) >= 8: h.nrnunit_use_legacy(1) @@ -44,7 +45,7 @@ def _simulate_single_trial(net, tstop, dt, trial_idx): global _PC, _CVODE - h.load_file("stdrun.hoc") + h.load_file('stdrun.hoc') rank = _get_rank() @@ -98,7 +99,8 @@ def simulation_time(): isec_py[gid] = dict() for sec_name, isec in isec_dict.items(): isec_py[gid][sec_name] = { - key: isec.to_python() for key, isec in isec.items()} + key: isec.to_python() for key, isec in isec.items() + } ca_py = dict() for gid, ca_dict in neuron_net._ca.items(): @@ -108,10 +110,10 @@ def simulation_time(): ca_py[gid][sec_name] = ca.to_python() dpl_data = np.c_[ - neuron_net._nrn_dipoles['L2_pyramidal'].as_numpy() + - neuron_net._nrn_dipoles['L5_pyramidal'].as_numpy(), + neuron_net._nrn_dipoles['L2_pyramidal'].as_numpy() + + neuron_net._nrn_dipoles['L5_pyramidal'].as_numpy(), neuron_net._nrn_dipoles['L2_pyramidal'].as_numpy(), - neuron_net._nrn_dipoles['L5_pyramidal'].as_numpy() + neuron_net._nrn_dipoles['L5_pyramidal'].as_numpy(), ] rec_arr_py = dict() @@ -120,16 +122,18 @@ def simulation_time(): rec_arr_py.update({arr_name: nrn_arr._get_nrn_voltages()}) rec_times_py.update({arr_name: nrn_arr._get_nrn_times()}) - data = {'dpl_data': dpl_data, - 'spike_times': neuron_net._all_spike_times.to_python(), - 'spike_gids': neuron_net._all_spike_gids.to_python(), - 'gid_ranges': net.gid_ranges, - 'vsec': vsec_py, - 'isec': isec_py, - 'ca': ca_py, - 'rec_data': rec_arr_py, - 'rec_times': rec_times_py, - 'times': times.to_python()} + data = { + 'dpl_data': dpl_data, + 'spike_times': neuron_net._all_spike_times.to_python(), + 'spike_gids': neuron_net._all_spike_gids.to_python(), + 'gid_ranges': net.gid_ranges, + 'vsec': vsec_py, + 'isec': isec_py, + 'ca': ca_py, + 'rec_data': rec_arr_py, + 'rec_times': rec_times_py, + 'times': times.to_python(), + } return data @@ -151,7 +155,6 @@ def _is_loaded_mechanisms(): def load_custom_mechanisms(): - if _is_loaded_mechanisms(): return @@ -337,10 +340,12 @@ def _build(self): record_vsec = self.net._params['record_vsec'] record_isec = self.net._params['record_isec'] record_ca = self.net._params['record_ca'] - self._create_cells_and_drives(threshold=self.net._params['threshold'], - record_vsec=record_vsec, - record_isec=record_isec, - record_ca=record_ca) + self._create_cells_and_drives( + threshold=self.net._params['threshold'], + record_vsec=record_vsec, + record_isec=record_isec, + record_ca=record_ca, + ) self.state_init() @@ -390,15 +395,17 @@ def _gid_assign(self, rank=None, n_hosts=None): conn_idxs = pick_connection(self.net, src_gids=src_gid) target_gids = set() for conn_idx in conn_idxs: - gid_pairs = self.net.connectivity[ - conn_idx]['gid_pairs'] + gid_pairs = self.net.connectivity[conn_idx]['gid_pairs'] if src_gid in gid_pairs: - target_gids.update(self.net.connectivity[conn_idx] - ['gid_pairs'][src_gid]) + target_gids.update( + self.net.connectivity[conn_idx]['gid_pairs'][src_gid] + ) for target_gid in target_gids: - if (target_gid in self._gid_list and - src_gid not in self._gid_list): + if ( + target_gid in self._gid_list + and src_gid not in self._gid_list + ): self._gid_list.append(src_gid) else: # round robin assignment of drive gids @@ -409,8 +416,9 @@ def _gid_assign(self, rank=None, n_hosts=None): # extremely important to get the gids in the right order self._gid_list.sort() - def _create_cells_and_drives(self, threshold, record_vsec=False, - record_isec=False, record_ca=False): + def _create_cells_and_drives( + self, threshold, record_vsec=False, record_isec=False, record_ca=False + ): """Parallel create cells AND external drives NB: _Cell.__init__ calls h.Section -> non-picklable! @@ -441,10 +449,13 @@ def _create_cells_and_drives(self, threshold, record_vsec=False, else: cell.build() # add tonic biases - if ('tonic' in self.net.external_biases and - src_type in self.net.external_biases['tonic']): - cell.create_tonic_bias(**self.net.external_biases - ['tonic'][src_type]) + if ( + 'tonic' in self.net.external_biases + and src_type in self.net.external_biases['tonic'] + ): + cell.create_tonic_bias( + **self.net.external_biases['tonic'][src_type] + ) cell.record(record_vsec, record_isec, record_ca) # this call could belong in init of a _Cell (with threshold)? @@ -455,8 +466,9 @@ def _create_cells_and_drives(self, threshold, record_vsec=False, # external driving inputs are special types of artificial-cells else: - event_times = self.net.external_drives[ - src_type]['events'][self.trial_idx][gid_idx] + event_times = self.net.external_drives[src_type]['events'][ + self.trial_idx + ][gid_idx] drive_cell = _ArtificialCell(event_times, threshold, gid=gid) _PC.cell(drive_cell.gid, drive_cell.nrn_netcon) self._drive_cells.append(drive_cell) @@ -500,14 +512,15 @@ def _connect_celltypes(self): src_type = self.net.gid_to_type(src_gid) target_type = self.net.gid_to_type(target_gid) target_cell = self._cells[target_filter[target_gid]] - connection_name = f'{_short_name(src_type)}_'\ - f'{_short_name(target_type)}_{receptor}' + connection_name = ( + f'{_short_name(src_type)}_' + f'{_short_name(target_type)}_{receptor}' + ) if connection_name not in self.ncs: self.ncs[connection_name] = list() pos_idx = src_gid - net.gid_ranges[_long_name(src_type)][0] # NB pos_dict for this drive must include ALL cell types! - nc_dict['pos_src'] = net.pos_dict[ - _long_name(src_type)][pos_idx] + nc_dict['pos_src'] = net.pos_dict[_long_name(src_type)][pos_idx] # get synapse locations syn_keys = list() @@ -521,9 +534,11 @@ def _connect_celltypes(self): for syn_key in syn_keys: nc = target_cell.parconnect_from_src( - src_gid, deepcopy(nc_dict), + src_gid, + deepcopy(nc_dict), target_cell._nrn_synapses[syn_key], - net._inplane_distance) + net._inplane_distance, + ) self.ncs[connection_name].append(nc) def _record_extracellular(self): @@ -567,10 +582,12 @@ def aggregate_data(self, n_samples): # add dipoles across neurons on the current thread if hasattr(cell, 'dipole'): if cell.dipole.size() != n_samples: - raise ValueError(f"n_samples does not match the size " - f"of at least one cell's dipole vector. " - f"Got n_samples={n_samples}, {cell.name}." - f"dipole.size()={cell.dipole.size()}.") + raise ValueError( + f'n_samples does not match the size ' + f"of at least one cell's dipole vector. " + f'Got n_samples={n_samples}, {cell.name}.' + f'dipole.size()={cell.dipole.size()}.' + ) nrn_dpl = self._nrn_dipoles[_long_name(cell.name)] nrn_dpl.add(cell.dipole) @@ -626,7 +643,7 @@ def state_init(self): elif sect.name() == 'L5Pyr_apical_tuft': seg.v = -67.30 else: - seg.v = -72. + seg.v = -72.0 elif cell.name == 'L2Basket': seg.v = -64.9737 elif cell.name == 'L5Basket': diff --git a/hnn_core/network_models.py b/hnn_core/network_models.py index 2f6caddfb..bddb07eb6 100644 --- a/hnn_core/network_models.py +++ b/hnn_core/network_models.py @@ -11,8 +11,9 @@ from .externals.mne import _validate_type -def jones_2009_model(params=None, add_drives_from_params=False, - legacy_mode=False, mesh_shape=(10, 10)): +def jones_2009_model( + params=None, add_drives_from_params=False, legacy_mode=False, mesh_shape=(10, 10) +): """Instantiate the network model described in Jones et al. J. of Neurophys. 2009 [1]_ @@ -60,8 +61,12 @@ def jones_2009_model(params=None, add_drives_from_params=False, if isinstance(params, str): params = read_params(params) - net = Network(params, add_drives_from_params=add_drives_from_params, - legacy_mode=legacy_mode, mesh_shape=mesh_shape) + net = Network( + params, + add_drives_from_params=add_drives_from_params, + legacy_mode=legacy_mode, + mesh_shape=mesh_shape, + ) delay = net.delay @@ -73,110 +78,119 @@ def jones_2009_model(params=None, add_drives_from_params=False, loc = 'proximal' for target_cell in ['L2_pyramidal', 'L5_pyramidal']: for receptor in ['nmda', 'ampa']: - key = f'gbar_{_short_name(target_cell)}_'\ - f'{_short_name(target_cell)}_{receptor}' + key = ( + f'gbar_{_short_name(target_cell)}_' + f'{_short_name(target_cell)}_{receptor}' + ) weight = net._params[key] net.add_connection( - target_cell, target_cell, loc, receptor, weight, - delay, lamtha, allow_autapses=False) + target_cell, + target_cell, + loc, + receptor, + weight, + delay, + lamtha, + allow_autapses=False, + ) # layer2 Basket -> layer2 Pyr src_cell = 'L2_basket' target_cell = 'L2_pyramidal' - lamtha = 50. + lamtha = 50.0 loc = 'soma' for receptor in ['gabaa', 'gabab']: key = f'gbar_L2Basket_L2Pyr_{receptor}' weight = net._params[key] - net.add_connection( - src_cell, target_cell, loc, receptor, weight, delay, lamtha) + net.add_connection(src_cell, target_cell, loc, receptor, weight, delay, lamtha) # layer5 Basket -> layer5 Pyr src_cell = 'L5_basket' target_cell = 'L5_pyramidal' - lamtha = 70. + lamtha = 70.0 loc = 'soma' for receptor in ['gabaa', 'gabab']: key = f'gbar_L5Basket_{_short_name(target_cell)}_{receptor}' weight = net._params[key] - net.add_connection( - src_cell, target_cell, loc, receptor, weight, delay, lamtha) + net.add_connection(src_cell, target_cell, loc, receptor, weight, delay, lamtha) # layer2 Pyr -> layer5 Pyr src_cell = 'L2_pyramidal' - lamtha = 3. + lamtha = 3.0 receptor = 'ampa' for loc in ['proximal', 'distal']: key = f'gbar_L2Pyr_{_short_name(target_cell)}' weight = net._params[key] - net.add_connection( - src_cell, target_cell, loc, receptor, weight, delay, lamtha) + net.add_connection(src_cell, target_cell, loc, receptor, weight, delay, lamtha) # layer2 Basket -> layer5 Pyr src_cell = 'L2_basket' - lamtha = 50. + lamtha = 50.0 key = f'gbar_L2Basket_{_short_name(target_cell)}' weight = net._params[key] loc = 'distal' receptor = 'gabaa' - net.add_connection( - src_cell, target_cell, loc, receptor, weight, delay, lamtha) + net.add_connection(src_cell, target_cell, loc, receptor, weight, delay, lamtha) # xx -> layer2 Basket src_cell = 'L2_pyramidal' target_cell = 'L2_basket' - lamtha = 3. + lamtha = 3.0 key = f'gbar_L2Pyr_{_short_name(target_cell)}' weight = net._params[key] loc = 'soma' receptor = 'ampa' - net.add_connection( - src_cell, target_cell, loc, receptor, weight, delay, lamtha) + net.add_connection(src_cell, target_cell, loc, receptor, weight, delay, lamtha) src_cell = 'L2_basket' - lamtha = 20. + lamtha = 20.0 key = f'gbar_L2Basket_{_short_name(target_cell)}' weight = net._params[key] loc = 'soma' receptor = 'gabaa' - net.add_connection( - src_cell, target_cell, loc, receptor, weight, delay, lamtha) + net.add_connection(src_cell, target_cell, loc, receptor, weight, delay, lamtha) # xx -> layer5 Basket src_cell = 'L5_basket' target_cell = 'L5_basket' - lamtha = 20. + lamtha = 20.0 loc = 'soma' receptor = 'gabaa' key = f'gbar_L5Basket_{_short_name(target_cell)}' weight = net._params[key] net.add_connection( - src_cell, target_cell, loc, receptor, weight, delay, lamtha, - allow_autapses=False) + src_cell, + target_cell, + loc, + receptor, + weight, + delay, + lamtha, + allow_autapses=False, + ) src_cell = 'L5_pyramidal' - lamtha = 3. + lamtha = 3.0 key = f'gbar_L5Pyr_{_short_name(target_cell)}' weight = net._params[key] loc = 'soma' receptor = 'ampa' - net.add_connection( - src_cell, target_cell, loc, receptor, weight, delay, lamtha) + net.add_connection(src_cell, target_cell, loc, receptor, weight, delay, lamtha) src_cell = 'L2_pyramidal' - lamtha = 3. + lamtha = 3.0 key = f'gbar_L2Pyr_{_short_name(target_cell)}' weight = net._params[key] loc = 'soma' receptor = 'ampa' - net.add_connection( - src_cell, target_cell, loc, receptor, weight, delay, lamtha) + net.add_connection(src_cell, target_cell, loc, receptor, weight, delay, lamtha) return net -def law_2021_model(params=None, add_drives_from_params=False, - legacy_mode=False, mesh_shape=(10, 10)): +def law_2021_model( + params=None, add_drives_from_params=False, legacy_mode=False, mesh_shape=(10, 10) +): """Instantiate the expansion of Jones 2009 model to study beta modulated ERPs as described in Law et al. Cereb. Cortex 2021 [1]_ @@ -210,8 +224,9 @@ def law_2021_model(params=None, add_drives_from_params=False, Perception." Cerebral Cortex, 32, 668–688 (2022). """ - net = jones_2009_model(params, add_drives_from_params, legacy_mode, - mesh_shape=mesh_shape) + net = jones_2009_model( + params, add_drives_from_params, legacy_mode, mesh_shape=mesh_shape + ) # Update biophysics (increase gabab duration of inhibition) net.cell_types['L2_pyramidal'].synapses['gabab']['tau1'] = 45.0 @@ -228,8 +243,7 @@ def law_2021_model(params=None, add_drives_from_params=False, # Remove L5 pyramidal somatic and basal dendrite calcium channels for sec in ['soma', 'basal_1', 'basal_2', 'basal_3']: - del net.cell_types['L5_pyramidal'].sections[ - sec].mechs['ca'] + del net.cell_types['L5_pyramidal'].sections[sec].mechs['ca'] # Remove L2_basket -> L5_pyramidal gabaa connection del net.connectivity[10] # Original paper simply sets gbar to 0.0 @@ -238,32 +252,31 @@ def law_2021_model(params=None, add_drives_from_params=False, delay = net.delay src_cell = 'L2_basket' target_cell = 'L5_pyramidal' - lamtha = 50. + lamtha = 50.0 weight = 0.0002 loc = 'distal' receptor = 'gabab' - net.add_connection( - src_cell, target_cell, loc, receptor, weight, delay, lamtha) + net.add_connection(src_cell, target_cell, loc, receptor, weight, delay, lamtha) # Add L5_basket -> L5_pyramidal distal connection # ("Martinotti-like recurrent tuft connection") src_cell = 'L5_basket' target_cell = 'L5_pyramidal' - lamtha = 70. + lamtha = 70.0 loc = 'distal' receptor = 'gabaa' key = f'gbar_L5Basket_L5Pyr_{receptor}' weight = net._params[key] - net.add_connection( - src_cell, target_cell, loc, receptor, weight, delay, lamtha) + net.add_connection(src_cell, target_cell, loc, receptor, weight, delay, lamtha) return net # Remove params argument after updating examples # (only relevant for Jones 2009 model) -def calcium_model(params=None, add_drives_from_params=False, - legacy_mode=False, mesh_shape=(10, 10)): +def calcium_model( + params=None, add_drives_from_params=False, legacy_mode=False, mesh_shape=(10, 10) +): """Instantiate the Jones 2009 model with improved calcium dynamics in L5 pyramidal neurons. For more details on changes to calcium dynamics see Kohl et al. Brain Topragr 2022 [1]_ @@ -297,14 +310,14 @@ def calcium_model(params=None, add_drives_from_params=False, if params is None: params = read_params(params_fname) - net = jones_2009_model(params, add_drives_from_params, legacy_mode, - mesh_shape=mesh_shape) + net = jones_2009_model( + params, add_drives_from_params, legacy_mode, mesh_shape=mesh_shape + ) # Replace L5 pyramidal cell template with updated calcium cell_name = 'L5_pyramidal' pos = net.cell_types[cell_name].pos - net.cell_types[cell_name] = pyramidal_ca( - cell_name=_short_name(cell_name), pos=pos) + net.cell_types[cell_name] = pyramidal_ca(cell_name=_short_name(cell_name), pos=pos) return net @@ -330,30 +343,67 @@ def add_erp_drives_to_jones_model(net, tstart=0.0): _validate_type(tstart, (float, int), 'tstart', 'float or int') # Add distal drive - weights_ampa_d1 = {'L2_basket': 0.006562, 'L2_pyramidal': 7e-6, - 'L5_pyramidal': 0.142300} - weights_nmda_d1 = {'L2_basket': 0.019482, 'L2_pyramidal': 0.004317, - 'L5_pyramidal': 0.080074} - synaptic_delays_d1 = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_pyramidal': 0.1} + weights_ampa_d1 = { + 'L2_basket': 0.006562, + 'L2_pyramidal': 7e-6, + 'L5_pyramidal': 0.142300, + } + weights_nmda_d1 = { + 'L2_basket': 0.019482, + 'L2_pyramidal': 0.004317, + 'L5_pyramidal': 0.080074, + } + synaptic_delays_d1 = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, 'L5_pyramidal': 0.1} net.add_evoked_drive( - 'evdist1', mu=63.53 + tstart, sigma=3.85, numspikes=1, - weights_ampa=weights_ampa_d1, weights_nmda=weights_nmda_d1, - location='distal', synaptic_delays=synaptic_delays_d1, event_seed=274) + 'evdist1', + mu=63.53 + tstart, + sigma=3.85, + numspikes=1, + weights_ampa=weights_ampa_d1, + weights_nmda=weights_nmda_d1, + location='distal', + synaptic_delays=synaptic_delays_d1, + event_seed=274, + ) # Add proximal drives - weights_ampa_p1 = {'L2_basket': 0.08831, 'L2_pyramidal': 0.01525, - 'L5_basket': 0.19934, 'L5_pyramidal': 0.00865} - synaptic_delays_prox = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_basket': 1., 'L5_pyramidal': 1.} + weights_ampa_p1 = { + 'L2_basket': 0.08831, + 'L2_pyramidal': 0.01525, + 'L5_basket': 0.19934, + 'L5_pyramidal': 0.00865, + } + synaptic_delays_prox = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.1, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.0, + } net.add_evoked_drive( - 'evprox1', mu=26.61 + tstart, sigma=2.47, numspikes=1, - weights_ampa=weights_ampa_p1, weights_nmda=None, location='proximal', - synaptic_delays=synaptic_delays_prox, event_seed=544) - - weights_ampa_p2 = {'L2_basket': 0.000003, 'L2_pyramidal': 1.438840, - 'L5_basket': 0.008958, 'L5_pyramidal': 0.684013} + 'evprox1', + mu=26.61 + tstart, + sigma=2.47, + numspikes=1, + weights_ampa=weights_ampa_p1, + weights_nmda=None, + location='proximal', + synaptic_delays=synaptic_delays_prox, + event_seed=544, + ) + + weights_ampa_p2 = { + 'L2_basket': 0.000003, + 'L2_pyramidal': 1.438840, + 'L5_basket': 0.008958, + 'L5_pyramidal': 0.684013, + } net.add_evoked_drive( - 'evprox2', mu=137.12 + tstart, sigma=8.33, numspikes=1, - weights_ampa=weights_ampa_p2, location='proximal', - synaptic_delays=synaptic_delays_prox, event_seed=814) + 'evprox2', + mu=137.12 + tstart, + sigma=8.33, + numspikes=1, + weights_ampa=weights_ampa_p2, + location='proximal', + synaptic_delays=synaptic_delays_prox, + event_seed=814, + ) diff --git a/hnn_core/optimization/general_optimization.py b/hnn_core/optimization/general_optimization.py index 4bfb7efcd..d195da861 100644 --- a/hnn_core/optimization/general_optimization.py +++ b/hnn_core/optimization/general_optimization.py @@ -11,8 +11,16 @@ class Optimizer: - def __init__(self, initial_net, tstop, constraints, set_params, - solver='bayesian', obj_fun='dipole_rmse', max_iter=200): + def __init__( + self, + initial_net, + tstop, + constraints, + set_params, + solver='bayesian', + obj_fun='dipole_rmse', + max_iter=200, + ): """Parameter optimization. Parameters @@ -63,9 +71,11 @@ def __init__(self, initial_net, tstop, constraints, set_params, """ if initial_net.external_drives: - raise ValueError("The current Network instance has external " + - "drives, provide a Network object with no " + - "external drives.") + raise ValueError( + 'The current Network instance has external ' + + 'drives, provide a Network object with no ' + + 'external drives.' + ) self._initial_net = initial_net self.constraints = constraints self._set_params = set_params @@ -102,7 +112,7 @@ def __repr__(self): is_fit = True name = self.__class__.__name__ - return f"<{name}\nsolver={self.solver}\nfit={is_fit}>" + return f'<{name}\nsolver={self.solver}\nfit={is_fit}>' def fit(self, **obj_fun_kwargs): """Runs optimization routine. @@ -120,25 +130,27 @@ def fit(self, **obj_fun_kwargs): smooth_window_len : float, optional The smooth window length. """ - if (self.obj_fun_name == 'dipole_rmse' and - 'target' not in obj_fun_kwargs): + if self.obj_fun_name == 'dipole_rmse' and 'target' not in obj_fun_kwargs: raise Exception('target must be specified') - elif (self.obj_fun_name == 'maximize_psd' and - ('f_bands' not in obj_fun_kwargs or - 'relative_bandpower' not in obj_fun_kwargs)): + elif self.obj_fun_name == 'maximize_psd' and ( + 'f_bands' not in obj_fun_kwargs + or 'relative_bandpower' not in obj_fun_kwargs + ): raise Exception('f_bands and relative_bandpower must be specified') constraints = self._assemble_constraints(self.constraints) initial_params = _get_initial_params(self.constraints) - opt_params, obj, net_ = self._run_opt(self._initial_net, - self.tstop, - constraints, - self._set_params, - self.obj_fun, - initial_params, - self.max_iter, - obj_fun_kwargs) + opt_params, obj, net_ = self._run_opt( + self._initial_net, + self.tstop, + constraints, + self._set_params, + self.obj_fun, + initial_params, + self.max_iter, + obj_fun_kwargs, + ) self.net_ = net_ self.obj_ = obj @@ -199,8 +211,9 @@ def _get_initial_params(constraints): initial_params = dict() for cons_key in constraints: - initial_params.update({cons_key: ((constraints[cons_key][0] + - constraints[cons_key][1])) / 2}) + initial_params.update( + {cons_key: (constraints[cons_key][0] + constraints[cons_key][1]) / 2} + ) return initial_params @@ -241,10 +254,8 @@ def _assemble_constraints_cobyla(constraints): # assemble constraints in solver-specific format cons_cobyla = list() for idx, cons_key in enumerate(constraints): - cons_cobyla.append(lambda x, idx=idx: - float(constraints[cons_key][1]) - x[idx]) - cons_cobyla.append(lambda x, idx=idx: - x[idx] - float(constraints[cons_key][0])) + cons_cobyla.append(lambda x, idx=idx: float(constraints[cons_key][1]) - x[idx]) + cons_cobyla.append(lambda x, idx=idx: x[idx] - float(constraints[cons_key][0])) return cons_cobyla @@ -272,8 +283,16 @@ def _update_params(initial_params, predicted_params): return params -def _run_opt_bayesian(initial_net, tstop, constraints, set_params, obj_fun, - initial_params, max_iter, obj_fun_kwargs): +def _run_opt_bayesian( + initial_net, + tstop, + constraints, + set_params, + obj_fun, + initial_params, + max_iter, + obj_fun_kwargs, +): """Runs optimization routine with gp_minimize optimizer. Parameters @@ -308,20 +327,24 @@ def _run_opt_bayesian(initial_net, tstop, constraints, set_params, obj_fun, obj_values = list() def _obj_func(predicted_params): - return obj_fun(initial_net=initial_net, - initial_params=initial_params, - set_params=set_params, - predicted_params=predicted_params, - update_params=_update_params, - obj_values=obj_values, - tstop=tstop, - obj_fun_kwargs=obj_fun_kwargs) - - opt_results, _ = bayes_opt(func=_obj_func, - x0=list(initial_params.values()), - cons=constraints, - acquisition=expected_improvement, - maxfun=max_iter) + return obj_fun( + initial_net=initial_net, + initial_params=initial_params, + set_params=set_params, + predicted_params=predicted_params, + update_params=_update_params, + obj_values=obj_values, + tstop=tstop, + obj_fun_kwargs=obj_fun_kwargs, + ) + + opt_results, _ = bayes_opt( + func=_obj_func, + x0=list(initial_params.values()), + cons=constraints, + acquisition=expected_improvement, + maxfun=max_iter, + ) # get optimized params opt_params = opt_results @@ -337,8 +360,16 @@ def _obj_func(predicted_params): return opt_params, obj, net_ -def _run_opt_cobyla(initial_net, tstop, constraints, set_params, obj_fun, - initial_params, max_iter, obj_fun_kwargs): +def _run_opt_cobyla( + initial_net, + tstop, + constraints, + set_params, + obj_fun, + initial_params, + max_iter, + obj_fun_kwargs, +): """Runs optimization routine with fmin_cobyla optimizer. Parameters @@ -373,22 +404,26 @@ def _run_opt_cobyla(initial_net, tstop, constraints, set_params, obj_fun, obj_values = list() def _obj_func(predicted_params): - return obj_fun(initial_net=initial_net, - initial_params=initial_params, - set_params=set_params, - predicted_params=predicted_params, - update_params=_update_params, - obj_values=obj_values, - tstop=tstop, - obj_fun_kwargs=obj_fun_kwargs) - - opt_results = fmin_cobyla(_obj_func, - cons=constraints, - rhobeg=0.1, - rhoend=1e-4, - x0=list(initial_params.values()), - maxfun=max_iter, - catol=0.0) + return obj_fun( + initial_net=initial_net, + initial_params=initial_params, + set_params=set_params, + predicted_params=predicted_params, + update_params=_update_params, + obj_values=obj_values, + tstop=tstop, + obj_fun_kwargs=obj_fun_kwargs, + ) + + opt_results = fmin_cobyla( + _obj_func, + cons=constraints, + rhobeg=0.1, + rhoend=1e-4, + x0=list(initial_params.values()), + maxfun=max_iter, + catol=0.0, + ) # get optimized params opt_params = opt_results diff --git a/hnn_core/optimization/objective_functions.py b/hnn_core/optimization/objective_functions.py index 0d74a5c78..e97548831 100644 --- a/hnn_core/optimization/objective_functions.py +++ b/hnn_core/optimization/objective_functions.py @@ -9,8 +9,16 @@ from ..dipole import _rmse -def _rmse_evoked(initial_net, initial_params, set_params, predicted_params, - update_params, obj_values, tstop, obj_fun_kwargs): +def _rmse_evoked( + initial_net, + initial_params, + set_params, + predicted_params, + update_params, + obj_values, + tstop, + obj_fun_kwargs, +): """The objective function for evoked responses. Parameters @@ -56,8 +64,16 @@ def _rmse_evoked(initial_net, initial_params, set_params, predicted_params, return obj -def _maximize_psd(initial_net, initial_params, set_params, predicted_params, - update_params, obj_values, tstop, obj_fun_kwargs): +def _maximize_psd( + initial_net, + initial_params, + set_params, + predicted_params, + update_params, + obj_values, + tstop, + obj_fun_kwargs, +): """The objective function for PSDs. Parameters @@ -113,16 +129,19 @@ def _maximize_psd(initial_net, initial_params, set_params, predicted_params, # resample? # get psd of simulated dpl - freqs_simulated, psd_simulated = periodogram(dpl.data['agg'], dpl.sfreq, - window='hamming') + freqs_simulated, psd_simulated = periodogram( + dpl.data['agg'], dpl.sfreq, window='hamming' + ) # for each f band f_bands_psds = list() for idx, f_band in enumerate(obj_fun_kwargs['f_bands']): - f_band_idx = np.where(np.logical_and(freqs_simulated >= f_band[0], - freqs_simulated <= f_band[1]))[0] - f_bands_psds.append(-obj_fun_kwargs['relative_bandpower'][idx] * - sum(psd_simulated[f_band_idx])) + f_band_idx = np.where( + np.logical_and(freqs_simulated >= f_band[0], freqs_simulated <= f_band[1]) + )[0] + f_bands_psds.append( + -obj_fun_kwargs['relative_bandpower'][idx] * sum(psd_simulated[f_band_idx]) + ) # grand sum obj = sum(f_bands_psds) / sum(psd_simulated) diff --git a/hnn_core/optimization/optimize_evoked.py b/hnn_core/optimization/optimize_evoked.py index 381a4c1f0..1efc87a5b 100644 --- a/hnn_core/optimization/optimize_evoked.py +++ b/hnn_core/optimization/optimize_evoked.py @@ -18,16 +18,22 @@ def _get_range(val, multiplier): """Get range of values to sweep over.""" - range_min = max(0, val - val * multiplier / 100.) - range_max = val + val * multiplier / 100. + range_min = max(0, val - val * multiplier / 100.0) + range_max = val + val * multiplier / 100.0 ranges = {'initial': val, 'minval': range_min, 'maxval': range_max} return ranges -def _split_by_evinput(drive_names, drive_dynamics, drive_syn_weights, tstop, - sigma_range_multiplier, timing_range_multiplier, - synweight_range_multiplier): - """ Sorts parameter ranges by evoked inputs into a dictionary +def _split_by_evinput( + drive_names, + drive_dynamics, + drive_syn_weights, + tstop, + sigma_range_multiplier, + timing_range_multiplier, + synweight_range_multiplier, +): + """Sorts parameter ranges by evoked inputs into a dictionary Parameters ---------- @@ -77,12 +83,15 @@ def _split_by_evinput(drive_names, drive_dynamics, drive_syn_weights, tstop, # sigma of 0 will not produce a CDF timing_sigma = 0.01 - evinput_params[drive_name] = {'mean': timing_mean, - 'sigma': timing_sigma, - 'ranges': {}} + evinput_params[drive_name] = { + 'mean': timing_mean, + 'sigma': timing_sigma, + 'ranges': {}, + } - evinput_params[drive_name]['ranges'][f'{drive_name}_sigma'] = \ - _get_range(timing_sigma, sigma_range_multiplier) + evinput_params[drive_name]['ranges'][f'{drive_name}_sigma'] = _get_range( + timing_sigma, sigma_range_multiplier + ) # calculate range for time timing_bound = timing_sigma * timing_range_multiplier @@ -91,8 +100,11 @@ def _split_by_evinput(drive_names, drive_dynamics, drive_syn_weights, tstop, evinput_params[drive_name]['start'] = range_min evinput_params[drive_name]['end'] = range_max - evinput_params[drive_name]['ranges'][f'{drive_name}_mu'] = \ - {'initial': timing_mean, 'minval': range_min, 'maxval': range_max} + evinput_params[drive_name]['ranges'][f'{drive_name}_mu'] = { + 'initial': timing_mean, + 'minval': range_min, + 'maxval': range_max, + } # calculate ranges for syn. weights for syn_weight_key in drive_syn_weights[drive_idx]: @@ -104,8 +116,9 @@ def _split_by_evinput(drive_names, drive_dynamics, drive_syn_weights, tstop, ranges['maxval'] = 1.0 evinput_params[drive_name]['ranges'][new_key] = ranges - sorted_evinput_params = OrderedDict(sorted(evinput_params.items(), - key=lambda x: x[1]['start'])) + sorted_evinput_params = OrderedDict( + sorted(evinput_params.items(), key=lambda x: x[1]['start']) + ) return sorted_evinput_params @@ -125,7 +138,8 @@ def _generate_weights(evinput_params, tstop, dt, decay_multiplier): for evinput_this in evinput_params.values(): # calculate cdf using start time (minival of optimization range) evinput_this['cdf'] = stats.norm.cdf( - times, evinput_this['start'], evinput_this['sigma']) + times, evinput_this['start'], evinput_this['sigma'] + ) for input_name, evinput_this in evinput_params.items(): evinput_this['weights'] = evinput_this['cdf'].copy() @@ -133,29 +147,32 @@ def _generate_weights(evinput_params, tstop, dt, decay_multiplier): for other_input, evinput_other in evinput_params.items(): # check ordering to only use inputs after us # and don't subtract our own cdf(s) - if (evinput_other['mean'] < evinput_this['mean'] or - input_name == other_input): + if ( + evinput_other['mean'] < evinput_this['mean'] + or input_name == other_input + ): continue - decay_factor = decay_multiplier * \ - (evinput_other['mean'] - evinput_this['mean']) / tstop + decay_factor = ( + decay_multiplier + * (evinput_other['mean'] - evinput_this['mean']) + / tstop + ) evinput_this['weights'] -= evinput_other['cdf'] * decay_factor # weights should not drop below 0 - np.clip(evinput_this['weights'], a_min=0, a_max=None, - out=evinput_this['weights']) + np.clip( + evinput_this['weights'], a_min=0, a_max=None, out=evinput_this['weights'] + ) # start and stop optimization where the weights are insignificant indices = np.where(evinput_this['weights'] > 0.01) - evinput_this['opt_start'] = min(evinput_this['start'], - times[indices][0]) - evinput_this['opt_end'] = max(evinput_this['end'], - times[indices][-1]) + evinput_this['opt_start'] = min(evinput_this['start'], times[indices][0]) + evinput_this['opt_end'] = max(evinput_this['end'], times[indices][-1]) # convert to multiples of dt evinput_this['opt_start'] = floor(evinput_this['opt_start'] / dt) * dt - evinput_params[input_name]['opt_end'] = ceil( - evinput_this['opt_end'] / dt) * dt + evinput_params[input_name]['opt_end'] = ceil(evinput_this['opt_end'] / dt) * dt for evinput_this in evinput_params.values(): del evinput_this['mean'], evinput_this['sigma'], evinput_this['cdf'] @@ -164,7 +181,7 @@ def _generate_weights(evinput_params, tstop, dt, decay_multiplier): def _create_last_chunk(input_chunks): - """ This creates a chunk that combines parameters for + """This creates a chunk that combines parameters for all chunks in input_chunks (final step) Parameters @@ -178,8 +195,7 @@ def _create_last_chunk(input_chunks): Dictionary of with parameters for combined chunk (final step) """ - chunk = {'inputs': [], 'ranges': {}, 'opt_start': 0.0, - 'opt_end': 0.0} + chunk = {'inputs': [], 'ranges': {}, 'opt_start': 0.0, 'opt_end': 0.0} for evinput in input_chunks: chunk['inputs'].extend(evinput['inputs']) @@ -214,17 +230,14 @@ def _consolidate_chunks(inputs): input_dict = inputs[input_name].copy() input_dict['inputs'] = [input_name] - if (len(chunks) > 0 and - input_dict['start'] <= chunks[-1]['end']): + if len(chunks) > 0 and input_dict['start'] <= chunks[-1]['end']: # update previous chunk chunks[-1]['inputs'].extend(input_dict['inputs']) chunks[-1]['end'] = input_dict['end'] chunks[-1]['ranges'].update(input_dict['ranges']) - chunks[-1]['opt_end'] = max(chunks[-1]['opt_end'], - input_dict['opt_end']) + chunks[-1]['opt_end'] = max(chunks[-1]['opt_end'], input_dict['opt_end']) # average the weights - chunks[-1]['weights'] = (chunks[-1]['weights'] + - input_dict['weights']) / 2 + chunks[-1]['weights'] = (chunks[-1]['weights'] + input_dict['weights']) / 2 else: # new chunk chunks.append(input_dict) @@ -237,9 +250,19 @@ def _consolidate_chunks(inputs): return chunks -def _optrun(drive_params_updated, drive_params_static, net, tstop, dt, - n_trials, opt_params, opt_dpls, scale_factor, smooth_window_len, - return_rmse): +def _optrun( + drive_params_updated, + drive_params_static, + net, + tstop, + dt, + n_trials, + opt_params, + opt_dpls, + scale_factor, + smooth_window_len, + return_rmse, +): """This is the function to run a simulation Parameters @@ -280,13 +303,16 @@ def _optrun(drive_params_updated, drive_params_static, net, tstop, dt, avg_rmse: float Weighted RMSE between data in dpl and exp_dpl """ - print("Optimization step %d, iteration %d" % (opt_params['cur_step'] + 1, - opt_params['optiter'] + 1)) + print( + 'Optimization step %d, iteration %d' + % (opt_params['cur_step'] + 1, opt_params['optiter'] + 1) + ) # match parameter values contained in list to their respective key names params_dict = dict() - for param_name, test_value in zip(opt_params['ranges'].keys(), - drive_params_updated): + for param_name, test_value in zip( + opt_params['ranges'].keys(), drive_params_updated + ): # tiny negative weights are possible. Clip them to 0. if test_value < 0: test_value = 0 @@ -294,24 +320,27 @@ def _optrun(drive_params_updated, drive_params_static, net, tstop, dt, # modify drives according to the drive names in the current chunk for drive_name in opt_params['inputs']: - # clear drive and its connectivity del net.external_drives[drive_name] conn_idxs = pick_connection(net, src_gids=drive_name) - net.connectivity = [conn for conn_idx, conn - in enumerate(net.connectivity) - if conn_idx not in conn_idxs] + net.connectivity = [ + conn + for conn_idx, conn in enumerate(net.connectivity) + if conn_idx not in conn_idxs + ] # extract syn weights: final weights dicts should have keys that # correspond to cell types - keys_ampa = fnmatch.filter(params_dict.keys(), - f'{drive_name}_gbar_ampa_*') - keys_nmda = fnmatch.filter(params_dict.keys(), - f'{drive_name}_gbar_nmda_*') - weights_ampa = {key.lstrip(f'{drive_name}_gbar_ampa_'): - params_dict[key] for key in keys_ampa} - weights_nmda = {key.lstrip(f'{drive_name}_gbar_nmda_'): - params_dict[key] for key in keys_nmda} + keys_ampa = fnmatch.filter(params_dict.keys(), f'{drive_name}_gbar_ampa_*') + keys_nmda = fnmatch.filter(params_dict.keys(), f'{drive_name}_gbar_nmda_*') + weights_ampa = { + key.lstrip(f'{drive_name}_gbar_ampa_'): params_dict[key] + for key in keys_ampa + } + weights_nmda = { + key.lstrip(f'{drive_name}_gbar_nmda_'): params_dict[key] + for key in keys_nmda + } net.add_evoked_drive( name=drive_name, @@ -327,7 +356,7 @@ def _optrun(drive_params_updated, drive_params_static, net, tstop, dt, synaptic_delays=drive_params_static[drive_name]['synaptic_delays'], probability=drive_params_static[drive_name]['probability'], event_seed=drive_params_static[drive_name]['event_seed'], - conn_seed=drive_params_static[drive_name]['conn_seed'] + conn_seed=drive_params_static[drive_name]['conn_seed'], ) # run the simulation @@ -338,21 +367,30 @@ def _optrun(drive_params_updated, drive_params_static, net, tstop, dt, dpls = [dpl.smooth(smooth_window_len) for dpl in dpls] avg_dpl = average_dipoles(dpls) - avg_rmse = _rmse(avg_dpl, opt_dpls['target_dpl'], - tstart=opt_params['opt_start'], - tstop=opt_params['opt_end'], - weights=opt_params['weights']) - avg_rmse_unweighted = _rmse(avg_dpl, opt_dpls['target_dpl'], - tstart=opt_params['opt_start'], - tstop=tstop, weights=None) + avg_rmse = _rmse( + avg_dpl, + opt_dpls['target_dpl'], + tstart=opt_params['opt_start'], + tstop=opt_params['opt_end'], + weights=opt_params['weights'], + ) + avg_rmse_unweighted = _rmse( + avg_dpl, + opt_dpls['target_dpl'], + tstart=opt_params['opt_start'], + tstop=tstop, + weights=None, + ) if return_rmse: opt_params['iter_avg_rmse'].append(avg_rmse_unweighted) opt_params['stepminopterr'] = avg_rmse opt_dpls['best_dpl'] = avg_dpl - print("weighted RMSE: %.2e over range [%3.3f-%3.3f] ms" % - (avg_rmse, opt_params['opt_start'], opt_params['opt_end'])) + print( + 'weighted RMSE: %.2e over range [%3.3f-%3.3f] ms' + % (avg_rmse, opt_params['opt_start'], opt_params['opt_end']) + ) opt_params['optiter'] += 1 @@ -360,17 +398,21 @@ def _optrun(drive_params_updated, drive_params_static, net, tstop, dt, def _run_optimization(maxiter, param_ranges, optrun): - cons = list() x0 = list() for idx, param_name in enumerate(param_ranges): x0.append(param_ranges[param_name]['initial']) - cons.append( - lambda x, idx=idx: param_ranges[param_name]['maxval'] - x[idx]) - cons.append( - lambda x, idx=idx: x[idx] - param_ranges[param_name]['minval']) - result = fmin_cobyla(func=optrun, cons=cons, rhobeg=0.1, rhoend=1e-4, - x0=x0, maxfun=maxiter, catol=0.0) + cons.append(lambda x, idx=idx: param_ranges[param_name]['maxval'] - x[idx]) + cons.append(lambda x, idx=idx: x[idx] - param_ranges[param_name]['minval']) + result = fmin_cobyla( + func=optrun, + cons=cons, + rhobeg=0.1, + rhoend=1e-4, + x0=x0, + maxfun=maxiter, + catol=0.0, + ) return result @@ -397,11 +439,10 @@ def _get_drive_params(net, drive_names): # legacy_mode hack: don't include invalid connections that have # been added in Network when legacy_mode=True - if not (drive['location'] == 'distal' and - target_type == 'L5_basket'): - if target_receptor == "ampa": + if not (drive['location'] == 'distal' and target_type == 'L5_basket'): + if target_receptor == 'ampa': weights.update({f'ampa_{target_type}': weight}) - if target_receptor == "nmda": + if target_receptor == 'nmda': weights.update({f'nmda_{target_type}': weight}) # delay should be constant across AMPA and NMDA receptor types delay = net.connectivity[conn_idx]['nc_dict']['A_delay'] @@ -433,19 +474,31 @@ def _get_drive_params(net, drive_names): return drive_dynamics, drive_syn_weights, drive_static_params -def optimize_evoked(net, tstop, n_trials, target_dpl, initial_dpl, maxiter=50, - timing_range_multiplier=3.0, sigma_range_multiplier=50.0, - synweight_range_multiplier=500.0, decay_multiplier=1.6, - scale_factor=1., smooth_window_len=None, dt=0.025, - which_drives='all', return_rmse=False): +def optimize_evoked( + net, + tstop, + n_trials, + target_dpl, + initial_dpl, + maxiter=50, + timing_range_multiplier=3.0, + sigma_range_multiplier=50.0, + synweight_range_multiplier=500.0, + decay_multiplier=1.6, + scale_factor=1.0, + smooth_window_len=None, + dt=0.025, + which_drives='all', + return_rmse=False, +): """Optimize drives to generate evoked response. Parameters ---------- net : Network instance - An instance of the Network object with attached evoked drives. Timing - and synaptic weight parameters will be optimized for each attached - evoked drive. Note that no new drives will be created or old drives + An instance of the Network object with attached evoked drives. Timing + and synaptic weight parameters will be optimized for each attached + evoked drive. Note that no new drives will be created or old drives destroyed. tstop : float The simulation stop time (ms). @@ -478,7 +531,7 @@ def optimize_evoked(net, tstop, n_trials, target_dpl, initial_dpl, maxiter=50, Evoked drives to optimize. If 'all', will optimize all evoked drives. If a subset list of evoked drives, will optimize only the evoked drives in the list. return_rmse : bool - Returns list of unweighted RMSEs between the simulated and experimental dipole + Returns list of unweighted RMSEs between the simulated and experimental dipole waveforms for each optimization step Returns @@ -499,24 +552,35 @@ def optimize_evoked(net, tstop, n_trials, target_dpl, initial_dpl, maxiter=50, net = net.copy() - evoked_drive_names = [key for key in net.external_drives.keys() - if net.external_drives[key]['type'] == 'evoked'] + evoked_drive_names = [ + key + for key in net.external_drives.keys() + if net.external_drives[key]['type'] == 'evoked' + ] if len(evoked_drive_names) == 0: - raise ValueError('The current Network instance lacks any evoked ' - 'drives. Consider adding drives using ' - 'net.add_evoked_drive') + raise ValueError( + 'The current Network instance lacks any evoked ' + 'drives. Consider adding drives using ' + 'net.add_evoked_drive' + ) elif which_drives == 'all': drive_names = evoked_drive_names else: - drive_names = [mydrive for mydrive in np.unique(which_drives) - if mydrive in evoked_drive_names] + drive_names = [ + mydrive + for mydrive in np.unique(which_drives) + if mydrive in evoked_drive_names + ] if len(drive_names) == 0: - raise ValueError('The drives selected to be optimized are not evoked ' - 'drives. Optimization works only evoked drives.') + raise ValueError( + 'The drives selected to be optimized are not evoked ' + 'drives. Optimization works only evoked drives.' + ) - drive_dynamics, drive_syn_weights, drive_static_params = \ - _get_drive_params(net, drive_names) + drive_dynamics, drive_syn_weights, drive_static_params = _get_drive_params( + net, drive_names + ) # Create a sorted dictionary with the inputs and parameters # belonging to each. @@ -527,20 +591,21 @@ def optimize_evoked(net, tstop, n_trials, target_dpl, initial_dpl, maxiter=50, # the simulation timeframe to optimize. Chunks are consolidated if # more than one input should # be optimized at a time. - evinput_params = _split_by_evinput(drive_names, - drive_dynamics, - drive_syn_weights, - tstop, - sigma_range_multiplier, - timing_range_multiplier, - synweight_range_multiplier) - evinput_params = _generate_weights(evinput_params, tstop, dt, - decay_multiplier) + evinput_params = _split_by_evinput( + drive_names, + drive_dynamics, + drive_syn_weights, + tstop, + sigma_range_multiplier, + timing_range_multiplier, + synweight_range_multiplier, + ) + evinput_params = _generate_weights(evinput_params, tstop, dt, decay_multiplier) param_chunks = _consolidate_chunks(evinput_params) best_rmse = _rmse(initial_dpl, target_dpl, tstop=tstop) opt_dpls = dict(best_dpl=initial_dpl, target_dpl=target_dpl) - print("Initial RMSE: %.2e" % best_rmse) + print('Initial RMSE: %.2e' % best_rmse) opt_params = dict() @@ -556,11 +621,10 @@ def optimize_evoked(net, tstop, n_trials, target_dpl, initial_dpl, maxiter=50, opt_params.update(param_chunks[step]) if maxiter == 0: - print("Skipping optimization step %d (0 simulations)" % (step + 1)) + print('Skipping optimization step %d (0 simulations)' % (step + 1)) continue - if (opt_params['cur_step'] > 0 and - opt_params['cur_step'] == total_steps - 1): + if opt_params['cur_step'] > 0 and opt_params['cur_step'] == total_steps - 1: # For the last step (all inputs), recalculate ranges and update # param_chunks. If previous optimization steps increased # std. dev. this could result in fewer optimization steps as @@ -570,58 +634,66 @@ def optimize_evoked(net, tstop, n_trials, target_dpl, initial_dpl, maxiter=50, # The purpose of the last step (with regular RMSE) is to clean up # overfitting introduced by local weighted RMSE optimization. - evinput_params = _split_by_evinput(drive_names, - drive_dynamics, - drive_syn_weights, - tstop, - sigma_range_multiplier, - timing_range_multiplier, - synweight_range_multiplier) - evinput_params = _generate_weights(evinput_params, tstop, dt, - decay_multiplier) + evinput_params = _split_by_evinput( + drive_names, + drive_dynamics, + drive_syn_weights, + tstop, + sigma_range_multiplier, + timing_range_multiplier, + synweight_range_multiplier, + ) + evinput_params = _generate_weights( + evinput_params, tstop, dt, decay_multiplier + ) param_chunks = _consolidate_chunks(evinput_params) # reload opt_params for the last step in case the number of # steps was changed by updateoptparams() opt_params.update(param_chunks[total_steps - 1]) - print("Starting optimization step %d/%d" % (step + 1, total_steps)) + print('Starting optimization step %d/%d' % (step + 1, total_steps)) opt_params['optiter'] = 0 - opt_params['stepminopterr'] = _rmse(opt_dpls['best_dpl'], - opt_dpls['target_dpl'], - tstart=opt_params['opt_start'], - tstop=opt_params['opt_end'], - weights=opt_params['weights']) + opt_params['stepminopterr'] = _rmse( + opt_dpls['best_dpl'], + opt_dpls['target_dpl'], + tstart=opt_params['opt_start'], + tstop=opt_params['opt_end'], + weights=opt_params['weights'], + ) net_opt = net.copy() # drive_params_updated must be a list for compatibility with the args # in the optimization engine, scipy.optimize.fmin_cobyla - _myoptrun = partial(_optrun, - drive_params_static=drive_static_params, - net=net_opt, - tstop=tstop, - dt=dt, - n_trials=n_trials, - opt_params=opt_params, - opt_dpls=opt_dpls, - scale_factor=scale_factor, - smooth_window_len=smooth_window_len, - return_rmse=return_rmse) - - print('Optimizing from [%3.3f-%3.3f] ms' % (opt_params['opt_start'], - opt_params['opt_end'])) - opt_results = _run_optimization(maxiter=maxiter, - param_ranges=opt_params['ranges'], - optrun=_myoptrun) + _myoptrun = partial( + _optrun, + drive_params_static=drive_static_params, + net=net_opt, + tstop=tstop, + dt=dt, + n_trials=n_trials, + opt_params=opt_params, + opt_dpls=opt_dpls, + scale_factor=scale_factor, + smooth_window_len=smooth_window_len, + return_rmse=return_rmse, + ) + + print( + 'Optimizing from [%3.3f-%3.3f] ms' + % (opt_params['opt_start'], opt_params['opt_end']) + ) + opt_results = _run_optimization( + maxiter=maxiter, param_ranges=opt_params['ranges'], optrun=_myoptrun + ) # tiny negative weights are possible. Clip them to 0. opt_results[opt_results < 0] = 0 # update opt_params for the next round if total rmse decreased - avg_rmse = _rmse(opt_dpls['best_dpl'], - opt_dpls['target_dpl'], - tstop=tstop, - weights=None) + avg_rmse = _rmse( + opt_dpls['best_dpl'], opt_dpls['target_dpl'], tstop=tstop, weights=None + ) if avg_rmse <= best_rmse: best_rmse = avg_rmse for var_name, value in zip(opt_params['ranges'], opt_results): @@ -629,7 +701,7 @@ def optimize_evoked(net, tstop, n_trials, target_dpl, initial_dpl, maxiter=50, net = net_opt - print("Final RMSE: %.2e" % best_rmse) + print('Final RMSE: %.2e' % best_rmse) if return_rmse is True: return net, opt_params['iter_avg_rmse'] diff --git a/hnn_core/parallel_backends.py b/hnn_core/parallel_backends.py index 9757414d0..7261cf054 100644 --- a/hnn_core/parallel_backends.py +++ b/hnn_core/parallel_backends.py @@ -41,12 +41,12 @@ def _gather_trial_data(sim_data, net, n_trials, postproc): # Create array of equally sampled time points for simulating currents cell_type_names = list(net.cell_types.keys()) - cell_response = CellResponse(times=sim_data[0]['times'], - cell_type_names=cell_type_names) + cell_response = CellResponse( + times=sim_data[0]['times'], cell_type_names=cell_type_names + ) net.cell_response = cell_response for idx in range(n_trials): - # cell response net.cell_response._spike_times.append(sim_data[idx]['spike_times']) net.cell_response._spike_gids.append(sim_data[idx]['spike_gids']) @@ -62,8 +62,7 @@ def _gather_trial_data(sim_data, net, n_trials, postproc): arr._times = sim_data[idx]['rec_times'][arr_name] # dipole - dpl = Dipole(times=sim_data[idx]['times'], - data=sim_data[idx]['dpl_data']) + dpl = Dipole(times=sim_data[idx]['times'], data=sim_data[idx]['dpl_data']) N_pyr_x = net._N_pyr_x N_pyr_y = net._N_pyr_y @@ -86,11 +85,11 @@ def _get_mpi_env(): my_env = os.environ.copy() # For Linux systems if sys.platform != 'win32': - my_env["OMPI_MCA_btl_base_warn_component_unused"] = '0' + my_env['OMPI_MCA_btl_base_warn_component_unused'] = '0' if 'darwin' in sys.platform: - my_env["PMIX_MCA_gds"] = "^ds12" # open-mpi/ompi/issues/7516 - my_env["TMPDIR"] = "/tmp" # open-mpi/ompi/issues/2956 + my_env['PMIX_MCA_gds'] = '^ds12' # open-mpi/ompi/issues/7516 + my_env['TMPDIR'] = '/tmp' # open-mpi/ompi/issues/2956 return my_env @@ -127,8 +126,7 @@ def run_subprocess(command, obj, timeout, proc_queue=None, *args, **kwargs): threads_started = False try: - proc = Popen(command, stdin=PIPE, stdout=PIPE, stderr=PIPE, *args, - **kwargs) + proc = Popen(command, stdin=PIPE, stdout=PIPE, stderr=PIPE, *args, **kwargs) # now that the process has started, add it to the queue # used by MPIBackend.terminate() @@ -138,10 +136,8 @@ def run_subprocess(command, obj, timeout, proc_queue=None, *args, **kwargs): # set up polling first so all of child's stdout/stderr # gets captured event = Event() - out_t = Thread(target=_thread_handler, - args=(event, proc.stdout, out_q)) - err_t = Thread(target=_thread_handler, - args=(event, proc.stderr, err_q)) + out_t = Thread(target=_thread_handler, args=(event, proc.stdout, out_q)) + err_t = Thread(target=_thread_handler, args=(event, proc.stderr, err_q)) out_t.start() err_t.start() threads_started = True @@ -166,7 +162,7 @@ def run_subprocess(command, obj, timeout, proc_queue=None, *args, **kwargs): elif child_terminated: # child terminated early, and we already # captured output left in queues - warn("Child process failed unexpectedly") + warn('Child process failed unexpectedly') kill_proc_name('nrniv') break @@ -178,8 +174,10 @@ def run_subprocess(command, obj, timeout, proc_queue=None, *args, **kwargs): # child failed during _write_net(). get the # output and break out of loop on the next # iteration - warn("Received BrokenPipeError exception. " - "Child process failed unexpectedly") + warn( + 'Received BrokenPipeError exception. ' + 'Child process failed unexpectedly' + ) continue else: sent_network = True @@ -192,14 +190,15 @@ def run_subprocess(command, obj, timeout, proc_queue=None, *args, **kwargs): # the network has been sent) break - if not child_terminated and \ - count_since_last_output > timeout_cycles: - warn("Timeout exceeded while waiting for child process output" - ". Terminating...") + if not child_terminated and count_since_last_output > timeout_cycles: + warn( + 'Timeout exceeded while waiting for child process output' + '. Terminating...' + ) kill_proc_name('nrniv') break except KeyboardInterrupt: - warn("Received KeyboardInterrupt. Stopping simulation process...") + warn('Received KeyboardInterrupt. Stopping simulation process...') if threads_started: # stop the threads @@ -228,12 +227,11 @@ def run_subprocess(command, obj, timeout, proc_queue=None, *args, **kwargs): try: proc.wait(1) # wait maximum of 1s except TimeoutExpired: - warn("Could not kill python subprocess: PID %d" % proc.pid) + warn('Could not kill python subprocess: PID %d' % proc.pid) if not proc.returncode == 0: # simulation failed with a numeric return code - raise RuntimeError("MPI simulation failed. Return code: %d" % - proc.returncode) + raise RuntimeError('MPI simulation failed. Return code: %d' % proc.returncode) child_data = _process_child_data(proc_data_bytes, data_len) @@ -261,8 +259,10 @@ def _process_child_data(data_bytes, data_len): """ if not data_len == len(data_bytes): # This is indicative of a failure. For debugging purposes. - warn("Length of received data unexpected. Expecting %d bytes, " - "got %d" % (data_len, len(data_bytes))) + warn( + 'Length of received data unexpected. Expecting %d bytes, ' + 'got %d' % (data_len, len(data_bytes)) + ) if len(data_bytes) == 0: raise RuntimeError("MPI simulation didn't return any data") @@ -274,9 +274,10 @@ def _process_child_data(data_bytes, data_len): # This is here for future debugging purposes. Unit tests can't # reproduce an incorrectly padded string, but this has been an # issue before - raise ValueError("Incorrect padding for data length %d bytes" % - len(data_len) + " (mod 4 = %d)" % - (len(data_len) % 4)) + raise ValueError( + 'Incorrect padding for data length %d bytes' % len(data_len) + + ' (mod 4 = %d)' % (len(data_len) % 4) + ) # unpickle the data return pickle.loads(data_pickled) @@ -351,10 +352,11 @@ def requires_mpi4py(function): try: import mpi4py + assert hasattr(mpi4py, '__version__') skip = False except (ImportError, ModuleNotFoundError) as err: - if "TRAVIS_OS_NAME" not in os.environ: + if 'TRAVIS_OS_NAME' not in os.environ: skip = True else: raise ImportError(err) @@ -368,10 +370,11 @@ def requires_psutil(function): try: import psutil + assert hasattr(psutil, '__version__') skip = False except (ImportError, ModuleNotFoundError) as err: - if "TRAVIS_OS_NAME" not in os.environ: + if 'TRAVIS_OS_NAME' not in os.environ: skip = True else: raise ImportError(err) @@ -380,8 +383,7 @@ def requires_psutil(function): def _extract_data_length(data_str, object_name): - data_len_match = re.search('@end_of_%s:' % object_name + r'(\d+)@', - data_str) + data_len_match = re.search('@end_of_%s:' % object_name + r'(\d+)@', data_str) if data_len_match is not None: return int(data_len_match.group(1)) else: @@ -431,12 +433,15 @@ def _get_procs_running(proc_name): from psutil import process_iter process_list = [] - for p in process_iter(attrs=["name", "exe", "cmdline"]): - if proc_name == p.info['name'] or \ - (p.info['exe'] is not None and - os.path.basename(p.info['exe']) == proc_name) or \ - (p.info['cmdline'] and - p.info['cmdline'][0] == proc_name): + for p in process_iter(attrs=['name', 'exe', 'cmdline']): + if ( + proc_name == p.info['name'] + or ( + p.info['exe'] is not None + and os.path.basename(p.info['exe']) == proc_name + ) + or (p.info['cmdline'] and p.info['cmdline'][0] == proc_name) + ): process_list.append(p) return process_list @@ -463,8 +468,7 @@ def kill_proc_name(proc_name): if len(running) < len(procs): killed_procs = True pids = [str(proc.pid) for proc in running] - warn("Failed to kill nrniv process(es) %s" % - ','.join(pids)) + warn('Failed to kill nrniv process(es) %s' % ','.join(pids)) else: killed_procs = True @@ -499,6 +503,7 @@ class JoblibBackend(object): n_jobs : int The number of jobs to start in parallel """ + def __init__(self, n_jobs=1): self.n_jobs = n_jobs @@ -554,14 +559,18 @@ def simulate(self, net, tstop, dt, n_trials, postproc=False): The Dipole results from each simulation trial """ - print(f"Joblib will run {n_trials} trial(s) in parallel by " - f"distributing trials over {self.n_jobs} jobs.") + print( + f'Joblib will run {n_trials} trial(s) in parallel by ' + f'distributing trials over {self.n_jobs} jobs.' + ) parallel, myfunc = self._parallel_func(_simulate_single_trial) - sim_data = parallel(myfunc(net, tstop, dt, trial_idx) for - trial_idx in range(n_trials)) + sim_data = parallel( + myfunc(net, tstop, dt, trial_idx) for trial_idx in range(n_trials) + ) - dpls = _gather_trial_data(sim_data, net=net, n_trials=n_trials, - postproc=postproc) + dpls = _gather_trial_data( + sim_data, net=net, n_trials=n_trials, postproc=postproc + ) return dpls @@ -596,6 +605,7 @@ class MPIBackend(object): There will be a valid process handle present the queue when a MPI åsimulation is running. """ + def __init__(self, n_procs=None, mpi_cmd='mpiexec'): self.expected_data_length = 0 self.proc = None @@ -643,10 +653,14 @@ def __init__(self, n_procs=None, mpi_cmd='mpiexec'): self.mpi_cmd += ' -np ' + str(self.n_procs) - self.mpi_cmd += ' nrniv -python -mpi -nobanner ' + \ - sys.executable + ' ' + \ - os.path.join(os.path.dirname(sys.modules[__name__].__file__), - 'mpi_child.py') + self.mpi_cmd += ( + ' nrniv -python -mpi -nobanner ' + + sys.executable + + ' ' + + os.path.join( + os.path.dirname(sys.modules[__name__].__file__), 'mpi_child.py' + ) + ) # Split the command into shell arguments for passing to Popen use_posix = True if sys.platform != 'win32' else False @@ -694,29 +708,39 @@ def simulate(self, net, tstop, dt, n_trials, postproc=False): # just use the joblib backend for a single core if self.n_procs == 1: - print("MPIBackend is set to use 1 core: transferring the " - "simulation to JoblibBackend....") - return JoblibBackend(n_jobs=1).simulate(net, tstop=tstop, - dt=dt, - n_trials=n_trials, - postproc=postproc) + print( + 'MPIBackend is set to use 1 core: transferring the ' + 'simulation to JoblibBackend....' + ) + return JoblibBackend(n_jobs=1).simulate( + net, tstop=tstop, dt=dt, n_trials=n_trials, postproc=postproc + ) if self.n_procs > net._n_cells: - raise ValueError(f'More MPI processes were assigned than there ' - f'are cells in the network. Please decrease ' - f'the number of parallel processes (got n_procs=' - f'{self.n_procs}) over which you will ' - f'distribute the {net._n_cells} network neurons.') - - print(f"MPI will run {n_trials} trial(s) sequentially by " - f"distributing network neurons over {self.n_procs} processes.") + raise ValueError( + f'More MPI processes were assigned than there ' + f'are cells in the network. Please decrease ' + f'the number of parallel processes (got n_procs=' + f'{self.n_procs}) over which you will ' + f'distribute the {net._n_cells} network neurons.' + ) + + print( + f'MPI will run {n_trials} trial(s) sequentially by ' + f'distributing network neurons over {self.n_procs} processes.' + ) env = _get_mpi_env() self.proc, sim_data = run_subprocess( - command=self.mpi_cmd, obj=[net, tstop, dt, n_trials], timeout=30, - proc_queue=self.proc_queue, env=env, cwd=os.getcwd(), - universal_newlines=True) + command=self.mpi_cmd, + obj=[net, tstop, dt, n_trials], + timeout=30, + proc_queue=self.proc_queue, + env=env, + cwd=os.getcwd(), + universal_newlines=True, + ) dpls = _gather_trial_data(sim_data, net, n_trials, postproc) return dpls @@ -731,12 +755,11 @@ def terminate(self): try: proc = self.proc_queue.get(timeout=1) except Empty: - warn("No currently running process to terminate") + warn('No currently running process to terminate') if proc is not None: proc.terminate() try: proc.wait(5) # wait maximum of 5s except TimeoutExpired: - warn("Could not kill python subprocess: PID %d" % - proc.pid) + warn('Could not kill python subprocess: PID %d' % proc.pid) diff --git a/hnn_core/params.py b/hnn_core/params.py index 9756991aa..66e2e4d92 100644 --- a/hnn_core/params.py +++ b/hnn_core/params.py @@ -90,8 +90,9 @@ def read_params(params_fname, file_contents=None): ext = split_fname[1] if ext not in ['.json', '.param']: - raise ValueError('Unrecognized extension, expected one of' + - ' .json, .param. Got %s' % ext) + raise ValueError( + 'Unrecognized extension, expected one of' + ' .json, .param. Got %s' % ext + ) if file_contents is None: with open(params_fname, 'r') as fp: @@ -101,8 +102,9 @@ def read_params(params_fname, file_contents=None): params_dict = read_func[ext](file_contents) if len(params_dict) == 0: - raise ValueError("Failed to read parameters from file: %s" % - op.normpath(params_fname)) + raise ValueError( + 'Failed to read parameters from file: %s' % op.normpath(params_fname) + ) params = Params(params_dict) @@ -110,16 +112,24 @@ def read_params(params_fname, file_contents=None): def _long_name(short_name): - long_name = dict(L2Basket='L2_basket', L5Basket='L5_basket', - L2Pyr='L2_pyramidal', L5Pyr='L5_pyramidal') + long_name = dict( + L2Basket='L2_basket', + L5Basket='L5_basket', + L2Pyr='L2_pyramidal', + L5Pyr='L5_pyramidal', + ) if short_name in long_name: return long_name[short_name] return short_name def _short_name(short_name): - long_name = dict(L2_basket='L2Basket', L5_basket='L5Basket', - L2_pyramidal='L2Pyr', L5_pyramidal='L5Pyr') + long_name = dict( + L2_basket='L2Basket', + L5_basket='L5Basket', + L2_pyramidal='L2Pyr', + L5_pyramidal='L5Pyr', + ) if short_name in long_name: return long_name[short_name] return short_name @@ -130,25 +140,26 @@ def _extract_bias_specs_from_hnn_params(params, cellname_list): bias_specs = {'tonic': {}} # currently only 'tonic' biases known for cellname in cellname_list: short_name = _short_name(cellname) - is_tonic_present = [f'Itonic_{p}_{short_name}_soma' in - params for p in ['A', 't0', 'T']] + is_tonic_present = [ + f'Itonic_{p}_{short_name}_soma' in params for p in ['A', 't0', 'T'] + ] if any(is_tonic_present): if not all(is_tonic_present): raise ValueError( f'Tonic input must have the amplitude, ' f'start time and end time specified. One ' f'or more parameter may be missing for ' - f'cell type {cellname}') + f'cell type {cellname}' + ) bias_specs['tonic'][cellname] = { 'amplitude': params[f'Itonic_A_{short_name}_soma'], 't0': params[f'Itonic_t0_{short_name}_soma'], - 'tstop': params[f'Itonic_T_{short_name}_soma'] + 'tstop': params[f'Itonic_T_{short_name}_soma'], } return bias_specs -def _extract_drive_specs_from_hnn_params( - params, cellname_list, legacy_mode=False): +def _extract_drive_specs_from_hnn_params(params, cellname_list, legacy_mode=False): """Create 'drive specification' dicts from saved parameters""" # convert legacy params-dict to legacy "feeds" dicts p_common, p_unique = create_pext(params, params['tstop']) @@ -162,14 +173,16 @@ def _extract_drive_specs_from_hnn_params( drive = dict() drive['type'] = 'bursty' drive['cell_specific'] = False - drive['dynamics'] = {'tstart': par['t0'], - 'tstart_std': par['t0_stdev'], - 'tstop': par['tstop'], - 'burst_rate': par['f_input'], - 'burst_std': par['stdev'], - 'numspikes': par['events_per_cycle'], - 'n_drive_cells': par['n_drive_cells'], - 'spike_isi': 10} # not exposed in params-files + drive['dynamics'] = { + 'tstart': par['t0'], + 'tstart_std': par['t0_stdev'], + 'tstop': par['tstop'], + 'burst_rate': par['f_input'], + 'burst_std': par['stdev'], + 'numspikes': par['events_per_cycle'], + 'n_drive_cells': par['n_drive_cells'], + 'spike_isi': 10, + } # not exposed in params-files drive['location'] = par['loc'] drive['space_constant'] = par['lamtha'] drive['event_seed'] = par['prng_seedcore'] @@ -201,16 +214,14 @@ def _extract_drive_specs_from_hnn_params( drive['weights_nmda'] = dict() drive['synaptic_delays'] = dict() - if (feed_name.startswith('evprox') or - feed_name.startswith('evdist')): + if feed_name.startswith('evprox') or feed_name.startswith('evdist'): drive['type'] = 'evoked' if feed_name.startswith('evprox'): drive['location'] = 'proximal' else: drive['location'] = 'distal' - cell_keys_present = [key for key in par if - key in cellname_list] + cell_keys_present = [key for key in par if key in cellname_list] sigma = par[cell_keys_present[0]][3] # IID for all cells! n_drive_cells = 'n_cells' @@ -218,10 +229,12 @@ def _extract_drive_specs_from_hnn_params( n_drive_cells = 1 drive['cell_specific'] = False - drive['dynamics'] = {'mu': par['t0'], - 'sigma': sigma, - 'numspikes': par['numspikes'], - 'n_drive_cells': n_drive_cells} + drive['dynamics'] = { + 'mu': par['t0'], + 'sigma': sigma, + 'numspikes': par['numspikes'], + 'n_drive_cells': n_drive_cells, + } drive['space_constant'] = par['lamtha'] drive['event_seed'] = par['prng_seedcore'] # XXX Force random states to be the same as HNN-gui for the default @@ -240,16 +253,17 @@ def _extract_drive_specs_from_hnn_params( # Skip drive if not in legacy mode elif feed_name.startswith('extgauss'): - if (not legacy_mode) and par[ - 'L2_basket'][3] > params['tstop']: + if (not legacy_mode) and par['L2_basket'][3] > params['tstop']: continue drive['type'] = 'gaussian' drive['location'] = par['loc'] - drive['dynamics'] = {'mu': par['L2_basket'][3], # NB IID - 'sigma': par['L2_basket'][4], - 'numspikes': 50, # NB hard-coded in GUI! - 'sync_within_trial': False} + drive['dynamics'] = { + 'mu': par['L2_basket'][3], # NB IID + 'sigma': par['L2_basket'][4], + 'numspikes': 50, # NB hard-coded in GUI! + 'sync_within_trial': False, + } drive['space_constant'] = par['lamtha'] drive['event_seed'] = par['prng_seedcore'] @@ -262,8 +276,7 @@ def _extract_drive_specs_from_hnn_params( drive['weights_nmda'] = dict() # no NMDA weights for Gaussians elif feed_name.startswith('extpois'): - if (not legacy_mode) and par['t_interval'][1] < par[ - 't_interval'][0]: + if (not legacy_mode) and par['t_interval'][1] < par['t_interval'][0]: continue drive['type'] = 'poisson' drive['location'] = par['loc'] @@ -287,9 +300,11 @@ def _extract_drive_specs_from_hnn_params( drive['synaptic_delays'][cellname] = synaptic_delays # do NOT allow negative times sometimes used in param-files - drive['dynamics'] = {'tstart': max(0, par['t_interval'][0]), - 'tstop': max(0, par['t_interval'][1]), - 'rate_constant': rate_params} + drive['dynamics'] = { + 'tstart': max(0, par['t_interval'][0]), + 'tstop': max(0, par['t_interval'][1]), + 'rate_constant': rate_params, + } drive_specs[feed_name] = drive return drive_specs @@ -305,7 +320,6 @@ class Params(dict): """ def __init__(self, params_input=None): - if params_input is None: params_input = dict() @@ -320,8 +334,9 @@ def __init__(self, params_input=None): else: self[key] = params_default[key] else: - raise ValueError('params_input must be dict or None. Got %s' - % type(params_input)) + raise ValueError( + 'params_input must be dict or None. Got %s' % type(params_input) + ) def __repr__(self): """Display the params nicely.""" @@ -400,9 +415,9 @@ def _validate_feed(p_ext_d, tstop): if not p_ext_d['stdev']: for key in p_ext_d.keys(): if key.endswith('Pyr'): - p_ext_d[key] = (p_ext_d[key][0] * 5., p_ext_d[key][1]) + p_ext_d[key] = (p_ext_d[key][0] * 5.0, p_ext_d[key][1]) elif key.endswith('Basket'): - p_ext_d[key] = (p_ext_d[key][0] * 5., p_ext_d[key][1]) + p_ext_d[key] = (p_ext_d[key][0] * 5.0, p_ext_d[key][1]) # if L5 delay is -1, use same delays as L2 unless L2 delay is 0.1 in # which case use 1. <<---- SN: WHAT IS THIS RULE!?!?!? @@ -412,7 +427,7 @@ def _validate_feed(p_ext_d, tstop): if p_ext_d['L2Pyr'][1] != 0.1: p_ext_d[key] = (p_ext_d[key][0], p_ext_d['L2Pyr'][1]) else: - p_ext_d[key] = (p_ext_d[key][0], 1.) + p_ext_d[key] = (p_ext_d[key][0], 1.0) return p_ext_d @@ -425,8 +440,9 @@ def check_evoked_synkeys(p, nprox, ndist): # evoked proximal target cell types lctdist = ['L2Pyr', 'L5Pyr', 'L2Basket'] lsy = ['ampa', 'nmda'] # synapse types used in evoked inputs - for nev, pref, lct in zip([nprox, ndist], ['evprox_', 'evdist_'], - [lctprox, lctdist]): + for nev, pref, lct in zip( + [nprox, ndist], ['evprox_', 'evdist_'], [lctprox, lctdist] + ): for i in range(nev): skey = pref + str(i + 1) for sy in lsy: @@ -437,6 +453,7 @@ def check_evoked_synkeys(p, nprox, ndist): if k not in p: p[k] = p['gbar_' + skey + '_' + ct] + # @@ -452,6 +469,7 @@ def check_pois_synkeys(p): if k not in p: p[k] = 0.0 + # creates the external feed params based on individual simulation params p @@ -478,29 +496,33 @@ def create_pext(p, tstop): 't0': p['t0_input_prox'], 'tstop': p['tstop_input_prox'], 'stdev': p['f_stdev_prox'], - 'L2Pyr_ampa': (p['input_prox_A_weight_L2Pyr_ampa'], - p['input_prox_A_delay_L2']), - 'L2Pyr_nmda': (p['input_prox_A_weight_L2Pyr_nmda'], - p['input_prox_A_delay_L2']), - 'L5Pyr_ampa': (p['input_prox_A_weight_L5Pyr_ampa'], - p['input_prox_A_delay_L5']), - 'L5Pyr_nmda': (p['input_prox_A_weight_L5Pyr_nmda'], - p['input_prox_A_delay_L5']), - 'L2Basket_ampa': (p['input_prox_A_weight_L2Basket_ampa'], - p['input_prox_A_delay_L2']), - 'L2Basket_nmda': (p['input_prox_A_weight_L2Basket_nmda'], - p['input_prox_A_delay_L2']), - 'L5Basket_ampa': (p['input_prox_A_weight_L5Basket_ampa'], - p['input_prox_A_delay_L5']), - 'L5Basket_nmda': (p['input_prox_A_weight_L5Basket_nmda'], - p['input_prox_A_delay_L5']), + 'L2Pyr_ampa': (p['input_prox_A_weight_L2Pyr_ampa'], p['input_prox_A_delay_L2']), + 'L2Pyr_nmda': (p['input_prox_A_weight_L2Pyr_nmda'], p['input_prox_A_delay_L2']), + 'L5Pyr_ampa': (p['input_prox_A_weight_L5Pyr_ampa'], p['input_prox_A_delay_L5']), + 'L5Pyr_nmda': (p['input_prox_A_weight_L5Pyr_nmda'], p['input_prox_A_delay_L5']), + 'L2Basket_ampa': ( + p['input_prox_A_weight_L2Basket_ampa'], + p['input_prox_A_delay_L2'], + ), + 'L2Basket_nmda': ( + p['input_prox_A_weight_L2Basket_nmda'], + p['input_prox_A_delay_L2'], + ), + 'L5Basket_ampa': ( + p['input_prox_A_weight_L5Basket_ampa'], + p['input_prox_A_delay_L5'], + ), + 'L5Basket_nmda': ( + p['input_prox_A_weight_L5Basket_nmda'], + p['input_prox_A_delay_L5'], + ), 'events_per_cycle': p['events_per_cycle_prox'], 'prng_seedcore': int(p['prng_seedcore_input_prox']), - 'lamtha': 100., + 'lamtha': 100.0, 'loc': 'proximal', 'n_drive_cells': p['repeats_prox'], 't0_stdev': p['t0_input_stdev_prox'], - 'threshold': p['threshold'] + 'threshold': p['threshold'], } # ensures time interval makes sense @@ -512,25 +534,25 @@ def create_pext(p, tstop): 't0': p['t0_input_dist'], 'tstop': p['tstop_input_dist'], 'stdev': p['f_stdev_dist'], - 'L2Pyr_ampa': (p['input_dist_A_weight_L2Pyr_ampa'], - p['input_dist_A_delay_L2']), - 'L2Pyr_nmda': (p['input_dist_A_weight_L2Pyr_nmda'], - p['input_dist_A_delay_L2']), - 'L5Pyr_ampa': (p['input_dist_A_weight_L5Pyr_ampa'], - p['input_dist_A_delay_L5']), - 'L5Pyr_nmda': (p['input_dist_A_weight_L5Pyr_nmda'], - p['input_dist_A_delay_L5']), - 'L2Basket_ampa': (p['input_dist_A_weight_L2Basket_ampa'], - p['input_dist_A_delay_L2']), - 'L2Basket_nmda': (p['input_dist_A_weight_L2Basket_nmda'], - p['input_dist_A_delay_L2']), + 'L2Pyr_ampa': (p['input_dist_A_weight_L2Pyr_ampa'], p['input_dist_A_delay_L2']), + 'L2Pyr_nmda': (p['input_dist_A_weight_L2Pyr_nmda'], p['input_dist_A_delay_L2']), + 'L5Pyr_ampa': (p['input_dist_A_weight_L5Pyr_ampa'], p['input_dist_A_delay_L5']), + 'L5Pyr_nmda': (p['input_dist_A_weight_L5Pyr_nmda'], p['input_dist_A_delay_L5']), + 'L2Basket_ampa': ( + p['input_dist_A_weight_L2Basket_ampa'], + p['input_dist_A_delay_L2'], + ), + 'L2Basket_nmda': ( + p['input_dist_A_weight_L2Basket_nmda'], + p['input_dist_A_delay_L2'], + ), 'events_per_cycle': p['events_per_cycle_dist'], 'prng_seedcore': int(p['prng_seedcore_input_dist']), - 'lamtha': 100., + 'lamtha': 100.0, 'loc': 'distal', 'n_drive_cells': p['repeats_dist'], 't0_stdev': p['t0_input_stdev_dist'], - 'threshold': p['threshold'] + 'threshold': p['threshold'], } p_common.append(_validate_feed(feed_dist, tstop)) @@ -554,24 +576,36 @@ def create_pext(p, tstop): skey = 'evprox_' + str(i + 1) p_unique['evprox' + str(i + 1)] = { 't0': p['t_' + skey], - 'L2_pyramidal': (p['gbar_' + skey + '_L2Pyr_ampa'], - p['gbar_' + skey + '_L2Pyr_nmda'], - 0.1, p['sigma_t_' + skey]), - 'L2_basket': (p['gbar_' + skey + '_L2Basket_ampa'], - p['gbar_' + skey + '_L2Basket_nmda'], - 0.1, p['sigma_t_' + skey]), - 'L5_pyramidal': (p['gbar_' + skey + '_L5Pyr_ampa'], - p['gbar_' + skey + '_L5Pyr_nmda'], - 1., p['sigma_t_' + skey]), - 'L5_basket': (p['gbar_' + skey + '_L5Basket_ampa'], - p['gbar_' + skey + '_L5Basket_nmda'], - 1., p['sigma_t_' + skey]), + 'L2_pyramidal': ( + p['gbar_' + skey + '_L2Pyr_ampa'], + p['gbar_' + skey + '_L2Pyr_nmda'], + 0.1, + p['sigma_t_' + skey], + ), + 'L2_basket': ( + p['gbar_' + skey + '_L2Basket_ampa'], + p['gbar_' + skey + '_L2Basket_nmda'], + 0.1, + p['sigma_t_' + skey], + ), + 'L5_pyramidal': ( + p['gbar_' + skey + '_L5Pyr_ampa'], + p['gbar_' + skey + '_L5Pyr_nmda'], + 1.0, + p['sigma_t_' + skey], + ), + 'L5_basket': ( + p['gbar_' + skey + '_L5Basket_ampa'], + p['gbar_' + skey + '_L5Basket_nmda'], + 1.0, + p['sigma_t_' + skey], + ), 'prng_seedcore': int(p['prng_seedcore_' + skey]), - 'lamtha': 3., + 'lamtha': 3.0, 'loc': 'proximal', 'sync_evinput': p['sync_evinput'], 'threshold': p['threshold'], - 'numspikes': p['numspikes_' + skey] + 'numspikes': p['numspikes_' + skey], } # Create distal evoked response parameters @@ -580,21 +614,30 @@ def create_pext(p, tstop): skey = 'evdist_' + str(i + 1) p_unique['evdist' + str(i + 1)] = { 't0': p['t_' + skey], - 'L2_pyramidal': (p['gbar_' + skey + '_L2Pyr_ampa'], - p['gbar_' + skey + '_L2Pyr_nmda'], - 0.1, p['sigma_t_' + skey]), - 'L5_pyramidal': (p['gbar_' + skey + '_L5Pyr_ampa'], - p['gbar_' + skey + '_L5Pyr_nmda'], - 0.1, p['sigma_t_' + skey]), - 'L2_basket': (p['gbar_' + skey + '_L2Basket_ampa'], - p['gbar_' + skey + '_L2Basket_nmda'], - 0.1, p['sigma_t_' + skey]), + 'L2_pyramidal': ( + p['gbar_' + skey + '_L2Pyr_ampa'], + p['gbar_' + skey + '_L2Pyr_nmda'], + 0.1, + p['sigma_t_' + skey], + ), + 'L5_pyramidal': ( + p['gbar_' + skey + '_L5Pyr_ampa'], + p['gbar_' + skey + '_L5Pyr_nmda'], + 0.1, + p['sigma_t_' + skey], + ), + 'L2_basket': ( + p['gbar_' + skey + '_L2Basket_ampa'], + p['gbar_' + skey + '_L2Basket_nmda'], + 0.1, + p['sigma_t_' + skey], + ), 'prng_seedcore': int(p['prng_seedcore_' + skey]), - 'lamtha': 3., + 'lamtha': 3.0, 'loc': 'distal', 'sync_evinput': p['sync_evinput'], 'threshold': p['threshold'], - 'numspikes': p['numspikes_' + skey] + 'numspikes': p['numspikes_' + skey], } # this needs to create many feeds @@ -603,23 +646,38 @@ def create_pext(p, tstop): # inputs p_unique['extgauss'] = { 'stim': 'gaussian', - 'L2_basket': (p['L2Basket_Gauss_A_weight'], - p['L2Basket_Gauss_A_weight'], - 1., p['L2Basket_Gauss_mu'], - p['L2Basket_Gauss_sigma']), - 'L2_pyramidal': (p['L2Pyr_Gauss_A_weight'], - p['L2Pyr_Gauss_A_weight'], - 0.1, p['L2Pyr_Gauss_mu'], p['L2Pyr_Gauss_sigma']), - 'L5_basket': (p['L5Basket_Gauss_A_weight'], - p['L5Basket_Gauss_A_weight'], - 1., p['L5Basket_Gauss_mu'], p['L5Basket_Gauss_sigma']), - 'L5_pyramidal': (p['L5Pyr_Gauss_A_weight'], - p['L5Pyr_Gauss_A_weight'], - 1., p['L5Pyr_Gauss_mu'], p['L5Pyr_Gauss_sigma']), - 'lamtha': 100., + 'L2_basket': ( + p['L2Basket_Gauss_A_weight'], + p['L2Basket_Gauss_A_weight'], + 1.0, + p['L2Basket_Gauss_mu'], + p['L2Basket_Gauss_sigma'], + ), + 'L2_pyramidal': ( + p['L2Pyr_Gauss_A_weight'], + p['L2Pyr_Gauss_A_weight'], + 0.1, + p['L2Pyr_Gauss_mu'], + p['L2Pyr_Gauss_sigma'], + ), + 'L5_basket': ( + p['L5Basket_Gauss_A_weight'], + p['L5Basket_Gauss_A_weight'], + 1.0, + p['L5Basket_Gauss_mu'], + p['L5Basket_Gauss_sigma'], + ), + 'L5_pyramidal': ( + p['L5Pyr_Gauss_A_weight'], + p['L5Pyr_Gauss_A_weight'], + 1.0, + p['L5Pyr_Gauss_mu'], + p['L5Pyr_Gauss_sigma'], + ), + 'lamtha': 100.0, 'prng_seedcore': int(p['prng_seedcore_extgauss']), 'loc': 'proximal', - 'threshold': p['threshold'] + 'threshold': p['threshold'], } check_pois_synkeys(p) @@ -628,23 +686,35 @@ def create_pext(p, tstop): # NEW: setting up AMPA and NMDA for Poisson inputs; why delays differ? p_unique['extpois'] = { 'stim': 'poisson', - 'L2_basket': (p['L2Basket_Pois_A_weight_ampa'], - p['L2Basket_Pois_A_weight_nmda'], - 1., p['L2Basket_Pois_lamtha']), - 'L2_pyramidal': (p['L2Pyr_Pois_A_weight_ampa'], - p['L2Pyr_Pois_A_weight_nmda'], - 0.1, p['L2Pyr_Pois_lamtha']), - 'L5_basket': (p['L5Basket_Pois_A_weight_ampa'], - p['L5Basket_Pois_A_weight_nmda'], - 1., p['L5Basket_Pois_lamtha']), - 'L5_pyramidal': (p['L5Pyr_Pois_A_weight_ampa'], - p['L5Pyr_Pois_A_weight_nmda'], - 1., p['L5Pyr_Pois_lamtha']), - 'lamtha': 100., + 'L2_basket': ( + p['L2Basket_Pois_A_weight_ampa'], + p['L2Basket_Pois_A_weight_nmda'], + 1.0, + p['L2Basket_Pois_lamtha'], + ), + 'L2_pyramidal': ( + p['L2Pyr_Pois_A_weight_ampa'], + p['L2Pyr_Pois_A_weight_nmda'], + 0.1, + p['L2Pyr_Pois_lamtha'], + ), + 'L5_basket': ( + p['L5Basket_Pois_A_weight_ampa'], + p['L5Basket_Pois_A_weight_nmda'], + 1.0, + p['L5Basket_Pois_lamtha'], + ), + 'L5_pyramidal': ( + p['L5Pyr_Pois_A_weight_ampa'], + p['L5Pyr_Pois_A_weight_nmda'], + 1.0, + p['L5Pyr_Pois_lamtha'], + ), + 'lamtha': 100.0, 'prng_seedcore': int(p['prng_seedcore_extpois']), 't_interval': (p['t0_pois'], p['T_pois']), 'loc': 'proximal', - 'threshold': p['threshold'] + 'threshold': p['threshold'], } return p_common, p_unique @@ -663,9 +733,10 @@ def compare_dictionaries(d1, d2): def _any_positive_weights(drive): - """ Checks a drive for any positive weights. """ - weights = (list(drive['weights_ampa'].values()) + - list(drive['weights_nmda'].values())) + """Checks a drive for any positive weights.""" + weights = list(drive['weights_ampa'].values()) + list( + drive['weights_nmda'].values() + ) if any([val > 0 for val in weights]): return True else: @@ -702,18 +773,21 @@ def remove_nulled_drives(net): space_constant = net.connectivity[conn_indices[0]]['nc_dict']['lamtha'] probability = net.connectivity[conn_indices[0]]['probability'] - extras[drive_name] = {'space_constant': space_constant, - 'probability': probability} + extras[drive_name] = { + 'space_constant': space_constant, + 'probability': probability, + } net.clear_drives() for drive_name, drive in drives_copy.items(): # Do not add drive if tstart is > tstop, or negative t_start = drive['dynamics'].get('tstart') t_stop = drive['dynamics'].get('tstop') - if (t_start is not None and t_stop is not None and - ((t_start > t_stop) or - (t_start < 0) or - (t_stop < 0))): + if ( + t_start is not None + and t_stop is not None + and ((t_start > t_stop) or (t_start < 0) or (t_stop < 0)) + ): continue # Do not add if all 0 weights elif not _any_positive_weights(drive): @@ -722,19 +796,22 @@ def remove_nulled_drives(net): # Set n_drive_cells to 'n_cells' if equal to max number of cells if drive['cell_specific']: drive['n_drive_cells'] = 'n_cells' - net._attach_drive(drive['name'], drive, drive['weights_ampa'], - drive['weights_nmda'], drive['location'], - extras[drive_name]['space_constant'], - drive['synaptic_delays'], - drive['n_drive_cells'], drive['cell_specific'], - extras[drive_name]['probability']) + net._attach_drive( + drive['name'], + drive, + drive['weights_ampa'], + drive['weights_nmda'], + drive['location'], + extras[drive_name]['space_constant'], + drive['synaptic_delays'], + drive['n_drive_cells'], + drive['cell_specific'], + extras[drive_name]['probability'], + ) return net -def convert_to_json(params_fname, - out_fname, - include_drives=True, - overwrite=True): +def convert_to_json(params_fname, out_fname, include_drives=True, overwrite=True): """Converts legacy json or param format to hierarchical json format Parameters @@ -765,18 +842,19 @@ def convert_to_json(params_fname, if out_fname.suffix != '.json': out_fname = out_fname.with_suffix('.json') - net = jones_2009_model(params=read_params(params_fname), - add_drives_from_params=include_drives, - legacy_mode=(True if params_suffix == 'param' - else False), - ) + net = jones_2009_model( + params=read_params(params_fname), + add_drives_from_params=include_drives, + legacy_mode=(True if params_suffix == 'param' else False), + ) # Remove drives that have null attributes net = remove_nulled_drives(net) - net.write_configuration(fname=out_fname, - overwrite=overwrite, - ) + net.write_configuration( + fname=out_fname, + overwrite=overwrite, + ) return diff --git a/hnn_core/params_default.py b/hnn_core/params_default.py index 465149ba2..272d51f81 100644 --- a/hnn_core/params_default.py +++ b/hnn_core/params_default.py @@ -5,124 +5,109 @@ def get_params_default(nprox=2, ndist=1): - """ Note that nearly all start times are set BEYOND tstop for this file - Most values here are set to whatever default value - inactivates them, such as 0 for conductance - prng seed values are also set to 0 (non-random) - flat file of default values - will most often be overwritten + """Note that nearly all start times are set BEYOND tstop for this file + Most values here are set to whatever default value + inactivates them, such as 0 for conductance + prng seed values are also set to 0 (non-random) + flat file of default values + will most often be overwritten """ # set default params p = { 'sim_prefix': 'default', - # simulation end time (ms) - 'tstop': 250., - + 'tstop': 250.0, # numbers of cells making up the pyramidal grids 'N_pyr_x': 1, 'N_pyr_y': 1, - # amplitudes of individual Gaussian random inputs to L2Pyr and L5Pyr # L2 Basket params - 'L2Basket_Gauss_A_weight': 0., - 'L2Basket_Gauss_mu': 2000., + 'L2Basket_Gauss_A_weight': 0.0, + 'L2Basket_Gauss_mu': 2000.0, 'L2Basket_Gauss_sigma': 3.6, - 'L2Basket_Pois_A_weight_ampa': 0., - 'L2Basket_Pois_A_weight_nmda': 0., - 'L2Basket_Pois_lamtha': 0., - + 'L2Basket_Pois_A_weight_ampa': 0.0, + 'L2Basket_Pois_A_weight_nmda': 0.0, + 'L2Basket_Pois_lamtha': 0.0, # L2 Pyr params - 'L2Pyr_Gauss_A_weight': 0., - 'L2Pyr_Gauss_mu': 2000., + 'L2Pyr_Gauss_A_weight': 0.0, + 'L2Pyr_Gauss_mu': 2000.0, 'L2Pyr_Gauss_sigma': 3.6, - 'L2Pyr_Pois_A_weight_ampa': 0., - 'L2Pyr_Pois_A_weight_nmda': 0., - 'L2Pyr_Pois_lamtha': 0., - + 'L2Pyr_Pois_A_weight_ampa': 0.0, + 'L2Pyr_Pois_A_weight_nmda': 0.0, + 'L2Pyr_Pois_lamtha': 0.0, # L5 Pyr params - 'L5Pyr_Gauss_A_weight': 0., - 'L5Pyr_Gauss_mu': 2000., + 'L5Pyr_Gauss_A_weight': 0.0, + 'L5Pyr_Gauss_mu': 2000.0, 'L5Pyr_Gauss_sigma': 4.8, - 'L5Pyr_Pois_A_weight_ampa': 0., - 'L5Pyr_Pois_A_weight_nmda': 0., - 'L5Pyr_Pois_lamtha': 0., - + 'L5Pyr_Pois_A_weight_ampa': 0.0, + 'L5Pyr_Pois_A_weight_nmda': 0.0, + 'L5Pyr_Pois_lamtha': 0.0, # L5 Basket params - 'L5Basket_Gauss_A_weight': 0., - 'L5Basket_Gauss_mu': 2000., - 'L5Basket_Gauss_sigma': 2., - 'L5Basket_Pois_A_weight_ampa': 0., - 'L5Basket_Pois_A_weight_nmda': 0., - 'L5Basket_Pois_lamtha': 0., - + 'L5Basket_Gauss_A_weight': 0.0, + 'L5Basket_Gauss_mu': 2000.0, + 'L5Basket_Gauss_sigma': 2.0, + 'L5Basket_Pois_A_weight_ampa': 0.0, + 'L5Basket_Pois_A_weight_nmda': 0.0, + 'L5Basket_Pois_lamtha': 0.0, # maximal conductances for all synapses # max conductances TO L2Pyrs - 'gbar_L2Pyr_L2Pyr_ampa': 0., - 'gbar_L2Pyr_L2Pyr_nmda': 0., - 'gbar_L2Basket_L2Pyr_gabaa': 0., - 'gbar_L2Basket_L2Pyr_gabab': 0., - + 'gbar_L2Pyr_L2Pyr_ampa': 0.0, + 'gbar_L2Pyr_L2Pyr_nmda': 0.0, + 'gbar_L2Basket_L2Pyr_gabaa': 0.0, + 'gbar_L2Basket_L2Pyr_gabab': 0.0, # max conductances TO L2Baskets - 'gbar_L2Pyr_L2Basket': 0., - 'gbar_L2Basket_L2Basket': 0., - + 'gbar_L2Pyr_L2Basket': 0.0, + 'gbar_L2Basket_L2Basket': 0.0, # max conductances TO L5Pyr - 'gbar_L5Pyr_L5Pyr_ampa': 0., - 'gbar_L5Pyr_L5Pyr_nmda': 0., - 'gbar_L2Pyr_L5Pyr': 0., - 'gbar_L2Basket_L5Pyr': 0., - 'gbar_L5Basket_L5Pyr_gabaa': 0., - 'gbar_L5Basket_L5Pyr_gabab': 0., - + 'gbar_L5Pyr_L5Pyr_ampa': 0.0, + 'gbar_L5Pyr_L5Pyr_nmda': 0.0, + 'gbar_L2Pyr_L5Pyr': 0.0, + 'gbar_L2Basket_L5Pyr': 0.0, + 'gbar_L5Basket_L5Pyr_gabaa': 0.0, + 'gbar_L5Basket_L5Pyr_gabab': 0.0, # max conductances TO L5Baskets - 'gbar_L5Basket_L5Basket': 0., - 'gbar_L5Pyr_L5Basket': 0., - 'gbar_L2Pyr_L5Basket': 0., - + 'gbar_L5Basket_L5Basket': 0.0, + 'gbar_L5Pyr_L5Basket': 0.0, + 'gbar_L2Pyr_L5Basket': 0.0, # Ongoing proximal alpha rhythm 'distribution_prox': 'normal', - 't0_input_prox': 1000., - 'tstop_input_prox': 250., - 'f_input_prox': 10., - 'f_stdev_prox': 20., + 't0_input_prox': 1000.0, + 'tstop_input_prox': 250.0, + 'f_input_prox': 10.0, + 'f_stdev_prox': 20.0, 'events_per_cycle_prox': 2, 'repeats_prox': 10, 't0_input_stdev_prox': 0.0, - # Ongoing distal alpha rhythm 'distribution_dist': 'normal', - 't0_input_dist': 1000., - 'tstop_input_dist': 250., - 'f_input_dist': 10., - 'f_stdev_dist': 20., + 't0_input_dist': 1000.0, + 'tstop_input_dist': 250.0, + 'f_input_dist': 10.0, + 'f_stdev_dist': 20.0, 'events_per_cycle_dist': 2, 'repeats_dist': 10, 't0_input_stdev_dist': 0.0, - # thalamic input amplitudes and delays - 'input_prox_A_weight_L2Pyr_ampa': 0., - 'input_prox_A_weight_L2Pyr_nmda': 0., - 'input_prox_A_weight_L5Pyr_ampa': 0., - 'input_prox_A_weight_L5Pyr_nmda': 0., - 'input_prox_A_weight_L2Basket_ampa': 0., - 'input_prox_A_weight_L2Basket_nmda': 0., - 'input_prox_A_weight_L5Basket_ampa': 0., - 'input_prox_A_weight_L5Basket_nmda': 0., + 'input_prox_A_weight_L2Pyr_ampa': 0.0, + 'input_prox_A_weight_L2Pyr_nmda': 0.0, + 'input_prox_A_weight_L5Pyr_ampa': 0.0, + 'input_prox_A_weight_L5Pyr_nmda': 0.0, + 'input_prox_A_weight_L2Basket_ampa': 0.0, + 'input_prox_A_weight_L2Basket_nmda': 0.0, + 'input_prox_A_weight_L5Basket_ampa': 0.0, + 'input_prox_A_weight_L5Basket_nmda': 0.0, 'input_prox_A_delay_L2': 0.1, 'input_prox_A_delay_L5': 1.0, - # current values, not sure where these distal values come from, need to # check - 'input_dist_A_weight_L2Pyr_ampa': 0., - 'input_dist_A_weight_L2Pyr_nmda': 0., - 'input_dist_A_weight_L5Pyr_ampa': 0., - 'input_dist_A_weight_L5Pyr_nmda': 0., - 'input_dist_A_weight_L2Basket_ampa': 0., - 'input_dist_A_weight_L2Basket_nmda': 0., - 'input_dist_A_delay_L2': 5., - 'input_dist_A_delay_L5': 5., - + 'input_dist_A_weight_L2Pyr_ampa': 0.0, + 'input_dist_A_weight_L2Pyr_nmda': 0.0, + 'input_dist_A_weight_L5Pyr_ampa': 0.0, + 'input_dist_A_weight_L5Pyr_nmda': 0.0, + 'input_dist_A_weight_L2Basket_ampa': 0.0, + 'input_dist_A_weight_L2Basket_nmda': 0.0, + 'input_dist_A_delay_L2': 5.0, + 'input_dist_A_delay_L5': 5.0, # times and stdevs for evoked responses 'dt_evprox0_evdist': -1, # not used in GUI 'dt_evprox0_evprox1': -1, # not used in GUI @@ -131,10 +116,9 @@ def get_params_default(nprox=2, ndist=1): # increment (ms) for avg evoked input start (for trial n, avg start # time is n * evinputinc 'inc_evinput': 0.0, - # analysis 'save_spec_data': 0, - 'f_max_spec': 40., + 'f_max_spec': 40.0, 'spec_cmap': 'jet', # only used in GUI 'dipole_scalefctr': 30e3, # scale factor for dipole - default at 30e3 # based on scaling needed to match model ongoing rhythms from @@ -150,11 +134,9 @@ def get_params_default(nprox=2, ndist=1): 'record_vsec': 0, # whether to record voltages 'record_isec': 0, # whether to record currents 'record_ca': 0, # whether to record calcium concentration - # numerics # N_trials of 1 means that seed is set by rank 'N_trials': 1, - # prng_state is a string for a filename containing the # random state one wants to use # prng seed cores are the base integer seed for the specific @@ -164,13 +146,12 @@ def get_params_default(nprox=2, ndist=1): 'prng_seedcore_input_dist': 0, 'prng_seedcore_extpois': 0, 'prng_seedcore_extgauss': 0, - # default end time for pois inputs - 't0_pois': 0., + 't0_pois': 0.0, 'T_pois': -1, 'dt': 0.025, 'celsius': 37.0, - 'threshold': 0.0 # firing threshold + 'threshold': 0.0, # firing threshold } # grab cell-specific params and update p accordingly @@ -187,6 +168,7 @@ def get_params_default(nprox=2, ndist=1): return p + # return dict with default params (empty) for evoked inputs; # n is number of evoked inputs # isprox == True iff proximal (otherwise distal) @@ -206,10 +188,9 @@ def get_ev_params_default(n, isprox): tystr = pref + '_' + str(i + 1) # this string includes input number for ty in lty: for sy in lsy: - dout['gbar_' + tystr + '_' + ty + - '_' + sy] = 0. # feed strength - dout['t_' + tystr] = 0. # times and stdevs for evoked responses - dout['sigma_t_' + tystr] = 0. + dout['gbar_' + tystr + '_' + ty + '_' + sy] = 0.0 # feed strength + dout['t_' + tystr] = 0.0 # times and stdevs for evoked responses + dout['sigma_t_' + tystr] = 0.0 # random number generator seed for this input dout['prng_seedcore_' + tystr] = 0 # number of presynaptic spikes (postsynaptic inputs) @@ -224,63 +205,49 @@ def get_L2Pyr_params_default(): 'L2Pyr_soma_L': 22.1, 'L2Pyr_soma_diam': 23.4, 'L2Pyr_soma_cm': 0.6195, - 'L2Pyr_soma_Ra': 200., - + 'L2Pyr_soma_Ra': 200.0, # Dendrites 'L2Pyr_dend_cm': 0.6195, - 'L2Pyr_dend_Ra': 200., - + 'L2Pyr_dend_Ra': 200.0, 'L2Pyr_apicaltrunk_L': 59.5, 'L2Pyr_apicaltrunk_diam': 4.25, - - 'L2Pyr_apical1_L': 306., + 'L2Pyr_apical1_L': 306.0, 'L2Pyr_apical1_diam': 4.08, - - 'L2Pyr_apicaltuft_L': 238., + 'L2Pyr_apicaltuft_L': 238.0, 'L2Pyr_apicaltuft_diam': 3.4, - - 'L2Pyr_apicaloblique_L': 340., + 'L2Pyr_apicaloblique_L': 340.0, 'L2Pyr_apicaloblique_diam': 3.91, - - 'L2Pyr_basal1_L': 85., + 'L2Pyr_basal1_L': 85.0, 'L2Pyr_basal1_diam': 4.25, - - 'L2Pyr_basal2_L': 255., + 'L2Pyr_basal2_L': 255.0, 'L2Pyr_basal2_diam': 2.72, - - 'L2Pyr_basal3_L': 255., + 'L2Pyr_basal3_L': 255.0, 'L2Pyr_basal3_diam': 2.72, - # Synapses - 'L2Pyr_ampa_e': 0., + 'L2Pyr_ampa_e': 0.0, 'L2Pyr_ampa_tau1': 0.5, - 'L2Pyr_ampa_tau2': 5., - - 'L2Pyr_nmda_e': 0., - 'L2Pyr_nmda_tau1': 1., - 'L2Pyr_nmda_tau2': 20., - - 'L2Pyr_gabaa_e': -80., + 'L2Pyr_ampa_tau2': 5.0, + 'L2Pyr_nmda_e': 0.0, + 'L2Pyr_nmda_tau1': 1.0, + 'L2Pyr_nmda_tau2': 20.0, + 'L2Pyr_gabaa_e': -80.0, 'L2Pyr_gabaa_tau1': 0.5, - 'L2Pyr_gabaa_tau2': 5., - - 'L2Pyr_gabab_e': -80., - 'L2Pyr_gabab_tau1': 1., - 'L2Pyr_gabab_tau2': 20., - + 'L2Pyr_gabaa_tau2': 5.0, + 'L2Pyr_gabab_e': -80.0, + 'L2Pyr_gabab_tau1': 1.0, + 'L2Pyr_gabab_tau2': 20.0, # Biophysics soma 'L2Pyr_soma_gkbar_hh2': 0.01, 'L2Pyr_soma_gnabar_hh2': 0.18, - 'L2Pyr_soma_el_hh2': -65., + 'L2Pyr_soma_el_hh2': -65.0, 'L2Pyr_soma_gl_hh2': 4.26e-5, - 'L2Pyr_soma_gbar_km': 250., - + 'L2Pyr_soma_gbar_km': 250.0, # Biophysics dends 'L2Pyr_dend_gkbar_hh2': 0.01, 'L2Pyr_dend_gnabar_hh2': 0.15, - 'L2Pyr_dend_el_hh2': -65., + 'L2Pyr_dend_el_hh2': -65.0, 'L2Pyr_dend_gl_hh2': 4.26e-5, - 'L2Pyr_dend_gbar_km': 250., + 'L2Pyr_dend_gbar_km': 250.0, } @@ -288,77 +255,62 @@ def get_L5Pyr_params_default(): """Returns default params for L5 pyramidal cell.""" return { # Soma - 'L5Pyr_soma_L': 39., + 'L5Pyr_soma_L': 39.0, 'L5Pyr_soma_diam': 28.9, 'L5Pyr_soma_cm': 0.85, - 'L5Pyr_soma_Ra': 200., - + 'L5Pyr_soma_Ra': 200.0, # Dendrites 'L5Pyr_dend_cm': 0.85, - 'L5Pyr_dend_Ra': 200., - - 'L5Pyr_apicaltrunk_L': 102., + 'L5Pyr_dend_Ra': 200.0, + 'L5Pyr_apicaltrunk_L': 102.0, 'L5Pyr_apicaltrunk_diam': 10.2, - - 'L5Pyr_apical1_L': 680., + 'L5Pyr_apical1_L': 680.0, 'L5Pyr_apical1_diam': 7.48, - - 'L5Pyr_apical2_L': 680., + 'L5Pyr_apical2_L': 680.0, 'L5Pyr_apical2_diam': 4.93, - - 'L5Pyr_apicaltuft_L': 425., + 'L5Pyr_apicaltuft_L': 425.0, 'L5Pyr_apicaltuft_diam': 3.4, - - 'L5Pyr_apicaloblique_L': 255., + 'L5Pyr_apicaloblique_L': 255.0, 'L5Pyr_apicaloblique_diam': 5.1, - - 'L5Pyr_basal1_L': 85., + 'L5Pyr_basal1_L': 85.0, 'L5Pyr_basal1_diam': 6.8, - - 'L5Pyr_basal2_L': 255., + 'L5Pyr_basal2_L': 255.0, 'L5Pyr_basal2_diam': 8.5, - - 'L5Pyr_basal3_L': 255., + 'L5Pyr_basal3_L': 255.0, 'L5Pyr_basal3_diam': 8.5, - # Synapses - 'L5Pyr_ampa_e': 0., + 'L5Pyr_ampa_e': 0.0, 'L5Pyr_ampa_tau1': 0.5, - 'L5Pyr_ampa_tau2': 5., - - 'L5Pyr_nmda_e': 0., - 'L5Pyr_nmda_tau1': 1., - 'L5Pyr_nmda_tau2': 20., - - 'L5Pyr_gabaa_e': -80., + 'L5Pyr_ampa_tau2': 5.0, + 'L5Pyr_nmda_e': 0.0, + 'L5Pyr_nmda_tau1': 1.0, + 'L5Pyr_nmda_tau2': 20.0, + 'L5Pyr_gabaa_e': -80.0, 'L5Pyr_gabaa_tau1': 0.5, - 'L5Pyr_gabaa_tau2': 5., - - 'L5Pyr_gabab_e': -80., - 'L5Pyr_gabab_tau1': 1., - 'L5Pyr_gabab_tau2': 20., - + 'L5Pyr_gabaa_tau2': 5.0, + 'L5Pyr_gabab_e': -80.0, + 'L5Pyr_gabab_tau1': 1.0, + 'L5Pyr_gabab_tau2': 20.0, # Biophysics soma 'L5Pyr_soma_gkbar_hh2': 0.01, 'L5Pyr_soma_gnabar_hh2': 0.16, - 'L5Pyr_soma_el_hh2': -65., + 'L5Pyr_soma_el_hh2': -65.0, 'L5Pyr_soma_gl_hh2': 4.26e-5, - 'L5Pyr_soma_gbar_ca': 60., - 'L5Pyr_soma_taur_cad': 20., + 'L5Pyr_soma_gbar_ca': 60.0, + 'L5Pyr_soma_taur_cad': 20.0, 'L5Pyr_soma_gbar_kca': 2e-4, - 'L5Pyr_soma_gbar_km': 200., + 'L5Pyr_soma_gbar_km': 200.0, 'L5Pyr_soma_gbar_cat': 2e-4, 'L5Pyr_soma_gbar_ar': 1e-6, - # Biophysics dends 'L5Pyr_dend_gkbar_hh2': 0.01, 'L5Pyr_dend_gnabar_hh2': 0.14, - 'L5Pyr_dend_el_hh2': -71., + 'L5Pyr_dend_el_hh2': -71.0, 'L5Pyr_dend_gl_hh2': 4.26e-5, - 'L5Pyr_dend_gbar_ca': 60., - 'L5Pyr_dend_taur_cad': 20., + 'L5Pyr_dend_gbar_ca': 60.0, + 'L5Pyr_dend_taur_cad': 20.0, 'L5Pyr_dend_gbar_kca': 2e-4, - 'L5Pyr_dend_gbar_km': 200., + 'L5Pyr_dend_gbar_km': 200.0, 'L5Pyr_dend_gbar_cat': 2e-4, 'L5Pyr_dend_gbar_ar': 1e-6, } diff --git a/hnn_core/tests/conftest.py b/hnn_core/tests/conftest.py index c2560c341..66871f18a 100644 --- a/hnn_core/tests/conftest.py +++ b/hnn_core/tests/conftest.py @@ -1,4 +1,4 @@ -""" Example from pytest documentation +"""Example from pytest documentation https://pytest.org/en/stable/example/simple.html#incremental-testing-test-steps """ @@ -18,8 +18,7 @@ def pytest_runtest_makereport(item, call): - - if "incremental" in item.keywords: + if 'incremental' in item.keywords: # incremental marker is used # The following condition was modified from the example linked above. @@ -27,7 +26,7 @@ def pytest_runtest_makereport(item, call): # a previous test was marked "Skipped". For instance if MPI tests # are skipped because mpi4py is not installed, still continue with # all other tests that do not require mpi4py - if call.excinfo is not None and not call.excinfo.typename == "Skipped": + if call.excinfo is not None and not call.excinfo.typename == 'Skipped': # the test has failed, but was not skipped # retrieve the class name of the test @@ -36,7 +35,7 @@ def pytest_runtest_makereport(item, call): # combination with incremental) parametrize_index = ( tuple(item.callspec.indices.values()) - if hasattr(item, "callspec") + if hasattr(item, 'callspec') else () ) # retrieve the name of the test function @@ -49,7 +48,7 @@ def pytest_runtest_makereport(item, call): def pytest_runtest_setup(item): - if "incremental" in item.keywords: + if 'incremental' in item.keywords: # retrieve the class name of the test cls_name = str(item.cls) # check if a previous test has failed for this class @@ -58,67 +57,88 @@ def pytest_runtest_setup(item): # combination with incremental) parametrize_index = ( tuple(item.callspec.indices.values()) - if hasattr(item, "callspec") + if hasattr(item, 'callspec') else () ) # retrieve the name of the first test function to fail for this # class name and index - test_name = _test_failed_incremental[cls_name].get( - parametrize_index, None) + test_name = _test_failed_incremental[cls_name].get(parametrize_index, None) # if name found, test has failed for the combination of class name # and test name if test_name is not None: - pytest.xfail("previous test failed ({})".format(test_name)) + pytest.xfail('previous test failed ({})'.format(test_name)) @pytest.fixture(scope='module') def run_hnn_core_fixture(): - def _run_hnn_core_fixture(backend=None, n_procs=None, n_jobs=1, - reduced=False, record_vsec=False, - record_isec=False, record_ca=False, - postproc=False, electrode_array=None): + def _run_hnn_core_fixture( + backend=None, + n_procs=None, + n_jobs=1, + reduced=False, + record_vsec=False, + record_isec=False, + record_ca=False, + postproc=False, + electrode_array=None, + ): hnn_core_root = op.dirname(hnn_core.__file__) # default params params_fname = op.join(hnn_core_root, 'param', 'default.json') params = read_params(params_fname) - tstop = 170. + tstop = 170.0 legacy_mode = True if reduced: mesh_shape = (3, 3) - params.update({'t_evprox_1': 5, - 't_evdist_1': 10, - 't_evprox_2': 20, - 'N_trials': 2}) - tstop = 40. + params.update( + {'t_evprox_1': 5, 't_evdist_1': 10, 't_evprox_2': 20, 'N_trials': 2} + ) + tstop = 40.0 legacy_mode = False else: mesh_shape = (10, 10) # Legacy mode necessary for exact dipole comparison test - net = jones_2009_model(params, add_drives_from_params=True, - legacy_mode=legacy_mode, mesh_shape=mesh_shape) + net = jones_2009_model( + params, + add_drives_from_params=True, + legacy_mode=legacy_mode, + mesh_shape=mesh_shape, + ) if electrode_array is not None: for name, positions in electrode_array.items(): net.add_electrode_array(name, positions) if backend == 'mpi': with MPIBackend(n_procs=n_procs, mpi_cmd='mpiexec'): - dpls = simulate_dipole(net, record_vsec=record_vsec, - record_isec=record_isec, - record_ca=record_ca, - postproc=postproc, tstop=tstop) + dpls = simulate_dipole( + net, + record_vsec=record_vsec, + record_isec=record_isec, + record_ca=record_ca, + postproc=postproc, + tstop=tstop, + ) elif backend == 'joblib': with JoblibBackend(n_jobs=n_jobs): - dpls = simulate_dipole(net, record_vsec=record_vsec, - record_isec=record_isec, - record_ca=record_ca, - postproc=postproc, tstop=tstop) + dpls = simulate_dipole( + net, + record_vsec=record_vsec, + record_isec=record_isec, + record_ca=record_ca, + postproc=postproc, + tstop=tstop, + ) else: - dpls = simulate_dipole(net, record_vsec=record_vsec, - record_isec=record_isec, - record_ca=record_ca, - postproc=postproc, tstop=tstop) + dpls = simulate_dipole( + net, + record_vsec=record_vsec, + record_isec=record_isec, + record_ca=record_ca, + postproc=postproc, + tstop=tstop, + ) # check that the network object is picklable after the simulation pickle.dumps(net) @@ -128,4 +148,5 @@ def _run_hnn_core_fixture(backend=None, n_procs=None, n_jobs=1, assert len(drive['events']) == params['N_trials'] return dpls, net + return _run_hnn_core_fixture diff --git a/hnn_core/tests/test_batch_simulate.py b/hnn_core/tests/test_batch_simulate.py index 3b61ca221..848e93099 100644 --- a/hnn_core/tests/test_batch_simulate.py +++ b/hnn_core/tests/test_batch_simulate.py @@ -14,30 +14,38 @@ @pytest.fixture def batch_simulate_instance(tmp_path): """Fixture for creating a BatchSimulate instance with custom parameters.""" - def set_params(param_values, net): - weights_ampa = {'L2_basket': param_values['weight_basket'], - 'L2_pyramidal': param_values['weight_pyr'], - 'L5_basket': param_values['weight_basket'], - 'L5_pyramidal': param_values['weight_pyr']} - synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_basket': 1., 'L5_pyramidal': 1.} + def set_params(param_values, net): + weights_ampa = { + 'L2_basket': param_values['weight_basket'], + 'L2_pyramidal': param_values['weight_pyr'], + 'L5_basket': param_values['weight_basket'], + 'L5_pyramidal': param_values['weight_pyr'], + } + + synaptic_delays = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.1, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.0, + } mu = param_values['mu'] sigma = param_values['sigma'] - net.add_evoked_drive('evprox', - mu=mu, - sigma=sigma, - numspikes=1, - location='proximal', - weights_ampa=weights_ampa, - synaptic_delays=synaptic_delays) + net.add_evoked_drive( + 'evprox', + mu=mu, + sigma=sigma, + numspikes=1, + location='proximal', + weights_ampa=weights_ampa, + synaptic_delays=synaptic_delays, + ) net = jones_2009_model() - return BatchSimulate(net=net, set_params=set_params, - tstop=1., - save_folder=tmp_path, - batch_size=3) + return BatchSimulate( + net=net, set_params=set_params, tstop=1.0, save_folder=tmp_path, batch_size=3 + ) @pytest.fixture @@ -48,7 +56,7 @@ def param_grid(): 'weight_basket': np.logspace(-4, -1, 2), 'weight_pyr': np.logspace(-4, -1, 2), 'mu': [40], - 'sigma': [5] + 'sigma': [5], } @@ -62,7 +70,7 @@ def test_parameter_validation(): 'save_currents', 'save_calcium', 'clear_cache', - 'summary_func' + 'summary_func', ] for param in boolean_params: @@ -72,30 +80,26 @@ def test_parameter_validation(): with pytest.raises(TypeError, match='set_params must be'): BatchSimulate(set_params='invalid') - with pytest.raises(TypeError, match="net must be"): - BatchSimulate(net="invalid_network", set_params=lambda x: x) + with pytest.raises(TypeError, match='net must be'): + BatchSimulate(net='invalid_network', set_params=lambda x: x) def test_generate_param_combinations(batch_simulate_instance, param_grid): """Test generating parameter combinations.""" param_combinations = batch_simulate_instance._generate_param_combinations( - param_grid) + param_grid + ) assert len(param_combinations) == ( - len(param_grid['weight_basket']) * - len(param_grid['weight_pyr']) * - len(param_grid['mu']) * - len(param_grid['sigma']) + len(param_grid['weight_basket']) + * len(param_grid['weight_pyr']) + * len(param_grid['mu']) + * len(param_grid['sigma']) ) def test_run_single_sim(batch_simulate_instance): """Test running a single simulation.""" - param_values = { - 'weight_basket': -3, - 'weight_pyr': -2, - 'mu': 40, - 'sigma': 20 - } + param_values = {'weight_basket': -3, 'weight_pyr': -2, 'mu': 40, 'sigma': 20} result = batch_simulate_instance._run_single_sim(param_values) assert 'net' in result assert 'dpl' in result @@ -107,10 +111,11 @@ def test_run_single_sim(batch_simulate_instance): def test_simulate_batch(batch_simulate_instance, param_grid): """Test simulating a batch of parameter sets.""" param_combinations = batch_simulate_instance._generate_param_combinations( - param_grid)[:1] - results = batch_simulate_instance.simulate_batch(param_combinations, - n_jobs=1, - backend='threading') + param_grid + )[:1] + results = batch_simulate_instance.simulate_batch( + param_combinations, n_jobs=1, backend='threading' + ) assert len(results) == len(param_combinations) for result in results: assert 'net' in result @@ -123,43 +128,41 @@ def test_simulate_batch(batch_simulate_instance, param_grid): batch_simulate_instance.simulate_batch(invalid_param_combinations) with pytest.raises(TypeError, match='n_jobs must be'): - batch_simulate_instance.simulate_batch(param_combinations, - n_jobs='invalid') + batch_simulate_instance.simulate_batch(param_combinations, n_jobs='invalid') with pytest.raises(ValueError, match="Invalid value for the 'backend'"): - batch_simulate_instance.simulate_batch(param_combinations, - backend='invalid') + batch_simulate_instance.simulate_batch(param_combinations, backend='invalid') with pytest.raises(TypeError, match='verbose must be'): - batch_simulate_instance.simulate_batch(param_combinations, - verbose='invalid') + batch_simulate_instance.simulate_batch(param_combinations, verbose='invalid') def test_run(batch_simulate_instance, param_grid): """Test the run method of the batch_simulate_instance.""" - results_without_cache = batch_simulate_instance.run(param_grid, - n_jobs=2, - return_output=True, - combinations=False, - backend='loky') + results_without_cache = batch_simulate_instance.run( + param_grid, n_jobs=2, return_output=True, combinations=False, backend='loky' + ) total_combinations = len( batch_simulate_instance._generate_param_combinations( - param_grid, combinations=False)) + param_grid, combinations=False + ) + ) assert results_without_cache is not None assert isinstance(results_without_cache, dict) assert 'simulated_data' in results_without_cache - assert len(results_without_cache['simulated_data'] - ) == total_combinations + assert len(results_without_cache['simulated_data']) == total_combinations batch_simulate_instance.clear_cache = True - results_with_cache = batch_simulate_instance.run(param_grid, - n_jobs=2, - return_output=True, - combinations=False, - backend='loky', - verbose=50) + results_with_cache = batch_simulate_instance.run( + param_grid, + n_jobs=2, + return_output=True, + combinations=False, + backend='loky', + verbose=50, + ) assert results_with_cache is not None assert isinstance(results_with_cache, dict) @@ -179,14 +182,12 @@ def test_run(batch_simulate_instance, param_grid): batch_simulate_instance.run(param_grid, verbose='invalid') -def test_save_load_and_overwrite(batch_simulate_instance, - param_grid, tmp_path): +def test_save_load_and_overwrite(batch_simulate_instance, param_grid, tmp_path): """Test the save method and its overwrite functionality.""" param_combinations = batch_simulate_instance._generate_param_combinations( - param_grid)[:3] - results = batch_simulate_instance.simulate_batch( - param_combinations, - n_jobs=2) + param_grid + )[:3] + results = batch_simulate_instance.simulate_batch(param_combinations, n_jobs=2) start_idx = 0 end_idx = len(results) @@ -197,13 +198,10 @@ def test_save_load_and_overwrite(batch_simulate_instance, assert os.path.exists(file_name) loaded_data = np.load(file_name, allow_pickle=True) - loaded_results = {key: loaded_data[key].tolist() - for key in loaded_data.files} + loaded_results = {key: loaded_data[key].tolist() for key in loaded_data.files} - original_data = np.stack([result['dpl'][0].data['agg'] - for result in results]) - loaded_data = np.stack([dpl[0].data['agg'] - for dpl in loaded_results['dpl']]) + original_data = np.stack([result['dpl'][0].data['agg'] for result in results]) + loaded_data = np.stack([dpl[0].data['agg'] for dpl in loaded_results['dpl']]) assert (original_data == loaded_data).all() @@ -218,13 +216,10 @@ def test_save_load_and_overwrite(batch_simulate_instance, batch_simulate_instance._save(results, start_idx, end_idx) loaded_data = np.load(file_name, allow_pickle=True) - loaded_results = {key: loaded_data[key].tolist() - for key in loaded_data.files} + loaded_results = {key: loaded_data[key].tolist() for key in loaded_data.files} - original_data = np.stack([result['dpl'][0].data['agg'] - for result in results]) - loaded_data = np.stack([dpl[0].data['agg'] - for dpl in loaded_results['dpl']]) + original_data = np.stack([result['dpl'][0].data['agg'] for result in results]) + loaded_data = np.stack([dpl[0].data['agg'] for dpl in loaded_results['dpl']]) assert (original_data == loaded_data).all() @@ -242,10 +237,9 @@ def test_save_load_and_overwrite(batch_simulate_instance, def test_load_results(batch_simulate_instance, param_grid, tmp_path): """Test loading results from a single file.""" param_combinations = batch_simulate_instance._generate_param_combinations( - param_grid)[:3] - results = batch_simulate_instance.simulate_batch( - param_combinations, - n_jobs=2) + param_grid + )[:3] + results = batch_simulate_instance.simulate_batch(param_combinations, n_jobs=2) start_idx = 0 end_idx = len(results) @@ -260,10 +254,8 @@ def test_load_results(batch_simulate_instance, param_grid, tmp_path): assert 'dpl' in loaded_results assert len(loaded_results['dpl']) == len(results) - original_data = np.stack([result['dpl'][0].data['agg'] - for result in results]) - loaded_data = np.stack([dpl[0].data['agg'] - for dpl in loaded_results['dpl']]) + original_data = np.stack([result['dpl'][0].data['agg'] for result in results]) + loaded_data = np.stack([dpl[0].data['agg'] for dpl in loaded_results['dpl']]) assert np.array_equal(original_data, loaded_data) for key in ['spiking', 'lfp', 'voltages', 'currents', 'calcium']: @@ -273,10 +265,11 @@ def test_load_results(batch_simulate_instance, param_grid, tmp_path): all_loaded_results = batch_simulate_instance.load_all_results() assert len(all_loaded_results) == 1 - all_loaded_data = np.stack([dpl[0].data['agg'] - for dpl in all_loaded_results[0]['dpl']]) + all_loaded_data = np.stack( + [dpl[0].data['agg'] for dpl in all_loaded_results[0]['dpl']] + ) assert np.array_equal(original_data, all_loaded_data) # Validation Tests with pytest.raises(TypeError, match='results must be'): - batch_simulate_instance._save("invalid_results", start_idx, end_idx) + batch_simulate_instance._save('invalid_results', start_idx, end_idx) diff --git a/hnn_core/tests/test_cell.py b/hnn_core/tests/test_cell.py index c2ca9fb96..91819c4d8 100644 --- a/hnn_core/tests/test_cell.py +++ b/hnn_core/tests/test_cell.py @@ -16,64 +16,52 @@ def test_cell(): load_custom_mechanisms() name = 'test' - pos = (0., 0., 0.) - sections = {'soma': Section(L=1, diam=5, Ra=3, cm=100, - end_pts=[[0, 0, 0], [0, 39., 0]])} - synapses = {'ampa': dict(e=0, tau1=0.5, tau2=5.)} - cell_tree = { - ('soma', 0): [('soma', 1)] + pos = (0.0, 0.0, 0.0) + sections = { + 'soma': Section(L=1, diam=5, Ra=3, cm=100, end_pts=[[0, 0, 0], [0, 39.0, 0]]) } + synapses = {'ampa': dict(e=0, tau1=0.5, tau2=5.0)} + cell_tree = {('soma', 0): [('soma', 1)]} sect_loc = {'proximal': 'soma'} # GID is assigned exactly once for each cell, either at initialisation... cell = Cell(name, pos, sections, synapses, sect_loc, cell_tree, gid=42) assert cell.gid == 42 - with pytest.raises(RuntimeError, - match='Global ID for this cell already assigned!'): + with pytest.raises(RuntimeError, match='Global ID for this cell already assigned!'): cell.gid += 1 # ... or later # cells can exist fine without gid cell = Cell(name, pos, sections, synapses, sect_loc, cell_tree) assert cell.gid is None # check that it's initialised to None - with pytest.raises(ValueError, - match='gid must be an integer'): + with pytest.raises(ValueError, match='gid must be an integer'): cell.gid = [1] cell.gid = 42 assert cell.gid == 42 - with pytest.raises(ValueError, - match='gid must be an integer'): + with pytest.raises(ValueError, match='gid must be an integer'): # test init checks gid - cell = Cell(name, pos, sections, synapses, sect_loc, cell_tree, - gid='one') + cell = Cell(name, pos, sections, synapses, sect_loc, cell_tree, gid='one') # test that ExpSyn always takes nrn.Segment, not float with pytest.raises(TypeError, match='secloc must be instance of'): - cell.syn_create(0.5, e=0., tau1=0.5, tau2=5.) + cell.syn_create(0.5, e=0.0, tau1=0.5, tau2=5.0) pickle.dumps(cell) # check cell object is picklable until built - bad_sections = {'blah': Section(L=1, diam=5, Ra=3, cm=100, - end_pts=[[0, 0, 0], [0, 39., 0]])} + bad_sections = { + 'blah': Section(L=1, diam=5, Ra=3, cm=100, end_pts=[[0, 0, 0], [0, 39.0, 0]]) + } # Check soma must be included in sections with pytest.raises(KeyError, match='soma must be defined'): cell = Cell(name, pos, bad_sections, synapses, sect_loc, cell_tree) sections = { 'soma': Section( - L=39, - diam=20, - cm=0.85, - Ra=200., - end_pts=[[0, 0, 0], [0, 39., 0]] + L=39, diam=20, cm=0.85, Ra=200.0, end_pts=[[0, 0, 0], [0, 39.0, 0]] ) } sections['soma'].syns = ['ampa'] sections['soma'].mechs = { - 'km': { - 'gbar_km': 60 - }, - 'ca': { - 'gbar_ca': lambda x: 3e-3 * x - } + 'km': {'gbar_km': 60}, + 'ca': {'gbar_ca': lambda x: 3e-3 * x}, } cell = Cell(name, pos, sections, synapses, sect_loc, cell_tree) @@ -82,8 +70,7 @@ def test_cell(): cell.build() assert 'soma' in cell._nrn_sections assert cell._nrn_sections['soma'].L == sections['soma'].L - assert cell._nrn_sections['soma'].gbar_km == sections[ - 'soma'].mechs['km']['gbar_km'] + assert cell._nrn_sections['soma'].gbar_km == sections['soma'].mechs['km']['gbar_km'] # test building cell with a dipole oriented to a nonexitent section with pytest.raises(ValueError, match='sec_name_apical must be an'): cell.build(sec_name_apical='blah') @@ -110,10 +97,10 @@ def test_cell(): cell1 = pyramidal(cell_name='L5Pyr') # Test other not NotImplemented for Cell Class - assert (cell1 == "cell") is False + assert (cell1 == 'cell') is False # Test other not NotImplemented for Section Class - assert (cell1.sections['soma'] == "section") is False + assert (cell1.sections['soma'] == 'section') is False end_pts_original = list() end_pts_new = list() @@ -131,8 +118,7 @@ def test_cell(): cell1.plot_morphology(show=True) for end_pt_original, end_pt_new in zip(end_pts_original, end_pts_new): for pt_original, pt_new in zip(end_pt_original, end_pt_new): - np.testing.assert_almost_equal(list(np.array(pt_original) * 2), - pt_new, 5) + np.testing.assert_almost_equal(list(np.array(pt_original) * 2), pt_new, 5) for sec_name in cell1.sections.keys(): section = cell1.sections[sec_name] @@ -182,17 +168,17 @@ def test_artificial_cell(): # GID is assigned exactly once for each cell, either at initialisation... cell = _ArtificialCell(event_times, threshold, gid=42) assert cell.gid == 42 - with pytest.raises(RuntimeError, - match='Global ID for this cell already assigned!'): + with pytest.raises(RuntimeError, match='Global ID for this cell already assigned!'): cell.gid += 1 - with pytest.raises(ValueError, - match='gid must be an integer'): + with pytest.raises(ValueError, match='gid must be an integer'): cell.gid = [1] # ... or later cell = _ArtificialCell(event_times, threshold) # fine without gid assert cell.gid is None # check that it's initialised to None cell.gid = 42 assert cell.gid == 42 - with pytest.raises(ValueError, # test init checks gid - match='gid must be an integer'): + with pytest.raises( + ValueError, # test init checks gid + match='gid must be an integer', + ): cell = _ArtificialCell(event_times, threshold, gid='one') diff --git a/hnn_core/tests/test_cell_response.py b/hnn_core/tests/test_cell_response.py index 32eb61ea4..488e298e4 100644 --- a/hnn_core/tests/test_cell_response.py +++ b/hnn_core/tests/test_cell_response.py @@ -14,16 +14,21 @@ def test_cell_response(tmp_path): # Round-trip test spike_times = [[2.3456, 7.89], [4.2812, 93.2]] spike_gids = [[1, 3], [5, 7]] - spike_types = [['L2_pyramidal', 'L2_basket'], - ['L5_pyramidal', 'L5_basket']] - tstart, tstop, fs = 0.1, 98.4, 1000. + spike_types = [['L2_pyramidal', 'L2_basket'], ['L5_pyramidal', 'L5_basket']] + tstart, tstop, fs = 0.1, 98.4, 1000.0 sim_times = np.arange(tstart, tstop, 1 / fs) - gid_ranges = {'L2_pyramidal': range(1, 2), 'L2_basket': range(3, 4), - 'L5_pyramidal': range(5, 6), 'L5_basket': range(7, 8)} - cell_response = CellResponse(spike_times=spike_times, - spike_gids=spike_gids, - spike_types=spike_types, - times=sim_times) + gid_ranges = { + 'L2_pyramidal': range(1, 2), + 'L2_basket': range(3, 4), + 'L5_pyramidal': range(5, 6), + 'L5_basket': range(7, 8), + } + cell_response = CellResponse( + spike_times=spike_times, + spike_gids=spike_gids, + spike_types=spike_types, + times=sim_times, + ) assert set(cell_response.cell_types) == set(gid_ranges.keys()) assert cell_response.spike_times_by_type['L2_basket'] == [[7.89], []] @@ -31,97 +36,115 @@ def test_cell_response(tmp_path): kwargs_hist = dict(alpha=0.25) fig = cell_response.plot_spikes_hist(show=False, **kwargs_hist) - assert all(patch.get_alpha() == kwargs_hist['alpha'] - for patch in fig.axes[0].patches - ), "Alpha value not applied to all patches" + assert all( + patch.get_alpha() == kwargs_hist['alpha'] for patch in fig.axes[0].patches + ), 'Alpha value not applied to all patches' # Testing writing using txt files - with pytest.warns(DeprecationWarning, - match="Writing cell response to txt files is " - "deprecated"): + with pytest.warns( + DeprecationWarning, match='Writing cell response to txt files is ' 'deprecated' + ): cell_response.write(tmp_path / 'spk_%d.txt') # Testing reading from txt files - with pytest.warns(DeprecationWarning, - match="Reading cell response from txt files is " - "deprecated"): + with pytest.warns( + DeprecationWarning, + match='Reading cell response from txt files is ' 'deprecated', + ): assert cell_response == read_spikes(tmp_path / 'spk_*.txt') - assert ("CellResponse | 2 simulation trials" in repr(cell_response)) + assert 'CellResponse | 2 simulation trials' in repr(cell_response) # reset clears all recorded variables, but leaves simulation time intact assert len(cell_response.times) == len(sim_times) - sim_attributes = ['_spike_times', '_spike_gids', '_spike_types', - '_vsec', '_isec', '_ca'] + sim_attributes = [ + '_spike_times', + '_spike_gids', + '_spike_types', + '_vsec', + '_isec', + '_ca', + ] net_attributes = ['_times', '_cell_type_names'] # `Network.__init__` # creates these check that we always know which response attributes are # simulated see #291 for discussion; objective is to keep cell_response # size small - assert list(cell_response.__dict__.keys()) == \ - sim_attributes + net_attributes + assert list(cell_response.__dict__.keys()) == sim_attributes + net_attributes # Test recovery of empty spike files - empty_spike = CellResponse(spike_times=[[], []], spike_gids=[[], []], - spike_types=[[], []]) + empty_spike = CellResponse( + spike_times=[[], []], spike_gids=[[], []], spike_types=[[], []] + ) empty_spike.write(tmp_path / 'empty_spk_%d.txt') empty_spike.write(tmp_path / 'empty_spk.txt') empty_spike.write(tmp_path / 'empty_spk_{0}.txt') assert empty_spike == read_spikes(tmp_path / 'empty_spk_*.txt') - assert ("CellResponse | 2 simulation trials" in repr(empty_spike)) - - with pytest.raises(TypeError, - match="spike_times should be a list of lists"): - cell_response = CellResponse(spike_times=([2.3456, 7.89], - [4.2812, 93.2]), - spike_gids=spike_gids, - spike_types=spike_types) - - with pytest.raises(TypeError, - match="spike_times should be a list of lists"): - cell_response = CellResponse(spike_times=[1, 2], spike_gids=spike_gids, - spike_types=spike_types) - - with pytest.raises(ValueError, match="spike times, gids, and types should " - "be lists of the same length"): - cell_response = CellResponse(spike_times=[[2.3456, 7.89]], - spike_gids=spike_gids, - spike_types=spike_types) - - cell_response = CellResponse(spike_times=spike_times, - spike_gids=spike_gids, - spike_types=spike_types) - - with pytest.raises(TypeError, match="spike_types should be str, " - "list, dict, or None"): + assert 'CellResponse | 2 simulation trials' in repr(empty_spike) + + with pytest.raises(TypeError, match='spike_times should be a list of lists'): + cell_response = CellResponse( + spike_times=([2.3456, 7.89], [4.2812, 93.2]), + spike_gids=spike_gids, + spike_types=spike_types, + ) + + with pytest.raises(TypeError, match='spike_times should be a list of lists'): + cell_response = CellResponse( + spike_times=[1, 2], spike_gids=spike_gids, spike_types=spike_types + ) + + with pytest.raises( + ValueError, + match='spike times, gids, and types should ' 'be lists of the same length', + ): + cell_response = CellResponse( + spike_times=[[2.3456, 7.89]], spike_gids=spike_gids, spike_types=spike_types + ) + + cell_response = CellResponse( + spike_times=spike_times, spike_gids=spike_gids, spike_types=spike_types + ) + + with pytest.raises( + TypeError, match='spike_types should be str, ' 'list, dict, or None' + ): cell_response.plot_spikes_hist(spike_types=1, show=False) - with pytest.raises(TypeError, match=r"spike_types\[ev\] must be a list\. " - r"Got int\."): + with pytest.raises( + TypeError, match=r'spike_types\[ev\] must be a list\. ' r'Got int\.' + ): cell_response.plot_spikes_hist(spike_types={'ev': 1}, show=False) - with pytest.raises(ValueError, match=r"Elements of spike_types must map to" - r" mutually exclusive input types\. L2_basket is found" - r" more than once\."): - cell_response.plot_spikes_hist(spike_types={'ev': - ['L2_basket', 'L2_b']}, - show=False) - - with pytest.raises(ValueError, match="No input types found for ABC"): + with pytest.raises( + ValueError, + match=r'Elements of spike_types must map to' + r' mutually exclusive input types\. L2_basket is found' + r' more than once\.', + ): + cell_response.plot_spikes_hist( + spike_types={'ev': ['L2_basket', 'L2_b']}, show=False + ) + + with pytest.raises(ValueError, match='No input types found for ABC'): cell_response.plot_spikes_hist(spike_types='ABC', show=False) - with pytest.raises(ValueError, match="tstart and tstop must be of type " - "int or float"): - cell_response.mean_rates(tstart=0.1, tstop='ABC', - gid_ranges=gid_ranges) + with pytest.raises( + ValueError, match='tstart and tstop must be of type ' 'int or float' + ): + cell_response.mean_rates(tstart=0.1, tstop='ABC', gid_ranges=gid_ranges) - with pytest.raises(ValueError, match="tstop must be greater than tstart"): + with pytest.raises(ValueError, match='tstop must be greater than tstart'): cell_response.mean_rates(tstart=0.1, tstop=-1.0, gid_ranges=gid_ranges) - with pytest.raises(ValueError, match="Invalid mean_type. Valid " - "arguments include 'all', 'trial', or 'cell'."): - cell_response.mean_rates(tstart=tstart, tstop=tstop, - gid_ranges=gid_ranges, mean_type='ABC') + with pytest.raises( + ValueError, + match='Invalid mean_type. Valid ' + "arguments include 'all', 'trial', or 'cell'.", + ): + cell_response.mean_rates( + tstart=tstart, tstop=tstop, gid_ranges=gid_ranges, mean_type='ABC' + ) test_rate = (1 / (tstop - tstart)) * 1000 @@ -129,19 +152,20 @@ def test_cell_response(tmp_path): 'L5_pyramidal': test_rate / 2, 'L5_basket': test_rate / 2, 'L2_pyramidal': test_rate / 2, - 'L2_basket': test_rate / 2} - assert cell_response.mean_rates(tstart, tstop, gid_ranges, - mean_type='trial') == { + 'L2_basket': test_rate / 2, + } + assert cell_response.mean_rates(tstart, tstop, gid_ranges, mean_type='trial') == { 'L5_pyramidal': [0.0, test_rate], 'L5_basket': [0.0, test_rate], 'L2_pyramidal': [test_rate, 0.0], - 'L2_basket': [test_rate, 0.0]} - assert cell_response.mean_rates(tstart, tstop, gid_ranges, - mean_type='cell') == { + 'L2_basket': [test_rate, 0.0], + } + assert cell_response.mean_rates(tstart, tstop, gid_ranges, mean_type='cell') == { 'L5_pyramidal': [[0.0], [test_rate]], 'L5_basket': [[0.0], [test_rate]], 'L2_pyramidal': [[test_rate], [0.0]], - 'L2_basket': [[test_rate], [0.0]]} + 'L2_basket': [[test_rate], [0.0]], + } # Write spike file with no 'types' column for fname in sorted(glob(str(tmp_path / 'spk_*.txt'))): @@ -153,13 +177,21 @@ def test_cell_response(tmp_path): assert cell_response.spike_types == spike_types # Check for gid_ranges errors - with pytest.raises(ValueError, match="gid_ranges must be provided if " - "spike types are unspecified in the file "): + with pytest.raises( + ValueError, + match='gid_ranges must be provided if ' + 'spike types are unspecified in the file ', + ): cell_response = read_spikes(tmp_path / 'spk_*.txt') - with pytest.raises(ValueError, match="gid_ranges should contain only " - "disjoint sets of gid values"): - gid_ranges = {'L2_pyramidal': range(3), 'L2_basket': range(2, 4), - 'L5_pyramidal': range(4, 6), 'L5_basket': range(6, 8)} - cell_response = read_spikes(tmp_path / 'spk_*.txt', - gid_ranges=gid_ranges) + with pytest.raises( + ValueError, + match='gid_ranges should contain only ' 'disjoint sets of gid values', + ): + gid_ranges = { + 'L2_pyramidal': range(3), + 'L2_basket': range(2, 4), + 'L5_pyramidal': range(4, 6), + 'L5_basket': range(6, 8), + } + cell_response = read_spikes(tmp_path / 'spk_*.txt', gid_ranges=gid_ranges) plt.close('all') diff --git a/hnn_core/tests/test_cells_default.py b/hnn_core/tests/test_cells_default.py index 39cd335d1..021730fd5 100644 --- a/hnn_core/tests/test_cells_default.py +++ b/hnn_core/tests/test_cells_default.py @@ -23,15 +23,21 @@ def test_cells_default(): # specified in get_L5Pyr_params_default (or overridden in a params file). # Note that the lengths implied by _secs_L5Pyr are completely ignored: # NEURON extends the sections as needed to match the sec.L 's - vertical_secs = ['basal_1', 'soma', 'apical_trunk', 'apical_1', 'apical_2', - 'apical_tuft'] + vertical_secs = [ + 'basal_1', + 'soma', + 'apical_trunk', + 'apical_1', + 'apical_2', + 'apical_tuft', + ] for sec_name in vertical_secs: sec = l5p._nrn_sections[sec_name] vert_len = np.abs(sec.z3d(1) - sec.z3d(0)) assert np.allclose(vert_len, sec.L) # smoke test to check if cell can be used in simulation - h.load_file("stdrun.hoc") + h.load_file('stdrun.hoc') h.tstop = 40 h.dt = 0.025 h.celsius = 37 @@ -42,8 +48,8 @@ def test_cells_default(): stim = h.IClamp(l5p._nrn_sections['soma'](0.5)) stim.delay = 5 - stim.dur = 5. - stim.amp = 2. + stim.dur = 5.0 + stim.amp = 2.0 h.finitialize() h.fcurrent() diff --git a/hnn_core/tests/test_dipole.py b/hnn_core/tests/test_dipole.py index 9beea6b9d..f665c161d 100644 --- a/hnn_core/tests/test_dipole.py +++ b/hnn_core/tests/test_dipole.py @@ -38,9 +38,13 @@ def test_dipole(tmp_path, run_hnn_core_fixture): dipole.smooth(window_len=params['dipole_smooth_win']) with pytest.raises(AssertionError): assert_allclose(dipole.data['agg'], dipole_raw.data['agg']) - assert_allclose(dipole.data['agg'], - (params['dipole_scalefctr'] * dipole_raw.smooth( - params['dipole_smooth_win']).data['agg'])) + assert_allclose( + dipole.data['agg'], + ( + params['dipole_scalefctr'] + * dipole_raw.smooth(params['dipole_smooth_win']).data['agg'] + ), + ) dipole.plot(show=False) plot_dipole([dipole, dipole], show=False) @@ -48,78 +52,78 @@ def test_dipole(tmp_path, run_hnn_core_fixture): # Test wrong argument to plot_dipole() with pytest.raises(TypeError, match='dpl must be an instance of'): plot_dipole([dipole, 10], show=False) - with pytest.raises(AttributeError, match="'numpy.ndarray' object has no" - " attribute 'append'"): + with pytest.raises( + AttributeError, match="'numpy.ndarray' object has no" " attribute 'append'" + ): plot_dipole(np.array([dipole, dipole]), average=True, show=False) # Test IO for txt files - with pytest.warns(DeprecationWarning, - match="Writing dipole to txt file is " - "deprecated"): + with pytest.warns( + DeprecationWarning, match='Writing dipole to txt file is ' 'deprecated' + ): dipole.write(dpl_out_fname) dipole_read = read_dipole(dpl_out_fname) assert_allclose(dipole_read.times, dipole.times, rtol=0, atol=0.00051) for dpl_key in dipole.data.keys(): - assert_allclose(dipole_read.data[dpl_key], - dipole.data[dpl_key], rtol=0, atol=0.000051) + assert_allclose( + dipole_read.data[dpl_key], dipole.data[dpl_key], rtol=0, atol=0.000051 + ) # Test IO for hdf5 files dipole.write(dpl_out_hdf5_fname) dipole_read_hdf5 = read_dipole(dpl_out_hdf5_fname) assert_allclose(dipole_read_hdf5.times, dipole.times, rtol=0.00051, atol=0) for dpl_key in dipole.data.keys(): - assert_allclose(dipole_read_hdf5.data[dpl_key], - dipole.data[dpl_key], rtol=0.000051, atol=0) + assert_allclose( + dipole_read_hdf5.data[dpl_key], dipole.data[dpl_key], rtol=0.000051, atol=0 + ) # Test read for hdf5 files using Path object dipole_read_hdf5 = read_dipole(Path(dpl_out_hdf5_fname)) assert_allclose(dipole_read_hdf5.times, dipole.times, rtol=0.00051, atol=0) for dpl_key in dipole.data.keys(): - assert_allclose(dipole_read_hdf5.data[dpl_key], - dipole.data[dpl_key], rtol=0.000051, atol=0) + assert_allclose( + dipole_read_hdf5.data[dpl_key], dipole.data[dpl_key], rtol=0.000051, atol=0 + ) # Testing when overwrite is False and same filename is used - with pytest.raises(FileExistsError, - match="File already exists at path "): + with pytest.raises(FileExistsError, match='File already exists at path '): dipole.write(dpl_out_hdf5_fname, overwrite=False) # Testing for wrong extension provided - with pytest.raises(NameError, - match="File extension should be either txt or hdf5"): + with pytest.raises(NameError, match='File extension should be either txt or hdf5'): dipole.write(tmp_path / 'dpl.xls') # Testing File Not Found Error - with pytest.raises(FileNotFoundError, - match="File not found at "): + with pytest.raises(FileNotFoundError, match='File not found at '): read_dipole(tmp_path / 'dpl1.hdf5') # dpls with different scale_applied should not be averaged. - with pytest.raises(RuntimeError, - match="All dipoles must be scaled equally"): + with pytest.raises(RuntimeError, match='All dipoles must be scaled equally'): dipole_avg = average_dipoles([dipole, dipole_read]) # Checking object type field not exists error dummy_data = dict() - dummy_data['objective'] = "Check Object type errors" + dummy_data['objective'] = 'Check Object type errors' write_hdf5(tmp_path / 'not_dpl.hdf5', dummy_data) - with pytest.raises(NameError, - match="The given file is not compatible."): + with pytest.raises(NameError, match='The given file is not compatible.'): read_dipole(tmp_path / 'not_dpl.hdf5') # Checking wrong object type error - dummy_data['object_type'] = "dpl" + dummy_data['object_type'] = 'dpl' write_hdf5(tmp_path / 'not_dpl.hdf5', dummy_data, overwrite=True) - with pytest.raises(ValueError, - match="The object should be of type Dipole."): + with pytest.raises(ValueError, match='The object should be of type Dipole.'): read_dipole(tmp_path / 'not_dpl.hdf5') # force the scale_applied to be identical across dpls to allow averaging. dipole.scale_applied = dipole_read.scale_applied # average two identical dipole objects dipole_avg = average_dipoles([dipole, dipole_read]) for dpl_key in dipole_avg.data.keys(): - assert_allclose(dipole_read.data[dpl_key], - dipole_avg.data[dpl_key], rtol=0, atol=0.000051) + assert_allclose( + dipole_read.data[dpl_key], dipole_avg.data[dpl_key], rtol=0, atol=0.000051 + ) - with pytest.raises(ValueError, match="Dipole at index 0 was already an " - "average of 2 trials"): + with pytest.raises( + ValueError, match='Dipole at index 0 was already an ' 'average of 2 trials' + ): dipole_avg = average_dipoles([dipole_avg, dipole_read]) # average an n_of_1 dipole list @@ -129,41 +133,56 @@ def test_dipole(tmp_path, run_hnn_core_fixture): dipole_read.data[dpl_key], single_dpl_avg.data[dpl_key], rtol=0, - atol=0.000051) + atol=0.000051, + ) # average dipole list with one dipole object and a zero dipole object n_times = len(dipole_read.data['agg']) - dpl_null = Dipole(np.zeros(n_times, ), np.zeros((n_times, 3))) + dpl_null = Dipole( + np.zeros( + n_times, + ), + np.zeros((n_times, 3)), + ) dpl_1 = [dipole, dpl_null] dpl_avg = average_dipoles(dpl_1) for dpl_key in dpl_avg.data.keys(): - assert_allclose(dpl_1[0].data[dpl_key] / 2., dpl_avg.data[dpl_key]) + assert_allclose(dpl_1[0].data[dpl_key] / 2.0, dpl_avg.data[dpl_key]) # Test experimental dipole dipole_exp = Dipole(times, data[:, 1]) dipole_exp.write(dpl_out_fname) dipole_exp_read = read_dipole(dpl_out_fname) - assert_allclose(dipole_exp.data['agg'], dipole_exp_read.data['agg'], - rtol=1e-2) + assert_allclose(dipole_exp.data['agg'], dipole_exp_read.data['agg'], rtol=1e-2) dipole_exp_avg = average_dipoles([dipole_exp, dipole_exp]) assert_allclose(dipole_exp.data['agg'], dipole_exp_avg.data['agg']) # XXX all below to be deprecated in 0.3 - dpls_raw, net = run_hnn_core_fixture(backend='joblib', n_jobs=1, - reduced=True, record_isec='soma', - record_vsec='soma', record_ca='soma') + dpls_raw, net = run_hnn_core_fixture( + backend='joblib', + n_jobs=1, + reduced=True, + record_isec='soma', + record_vsec='soma', + record_ca='soma', + ) # test deprecation of postproc - with pytest.warns(DeprecationWarning, - match='The postproc-argument is deprecated'): - dpls, _ = run_hnn_core_fixture(backend='joblib', n_jobs=1, - reduced=True, record_isec='soma', - record_vsec='soma', record_ca='soma', - postproc=True) + with pytest.warns(DeprecationWarning, match='The postproc-argument is deprecated'): + dpls, _ = run_hnn_core_fixture( + backend='joblib', + n_jobs=1, + reduced=True, + record_isec='soma', + record_vsec='soma', + record_ca='soma', + postproc=True, + ) with pytest.raises(AssertionError): assert_allclose(dpls[0].data['agg'], dpls_raw[0].data['agg']) - dpls_raw[0]._post_proc(net._params['dipole_smooth_win'], - net._params['dipole_scalefctr']) + dpls_raw[0]._post_proc( + net._params['dipole_smooth_win'], net._params['dipole_scalefctr'] + ) assert_allclose(dpls_raw[0].data['agg'], dpls[0].data['agg']) plt.close('all') @@ -174,25 +193,30 @@ def test_dipole_simulation(): hnn_core_root = op.dirname(hnn_core.__file__) params_fname = op.join(hnn_core_root, 'param', 'default.json') params = read_params(params_fname) - params.update({'dipole_smooth_win': 5, - 't_evprox_1': 5, - 't_evdist_1': 10, - 't_evprox_2': 20}) - net = jones_2009_model(params, add_drives_from_params=True, - mesh_shape=(3, 3)) - with pytest.raises(ValueError, match="Invalid number of simulations: 0"): - simulate_dipole(net, tstop=25., n_trials=0) - with pytest.raises(ValueError, match="Invalid value for the"): - simulate_dipole(net, tstop=25., n_trials=1, record_vsec='abc') - with pytest.raises(ValueError, match="Invalid value for the"): - simulate_dipole(net, tstop=25., n_trials=1, record_vsec=False, - record_isec='abc') - with pytest.raises(ValueError, match="Invalid value for the"): - simulate_dipole(net, tstop=25., n_trials=1, record_vsec=False, - record_isec=False, record_ca='abc') + params.update( + {'dipole_smooth_win': 5, 't_evprox_1': 5, 't_evdist_1': 10, 't_evprox_2': 20} + ) + net = jones_2009_model(params, add_drives_from_params=True, mesh_shape=(3, 3)) + with pytest.raises(ValueError, match='Invalid number of simulations: 0'): + simulate_dipole(net, tstop=25.0, n_trials=0) + with pytest.raises(ValueError, match='Invalid value for the'): + simulate_dipole(net, tstop=25.0, n_trials=1, record_vsec='abc') + with pytest.raises(ValueError, match='Invalid value for the'): + simulate_dipole( + net, tstop=25.0, n_trials=1, record_vsec=False, record_isec='abc' + ) + with pytest.raises(ValueError, match='Invalid value for the'): + simulate_dipole( + net, + tstop=25.0, + n_trials=1, + record_vsec=False, + record_isec=False, + record_ca='abc', + ) # test Network.copy() returns 'bare' network after simulating - dpl = simulate_dipole(net, tstop=25., n_trials=1)[0] + dpl = simulate_dipole(net, tstop=25.0, n_trials=1)[0] assert net._dt == 0.025 assert net._tstop == 25.0 net_copy = net.copy() @@ -218,12 +242,22 @@ def test_cell_response_backends(run_hnn_core_fixture): # reduced simulation has n_trials=2 trial_idx, n_trials, gid = 0, 2, 7 - _, joblib_net = run_hnn_core_fixture(backend='joblib', n_jobs=1, - reduced=True, record_vsec='all', - record_isec='soma', record_ca='all') - _, mpi_net = run_hnn_core_fixture(backend='mpi', n_procs=2, reduced=True, - record_vsec='all', record_isec='soma', - record_ca='all') + _, joblib_net = run_hnn_core_fixture( + backend='joblib', + n_jobs=1, + reduced=True, + record_vsec='all', + record_isec='soma', + record_ca='all', + ) + _, mpi_net = run_hnn_core_fixture( + backend='mpi', + n_procs=2, + reduced=True, + record_vsec='all', + record_isec='soma', + record_ca='all', + ) n_times = len(joblib_net.cell_response.times) @@ -232,31 +266,29 @@ def test_cell_response_backends(run_hnn_core_fixture): assert len(joblib_net.cell_response.ca) == n_trials assert len(joblib_net.cell_response.vsec[trial_idx][gid]) == 8 # num sec assert len(joblib_net.cell_response.isec[trial_idx][gid]) == 1 - assert len(joblib_net.cell_response.vsec[ - trial_idx][gid]['apical_1']) == n_times - assert len(joblib_net.cell_response.isec[ - trial_idx][gid]['soma']['soma_gabaa']) == n_times + assert len(joblib_net.cell_response.vsec[trial_idx][gid]['apical_1']) == n_times + assert ( + len(joblib_net.cell_response.isec[trial_idx][gid]['soma']['soma_gabaa']) + == n_times + ) assert len(mpi_net.cell_response.vsec) == n_trials assert len(mpi_net.cell_response.isec) == n_trials assert len(mpi_net.cell_response.vsec[trial_idx][gid]) == 8 # num sec assert len(mpi_net.cell_response.isec[trial_idx][gid]) == 1 - assert len(mpi_net.cell_response.vsec[ - trial_idx][gid]['apical_1']) == n_times - assert len(mpi_net.cell_response.isec[ - trial_idx][gid]['soma']['soma_gabaa']) == n_times + assert len(mpi_net.cell_response.vsec[trial_idx][gid]['apical_1']) == n_times + assert ( + len(mpi_net.cell_response.isec[trial_idx][gid]['soma']['soma_gabaa']) == n_times + ) assert mpi_net.cell_response.vsec == joblib_net.cell_response.vsec assert mpi_net.cell_response.isec == joblib_net.cell_response.isec # test if calcium concentration is stored correctly (only L5 pyramidal) gid = joblib_net.gid_ranges['L5_pyramidal'][0] assert len(joblib_net.cell_response.ca[trial_idx][gid]) == 9 - assert len(joblib_net.cell_response.ca[ - trial_idx][gid]['soma']) == n_times + assert len(joblib_net.cell_response.ca[trial_idx][gid]['soma']) == n_times assert len(mpi_net.cell_response.ca[trial_idx][gid]) == 9 - assert len(mpi_net.cell_response.ca[ - trial_idx][gid]['soma']) == n_times - assert len(mpi_net.cell_response.ca[ - trial_idx][gid]['apical_1']) == n_times + assert len(mpi_net.cell_response.ca[trial_idx][gid]['soma']) == n_times + assert len(mpi_net.cell_response.ca[trial_idx][gid]['apical_1']) == n_times # Test if spike time falls within depolarization window above v_thresh v_thresh = 0.0 times = np.array(joblib_net.cell_response.times) @@ -265,38 +297,52 @@ def test_cell_response_backends(run_hnn_core_fixture): vsoma = np.array(joblib_net.cell_response.vsec[trial_idx][gid]['soma']) v_mask = vsoma > v_thresh - assert np.all([spike_times[spike_gids == gid] > times[v_mask][0], - spike_times[spike_gids == gid] < times[v_mask][-1]]) + assert np.all( + [ + spike_times[spike_gids == gid] > times[v_mask][0], + spike_times[spike_gids == gid] < times[v_mask][-1], + ] + ) # test that event times before and after simulation are the same for drive_name, drive in joblib_net.external_drives.items(): gid_ran = joblib_net.gid_ranges[drive_name] for idx_drive, event_times in enumerate(drive['events'][trial_idx]): - net_ets = [spike_times[i] for i, g in enumerate(spike_gids) if - g == gid_ran[idx_drive]] + net_ets = [ + spike_times[i] + for i, g in enumerate(spike_gids) + if g == gid_ran[idx_drive] + ] assert_allclose(np.array(event_times), np.array(net_ets)) def test_rmse(): """Test to check RMSE calculation""" - data_url = ('https://raw.githubusercontent.com/jonescompneurolab/hnn/' - 'master/data/MEG_detection_data/yes_trial_S1_ERP_all_avg.txt') + data_url = ( + 'https://raw.githubusercontent.com/jonescompneurolab/hnn/' + 'master/data/MEG_detection_data/yes_trial_S1_ERP_all_avg.txt' + ) if not op.exists('yes_trial_S1_ERP_all_avg.txt'): urlretrieve(data_url, 'yes_trial_S1_ERP_all_avg.txt') extdata = np.loadtxt('yes_trial_S1_ERP_all_avg.txt') - exp_dpl = Dipole(times=extdata[:, 0], - data=np.c_[extdata[:, 1], extdata[:, 1], extdata[:, 1]]) + exp_dpl = Dipole( + times=extdata[:, 0], data=np.c_[extdata[:, 1], extdata[:, 1], extdata[:, 1]] + ) hnn_core_root = op.join(op.dirname(hnn_core.__file__)) params_fname = op.join(hnn_core_root, 'param', 'default.json') params = read_params(params_fname) expected_rmse = 0.1 - test_dpl = Dipole(times=extdata[:, 0], - data=np.c_[extdata[:, 1] + expected_rmse, - extdata[:, 1] + expected_rmse, - extdata[:, 1] + expected_rmse]) + test_dpl = Dipole( + times=extdata[:, 0], + data=np.c_[ + extdata[:, 1] + expected_rmse, + extdata[:, 1] + expected_rmse, + extdata[:, 1] + expected_rmse, + ], + ) avg_rmse = _rmse(test_dpl, exp_dpl, tstop=params['tstop']) assert_allclose(avg_rmse, expected_rmse) diff --git a/hnn_core/tests/test_drives.py b/hnn_core/tests/test_drives.py index 98efab2b2..c9f789b7a 100644 --- a/hnn_core/tests/test_drives.py +++ b/hnn_core/tests/test_drives.py @@ -8,11 +8,16 @@ import hnn_core from hnn_core import Network, read_params -from hnn_core.drives import (_drive_cell_event_times, _get_prng, - _create_extpois, _create_bursty_input) +from hnn_core.drives import ( + _drive_cell_event_times, + _get_prng, + _create_extpois, + _create_bursty_input, +) from hnn_core.network import pick_connection from hnn_core.network_models import jones_2009_model from hnn_core import simulate_dipole + hnn_core_root = op.dirname(hnn_core.__file__) @@ -32,34 +37,41 @@ def test_external_drive_times(): drive_type = 'invalid_drive' dynamics = dict(mu=5, sigma=0.5, numspikes=1) tstop = 10 - pytest.raises(ValueError, _drive_cell_event_times, - 'invalid_drive', dynamics, tstop) - pytest.raises(ValueError, _drive_cell_event_times, - 'ss', dynamics, tstop) # ambiguous + pytest.raises(ValueError, _drive_cell_event_times, 'invalid_drive', dynamics, tstop) + pytest.raises( + ValueError, _drive_cell_event_times, 'ss', dynamics, tstop + ) # ambiguous # validate poisson input time interval drive_type = 'poisson' - dynamics = {'tstart': 0, 'tstop': 250.0, 'rate_constant': - {'L2_basket': 1, 'L2_pyramidal': 140.0, 'L5_basket': 1, - 'L5_pyramidal': 40.0}} + dynamics = { + 'tstart': 0, + 'tstop': 250.0, + 'rate_constant': { + 'L2_basket': 1, + 'L2_pyramidal': 140.0, + 'L5_basket': 1, + 'L5_pyramidal': 40.0, + }, + } with pytest.raises(ValueError, match='The end time for Poisson input'): dynamics['tstop'] = -1 - event_times = _drive_cell_event_times(drive_type=drive_type, - dynamics=dynamics, - tstop=tstop) + event_times = _drive_cell_event_times( + drive_type=drive_type, dynamics=dynamics, tstop=tstop + ) with pytest.raises(ValueError, match='The start time for Poisson'): dynamics['tstop'] = tstop dynamics['tstart'] = -1 - event_times = _drive_cell_event_times(drive_type=drive_type, - dynamics=dynamics, - tstop=tstop) + event_times = _drive_cell_event_times( + drive_type=drive_type, dynamics=dynamics, tstop=tstop + ) # checks the poisson spike train generation prng = np.random.RandomState() - lamtha = 50. + lamtha = 50.0 event_times = _create_extpois(t0=0, T=100000, lamtha=lamtha, prng=prng) event_intervals = np.diff(event_times) - assert pytest.approx(event_intervals.mean(), abs=1.) == 1000 * 1 / lamtha + assert pytest.approx(event_intervals.mean(), abs=1.0) == 1000 * 1 / lamtha with pytest.raises(ValueError, match='The start time for Poisson'): _create_extpois(t0=-5, T=5, lamtha=lamtha, prng=prng) @@ -72,46 +84,70 @@ def test_external_drive_times(): t0 = 0 t0_stdev = 5 tstop = 100 - f_input = 20. + f_input = 20.0 events_per_cycle = 3 cycle_events_isi = 7 - events_jitter_std = 5. + events_jitter_std = 5.0 prng, prng2 = _get_prng(seed=0, gid=5) event_times = _create_bursty_input( - t0=t0, t0_stdev=t0_stdev, tstop=tstop, - f_input=f_input, events_jitter_std=events_jitter_std, - events_per_cycle=events_per_cycle, cycle_events_isi=cycle_events_isi, - prng=prng, prng2=prng2) + t0=t0, + t0_stdev=t0_stdev, + tstop=tstop, + f_input=f_input, + events_jitter_std=events_jitter_std, + events_per_cycle=events_per_cycle, + cycle_events_isi=cycle_events_isi, + prng=prng, + prng2=prng2, + ) events_per_cycle = 5 cycle_events_isi = 20 - with pytest.raises(ValueError, - match=r'(?s)Burst duration .* cannot be greater than'): - _create_bursty_input(t0=t0, t0_stdev=t0_stdev, - tstop=tstop, f_input=f_input, - events_jitter_std=events_jitter_std, - events_per_cycle=events_per_cycle, - cycle_events_isi=cycle_events_isi, - prng=prng, prng2=prng2) + with pytest.raises( + ValueError, match=r'(?s)Burst duration .* cannot be greater than' + ): + _create_bursty_input( + t0=t0, + t0_stdev=t0_stdev, + tstop=tstop, + f_input=f_input, + events_jitter_std=events_jitter_std, + events_per_cycle=events_per_cycle, + cycle_events_isi=cycle_events_isi, + prng=prng, + prng2=prng2, + ) def test_drive_seeds(setup_net): """Test that unique spike times are generated across trials""" net = setup_net - weights_ampa = {'L2_basket': 0.3, 'L2_pyramidal': 0.3, - 'L5_basket': 0.3, 'L5_pyramidal': 0.3} - synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_basket': 1., 'L5_pyramidal': 1.} + weights_ampa = { + 'L2_basket': 0.3, + 'L2_pyramidal': 0.3, + 'L5_basket': 0.3, + 'L5_pyramidal': 0.3, + } + synaptic_delays = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.1, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.0, + } net.add_evoked_drive( - 'prox', mu=40, sigma=8.33, numspikes=1, - weights_ampa=weights_ampa, location='proximal', - synaptic_delays=synaptic_delays, event_seed=1) + 'prox', + mu=40, + sigma=8.33, + numspikes=1, + weights_ampa=weights_ampa, + location='proximal', + synaptic_delays=synaptic_delays, + event_seed=1, + ) _ = simulate_dipole(net, tstop=100, dt=0.5, n_trials=2) - trial1_spikes = np.array(sorted( - net.external_drives['prox']['events'][0])) - trial2_spikes = np.array(sorted( - net.external_drives['prox']['events'][1])) + trial1_spikes = np.array(sorted(net.external_drives['prox']['events'][0])) + trial2_spikes = np.array(sorted(net.external_drives['prox']['events'][1])) # No two spikes should be perfectly identical across seeds assert ~np.any(np.allclose(trial1_spikes, trial2_spikes)) @@ -120,19 +156,31 @@ def test_clear_drives(setup_net): """Test clearing drives updates Network""" net = setup_net weights_ampa = {'L5_pyramidal': 0.3} - synaptic_delays = {'L5_pyramidal': 1.} + synaptic_delays = {'L5_pyramidal': 1.0} # Test attributes after adding 2 drives n_gids = net._n_gids net.add_evoked_drive( - 'prox', mu=40, sigma=8.33, numspikes=1, - weights_ampa=weights_ampa, location='proximal', - synaptic_delays=synaptic_delays, cell_specific=True) + 'prox', + mu=40, + sigma=8.33, + numspikes=1, + weights_ampa=weights_ampa, + location='proximal', + synaptic_delays=synaptic_delays, + cell_specific=True, + ) net.add_evoked_drive( - 'dist', mu=40, sigma=8.33, numspikes=1, - weights_ampa=weights_ampa, location='distal', - synaptic_delays=synaptic_delays, cell_specific=True) + 'dist', + mu=40, + sigma=8.33, + numspikes=1, + weights_ampa=weights_ampa, + location='distal', + synaptic_delays=synaptic_delays, + cell_specific=True, + ) for drive_name in ['prox', 'dist']: assert len(net.external_drives) == 2 @@ -152,9 +200,15 @@ def test_clear_drives(setup_net): # Test attributes after adding 1 drive net.add_evoked_drive( - 'prox', mu=40, sigma=8.33, numspikes=1, - weights_ampa=weights_ampa, location='proximal', - synaptic_delays=synaptic_delays, cell_specific=True) + 'prox', + mu=40, + sigma=8.33, + numspikes=1, + weights_ampa=weights_ampa, + location='proximal', + synaptic_delays=synaptic_delays, + cell_specific=True, + ) assert len(net.external_drives) == 1 assert 'prox' in net.external_drives @@ -177,9 +231,13 @@ def test_add_drives(): n_drive_cells = 10 cell_specific = False # default for bursty drive net.add_bursty_drive( - 'bursty', location='distal', burst_rate=10, - weights_ampa=weights_ampa, synaptic_delays=syn_delays, - n_drive_cells=n_drive_cells) + 'bursty', + location='distal', + burst_rate=10, + weights_ampa=weights_ampa, + synaptic_delays=syn_delays, + n_drive_cells=n_drive_cells, + ) assert net.external_drives['bursty']['n_drive_cells'] == n_drive_cells assert net.external_drives['bursty']['cell_specific'] == cell_specific @@ -193,13 +251,18 @@ def test_add_drives(): n_drive_cells = 'n_cells' # default for evoked drive cell_specific = True net.add_evoked_drive( - 'evoked_dist', mu=1.0, sigma=1.0, numspikes=1, - weights_ampa=weights_ampa, location='distal', - synaptic_delays=syn_delays, cell_specific=True) + 'evoked_dist', + mu=1.0, + sigma=1.0, + numspikes=1, + weights_ampa=weights_ampa, + location='distal', + synaptic_delays=syn_delays, + cell_specific=True, + ) n_dist_targets = 235 # 270 with legacy mode - assert (net.external_drives['evoked_dist'] - ['n_drive_cells'] == n_dist_targets) + assert net.external_drives['evoked_dist']['n_drive_cells'] == n_dist_targets assert net.external_drives['evoked_dist']['cell_specific'] == cell_specific conn_idxs = pick_connection(net, src_gids='evoked_dist') for conn_idx in conn_idxs: @@ -211,13 +274,16 @@ def test_add_drives(): n_drive_cells = 'n_cells' # default for poisson drive cell_specific = True net.add_poisson_drive( - 'poisson', rate_constant=1.0, weights_ampa=weights_ampa, - location='distal', synaptic_delays=syn_delays, - cell_specific=cell_specific) + 'poisson', + rate_constant=1.0, + weights_ampa=weights_ampa, + location='distal', + synaptic_delays=syn_delays, + cell_specific=cell_specific, + ) n_dist_targets = 235 # 270 with non-legacy mode - assert (net.external_drives['poisson'] - ['n_drive_cells'] == n_dist_targets) + assert net.external_drives['poisson']['n_drive_cells'] == n_dist_targets assert net.external_drives['poisson']['cell_specific'] == cell_specific conn_idxs = pick_connection(net, src_gids='poisson') for conn_idx in conn_idxs: @@ -232,21 +298,29 @@ def test_add_drives(): weights_ampa_tuft = {'L2_pyramidal': 1.0, 'L5_pyramidal': 2.0} syn_delays_tuft = {'L2_pyramidal': 1.0, 'L5_pyramidal': 2.0} net.add_bursty_drive( - 'bursty_tuft', location=location, burst_rate=10, - weights_ampa=weights_ampa_tuft, synaptic_delays=syn_delays_tuft, - n_drive_cells=10) + 'bursty_tuft', + location=location, + burst_rate=10, + weights_ampa=weights_ampa_tuft, + synaptic_delays=syn_delays_tuft, + n_drive_cells=10, + ) assert net.connectivity[-1]['loc'] == location # Section not present on cells indicated location = 'apical_tuft' weights_ampa_no_tuft = {'L2_pyramidal': 1.0, 'L5_basket': 2.0} syn_delays_no_tuft = {'L2_pyramidal': 1.0, 'L5_basket': 2.0} - match = ('Invalid value for') + match = 'Invalid value for' with pytest.raises(ValueError, match=match): net.add_bursty_drive( - 'bursty_no_tuft', location=location, burst_rate=10, + 'bursty_no_tuft', + location=location, + burst_rate=10, weights_ampa=weights_ampa_no_tuft, - synaptic_delays=syn_delays_no_tuft, n_drive_cells=n_drive_cells) + synaptic_delays=syn_delays_no_tuft, + n_drive_cells=n_drive_cells, + ) # Test probabilistic drive connections. # drive with cell_specific=False @@ -254,240 +328,340 @@ def test_add_drives(): probability = 0.5 # test that only half of possible connections are made weights_nmda = {'L2_basket': 1.0, 'L2_pyramidal': 3.0, 'L5_pyramidal': 4.0} net.add_bursty_drive( - 'bursty_prob', location='distal', burst_rate=10, - weights_ampa=weights_ampa, weights_nmda=weights_nmda, - synaptic_delays=syn_delays, n_drive_cells=n_drive_cells, - probability=probability) + 'bursty_prob', + location='distal', + burst_rate=10, + weights_ampa=weights_ampa, + weights_nmda=weights_nmda, + synaptic_delays=syn_delays, + n_drive_cells=n_drive_cells, + probability=probability, + ) for cell_type in weights_ampa.keys(): - conn_idxs = pick_connection( - net, src_gids='bursty_prob', target_gids=cell_type) + conn_idxs = pick_connection(net, src_gids='bursty_prob', target_gids=cell_type) gid_pairs_comparison = net.connectivity[conn_idxs[0]]['gid_pairs'] for conn_idx in conn_idxs: conn = net.connectivity[conn_idx] - num_connections = np.sum( - [len(gids) for gids in conn['gid_pairs'].values()]) + num_connections = np.sum([len(gids) for gids in conn['gid_pairs'].values()]) # Ensures that AMPA and NMDA connections target the same gids. # Necessary when weights of both are non-zero. assert gid_pairs_comparison == conn['gid_pairs'] - assert num_connections == \ - np.around(len(net.gid_ranges[cell_type]) * n_drive_cells * - probability).astype(int) + assert num_connections == np.around( + len(net.gid_ranges[cell_type]) * n_drive_cells * probability + ).astype(int) # drives with cell_specific=True probability = {'L2_basket': 0.1, 'L2_pyramidal': 0.25, 'L5_pyramidal': 0.5} net.add_evoked_drive( - 'evoked_prob', mu=1.0, sigma=1.0, numspikes=1, - weights_ampa=weights_ampa, weights_nmda=weights_nmda, - location='distal', synaptic_delays=syn_delays, cell_specific=True, - probability=probability) + 'evoked_prob', + mu=1.0, + sigma=1.0, + numspikes=1, + weights_ampa=weights_ampa, + weights_nmda=weights_nmda, + location='distal', + synaptic_delays=syn_delays, + cell_specific=True, + probability=probability, + ) for cell_type in weights_ampa.keys(): - conn_idxs = pick_connection( - net, src_gids='evoked_prob', target_gids=cell_type) + conn_idxs = pick_connection(net, src_gids='evoked_prob', target_gids=cell_type) gid_pairs_comparison = net.connectivity[conn_idxs[0]]['gid_pairs'] for conn_idx in conn_idxs: conn = net.connectivity[conn_idx] - num_connections = np.sum( - [len(gids) for gids in conn['gid_pairs'].values()]) + num_connections = np.sum([len(gids) for gids in conn['gid_pairs'].values()]) assert gid_pairs_comparison == conn['gid_pairs'] - assert num_connections == \ - np.around(len(net.gid_ranges[cell_type]) * - probability[cell_type]).astype(int) + assert num_connections == np.around( + len(net.gid_ranges[cell_type]) * probability[cell_type] + ).astype(int) # Test adding just the NMDA weights (no AMPA) net.add_evoked_drive( - 'evoked_nmda', mu=1.0, sigma=1.0, numspikes=1, + 'evoked_nmda', + mu=1.0, + sigma=1.0, + numspikes=1, weights_nmda=weights_nmda, - location='distal', synaptic_delays=syn_delays, cell_specific=True, - probability=probability) + location='distal', + synaptic_delays=syn_delays, + cell_specific=True, + probability=probability, + ) # Round trip test to ensure drives API produces a functioning Network simulate_dipole(net, tstop=1) # evoked - with pytest.raises(ValueError, - match='Standard deviation cannot be negative'): - net.add_evoked_drive('evdist1', mu=10, sigma=-1, numspikes=1, - location='distal') - with pytest.raises(ValueError, - match='Number of spikes must be greater than zero'): - net.add_evoked_drive('evdist1', mu=10, sigma=1, numspikes=0, - location='distal') + with pytest.raises(ValueError, match='Standard deviation cannot be negative'): + net.add_evoked_drive('evdist1', mu=10, sigma=-1, numspikes=1, location='distal') + with pytest.raises(ValueError, match='Number of spikes must be greater than zero'): + net.add_evoked_drive('evdist1', mu=10, sigma=1, numspikes=0, location='distal') # Test Network._attach_drive() - with pytest.raises(ValueError, - match='Invalid value for'): - net.add_evoked_drive('evdist1', mu=10, sigma=1, numspikes=1, - location='bogus_location', - weights_ampa={'L5_basket': 1.}, - synaptic_delays={'L5_basket': .1}) - with pytest.raises(ValueError, - match='Drive evoked_dist already defined'): - net.add_evoked_drive('evoked_dist', mu=10, sigma=1, numspikes=1, - location='distal') - with pytest.raises(ValueError, - match='No target cell types have been given a synaptic ' - 'weight'): - net.add_evoked_drive('evdist1', mu=10, sigma=1, numspikes=1, - location='distal') - with pytest.raises(ValueError, - match='Due to physiological/anatomical constraints, ' - 'a distal drive cannot target L5_basket cell types. '): - net.add_evoked_drive('evdist1', mu=10, sigma=1, numspikes=1, - location='distal', weights_ampa={'L5_basket': 1.}, - synaptic_delays={'L5_basket': .1}) - with pytest.raises(ValueError, - match='If cell_specific is True, n_drive_cells'): - net.add_evoked_drive('evdist1', mu=10, sigma=1, numspikes=1, - location='distal', n_drive_cells=10, - cell_specific=True, weights_ampa=weights_ampa, - synaptic_delays=syn_delays) - with pytest.raises(ValueError, - match='If cell_specific is False, n_drive_cells'): - net.add_evoked_drive('evdist1', mu=10, sigma=1, numspikes=1, - location='distal', n_drive_cells='n_cells', - cell_specific=False, weights_ampa=weights_ampa, - synaptic_delays=syn_delays) - with pytest.raises(ValueError, - match='Number of drive cells must be greater than 0'): - net.add_evoked_drive('evdist1', mu=10, sigma=1, numspikes=1, - location='distal', n_drive_cells=0, - cell_specific=False, weights_ampa=weights_ampa, - synaptic_delays=syn_delays) + with pytest.raises(ValueError, match='Invalid value for'): + net.add_evoked_drive( + 'evdist1', + mu=10, + sigma=1, + numspikes=1, + location='bogus_location', + weights_ampa={'L5_basket': 1.0}, + synaptic_delays={'L5_basket': 0.1}, + ) + with pytest.raises(ValueError, match='Drive evoked_dist already defined'): + net.add_evoked_drive( + 'evoked_dist', mu=10, sigma=1, numspikes=1, location='distal' + ) + with pytest.raises( + ValueError, match='No target cell types have been given a synaptic ' 'weight' + ): + net.add_evoked_drive('evdist1', mu=10, sigma=1, numspikes=1, location='distal') + with pytest.raises( + ValueError, + match='Due to physiological/anatomical constraints, ' + 'a distal drive cannot target L5_basket cell types. ', + ): + net.add_evoked_drive( + 'evdist1', + mu=10, + sigma=1, + numspikes=1, + location='distal', + weights_ampa={'L5_basket': 1.0}, + synaptic_delays={'L5_basket': 0.1}, + ) + with pytest.raises(ValueError, match='If cell_specific is True, n_drive_cells'): + net.add_evoked_drive( + 'evdist1', + mu=10, + sigma=1, + numspikes=1, + location='distal', + n_drive_cells=10, + cell_specific=True, + weights_ampa=weights_ampa, + synaptic_delays=syn_delays, + ) + with pytest.raises(ValueError, match='If cell_specific is False, n_drive_cells'): + net.add_evoked_drive( + 'evdist1', + mu=10, + sigma=1, + numspikes=1, + location='distal', + n_drive_cells='n_cells', + cell_specific=False, + weights_ampa=weights_ampa, + synaptic_delays=syn_delays, + ) + with pytest.raises( + ValueError, match='Number of drive cells must be greater than 0' + ): + net.add_evoked_drive( + 'evdist1', + mu=10, + sigma=1, + numspikes=1, + location='distal', + n_drive_cells=0, + cell_specific=False, + weights_ampa=weights_ampa, + synaptic_delays=syn_delays, + ) # Poisson - with pytest.raises(ValueError, - match='End time of Poisson drive cannot be negative'): - net.add_poisson_drive('poisson1', tstart=0, tstop=-1, - location='distal', rate_constant=10.) - with pytest.raises(ValueError, - match='Start time of Poisson drive cannot be negative'): - net.add_poisson_drive('poisson1', tstart=-1, - location='distal', rate_constant=10.) - with pytest.raises(ValueError, - match='Duration of Poisson drive cannot be negative'): - net.add_poisson_drive('poisson1', tstart=10, tstop=1, - location='distal', rate_constant=10.) - with pytest.raises(ValueError, - match='Rate constant must be positive'): - net.add_poisson_drive('poisson1', location='distal', - rate_constant=0., - weights_ampa=weights_ampa, - synaptic_delays=syn_delays) - - with pytest.raises(ValueError, - match='Rate constants not provided for all target'): - net.add_poisson_drive('poisson1', location='distal', - rate_constant={'L2_pyramidal': 10.}, - weights_ampa=weights_ampa, - synaptic_delays=syn_delays) - with pytest.raises(ValueError, - match='Rate constant provided for unknown target cell'): - net.add_poisson_drive('poisson1', location='distal', - rate_constant={'L2_pyramidal': 10., - 'bogus_celltype': 20.}, - weights_ampa={'L2_pyramidal': .01, - 'bogus_celltype': .01}, - synaptic_delays=0.1) - - with pytest.raises(ValueError, - match='Drives specific to cell types are only ' - 'possible with cell_specific=True'): - net.add_poisson_drive('poisson1', location='distal', - rate_constant={'L2_basket': 10., - 'L2_pyramidal': 11., - 'L5_basket': 12., - 'L5_pyramidal': 13.}, - n_drive_cells=1, cell_specific=False, - weights_ampa=weights_ampa, - synaptic_delays=syn_delays) + with pytest.raises( + ValueError, match='End time of Poisson drive cannot be negative' + ): + net.add_poisson_drive( + 'poisson1', tstart=0, tstop=-1, location='distal', rate_constant=10.0 + ) + with pytest.raises( + ValueError, match='Start time of Poisson drive cannot be negative' + ): + net.add_poisson_drive( + 'poisson1', tstart=-1, location='distal', rate_constant=10.0 + ) + with pytest.raises( + ValueError, match='Duration of Poisson drive cannot be negative' + ): + net.add_poisson_drive( + 'poisson1', tstart=10, tstop=1, location='distal', rate_constant=10.0 + ) + with pytest.raises(ValueError, match='Rate constant must be positive'): + net.add_poisson_drive( + 'poisson1', + location='distal', + rate_constant=0.0, + weights_ampa=weights_ampa, + synaptic_delays=syn_delays, + ) + + with pytest.raises(ValueError, match='Rate constants not provided for all target'): + net.add_poisson_drive( + 'poisson1', + location='distal', + rate_constant={'L2_pyramidal': 10.0}, + weights_ampa=weights_ampa, + synaptic_delays=syn_delays, + ) + with pytest.raises( + ValueError, match='Rate constant provided for unknown target cell' + ): + net.add_poisson_drive( + 'poisson1', + location='distal', + rate_constant={'L2_pyramidal': 10.0, 'bogus_celltype': 20.0}, + weights_ampa={'L2_pyramidal': 0.01, 'bogus_celltype': 0.01}, + synaptic_delays=0.1, + ) + + with pytest.raises( + ValueError, + match='Drives specific to cell types are only ' + 'possible with cell_specific=True', + ): + net.add_poisson_drive( + 'poisson1', + location='distal', + rate_constant={ + 'L2_basket': 10.0, + 'L2_pyramidal': 11.0, + 'L5_basket': 12.0, + 'L5_pyramidal': 13.0, + }, + n_drive_cells=1, + cell_specific=False, + weights_ampa=weights_ampa, + synaptic_delays=syn_delays, + ) # bursty - with pytest.raises(ValueError, - match='End time of bursty drive cannot be negative'): - net.add_bursty_drive('bursty_drive', tstop=-1, - location='distal', burst_rate=10) - with pytest.raises(ValueError, - match='Start time of bursty drive cannot be negative'): - net.add_bursty_drive('bursty_drive', tstart=-1, - location='distal', burst_rate=10) - with pytest.raises(ValueError, - match='Duration of bursty drive cannot be negative'): - net.add_bursty_drive('bursty_drive', tstart=10, tstop=1, - location='distal', burst_rate=10) - - msg = (r'(?s)Burst duration .* cannot be greater than ' - 'burst period') + with pytest.raises(ValueError, match='End time of bursty drive cannot be negative'): + net.add_bursty_drive('bursty_drive', tstop=-1, location='distal', burst_rate=10) + with pytest.raises( + ValueError, match='Start time of bursty drive cannot be negative' + ): + net.add_bursty_drive( + 'bursty_drive', tstart=-1, location='distal', burst_rate=10 + ) + with pytest.raises(ValueError, match='Duration of bursty drive cannot be negative'): + net.add_bursty_drive( + 'bursty_drive', tstart=10, tstop=1, location='distal', burst_rate=10 + ) + + msg = r'(?s)Burst duration .* cannot be greater than ' 'burst period' with pytest.raises(ValueError, match=msg): - net.add_bursty_drive('bursty_drive', location='distal', - burst_rate=10, burst_std=20., numspikes=4, - spike_isi=50) + net.add_bursty_drive( + 'bursty_drive', + location='distal', + burst_rate=10, + burst_std=20.0, + numspikes=4, + spike_isi=50, + ) # attaching drives - with pytest.raises(ValueError, - match='Drive evoked_dist already defined'): - net.add_poisson_drive('evoked_dist', location='distal', - rate_constant=10., - weights_ampa=weights_ampa, - synaptic_delays=syn_delays) - with pytest.raises(ValueError, - match='Invalid value for the'): - net.add_poisson_drive('weird_poisson', location='inbetween', - rate_constant=10., - weights_ampa=weights_ampa, - synaptic_delays=syn_delays) - with pytest.raises(ValueError, - match='Allowed drive target cell types are:'): - net.add_poisson_drive('cell_unknown', location='proximal', - rate_constant=10., - weights_ampa={'CA1_pyramidal': 1.}, - synaptic_delays=.01) - with pytest.raises(ValueError, - match='synaptic_delays is either a common float or ' - 'needs to be specified as a dict for each of the cell'): - net.add_poisson_drive('cell_unknown', location='proximal', - rate_constant=10., - weights_ampa={'L2_pyramidal': 1.}, - synaptic_delays={'L5_pyramidal': 1.}) - with pytest.raises(ValueError, - match=r'probability must be in the range \(0\,1\)'): + with pytest.raises(ValueError, match='Drive evoked_dist already defined'): + net.add_poisson_drive( + 'evoked_dist', + location='distal', + rate_constant=10.0, + weights_ampa=weights_ampa, + synaptic_delays=syn_delays, + ) + with pytest.raises(ValueError, match='Invalid value for the'): + net.add_poisson_drive( + 'weird_poisson', + location='inbetween', + rate_constant=10.0, + weights_ampa=weights_ampa, + synaptic_delays=syn_delays, + ) + with pytest.raises(ValueError, match='Allowed drive target cell types are:'): + net.add_poisson_drive( + 'cell_unknown', + location='proximal', + rate_constant=10.0, + weights_ampa={'CA1_pyramidal': 1.0}, + synaptic_delays=0.01, + ) + with pytest.raises( + ValueError, + match='synaptic_delays is either a common float or ' + 'needs to be specified as a dict for each of the cell', + ): + net.add_poisson_drive( + 'cell_unknown', + location='proximal', + rate_constant=10.0, + weights_ampa={'L2_pyramidal': 1.0}, + synaptic_delays={'L5_pyramidal': 1.0}, + ) + with pytest.raises(ValueError, match=r'probability must be in the range \(0\,1\)'): net.add_bursty_drive( - 'cell_unknown', location='distal', burst_rate=10, - weights_ampa={'L2_pyramidal': 1.}, - synaptic_delays={'L2_pyramidal': 1.}, probability=2.0) - - with pytest.raises(TypeError, match="probability must be an instance of " - r"float or dict, got \ instead"): + 'cell_unknown', + location='distal', + burst_rate=10, + weights_ampa={'L2_pyramidal': 1.0}, + synaptic_delays={'L2_pyramidal': 1.0}, + probability=2.0, + ) + + with pytest.raises( + TypeError, + match='probability must be an instance of ' + r"float or dict, got \ instead", + ): net.add_bursty_drive( - 'cell_unknown2', location='distal', burst_rate=10, - weights_ampa={'L2_pyramidal': 1.}, - synaptic_delays={'L2_pyramidal': 1.}, probability='1.0') - - with pytest.raises(ValueError, match='probability is either a common ' - 'float or needs to be specified as a dict for ' - 'each of the cell'): + 'cell_unknown2', + location='distal', + burst_rate=10, + weights_ampa={'L2_pyramidal': 1.0}, + synaptic_delays={'L2_pyramidal': 1.0}, + probability='1.0', + ) + + with pytest.raises( + ValueError, + match='probability is either a common ' + 'float or needs to be specified as a dict for ' + 'each of the cell', + ): net.add_bursty_drive( - 'cell_unknown2', location='distal', burst_rate=10, - weights_ampa={'L2_pyramidal': 1.}, - synaptic_delays={'L2_pyramidal': 1.}, - probability={'L5_pyramidal': 1.}) - - with pytest.raises(TypeError, match="probability must be an instance of " - r"float, got \ instead"): + 'cell_unknown2', + location='distal', + burst_rate=10, + weights_ampa={'L2_pyramidal': 1.0}, + synaptic_delays={'L2_pyramidal': 1.0}, + probability={'L5_pyramidal': 1.0}, + ) + + with pytest.raises( + TypeError, + match='probability must be an instance of ' + r"float, got \ instead", + ): net.add_bursty_drive( - 'cell_unknown2', location='distal', burst_rate=10, - weights_ampa={'L2_pyramidal': 1.}, - synaptic_delays={'L2_pyramidal': 1.}, - probability={'L2_pyramidal': '1.0'}) - - with pytest.raises(ValueError, - match=r'probability must be in the range \(0\,1\)'): + 'cell_unknown2', + location='distal', + burst_rate=10, + weights_ampa={'L2_pyramidal': 1.0}, + synaptic_delays={'L2_pyramidal': 1.0}, + probability={'L2_pyramidal': '1.0'}, + ) + + with pytest.raises(ValueError, match=r'probability must be in the range \(0\,1\)'): net.add_bursty_drive( - 'cell_unknown3', location='distal', burst_rate=10, - weights_ampa={'L2_pyramidal': 1.}, - synaptic_delays={'L2_pyramidal': 1.}, - probability={'L2_pyramidal': 2.0}) + 'cell_unknown3', + location='distal', + burst_rate=10, + weights_ampa={'L2_pyramidal': 1.0}, + synaptic_delays={'L2_pyramidal': 1.0}, + probability={'L2_pyramidal': 2.0}, + ) with pytest.warns(UserWarning, match='No external drives or biases load'): net.clear_drives() @@ -497,38 +671,67 @@ def test_add_drives(): def test_drive_random_state(): """Tests to check same random state always gives same spike times.""" - weights_ampa = {'L2_basket': 0.08, 'L2_pyramidal': 0.02, - 'L5_basket': 0.2, 'L5_pyramidal': 0.00865} - synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_basket': 1., 'L5_pyramidal': 1.} + weights_ampa = { + 'L2_basket': 0.08, + 'L2_pyramidal': 0.02, + 'L5_basket': 0.2, + 'L5_pyramidal': 0.00865, + } + synaptic_delays = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.1, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.0, + } net = jones_2009_model() for drive_name in ['evprox1', 'evprox2']: net.add_evoked_drive( - drive_name, mu=137.12, sigma=8, numspikes=1, - weights_ampa=weights_ampa, weights_nmda=None, - location='proximal', synaptic_delays=synaptic_delays, event_seed=4) - - net._instantiate_drives(tstop=170.) - assert (net.external_drives['evprox1']['events'] == - net.external_drives['evprox2']['events']) - - -@pytest.mark.parametrize("rate_constant,cell_specific,n_drive_cells", - [(2, False, 1), (2.0, False, 1), - (2, True, 'n_cells'), (2.0, True, 'n_cells'), - ]) -def test_add_poisson_drive(setup_net, rate_constant, cell_specific, - n_drive_cells): + drive_name, + mu=137.12, + sigma=8, + numspikes=1, + weights_ampa=weights_ampa, + weights_nmda=None, + location='proximal', + synaptic_delays=synaptic_delays, + event_seed=4, + ) + + net._instantiate_drives(tstop=170.0) + assert ( + net.external_drives['evprox1']['events'] + == net.external_drives['evprox2']['events'] + ) + + +@pytest.mark.parametrize( + 'rate_constant,cell_specific,n_drive_cells', + [ + (2, False, 1), + (2.0, False, 1), + (2, True, 'n_cells'), + (2.0, True, 'n_cells'), + ], +) +def test_add_poisson_drive(setup_net, rate_constant, cell_specific, n_drive_cells): """Testing rate constant when adding non-cell-specific poisson drive""" net = setup_net - weights_ampa_noise = {'L2_basket': 0.01, 'L2_pyramidal': 0.002, - 'L5_pyramidal': 0.02} + weights_ampa_noise = { + 'L2_basket': 0.01, + 'L2_pyramidal': 0.002, + 'L5_pyramidal': 0.02, + } - net.add_poisson_drive('noise_global', rate_constant=rate_constant, - location='distal', weights_ampa=weights_ampa_noise, - space_constant=100, n_drive_cells=n_drive_cells, - cell_specific=cell_specific) + net.add_poisson_drive( + 'noise_global', + rate_constant=rate_constant, + location='distal', + weights_ampa=weights_ampa_noise, + space_constant=100, + n_drive_cells=n_drive_cells, + cell_specific=cell_specific, + ) simulate_dipole(net, tstop=5) diff --git a/hnn_core/tests/test_extracellular.py b/hnn_core/tests/test_extracellular.py index ae97083f1..7de709cc3 100644 --- a/hnn_core/tests/test_extracellular.py +++ b/hnn_core/tests/test_extracellular.py @@ -9,8 +9,11 @@ import hnn_core from hnn_core import read_params, jones_2009_model, simulate_dipole -from hnn_core.extracellular import (ExtracellularArray, calculate_csd2d, - _get_laminar_z_coords) +from hnn_core.extracellular import ( + ExtracellularArray, + calculate_csd2d, + _get_laminar_z_coords, +) from hnn_core.parallel_backends import requires_mpi4py, requires_psutil import matplotlib.pyplot as plt @@ -34,7 +37,7 @@ def test_extracellular_api(): assert len(net.rec_arrays['arr1'].positions) == 2 # Test other not NotImplemented for ExtracellularArray Class - assert (net.rec_arrays['arr1'] == "extArr") is False + assert (net.rec_arrays['arr1'] == 'extArr') is False # ensure unique names pytest.raises(ValueError, net.add_electrode_array, 'arr1', [(6, 6, 800)]) @@ -44,15 +47,13 @@ def test_extracellular_api(): # Added second string in the match pattern due to changes in python >=3.11 # AttributeError message changed to "property X of object Y has no setter" - with pytest.raises(AttributeError, - match="has no setter|can't set attribute"): + with pytest.raises(AttributeError, match="has no setter|can't set attribute"): rec_arr.times = [1, 2, 3] - with pytest.raises(AttributeError, - match="has no setter|can't set attribute"): + with pytest.raises(AttributeError, match="has no setter|can't set attribute"): rec_arr.voltages = [1, 2, 3] - with pytest.raises(TypeError, match="trial index must be int"): + with pytest.raises(TypeError, match='trial index must be int'): _ = rec_arr['0'] - with pytest.raises(IndexError, match="the data contain"): + with pytest.raises(IndexError, match='the data contain'): _ = rec_arr[42] # positions are 3-tuples @@ -62,69 +63,99 @@ def test_extracellular_api(): good_positions = [(1, 2, 3), (100, 200, 300)] for cond in ['0.3', [0.3], -1]: # conductivity is positive float - pytest.raises((TypeError, AssertionError), ExtracellularArray, - good_positions, conductivity=cond) + pytest.raises( + (TypeError, AssertionError), + ExtracellularArray, + good_positions, + conductivity=cond, + ) for meth in ['foo', 0.3]: # method is 'psa' or 'lsa' (or None for test) - pytest.raises((TypeError, AssertionError, ValueError), - ExtracellularArray, good_positions, method=meth) + pytest.raises( + (TypeError, AssertionError, ValueError), + ExtracellularArray, + good_positions, + method=meth, + ) for mind in ['foo', -1, None]: # minimum distance to segment boundary - pytest.raises((TypeError, AssertionError), ExtracellularArray, - good_positions, min_distance=mind) - - pytest.raises(ValueError, ExtracellularArray, # more chans than voltages - good_positions, times=[1], voltages=[[[42]]]) - pytest.raises(ValueError, ExtracellularArray, # less times than voltages - good_positions, times=[1], voltages=[[[42, 42], [84, 84]]]) - - rec_arr = ExtracellularArray(good_positions, - times=[0, 0.1, 0.21, 0.3], # uneven sampling - voltages=[[[0, 0, 0, 0], [0, 0, 0, 0]]]) - with pytest.raises(RuntimeError, match="Extracellular sampling times"): + pytest.raises( + (TypeError, AssertionError), + ExtracellularArray, + good_positions, + min_distance=mind, + ) + + pytest.raises( + ValueError, + ExtracellularArray, # more chans than voltages + good_positions, + times=[1], + voltages=[[[42]]], + ) + pytest.raises( + ValueError, + ExtracellularArray, # less times than voltages + good_positions, + times=[1], + voltages=[[[42, 42], [84, 84]]], + ) + + rec_arr = ExtracellularArray( + good_positions, + times=[0, 0.1, 0.21, 0.3], # uneven sampling + voltages=[[[0, 0, 0, 0], [0, 0, 0, 0]]], + ) + with pytest.raises(RuntimeError, match='Extracellular sampling times'): _ = rec_arr.sfreq rec_arr._reset() assert len(rec_arr.times) == len(rec_arr.voltages) == 0 assert rec_arr.sfreq is None - rec_arr = ExtracellularArray(good_positions, - times=[0], voltages=[[[0], [0]]]) - with pytest.raises(RuntimeError, match="Sampling rate is not defined"): + rec_arr = ExtracellularArray(good_positions, times=[0], voltages=[[[0], [0]]]) + with pytest.raises(RuntimeError, match='Sampling rate is not defined'): _ = rec_arr.sfreq # test colinearity and equal spacing between electrode contacts for laminar # profiling (e.g., for platting laminar LFP or CSD) - electrode_pos = [(1, 2, 1000), (2, 3, 3000), (3, 4, 5000), - (4, 5, 7000)] + electrode_pos = [(1, 2, 1000), (2, 3, 3000), (3, 4, 5000), (4, 5, 7000)] z_coords, z_delta = _get_laminar_z_coords(electrode_pos) assert np.array_equal(z_coords, [1000, 3000, 5000, 7000]) assert z_delta == 2000 - with pytest.raises(ValueError, match='Electrode array positions must ' - 'contain more than 1 contact'): + with pytest.raises( + ValueError, + match='Electrode array positions must ' 'contain more than 1 contact', + ): _, _ = _get_laminar_z_coords([(1, 2, 3)]) - with pytest.raises(ValueError, match='Make sure the electrode positions ' - 'are equispaced, colinear'): + with pytest.raises( + ValueError, + match='Make sure the electrode positions ' 'are equispaced, colinear', + ): _, _ = _get_laminar_z_coords([(1, 1, 3), (1, 1, 4), (1, 1, 3.5)]) def test_transmembrane_currents(): """Test that net transmembrane current is zero at all times.""" - params.update({'N_pyr_x': 3, - 'N_pyr_y': 3, - 't_evprox_1': 5, - 't_evdist_1': 10, - 't_evprox_2': 20, - 'N_trials': 1}) + params.update( + { + 'N_pyr_x': 3, + 'N_pyr_y': 3, + 't_evprox_1': 5, + 't_evdist_1': 10, + 't_evprox_2': 20, + 'N_trials': 1, + } + ) net = jones_2009_model(params, add_drives_from_params=True) electrode_pos = (0, 0, 0) # irrelevant where electrode is # all transfer resistances set to unity net.add_electrode_array('net_Im', electrode_pos, method=None) - _ = simulate_dipole(net, tstop=40.) - assert_allclose(net.rec_arrays['net_Im'].voltages, 0, - rtol=1e-10, atol=1e-10) + _ = simulate_dipole(net, tstop=40.0) + assert_allclose(net.rec_arrays['net_Im'].voltages, 0, rtol=1e-10, atol=1e-10) def test_transfer_resistance(): """Test transfer resistances calculated correctly""" from neuron import h from hnn_core.extracellular import _transfer_resistance + sec = h.Section(name='dend') h.pt3dclear(sec=sec) h.pt3dadd(0, 0, 0, 1, sec=sec) @@ -148,24 +179,29 @@ def test_transfer_resistance(): target_vals = {'psa': list(), 'lsa': list()} for seg_idx in range(sec.nseg): # PSA: distance to middle segment == electrode x-position - var_r_psa = np.sqrt(elec_pos[0] ** 2 + - (elec_pos[1] - seg_ctr_pts[seg_idx]) ** 2) - target_vals['psa'].append( - 1000 / (4. * np.pi * conductivity * var_r_psa)) + var_r_psa = np.sqrt( + elec_pos[0] ** 2 + (elec_pos[1] - seg_ctr_pts[seg_idx]) ** 2 + ) + target_vals['psa'].append(1000 / (4.0 * np.pi * conductivity * var_r_psa)) # LSA: calculate L and H variables relative to segment endpoints var_l = elec_pos[1] - (seg_ctr_pts[seg_idx] - seg_lens[seg_idx]) var_h = elec_pos[1] - (seg_ctr_pts[seg_idx] + seg_lens[seg_idx]) var_r_lsa = elec_pos[0] # just use the axial distance target_vals['lsa'].append( - 1000 * np.log(np.abs( - (np.sqrt(var_h ** 2 + var_r_lsa ** 2) - var_h) / - (np.sqrt(var_l ** 2 + var_r_lsa ** 2) - var_l) - )) / (4. * np.pi * conductivity * 2 * seg_lens[seg_idx])) + 1000 + * np.log( + np.abs( + (np.sqrt(var_h**2 + var_r_lsa**2) - var_h) + / (np.sqrt(var_l**2 + var_r_lsa**2) - var_l) + ) + ) + / (4.0 * np.pi * conductivity * 2 * seg_lens[seg_idx]) + ) for method in ['psa', 'lsa']: res = _transfer_resistance(sec, elec_pos, conductivity, method) - assert_allclose(res, target_vals[method], rtol=1e-12, atol=0.) + assert_allclose(res, target_vals[method], rtol=1e-12, atol=0.0) @requires_mpi4py @@ -173,33 +209,52 @@ def test_transfer_resistance(): def test_extracellular_backends(run_hnn_core_fixture): """Test extracellular outputs across backends.""" # calculation of CSD requires >=4 electrode contacts - electrode_array = {'arr1': [(2, 2, 400), (2, 2, 600), (2, 2, 800), - (2, 2, 1000)]} + electrode_array = {'arr1': [(2, 2, 400), (2, 2, 600), (2, 2, 800), (2, 2, 1000)]} _, joblib_net = run_hnn_core_fixture( - backend='joblib', n_jobs=1, reduced=True, record_isec='soma', - record_vsec='soma', record_ca='soma', electrode_array=electrode_array) + backend='joblib', + n_jobs=1, + reduced=True, + record_isec='soma', + record_vsec='soma', + record_ca='soma', + electrode_array=electrode_array, + ) _, mpi_net = run_hnn_core_fixture( - backend='mpi', n_procs=2, reduced=True, record_isec='soma', - record_vsec='soma', record_ca='soma', electrode_array=electrode_array) - - assert (len(electrode_array['arr1']) == - len(joblib_net.rec_arrays['arr1'].positions) == - len(mpi_net.rec_arrays['arr1'].positions)) - assert (len(joblib_net.rec_arrays['arr1']) == - len(mpi_net.rec_arrays['arr1']) == - 2) # length == n.o. trials + backend='mpi', + n_procs=2, + reduced=True, + record_isec='soma', + record_vsec='soma', + record_ca='soma', + electrode_array=electrode_array, + ) + + assert ( + len(electrode_array['arr1']) + == len(joblib_net.rec_arrays['arr1'].positions) + == len(mpi_net.rec_arrays['arr1'].positions) + ) + assert ( + len(joblib_net.rec_arrays['arr1']) == len(mpi_net.rec_arrays['arr1']) == 2 + ) # length == n.o. trials # reduced simulation has n_trials=2 # trial_idx, n_trials = 0, 2 for tr_idx, el_idx in zip([0, 1], [0, 1]): - assert_allclose(joblib_net.rec_arrays['arr1']._data[tr_idx][el_idx], - mpi_net.rec_arrays['arr1']._data[tr_idx][el_idx]) + assert_allclose( + joblib_net.rec_arrays['arr1']._data[tr_idx][el_idx], + mpi_net.rec_arrays['arr1']._data[tr_idx][el_idx], + ) assert isinstance(joblib_net.rec_arrays['arr1'].voltages, np.ndarray) - assert_array_equal(joblib_net.rec_arrays['arr1'].voltages.shape, - [len(joblib_net.rec_arrays['arr1']._data), - len(joblib_net.rec_arrays['arr1']._data[0]), - len(joblib_net.rec_arrays['arr1']._data[0][0])]) + assert_array_equal( + joblib_net.rec_arrays['arr1'].voltages.shape, + [ + len(joblib_net.rec_arrays['arr1']._data), + len(joblib_net.rec_arrays['arr1']._data[0]), + len(joblib_net.rec_arrays['arr1']._data[0][0]), + ], + ) # make sure sampling rate is fixed (raises RuntimeError if not) _ = joblib_net.rec_arrays['arr1'].sfreq @@ -215,24 +270,23 @@ def test_rec_array_calculation(): hnn_core_root = op.dirname(hnn_core.__file__) params_fname = op.join(hnn_core_root, 'param', 'default.json') params = read_params(params_fname) - params.update({'t_evprox_1': 7, - 't_evdist_1': 17}) - net = jones_2009_model(params, mesh_shape=(3, 3), - add_drives_from_params=True) + params.update({'t_evprox_1': 7, 't_evdist_1': 17}) + net = jones_2009_model(params, mesh_shape=(3, 3), add_drives_from_params=True) # one electrode inside, one above the active elements of the network, # and two more to allow calculation of CSD (2nd spatial derivative) - electrode_pos = [(1, 2, 1000), (2, 3, 3000), (3, 4, 5000), - (4, 5, 7000)] + electrode_pos = [(1, 2, 1000), (2, 3, 3000), (3, 4, 5000), (4, 5, 7000)] net.add_electrode_array('arr1', electrode_pos) _ = simulate_dipole(net, tstop=5, n_trials=1) # test accessing simulated voltages - assert (len(net.rec_arrays['arr1']) == - len(net.rec_arrays['arr1'].voltages) == 1) # n_trials + assert ( + len(net.rec_arrays['arr1']) == len(net.rec_arrays['arr1'].voltages) == 1 + ) # n_trials assert len(net.rec_arrays['arr1'].voltages[0]) == 4 # n_contacts - assert (len(net.rec_arrays['arr1'].voltages[0][0]) == - len(net.rec_arrays['arr1'].times)) + assert len(net.rec_arrays['arr1'].voltages[0][0]) == len( + net.rec_arrays['arr1'].times + ) # test dimensionality of LFP and CSD matrices lfp_data = net.rec_arrays['arr1'].voltages[0] @@ -260,9 +314,12 @@ def test_rec_array_calculation(): for trial_idx in range(n_trials): # LSA and PSA should agree far away (second electrode) - assert_allclose(net.rec_arrays['arr1']._data[trial_idx][1], - net.rec_arrays['arr2']._data[trial_idx][1], - rtol=1e-3, atol=1e-3) + assert_allclose( + net.rec_arrays['arr1']._data[trial_idx][1], + net.rec_arrays['arr2']._data[trial_idx][1], + rtol=1e-3, + atol=1e-3, + ) def test_extracellular_viz(): @@ -271,8 +328,7 @@ def test_extracellular_viz(): params_fname = op.join(hnn_core_root, 'param', 'default.json') params = read_params(params_fname) params.update({'t_evprox_1': 7, 't_evdist_1': 17}) - net = jones_2009_model(params, mesh_shape=(3, 3), - add_drives_from_params=True) + net = jones_2009_model(params, mesh_shape=(3, 3), add_drives_from_params=True) # one electrode inside, one above the active elements of the network, # and two more to allow calculation of CSD (2nd spatial derivative) @@ -282,7 +338,10 @@ def test_extracellular_viz(): with pytest.deprecated_call(): net.rec_arrays['arr1'].plot_lfp(show=False, tmin=10, tmax=100) - with pytest.raises(RuntimeError, match='Please use sink = "b" or ' - 'sink = "r". Only colormap "jet" is supported ' - 'for CSD.'): + with pytest.raises( + RuntimeError, + match='Please use sink = "b" or ' + 'sink = "r". Only colormap "jet" is supported ' + 'for CSD.', + ): net.rec_arrays['arr1'].plot_csd(show=False, sink='g') diff --git a/hnn_core/tests/test_general_optimization.py b/hnn_core/tests/test_general_optimization.py index af04a6cb5..2c7225cc5 100644 --- a/hnn_core/tests/test_general_optimization.py +++ b/hnn_core/tests/test_general_optimization.py @@ -9,206 +9,262 @@ from hnn_core.optimization import Optimizer -@pytest.mark.parametrize("solver", ['bayesian', 'cobyla']) +@pytest.mark.parametrize('solver', ['bayesian', 'cobyla']) def test_optimize_evoked(solver): """Test optimization routines for evoked drives in a reduced network.""" max_iter = 2 - tstop = 10. + tstop = 10.0 n_trials = 1 # simulate a dipole to establish ground-truth drive parameters net_orig = jones_2009_model(mesh_shape=(3, 3)) - mu_orig = 2. - weights_ampa = {'L2_basket': 0.5, - 'L2_pyramidal': 0.5, - 'L5_basket': 0.5, - 'L5_pyramidal': 0.5} - synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_basket': 1., 'L5_pyramidal': 1.} - net_orig.add_evoked_drive('evprox', - mu=mu_orig, - sigma=1, - numspikes=1, - location='proximal', - weights_ampa=weights_ampa, - synaptic_delays=synaptic_delays) + mu_orig = 2.0 + weights_ampa = { + 'L2_basket': 0.5, + 'L2_pyramidal': 0.5, + 'L5_basket': 0.5, + 'L5_pyramidal': 0.5, + } + synaptic_delays = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.1, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.0, + } + net_orig.add_evoked_drive( + 'evprox', + mu=mu_orig, + sigma=1, + numspikes=1, + location='proximal', + weights_ampa=weights_ampa, + synaptic_delays=synaptic_delays, + ) dpl_orig = simulate_dipole(net_orig, tstop=tstop, n_trials=n_trials)[0] # define set_params function and constraints net_offset = jones_2009_model(mesh_shape=(3, 3)) def set_params(net_offset, params): - weights_ampa = {'L2_basket': 0.5, - 'L2_pyramidal': 0.5, - 'L5_basket': 0.5, - 'L5_pyramidal': 0.5} - synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_basket': 1., 'L5_pyramidal': 1.} - net_offset.add_evoked_drive('evprox', - mu=params['mu'], - sigma=params['sigma'], - numspikes=1, - location='proximal', - weights_ampa=weights_ampa, - synaptic_delays=synaptic_delays) + weights_ampa = { + 'L2_basket': 0.5, + 'L2_pyramidal': 0.5, + 'L5_basket': 0.5, + 'L5_pyramidal': 0.5, + } + synaptic_delays = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.1, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.0, + } + net_offset.add_evoked_drive( + 'evprox', + mu=params['mu'], + sigma=params['sigma'], + numspikes=1, + location='proximal', + weights_ampa=weights_ampa, + synaptic_delays=synaptic_delays, + ) # define constraints constraints = dict() - constraints.update({'mu': (1, 6), - 'sigma': (1, 3)}) - - optim = Optimizer(net_offset, tstop=tstop, constraints=constraints, - set_params=set_params, solver=solver, - obj_fun='dipole_rmse', max_iter=max_iter) + constraints.update({'mu': (1, 6), 'sigma': (1, 3)}) + + optim = Optimizer( + net_offset, + tstop=tstop, + constraints=constraints, + set_params=set_params, + solver=solver, + obj_fun='dipole_rmse', + max_iter=max_iter, + ) # test exception raised - with pytest.raises(ValueError, match='The current Network instance has ' - 'external drives, provide a Network object with no ' - 'external drives.'): + with pytest.raises( + ValueError, + match='The current Network instance has ' + 'external drives, provide a Network object with no ' + 'external drives.', + ): net_with_drives = net_orig.copy() - optim = Optimizer(net_with_drives, - tstop=tstop, - constraints=constraints, - set_params=set_params, - solver=solver, - obj_fun='dipole_rmse', - max_iter=max_iter) + optim = Optimizer( + net_with_drives, + tstop=tstop, + constraints=constraints, + set_params=set_params, + solver=solver, + obj_fun='dipole_rmse', + max_iter=max_iter, + ) # test repr before fitting - assert 'fit=False' in repr(optim), "optimizer is already fit" + assert 'fit=False' in repr(optim), 'optimizer is already fit' optim.fit(target=dpl_orig) # test repr after fitting - assert 'fit=True' in repr(optim), "optimizer was not fit" + assert 'fit=True' in repr(optim), 'optimizer was not fit' # the optimized parameter is in the range for param_idx, param in enumerate(optim.opt_params_): - assert (list(constraints.values())[param_idx][0] <= param <= - list(constraints.values())[param_idx][1]), ( - "Optimized parameter is not in user-defined range") + assert ( + list(constraints.values())[param_idx][0] + <= param + <= list(constraints.values())[param_idx][1] + ), 'Optimized parameter is not in user-defined range' obj = optim.obj_ # the number of returned rmse values should be the same as max_iter - assert (len(obj) <= max_iter), ( - "Number of rmse values should be the same as max_iter") + assert len(obj) <= max_iter, 'Number of rmse values should be the same as max_iter' # the returned rmse values should be positive - assert all(vals >= 0 for vals in obj), "rmse values should be positive" + assert all(vals >= 0 for vals in obj), 'rmse values should be positive' -@pytest.mark.parametrize("solver", ['bayesian', 'cobyla']) +@pytest.mark.parametrize('solver', ['bayesian', 'cobyla']) def test_rhythmic(solver): """Test optimization routines for rhythmic drives in a reduced network.""" max_iter = 2 - tstop = 10. + tstop = 10.0 # simulate a dipole to establish ground-truth drive parameters net_offset = jones_2009_model(mesh_shape=(3, 3)) # define set_params function and constraints def set_params(net_offset, params): - # Proximal (alpha) - weights_ampa_p = {'L2_pyramidal': params['alpha_prox_weight'], - 'L5_pyramidal': 4.4e-5} - syn_delays_p = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.} - - net_offset.add_bursty_drive('alpha_prox', - tstart=params['alpha_prox_tstart'], - burst_rate=params['alpha_prox_burst_rate'], - burst_std=params['alpha_prox_burst_std'], - numspikes=2, - spike_isi=10, - n_drive_cells=10, - location='proximal', - weights_ampa=weights_ampa_p, - synaptic_delays=syn_delays_p) + weights_ampa_p = { + 'L2_pyramidal': params['alpha_prox_weight'], + 'L5_pyramidal': 4.4e-5, + } + syn_delays_p = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.0} + + net_offset.add_bursty_drive( + 'alpha_prox', + tstart=params['alpha_prox_tstart'], + burst_rate=params['alpha_prox_burst_rate'], + burst_std=params['alpha_prox_burst_std'], + numspikes=2, + spike_isi=10, + n_drive_cells=10, + location='proximal', + weights_ampa=weights_ampa_p, + synaptic_delays=syn_delays_p, + ) # Distal (beta) - weights_ampa_d = {'L2_pyramidal': params['alpha_dist_weight'], - 'L5_pyramidal': 4.4e-5} - syn_delays_d = {'L2_pyramidal': 5., 'L5_pyramidal': 5.} - - net_offset.add_bursty_drive('alpha_dist', - tstart=params['alpha_dist_tstart'], - burst_rate=params['alpha_dist_burst_rate'], - burst_std=params['alpha_dist_burst_std'], - numspikes=2, - spike_isi=10, - n_drive_cells=10, - location='distal', - weights_ampa=weights_ampa_d, - synaptic_delays=syn_delays_d) + weights_ampa_d = { + 'L2_pyramidal': params['alpha_dist_weight'], + 'L5_pyramidal': 4.4e-5, + } + syn_delays_d = {'L2_pyramidal': 5.0, 'L5_pyramidal': 5.0} + + net_offset.add_bursty_drive( + 'alpha_dist', + tstart=params['alpha_dist_tstart'], + burst_rate=params['alpha_dist_burst_rate'], + burst_std=params['alpha_dist_burst_std'], + numspikes=2, + spike_isi=10, + n_drive_cells=10, + location='distal', + weights_ampa=weights_ampa_d, + synaptic_delays=syn_delays_d, + ) # define constraints constraints = dict() - constraints.update({'alpha_prox_weight': (4.4e-5, 6.4e-5), - 'alpha_prox_tstart': (45, 55), - 'alpha_prox_burst_rate': (8, 12), - 'alpha_prox_burst_std': (10, 25), - 'alpha_dist_weight': (4.4e-5, 6.4e-5), - 'alpha_dist_tstart': (45, 55), - 'alpha_dist_burst_rate': (8, 12), - 'alpha_dist_burst_std': (10, 25)}) + constraints.update( + { + 'alpha_prox_weight': (4.4e-5, 6.4e-5), + 'alpha_prox_tstart': (45, 55), + 'alpha_prox_burst_rate': (8, 12), + 'alpha_prox_burst_std': (10, 25), + 'alpha_dist_weight': (4.4e-5, 6.4e-5), + 'alpha_dist_tstart': (45, 55), + 'alpha_dist_burst_rate': (8, 12), + 'alpha_dist_burst_std': (10, 25), + } + ) # Optimize - optim = Optimizer(net_offset, tstop=tstop, constraints=constraints, - set_params=set_params, solver=solver, - obj_fun='maximize_psd', max_iter=max_iter) + optim = Optimizer( + net_offset, + tstop=tstop, + constraints=constraints, + set_params=set_params, + solver=solver, + obj_fun='maximize_psd', + max_iter=max_iter, + ) # test exception raised - with pytest.raises(ValueError, match='The current Network instance has ' - 'external drives, provide a Network object with no ' - 'external drives.'): + with pytest.raises( + ValueError, + match='The current Network instance has ' + 'external drives, provide a Network object with no ' + 'external drives.', + ): net_with_drives = jones_2009_model(add_drives_from_params=True) - optim = Optimizer(net_with_drives, - tstop=tstop, - constraints=constraints, - set_params=set_params, - solver=solver, - obj_fun='maximize_psd', - max_iter=max_iter) + optim = Optimizer( + net_with_drives, + tstop=tstop, + constraints=constraints, + set_params=set_params, + solver=solver, + obj_fun='maximize_psd', + max_iter=max_iter, + ) # test repr before fitting - assert 'fit=False' in repr(optim), "optimizer is already fit" + assert 'fit=False' in repr(optim), 'optimizer is already fit' optim.fit(f_bands=[(8, 12), (18, 22)], relative_bandpower=(1, 2)) # test repr after fitting - assert 'fit=True' in repr(optim), "optimizer was not fit" + assert 'fit=True' in repr(optim), 'optimizer was not fit' # the optimized parameter is in the range for param_idx, param in enumerate(optim.opt_params_): - assert (list(constraints.values())[param_idx][0] <= param <= - list(constraints.values())[param_idx][1]), ( - "Optimized parameter is not in user-defined range") + assert ( + list(constraints.values())[param_idx][0] + <= param + <= list(constraints.values())[param_idx][1] + ), 'Optimized parameter is not in user-defined range' obj = optim.obj_ # the number of returned rmse values should be the same as max_iter - assert (len(obj) <= max_iter), ( - "Number of rmse values should be the same as max_iter") + assert len(obj) <= max_iter, 'Number of rmse values should be the same as max_iter' -@pytest.mark.parametrize("solver", ['bayesian', 'cobyla']) +@pytest.mark.parametrize('solver', ['bayesian', 'cobyla']) def test_user_obj_fun(solver): """Test optimization routines with a user-defined optimization function.""" max_iter = 2 - tstop = 10. + tstop = 10.0 # simulate a dipole to establish ground-truth drive parameters net_offset = jones_2009_model(mesh_shape=(3, 3)) - def maximize_csd(initial_net, initial_params, set_params, predicted_params, - update_params, obj_values, tstop, obj_fun_kwargs): - + def maximize_csd( + initial_net, + initial_params, + set_params, + predicted_params, + update_params, + obj_values, + tstop, + obj_fun_kwargs, + ): import numpy as np from hnn_core.optimization import _update_params - from hnn_core.extracellular import (calculate_csd2d, - _get_laminar_z_coords) + from hnn_core.extracellular import calculate_csd2d, _get_laminar_z_coords params = _update_params(initial_params, predicted_params) @@ -239,13 +295,16 @@ def maximize_csd(initial_net, initial_params, set_params, predicted_params, for idx, t_band in enumerate(obj_fun_kwargs['t_bands']): t_min = np.argmax(potentials.times >= t_band[0]) t_max = np.argmax(potentials.times >= t_band[1]) - depth_min = np.argmax(contact_labels >= - obj_fun_kwargs['electrode_depths'][idx][0]) - depth_max = np.argmax(contact_labels >= - obj_fun_kwargs['electrode_depths'][idx][1]) + depth_min = np.argmax( + contact_labels >= obj_fun_kwargs['electrode_depths'][idx][0] + ) + depth_max = np.argmax( + contact_labels >= obj_fun_kwargs['electrode_depths'][idx][1] + ) - csd_subsets.append(sum(sum(csd[depth_min:depth_max + 1, - t_min:t_max + 1]))) + csd_subsets.append( + sum(sum(csd[depth_min : depth_max + 1, t_min : t_max + 1])) + ) obj = sum(csd_subsets) / sum(sum(csd)) obj_values.append(obj) @@ -253,53 +312,80 @@ def maximize_csd(initial_net, initial_params, set_params, predicted_params, return obj def set_params(net_offset, params): - weights_ampa = {'L2_basket': 0.5, - 'L2_pyramidal': 0.5, - 'L5_basket': 0.5, - 'L5_pyramidal': 0.5} - synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, - 'L5_basket': 1., 'L5_pyramidal': 1.} - net_offset.add_evoked_drive('evprox', - mu=params['mu'], - sigma=params['sigma'], - numspikes=1, - location='proximal', - weights_ampa=weights_ampa, - synaptic_delays=synaptic_delays) + weights_ampa = { + 'L2_basket': 0.5, + 'L2_pyramidal': 0.5, + 'L5_basket': 0.5, + 'L5_pyramidal': 0.5, + } + synaptic_delays = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.1, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.0, + } + net_offset.add_evoked_drive( + 'evprox', + mu=params['mu'], + sigma=params['sigma'], + numspikes=1, + location='proximal', + weights_ampa=weights_ampa, + synaptic_delays=synaptic_delays, + ) # define constraints constraints = dict() - constraints.update({'mu': (1, 200), - 'sigma': (1, 15)}) - - optim = Optimizer(net_offset, tstop=tstop, constraints=constraints, - set_params=set_params, solver=solver, - obj_fun=maximize_csd, max_iter=max_iter) + constraints.update({'mu': (1, 200), 'sigma': (1, 15)}) + + optim = Optimizer( + net_offset, + tstop=tstop, + constraints=constraints, + set_params=set_params, + solver=solver, + obj_fun=maximize_csd, + max_iter=max_iter, + ) # test exception raised - with pytest.raises(ValueError, match='The current Network instance has ' - 'external drives, provide a Network object with no ' - 'external drives.'): + with pytest.raises( + ValueError, + match='The current Network instance has ' + 'external drives, provide a Network object with no ' + 'external drives.', + ): net_with_drives = jones_2009_model(add_drives_from_params=True) - optim = Optimizer(net_with_drives, - tstop=tstop, - constraints=constraints, - set_params=set_params, - solver=solver, - obj_fun=maximize_csd, - max_iter=max_iter) + optim = Optimizer( + net_with_drives, + tstop=tstop, + constraints=constraints, + set_params=set_params, + solver=solver, + obj_fun=maximize_csd, + max_iter=max_iter, + ) # test repr before fitting - assert 'fit=False' in repr(optim), "optimizer is already fit" + assert 'fit=False' in repr(optim), 'optimizer is already fit' # increase power in infragranular layers (100-150 ms) - optim.fit(t_bands=[(100, 150),], electrode_depths=[(0, 200),]) + optim.fit( + t_bands=[ + (100, 150), + ], + electrode_depths=[ + (0, 200), + ], + ) # test repr after fitting - assert 'fit=True' in repr(optim), "optimizer was not fit" + assert 'fit=True' in repr(optim), 'optimizer was not fit' # the optimized parameter is in the range for param_idx, param in enumerate(optim.opt_params_): - assert (list(constraints.values())[param_idx][0] <= param <= - list(constraints.values())[param_idx][1]), ( - "Optimized parameter is not in user-defined range") + assert ( + list(constraints.values())[param_idx][0] + <= param + <= list(constraints.values())[param_idx][1] + ), 'Optimized parameter is not in user-defined range' diff --git a/hnn_core/tests/test_gui.py b/hnn_core/tests/test_gui.py index 1349c6629..6848998e0 100644 --- a/hnn_core/tests/test_gui.py +++ b/hnn_core/tests/test_gui.py @@ -14,15 +14,19 @@ from pathlib import Path from hnn_core import Dipole, Network from hnn_core.gui import HNNGUI -from hnn_core.gui._viz_manager import (_idx2figname, - _plot_types, - _no_overlay_plot_types, - unlink_relink) -from hnn_core.gui.gui import (_init_network_from_widgets, - _prepare_upload_file, - _update_nested_dict, - serialize_simulation, - serialize_config) +from hnn_core.gui._viz_manager import ( + _idx2figname, + _plot_types, + _no_overlay_plot_types, + unlink_relink, +) +from hnn_core.gui.gui import ( + _init_network_from_widgets, + _prepare_upload_file, + _update_nested_dict, + serialize_simulation, + serialize_config, +) from hnn_core.network import pick_connection, _compare_lists from hnn_core.parallel_backends import requires_mpi4py, requires_psutil from hnn_core.hnn_io import dict_to_network, read_network_configuration @@ -36,9 +40,7 @@ @pytest.fixture def setup_gui(): - gui = HNNGUI( - network_configuration=assets_path / 'jones2009_3x3_drives.json' - ) + gui = HNNGUI(network_configuration=assets_path / 'jones2009_3x3_drives.json') gui.compose() gui.widget_dt.value = 0.5 # speed up tests gui.widget_tstop.value = 70 # speed up tests @@ -53,6 +55,7 @@ def check_equal_networks(net1, net2): comparing GUI-derived networks to API-derived networks. This function adapts the __eq__ """ + def check_equality(item1, item2, message=None): assert item1 == item2, message @@ -61,8 +64,7 @@ def check_equality(item1, item2, message=None): def check_items(dict1, dict2, ignore_keys=[], message=''): for d_key, d_value in dict1.items(): if d_key not in ignore_keys: - check_equality(d_value, dict2[d_key], - f'{message}{d_key} not equal') + check_equality(d_value, dict2[d_key], f'{message}{d_key} not equal') def check_drive(drive1, drive2, keys): name = drive1['name'] @@ -70,28 +72,31 @@ def check_drive(drive1, drive2, keys): value1 = drive1[key] value2 = drive2[key] if key != 'dynamics': - check_equality(value1, value2, - f'>{name}>{key} not equal') + check_equality(value1, value2, f'>{name}>{key} not equal') else: - check_items(value1, value2, ignore_keys=['tstop'], - message=f'>{name}>{key}>') + check_items( + value1, value2, ignore_keys=['tstop'], message=f'>{name}>{key}>' + ) # Check connectivity assert len(net1.connectivity) == len(net2.connectivity) assert _compare_lists(net1.connectivity, net2.connectivity) # Check drives - for drive1, drive2 in zip(net1.external_drives.values(), - net2.external_drives.values()): + for drive1, drive2 in zip( + net1.external_drives.values(), net2.external_drives.values() + ): check_drive(drive1, drive2, keys=drive1.keys()) # Check external biases for bias_name, bias_dict in net1.external_biases.items(): for cell_type, bias_params in bias_dict.items(): - check_items(bias_params, - net2.external_biases[bias_name][cell_type], - ignore_keys=['tstop'], - message=f'{bias_name}>{cell_type}>') + check_items( + bias_params, + net2.external_biases[bias_name][cell_type], + ignore_keys=['tstop'], + message=f'{bias_name}>{cell_type}>', + ) # Check all other attributes attrs_to_ignore = ['connectivity', 'external_drives', 'external_biases'] @@ -99,8 +104,7 @@ def check_drive(drive1, drive2, keys): if attr.startswith('_') or attr in attrs_to_ignore: continue - check_equality(getattr(net1, attr), getattr(net2, attr), - f'{attr} not equal') + check_equality(getattr(net1, attr), getattr(net2, attr), f'{attr} not equal') def test_gui_load_params(): @@ -122,23 +126,20 @@ def test_gui_compose(): def test_prepare_upload_file(): """Tests that input files from local or url sources import correctly""" + def _import_json(content): - decode = codecs.decode(content, encoding="utf-8") + decode = codecs.decode(content, encoding='utf-8') json_content = json.load(io.StringIO(decode)) return json_content - url = "https://raw.githubusercontent.com/jonescompneurolab/hnn-core/master/hnn_core/param/default.json" # noqa + url = 'https://raw.githubusercontent.com/jonescompneurolab/hnn-core/master/hnn_core/param/default.json' # noqa file = Path(hnn_core_root, 'param', 'default.json') content_from_url = _prepare_upload_file(url)[0] content_from_local = _prepare_upload_file(file)[0] - assert (content_from_url['name'] == - content_from_local['name'] == - 'default.json') - assert (content_from_url['type'] == - content_from_local['type'] == - 'application/json') + assert content_from_url['name'] == content_from_local['name'] == 'default.json' + assert content_from_url['type'] == content_from_local['type'] == 'application/json' # Check that the size attribute is present. Cannot do an equivalency check # because file systems may add additional when saving to disk. assert 'size' in content_from_url @@ -218,7 +219,7 @@ def test_gui_upload_drives(): # the drive tstop gets set to the widget value if the drive stop is larger # than the widget tstop. This may change in the future if tstop is saved to # the network configs. - assert gui.drive_widgets[0]['tstop'].value == 170. + assert gui.drive_widgets[0]['tstop'].value == 170.0 # Load connectivity and make sure drives did not change gui._simulate_upload_connectivity(file1_url) @@ -228,8 +229,14 @@ def test_gui_upload_drives(): gui.delete_drive_button.click() gui._simulate_upload_drives(file3_url) drive_names = [widget['name'] for widget in gui.drive_widgets] - assert drive_names == ['evdist1', 'evprox1', 'evprox2', - 'alpha_prox', 'poisson', 'tonic'] + assert drive_names == [ + 'evdist1', + 'evprox1', + 'evprox2', + 'alpha_prox', + 'poisson', + 'tonic', + ] # Check for correct tonic bias loading assert gui.drive_widgets[5]['type'] == 'Tonic' @@ -248,8 +255,8 @@ def test_gui_upload_data(): assert len(gui.viz_manager.data['figs']) == 0 assert len(gui.data['simulation_data']) == 0 - file1_url = "https://raw.githubusercontent.com/jonescompneurolab/hnn/master/data/MEG_detection_data/S1_SupraT.txt" # noqa - file2_url = "https://raw.githubusercontent.com/jonescompneurolab/hnn/master/data/MEG_detection_data/yes_trial_S1_ERP_all_avg.txt" # noqa + file1_url = 'https://raw.githubusercontent.com/jonescompneurolab/hnn/master/data/MEG_detection_data/S1_SupraT.txt' # noqa + file2_url = 'https://raw.githubusercontent.com/jonescompneurolab/hnn/master/data/MEG_detection_data/yes_trial_S1_ERP_all_avg.txt' # noqa gui._simulate_upload_data(file1_url) assert len(gui.data['simulation_data']) == 1 assert 'S1_SupraT' in gui.data['simulation_data'].keys() @@ -267,10 +274,9 @@ def test_gui_upload_data(): assert len(gui.viz_manager.data['figs']) == 2 # No data loading for legacy multi-trial data files. - file3_url = "https://raw.githubusercontent.com/jonescompneurolab/hnn/master/data/gamma_tutorial/100_trials.txt" # noqa + file3_url = 'https://raw.githubusercontent.com/jonescompneurolab/hnn/master/data/gamma_tutorial/100_trials.txt' # noqa with pytest.raises( - ValueError, - match="Data are supposed to have 2 or 4 columns while we have 101." + ValueError, match='Data are supposed to have 2 or 4 columns while we have 101.' ): gui._simulate_upload_data(file3_url) assert len(gui.data['simulation_data']) == 2 @@ -295,7 +301,8 @@ def test_gui_change_connectivity(): src_gids=vbox._belongsto['src_gids'], target_gids=vbox._belongsto['target_gids'], loc=vbox._belongsto['location'], - receptor=vbox._belongsto['receptor']) + receptor=vbox._belongsto['receptor'], + ) assert len(conn_indices) > 0 conn_idx = conn_indices[0] @@ -304,17 +311,24 @@ def test_gui_change_connectivity(): vbox.children[1].value = w_val # re initialize network - _init_network_from_widgets(gui.params, gui.widget_dt, - gui.widget_tstop, - _single_simulation, - gui.drive_widgets, - gui.connectivity_widgets, - gui.cell_pameters_widgets, - add_drive=False) + _init_network_from_widgets( + gui.params, + gui.widget_dt, + gui.widget_tstop, + _single_simulation, + gui.drive_widgets, + gui.connectivity_widgets, + gui.cell_pameters_widgets, + add_drive=False, + ) # test if the new value is reflected in the network - assert (_single_simulation['net'].connectivity[conn_idx] - ['nc_dict']['A_weight'] == w_val) + assert ( + _single_simulation['net'].connectivity[conn_idx]['nc_dict'][ + 'A_weight' + ] + == w_val + ) plt.close('all') @@ -323,8 +337,8 @@ def test_gui_add_drives(): gui = HNNGUI() _ = gui.compose() - for val_drive_type in ("Poisson", "Evoked", "Rhythmic"): - for val_location in ("distal", "proximal"): + for val_drive_type in ('Poisson', 'Evoked', 'Rhythmic'): + for val_location in ('distal', 'proximal'): gui.delete_drive_button.click() assert len(gui.drive_widgets) == 0 @@ -345,16 +359,21 @@ def test_gui_init_network(setup_gui): # now the default parameter has been loaded. _single_simulation = {} _single_simulation['net'] = dict_to_network(gui.params) - _init_network_from_widgets(gui.params, gui.widget_dt, gui.widget_tstop, - _single_simulation, gui.drive_widgets, - gui.connectivity_widgets, - gui.cell_pameters_widgets) + _init_network_from_widgets( + gui.params, + gui.widget_dt, + gui.widget_tstop, + _single_simulation, + gui.drive_widgets, + gui.connectivity_widgets, + gui.cell_pameters_widgets, + ) plt.close('all') net_from_gui = _single_simulation['net'] # copied from test_network.py - assert np.isclose(net_from_gui._inplane_distance, 1.) + assert np.isclose(net_from_gui._inplane_distance, 1.0) assert np.isclose(net_from_gui._layer_separation, 1307.4) # Compare Network created from API @@ -373,13 +392,13 @@ def test_gui_run_simulation_mpi(): gui.widget_tstop.value = 70 gui.widget_dt.value = 0.5 - gui.widget_backend_selection.value = "MPI" + gui.widget_backend_selection.value = 'MPI' gui.widget_ntrials.value = 2 gui.run_button.click() default_name = gui.widget_simulation_name.value dpls = gui.simulation_data[default_name]['dpls'] - assert isinstance(gui.simulation_data[default_name]["net"], Network) + assert isinstance(gui.simulation_data[default_name]['net'], Network) assert isinstance(dpls, list) assert len(dpls) > 0 assert all([isinstance(dpl, Dipole) for dpl in dpls]) @@ -390,10 +409,8 @@ def test_gui_run_simulations(setup_gui): """Test if run button triggers multiple simulations correctly.""" gui = setup_gui - tstop_trials_tstep = [(10, 1, 0.25), - (10, 2, 0.5), - (12, 1, 0.5)] - assert gui.widget_backend_selection.value == "Joblib" + tstop_trials_tstep = [(10, 1, 0.25), (10, 2, 0.5), (12, 1, 0.5)] + assert gui.widget_backend_selection.value == 'Joblib' sim_count = 0 for val_tstop, val_ntrials, val_tstep in tstop_trials_tstep: @@ -406,18 +423,14 @@ def test_gui_run_simulations(setup_gui): sim_name = gui.widget_simulation_name.value dpls = gui.simulation_data[sim_name]['dpls'] - assert isinstance(gui.simulation_data[sim_name]["net"], - Network) + assert isinstance(gui.simulation_data[sim_name]['net'], Network) assert isinstance(dpls, list) assert all([isinstance(dpl, Dipole) for dpl in dpls]) assert len(dpls) == val_ntrials - assert all([ - pytest.approx(dpl.times[-1]) == val_tstop for dpl in dpls - ]) - assert all([ - pytest.approx(dpl.times[1] - dpl.times[0]) == val_tstep - for dpl in dpls - ]) + assert all([pytest.approx(dpl.times[-1]) == val_tstop for dpl in dpls]) + assert all( + [pytest.approx(dpl.times[1] - dpl.times[0]) == val_tstep for dpl in dpls] + ) sim_count += 1 @@ -425,23 +438,23 @@ def test_gui_run_simulations(setup_gui): def test_non_unique_name_error(setup_gui): - """ Checks that simulation fails if new name is not supplied. """ + """Checks that simulation fails if new name is not supplied.""" gui = setup_gui sim_name = gui.widget_simulation_name.value gui.run_button.click() dpls = gui.simulation_data[sim_name]['dpls'] - assert isinstance(gui.simulation_data[sim_name]["net"], Network) + assert isinstance(gui.simulation_data[sim_name]['net'], Network) assert isinstance(dpls, list) - assert gui._simulation_status_bar.value == \ - gui._simulation_status_contents['finished'] + assert ( + gui._simulation_status_bar.value == gui._simulation_status_contents['finished'] + ) gui.widget_simulation_name.value = sim_name gui.run_button.click() assert len(gui.simulation_data) == 1 - assert gui._simulation_status_bar.value == \ - gui._simulation_status_contents['failed'] + assert gui._simulation_status_bar.value == gui._simulation_status_contents['failed'] plt.close('all') @@ -451,7 +464,7 @@ def test_gui_take_screenshots(): gui.compose(return_layout=False) screenshot = gui.capture(render=False) assert type(screenshot) is IFrame - gui._simulate_left_tab_click("External drives") + gui._simulate_left_tab_click('External drives') screenshot1 = gui.capture(render=False) assert screenshot._repr_html_() != screenshot1._repr_html_() plt.close('all') @@ -497,8 +510,7 @@ def test_gui_add_figure(setup_gui): fig_tabs.get_title(idx) for idx in range(len(fig_tabs.children)) ] remaining_titles2 = [ - axes_config_tabs.get_title(idx) - for idx in range(len(axes_config_tabs.children)) + axes_config_tabs.get_title(idx) for idx in range(len(axes_config_tabs.children)) ] correct_remaining_titles = [_idx2figname(idx) for idx in (1, 3, 4)] assert remaining_titles1 == remaining_titles2 == correct_remaining_titles @@ -522,13 +534,15 @@ def test_gui_add_data_dependent_figure(setup_gui): assert len(axes_config_tabs.children) == 1 assert gui.viz_manager.fig_idx['idx'] == 2 - template_names = [('Drive-Dipole (2x1)', 2), - ('Dipole Layers (3x1)', 3), - ('Drive-Spikes (2x1)', 2), - ('Dipole-Spectrogram (2x1)', 2), - ("Dipole-Spikes (2x1)", 2), - ('Drive-Dipole-Spectrogram (3x1)', 3), - ('PSD Layers (3x1)', 3)] + template_names = [ + ('Drive-Dipole (2x1)', 2), + ('Dipole Layers (3x1)', 3), + ('Drive-Spikes (2x1)', 2), + ('Dipole-Spectrogram (2x1)', 2), + ('Dipole-Spikes (2x1)', 2), + ('Drive-Dipole-Spectrogram (3x1)', 3), + ('PSD Layers (3x1)', 3), + ] n_fig = 1 for template_name, num_axes in template_names: @@ -552,7 +566,7 @@ def test_gui_edit_figure(setup_gui): axes_config_tabs = gui.viz_manager.axes_config_tabs # after each run we should have a default fig - sim_names = ["t1", "t2", "t3"] + sim_names = ['t1', 't2', 't3'] for sim_idx, sim_name in enumerate(sim_names): gui.widget_simulation_name.value = sim_name gui.run_button.click() @@ -586,13 +600,13 @@ def test_gui_synchronous_inputs(setup_gui): # Run simulation gui.run_button.click() - sim = (gui.viz_manager.data - ['simulations'][gui.widget_simulation_name.value]) + sim = gui.viz_manager.data['simulations'][gui.widget_simulation_name.value] # Filter connections for specific driver_name first network_connections = sim['net'].connectivity - driver_connections = [conn for conn in network_connections - if conn['src_type'] == drive_name] + driver_connections = [ + conn for conn in network_connections if conn['src_type'] == drive_name + ] # Check src_gids length for connectivity in driver_connections: @@ -616,13 +630,13 @@ def test_gui_cell_specific_drive(setup_gui): # Filter connections for specific driver_name first network_connections = sim['net'].connectivity - driver_connections = [conn for conn in network_connections - if conn['src_type'] == driver_name] + driver_connections = [ + conn for conn in network_connections if conn['src_type'] == driver_name + ] # Check src_gids length for connectivity in driver_connections: - assert (len(connectivity['src_gids']) == - len(connectivity['target_gids'])) + assert len(connectivity['src_gids']) == len(connectivity['target_gids']) def test_gui_figure_overlay(setup_gui): @@ -661,41 +675,48 @@ def test_gui_visualization(setup_gui): gui.widget_tstop.value = 500 gui.run_button.click() - gui._simulate_viz_action("switch_fig_template", "[Blank] single figure") - gui._simulate_viz_action("add_fig") + gui._simulate_viz_action('switch_fig_template', '[Blank] single figure') + gui._simulate_viz_action('add_fig') figid = 2 figname = f'Figure {figid}' axname = 'ax0' for viz_type in _plot_types: - gui._simulate_viz_action("edit_figure", figname, - axname, 'default', viz_type, {}, 'clear') + gui._simulate_viz_action( + 'edit_figure', figname, axname, 'default', viz_type, {}, 'clear' + ) # Check that extra axes have been successfully removed assert len(gui.viz_manager.figs[figid].axes) == 1 # Check if data on the axes has been successfully cleared assert not gui.viz_manager.figs[figid].axes[0].has_data() - gui._simulate_viz_action("edit_figure", figname, - axname, 'default', viz_type, {}, 'plot') + gui._simulate_viz_action( + 'edit_figure', figname, axname, 'default', viz_type, {}, 'plot' + ) # Check if data is plotted on the axes assert gui.viz_manager.figs[figid].axes[0].has_data() - if viz_type == "input histogram": + if viz_type == 'input histogram': # Check if the correct number of axes are present # "input histogram" is a special case due to "plot_spikes_hist" # using 2 axes assert len(gui.viz_manager.figs[figid].axes) == 2 - elif viz_type == "spectrogram": + elif viz_type == 'spectrogram': # make sure the colorbar is correctly added - assert any(['_cbar-ax-' in attr - for attr in dir(gui.viz_manager.figs[figid])]) is True + assert ( + any(['_cbar-ax-' in attr for attr in dir(gui.viz_manager.figs[figid])]) + is True + ) assert len(gui.viz_manager.figs[figid].axes) == 2 # make sure the colorbar is safely removed - gui._simulate_viz_action("edit_figure", figname, axname, 'default', - 'spectrogram', {}, 'clear') - assert any(['_cbar-ax-' in attr - for attr in dir(gui.viz_manager.figs[figid])]) is False + gui._simulate_viz_action( + 'edit_figure', figname, axname, 'default', 'spectrogram', {}, 'clear' + ) + assert ( + any(['_cbar-ax-' in attr for attr in dir(gui.viz_manager.figs[figid])]) + is False + ) assert len(gui.viz_manager.figs[figid].axes) == 1 else: @@ -721,12 +742,18 @@ def test_dipole_data_overlay(setup_gui): figid = 1 figname = f'Figure {figid}' axname = 'ax1' - gui._simulate_viz_action("edit_figure", figname, - axname, 'default', 'current dipole', {}, 'clear') - gui._simulate_viz_action("edit_figure", figname, - axname, 'default', 'current dipole', - {'data_to_compare': 'test_default'}, - 'plot') + gui._simulate_viz_action( + 'edit_figure', figname, axname, 'default', 'current dipole', {}, 'clear' + ) + gui._simulate_viz_action( + 'edit_figure', + figname, + axname, + 'default', + 'current dipole', + {'data_to_compare': 'test_default'}, + 'plot', + ) ax = gui.viz_manager.figs[figid].axes[1] # Check number of lines @@ -759,12 +786,10 @@ def __init__(self): def add_child(self, to_add=1): n_tabs = len(self.tab_group_2.children) + to_add # Add tab and select latest tab - self.tab_group_1.children = \ - [Text(f'Test{s}') for s in np.arange(n_tabs)] + self.tab_group_1.children = [Text(f'Test{s}') for s in np.arange(n_tabs)] self.tab_group_1.selected_index = n_tabs - 1 - self.tab_group_2.children = \ - [Text(f'Test{s}') for s in np.arange(n_tabs)] + self.tab_group_2.children = [Text(f'Test{s}') for s in np.arange(n_tabs)] self.tab_group_2.selected_index = n_tabs - 1 @unlink_relink(attribute='tab_link') @@ -799,37 +824,43 @@ def test_gui_download_simulation(setup_gui): gui.widget_ntrials.value = 2 # Initiate 1rs simulation - sim_name = "sim1" + sim_name = 'sim1' gui.widget_simulation_name.value = sim_name # Run simulation gui.run_button.click() - _, file_extension = ( - serialize_simulation(gui.data, sim_name)) + _, file_extension = serialize_simulation(gui.data, sim_name) # result is a zip file - assert file_extension == ".zip" + assert file_extension == '.zip' # Run a simulation with 1 trials gui.widget_ntrials.value = 1 # Initiate 2nd simulation - sim_name2 = "sim2" + sim_name2 = 'sim2' gui.widget_simulation_name.value = sim_name2 # Run simulation gui.run_button.click() - _, file_extension = ( - serialize_simulation(gui.data, sim_name2)) + _, file_extension = serialize_simulation(gui.data, sim_name2) # result is a single csv file - assert file_extension == ".csv" + assert file_extension == '.csv' # Check no loaded data is listed in the sims dropdown list to download - file1_url = "https://raw.githubusercontent.com/jonescompneurolab/hnn/master/data/MEG_detection_data/S1_SupraT.txt" # noqa + file1_url = 'https://raw.githubusercontent.com/jonescompneurolab/hnn/master/data/MEG_detection_data/S1_SupraT.txt' # noqa gui._simulate_upload_data(file1_url) download_simulation_list = gui.simulation_list_widget.options - assert (len([sim_name for sim_name in download_simulation_list - if sim_name == "S1_SupraT"]) == 0) + assert ( + len( + [ + sim_name + for sim_name in download_simulation_list + if sim_name == 'S1_SupraT' + ] + ) + == 0 + ) def test_gui_upload_csv_simulation(setup_gui): @@ -854,20 +885,25 @@ def test_gui_upload_csv_simulation(setup_gui): # we are loading only 1 trial, # assume all the data we need is in the [0] position - data_lengh = ( - len(gui.data['simulation_data']['test_default']['dpls'][0].times)) + data_lengh = len(gui.data['simulation_data']['test_default']['dpls'][0].times) assert len(gui.data['simulation_data']) == 1 assert 'test_default' in gui.data['simulation_data'].keys() assert gui.data['simulation_data']['test_default']['net'] is None assert type(gui.data['simulation_data']['test_default']['dpls']) is list assert len(gui.viz_manager.data['figs']) == 1 - assert (len(gui.data['simulation_data']['test_default'] - ['dpls'][0].data['agg']) == data_lengh) - assert (len(gui.data['simulation_data']['test_default'] - ['dpls'][0].data['L2']) == data_lengh) - assert (len(gui.data['simulation_data']['test_default'] - ['dpls'][0].data['L5']) == data_lengh) + assert ( + len(gui.data['simulation_data']['test_default']['dpls'][0].data['agg']) + == data_lengh + ) + assert ( + len(gui.data['simulation_data']['test_default']['dpls'][0].data['L2']) + == data_lengh + ) + assert ( + len(gui.data['simulation_data']['test_default']['dpls'][0].data['L5']) + == data_lengh + ) def test_gui_download_configuration(setup_gui): @@ -876,7 +912,7 @@ def test_gui_download_configuration(setup_gui): gui = setup_gui # Initiate 1st simulation - sim_name = "sim1" + sim_name = 'sim1' gui.widget_simulation_name.value = sim_name # Run simulation @@ -902,19 +938,18 @@ def test_gui_add_tonic_input(): """Test if gui add different type of drives.""" gui = HNNGUI() _ = gui.compose() - assert 'tonic' not in [drive['type'].lower() - for drive in gui.drive_widgets] + assert 'tonic' not in [drive['type'].lower() for drive in gui.drive_widgets] _single_simulation = {} _single_simulation['net'] = dict_to_network(gui.params) # Add tonic input widget - gui.widget_drive_type_selection.value = "Tonic" + gui.widget_drive_type_selection.value = 'Tonic' gui.add_drive_button.click() # Check last drive (Tonic) last_drive = gui.drive_widgets[-1] - assert last_drive['type'] == "Tonic" + assert last_drive['type'] == 'Tonic' assert last_drive['t0'].value == 0.0 assert last_drive['tstop'].value == 170.0 assert last_drive['amplitude']['L5_pyramidal'].value == 0 @@ -926,14 +961,23 @@ def test_gui_add_tonic_input(): # Check that you can't add more than one tonic gui.add_drive_button.click() - assert ([drive['type'].lower() for drive in gui.drive_widgets] == - ['evoked', 'evoked', 'evoked', 'tonic']) + assert [drive['type'].lower() for drive in gui.drive_widgets] == [ + 'evoked', + 'evoked', + 'evoked', + 'tonic', + ] # Add tonic bias to the network - _init_network_from_widgets(gui.params, gui.widget_dt, gui.widget_tstop, - _single_simulation, gui.drive_widgets, - gui.connectivity_widgets, - gui.cell_pameters_widgets) + _init_network_from_widgets( + gui.params, + gui.widget_dt, + gui.widget_tstop, + _single_simulation, + gui.drive_widgets, + gui.connectivity_widgets, + gui.cell_pameters_widgets, + ) net = _single_simulation['net'] assert net.external_biases['tonic'] is not None @@ -948,17 +992,19 @@ def test_gui_cell_params_widgets(setup_gui): _single_simulation = {} _single_simulation['net'] = dict_to_network(gui.params) _single_simulation['net'].cell_types - pyramid_cell_types = [cell_type for cell_type - in _single_simulation['net'].cell_types - if "pyramidal" in cell_type] - assert (len(pyramid_cell_types) == 2) + pyramid_cell_types = [ + cell_type + for cell_type in _single_simulation['net'].cell_types + if 'pyramidal' in cell_type + ] + assert len(pyramid_cell_types) == 2 # Security check for if parameters have been added or removed from the cell # params dict. Any additions will need mappings added to the # update_{*}_cell_params functions layers = gui.cell_layer_radio_buttons.options - assert (len(layers) == 3) + assert len(layers) == 3 keys = gui.cell_pameters_widgets.keys() num_cell_params = 0 @@ -966,18 +1012,18 @@ def test_gui_cell_params_widgets(setup_gui): cell_type = pyramid_cell_type.split('_')[0] for cell_layer in layers: key = f'{cell_type} Pyramidal_{cell_layer}' - assert (any(key in k for k in keys)) + assert any(key in k for k in keys) num_cell_params += 1 - assert (len(keys) == num_cell_params) + assert len(keys) == num_cell_params # Check the if the cell params dictionary has been updated cell_params = gui.get_cell_parameters_dict() - assert (len(cell_params['Geometry L2']) == 20) - assert (len(cell_params['Geometry L5']) == 22) - assert (len(cell_params['Synapses']) == 12) - assert (len(cell_params['Biophysics L2']) == 10) - assert (len(cell_params['Biophysics L5']) == 20) + assert len(cell_params['Geometry L2']) == 20 + assert len(cell_params['Geometry L5']) == 22 + assert len(cell_params['Synapses']) == 12 + assert len(cell_params['Biophysics L2']) == 10 + assert len(cell_params['Biophysics L5']) == 20 def test_fig_tabs_dropdown_lists(setup_gui): @@ -988,14 +1034,14 @@ def test_fig_tabs_dropdown_lists(setup_gui): gui.widget_ntrials.value = 1 # Initiate 1st simulation - sim_name = "sim1" + sim_name = 'sim1' gui.widget_simulation_name.value = sim_name # Run simulation gui.run_button.click() # Initiate 2nd simulation - sim_name2 = "sim2" + sim_name2 = 'sim2' gui.widget_simulation_name.value = sim_name2 # Run simulation @@ -1005,93 +1051,87 @@ def test_fig_tabs_dropdown_lists(setup_gui): for tab in viz_tabs: controls = tab.children[1] for ax_control in controls.children: - assert ax_control.children[1].description == "Simulation Data:" + assert ax_control.children[1].description == 'Simulation Data:' sim_names = ax_control.children[1].options # Check that dropdown has been updated with all simulation names assert all(sim in sim_names for sim in [sim_name, sim_name2]) - assert ax_control.children[4].description == "Data to Compare:" + assert ax_control.children[4].description == 'Data to Compare:' # Check the data to compare dropdown is enable for # non "input histograms" plot type - if ax_control.children[0].value != "input histogram": + if ax_control.children[0].value != 'input histogram': assert not ax_control.children[4].disabled def test_update_nested_dict(): """Tests nested dictionary updates values appropriately.""" - original = {'a': 1, - 'b': {'a2': 0, - 'b2': {'a3': 0 - } - }, - } + original = { + 'a': 1, + 'b': {'a2': 0, 'b2': {'a3': 0}}, + } # Changes at each level - changes = {'a': 2, - 'b': {'a2': 1, - 'b2': {'a3': 1 - } - }, - } + changes = { + 'a': 2, + 'b': {'a2': 1, 'b2': {'a3': 1}}, + } updated = _update_nested_dict(original, changes) expected = changes assert updated == expected # Omitted items should not be changed from in the original - omission = {'a': 2, - 'b': {'a2': 0}, - } - expected = {'a': 2, - 'b': {'a2': 0, - 'b2': {'a3': 0 - } - }, - } + omission = { + 'a': 2, + 'b': {'a2': 0}, + } + expected = { + 'a': 2, + 'b': {'a2': 0, 'b2': {'a3': 0}}, + } updated = _update_nested_dict(original, omission) assert updated == expected # Additional items should be added - addition = {'a': 2, - 'b': {'a2': 0, - 'b2': {'a3': 0, - 'b3': 0, - }, - 'c2': 1 - }, - 'c': 1 - } + addition = { + 'a': 2, + 'b': { + 'a2': 0, + 'b2': { + 'a3': 0, + 'b3': 0, + }, + 'c2': 1, + }, + 'c': 1, + } expected = addition updated = _update_nested_dict(original, addition) assert updated == expected # Test passing of None values - has_none = {'a': 1, - 'b': {'a2': None}, - } + has_none = { + 'a': 1, + 'b': {'a2': None}, + } # Default behavior will not pass in None values to the update expected = original # No change expected updated = _update_nested_dict(original, has_none) assert updated == expected # Skip_none set of False will pass in None values to the update updated = _update_nested_dict(original, has_none, skip_none=False) - expected = {'a': 1, - 'b': {'a2': None, - 'b2': {'a3': 0 - } - }, - } + expected = { + 'a': 1, + 'b': {'a2': None, 'b2': {'a3': 0}}, + } assert updated == expected # Values that evaluate to False that but are not None type should be passed # to the updated dict by default. - has_nulls = {'a': 0, - 'b': {'a2': np.nan, - 'b2': {'a3': False, - 'b3': '' - } - }, - } + has_nulls = { + 'a': 0, + 'b': {'a2': np.nan, 'b2': {'a3': False, 'b3': ''}}, + } # Skip_none set of False will pass in None values to the update updated = _update_nested_dict(original, has_nulls) expected = has_nulls @@ -1102,17 +1142,21 @@ def test_delete_single_drive(setup_gui): """Deleting a single drive.""" gui = setup_gui assert len(gui.drive_accordion.children) == 6 - assert gui.drive_accordion.titles == ('evdist1 (distal)', - 'evprox1 (proximal)', - 'evprox2 (proximal)', - 'alpha_prox (proximal)', - 'poisson (proximal)', - 'tonic') + assert gui.drive_accordion.titles == ( + 'evdist1 (distal)', + 'evprox1 (proximal)', + 'evprox2 (proximal)', + 'alpha_prox (proximal)', + 'poisson (proximal)', + 'tonic', + ) gui._simulate_delete_single_drive(2) assert len(gui.drive_accordion.children) == 5 - assert gui.drive_accordion.titles == ('evdist1 (distal)', - 'evprox1 (proximal)', - 'alpha_prox (proximal)', - 'poisson (proximal)', - 'tonic') + assert gui.drive_accordion.titles == ( + 'evdist1 (distal)', + 'evprox1 (proximal)', + 'alpha_prox (proximal)', + 'poisson (proximal)', + 'tonic', + ) diff --git a/hnn_core/tests/test_io.py b/hnn_core/tests/test_io.py index 54d32f212..e421aa0df 100644 --- a/hnn_core/tests/test_io.py +++ b/hnn_core/tests/test_io.py @@ -7,15 +7,22 @@ import numpy as np import json -from hnn_core import (simulate_dipole, read_params, - jones_2009_model, calcium_model, - ) - -from hnn_core.hnn_io import (_cell_response_to_dict, _rec_array_to_dict, - _external_drive_to_dict, _str_to_node, - _conn_to_dict, _order_drives, - read_network_configuration - ) +from hnn_core import ( + simulate_dipole, + read_params, + jones_2009_model, + calcium_model, +) + +from hnn_core.hnn_io import ( + _cell_response_to_dict, + _rec_array_to_dict, + _external_drive_to_dict, + _str_to_node, + _conn_to_dict, + _order_drives, + read_network_configuration, +) hnn_core_root = Path(__file__).parents[1] assets_path = Path(hnn_core_root, 'tests', 'assets') @@ -33,42 +40,83 @@ def params(): @pytest.fixture def jones_2009_network(params): - # Instantiating network along with drives - net = jones_2009_model(params=params, add_drives_from_params=True, - mesh_shape=(3, 3)) + net = jones_2009_model( + params=params, add_drives_from_params=True, mesh_shape=(3, 3) + ) # Adding bias - tonic_bias = {'L2_pyramidal': 1.0, 'L5_pyramidal': 0.0, - 'L2_basket': 0.0, 'L5_basket': 0.0} + tonic_bias = { + 'L2_pyramidal': 1.0, + 'L5_pyramidal': 0.0, + 'L2_basket': 0.0, + 'L5_basket': 0.0, + } net.add_tonic_bias(amplitude=tonic_bias) # Add drives location = 'proximal' burst_std = 20 - weights_ampa_p = {'L2_pyramidal': 5.4e-5, 'L5_pyramidal': 5.4e-5, - 'L2_basket': 0.0, 'L5_basket': 0.0} - weights_nmda_p = {'L2_pyramidal': 0.0, 'L5_pyramidal': 0.0, - 'L2_basket': 0.0, 'L5_basket': 0.0} - syn_delays_p = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1., - 'L2_basket': 0.0, 'L5_basket': 0.0} + weights_ampa_p = { + 'L2_pyramidal': 5.4e-5, + 'L5_pyramidal': 5.4e-5, + 'L2_basket': 0.0, + 'L5_basket': 0.0, + } + weights_nmda_p = { + 'L2_pyramidal': 0.0, + 'L5_pyramidal': 0.0, + 'L2_basket': 0.0, + 'L5_basket': 0.0, + } + syn_delays_p = { + 'L2_pyramidal': 0.1, + 'L5_pyramidal': 1.0, + 'L2_basket': 0.0, + 'L5_basket': 0.0, + } net.add_bursty_drive( - 'alpha_prox', tstart=1., burst_rate=10, burst_std=burst_std, - numspikes=2, spike_isi=10, n_drive_cells=10, location=location, - weights_ampa=weights_ampa_p, weights_nmda=weights_nmda_p, - synaptic_delays=syn_delays_p, event_seed=284) - - weights_ampa = {'L2_pyramidal': 0.0008, 'L5_pyramidal': 0.0075, - 'L2_basket': 0.0, 'L5_basket': 0.0} - synaptic_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.0, - 'L2_basket': 0.0, 'L5_basket': 0.0} - rate_constant = {'L2_pyramidal': 140.0, 'L5_pyramidal': 40.0, - 'L2_basket': 40.0, 'L5_basket': 40.0} + 'alpha_prox', + tstart=1.0, + burst_rate=10, + burst_std=burst_std, + numspikes=2, + spike_isi=10, + n_drive_cells=10, + location=location, + weights_ampa=weights_ampa_p, + weights_nmda=weights_nmda_p, + synaptic_delays=syn_delays_p, + event_seed=284, + ) + + weights_ampa = { + 'L2_pyramidal': 0.0008, + 'L5_pyramidal': 0.0075, + 'L2_basket': 0.0, + 'L5_basket': 0.0, + } + synaptic_delays = { + 'L2_pyramidal': 0.1, + 'L5_pyramidal': 1.0, + 'L2_basket': 0.0, + 'L5_basket': 0.0, + } + rate_constant = { + 'L2_pyramidal': 140.0, + 'L5_pyramidal': 40.0, + 'L2_basket': 40.0, + 'L5_basket': 40.0, + } net.add_poisson_drive( - 'poisson', rate_constant=rate_constant, - weights_ampa=weights_ampa, weights_nmda=weights_nmda_p, - location='proximal', synaptic_delays=synaptic_delays, - event_seed=1349) + 'poisson', + rate_constant=rate_constant, + weights_ampa=weights_ampa, + weights_nmda=weights_nmda_p, + location='proximal', + synaptic_delays=synaptic_delays, + event_seed=1349, + ) # Adding electrode arrays electrode_pos = (1, 2, 3) @@ -82,13 +130,10 @@ def jones_2009_network(params): @pytest.fixture def calcium_network(params): # Instantiating network along with drives - net = calcium_model(params=params, add_drives_from_params=True, - mesh_shape=(3, 3)) + net = calcium_model(params=params, add_drives_from_params=True, mesh_shape=(3, 3)) # Adding bias - tonic_bias = { - 'L2_pyramidal': 1.0 - } + tonic_bias = {'L2_pyramidal': 1.0} net.add_tonic_bias(amplitude=tonic_bias) # Adding electrode arrays @@ -101,7 +146,7 @@ def calcium_network(params): def generate_test_files(jones_2009_network): - """ Generates files used in read-in tests """ + """Generates files used in read-in tests""" net = jones_2009_network net.write_configuration(Path('.', 'assets/jones2009_3x3_drives.json')) @@ -127,8 +172,7 @@ def test_eq(jones_2009_network, calcium_network): # Hardwired change in drive weights net1_hard_change_drive = net1.copy() - (net1_hard_change_drive.external_drives['evdist1']['weights_ampa'] - ['L2_basket']) = 0 + (net1_hard_change_drive.external_drives['evdist1']['weights_ampa']['L2_basket']) = 0 assert net1_hard_change_drive != net1 @@ -164,7 +208,7 @@ def test_eq_conn(jones_2009_network): def test_write_configuration(tmp_path, jones_2009_network): - """ Tests that a json file is written """ + """Tests that a json file is written""" net = jones_2009_network.copy() simulate_dipole(net, tstop=2, n_trials=1, dt=0.5) @@ -185,28 +229,21 @@ def test_write_configuration(tmp_path, jones_2009_network): assert last_mod_time1 < last_mod_time2 # No overwrite check - with pytest.raises(FileExistsError, - match="File already exists at path "): + with pytest.raises(FileExistsError, match='File already exists at path '): jones_2009_network.write_configuration(path_out, overwrite=False) # Check no outputs were written with open(path_out) as file: read_in = json.load(file) - assert not any([bool(val['times']) - for val in read_in['rec_arrays'].values()] - ) - assert not any([bool(val['voltages']) - for val in read_in['rec_arrays'].values()] - ) - assert not any([bool(val['events']) - for val in read_in['external_drives'].values()] - ) + assert not any([bool(val['times']) for val in read_in['rec_arrays'].values()]) + assert not any([bool(val['voltages']) for val in read_in['rec_arrays'].values()]) + assert not any([bool(val['events']) for val in read_in['external_drives'].values()]) assert read_in['cell_response'] == {} def test_cell_response_to_dict(jones_2009_network): - """ Tests _cell_response_to_dict function """ + """Tests _cell_response_to_dict function""" net = jones_2009_network # When a simulation hasn't been run, return an empty dict @@ -225,19 +262,26 @@ def test_cell_response_to_dict(jones_2009_network): def test_rec_array_to_dict(jones_2009_network): - """ Tests _rec_array_to_dict function """ + """Tests _rec_array_to_dict function""" net = jones_2009_network # Check rec array times and voltages are in dict after simulation simulate_dipole(net, tstop=2, n_trials=1, dt=0.5) result = _rec_array_to_dict(net.rec_arrays['el1'], write_output=True) assert isinstance(result, dict) - assert all([key in result for key in ['positions', 'conductivity', - 'method', 'min_distance', - 'times', 'voltages' - ] - ] - ) + assert all( + [ + key in result + for key in [ + 'positions', + 'conductivity', + 'method', + 'min_distance', + 'times', + 'voltages', + ] + ] + ) assert np.array_equal(result['times'], [0.0, 0.5, 1.0, 1.5, 2.0]) assert result['voltages'].shape == (1, 1, 5) @@ -248,58 +292,72 @@ def test_rec_array_to_dict(jones_2009_network): def test_conn_to_dict(jones_2009_network): - """ Tests _connectivity_to_list_of_dicts function """ + """Tests _connectivity_to_list_of_dicts function""" net = jones_2009_network result = _conn_to_dict(net.connectivity[0]) assert isinstance(result, dict) - assert result == {'target_type': 'L2_basket', - 'target_gids': [0, 1, 2], - 'num_targets': 3, - 'src_type': 'evdist1', - 'src_gids': [24, 25, 26], - 'num_srcs': 3, - 'gid_pairs': {'24': [0], '25': [1], '26': [2]}, - 'loc': 'distal', - 'receptor': 'ampa', - 'nc_dict': {'A_delay': 0.1, - 'A_weight': 0.006562, - 'lamtha': 3.0, - 'threshold': 0.0, - 'gain': 1.0}, - 'allow_autapses': 1, - 'probability': 1.0} + assert result == { + 'target_type': 'L2_basket', + 'target_gids': [0, 1, 2], + 'num_targets': 3, + 'src_type': 'evdist1', + 'src_gids': [24, 25, 26], + 'num_srcs': 3, + 'gid_pairs': {'24': [0], '25': [1], '26': [2]}, + 'loc': 'distal', + 'receptor': 'ampa', + 'nc_dict': { + 'A_delay': 0.1, + 'A_weight': 0.006562, + 'lamtha': 3.0, + 'threshold': 0.0, + 'gain': 1.0, + }, + 'allow_autapses': 1, + 'probability': 1.0, + } def test_external_drive_to_dict(jones_2009_network): - """ Tests _external_drive_to_dict function """ + """Tests _external_drive_to_dict function""" net = jones_2009_network simulate_dipole(net, tstop=2, n_trials=1, dt=0.5) first_key = list(net.external_drives.keys())[0] - result = _external_drive_to_dict(net.external_drives[first_key], - write_output=True - ) + result = _external_drive_to_dict(net.external_drives[first_key], write_output=True) assert isinstance(result, dict) - assert all([key in result for key in ['type', 'location', 'n_drive_cells', - 'event_seed', 'conn_seed', - 'dynamics', 'events', 'weights_ampa', - 'weights_nmda', 'synaptic_delays', - 'probability', 'name', - 'target_types', 'cell_specific' - ] - ] - ) + assert all( + [ + key in result + for key in [ + 'type', + 'location', + 'n_drive_cells', + 'event_seed', + 'conn_seed', + 'dynamics', + 'events', + 'weights_ampa', + 'weights_nmda', + 'synaptic_delays', + 'probability', + 'name', + 'target_types', + 'cell_specific', + ] + ] + ) assert len(result['events'][0]) == 21 - result2 = _external_drive_to_dict(net.external_drives[first_key], - write_output=False - ) + result2 = _external_drive_to_dict( + net.external_drives[first_key], write_output=False + ) assert len(result2['events']) == 0 def test_str_to_node(): - """ Creates a tuple (str,int) from string with a comma """ + """Creates a tuple (str,int) from string with a comma""" result = _str_to_node('cell_name,0') assert isinstance(result, tuple) assert isinstance(result[0], str) @@ -307,44 +365,52 @@ def test_str_to_node(): def test_order_drives(jones_2009_network): - """ Reorders drive dict by ascending range order """ + """Reorders drive dict by ascending range order""" drive_names = list(jones_2009_network.external_drives.keys()) drive_names_alpha = sorted(drive_names) - drives_reordered = {name: jones_2009_network.external_drives - for name in drive_names_alpha} - assert (list(drives_reordered.keys()) == - ['alpha_prox', 'evdist1', 'evprox1', 'evprox2', 'poisson']) - - drives_by_range = _order_drives(jones_2009_network.gid_ranges, - drives_reordered) - assert (list(drives_by_range.keys()) == - ['evdist1', 'evprox1', 'evprox2', 'alpha_prox', 'poisson']) + drives_reordered = { + name: jones_2009_network.external_drives for name in drive_names_alpha + } + assert list(drives_reordered.keys()) == [ + 'alpha_prox', + 'evdist1', + 'evprox1', + 'evprox2', + 'poisson', + ] + + drives_by_range = _order_drives(jones_2009_network.gid_ranges, drives_reordered) + assert list(drives_by_range.keys()) == [ + 'evdist1', + 'evprox1', + 'evprox2', + 'alpha_prox', + 'poisson', + ] def test_read_configuration_json(jones_2009_network): - """ Read-in of a hdf5 file """ - net = read_network_configuration(Path(assets_path, - 'jones2009_3x3_drives.json') - ) + """Read-in of a hdf5 file""" + net = read_network_configuration(Path(assets_path, 'jones2009_3x3_drives.json')) assert net == jones_2009_network # Read without drives net_no_drives = read_network_configuration( - Path(assets_path, 'jones2009_3x3_drives.json'), - read_drives=False + Path(assets_path, 'jones2009_3x3_drives.json'), read_drives=False ) # Check there are no external drives assert len(net_no_drives.external_drives) == 0 # Check there are no external drive connections - connection_src_types = [connection['src_type'] - for connection in net_no_drives.connectivity] - assert not any([src_type in net.external_drives.keys() - for src_type in connection_src_types]) + connection_src_types = [ + connection['src_type'] for connection in net_no_drives.connectivity + ] + assert not any( + [src_type in net.external_drives.keys() for src_type in connection_src_types] + ) # Read without external bias net_no_bias = read_network_configuration( - Path(assets_path, 'jones2009_3x3_drives.json'), - read_external_biases=False + Path(assets_path, 'jones2009_3x3_drives.json'), read_external_biases=False ) assert len(net_no_bias.external_biases) == 0 assert len(net_no_bias.external_drives) > 0 @@ -356,11 +422,10 @@ def test_read_incorrect_format(tmp_path): # Checking object type field not exists error dummy_data = dict() - dummy_data['object_type'] = "NotNetwork" + dummy_data['object_type'] = 'NotNetwork' file_path = tmp_path / 'not_net.json' with open(file_path, 'w') as file: json.dump(dummy_data, file) - with pytest.raises(ValueError, - match="The json should encode a Network object."): + with pytest.raises(ValueError, match='The json should encode a Network object.'): read_network_configuration(file_path) diff --git a/hnn_core/tests/test_mpi_child.py b/hnn_core/tests/test_mpi_child.py index 98d9f2788..68817b7e9 100644 --- a/hnn_core/tests/test_mpi_child.py +++ b/hnn_core/tests/test_mpi_child.py @@ -7,19 +7,22 @@ import hnn_core from hnn_core import read_params, Network, jones_2009_model -from hnn_core.mpi_child import (MPISimulation, _str_to_net, _pickle_data) -from hnn_core.parallel_backends import (_gather_trial_data, - _process_child_data, - _echo_child_output, - _get_data_from_child_err, - _extract_data, _extract_data_length) +from hnn_core.mpi_child import MPISimulation, _str_to_net, _pickle_data +from hnn_core.parallel_backends import ( + _gather_trial_data, + _process_child_data, + _echo_child_output, + _get_data_from_child_err, + _extract_data, + _extract_data_length, +) def test_get_data_from_child_err(): """Test _get_data_from_child_err for handling stderr""" # write data to queue err_q = Queue() - test_string = "this gets printed to stdout" + test_string = 'this gets printed to stdout' err_q.put(test_string) with io.StringIO() as buf_out, redirect_stdout(buf_out): @@ -32,7 +35,7 @@ def test_echo_child_output(): """Test _echo_child_output for handling stdout, i.e. status messages""" # write data to queue out_q = Queue() - test_string = "Test output" + test_string = 'Test output' out_q.put(test_string) with io.StringIO() as buf_out, redirect_stdout(buf_out): @@ -46,16 +49,16 @@ def test_extract_data(): """Test _extract_data for extraction between signals""" # no ending - test_string = "@start_of_data@start of data" + test_string = '@start_of_data@start of data' output = _extract_data(test_string, 'data') assert output == '' # valid end, but no start to data - test_string = "end of data@end_of_data:11@" + test_string = 'end of data@end_of_data:11@' output = _extract_data(test_string, 'data') assert output == '' - test_string = "@start_of_data@all data@end_of_data:8@" + test_string = '@start_of_data@all data@end_of_data:8@' output = _extract_data(test_string, 'data') assert output == 'all data' @@ -63,12 +66,11 @@ def test_extract_data(): def test_extract_data_length(): """Test _extract_data_length for data length in signal""" - test_string = "end of data@end_of_data:@" - with pytest.raises(ValueError, match="Couldn't find data length in " - "string"): + test_string = 'end of data@end_of_data:@' + with pytest.raises(ValueError, match="Couldn't find data length in " 'string'): _extract_data_length(test_string, 'data') - test_string = "all data@end_of_data:8@" + test_string = 'all data@end_of_data:8@' output = _extract_data_length(test_string, 'data') assert output == 8 @@ -85,19 +87,25 @@ def test_str_to_net(): pickled_net = _pickle_data(net) - input_str = '@start_of_net@' + pickled_net.decode() + \ - '@end_of_net:%d@\n' % (len(pickled_net)) + input_str = ( + '@start_of_net@' + + pickled_net.decode() + + '@end_of_net:%d@\n' % (len(pickled_net)) + ) received_net = _str_to_net(input_str) assert isinstance(received_net, Network) # muck with the data size in the signal - input_str = '@start_of_net@' + pickled_net.decode() + \ - '@end_of_net:%d@\n' % (len(pickled_net) + 1) + input_str = ( + '@start_of_net@' + + pickled_net.decode() + + '@end_of_net:%d@\n' % (len(pickled_net) + 1) + ) - expected_string = "Got incorrect network size: %d bytes " % \ - len(pickled_net) + "expected length: %d" % \ - (len(pickled_net) + 1) + expected_string = 'Got incorrect network size: %d bytes ' % len( + pickled_net + ) + 'expected length: %d' % (len(pickled_net) + 1) # process input from queue with pytest.raises(ValueError, match=expected_string): @@ -113,26 +121,26 @@ def test_child_run(): params_fname = op.join(hnn_core_root, 'param', 'default.json') params = read_params(params_fname) params_reduced = params.copy() - params_reduced.update({'t_evprox_1': 5, - 't_evdist_1': 10, - 't_evprox_2': 20}) + params_reduced.update({'t_evprox_1': 5, 't_evdist_1': 10, 't_evprox_2': 20}) tstop, n_trials = 25, 2 - net_reduced = jones_2009_model(params_reduced, add_drives_from_params=True, - mesh_shape=(3, 3)) + net_reduced = jones_2009_model( + params_reduced, add_drives_from_params=True, mesh_shape=(3, 3) + ) net_reduced._instantiate_drives(tstop=tstop, n_trials=n_trials) with MPISimulation(skip_mpi_import=True) as mpi_sim: with io.StringIO() as buf, redirect_stdout(buf): - sim_data = mpi_sim.run(net_reduced, tstop=tstop, dt=0.025, - n_trials=n_trials) + sim_data = mpi_sim.run( + net_reduced, tstop=tstop, dt=0.025, n_trials=n_trials + ) stdout = buf.getvalue() - assert "Trial 1: 0.03 ms..." in stdout + assert 'Trial 1: 0.03 ms...' in stdout with io.StringIO() as buf_err, redirect_stderr(buf_err): mpi_sim._write_data_stderr(sim_data) stderr_str = buf_err.getvalue() - assert "@start_of_data@" in stderr_str - assert "@end_of_data:" in stderr_str + assert '@start_of_data@' in stderr_str + assert '@end_of_data:' in stderr_str # write data to queue err_q = Queue() @@ -150,13 +158,12 @@ def test_child_run(): def test_empty_data(): """Test that an empty string raises RuntimeError""" data_bytes = b'' - with pytest.raises(RuntimeError, - match="MPI simulation didn't return any data"): + with pytest.raises(RuntimeError, match="MPI simulation didn't return any data"): _process_child_data(data_bytes, len(data_bytes)) def test_data_len_mismatch(): - """Test that padded data can be unpickled with warning for length """ + """Test that padded data can be unpickled with warning for length""" pickled_bytes = _pickle_data({}) @@ -165,8 +172,10 @@ def test_data_len_mismatch(): with pytest.warns(UserWarning) as record: _process_child_data(pickled_bytes, expected_len) - expected_string = "Length of received data unexpected. " + \ - "Expecting %d bytes, got %d" % (expected_len, len(pickled_bytes)) + expected_string = ( + 'Length of received data unexpected. ' + + 'Expecting %d bytes, got %d' % (expected_len, len(pickled_bytes)) + ) assert len(record) == 1 assert record[0].message.args[0] == expected_string diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index 5b7d7a742..a60c9d872 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -17,9 +17,9 @@ params_fname = op.join(hnn_core_root, 'param', 'default.json') -@pytest.fixture(scope="class") +@pytest.fixture(scope='class') def base_network(): - """ Base Network with connections and drives """ + """Base Network with connections and drives""" params_fname = op.join(hnn_core_root, 'param', 'default.json') params = read_params(params_fname) net = Network(params, legacy_mode=False) @@ -29,41 +29,75 @@ def base_network(): for target_cell in ['L2_pyramidal', 'L5_pyramidal']: for receptor in ['nmda', 'ampa']: net.add_connection( - target_cell, target_cell, loc='proximal', receptor=receptor, - weight=5e-4, delay=net.delay, lamtha=3.0, allow_autapses=False) + target_cell, + target_cell, + loc='proximal', + receptor=receptor, + weight=5e-4, + delay=net.delay, + lamtha=3.0, + allow_autapses=False, + ) # layer2 Basket -> layer2 Pyr # layer5 Basket -> layer5 Pyr for receptor in ['gabaa', 'gabab']: net.add_connection( - src_gids='L2_basket', target_gids='L2_pyramidal', loc='soma', - receptor=receptor, weight=5e-4, delay=net.delay, lamtha=50.0) + src_gids='L2_basket', + target_gids='L2_pyramidal', + loc='soma', + receptor=receptor, + weight=5e-4, + delay=net.delay, + lamtha=50.0, + ) net.add_connection( - src_gids='L5_basket', target_gids='L2_pyramidal', loc='soma', - receptor=receptor, weight=5e-4, delay=net.delay, lamtha=70.0) + src_gids='L5_basket', + target_gids='L2_pyramidal', + loc='soma', + receptor=receptor, + weight=5e-4, + delay=net.delay, + lamtha=70.0, + ) # layer2 Basket -> layer2 Basket (autapses allowed) net.add_connection( - src_gids='L2_basket', target_gids='L2_basket', loc='soma', - receptor='gabaa', weight=5e-4, delay=net.delay, lamtha=20.0) + src_gids='L2_basket', + target_gids='L2_basket', + loc='soma', + receptor='gabaa', + weight=5e-4, + delay=net.delay, + lamtha=20.0, + ) # add arbitrary drives that contribute artificial cells to network - net.add_evoked_drive(name='evdist1', mu=5.0, sigma=1.0, - numspikes=1, location='distal', - weights_ampa={'L2_basket': 0.1, - 'L2_pyramidal': 0.1}) - net.add_evoked_drive(name='evprox1', mu=5.0, sigma=1.0, - numspikes=1, location='proximal', - weights_ampa={'L2_basket': 0.1, - 'L2_pyramidal': 0.1}) + net.add_evoked_drive( + name='evdist1', + mu=5.0, + sigma=1.0, + numspikes=1, + location='distal', + weights_ampa={'L2_basket': 0.1, 'L2_pyramidal': 0.1}, + ) + net.add_evoked_drive( + name='evprox1', + mu=5.0, + sigma=1.0, + numspikes=1, + location='proximal', + weights_ampa={'L2_basket': 0.1, 'L2_pyramidal': 0.1}, + ) return net, params def test_network_models(): - """"Test instantiations of the network object""" + """ "Test instantiations of the network object""" # Make sure critical biophysics for Law model are updated net_law = law_2021_model() # instantiate drive events for NetworkBuilder - net_law._instantiate_drives(tstop=net_law._params['tstop'], - n_trials=net_law._params['N_trials']) + net_law._instantiate_drives( + tstop=net_law._params['tstop'], n_trials=net_law._params['N_trials'] + ) for cell_name in ['L5_pyramidal', 'L2_pyramidal']: assert net_law.cell_types[cell_name].synapses['gabab']['tau1'] == 45.0 @@ -74,8 +108,7 @@ def test_network_models(): with pytest.raises(TypeError, match='net must be'): add_erp_drives_to_jones_model(net='invalid_input') with pytest.raises(TypeError, match='tstart must be'): - add_erp_drives_to_jones_model(net=net_default, - tstart='invalid_input') + add_erp_drives_to_jones_model(net=net_default, tstart='invalid_input') n_conn = len(net_default.connectivity) add_erp_drives_to_jones_model(net_default) for drive_name in ['evdist1', 'evprox1', 'evprox2']: @@ -87,19 +120,22 @@ def test_network_models(): # Ensure distant dependent calcium gbar net_calcium = calcium_model() # instantiate drive events for NetworkBuilder - net_calcium._instantiate_drives(tstop=net_calcium._params['tstop'], - n_trials=net_calcium._params['N_trials']) + net_calcium._instantiate_drives( + tstop=net_calcium._params['tstop'], n_trials=net_calcium._params['N_trials'] + ) network_builder = NetworkBuilder(net_calcium) gid = net_calcium.gid_ranges['L5_pyramidal'][0] - for section_name, section in \ - network_builder._cells[gid]._nrn_sections.items(): + for section_name, section in network_builder._cells[gid]._nrn_sections.items(): # Section endpoints where seg.x == 0.0 or 1.0 don't have 'ca' mech - ca_gbar = [seg.__getattribute__('ca').gbar for - seg in list(section.allseg())[1:-1]] - na_gbar = [seg.__getattribute__('hh2').gnabar for - seg in list(section.allseg())[1:-1]] - k_gbar = [seg.__getattribute__('hh2').gkbar for - seg in list(section.allseg())[1:-1]] + ca_gbar = [ + seg.__getattribute__('ca').gbar for seg in list(section.allseg())[1:-1] + ] + na_gbar = [ + seg.__getattribute__('hh2').gnabar for seg in list(section.allseg())[1:-1] + ] + k_gbar = [ + seg.__getattribute__('hh2').gkbar for seg in list(section.allseg())[1:-1] + ] # Ensure positive distance dependent calcium gbar with plateau if section_name == 'apical_tuft': @@ -122,33 +158,33 @@ def test_network_models(): def test_network_cell_positions(): - """"Test manipulation of cell positions in the network object""" + """ "Test manipulation of cell positions in the network object""" net = jones_2009_model() - assert np.isclose(net._inplane_distance, 1.) # default + assert np.isclose(net._inplane_distance, 1.0) # default assert np.isclose(net._layer_separation, 1307.4) # default # change both from their default values - net.set_cell_positions(inplane_distance=2.) + net.set_cell_positions(inplane_distance=2.0) assert np.isclose(net._layer_separation, 1307.4) # still the default - net.set_cell_positions(layer_separation=1000.) - assert np.isclose(net._inplane_distance, 2.) # mustn't change + net.set_cell_positions(layer_separation=1000.0) + assert np.isclose(net._inplane_distance, 2.0) # mustn't change # check that in-plane distance is now 2. for the default 10 x 10 grid assert np.allclose( # x-coordinate jumps every 10th gid - np.diff(np.array(net.pos_dict['L5_pyramidal'])[9::10, 0], axis=0), 2.) + np.diff(np.array(net.pos_dict['L5_pyramidal'])[9::10, 0], axis=0), 2.0 + ) assert np.allclose( # test first 10 y-coordinates - np.diff(np.array(net.pos_dict['L5_pyramidal'])[:9, 1], axis=0), 2.) + np.diff(np.array(net.pos_dict['L5_pyramidal'])[:9, 1], axis=0), 2.0 + ) # check that layer separation has changed (L5 is zero) tp 1000. - assert np.isclose(net.pos_dict['L2_pyramidal'][0][2], 1000.) + assert np.isclose(net.pos_dict['L2_pyramidal'][0][2], 1000.0) - with pytest.raises(ValueError, - match='In-plane distance must be positive'): - net.set_cell_positions(inplane_distance=0.) - with pytest.raises(ValueError, - match='Layer separation must be positive'): - net.set_cell_positions(layer_separation=0.) + with pytest.raises(ValueError, match='In-plane distance must be positive'): + net.set_cell_positions(inplane_distance=0.0) + with pytest.raises(ValueError, match='Layer separation must be positive'): + net.set_cell_positions(layer_separation=0.0) # Check that the origin of the drive cells matches the new 'origin' # when set_cell_positions is called after adding drives. @@ -157,13 +193,12 @@ def test_network_cell_positions(): # dependent weights and delays of the drives are calculated with respect to # this origin. add_erp_drives_to_jones_model(net) - net.set_cell_positions(inplane_distance=20.) + net.set_cell_positions(inplane_distance=20.0) for drive_name, drive in net.external_drives.items(): assert len(net.pos_dict[drive_name]) == drive['n_drive_cells'] # just test the 0th index, assume all others then fine too for idx in range(3): # x,y,z coords - assert (net.pos_dict[drive_name][0][idx] == - net.pos_dict['origin'][idx]) + assert net.pos_dict[drive_name][0][idx] == net.pos_dict['origin'][idx] def test_network_drives(): @@ -181,11 +216,17 @@ def test_network_drives(): n_drive_cells = 'n_cells' n_drive_cells_list.append(n_drive_cells) drive_weights = dict() - drive_weights['evdist1'] = {'L2_basket': 0.01, 'L2_pyramidal': 0.02, - 'L5_pyramidal': 0.03} + drive_weights['evdist1'] = { + 'L2_basket': 0.01, + 'L2_pyramidal': 0.02, + 'L5_pyramidal': 0.03, + } drive_delays = dict() - drive_delays['evdist1'] = {'L2_basket': 0.1, 'L2_pyramidal': 0.2, - 'L5_pyramidal': 0.3} + drive_delays['evdist1'] = { + 'L2_basket': 0.1, + 'L2_pyramidal': 0.2, + 'L5_pyramidal': 0.3, + } net.add_evoked_drive( name='evdist1', mu=63.53, @@ -197,14 +238,23 @@ def test_network_drives(): weights_ampa=drive_weights['evdist1'], weights_nmda=drive_weights['evdist1'], synaptic_delays=drive_delays['evdist1'], - event_seed=274) + event_seed=274, + ) n_drive_cells = 'n_cells' n_drive_cells_list.append(n_drive_cells) - drive_weights['evprox1'] = {'L2_basket': 0.04, 'L2_pyramidal': 0.05, - 'L5_basket': 0.06, 'L5_pyramidal': 0.07} - drive_delays['evprox1'] = {'L2_basket': 0.4, 'L2_pyramidal': 0.5, - 'L5_basket': 0.6, 'L5_pyramidal': 0.7} + drive_weights['evprox1'] = { + 'L2_basket': 0.04, + 'L2_pyramidal': 0.05, + 'L5_basket': 0.06, + 'L5_pyramidal': 0.07, + } + drive_delays['evprox1'] = { + 'L2_basket': 0.4, + 'L2_pyramidal': 0.5, + 'L5_basket': 0.6, + 'L5_pyramidal': 0.7, + } net.add_evoked_drive( name='evprox1', mu=26.61, @@ -216,14 +266,23 @@ def test_network_drives(): weights_ampa=drive_weights['evprox1'], weights_nmda=drive_weights['evprox1'], synaptic_delays=drive_delays['evprox1'], - event_seed=544) + event_seed=544, + ) n_drive_cells = 'n_cells' n_drive_cells_list.append(n_drive_cells) - drive_weights['evprox2'] = {'L2_basket': 0.08, 'L2_pyramidal': 0.09, - 'L5_basket': 0.1, 'L5_pyramidal': 0.11} - drive_delays['evprox2'] = {'L2_basket': 0.8, 'L2_pyramidal': 0.9, - 'L5_basket': 1.0, 'L5_pyramidal': 1.1} + drive_weights['evprox2'] = { + 'L2_basket': 0.08, + 'L2_pyramidal': 0.09, + 'L5_basket': 0.1, + 'L5_pyramidal': 0.11, + } + drive_delays['evprox2'] = { + 'L2_basket': 0.8, + 'L2_pyramidal': 0.9, + 'L5_basket': 1.0, + 'L5_pyramidal': 1.1, + } net.add_evoked_drive( name='evprox2', mu=137.12, @@ -235,55 +294,71 @@ def test_network_drives(): weights_ampa=drive_weights['evprox2'], weights_nmda=drive_weights['evprox2'], synaptic_delays=drive_delays['evprox2'], - event_seed=814) + event_seed=814, + ) # add an bursty drive as well n_drive_cells = 10 n_drive_cells_list.append(n_drive_cells) - drive_weights['bursty1'] = {'L2_basket': 0.12, 'L2_pyramidal': 0.13, - 'L5_basket': 0.14, 'L5_pyramidal': 0.15} - drive_delays['bursty1'] = {'L2_basket': 1.2, 'L2_pyramidal': 1.3, - 'L5_basket': 1.4, 'L5_pyramidal': 1.5} + drive_weights['bursty1'] = { + 'L2_basket': 0.12, + 'L2_pyramidal': 0.13, + 'L5_basket': 0.14, + 'L5_pyramidal': 0.15, + } + drive_delays['bursty1'] = { + 'L2_basket': 1.2, + 'L2_pyramidal': 1.3, + 'L5_basket': 1.4, + 'L5_pyramidal': 1.5, + } net.add_bursty_drive( name='bursty1', - tstart=10., + tstart=10.0, tstart_std=0.5, - tstop=30., + tstop=30.0, location='proximal', - burst_rate=100., - burst_std=0., + burst_rate=100.0, + burst_std=0.0, numspikes=2, - spike_isi=1., + spike_isi=1.0, n_drive_cells=n_drive_cells, cell_specific=False, weights_ampa=drive_weights['bursty1'], weights_nmda=drive_weights['bursty1'], synaptic_delays=drive_delays['bursty1'], - event_seed=4) + event_seed=4, + ) # add poisson drive as well n_drive_cells = 'n_cells' n_drive_cells_list.append(n_drive_cells) - drive_weights['poisson1'] = {'L2_basket': 0.16, 'L2_pyramidal': 0.17, - 'L5_pyramidal': 0.18} - drive_delays['poisson1'] = {'L2_basket': 1.6, 'L2_pyramidal': 1.7, - 'L5_pyramidal': 1.8} + drive_weights['poisson1'] = { + 'L2_basket': 0.16, + 'L2_pyramidal': 0.17, + 'L5_pyramidal': 0.18, + } + drive_delays['poisson1'] = { + 'L2_basket': 1.6, + 'L2_pyramidal': 1.7, + 'L5_pyramidal': 1.8, + } net.add_poisson_drive( name='poisson1', - tstart=10., - tstop=30., - rate_constant=50., + tstart=10.0, + tstop=30.0, + rate_constant=50.0, location='distal', n_drive_cells=n_drive_cells, cell_specific=True, weights_ampa=drive_weights['poisson1'], weights_nmda=drive_weights['poisson1'], synaptic_delays=drive_delays['poisson1'], - event_seed=4) + event_seed=4, + ) # instantiate drive events for NetworkBuilder - net._instantiate_drives(tstop=params['tstop'], - n_trials=params['N_trials']) + net._instantiate_drives(tstop=params['tstop'], n_trials=params['N_trials']) network_builder = NetworkBuilder(net) # needed to instantiate cells # Assert that params are conserved across Network initialization @@ -294,10 +369,14 @@ def test_network_drives(): print(network_builder._cells[:2]) # Assert that proper number/types of gids are created for Network drives - dns_from_gids = [name for name in net.gid_ranges.keys() if - name not in net.cell_types] - assert (sorted(dns_from_gids) == sorted(net.external_drives.keys()) == - sorted(drive_weights.keys())) + dns_from_gids = [ + name for name in net.gid_ranges.keys() if name not in net.cell_types + ] + assert ( + sorted(dns_from_gids) + == sorted(net.external_drives.keys()) + == sorted(drive_weights.keys()) + ) for dn in dns_from_gids: n_drive_cells = net.external_drives[dn]['n_drive_cells'] assert len(net.gid_ranges[dn]) == n_drive_cells @@ -307,8 +386,7 @@ def test_network_drives(): # source gids by target type drive_src_list = list() for target_type in net.cell_types: - conn_idxs = pick_connection(net, src_gids='evprox1', - target_gids=target_type) + conn_idxs = pick_connection(net, src_gids='evprox1', target_gids=target_type) src_set = set() for conn_idx in conn_idxs: src_set.update(net.connectivity[conn_idx]['src_gids']) @@ -319,9 +397,13 @@ def test_network_drives(): for drive_idx, drive in enumerate(net.external_drives.values()): # Check that connectivity sources correspond to gid_ranges conn_idxs = pick_connection(net, src_gids=drive['name']) - this_src_gids = set([gid for conn_idx in conn_idxs - for gid in net.connectivity[conn_idx]['src_gids'] - ]) # NB set: globals + this_src_gids = set( + [ + gid + for conn_idx in conn_idxs + for gid in net.connectivity[conn_idx]['src_gids'] + ] + ) # NB set: globals assert sorted(this_src_gids) == list(net.gid_ranges[drive['name']]) # Check type-specific dynamics and events n_drive_cells = drive['n_drive_cells'] @@ -339,70 +421,96 @@ def test_network_drives(): assert kw in drive['dynamics'].keys() assert len(drive['events'][0]) == n_drive_cells elif drive['type'] == 'bursty': - for kw in ['tstart', 'tstart_std', 'tstop', - 'burst_rate', 'burst_std', 'numspikes']: + for kw in [ + 'tstart', + 'tstart_std', + 'tstop', + 'burst_rate', + 'burst_std', + 'numspikes', + ]: assert kw in drive['dynamics'].keys() assert len(drive['events'][0]) == n_drive_cells n_events = ( - drive['dynamics']['numspikes'] * # 2 - (1 + (drive['dynamics']['tstop'] - - drive['dynamics']['tstart'] - 1) // - (1000. / drive['dynamics']['burst_rate']))) + drive['dynamics']['numspikes'] # 2 + * ( + 1 + + (drive['dynamics']['tstop'] - drive['dynamics']['tstart'] - 1) + // (1000.0 / drive['dynamics']['burst_rate']) + ) + ) assert len(drive['events'][0][0]) == n_events # 4 # make sure the PRNGs are consistent. - target_times = {'evdist1': [66.30498327062551, 66.33129889343446], - 'evprox1': [23.80641637082997, 30.857310915553647], - 'evprox2': [141.76252038319825, 137.73942375578602]} + target_times = { + 'evdist1': [66.30498327062551, 66.33129889343446], + 'evprox1': [23.80641637082997, 30.857310915553647], + 'evprox2': [141.76252038319825, 137.73942375578602], + } for drive_name in target_times: for idx in [0, -1]: # first and last - assert_allclose(net.external_drives[drive_name]['events'][0][idx], - target_times[drive_name][idx], rtol=1e-12) + assert_allclose( + net.external_drives[drive_name]['events'][0][idx], + target_times[drive_name][idx], + rtol=1e-12, + ) # check select excitatory (AMPA+NMDA) synaptic weights and delays for drive_name in drive_weights: for target_type in drive_weights[drive_name]: - conn_idxs = pick_connection(net, src_gids=drive_name, - target_gids=target_type) + conn_idxs = pick_connection( + net, src_gids=drive_name, target_gids=target_type + ) for conn_idx in conn_idxs: drive_conn = net.connectivity[conn_idx] # weights - assert_allclose(drive_conn['nc_dict']['A_weight'], - drive_weights[drive_name][target_type], - rtol=1e-12) + assert_allclose( + drive_conn['nc_dict']['A_weight'], + drive_weights[drive_name][target_type], + rtol=1e-12, + ) # delays - assert_allclose(drive_conn['nc_dict']['A_delay'], - drive_delays[drive_name][target_type], - rtol=1e-12) + assert_allclose( + drive_conn['nc_dict']['A_delay'], + drive_delays[drive_name][target_type], + rtol=1e-12, + ) # array of simulation times is created in Network.__init__, but passed # to CellResponse-constructor for storage (Network is agnostic of time) - with pytest.raises(TypeError, - match="'times' is an np.ndarray of simulation times"): + with pytest.raises(TypeError, match="'times' is an np.ndarray of simulation times"): _ = CellResponse(times='blah') # Check that all external drives are initialized with the expected amount # of artificial cells assuming legacy_mode=False (i.e., dependent on # drive targets). - prox_targets = (len(net.gid_ranges['L2_basket']) + - len(net.gid_ranges['L2_pyramidal']) + - len(net.gid_ranges['L5_basket']) + - len(net.gid_ranges['L5_pyramidal'])) - dist_targets = (len(net.gid_ranges['L2_basket']) + - len(net.gid_ranges['L2_pyramidal']) + - len(net.gid_ranges['L5_pyramidal'])) + prox_targets = ( + len(net.gid_ranges['L2_basket']) + + len(net.gid_ranges['L2_pyramidal']) + + len(net.gid_ranges['L5_basket']) + + len(net.gid_ranges['L5_pyramidal']) + ) + dist_targets = ( + len(net.gid_ranges['L2_basket']) + + len(net.gid_ranges['L2_pyramidal']) + + len(net.gid_ranges['L5_pyramidal']) + ) n_evoked_sources = dist_targets + (2 * prox_targets) n_pois_sources = dist_targets n_bursty_sources = net.external_drives['bursty1']['n_drive_cells'] # test that expected number of external driving events are created - assert len(network_builder._drive_cells) == (n_evoked_sources + - n_pois_sources + - n_bursty_sources) - assert len(network_builder._gid_list) ==\ - len(network_builder._drive_cells) + net._n_cells + assert len(network_builder._drive_cells) == ( + n_evoked_sources + n_pois_sources + n_bursty_sources + ) + assert ( + len(network_builder._gid_list) + == len(network_builder._drive_cells) + net._n_cells + ) # first 'evoked drive' comes after real cells and bursty drive cells - assert network_builder._drive_cells[n_bursty_sources].gid ==\ - net._n_cells + n_bursty_sources + assert ( + network_builder._drive_cells[n_bursty_sources].gid + == net._n_cells + n_bursty_sources + ) # check that Network drive connectivity transfers to NetworkBuilder n_pyr = len(net.gid_ranges['L2_pyramidal']) @@ -428,12 +536,16 @@ def test_network_drives_legacy(): """Test manipulation of drives in the network object under legacy mode.""" params = read_params(params_fname) # add rhythmic inputs (i.e., a type of common input) - params.update({'input_dist_A_weight_L2Pyr_ampa': 1.4e-5, - 'input_dist_A_weight_L5Pyr_ampa': 2.4e-5, - 't0_input_dist': 50, - 'input_prox_A_weight_L2Pyr_ampa': 3.4e-5, - 'input_prox_A_weight_L5Pyr_ampa': 4.4e-5, - 't0_input_prox': 50}) + params.update( + { + 'input_dist_A_weight_L2Pyr_ampa': 1.4e-5, + 'input_dist_A_weight_L5Pyr_ampa': 2.4e-5, + 't0_input_dist': 50, + 'input_prox_A_weight_L2Pyr_ampa': 3.4e-5, + 'input_prox_A_weight_L5Pyr_ampa': 4.4e-5, + 't0_input_prox': 50, + } + ) # Test deprecation warning of legacy mode with pytest.warns(DeprecationWarning, match='Legacy mode'): @@ -442,12 +554,10 @@ def test_network_drives_legacy(): _ = calcium_model(legacy_mode=True) _ = Network(params, legacy_mode=True) - net = jones_2009_model(params, legacy_mode=True, - add_drives_from_params=True) + net = jones_2009_model(params, legacy_mode=True, add_drives_from_params=True) # instantiate drive events for NetworkBuilder - net._instantiate_drives(tstop=params['tstop'], - n_trials=params['N_trials']) + net._instantiate_drives(tstop=params['tstop'], n_trials=params['N_trials']) network_builder = NetworkBuilder(net) # needed to instantiate cells # Assert that params are conserved across Network initialization @@ -458,8 +568,9 @@ def test_network_drives_legacy(): print(network_builder._cells[:2]) # Assert that proper number/types of gids are created for Network drives - dns_from_gids = [name for name in net.gid_ranges.keys() if - name not in net.cell_types] + dns_from_gids = [ + name for name in net.gid_ranges.keys() if name not in net.cell_types + ] assert sorted(dns_from_gids) == sorted(net.external_drives.keys()) for dn in dns_from_gids: n_drive_cells = net.external_drives[dn]['n_drive_cells'] @@ -469,9 +580,13 @@ def test_network_drives_legacy(): for drive in net.external_drives.values(): # Check that connectivity sources correspond to gid_ranges conn_idxs = pick_connection(net, src_gids=drive['name']) - this_src_gids = set([gid for conn_idx in conn_idxs - for gid in net.connectivity[conn_idx]['src_gids'] - ]) # NB set: globals + this_src_gids = set( + [ + gid + for conn_idx in conn_idxs + for gid in net.connectivity[conn_idx]['src_gids'] + ] + ) # NB set: globals assert sorted(this_src_gids) == list(net.gid_ranges[drive['name']]) # Check type-specific dynamics and events n_drive_cells = drive['n_drive_cells'] @@ -491,68 +606,83 @@ def test_network_drives_legacy(): assert kw in drive['dynamics'].keys() assert len(drive['events'][0]) == n_drive_cells elif drive['type'] == 'bursty': - for kw in ['tstart', 'tstart_std', 'tstop', - 'burst_rate', 'burst_std', 'numspikes']: + for kw in [ + 'tstart', + 'tstart_std', + 'tstop', + 'burst_rate', + 'burst_std', + 'numspikes', + ]: assert kw in drive['dynamics'].keys() assert len(drive['events'][0]) == n_drive_cells n_events = ( - drive['dynamics']['numspikes'] * # 2 - (1 + (drive['dynamics']['tstop'] - - drive['dynamics']['tstart'] - 1) // - (1000. / drive['dynamics']['burst_rate']))) + drive['dynamics']['numspikes'] # 2 + * ( + 1 + + (drive['dynamics']['tstop'] - drive['dynamics']['tstart'] - 1) + // (1000.0 / drive['dynamics']['burst_rate']) + ) + ) assert len(drive['events'][0][0]) == n_events # 4 # make sure the PRNGs are consistent. - target_times = {'evdist1': [66.30498327062551, 61.54362532343694], - 'evprox1': [23.80641637082997, 30.857310915553647], - 'evprox2': [141.76252038319825, 137.73942375578602]} + target_times = { + 'evdist1': [66.30498327062551, 61.54362532343694], + 'evprox1': [23.80641637082997, 30.857310915553647], + 'evprox2': [141.76252038319825, 137.73942375578602], + } for drive_name in target_times: for idx in [0, -1]: # first and last - assert_allclose(net.external_drives[drive_name]['events'][0][idx], - target_times[drive_name][idx], rtol=1e-12) + assert_allclose( + net.external_drives[drive_name]['events'][0][idx], + target_times[drive_name][idx], + rtol=1e-12, + ) # check select AMPA weights - target_weights = {'evdist1': {'L2_basket': 0.006562, - 'L5_pyramidal': 0.142300}, - 'evprox1': {'L2_basket': 0.08831, - 'L5_pyramidal': 0.00865}, - 'evprox2': {'L2_basket': 0.000003, - 'L5_pyramidal': 0.684013}, - 'bursty1': {'L2_pyramidal': 0.000034, - 'L5_pyramidal': 0.000044}, - 'bursty2': {'L2_pyramidal': 0.000014, - 'L5_pyramidal': 0.000024} - } + target_weights = { + 'evdist1': {'L2_basket': 0.006562, 'L5_pyramidal': 0.142300}, + 'evprox1': {'L2_basket': 0.08831, 'L5_pyramidal': 0.00865}, + 'evprox2': {'L2_basket': 0.000003, 'L5_pyramidal': 0.684013}, + 'bursty1': {'L2_pyramidal': 0.000034, 'L5_pyramidal': 0.000044}, + 'bursty2': {'L2_pyramidal': 0.000014, 'L5_pyramidal': 0.000024}, + } for drive_name in target_weights: for target_type in target_weights[drive_name]: - conn_idxs = pick_connection(net, src_gids=drive_name, - target_gids=target_type, - receptor='ampa') + conn_idxs = pick_connection( + net, src_gids=drive_name, target_gids=target_type, receptor='ampa' + ) for conn_idx in conn_idxs: drive_conn = net.connectivity[conn_idx] - assert_allclose(drive_conn['nc_dict']['A_weight'], - target_weights[drive_name][target_type], - rtol=1e-12) + assert_allclose( + drive_conn['nc_dict']['A_weight'], + target_weights[drive_name][target_type], + rtol=1e-12, + ) # check select synaptic delays - target_delays = {'evdist1': {'L2_basket': 0.1, 'L5_pyramidal': 0.1}, - 'evprox1': {'L2_basket': 0.1, 'L5_pyramidal': 1.}, - 'evprox2': {'L2_basket': 0.1, 'L5_pyramidal': 1.}} + target_delays = { + 'evdist1': {'L2_basket': 0.1, 'L5_pyramidal': 0.1}, + 'evprox1': {'L2_basket': 0.1, 'L5_pyramidal': 1.0}, + 'evprox2': {'L2_basket': 0.1, 'L5_pyramidal': 1.0}, + } for drive_name in target_delays: for target_type in target_delays[drive_name]: - conn_idxs = pick_connection(net, src_gids=drive_name, - target_gids=target_type, - receptor='ampa') + conn_idxs = pick_connection( + net, src_gids=drive_name, target_gids=target_type, receptor='ampa' + ) for conn_idx in conn_idxs: drive_conn = net.connectivity[conn_idx] - assert_allclose(drive_conn['nc_dict']['A_delay'], - target_delays[drive_name][target_type], - rtol=1e-12) + assert_allclose( + drive_conn['nc_dict']['A_delay'], + target_delays[drive_name][target_type], + rtol=1e-12, + ) # array of simulation times is created in Network.__init__, but passed # to CellResponse-constructor for storage (Network is agnostic of time) - with pytest.raises(TypeError, - match="'times' is an np.ndarray of simulation times"): + with pytest.raises(TypeError, match="'times' is an np.ndarray of simulation times"): _ = CellResponse(times='blah') # Assert that all external drives are initialized @@ -561,21 +691,21 @@ def test_network_drives_legacy(): n_evoked_sources = 3 * net._n_cells n_pois_sources = net._n_cells n_gaus_sources = net._n_cells - n_bursty_sources = (net.external_drives['bursty1']['n_drive_cells'] + - net.external_drives['bursty2']['n_drive_cells']) + n_bursty_sources = ( + net.external_drives['bursty1']['n_drive_cells'] + + net.external_drives['bursty2']['n_drive_cells'] + ) # test that expected number of external driving events are created - assert len(network_builder._drive_cells) == (n_evoked_sources + - n_pois_sources + - n_gaus_sources + - n_bursty_sources) + assert len(network_builder._drive_cells) == ( + n_evoked_sources + n_pois_sources + n_gaus_sources + n_bursty_sources + ) def test_network_connectivity(base_network): net, params = base_network # instantiate drive events and artificial cells for NetworkBuilder - net._instantiate_drives(tstop=10.0, - n_trials=1) + net._instantiate_drives(tstop=10.0, n_trials=1) network_builder = NetworkBuilder(net) # start by checking that Network connectivity transfers to NetworkBuilder @@ -584,14 +714,14 @@ def test_network_connectivity(base_network): # Check basket-basket connection where allow_autapses=False assert 'L2Pyr_L2Pyr_nmda' in network_builder.ncs - n_connections = 3 * (n_pyr ** 2 - n_pyr) # 3 synapses / cell + n_connections = 3 * (n_pyr**2 - n_pyr) # 3 synapses / cell assert len(network_builder.ncs['L2Pyr_L2Pyr_nmda']) == n_connections nc = network_builder.ncs['L2Pyr_L2Pyr_nmda'][0] assert nc.threshold == params['threshold'] # Check basket-basket connection where allow_autapses=True assert 'L2Basket_L2Basket_gabaa' in network_builder.ncs - n_connections = n_basket ** 2 # 1 synapse / cell + n_connections = n_basket**2 # 1 synapse / cell assert len(network_builder.ncs['L2Basket_L2Basket_gabaa']) == n_connections nc = network_builder.ncs['L2Basket_L2Basket_gabaa'][0] assert nc.threshold == params['threshold'] @@ -601,10 +731,16 @@ def test_network_connectivity(base_network): n_conn_trunk = len(network_builder.ncs['L2Pyr_L2Pyr_nmda']) # add connections targeting single section and rebuild - kwargs_default = dict(src_gids=[35, 36], target_gids=[35, 36], - loc='proximal', receptor='ampa', - weight=5e-4, delay=1.0, lamtha=1e9, - probability=1.0) + kwargs_default = dict( + src_gids=[35, 36], + target_gids=[35, 36], + loc='proximal', + receptor='ampa', + weight=5e-4, + delay=1.0, + lamtha=1e9, + probability=1.0, + ) net.add_connection(**kwargs_default) # smoke test kwargs_trunk = kwargs_default.copy() kwargs_trunk['loc'] = 'apical_trunk' @@ -625,40 +761,62 @@ def test_network_connectivity(base_network): assert_allclose(nc.weight[0], kwargs_trunk['weight']) # Check that exactly 4 apical_trunk connections appended for idx in range(1, 5): - assert network_builder.ncs['L2Pyr_L2Pyr_nmda'][ - -idx].postseg().__str__() == 'L2Pyr_apical_trunk(0.5)' - assert network_builder.ncs['L2Pyr_L2Pyr_nmda'][ - -5].postseg().__str__() == 'L2Pyr_basal_3(0.5)' + assert ( + network_builder.ncs['L2Pyr_L2Pyr_nmda'][-idx].postseg().__str__() + == 'L2Pyr_apical_trunk(0.5)' + ) + assert ( + network_builder.ncs['L2Pyr_L2Pyr_nmda'][-5].postseg().__str__() + == 'L2Pyr_basal_3(0.5)' + ) kwargs_good = [ - ('src_gids', 0), ('src_gids', 'L2_pyramidal'), ('src_gids', range(2)), - ('target_gids', 35), ('target_gids', range(2)), + ('src_gids', 0), + ('src_gids', 'L2_pyramidal'), + ('src_gids', range(2)), + ('target_gids', 35), + ('target_gids', range(2)), ('target_gids', 'L2_pyramidal'), - ('target_gids', [[35, 36], [37, 38]]), ('probability', 0.5), - ('loc', 'apical_trunk')] + ('target_gids', [[35, 36], [37, 38]]), + ('probability', 0.5), + ('loc', 'apical_trunk'), + ] for arg, item in kwargs_good: kwargs = kwargs_default.copy() kwargs[arg] = item net.add_connection(**kwargs) kwargs_bad = [ - ('src_gids', 0.0), ('src_gids', [0.0]), - ('target_gids', 35.0), ('target_gids', [35.0]), - ('target_gids', [[35], [36.0]]), ('loc', 1.0), - ('receptor', 1.0), ('weight', '1.0'), ('delay', '1.0'), - ('lamtha', '1.0'), ('probability', '0.5'), ('allow_autapses', 1.0)] + ('src_gids', 0.0), + ('src_gids', [0.0]), + ('target_gids', 35.0), + ('target_gids', [35.0]), + ('target_gids', [[35], [36.0]]), + ('loc', 1.0), + ('receptor', 1.0), + ('weight', '1.0'), + ('delay', '1.0'), + ('lamtha', '1.0'), + ('probability', '0.5'), + ('allow_autapses', 1.0), + ] for arg, item in kwargs_bad: - match = ('must be an instance of') + match = 'must be an instance of' with pytest.raises(TypeError, match=match): kwargs = kwargs_default.copy() kwargs[arg] = item net.add_connection(**kwargs) kwargs_bad = [ - ('src_gids', -1), ('src_gids', [-1]), - ('target_gids', -1), ('target_gids', [-1]), - ('target_gids', [[35], [-1]]), ('target_gids', [[35]]), - ('src_gids', [0, 100]), ('target_gids', [0, 100])] + ('src_gids', -1), + ('src_gids', [-1]), + ('target_gids', -1), + ('target_gids', [-1]), + ('target_gids', [[35], [-1]]), + ('target_gids', [[35]]), + ('src_gids', [0, 100]), + ('target_gids', [0, 100]), + ] for arg, item in kwargs_bad: with pytest.raises(AssertionError): kwargs = kwargs_default.copy() @@ -679,11 +837,11 @@ def test_network_connectivity(base_network): kwargs['probability'] = 0.5 net.add_connection(**kwargs) n_connections = np.sum( - [len(t_gids) for - t_gids in net.connectivity[-2]['gid_pairs'].values()]) + [len(t_gids) for t_gids in net.connectivity[-2]['gid_pairs'].values()] + ) n_connections_new = np.sum( - [len(t_gids) for - t_gids in net.connectivity[-1]['gid_pairs'].values()]) + [len(t_gids) for t_gids in net.connectivity[-1]['gid_pairs'].values()] + ) assert n_connections_new == np.round(n_connections * 0.5).astype(int) assert net.connectivity[-1]['probability'] == 0.5 with pytest.raises(ValueError, match='probability must be'): @@ -692,7 +850,7 @@ def test_network_connectivity(base_network): net.add_connection(**kwargs) # Make sure warning raised if section targeted doesn't contain synapse - match = ('Invalid value for') + match = 'Invalid value for' with pytest.raises(ValueError, match=match): kwargs = kwargs_default.copy() kwargs['target_gids'] = 'L5_pyramidal' @@ -716,8 +874,7 @@ def test_add_cell_type(): params = read_params(params_fname) net = jones_2009_model(params) # instantiate drive events for NetworkBuilder - net._instantiate_drives(tstop=params['tstop'], - n_trials=params['N_trials']) + net._instantiate_drives(tstop=params['tstop'], n_trials=params['N_trials']) n_total_cells = net._n_cells pos = [(0, idx, 0) for idx in range(10)] @@ -729,9 +886,15 @@ def test_add_cell_type(): n_new_type = len(net.gid_ranges['new_type']) assert n_new_type == len(pos) - net.add_connection('L2_basket', 'new_type', loc='proximal', - receptor='gabaa', weight=8e-3, delay=1, - lamtha=2) + net.add_connection( + 'L2_basket', + 'new_type', + loc='proximal', + receptor='gabaa', + weight=8e-3, + delay=1, + lamtha=2, + ) network_builder = NetworkBuilder(net) assert net._n_cells == n_total_cells + len(pos) @@ -752,47 +915,43 @@ def test_tonic_biases(): net = Network(params) # add arbitrary local network connectivity to avoid simulation warning - net.add_connection(src_gids='L2_pyramidal', - target_gids='L2_basket', - loc='soma', receptor='ampa', weight=1e-3, - delay=1.0, lamtha=3.0) - - tonic_bias_1 = { - 'L2_pyramidal': 1.0, - 'name_nonexistent': 1.0 - } + net.add_connection( + src_gids='L2_pyramidal', + target_gids='L2_basket', + loc='soma', + receptor='ampa', + weight=1e-3, + delay=1.0, + lamtha=3.0, + ) + + tonic_bias_1 = {'L2_pyramidal': 1.0, 'name_nonexistent': 1.0} with pytest.raises(ValueError, match=r'cell_type must be one of .*$'): - net.add_tonic_bias(amplitude=tonic_bias_1, t0=0.0, - tstop=4.0) + net.add_tonic_bias(amplitude=tonic_bias_1, t0=0.0, tstop=4.0) # The previous test only adds L2_pyramidal and ignores name_nonexistent # Testing the fist bias was added assert net.external_biases['tonic']['L2_pyramidal'] is not None net.external_biases = dict() - with pytest.raises(TypeError, - match='amplitude must be an instance of dict'): - net.add_tonic_bias(amplitude=0.1, - t0=5.0, tstop=-1.0) + with pytest.raises(TypeError, match='amplitude must be an instance of dict'): + net.add_tonic_bias(amplitude=0.1, t0=5.0, tstop=-1.0) - tonic_bias_2 = { - 'L2_pyramidal': 1.0, - 'L5_basket': 0.5 - } + tonic_bias_2 = {'L2_pyramidal': 1.0, 'L5_basket': 0.5} - with pytest.raises(ValueError, match='Duration of tonic input cannot be' - ' negative'): - net.add_tonic_bias(amplitude=tonic_bias_2, - t0=5.0, tstop=4.0) - simulate_dipole(net, tstop=20.) + with pytest.raises( + ValueError, match='Duration of tonic input cannot be' ' negative' + ): + net.add_tonic_bias(amplitude=tonic_bias_2, t0=5.0, tstop=4.0) + simulate_dipole(net, tstop=20.0) net.external_biases = dict() - with pytest.raises(ValueError, match='End time of tonic input cannot be' - ' negative'): - net.add_tonic_bias(amplitude=tonic_bias_2, - t0=5.0, tstop=-1.0) - simulate_dipole(net, tstop=5.) + with pytest.raises( + ValueError, match='End time of tonic input cannot be' ' negative' + ): + net.add_tonic_bias(amplitude=tonic_bias_2, t0=5.0, tstop=-1.0) + simulate_dipole(net, tstop=5.0) net.external_biases = dict() with pytest.raises(ValueError, match='parameter may be missing'): @@ -802,48 +961,62 @@ def test_tonic_biases(): # test adding single cell_type - amplitude (old API) with pytest.raises(ValueError, match=r'cell_type must be one of .*$'): - with pytest.warns(DeprecationWarning, - match=r'cell_type argument will be deprecated'): - net.add_tonic_bias(cell_type='name_nonexistent', amplitude=1.0, - t0=0.0, tstop=4.0) - - with pytest.raises(TypeError, - match='amplitude must be an instance of float or int'): - with pytest.warns(DeprecationWarning, - match=r'cell_type argument will be deprecated'): - net.add_tonic_bias(cell_type='L5_pyramidal', - amplitude={'L2_pyramidal': 0.1}, - t0=5.0, tstop=-1.0) - - with pytest.raises(ValueError, match='Duration of tonic input cannot be' - ' negative'): - with pytest.warns(DeprecationWarning, - match=r'cell_type argument will be deprecated'): - net.add_tonic_bias(cell_type='L2_pyramidal', amplitude=1, - t0=5.0, tstop=4.0) - simulate_dipole(net, tstop=20.) + with pytest.warns( + DeprecationWarning, match=r'cell_type argument will be deprecated' + ): + net.add_tonic_bias( + cell_type='name_nonexistent', amplitude=1.0, t0=0.0, tstop=4.0 + ) + + with pytest.raises( + TypeError, match='amplitude must be an instance of float or int' + ): + with pytest.warns( + DeprecationWarning, match=r'cell_type argument will be deprecated' + ): + net.add_tonic_bias( + cell_type='L5_pyramidal', + amplitude={'L2_pyramidal': 0.1}, + t0=5.0, + tstop=-1.0, + ) + + with pytest.raises( + ValueError, match='Duration of tonic input cannot be' ' negative' + ): + with pytest.warns( + DeprecationWarning, match=r'cell_type argument will be deprecated' + ): + net.add_tonic_bias(cell_type='L2_pyramidal', amplitude=1, t0=5.0, tstop=4.0) + simulate_dipole(net, tstop=20.0) net.external_biases = dict() - with pytest.raises(ValueError, match='End time of tonic input cannot be' - ' negative'): - with pytest.warns(DeprecationWarning, - match=r'cell_type argument will be deprecated'): - net.add_tonic_bias(cell_type='L2_pyramidal', amplitude=1.0, - t0=5.0, tstop=-1.0) - simulate_dipole(net, tstop=5.) - - params.update({ - 'N_pyr_x': 3, 'N_pyr_y': 3, - 'N_trials': 1, - 'dipole_smooth_win': 5, - 't_evprox_1': 5, - 't_evdist_1': 10, - 't_evprox_2': 20, - # tonic inputs - 'Itonic_A_L2Pyr_soma': 1.0, - 'Itonic_t0_L2Pyr_soma': 5.0, - 'Itonic_T_L2Pyr_soma': 15.0 - }) + with pytest.raises( + ValueError, match='End time of tonic input cannot be' ' negative' + ): + with pytest.warns( + DeprecationWarning, match=r'cell_type argument will be deprecated' + ): + net.add_tonic_bias( + cell_type='L2_pyramidal', amplitude=1.0, t0=5.0, tstop=-1.0 + ) + simulate_dipole(net, tstop=5.0) + + params.update( + { + 'N_pyr_x': 3, + 'N_pyr_y': 3, + 'N_trials': 1, + 'dipole_smooth_win': 5, + 't_evprox_1': 5, + 't_evdist_1': 10, + 't_evprox_2': 20, + # tonic inputs + 'Itonic_A_L2Pyr_soma': 1.0, + 'Itonic_t0_L2Pyr_soma': 5.0, + 'Itonic_T_L2Pyr_soma': 15.0, + } + ) # old API net = Network(params, add_drives_from_params=True) assert 'tonic' in net.external_biases @@ -917,8 +1090,7 @@ def test_synaptic_gains(): # Single argument check with copy net_updated = net.update_weights(e_e=2.0, copy=True) for conn in net_updated.connectivity: - if (conn['src_type'] in e_cell_names and - conn['target_type'] in e_cell_names): + if conn['src_type'] in e_cell_names and conn['target_type'] in e_cell_names: assert conn['nc_dict']['gain'] == 2.0 else: assert conn['nc_dict']['gain'] == 1.0 @@ -929,8 +1101,7 @@ def test_synaptic_gains(): # Single argument with inplace change net.update_weights(i_e=0.5, copy=False) for conn in net.connectivity: - if (conn['src_type'] in i_cell_names and - conn['target_type'] in e_cell_names): + if conn['src_type'] in i_cell_names and conn['target_type'] in e_cell_names: assert conn['nc_dict']['gain'] == 0.5 else: assert conn['nc_dict']['gain'] == 1.0 @@ -938,11 +1109,9 @@ def test_synaptic_gains(): # Two argument check net.update_weights(i_e=0.5, i_i=0.25, copy=False) for conn in net.connectivity: - if (conn['src_type'] in i_cell_names and - conn['target_type'] in e_cell_names): + if conn['src_type'] in i_cell_names and conn['target_type'] in e_cell_names: assert conn['nc_dict']['gain'] == 0.5 - elif (conn['src_type'] in i_cell_names and - conn['target_type'] in i_cell_names): + elif conn['src_type'] in i_cell_names and conn['target_type'] in i_cell_names: assert conn['nc_dict']['gain'] == 0.25 else: assert conn['nc_dict']['gain'] == 1.0 @@ -953,29 +1122,34 @@ def _get_weight(nb, conn_name, idx=0): nb_updated = NetworkBuilder(net) # i_e check - assert (_get_weight(nb_updated, 'L2Basket_L2Pyr_gabaa') / - _get_weight(nb_base, 'L2Basket_L2Pyr_gabaa')) == 0.5 + assert ( + _get_weight(nb_updated, 'L2Basket_L2Pyr_gabaa') + / _get_weight(nb_base, 'L2Basket_L2Pyr_gabaa') + ) == 0.5 # i_i check - assert (_get_weight(nb_updated, 'L2Basket_L2Basket_gabaa') / - _get_weight(nb_base, 'L2Basket_L2Basket_gabaa')) == 0.25 + assert ( + _get_weight(nb_updated, 'L2Basket_L2Basket_gabaa') + / _get_weight(nb_base, 'L2Basket_L2Basket_gabaa') + ) == 0.25 # Unaltered check - assert (_get_weight(nb_updated, 'L2Pyr_L5Basket_ampa') / - _get_weight(nb_base, 'L2Pyr_L5Basket_ampa')) == 1 + assert ( + _get_weight(nb_updated, 'L2Pyr_L5Basket_ampa') + / _get_weight(nb_base, 'L2Pyr_L5Basket_ampa') + ) == 1 class TestPickConnection: """Tests for the pick_connection function.""" - @pytest.mark.parametrize("arg_name", - ["src_gids", "target_gids", "loc", "receptor"] - ) + + @pytest.mark.parametrize('arg_name', ['src_gids', 'target_gids', 'loc', 'receptor']) def test_1argument_none(self, base_network, arg_name): - """ Tests passing None as an argument value. """ + """Tests passing None as an argument value.""" net, _ = base_network kwargs = {'net': net, f'{arg_name}': None} indices = pick_connection(**kwargs) assert len(indices) == 0 - @pytest.mark.parametrize("arg_name", ["src_gids", "target_gids"]) + @pytest.mark.parametrize('arg_name', ['src_gids', 'target_gids']) def test_1argument_gids_range(self, base_network, arg_name): """Tests passing range as an argument value.""" net, _ = base_network @@ -984,16 +1158,17 @@ def test_1argument_gids_range(self, base_network, arg_name): indices = pick_connection(**kwargs) for conn_idx in indices: - assert set(test_range).issubset( - net.connectivity[conn_idx][arg_name] - ) - - @pytest.mark.parametrize("arg_name,value", - [("src_gids", 'L2_pyramidal'), - ("target_gids", 'L2_pyramidal'), - ("loc", 'soma'), - ("receptor", 'gabaa'), - ]) + assert set(test_range).issubset(net.connectivity[conn_idx][arg_name]) + + @pytest.mark.parametrize( + 'arg_name,value', + [ + ('src_gids', 'L2_pyramidal'), + ('target_gids', 'L2_pyramidal'), + ('loc', 'soma'), + ('receptor', 'gabaa'), + ], + ) def test_1argument_str(self, base_network, arg_name, value): """Tests passing string as an argument value.""" net, _ = base_network @@ -1003,17 +1178,20 @@ def test_1argument_str(self, base_network, arg_name, value): for conn_idx in indices: if arg_name in ('src_gids', 'target_gids'): # arg specifies a subset of item gids (within gid_ranges) - assert (net.connectivity[conn_idx][arg_name] - .issubset(net.gid_ranges[value]) - ) + assert net.connectivity[conn_idx][arg_name].issubset( + net.gid_ranges[value] + ) else: # arg and item specify equivalent string descriptors assert net.connectivity[conn_idx][arg_name] == value - @pytest.mark.parametrize("arg_name,value", - [("src_gids", 0), - ("target_gids", 35), - ]) + @pytest.mark.parametrize( + 'arg_name,value', + [ + ('src_gids', 0), + ('target_gids', 35), + ], + ) def test_1argument_gids_int(self, base_network, arg_name, value): """Tests that connections are not missing when passing one gid.""" net, _ = base_network @@ -1026,31 +1204,34 @@ def test_1argument_gids_int(self, base_network, arg_name, value): else: assert value not in net.connectivity[conn_idx][arg_name] - @pytest.mark.parametrize("arg_name,value", - [("src_gids", ['L2_basket', 'L5_basket']), - ("target_gids", ['L2_pyramidal', 'L5_pyramidal']) - ]) - def test_1argument_list_of_cell_types_str(self, - base_network, - arg_name, - value): + @pytest.mark.parametrize( + 'arg_name,value', + [ + ('src_gids', ['L2_basket', 'L5_basket']), + ('target_gids', ['L2_pyramidal', 'L5_pyramidal']), + ], + ) + def test_1argument_list_of_cell_types_str(self, base_network, arg_name, value): """Tests passing a list of valid strings""" net, _ = base_network kwargs = {'net': net, f'{arg_name}': value} indices = pick_connection(**kwargs) - true_gid_set = set(list(net.gid_ranges[value[0]]) + - list(net.gid_ranges[value[1]]) - ) + true_gid_set = set( + list(net.gid_ranges[value[0]]) + list(net.gid_ranges[value[1]]) + ) pick_gid_list = [] for idx in indices: pick_gid_list.extend(net.connectivity[idx][arg_name]) assert true_gid_set == set(pick_gid_list) - @pytest.mark.parametrize("arg_name,value", - [("src_gids", [0, 5]), - ("target_gids", [35, 34]), - ]) + @pytest.mark.parametrize( + 'arg_name,value', + [ + ('src_gids', [0, 5]), + ('target_gids', [35, 34]), + ], + ) def test_1argument_list_of_gids_int(self, base_network, arg_name, value): """Tests passing a list of valid ints.""" net, _ = base_network @@ -1064,61 +1245,69 @@ def test_1argument_list_of_gids_int(self, base_network, arg_name, value): assert indices == true_idx_list - @pytest.mark.parametrize("src_gids,target_gids,loc,receptor", - [("evdist1", None, "proximal", None), - ("evprox1", None, "distal", None), - (None, None, "distal", "gabab"), - ("L2_pyramidal", None, None, "gabab"), - ("L2_basket", "L2_basket", "proximal", "nmda"), - ("L2_pyramidal", "L2_basket", "distal", "gabab"), - ]) - def test_no_match(self, base_network, - src_gids, target_gids, loc, receptor): + @pytest.mark.parametrize( + 'src_gids,target_gids,loc,receptor', + [ + ('evdist1', None, 'proximal', None), + ('evprox1', None, 'distal', None), + (None, None, 'distal', 'gabab'), + ('L2_pyramidal', None, None, 'gabab'), + ('L2_basket', 'L2_basket', 'proximal', 'nmda'), + ('L2_pyramidal', 'L2_basket', 'distal', 'gabab'), + ], + ) + def test_no_match(self, base_network, src_gids, target_gids, loc, receptor): """Tests no matches returned for non-configured connections.""" net, _ = base_network - indices = pick_connection(net, - src_gids=src_gids, - target_gids=target_gids, - loc=loc, - receptor=receptor) + indices = pick_connection( + net, src_gids=src_gids, target_gids=target_gids, loc=loc, receptor=receptor + ) assert len(indices) == 0 - @pytest.mark.parametrize("src_gids,target_gids,loc,receptor", - [(0.0, None, None, None), - ([0.0], None, None, None), - (None, 35.0, None, None), - (None, [35.0], None, None), - (None, [35, [36.0]], None, None), - (None, None, 1.0, None), - (None, None, None, 1.0), - ]) - def test_type_error(self, base_network, - src_gids, target_gids, loc, receptor): + @pytest.mark.parametrize( + 'src_gids,target_gids,loc,receptor', + [ + (0.0, None, None, None), + ([0.0], None, None, None), + (None, 35.0, None, None), + (None, [35.0], None, None), + (None, [35, [36.0]], None, None), + (None, None, 1.0, None), + (None, None, None, 1.0), + ], + ) + def test_type_error(self, base_network, src_gids, target_gids, loc, receptor): """Tests TypeError when passing floats.""" net, _ = base_network - match = ('must be an instance of') + match = 'must be an instance of' with pytest.raises(TypeError, match=match): - pick_connection(net, - src_gids=src_gids, - target_gids=target_gids, - loc=loc, - receptor=receptor) - - @pytest.mark.parametrize("src_gids,target_gids", - [(-1, None), ([-1], None), - (None, -1), (None, [-1]), - ([35, -1], None), (None, [35, -1]), - ]) + pick_connection( + net, + src_gids=src_gids, + target_gids=target_gids, + loc=loc, + receptor=receptor, + ) + + @pytest.mark.parametrize( + 'src_gids,target_gids', + [ + (-1, None), + ([-1], None), + (None, -1), + (None, [-1]), + ([35, -1], None), + (None, [35, -1]), + ], + ) def test_invalid_gids_int(self, base_network, src_gids, target_gids): """Tests AssertionError when passing negative ints.""" net, _ = base_network - match = ('not in net.gid_ranges') + match = 'not in net.gid_ranges' with pytest.raises(AssertionError, match=match): pick_connection(net, src_gids=src_gids, target_gids=target_gids) - @pytest.mark.parametrize("arg_name", - ["src_gids", "target_gids", "loc", "receptor"] - ) + @pytest.mark.parametrize('arg_name', ['src_gids', 'target_gids', 'loc', 'receptor']) def test_invalid_str(self, base_network, arg_name): """Tests ValueError raises when passing unrecognized string.""" net, _ = base_network @@ -1127,21 +1316,20 @@ def test_invalid_str(self, base_network, arg_name): kwargs = {'net': net, f'{arg_name}': 'invalid_string'} pick_connection(**kwargs) - @pytest.mark.parametrize("src_gids,target_gids,expected", - [("evdist1", "L5_pyramidal", 2), - ("evprox1", "L2_basket", 2), - ("L2_basket", "L2_basket", 0), - ]) - def test_only_drives_specified(self, base_network, src_gids, - target_gids, expected): + @pytest.mark.parametrize( + 'src_gids,target_gids,expected', + [ + ('evdist1', 'L5_pyramidal', 2), + ('evprox1', 'L2_basket', 2), + ('L2_basket', 'L2_basket', 0), + ], + ) + def test_only_drives_specified(self, base_network, src_gids, target_gids, expected): """Tests searching a Network with only drive connections added. Only searches for drive connectivity should have results. """ _, param = base_network net = Network(param, add_drives_from_params=True) - indices = pick_connection(net, - src_gids=src_gids, - target_gids=target_gids - ) + indices = pick_connection(net, src_gids=src_gids, target_gids=target_gids) assert len(indices) == expected diff --git a/hnn_core/tests/test_optimize_evoked.py b/hnn_core/tests/test_optimize_evoked.py index fd6e07fdb..ffb82db22 100644 --- a/hnn_core/tests/test_optimize_evoked.py +++ b/hnn_core/tests/test_optimize_evoked.py @@ -6,11 +6,13 @@ import hnn_core from hnn_core import read_params, jones_2009_model, simulate_dipole -from hnn_core.optimization.optimize_evoked import (_consolidate_chunks, - _split_by_evinput, - _generate_weights, - _get_drive_params, - optimize_evoked) +from hnn_core.optimization.optimize_evoked import ( + _consolidate_chunks, + _split_by_evinput, + _generate_weights, + _get_drive_params, + optimize_evoked, +) def test_consolidate_chunks(): @@ -21,15 +23,15 @@ def test_consolidate_chunks(): 'end': 25, 'ranges': {'initial': 1e-10, 'minval': 1e-11, 'maxval': 1e-9}, 'opt_end': 90, - 'weights': np.array([5., 10.]) + 'weights': np.array([5.0, 10.0]), }, 'ev2': { 'start': 100, 'end': 120, 'ranges': {'initial': 1e-10, 'minval': 1e-11, 'maxval': 1e-9}, 'opt_end': 170, - 'weights': np.array([10., 5.]) - } + 'weights': np.array([10.0, 5.0]), + }, } chunks = _consolidate_chunks(inputs) assert len(chunks) == len(inputs) + 1 # extra last chunk?? @@ -43,39 +45,47 @@ def test_consolidate_chunks(): assert len(chunks) == 1 assert chunks[0]['start'] == inputs['ev1']['start'] assert chunks[0]['end'] == inputs['ev2']['end'] - assert np.allclose(chunks[0]['weights'], - (inputs['ev1']['weights'] + - inputs['ev2']['weights']) / 2.) + assert np.allclose( + chunks[0]['weights'], + (inputs['ev1']['weights'] + inputs['ev2']['weights']) / 2.0, + ) def test_split_by_evinput(): """Test splitting evoked input.""" drive_names = ['ev_drive_1', 'ev_drive_2'] - drive_dynamics = [{'mu': 5., 'sigma': .1}, {'mu': 10., 'sigma': .2}] - drive_syn_weights = [{'ampa_L2_pyramidal': 1.}, {'nmda_L5_basket': 2.}] - tstop = 20. + drive_dynamics = [{'mu': 5.0, 'sigma': 0.1}, {'mu': 10.0, 'sigma': 0.2}] + drive_syn_weights = [{'ampa_L2_pyramidal': 1.0}, {'nmda_L5_basket': 2.0}] + tstop = 20.0 dt = 0.025 timing_range_multiplier = 3.0 sigma_range_multiplier = 50.0 synweight_range_multiplier = 500.0 decay_multiplier = 1.6 - evinput_params = _split_by_evinput(drive_names, drive_dynamics, - drive_syn_weights, tstop, - sigma_range_multiplier, - timing_range_multiplier, - synweight_range_multiplier) + evinput_params = _split_by_evinput( + drive_names, + drive_dynamics, + drive_syn_weights, + tstop, + sigma_range_multiplier, + timing_range_multiplier, + synweight_range_multiplier, + ) assert list(evinput_params.keys()) == drive_names for evinput in evinput_params.values(): - assert list(evinput.keys()) == ['mean', 'sigma', 'ranges', - 'start', 'end'] + assert list(evinput.keys()) == ['mean', 'sigma', 'ranges', 'start', 'end'] - evinput_params = _generate_weights(evinput_params, tstop, dt, - decay_multiplier) + evinput_params = _generate_weights(evinput_params, tstop, dt, decay_multiplier) for evinput in evinput_params.values(): - assert list(evinput.keys()) == ['ranges', 'start', 'end', - 'weights', 'opt_start', - 'opt_end'] + assert list(evinput.keys()) == [ + 'ranges', + 'start', + 'end', + 'weights', + 'opt_start', + 'opt_end', + ] def test_optimize_evoked(): @@ -84,70 +94,100 @@ def test_optimize_evoked(): params_fname = op.join(hnn_core_root, 'param', 'default.json') params = read_params(params_fname) - tstop = 10. + tstop = 10.0 n_trials = 1 # simulate a dipole to establish ground-truth drive parameters - mu_orig = 6. - params.update({'t_evprox_1': mu_orig, - 'sigma_t_evprox_1': 2., - 't_evdist_1': mu_orig + 2, - 'sigma_t_evdist_1': 2.}) - net_orig = jones_2009_model(params, add_drives_from_params=True, - mesh_shape=(3, 3)) + mu_orig = 6.0 + params.update( + { + 't_evprox_1': mu_orig, + 'sigma_t_evprox_1': 2.0, + 't_evdist_1': mu_orig + 2, + 'sigma_t_evdist_1': 2.0, + } + ) + net_orig = jones_2009_model(params, add_drives_from_params=True, mesh_shape=(3, 3)) del net_orig.external_drives['evprox2'] dpl_orig = simulate_dipole(net_orig, tstop=tstop, n_trials=n_trials)[0] # simulate a dipole with a time-shifted drive - mu_offset = 4. - params.update({'t_evprox_1': mu_offset, - 'sigma_t_evprox_1': 2., - 't_evdist_1': mu_offset + 2, - 'sigma_t_evdist_1': 2.}) - net_offset = jones_2009_model(params, add_drives_from_params=True, - mesh_shape=(3, 3)) + mu_offset = 4.0 + params.update( + { + 't_evprox_1': mu_offset, + 'sigma_t_evprox_1': 2.0, + 't_evdist_1': mu_offset + 2, + 'sigma_t_evdist_1': 2.0, + } + ) + net_offset = jones_2009_model( + params, add_drives_from_params=True, mesh_shape=(3, 3) + ) del net_offset.external_drives['evprox2'] dpl_offset = simulate_dipole(net_offset, tstop=tstop, n_trials=n_trials)[0] # get drive params from the pre-optimization Network instance _, _, drive_static_params_orig = _get_drive_params(net_offset, ['evprox1']) - with pytest.raises(ValueError, match='The current Network instance lacks ' - 'any evoked drives'): + with pytest.raises( + ValueError, match='The current Network instance lacks ' 'any evoked drives' + ): net_empty = net_offset.copy() del net_empty.external_drives['evprox1'] del net_empty.external_drives['evdist1'] - net_opt = optimize_evoked(net_empty, tstop=tstop, - n_trials=n_trials, target_dpl=dpl_orig, - initial_dpl=dpl_offset, maxiter=10) - - with pytest.raises(ValueError, match='The drives selected to be optimized ' - 'are not evoked drives'): + net_opt = optimize_evoked( + net_empty, + tstop=tstop, + n_trials=n_trials, + target_dpl=dpl_orig, + initial_dpl=dpl_offset, + maxiter=10, + ) + + with pytest.raises( + ValueError, match='The drives selected to be optimized ' 'are not evoked drives' + ): net_test_bursty = net_offset.copy() which_drives = ['bursty1'] - net_opt = optimize_evoked(net_test_bursty, tstop=tstop, - n_trials=n_trials, target_dpl=dpl_orig, - initial_dpl=dpl_offset, - which_drives=which_drives, maxiter=10) + net_opt = optimize_evoked( + net_test_bursty, + tstop=tstop, + n_trials=n_trials, + target_dpl=dpl_orig, + initial_dpl=dpl_offset, + which_drives=which_drives, + maxiter=10, + ) which_drives = ['evprox1'] # drive selected to optimize maxiter = 10 # try without returning iteration RMSE first - net_opt = optimize_evoked(net_offset, tstop=tstop, n_trials=n_trials, - target_dpl=dpl_orig, - initial_dpl=dpl_offset, - timing_range_multiplier=3., - sigma_range_multiplier=50., - synweight_range_multiplier=500., - maxiter=maxiter, which_drives=which_drives, - return_rmse=False) - net_opt, rmse = optimize_evoked(net_offset, tstop=tstop, n_trials=n_trials, - target_dpl=dpl_orig, - initial_dpl=dpl_offset, - timing_range_multiplier=3., - sigma_range_multiplier=50., - synweight_range_multiplier=500., - maxiter=maxiter, which_drives=which_drives, - return_rmse=True) + net_opt = optimize_evoked( + net_offset, + tstop=tstop, + n_trials=n_trials, + target_dpl=dpl_orig, + initial_dpl=dpl_offset, + timing_range_multiplier=3.0, + sigma_range_multiplier=50.0, + synweight_range_multiplier=500.0, + maxiter=maxiter, + which_drives=which_drives, + return_rmse=False, + ) + net_opt, rmse = optimize_evoked( + net_offset, + tstop=tstop, + n_trials=n_trials, + target_dpl=dpl_orig, + initial_dpl=dpl_offset, + timing_range_multiplier=3.0, + sigma_range_multiplier=50.0, + synweight_range_multiplier=500.0, + maxiter=maxiter, + which_drives=which_drives, + return_rmse=True, + ) # the number of returned rmse values should be the same as maxiter assert len(rmse) <= maxiter @@ -158,22 +198,32 @@ def test_optimize_evoked(): # the names of drives should be preserved during optimization assert net_offset.external_drives.keys() == net_opt.external_drives.keys() - drive_dynamics_opt, drive_syn_weights_opt, drive_static_params_opt = \ + drive_dynamics_opt, drive_syn_weights_opt, drive_static_params_opt = ( _get_drive_params(net_opt, ['evprox1']) + ) # ensure that params corresponding to only one evoked drive are discovered - assert (len(drive_dynamics_opt) == - len(drive_syn_weights_opt) == - len(drive_static_params_opt) == 1) + assert ( + len(drive_dynamics_opt) + == len(drive_syn_weights_opt) + == len(drive_static_params_opt) + == 1 + ) # static drive params should remain constant assert drive_static_params_opt == drive_static_params_orig # ensure that only the drive that we wanted to optimize over changed - drive_evdist1_dynamics_offset, drive_evdist1_syn_weights_offset, \ - drive_static_params_offset = _get_drive_params(net_offset, ['evdist1']) - drive_evdist1_dynamics_opt, drive_evdist1_syn_weights_opt, \ - drive_static_params_opt = _get_drive_params(net_opt, ['evdist1']) + ( + drive_evdist1_dynamics_offset, + drive_evdist1_syn_weights_offset, + drive_static_params_offset, + ) = _get_drive_params(net_offset, ['evdist1']) + ( + drive_evdist1_dynamics_opt, + drive_evdist1_syn_weights_opt, + drive_static_params_opt, + ) = _get_drive_params(net_opt, ['evdist1']) # assert that evdist1 did NOT change assert drive_evdist1_dynamics_opt == drive_evdist1_dynamics_offset diff --git a/hnn_core/tests/test_parallel_backends.py b/hnn_core/tests/test_parallel_backends.py index b6e9f7ed6..a0ca9f3ae 100644 --- a/hnn_core/tests/test_parallel_backends.py +++ b/hnn_core/tests/test_parallel_backends.py @@ -40,17 +40,28 @@ def test_gid_assignment(): net = jones_2009_model(add_drives_from_params=False) weights_ampa = {'L2_basket': 1.0, 'L2_pyramidal': 2.0, 'L5_pyramidal': 3.0} - syn_delays = {'L2_basket': .1, 'L2_pyramidal': .2, 'L5_pyramidal': .3} + syn_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.2, 'L5_pyramidal': 0.3} net.add_bursty_drive( - 'bursty_dist', location='distal', burst_rate=10, - weights_ampa=weights_ampa, synaptic_delays=syn_delays, - cell_specific=False, n_drive_cells=5) + 'bursty_dist', + location='distal', + burst_rate=10, + weights_ampa=weights_ampa, + synaptic_delays=syn_delays, + cell_specific=False, + n_drive_cells=5, + ) net.add_evoked_drive( - 'evoked_prox', mu=1.0, sigma=1.0, numspikes=1, - weights_ampa=weights_ampa, location='proximal', - synaptic_delays=syn_delays, cell_specific=True, - n_drive_cells='n_cells') + 'evoked_prox', + mu=1.0, + sigma=1.0, + numspikes=1, + weights_ampa=weights_ampa, + location='proximal', + synaptic_delays=syn_delays, + cell_specific=True, + n_drive_cells='n_cells', + ) net._instantiate_drives(tstop=20, n_trials=2) all_gids = list() @@ -77,7 +88,7 @@ def test_gid_assignment(): @pytest.mark.incremental -class TestParallelBackends(): +class TestParallelBackends: dpls_reduced_mpi = None dpls_reduced_default = None dpls_reduced_joblib = None @@ -87,20 +98,26 @@ def test_run_default(self, run_hnn_core_fixture): global dpls_reduced_default dpls_reduced_default, _ = run_hnn_core_fixture(None, reduced=True) # test consistency across all parallel backends for multiple trials - assert_raises(AssertionError, assert_array_equal, - dpls_reduced_default[0].data['agg'], - dpls_reduced_default[1].data['agg']) + assert_raises( + AssertionError, + assert_array_equal, + dpls_reduced_default[0].data['agg'], + dpls_reduced_default[1].data['agg'], + ) def test_run_joblibbackend(self, run_hnn_core_fixture): """Test consistency between joblib backend simulation with master""" global dpls_reduced_default, dpls_reduced_joblib - dpls_reduced_joblib, _ = run_hnn_core_fixture(backend='joblib', - n_jobs=2, reduced=True) + dpls_reduced_joblib, _ = run_hnn_core_fixture( + backend='joblib', n_jobs=2, reduced=True + ) for trial_idx in range(len(dpls_reduced_default)): - assert_array_equal(dpls_reduced_default[trial_idx].data['agg'], - dpls_reduced_joblib[trial_idx].data['agg']) + assert_array_equal( + dpls_reduced_default[trial_idx].data['agg'], + dpls_reduced_joblib[trial_idx].data['agg'], + ) @requires_mpi4py @requires_psutil @@ -119,9 +136,12 @@ def test_run_mpibackend(self, run_hnn_core_fixture): dpls_reduced_mpi, _ = run_hnn_core_fixture(backend='mpi', reduced=True) for trial_idx in range(len(dpls_reduced_default)): # account for rounding error incured during MPI parallelization - assert_allclose(dpls_reduced_default[trial_idx].data['agg'], - dpls_reduced_mpi[trial_idx].data['agg'], rtol=0, - atol=1e-14) + assert_allclose( + dpls_reduced_default[trial_idx].data['agg'], + dpls_reduced_mpi[trial_idx].data['agg'], + rtol=0, + atol=1e-14, + ) @requires_mpi4py @requires_psutil @@ -130,19 +150,16 @@ def test_terminate_mpibackend(self, run_hnn_core_fixture): hnn_core_root = op.dirname(hnn_core.__file__) params_fname = op.join(hnn_core_root, 'param', 'default.json') params = read_params(params_fname) - params.update({'t_evprox_1': 5, - 't_evdist_1': 10, - 't_evprox_2': 20, - 'N_trials': 2}) - net = jones_2009_model(params, add_drives_from_params=True, - mesh_shape=(3, 3)) + params.update( + {'t_evprox_1': 5, 't_evdist_1': 10, 't_evprox_2': 20, 'N_trials': 2} + ) + net = jones_2009_model(params, add_drives_from_params=True, mesh_shape=(3, 3)) with MPIBackend() as backend: event = Event() # start background thread that will kill all MPIBackends # until event.set() - kill_t = Thread(target=_terminate_mpibackend, - args=(event, backend)) + kill_t = Thread(target=_terminate_mpibackend, args=(event, backend)) # make thread a daemon in case we throw an exception # and don't run event.set() so that py.test will # not hang before exiting @@ -151,12 +168,12 @@ def test_terminate_mpibackend(self, run_hnn_core_fixture): with pytest.warns(UserWarning) as record: with pytest.raises( - RuntimeError, - match="MPI simulation failed. Return code: 1"): + RuntimeError, match='MPI simulation failed. Return code: 1' + ): simulate_dipole(net, tstop=40) event.set() - expected_string = "Child process failed unexpectedly" + expected_string = 'Child process failed unexpectedly' assert expected_string in record[0].message.args[0] @requires_mpi4py @@ -166,18 +183,17 @@ def test_run_mpibackend_oversubscribed(self, run_hnn_core_fixture): hnn_core_root = op.dirname(hnn_core.__file__) params_fname = op.join(hnn_core_root, 'param', 'default.json') params = read_params(params_fname) - params.update({'t_evprox_1': 5, - 't_evdist_1': 10, - 't_evprox_2': 20, - 'N_trials': 2}) - net = jones_2009_model(params, add_drives_from_params=True, - mesh_shape=(3, 3)) + params.update( + {'t_evprox_1': 5, 't_evdist_1': 10, 't_evprox_2': 20, 'N_trials': 2} + ) + net = jones_2009_model(params, add_drives_from_params=True, mesh_shape=(3, 3)) # try running with more procs than cells in the network (will probably # oversubscribe) too_many_procs = net._n_cells + 1 - with pytest.raises(ValueError, match='More MPI processes were ' - 'assigned than there are cells'): + with pytest.raises( + ValueError, match='More MPI processes were ' 'assigned than there are cells' + ): with MPIBackend(n_procs=too_many_procs) as backend: simulate_dipole(net, tstop=40) @@ -185,24 +201,26 @@ def test_run_mpibackend_oversubscribed(self, run_hnn_core_fixture): # always enough cells in the network oversubscribed_procs = cpu_count() + 1 n_grid_1d = int(np.ceil(np.sqrt(oversubscribed_procs))) - params.update({'t_evprox_1': 5, - 't_evdist_1': 10, - 't_evprox_2': 20, - 'N_trials': 2}) - net = jones_2009_model(params, add_drives_from_params=True, - mesh_shape=(n_grid_1d, n_grid_1d)) + params.update( + {'t_evprox_1': 5, 't_evdist_1': 10, 't_evprox_2': 20, 'N_trials': 2} + ) + net = jones_2009_model( + params, add_drives_from_params=True, mesh_shape=(n_grid_1d, n_grid_1d) + ) with MPIBackend(n_procs=oversubscribed_procs) as backend: assert backend.n_procs == oversubscribed_procs simulate_dipole(net, tstop=40) - @pytest.mark.parametrize("backend", ['mpi', 'joblib']) + @pytest.mark.parametrize('backend', ['mpi', 'joblib']) def test_compare_hnn_core(self, run_hnn_core_fixture, backend, n_jobs=1): """Test hnn-core does not break.""" # small snippet of data on data branch for now. To be deleted # later. Data branch should have only commit so it does not # pollute the history. - data_url = ('https://raw.githubusercontent.com/jonescompneurolab/' - 'hnn-core/test_data/dpl.txt') + data_url = ( + 'https://raw.githubusercontent.com/jonescompneurolab/' + 'hnn-core/test_data/dpl.txt' + ) if not op.exists('dpl.txt'): urlretrieve(data_url, 'dpl.txt') dpl_master = loadtxt('dpl.txt') @@ -228,13 +246,15 @@ def test_compare_hnn_core(self, run_hnn_core_fixture, backend, n_jobs=1): assert 'common' not in spike_type_counts assert 'exgauss' not in spike_type_counts assert 'extpois' not in spike_type_counts - assert spike_type_counts == {'evprox1': 270, - 'L2_basket': 55, - 'L2_pyramidal': 114, - 'L5_pyramidal': 396, - 'L5_basket': 86, - 'evdist1': 270, - 'evprox2': 270} + assert spike_type_counts == { + 'evprox1': 270, + 'L2_basket': 55, + 'L2_pyramidal': 114, + 'L5_pyramidal': 396, + 'L5_basket': 86, + 'evdist1': 270, + 'evprox2': 270, + } # there are no dependencies if this unit tests fails; no need to be in @@ -244,19 +264,18 @@ def test_compare_hnn_core(self, run_hnn_core_fixture, backend, n_jobs=1): def test_mpi_failure(run_hnn_core_fixture): """Test that an MPI failure is handled and messages are printed""" # this MPI parameter will cause a MPI job to fail - environ["OMPI_MCA_btl"] = "self" + environ['OMPI_MCA_btl'] = 'self' with pytest.warns(UserWarning) as record: with io.StringIO() as buf, redirect_stdout(buf): - with pytest.raises(RuntimeError, match="MPI simulation failed"): - run_hnn_core_fixture(backend='mpi', reduced=True, - postproc=False) + with pytest.raises(RuntimeError, match='MPI simulation failed'): + run_hnn_core_fixture(backend='mpi', reduced=True, postproc=False) stdout = buf.getvalue() - assert "MPI processes are unable to reach each other" in stdout + assert 'MPI processes are unable to reach each other' in stdout - expected_string = "Child process failed unexpectedly" + expected_string = 'Child process failed unexpectedly' assert len(record) == 1 assert record[0].message.args[0] == expected_string - del environ["OMPI_MCA_btl"] + del environ['OMPI_MCA_btl'] diff --git a/hnn_core/tests/test_params.py b/hnn_core/tests/test_params.py index 6155440a6..e7d7f5df9 100644 --- a/hnn_core/tests/test_params.py +++ b/hnn_core/tests/test_params.py @@ -22,10 +22,8 @@ def test_read_params(): params_fname = op.join(hnn_core_root, 'param', 'default.json') params = read_params(params_fname) # Smoke test that network loads params - _ = jones_2009_model( - params, add_drives_from_params=True, legacy_mode=False) - _ = jones_2009_model( - params, add_drives_from_params=True, legacy_mode=True) + _ = jones_2009_model(params, add_drives_from_params=True, legacy_mode=False) + _ = jones_2009_model(params, add_drives_from_params=True, legacy_mode=True) print(params) print(params['L2Pyr*']) @@ -43,8 +41,10 @@ def test_read_params(): def test_read_legacy_params(): """Test reading of legacy .param file.""" - param_url = ('https://raw.githubusercontent.com/hnnsolver/' - 'hnn-core/test_data/default.param') + param_url = ( + 'https://raw.githubusercontent.com/hnnsolver/' + 'hnn-core/test_data/default.param' + ) params_legacy_fname = op.join(hnn_core_root, 'param', 'default.param') if not op.exists(params_legacy_fname): urlretrieve(param_url, params_legacy_fname) @@ -53,18 +53,25 @@ def test_read_legacy_params(): params_legacy = read_params(params_legacy_fname) params_new = read_params(params_new_fname) - params_new_seedless = {key: val for key, val in params_new.items() - if key not in params_new['prng_seedcore*'].keys()} - params_legacy_seedless = {key: val for key, val in params_legacy.items() - if key not in - params_legacy['prng_seedcore*'].keys()} + params_new_seedless = { + key: val + for key, val in params_new.items() + if key not in params_new['prng_seedcore*'].keys() + } + params_legacy_seedless = { + key: val + for key, val in params_legacy.items() + if key not in params_legacy['prng_seedcore*'].keys() + } assert params_new_seedless == params_legacy_seedless def test_base_params(): """Test default params object matches base params""" - param_url = ('https://raw.githubusercontent.com/jonescompneurolab/' - 'hnn-core/test_data/base.json') + param_url = ( + 'https://raw.githubusercontent.com/jonescompneurolab/' + 'hnn-core/test_data/base.json' + ) params_base_fname = op.join(hnn_core_root, 'param', 'base.json') if not op.exists(params_base_fname): urlretrieve(param_url, params_base_fname) @@ -79,37 +86,38 @@ def test_base_params(): def test_remove_nulled_drives(tmp_path): - param_url = ("https://raw.githubusercontent.com/jonescompneurolab/hnn/" - "master/param/ERPYes100Trials.param") + param_url = ( + 'https://raw.githubusercontent.com/jonescompneurolab/hnn/' + 'master/param/ERPYes100Trials.param' + ) params_fname = Path(hnn_core_root, 'param', 'ERPYes100Trials.param') if not op.exists(params_fname): urlretrieve(param_url, params_fname) - net = jones_2009_model(params=read_params(params_fname), - add_drives_from_params=True, - legacy_mode=True, - ) + net = jones_2009_model( + params=read_params(params_fname), + add_drives_from_params=True, + legacy_mode=True, + ) net_removed = remove_nulled_drives(net) assert net_removed != net # External drives were removed drives_removed = ['bursty1', 'bursty2', 'extgauss', 'extpois'] - assert all([drive not in net_removed.external_drives.keys() - for drive in drives_removed]) + assert all( + [drive not in net_removed.external_drives.keys() for drive in drives_removed] + ) # Connections were removed - conn_src_types = set([conn['src_type'] - for conn in net_removed.connectivity]) + conn_src_types = set([conn['src_type'] for conn in net_removed.connectivity]) assert all([drive not in conn_src_types for drive in drives_removed]) # gid ranges were updated - assert all([drive not in net_removed.gid_ranges.keys() - for drive in drives_removed]) + assert all([drive not in net_removed.gid_ranges.keys() for drive in drives_removed]) # position dictionary was updated - assert all([drive not in net_removed.pos_dict.keys() - for drive in drives_removed]) + assert all([drive not in net_removed.pos_dict.keys() for drive in drives_removed]) class TestConvertToJson: @@ -120,47 +128,44 @@ class TestConvertToJson: def test_default_network_connectivity(self, tmp_path): """Tests conversion with default parameters""" - net_params = jones_2009_model(params=read_params(self.path_default), - add_drives_from_params=True) + net_params = jones_2009_model( + params=read_params(self.path_default), add_drives_from_params=True + ) # Write json and check if constructed network is equal outpath = Path(tmp_path, 'default.json') - convert_to_json(self.path_default, - outpath - ) + convert_to_json(self.path_default, outpath) net_json = read_network_configuration(outpath) assert net_json == net_params # Write json without drives outpath_no_drives = Path(tmp_path, 'default_no_drives.json') - convert_to_json(self.path_default, - outpath_no_drives, - include_drives=False - ) + convert_to_json(self.path_default, outpath_no_drives, include_drives=False) net_json_no_drives = read_network_configuration(outpath_no_drives) assert net_json_no_drives != net_json assert bool(net_json_no_drives.external_drives) is False # Check that writing with no extension will add one outpath_no_ext = Path(tmp_path, 'default_no_ext') - convert_to_json(self.path_default, - outpath_no_ext - ) + convert_to_json(self.path_default, outpath_no_ext) assert outpath_no_ext.with_suffix('.json').exists() def test_convert_to_json_legacy(self, tmp_path): """Tests conversion of a param legacy file to json""" # Download params - param_url = ('https://raw.githubusercontent.com/hnnsolver/' - 'hnn-core/test_data/default.param') + param_url = ( + 'https://raw.githubusercontent.com/hnnsolver/' + 'hnn-core/test_data/default.param' + ) params_base_fname = Path(hnn_core_root, 'param', 'default.param') if not op.exists(params_base_fname): urlretrieve(param_url, params_base_fname) - net_params = jones_2009_model(read_params(params_base_fname), - add_drives_from_params=True, - legacy_mode=True - ) + net_params = jones_2009_model( + read_params(params_base_fname), + add_drives_from_params=True, + legacy_mode=True, + ) # Write json and check if constructed network is correct outpath = Path(tmp_path, 'default.json') @@ -180,22 +185,17 @@ def test_convert_to_json_bad_type(self): bad_path = 5 # Valid path and string, but not actual files - with pytest.raises( - ValueError, - match="Unrecognized extension, expected one of" - ): + with pytest.raises(ValueError, match='Unrecognized extension, expected one of'): convert_to_json(good_path, path_str) # Bad params_fname with pytest.raises( - TypeError, - match="params_fname must be an instance of str or Path" + TypeError, match='params_fname must be an instance of str or Path' ): convert_to_json(bad_path, good_path) # Bad out_fname with pytest.raises( - TypeError, - match="out_fname must be an instance of str or Path" + TypeError, match='out_fname must be an instance of str or Path' ): convert_to_json(good_path, bad_path) diff --git a/hnn_core/tests/test_utils.py b/hnn_core/tests/test_utils.py index 600905b29..e78ca23c8 100644 --- a/hnn_core/tests/test_utils.py +++ b/hnn_core/tests/test_utils.py @@ -14,28 +14,31 @@ def test_hamming_smoothing(): pytest.raises(RuntimeError, smooth_waveform, data, window_len, sfreq) # window_len is positive number, longer than data, and >1ms - data, sfreq = np.random.random((100, )), 1 + data, sfreq = np.random.random((100,)), 1 for window_len in [None, -1, 1e6, 1e-1]: pytest.raises(ValueError, smooth_waveform, data, window_len, sfreq) # sfreq is positive number - data, window_len = np.random.random((100, )), 1 + data, window_len = np.random.random((100,)), 1 for sfreq in [None, [1], -1]: - pytest.raises((TypeError, AssertionError), smooth_waveform, - data, window_len, sfreq) + pytest.raises( + (TypeError, AssertionError), smooth_waveform, data, window_len, sfreq + ) def test_savgol_filter(): """Test Savitzky-Golay smoothing""" - data, sfreq = np.random.random((100, )), 1 + data, sfreq = np.random.random((100,)), 1 # h_freq is positive number and less than half the sampling rate for h_freq in [None, [1], -1, sfreq / 2]: - pytest.raises((TypeError, AssertionError, ValueError), - _savgol_filter, data, h_freq, sfreq) + pytest.raises( + (TypeError, AssertionError, ValueError), _savgol_filter, data, h_freq, sfreq + ) h_freq = 0.6 # sfreq is positive number and at least twice the cutoff frequency for sfreq in [None, [1], -1, 1]: - pytest.raises((TypeError, AssertionError, ValueError), - _savgol_filter, data, h_freq, sfreq) + pytest.raises( + (TypeError, AssertionError, ValueError), _savgol_filter, data, h_freq, sfreq + ) diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index 3b9c64e7c..36048fea3 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -11,9 +11,15 @@ import hnn_core from hnn_core import read_params, jones_2009_model -from hnn_core.viz import (plot_cells, plot_dipole, plot_psd, plot_tfr_morlet, - plot_connectivity_matrix, plot_cell_connectivity, - NetworkPlotter) +from hnn_core.viz import ( + plot_cells, + plot_dipole, + plot_psd, + plot_tfr_morlet, + plot_connectivity_matrix, + plot_cell_connectivity, + NetworkPlotter, +) from hnn_core.dipole import simulate_dipole matplotlib.use('agg') @@ -33,8 +39,7 @@ def _fake_click(fig, ax, point, button=1): """Fake a click at a point within axes.""" x, y = ax.transData.transform_point(point) button_press_event = backend_bases.MouseEvent( - name='button_press_event', canvas=fig.canvas, - x=x, y=y, button=button + name='button_press_event', canvas=fig.canvas, x=x, y=y, button=button ) fig.canvas.callbacks.process('button_press_event', button_press_event) @@ -77,8 +82,7 @@ def test_network_visualization(setup_net): cell_type.plot_morphology(color='r') sections = list(cell_type.sections.keys()) - section_color = {sect_name: f'C{idx}' for - idx, sect_name in enumerate(sections)} + section_color = {sect_name: f'C{idx}' for idx, sect_name in enumerate(sections)} cell_type.plot_morphology(color=section_color) cell_type = net.cell_types['L2_basket'] @@ -91,8 +95,9 @@ def test_network_visualization(setup_net): # test for invalid Axes object to plot_cells fig, axes = plt.subplots(1, 1) - with pytest.raises(TypeError, - match="'ax' to be an instance of Axes3D, but got Axes"): + with pytest.raises( + TypeError, match="'ax' to be an instance of Axes3D, but got Axes" + ): plot_cells(net, ax=axes, show=False) cell_type.plot_morphology(pos=(1.0, 2.0, 3.0)) with pytest.raises(TypeError, match='pos must be'): @@ -107,10 +112,16 @@ def test_network_visualization(setup_net): # test interactive clicking updates the position of src_cell in plot del net.connectivity[-1] conn_idx = 15 - net.add_connection(net.gid_ranges['L2_pyramidal'][::2], - 'L5_basket', 'soma', - 'ampa', 0.00025, 1.0, lamtha=3.0, - probability=0.8) + net.add_connection( + net.gid_ranges['L2_pyramidal'][::2], + 'L5_basket', + 'soma', + 'ampa', + 0.00025, + 1.0, + lamtha=3.0, + probability=0.8, + ) fig = plot_cell_connectivity(net, conn_idx, show=False) ax_src, ax_target, _ = fig.axes @@ -126,26 +137,42 @@ def test_dipole_visualization(setup_net): net = setup_net # Test plotting of simulations with no spiking - dpls = simulate_dipole(net, tstop=100., n_trials=1) + dpls = simulate_dipole(net, tstop=100.0, n_trials=1) net.cell_response.plot_spikes_raster() net.cell_response.plot_spikes_hist() weights_ampa = {'L2_pyramidal': 5.4e-5, 'L5_pyramidal': 5.4e-5} - syn_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.} + syn_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.0} net.add_bursty_drive( - 'beta_prox', tstart=0., burst_rate=25, burst_std=5, - numspikes=1, spike_isi=0, n_drive_cells=11, location='proximal', - weights_ampa=weights_ampa, synaptic_delays=syn_delays, - event_seed=14) + 'beta_prox', + tstart=0.0, + burst_rate=25, + burst_std=5, + numspikes=1, + spike_isi=0, + n_drive_cells=11, + location='proximal', + weights_ampa=weights_ampa, + synaptic_delays=syn_delays, + event_seed=14, + ) net.add_bursty_drive( - 'beta_dist', tstart=0., burst_rate=25, burst_std=5, - numspikes=1, spike_isi=0, n_drive_cells=11, location='distal', - weights_ampa=weights_ampa, synaptic_delays=syn_delays, - event_seed=14) + 'beta_dist', + tstart=0.0, + burst_rate=25, + burst_std=5, + numspikes=1, + spike_isi=0, + n_drive_cells=11, + location='distal', + weights_ampa=weights_ampa, + synaptic_delays=syn_delays, + event_seed=14, + ) - dpls = simulate_dipole(net, tstop=100., n_trials=2, record_vsec='all') + dpls = simulate_dipole(net, tstop=100.0, n_trials=2, record_vsec='all') fig = dpls[0].plot() # plot the first dipole alone axes = fig.get_axes()[0] dpls[0].copy().smooth(window_len=10).plot(ax=axes) # add smoothed versions @@ -153,9 +180,10 @@ def test_dipole_visualization(setup_net): # test decimation options plot_dipole(dpls[0], decim=2, show=False) - for dec in [-1, [2, 2.]]: - with pytest.raises(ValueError, - match='each decimation factor must be a positive'): + for dec in [-1, [2, 2.0]]: + with pytest.raises( + ValueError, match='each decimation factor must be a positive' + ): plot_dipole(dpls[0], decim=dec, show=False) # test plotting multiple dipoles as overlay @@ -172,30 +200,24 @@ def test_dipole_visualization(setup_net): fig, axes = plt.subplots(nrows=3, ncols=1) fig = plot_dipole(dpls, show=False, ax=axes, layer=['L2', 'L5', 'agg']) fig, axes = plt.subplots(nrows=3, ncols=1) - fig = plot_dipole(dpls, - show=False, - ax=[axes[0], axes[1], axes[2]], - layer=['L2', 'L5', 'agg']) + fig = plot_dipole( + dpls, show=False, ax=[axes[0], axes[1], axes[2]], layer=['L2', 'L5', 'agg'] + ) plt.close('all') - with pytest.raises(AssertionError, - match="ax and layer should have the same size"): + with pytest.raises(AssertionError, match='ax and layer should have the same size'): fig, axes = plt.subplots(nrows=3, ncols=1) fig = plot_dipole(dpls, show=False, ax=axes, layer=['L2', 'L5']) # multiple TFRs get averaged - fig = plot_tfr_morlet(dpls, freqs=np.arange(23, 26, 1.), n_cycles=3, - show=False) + fig = plot_tfr_morlet(dpls, freqs=np.arange(23, 26, 1.0), n_cycles=3, show=False) - with pytest.raises(RuntimeError, - match="All dipoles must be scaled equally!"): + with pytest.raises(RuntimeError, match='All dipoles must be scaled equally!'): plot_dipole([dpls[0].copy().scale(10), dpls[1].copy().scale(20)]) - with pytest.raises(RuntimeError, - match="All dipoles must be scaled equally!"): + with pytest.raises(RuntimeError, match='All dipoles must be scaled equally!'): plot_psd([dpls[0].copy().scale(10), dpls[1].copy().scale(20)]) - with pytest.raises(RuntimeError, - match="All dipoles must be sampled equally!"): + with pytest.raises(RuntimeError, match='All dipoles must be sampled equally!'): dpl_sfreq = dpls[0].copy() dpl_sfreq.sfreq /= 10 plot_psd([dpls[0], dpl_sfreq]) @@ -205,36 +227,34 @@ def test_dipole_visualization(setup_net): plot_dipole(dpls[0], show=False, tmin=10, tmax=100) # test cell response plotting - with pytest.raises(TypeError, match="trial_idx must be an instance of"): + with pytest.raises(TypeError, match='trial_idx must be an instance of'): net.cell_response.plot_spikes_raster(trial_idx='blah', show=False) net.cell_response.plot_spikes_raster(trial_idx=0, show=False) fig = net.cell_response.plot_spikes_raster(trial_idx=[0, 1], show=False) - assert len(fig.axes[0].collections) > 0, "No data plotted in raster plot" + assert len(fig.axes[0].collections) > 0, 'No data plotted in raster plot' - with pytest.raises(TypeError, match="trial_idx must be an instance of"): + with pytest.raises(TypeError, match='trial_idx must be an instance of'): net.cell_response.plot_spikes_hist(trial_idx='blah') net.cell_response.plot_spikes_hist(trial_idx=0, show=False) net.cell_response.plot_spikes_hist(trial_idx=[0, 1], show=False) net.cell_response.plot_spikes_hist(color='r') net.cell_response.plot_spikes_hist(color=['C0', 'C1']) - net.cell_response.plot_spikes_hist(color={'beta_prox': 'r', - 'beta_dist': 'g'}) + net.cell_response.plot_spikes_hist(color={'beta_prox': 'r', 'beta_dist': 'g'}) net.cell_response.plot_spikes_hist( - spike_types={'group1': ['beta_prox', 'beta_dist']}, - color={'group1': 'r'}) + spike_types={'group1': ['beta_prox', 'beta_dist']}, color={'group1': 'r'} + ) net.cell_response.plot_spikes_hist( - spike_types={'group1': ['beta']}, color={'group1': 'r'}) + spike_types={'group1': ['beta']}, color={'group1': 'r'} + ) - with pytest.raises(TypeError, match="color must be an instance of"): + with pytest.raises(TypeError, match='color must be an instance of'): net.cell_response.plot_spikes_hist(color=123) with pytest.raises(ValueError): net.cell_response.plot_spikes_hist(color='z') with pytest.raises(ValueError): - net.cell_response.plot_spikes_hist(color={'beta_prox': 'z', - 'beta_dist': 'g'}) - with pytest.raises(TypeError, match="Dictionary values of color must"): - net.cell_response.plot_spikes_hist(color={'beta_prox': 123, - 'beta_dist': 'g'}) + net.cell_response.plot_spikes_hist(color={'beta_prox': 'z', 'beta_dist': 'g'}) + with pytest.raises(TypeError, match='Dictionary values of color must'): + net.cell_response.plot_spikes_hist(color={'beta_prox': 123, 'beta_dist': 'g'}) with pytest.raises(ValueError, match="'beta_dist' must be"): net.cell_response.plot_spikes_hist(color={'beta_prox': 'r'}) plt.close('all') @@ -244,8 +264,18 @@ def test_network_plotter_init(setup_net): """Test init keywords of NetworkPlotter class.""" net = setup_net # test NetworkPlotter class - args = ['xlim', 'ylim', 'zlim', 'elev', 'azim', 'vmin', 'vmax', - 'trial_idx', 'time_idx', 'colorbar'] + args = [ + 'xlim', + 'ylim', + 'zlim', + 'elev', + 'azim', + 'vmin', + 'vmax', + 'trial_idx', + 'time_idx', + 'colorbar', + ] for arg in args: with pytest.raises(TypeError, match=f'{arg} must be'): net_plot = NetworkPlotter(net, **{arg: 'blah'}) @@ -298,17 +328,35 @@ def test_network_plotter_setter(setup_net): net = setup_net net_plot = NetworkPlotter(net) # Type check errors - args = ['xlim', 'ylim', 'zlim', 'elev', 'azim', 'vmin', 'vmax', - 'trial_idx', 'time_idx', 'colorbar'] + args = [ + 'xlim', + 'ylim', + 'zlim', + 'elev', + 'azim', + 'vmin', + 'vmax', + 'trial_idx', + 'time_idx', + 'colorbar', + ] for arg in args: with pytest.raises(TypeError, match=f'{arg} must be'): setattr(net_plot, arg, 'blah') # Check that the setters and getters work - arg_dict = {'xlim': (-100, 100), 'ylim': (-100, 100), 'zlim': (-100, 100), - 'elev': 10, 'azim': 10, 'vmin': 0, 'vmax': 100, - 'bgcolor': 'white', 'voltage_colormap': 'jet', - 'colorbar': False} + arg_dict = { + 'xlim': (-100, 100), + 'ylim': (-100, 100), + 'zlim': (-100, 100), + 'elev': 10, + 'azim': 10, + 'vmin': 0, + 'vmax': 100, + 'bgcolor': 'white', + 'voltage_colormap': 'jet', + 'colorbar': False, + } for arg, val in arg_dict.items(): setattr(net_plot, arg, val) assert getattr(net_plot, arg) == val @@ -328,8 +376,7 @@ def test_network_plotter_setter(setup_net): def test_network_plotter_export(tmp_path, setup_net): """Test NetworkPlotter class export methods.""" net = setup_net - _ = simulate_dipole(net, dt=0.5, tstop=10, n_trials=1, - record_vsec='all') + _ = simulate_dipole(net, dt=0.5, tstop=10, n_trials=1, record_vsec='all') net_plot = NetworkPlotter(net) # Check no file is already written @@ -349,21 +396,31 @@ def test_invert_spike_types(setup_net): net = setup_net weights_ampa = {'L2_pyramidal': 0.15, 'L5_pyramidal': 0.15} - syn_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.} + syn_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.0} net.add_evoked_drive( - 'evdist1', mu=63.53, sigma=3.85, numspikes=1, - weights_ampa=weights_ampa, location='distal', - synaptic_delays=syn_delays, event_seed=274 + 'evdist1', + mu=63.53, + sigma=3.85, + numspikes=1, + weights_ampa=weights_ampa, + location='distal', + synaptic_delays=syn_delays, + event_seed=274, ) net.add_evoked_drive( - 'evprox1', mu=26.61, sigma=2.47, numspikes=1, - weights_ampa=weights_ampa, location='proximal', - synaptic_delays=syn_delays, event_seed=274 + 'evprox1', + mu=26.61, + sigma=2.47, + numspikes=1, + weights_ampa=weights_ampa, + location='proximal', + synaptic_delays=syn_delays, + event_seed=274, ) - _ = simulate_dipole(net, dt=0.5, tstop=80., n_trials=1) + _ = simulate_dipole(net, dt=0.5, tstop=80.0, n_trials=1) # test string input net.cell_response.plot_spikes_hist( diff --git a/hnn_core/utils.py b/hnn_core/utils.py index bb888c558..930f727de 100644 --- a/hnn_core/utils.py +++ b/hnn_core/utils.py @@ -42,19 +42,18 @@ def _savgol_filter(data, h_freq, sfreq): from scipy.signal import savgol_filter _validate_type(sfreq, (float, int), 'sfreq') - assert sfreq > 0. + assert sfreq > 0.0 _validate_type(h_freq, (float, int), 'h_freq') - assert h_freq > 0. + assert h_freq > 0.0 h_freq = float(h_freq) - if h_freq >= sfreq / 2.: + if h_freq >= sfreq / 2.0: raise ValueError('h_freq must be less than half the sample rate') # savitzky-golay filtering window_length = (int(np.round(sfreq / h_freq)) // 2) * 2 + 1 # loop over 'agg', 'L2', and 'L5' - filt_data = savgol_filter(data, axis=-1, polyorder=5, - window_length=window_length) + filt_data = savgol_filter(data, axis=-1, polyorder=5, window_length=window_length) return filt_data @@ -76,8 +75,9 @@ def smooth_waveform(data, window_len, sfreq): data_filt : np.ndarray The filtered data """ - if ((isinstance(data, np.ndarray) and data.ndim > 1) or - (isinstance(data, list) and isinstance(data[0], list))): + if (isinstance(data, np.ndarray) and data.ndim > 1) or ( + isinstance(data, list) and isinstance(data[0], list) + ): raise RuntimeError('smoothing currently only supported for 1D-arrays') if not isinstance(window_len, (float, int)) or window_len < 0: @@ -86,12 +86,13 @@ def smooth_waveform(data, window_len, sfreq): raise ValueError('Window length less than 1 ms is not supported') _validate_type(sfreq, (float, int), 'sfreq') - assert sfreq > 0. + assert sfreq > 0.0 # convolutional filter length is given in samples winsz = np.round(1e-3 * window_len * sfreq) if winsz > len(data): raise ValueError( f'Window length too long: {winsz} samples; data length is ' - f'{len(data)} samples') + f'{len(data)} samples' + ) return _hammfilt(data, winsz) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 19eb7e9a4..f653299c1 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -13,6 +13,7 @@ def _lighten_color(color, amount=0.5): import matplotlib.colors as mc + try: c = mc.cnames[color] except: @@ -41,13 +42,16 @@ def _get_plot_data_trange(times, data, tmin=None, tmax=None): def _decimate_plot_data(decim, data, times, sfreq=None): from scipy.signal import decimate + if not isinstance(decim, list): decim = [decim] for dec in decim: if not isinstance(dec, int) or dec < 1: - raise ValueError('each decimation factor must be a positive int, ' - f'but {dec} is a {type(dec)}') + raise ValueError( + 'each decimation factor must be a positive int, ' + f'but {dec} is a {type(dec)}' + ) data = decimate(data, dec) times = times[::dec] @@ -74,13 +78,24 @@ def plt_show(show=True, fig=None, **kwargs): """ from matplotlib import get_backend import matplotlib.pyplot as plt + if show and get_backend() != 'agg': (fig or plt).show(**kwargs) -def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, - ax=None, decim=None, color='cividis', - voltage_offset=50, voltage_scalebar=200, show=True): +def plot_laminar_lfp( + times, + data, + contact_labels, + tmin=None, + tmax=None, + ax=None, + decim=None, + color='cividis', + voltage_offset=50, + voltage_scalebar=200, + show=True, +): """Plot laminar extracellular electrode array voltage time series. Parameters @@ -121,6 +136,7 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, """ import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap + _validate_type(times, (list, np.ndarray), 'times') _validate_type(data, (list, np.ndarray), 'data') if isinstance(times, list): @@ -130,28 +146,31 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, if data.ndim != 2: raise ValueError(f'data must be 2D, got shape {data.shape}') if len(times) != data.shape[1]: - raise ValueError(f'length of times ({len(times)}) and data ' - f'({len(data)}) do not match') + raise ValueError( + f'length of times ({len(times)}) and data ' f'({len(data)}) do not match' + ) n_contacts = data.shape[0] if color is not None: - _validate_type(color, - (str, tuple, list, np.ndarray, ListedColormap), - 'color') + _validate_type(color, (str, tuple, list, np.ndarray, ListedColormap), 'color') if isinstance(color, (tuple, list)): - if (not np.all([isinstance(c, float) for c in color]) or - len(color) < 3 or len(color) > 4): - raise ValueError( - f'color must be length 3 or 4, got {color}') + if ( + not np.all([isinstance(c, float) for c in color]) + or len(color) < 3 + or len(color) > 4 + ): + raise ValueError(f'color must be length 3 or 4, got {color}') elif isinstance(color, np.ndarray): - if (color.shape[0] != n_contacts or - (color.shape[1] < 3 or color.shape[1] > 4)): - raise ValueError( - f'color must be n_contacts x (3 or 4), got {color}') + if color.shape[0] != n_contacts or ( + color.shape[1] < 3 or color.shape[1] > 4 + ): + raise ValueError(f'color must be n_contacts x (3 or 4), got {color}') elif isinstance(color, ListedColormap): if color.N != n_contacts: - raise ValueError(f'ListedColormap has N={color.N}, but ' - f'there are {n_contacts} contacts') + raise ValueError( + f'ListedColormap has N={color.N}, but ' + f'there are {n_contacts} contacts' + ) elif isinstance(color, str): color = plt.get_cmap(color, len(contact_labels)) @@ -168,8 +187,7 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, plot_times = times if decim is not None: - plot_data, plot_times = _decimate_plot_data(decim, plot_data, - plot_times) + plot_data, plot_times = _decimate_plot_data(decim, plot_data, plot_times) if isinstance(color, np.ndarray): col = color[contact_no] @@ -177,16 +195,22 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, col = color(contact_no) else: col = color - ax.plot(plot_times, plot_data + trace_offsets[contact_no], - label=f'C{contact_no}', color=col) + ax.plot( + plot_times, + plot_data + trace_offsets[contact_no], + label=f'C{contact_no}', + color=col, + ) # To be removed after deprecation cycle if tmin is not None or tmax is not None: ax.set_xlim(left=tmin, right=tmax) - warnings.warn('tmin and tmax are deprecated and will be ' - 'removed in future releases of hnn-core. Please' - 'use matplotlib plt.xlim to set tmin and tmax.', - DeprecationWarning) + warnings.warn( + 'tmin and tmax are deprecated and will be ' + 'removed in future releases of hnn-core. Please' + 'use matplotlib plt.xlim to set tmin and tmax.', + DeprecationWarning, + ) else: ax.set_xlim(left=times[0], right=times[-1]) @@ -194,11 +218,14 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, ax.set_ylim(-voltage_offset, n_offsets * voltage_offset) ylabel = 'Individual contact traces' if len(contact_labels) != n_offsets: - raise ValueError(f'contact_labels is length {len(contact_labels)},' - f' but {n_offsets} contacts to be plotted') + raise ValueError( + f'contact_labels is length {len(contact_labels)},' + f' but {n_offsets} contacts to be plotted' + ) else: - trace_ticks = np.arange(0, len(contact_labels) * voltage_offset, - voltage_offset) + trace_ticks = np.arange( + 0, len(contact_labels) * voltage_offset, voltage_offset + ) ax.set_yticks(trace_ticks) ax.set_yticklabels(contact_labels) @@ -207,14 +234,18 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, if voltage_scalebar is not None: from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar - scalebar = AnchoredSizeBar(ax.transData, 1, - f'{voltage_scalebar:.0f} ' + r'$\mu V$', - 'upper left', - size_vertical=voltage_scalebar, - pad=0.1, - color='black', - label_top=False, - frameon=False) + + scalebar = AnchoredSizeBar( + ax.transData, + 1, + f'{voltage_scalebar:.0f} ' + r'$\mu V$', + 'upper left', + size_vertical=voltage_scalebar, + pad=0.1, + color='black', + label_top=False, + frameon=False, + ) ax.add_artist(scalebar) else: ylabel = r'Electric potential ($\mu V$)' @@ -227,8 +258,18 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, return ax.get_figure() -def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, - color='k', label="average", average=False, show=True): +def plot_dipole( + dpl, + tmin=None, + tmax=None, + ax=None, + layer='agg', + decim=None, + color='k', + label='average', + average=False, + show=True, +): """Simple layer-specific plot function. Parameters @@ -264,11 +305,9 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, layers = layer if isinstance(layer, list) else [layer] if ax is None: - _, ax = plt.subplots(len(layers), - 1, - constrained_layout=True, - sharex=True, - sharey=True) + _, ax = plt.subplots( + len(layers), 1, constrained_layout=True, sharex=True, sharey=True + ) axes = ax if isinstance(ax, (list, np.ndarray)) else [ax] if isinstance(dpl, Dipole): @@ -281,7 +320,7 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, scale_applied = dpl[0].scale_applied - assert len(layers) == len(axes), "ax and layer should have the same size" + assert len(layers) == len(axes), 'ax and layer should have the same size' for layer, ax in zip(layers, axes): for idx, dpl_trial in enumerate(dpl): @@ -289,7 +328,6 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, raise RuntimeError('All dipoles must be scaled equally!') if layer in dpl_trial.data.keys(): - # extract scaled data and times data = dpl_trial.data[layer] times = dpl_trial.times @@ -300,18 +338,25 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, # the average dpl ax.plot(times, data, color=color, label=label, lw=1.5) else: - alpha = 0.5 if average else 1. - ax.plot(times, data, color=_lighten_color(color, 0.5), - alpha=alpha, lw=1.) + alpha = 0.5 if average else 1.0 + ax.plot( + times, + data, + color=_lighten_color(color, 0.5), + alpha=alpha, + lw=1.0, + ) # To be removed after deprecation cycle if tmin is not None or tmax is not None: if tmin is not None or tmax is not None: - warnings.warn('tmin and tmax are deprecated and will be ' - 'removed in future releases of hnn-core. ' - 'Please use matplotlib plt.xlim to set tmin' - ' and tmax.', - DeprecationWarning) + warnings.warn( + 'tmin and tmax are deprecated and will be ' + 'removed in future releases of hnn-core. ' + 'Please use matplotlib plt.xlim to set tmin' + ' and tmax.', + DeprecationWarning, + ) ax.set_xlim(left=tmin, right=tmax) else: ax.set_xlim(left=0, right=times[-1]) @@ -323,8 +368,7 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, if scale_applied == 1: ylabel = 'Dipole moment (nAm)' else: - ylabel = 'Dipole moment\n(nAm ' +\ - r'$\times$ {:.0f})'.format(scale_applied) + ylabel = 'Dipole moment\n(nAm ' + r'$\times$ {:.0f})'.format(scale_applied) ax.set_ylabel(ylabel, multialignment='center') if layer == 'agg': title_str = 'Aggregate (L2/3 + L5)' @@ -338,9 +382,16 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, return axes[0].get_figure() -def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, - color=None, invert_spike_types=None, show=True, - **kwargs_hist): +def plot_spikes_hist( + cell_response, + trial_idx=None, + ax=None, + spike_types=None, + color=None, + invert_spike_types=None, + show=True, + **kwargs_hist, +): """Plot the histogram of spiking activity across trials. Parameters @@ -405,6 +456,7 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, The matplotlib figure handle. """ import matplotlib.pyplot as plt + n_trials = len(cell_response.spike_times) if trial_idx is None: trial_idx = list(range(n_trials)) @@ -416,16 +468,19 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, # Extract desired trials if len(cell_response._spike_times[0]) > 0: spike_times = np.concatenate( - np.array(cell_response._spike_times, dtype=object)[trial_idx]) + np.array(cell_response._spike_times, dtype=object)[trial_idx] + ) spike_types_data = np.concatenate( - np.array(cell_response._spike_types, dtype=object)[trial_idx]) + np.array(cell_response._spike_types, dtype=object)[trial_idx] + ) else: spike_times = np.array([]) spike_types_data = np.array([]) unique_types = np.unique(spike_types_data) - spike_types_mask = {s_type: np.isin(spike_types_data, s_type) - for s_type in unique_types} + spike_types_mask = { + s_type: np.isin(spike_types_data, s_type) for s_type in unique_types + } cell_types = ['L5_pyramidal', 'L5_basket', 'L2_pyramidal', 'L2_basket'] input_types = np.setdiff1d(unique_types, cell_types) @@ -442,9 +497,11 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, if isinstance(spike_types, dict): for spike_label in spike_types: if not isinstance(spike_types[spike_label], list): - raise TypeError(f'spike_types[{spike_label}] must be a list. ' - f'Got ' - f'{type(spike_types[spike_label]).__name__}.') + raise TypeError( + f'spike_types[{spike_label}] must be a list. ' + f'Got ' + f'{type(spike_types[spike_label]).__name__}.' + ) if not isinstance(spike_types, dict): raise TypeError('spike_types should be str, list, dict, or None') @@ -456,10 +513,12 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, for unique_type in unique_types: if unique_type.startswith(spike_type): if unique_type in spike_labels: - raise ValueError(f'Elements of spike_types must map to' - f' mutually exclusive input types.' - f' {unique_type} is found more than' - f' once.') + raise ValueError( + f'Elements of spike_types must map to' + f' mutually exclusive input types.' + f' {unique_type} is found more than' + f' once.' + ) spike_labels[unique_type] = spike_label n_found += 1 if n_found == 0: @@ -468,8 +527,7 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, if ax is None: _, ax = plt.subplots(1, 1, constrained_layout=True) - _validate_type(color, (str, list, dict, None), - 'color', 'str, list of str, or dict') + _validate_type(color, (str, list, dict, None), 'color', 'str, list of str, or dict') if color is None: color_cycle = cycle(['r', 'g', 'b', 'y', 'm', 'c']) @@ -484,29 +542,32 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, bins = np.linspace(0, spike_times[-1], 50) # Create dictionary to aggregate spike times that have the same spike_label - spike_type_times = {spike_label: list() for - spike_label in np.unique(list(spike_labels.values()))} + spike_type_times = { + spike_label: list() for spike_label in np.unique(list(spike_labels.values())) + } spike_color = dict() # Store colors specified for each spike_label for spike_type, spike_label in spike_labels.items(): if spike_label not in spike_color: if isinstance(color, dict): if spike_label not in color: raise ValueError( - f"'{spike_label}' must be defined in color dictionary") - _validate_type(color[spike_label], str, - 'Dictionary values of color', 'str') + f"'{spike_label}' must be defined in color dictionary" + ) + _validate_type( + color[spike_label], str, 'Dictionary values of color', 'str' + ) spike_color[spike_label] = color[spike_label] else: spike_color[spike_label] = next(color_cycle) - spike_type_times[spike_label].extend( - spike_times[spike_types_mask[spike_type]]) + spike_type_times[spike_label].extend(spike_times[spike_types_mask[spike_type]]) if invert_spike_types is None: invert_spike_types = list() else: if not isinstance(invert_spike_types, (str, list)): raise TypeError( - "'invert_spike_types' must be a string or a list of strings") + "'invert_spike_types' must be a string or a list of strings" + ) if isinstance(invert_spike_types, str): invert_spike_types = [invert_spike_types] @@ -516,8 +577,7 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, check_intersection = unique_invert_inputs.intersection(unique_inputs) if not check_intersection == unique_invert_inputs: raise ValueError( - "Elements of 'invert_spike_types' must" - "map to valid input types" + "Elements of 'invert_spike_types' must" 'map to valid input types' ) # Initialize secondary axis @@ -529,14 +589,14 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, # Plot on the primary y-axis if spike_label not in invert_spike_types: - ax.hist(plot_data, bins, - label=spike_label, color=hist_color, **kwargs_hist) + ax.hist(plot_data, bins, label=spike_label, color=hist_color, **kwargs_hist) # Plot on secondary y-axis else: if ax1 is None: ax1 = ax.twinx() - ax1.hist(plot_data, bins, - label=spike_label, color=hist_color, **kwargs_hist) + ax1.hist( + plot_data, bins, label=spike_label, color=hist_color, **kwargs_hist + ) # Need to add label for easy removal later # Set the y-limits based on the maximum across both axes @@ -548,15 +608,15 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, ax.set_ylim(0, y_max) ax1.set_ylim(0, y_max) ax1.invert_yaxis() - ax1.set_label("Inverted spike histogram") + ax1.set_label('Inverted spike histogram') if len(cell_response.times) > 0: ax.set_xlim(left=0, right=cell_response.times[-1]) else: ax.set_xlim(left=0) - ax.set_ylabel("Counts") - ax.set_label("Spike histogram") + ax.set_ylabel('Counts') + ax.set_label('Spike histogram') if ax1 is not None: # Combine legends @@ -594,6 +654,7 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): """ import matplotlib.pyplot as plt + n_trials = len(cell_response.spike_times) if trial_idx is None: trial_idx = list(range(n_trials)) @@ -604,15 +665,22 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): # Extract desired trials spike_times = np.concatenate( - np.array(cell_response._spike_times, dtype=object)[trial_idx]) + np.array(cell_response._spike_times, dtype=object)[trial_idx] + ) spike_types = np.concatenate( - np.array(cell_response._spike_types, dtype=object)[trial_idx]) + np.array(cell_response._spike_types, dtype=object)[trial_idx] + ) spike_gids = np.concatenate( - np.array(cell_response._spike_gids, dtype=object)[trial_idx]) + np.array(cell_response._spike_gids, dtype=object)[trial_idx] + ) cell_types = ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal'] - cell_type_colors = {'L5_pyramidal': 'r', 'L5_basket': 'b', - 'L2_pyramidal': 'g', 'L2_basket': 'w'} + cell_type_colors = { + 'L5_pyramidal': 'r', + 'L5_basket': 'b', + 'L2_pyramidal': 'g', + 'L2_basket': 'w', + } if ax is None: _, ax = plt.subplots(1, 1, constrained_layout=True) @@ -628,14 +696,24 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): if cell_type_times: events.append( - ax.eventplot(cell_type_times, lineoffsets=cell_type_ypos, - color=cell_type_colors[cell_type], - label=cell_type, linelengths=5)) + ax.eventplot( + cell_type_times, + lineoffsets=cell_type_ypos, + color=cell_type_colors[cell_type], + label=cell_type, + linelengths=5, + ) + ) else: events.append( - ax.eventplot([-1], lineoffsets=[-1], - color=cell_type_colors[cell_type], - label=cell_type, linelengths=5)) + ax.eventplot( + [-1], + lineoffsets=[-1], + color=cell_type_colors[cell_type], + label=cell_type, + linelengths=5, + ) + ) ax.legend(handles=[e[0] for e in events], loc=1) ax.set_facecolor('k') @@ -678,13 +756,22 @@ def plot_cells(net, ax=None, show=True): ax = fig.add_subplot(111, projection='3d') elif not isinstance(ax, Axes3D): - raise TypeError("Expected 'ax' to be an instance of Axes3D, " - f"but got {type(ax).__name__}") - - colors = {'L5_pyramidal': 'b', 'L2_pyramidal': 'c', - 'L5_basket': 'r', 'L2_basket': 'm'} - markers = {'L5_pyramidal': '^', 'L2_pyramidal': '^', - 'L5_basket': 'x', 'L2_basket': 'x'} + raise TypeError( + "Expected 'ax' to be an instance of Axes3D, " f'but got {type(ax).__name__}' + ) + + colors = { + 'L5_pyramidal': 'b', + 'L2_pyramidal': 'c', + 'L5_basket': 'r', + 'L2_basket': 'm', + } + markers = { + 'L5_pyramidal': '^', + 'L2_pyramidal': '^', + 'L5_basket': 'x', + 'L2_basket': 'x', + } for cell_type in net.cell_types: x = [pos[0] for pos in net.pos_dict[cell_type]] @@ -701,19 +788,30 @@ def plot_cells(net, ax=None, show=True): x = [p[0] for p in arr.positions] y = [p[1] for p in arr.positions] z = [p[2] for p in arr.positions] - ax.scatter(x, y, z, color=cols(ii + 1), s=25, marker='o', - label=arr_name) + ax.scatter(x, y, z, color=cols(ii + 1), s=25, marker='o', label=arr_name) - plt.legend(bbox_to_anchor=(-0.15, 1.025), loc="upper left") + plt.legend(bbox_to_anchor=(-0.15, 1.025), loc='upper left') plt_show(show) return ax.get_figure() -def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, - layer='agg', decim=None, padding='zeros', ax=None, - colormap='inferno', colorbar=True, colorbar_inside=False, - show=True): +def plot_tfr_morlet( + dpl, + freqs, + *, + n_cycles=7.0, + tmin=None, + tmax=None, + layer='agg', + decim=None, + padding='zeros', + ax=None, + colormap='inferno', + colorbar=True, + colorbar_inside=False, + show=True, +): """Plot Morlet time-frequency representation of dipole time course Parameters @@ -776,37 +874,37 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, if dpl_trial.sfreq != sfreq: raise RuntimeError('All dipoles must be sampled equally!') - data, times = _get_plot_data_trange(dpl_trial.times, - dpl_trial.data[layer], - tmin, tmax) + data, times = _get_plot_data_trange( + dpl_trial.times, dpl_trial.data[layer], tmin, tmax + ) sfreq = dpl_trial.sfreq if decim is not None: - data, times, sfreq = _decimate_plot_data(decim, data, times, - sfreq=sfreq) + data, times, sfreq = _decimate_plot_data(decim, data, times, sfreq=sfreq) if padding is not None: if not isinstance(padding, str): raise ValueError('padding must be a string (or None)') if padding == 'zeros': - data = np.r_[np.zeros((len(data) - 1,)), data.ravel(), - np.zeros((len(data) - 1,))] + data = np.r_[ + np.zeros((len(data) - 1,)), data.ravel(), np.zeros((len(data) - 1,)) + ] elif padding == 'mirror': data = np.r_[data[-1:0:-1], data, data[-2::-1]] # MNE expects an array of shape (n_trials, n_channels, n_times) data = data[None, None, :] - power = tfr_array_morlet(data, sfreq=sfreq, freqs=freqs, - n_cycles=n_cycles, output='power') + power = tfr_array_morlet( + data, sfreq=sfreq, freqs=freqs, n_cycles=n_cycles, output='power' + ) if padding is not None: # get the middle portion after padding - power = power[:, :, :, times.shape[0] - 1:2 * times.shape[0] - 1] + power = power[:, :, :, times.shape[0] - 1 : 2 * times.shape[0] - 1] trial_power.append(power) power = np.mean(trial_power, axis=0) - im = ax.pcolormesh(times, freqs, power[0, 0, ...], cmap=colormap, - shading='auto') + im = ax.pcolormesh(times, freqs, power[0, 0, ...], cmap=colormap, shading='auto') ax.set_xlabel('Time (ms)') ax.set_ylabel('Frequency (Hz)') @@ -818,11 +916,14 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, if colorbar_inside is False: cbar = fig.colorbar(im, ax=ax, format=xfmt, shrink=0.8, pad=0) cbar.ax.yaxis.set_ticks_position('left') - cbar.ax.set_ylabel(r'Power ([nAm $\times$ {:.0f}]$^2$)'.format( - scale_applied), rotation=-90, va="bottom") + cbar.ax.set_ylabel( + r'Power ([nAm $\times$ {:.0f}]$^2$)'.format(scale_applied), + rotation=-90, + va='bottom', + ) # put colorbar inside the heatmap. else: - cbar_color = "white" + cbar_color = 'white' cbar_fontsize = 6 ax_pos = ax.get_position() @@ -840,10 +941,17 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, cbar.ax.set_ylabel( r'Power ([nAm $\times$ {:.0f}]$^2$)'.format(scale_applied), - rotation=-90, va="bottom", fontsize=cbar_fontsize, - color=cbar_color) - cbar.ax.tick_params(direction='in', labelsize=cbar_fontsize, - labelcolor=cbar_color, colors=cbar_color) + rotation=-90, + va='bottom', + fontsize=cbar_fontsize, + color=cbar_color, + ) + cbar.ax.tick_params( + direction='in', + labelsize=cbar_fontsize, + labelcolor=cbar_color, + colors=cbar_color, + ) plt.setp(cbar.ax.spines.values(), color=cbar_color) setattr(fig, f'_cbar-ax-{id(ax)}', cbar) @@ -851,8 +959,19 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, return ax.get_figure() -def plot_psd(dpl, *, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg', - color=None, label=None, ax=None, show=True): +def plot_psd( + dpl, + *, + fmin=0, + fmax=None, + tmin=None, + tmax=None, + layer='agg', + color=None, + label=None, + ax=None, + show=True, +): """Plot power spectral density (PSD) of dipole time course Applies `~scipy.signal.periodogram` from SciPy with ``window='hamming'``. @@ -908,15 +1027,14 @@ def plot_psd(dpl, *, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg', if dpl_trial.sfreq != sfreq: raise RuntimeError('All dipoles must be sampled equally!') - data, _ = _get_plot_data_trange(dpl_trial.times, - dpl_trial.data[layer], - tmin, tmax) + data, _ = _get_plot_data_trange( + dpl_trial.times, dpl_trial.data[layer], tmin, tmax + ) freqs, Pxx = periodogram(data, sfreq, window='hamming', nfft=len(data)) trial_power.append(Pxx) - ax.plot(freqs, np.mean(np.array(Pxx, ndmin=2), axis=0), color=color, - label=label) + ax.plot(freqs, np.mean(np.array(Pxx, ndmin=2), axis=0), color=color, label=label) if label: ax.legend() if fmax is not None: @@ -926,9 +1044,11 @@ def plot_psd(dpl, *, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg', if scale_applied == 1: ylabel = 'Power spectral density\n(nAm' + r'$^2 \ Hz^{-1}$)' else: - ylabel = 'Power spectral density\n' +\ - r'([nAm$\times$ {:.0f}]'.format(scale_applied) +\ - r'$^2 \ Hz^{-1}$)' + ylabel = ( + 'Power spectral density\n' + + r'([nAm$\times$ {:.0f}]'.format(scale_applied) + + r'$^2 \ Hz^{-1}$)' + ) ax.set_ylabel(ylabel, multialignment='center') plt_show(show) @@ -946,8 +1066,15 @@ def _linewidth_from_data_units(ax, linewidth): def plot_cell_morphology( - cell, ax, color=None, pos=(0, 0, 0), xlim=(-250, 150), - ylim=(-100, 100), zlim=(-100, 1200), show=True): + cell, + ax, + color=None, + pos=(0, 0, 0), + xlim=(-250, 150), + ylim=(-100, 100), + zlim=(-100, 1200), + show=True, +): """Plot the cell morphology. Parameters @@ -1021,8 +1148,7 @@ def plot_cell_morphology( xs.append(pt[0] + dx) ys.append(pt[1] + dz) zs.append(pt[2] + dy) - ax.plot(xs, ys, zs, '-', linewidth=linewidth, - color=section_colors[sec_name]) + ax.plot(xs, ys, zs, '-', linewidth=linewidth, color=section_colors[sec_name]) ax.view_init(0, -90) ax.axis('off') @@ -1031,9 +1157,9 @@ def plot_cell_morphology( return ax -def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True, - colorbar=True, colormap='Greys', - show=True): +def plot_connectivity_matrix( + net, conn_idx, ax=None, show_weight=True, colorbar=True, colormap='Greys', show=True +): """Plot connectivity matrix with color bar for synaptic weights Parameters @@ -1093,8 +1219,8 @@ def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True, # Identical calculation used in Cell.par_connect_from_src() if show_weight: weight, _ = _get_gaussian_connection( - src_pos, target_pos, nc_dict, - inplane_distance=net._inplane_distance) + src_pos, target_pos, nc_dict, inplane_distance=net._inplane_distance + ) else: weight = 1.0 @@ -1111,25 +1237,38 @@ def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True, xfmt.set_powerlimits((-2, 2)) cbar = fig.colorbar(im, ax=ax, format=xfmt) cbar.ax.yaxis.set_ticks_position('right') - cbar.ax.set_ylabel('Weight', rotation=-90, va="bottom") + cbar.ax.set_ylabel('Weight', rotation=-90, va='bottom') - ax.set_xlabel(f"{conn['target_type']} target gids " - f"({target_range[0]}-{target_range[-1]})") + ax.set_xlabel( + f"{conn['target_type']} target gids " f"({target_range[0]}-{target_range[-1]})" + ) ax.set_xticklabels(list()) - ax.set_ylabel(f"{conn['src_type']} source gids " - f"({src_range[0]}-{src_range[-1]})") + ax.set_ylabel( + f"{conn['src_type']} source gids " f"({src_range[0]}-{src_range[-1]})" + ) ax.set_yticklabels(list()) - ax.set_title(f"{conn['src_type']} -> {conn['target_type']} " - f"({conn['loc']}, {conn['receptor']})") + ax.set_title( + f"{conn['src_type']} -> {conn['target_type']} " + f"({conn['loc']}, {conn['receptor']})" + ) plt.tight_layout() plt_show(show) return ax.get_figure() -def _update_target_plot(ax, conn, src_gid, src_type_pos, target_type_pos, - src_range, target_range, nc_dict, colormap, - inplane_distance): +def _update_target_plot( + ax, + conn, + src_gid, + src_type_pos, + target_type_pos, + src_range, + target_range, + nc_dict, + colormap, + inplane_distance, +): from .cell import _get_gaussian_connection # Extract indices to get position in network @@ -1146,13 +1285,13 @@ def _update_target_plot(ax, conn, src_gid, src_type_pos, target_type_pos, target_pos = target_type_pos[target_idx] target_x_pos.append(target_pos[0]) target_y_pos.append(target_pos[1]) - weight, _ = _get_gaussian_connection(src_pos, target_pos, nc_dict, - inplane_distance) + weight, _ = _get_gaussian_connection( + src_pos, target_pos, nc_dict, inplane_distance + ) weights.append(weight) ax.clear() - im = ax.scatter(target_x_pos, target_y_pos, c=weights, s=50, - cmap=colormap) + im = ax.scatter(target_x_pos, target_y_pos, c=weights, s=50, cmap=colormap) x_pos = target_type_pos[:, 0] y_pos = target_type_pos[:, 1] ax.scatter(x_pos, y_pos, color='k', marker='x', zorder=-1, s=20) @@ -1162,8 +1301,9 @@ def _update_target_plot(ax, conn, src_gid, src_type_pos, target_type_pos, return im -def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, - colorbar=True, colormap='viridis', show=True): +def plot_cell_connectivity( + net, conn_idx, src_gid=None, axes=None, colorbar=True, colormap='viridis', show=True +): """Plot synaptic weight of connections. This is an interactive plot with source cells shown in the left @@ -1229,8 +1369,10 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, _validate_type(src_gid, int, 'src_gid', 'int') if src_gid not in valid_src_gids: - raise ValueError(f'src_gid {src_gid} not a valid cell ID for this ' - f'connection. Please select one of {valid_src_gids}') + raise ValueError( + f'src_gid {src_gid} not a valid cell ID for this ' + f'connection. Please select one of {valid_src_gids}' + ) target_range = np.array(net.gid_ranges[conn['target_type']]) @@ -1246,39 +1388,56 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, else: ax = axes[0] - im = _update_target_plot(ax, conn, src_gid, src_type_pos, - target_type_pos, src_range, - target_range, nc_dict, colormap, - net._inplane_distance) + im = _update_target_plot( + ax, + conn, + src_gid, + src_type_pos, + target_type_pos, + src_range, + target_range, + nc_dict, + colormap, + net._inplane_distance, + ) x_src = src_type_pos[:, 0] y_src = src_type_pos[:, 1] x_src_valid = src_pos_valid[:, 0] y_src_valid = src_pos_valid[:, 1] if src_type in net.cell_types: - ax_src.scatter(x_src, y_src, marker='s', color='red', s=50, - alpha=0.2) - ax_src.scatter(x_src_valid, y_src_valid, marker='s', color='red', - s=50) + ax_src.scatter(x_src, y_src, marker='s', color='red', s=50, alpha=0.2) + ax_src.scatter(x_src_valid, y_src_valid, marker='s', color='red', s=50) - plt.suptitle(f"{conn['src_type']}-> {conn['target_type']}" - f" ({conn['loc']}, {conn['receptor']})") + plt.suptitle( + f"{conn['src_type']}-> {conn['target_type']}" + f" ({conn['loc']}, {conn['receptor']})" + ) def _onclick(event): if event.inaxes in [ax] or event.inaxes is None: return - dist = np.linalg.norm(src_type_pos[:, :2] - - np.array([event.xdata, event.ydata]), - axis=1) + dist = np.linalg.norm( + src_type_pos[:, :2] - np.array([event.xdata, event.ydata]), axis=1 + ) src_idx = np.argmin(dist) src_gid = src_range[src_idx] if src_gid not in valid_src_gids: return - _update_target_plot(ax, conn, src_gid, src_type_pos, - target_type_pos, src_range, target_range, - nc_dict, colormap, net._inplane_distance) + _update_target_plot( + ax, + conn, + src_gid, + src_type_pos, + target_type_pos, + src_range, + target_range, + nc_dict, + colormap, + net._inplane_distance, + ) fig.canvas.draw() @@ -1288,7 +1447,7 @@ def _onclick(event): xfmt.set_powerlimits((-2, 2)) cbar = fig.colorbar(im, ax=ax, format=xfmt) cbar.ax.yaxis.set_ticks_position('right') - cbar.ax.set_ylabel('Weight', rotation=-90, va="bottom") + cbar.ax.set_ylabel('Weight', rotation=-90, va='bottom') plt.tight_layout() @@ -1298,9 +1457,18 @@ def _onclick(event): return ax.get_figure() -def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True, - vmin=None, vmax=None, sink='b', interpolation='spline', - show=True): +def plot_laminar_csd( + times, + data, + contact_labels, + ax=None, + colorbar=True, + vmin=None, + vmax=None, + sink='b', + interpolation='spline', + show=True, +): """Plot laminar current source density (CSD) estimation from LFP array. Parameters @@ -1344,19 +1512,24 @@ def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True, _, ax = plt.subplots(1, 1, constrained_layout=True) if sink[0].lower() == 'b': - cmap = "jet" + cmap = 'jet' elif sink[0].lower() == 'r': - cmap = "jet_r" + cmap = 'jet_r' elif sink[0].lower() != 'b' or sink[0].lower() != 'r': - raise RuntimeError('Please use sink = "b" or sink = "r".' - ' Only colormap "jet" is supported for CSD.') + raise RuntimeError( + 'Please use sink = "b" or sink = "r".' + ' Only colormap "jet" is supported for CSD.' + ) if interpolation == 'spline': # create interpolation function interp_data = RectBivariateSpline(times, contact_labels, data.T) # increase number of contacts - new_depths = np.linspace(contact_labels[0], contact_labels[-1], - contact_labels[-1] - contact_labels[0]) + new_depths = np.linspace( + contact_labels[0], + contact_labels[-1], + contact_labels[-1] - contact_labels[0], + ) # interpolate data = interp_data(times, new_depths).T elif interpolation is None: @@ -1368,8 +1541,9 @@ def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True, vmin = -np.max(np.abs(data)) vmax = np.max(np.abs(data)) - im = ax.pcolormesh(times, new_depths, data, - cmap=cmap, shading='auto', vmin=vmin, vmax=vmax) + im = ax.pcolormesh( + times, new_depths, data, cmap=cmap, shading='auto', vmin=vmin, vmax=vmax + ) ax.set_xlabel('time (s)') ax.set_ylabel('electrode depth') if colorbar: @@ -1420,15 +1594,40 @@ class NetworkPlotter: time_idx : int Index of time point plotted. Default: 0 """ - def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black', - colorbar=True, voltage_colormap='viridis', elev=10, azim=-500, - xlim=(-200, 3100), ylim=(-200, 3100), zlim=(-300, 2200), - trial_idx=0, time_idx=0): + + def __init__( + self, + net, + ax=None, + vmin=-100, + vmax=50, + bg_color='black', + colorbar=True, + voltage_colormap='viridis', + elev=10, + azim=-500, + xlim=(-200, 3100), + ylim=(-200, 3100), + zlim=(-300, 2200), + trial_idx=0, + time_idx=0, + ): from matplotlib import colormaps - self._validate_parameters(vmin, vmax, bg_color, voltage_colormap, - colorbar, elev, azim, xlim, ylim, zlim, - trial_idx, time_idx) + self._validate_parameters( + vmin, + vmax, + bg_color, + voltage_colormap, + colorbar, + elev, + azim, + xlim, + ylim, + zlim, + trial_idx, + time_idx, + ) # Set init arguments self.net = net @@ -1462,9 +1661,21 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black', else: self._cbar = None - def _validate_parameters(self, vmin, vmax, bg_color, voltage_colormap, - colorbar, elev, azim, xlim, ylim, zlim, trial_idx, - time_idx): + def _validate_parameters( + self, + vmin, + vmax, + bg_color, + voltage_colormap, + colorbar, + elev, + azim, + xlim, + ylim, + zlim, + trial_idx, + time_idx, + ): _validate_type(vmin, (int, float), 'vmin') _validate_type(vmax, (int, float), 'vmax') _validate_type(bg_color, str, 'bg_color') @@ -1492,6 +1703,7 @@ def _check_network_simulation(self): def _initialize_plots(self): import matplotlib.pyplot as plt + # Create figure if self.ax is None: self.fig = plt.figure() @@ -1509,8 +1721,9 @@ def _get_voltages(self): cell = self.net.cell_types[cell_type] for sec_name in cell.sections.keys(): if self._vsec_recorded is True: - vsec = np.array(self.net.cell_response.vsec[ - self.trial_idx][gid][sec_name]) + vsec = np.array( + self.net.cell_response.vsec[self.trial_idx][gid][sec_name] + ) vsec_list.append(vsec) else: # Populate with zeros if no voltage recording vsec_list.append([0.0]) @@ -1521,9 +1734,11 @@ def _get_voltages(self): def _update_section_voltages(self, t_idx): if not self._vsec_recorded: - raise RuntimeError("Network must be simulated with" - "`simulate_dipole(record_vsec='all')` before" - "plotting voltages.") + raise RuntimeError( + 'Network must be simulated with' + "`simulate_dipole(record_vsec='all')` before" + 'plotting voltages.' + ) color_list = self.color_array[:, t_idx] for line, color in zip(self.ax.lines, color_list): line.set_color(color) @@ -1532,15 +1747,19 @@ def _init_network_plot(self): for cell_type in self.net.cell_types: gid_range = self.net.gid_ranges[cell_type] for gid_idx, gid in enumerate(gid_range): - cell = self.net.cell_types[cell_type] pos = self.net.pos_dict[cell_type][gid_idx] pos = (float(pos[0]), float(pos[2]), float(pos[1])) - cell.plot_morphology(ax=self.ax, show=False, - pos=pos, xlim=self.xlim, - ylim=self.ylim, zlim=self.zlim) + cell.plot_morphology( + ax=self.ax, + show=False, + pos=pos, + xlim=self.xlim, + ylim=self.ylim, + zlim=self.zlim, + ) def _update_axes(self): self.ax.set_xlim(self._xlim) @@ -1556,12 +1775,21 @@ def _update_colorbar(self): fig = self.ax.get_figure() sm = plt.cm.ScalarMappable( cmap=self.voltage_colormap, - norm=mc.Normalize(vmin=self.vmin, vmax=self.vmax)) + norm=mc.Normalize(vmin=self.vmin, vmax=self.vmax), + ) self._cbar = fig.colorbar(sm, ax=self.ax) - def export_movie(self, fname, fps=30, dpi=300, decim=10, - interval=30, frame_start=0, frame_stop=None, - writer='pillow'): + def export_movie( + self, + fname, + fps=30, + dpi=300, + decim=10, + interval=30, + frame_start=0, + frame_stop=None, + writer='pillow', + ): """Export movie of network activity Parameters @@ -1589,15 +1817,18 @@ def export_movie(self, fname, fps=30, dpi=300, decim=10, import matplotlib.animation as animation if not self._vsec_recorded: - raise RuntimeError("Network must be simulated with" - "`simulate_dipole(record_vsec='all')` before" - "plotting voltages.") + raise RuntimeError( + 'Network must be simulated with' + "`simulate_dipole(record_vsec='all')` before" + 'plotting voltages.' + ) if frame_stop is None: frame_stop = len(self.times) - 1 frames = np.arange(frame_start, frame_stop, decim) ani = animation.FuncAnimation( - self.fig, self._set_time_idx, frames, interval=interval) + self.fig, self._set_time_idx, frames, interval=interval + ) writer = animation.writers[writer](fps=fps) ani.save(fname, writer=writer, dpi=dpi) @@ -1693,9 +1924,11 @@ def trial_idx(self): def trial_idx(self, trial_idx): _validate_type(trial_idx, int, 'trial_idx') if not self._vsec_recorded: - raise RuntimeError("Network must be simulated with" - "`simulate_dipole(record_vsec='all')` before" - "setting `trial_idx`.") + raise RuntimeError( + 'Network must be simulated with' + "`simulate_dipole(record_vsec='all')` before" + 'setting `trial_idx`.' + ) self._trial_idx = trial_idx self.vsec_array = self._get_voltages() self.color_array = self._colormap(self.vsec_array) @@ -1709,9 +1942,11 @@ def time_idx(self): def time_idx(self, time_idx): _validate_type(time_idx, (int, np.integer), 'time_idx') if not self._vsec_recorded: - raise RuntimeError("Network must be simulated with" - "`simulate_dipole(record_vsec='all')` before" - "setting `time_idx`.") + raise RuntimeError( + 'Network must be simulated with' + "`simulate_dipole(record_vsec='all')` before" + 'setting `time_idx`.' + ) self._time_idx = time_idx self._update_section_voltages(self._time_idx) diff --git a/pyproject.toml b/pyproject.toml index 25cd619b7..abfc924ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,3 +7,16 @@ check-hidden = true # in jupyter notebooks - images and also some embedded outputs ignore-regex = '^\s*"image/\S+": ".*|.*%22%3A%20.*' ignore-words-list = 'tha,nam,sherif,dout' + +[tool.ruff] +exclude = ["*.ipynb"] +[tool.ruff.format] +quote-style = "single" +[tool.ruff.lint] +exclude = ["__init__.py"] +# We don't include rule "W504", which was in our old flake8 "setup.cfg" file for 2 reasons: it is an +# invalid option for the ruff linter, and the ruff formatter already applies that rule. +ignore = [ + "E402", # Needed for notes at beginning of example scripts + "E722", # From original flake8 'setup.cfg' file, needed in viz.py +] diff --git a/setup.cfg b/setup.cfg index f6c32644d..cd3994bff 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,3 @@ -[flake8] -exclude = __init__.py -ignore = E722, W504 - [check-manifest] ignore = .circleci/* diff --git a/setup.py b/setup.py index 8eea61f7d..3978cbd8e 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ with open(os.path.join('hnn_core', '__init__.py'), 'r') as fid: for line in (line.strip() for line in fid): if line.startswith('__version__'): - version = line.split('=')[1].strip().strip('\'') + version = line.split('=')[1].strip().strip("'") break if version is None: raise RuntimeError('Could not determine version') @@ -47,7 +47,7 @@ def finalize_options(self): pass def run(self): - print("=> Building mod files ...") + print('=> Building mod files ...') if platform.system() == 'Windows': shell = True @@ -55,16 +55,20 @@ def run(self): shell = False mod_path = op.join(op.dirname(__file__), 'hnn_core', 'mod') - process = subprocess.Popen(['nrnivmodl'], cwd=mod_path, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, shell=shell) + process = subprocess.Popen( + ['nrnivmodl'], + cwd=mod_path, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=shell, + ) outs, errs = process.communicate() print(outs) class build_py_mod(build_py): def run(self): - self.run_command("build_mod") + self.run_command('build_mod') build_dir = op.join(self.build_lib, 'hnn_core', 'mod') mod_path = op.join(op.dirname(__file__), 'hnn_core', 'mod') @@ -73,58 +77,76 @@ def run(self): build_py.run(self) -if __name__ == "__main__": +if __name__ == '__main__': extras = { 'opt': ['scikit-learn'], 'parallel': ['joblib', 'psutil'], - 'test': ['flake8', 'pytest', 'pytest-cov', ], - 'docs': ['mne', 'nibabel', 'pooch', 'tdqm', - 'sphinx', 'sphinx-gallery', - 'sphinx_bootstrap_theme', 'sphinx-copybutton', - 'pillow', 'numpydoc', - ], - 'gui': ['ipywidgets>=8.0.0', 'ipykernel', 'ipympl', 'voila', ], + 'test': [ + 'pytest', + 'pytest-cov', + 'ruff', + ], + 'docs': [ + 'mne', + 'nibabel', + 'pooch', + 'tdqm', + 'sphinx', + 'sphinx-gallery', + 'sphinx_bootstrap_theme', + 'sphinx-copybutton', + 'pillow', + 'numpydoc', + ], + 'gui': [ + 'ipywidgets>=8.0.0', + 'ipykernel', + 'ipympl', + 'voila', + ], } - extras['dev'] = (extras['opt'] + extras['parallel'] + extras['test'] + - extras['docs'] + extras['gui'] - ) - - - setup(name=DISTNAME, - maintainer=MAINTAINER, - maintainer_email=MAINTAINER_EMAIL, - description=DESCRIPTION, - license=LICENSE, - url=URL, - version=version, - download_url=DOWNLOAD_URL, - long_description=open('README.rst').read(), - classifiers=[ - 'Intended Audience :: Science/Research', - 'Intended Audience :: Developers', - 'License :: OSI Approved', - 'Programming Language :: Python', - 'Topic :: Software Development', - 'Topic :: Scientific/Engineering', - 'Operating System :: Microsoft :: Windows', - 'Operating System :: POSIX', - 'Operating System :: Unix', - 'Operating System :: MacOS', - ], - platforms='any', - install_requires=[ - 'numpy >=1.14', - 'NEURON >=7.7; platform_system != "Windows"', - 'matplotlib>=3.5.3', - 'scipy', - 'h5io' - ], - extras_require=extras, - python_requires='>=3.8', - packages=find_packages(), - package_data={'hnn_core': [ - 'param/*.json', - 'gui/*.ipynb']}, - cmdclass={'build_py': build_py_mod, 'build_mod': BuildMod}, - entry_points={'console_scripts': ['hnn-gui=hnn_core.gui.gui:launch']} - ) + extras['dev'] = ( + extras['opt'] + + extras['parallel'] + + extras['test'] + + extras['docs'] + + extras['gui'] + ) + + setup( + name=DISTNAME, + maintainer=MAINTAINER, + maintainer_email=MAINTAINER_EMAIL, + description=DESCRIPTION, + license=LICENSE, + url=URL, + version=version, + download_url=DOWNLOAD_URL, + long_description=open('README.rst').read(), + classifiers=[ + 'Intended Audience :: Science/Research', + 'Intended Audience :: Developers', + 'License :: OSI Approved', + 'Programming Language :: Python', + 'Topic :: Software Development', + 'Topic :: Scientific/Engineering', + 'Operating System :: Microsoft :: Windows', + 'Operating System :: POSIX', + 'Operating System :: Unix', + 'Operating System :: MacOS', + ], + platforms='any', + install_requires=[ + 'numpy >=1.14', + 'NEURON >=7.7; platform_system != "Windows"', + 'matplotlib>=3.5.3', + 'scipy', + 'h5io', + ], + extras_require=extras, + python_requires='>=3.8', + packages=find_packages(), + package_data={'hnn_core': ['param/*.json', 'gui/*.ipynb']}, + cmdclass={'build_py': build_py_mod, 'build_mod': BuildMod}, + entry_points={'console_scripts': ['hnn-gui=hnn_core.gui.gui:launch']}, + )