Skip to content

Commit

Permalink
Fix a bunch of tests :)
Browse files Browse the repository at this point in the history
  • Loading branch information
ntolley authored and jasmainak committed May 19, 2021
1 parent fffab09 commit 44dc87d
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 16 deletions.
5 changes: 3 additions & 2 deletions hnn_core/tests/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import pytest

import hnn_core
from hnn_core import read_params, read_dipole, average_dipoles, Network
from hnn_core import read_params, read_dipole, average_dipoles
from hnn_core import Network, default_network
from hnn_core.viz import plot_dipole
from hnn_core.dipole import Dipole, simulate_dipole
from hnn_core.parallel_backends import requires_mpi4py, requires_psutil
Expand Down Expand Up @@ -87,7 +88,7 @@ def test_dipole_simulation():
't_evprox_1': 5,
't_evdist_1': 10,
't_evprox_2': 20})
net = Network(params, add_drives_from_params=True)
net = default_network(params, add_drives_from_params=True)
with pytest.raises(ValueError, match="Invalid number of simulations: 0"):
simulate_dipole(net, n_trials=0)
with pytest.raises(TypeError, match="record_vsoma must be bool, got int"):
Expand Down
6 changes: 3 additions & 3 deletions hnn_core/tests/test_mpi_child.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

import hnn_core
from hnn_core import read_params, Network
from hnn_core import read_params, Network, default_network
from hnn_core.mpi_child import (MPISimulation, _str_to_net, _pickle_data)
from hnn_core.parallel_backends import (_gather_trial_data,
_process_child_data,
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_str_to_net():
# prepare network
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)
net = Network(params, add_drives_from_params=True)
net = default_network(params, add_drives_from_params=True)

pickled_net = _pickle_data(net)

Expand Down Expand Up @@ -120,7 +120,7 @@ def test_child_run():
't_evdist_1': 10,
't_evprox_2': 20,
'N_trials': 2})
net_reduced = Network(params_reduced, add_drives_from_params=True)
net_reduced = default_network(params_reduced, add_drives_from_params=True)

with MPISimulation(skip_mpi_import=True) as mpi_sim:
with io.StringIO() as buf, redirect_stdout(buf):
Expand Down
10 changes: 5 additions & 5 deletions hnn_core/tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

import hnn_core
from hnn_core import read_params, Network, CellResponse
from hnn_core import read_params, default_network, CellResponse
from hnn_core.network_builder import NetworkBuilder


Expand All @@ -23,7 +23,7 @@ def test_network():
'input_prox_A_weight_L2Pyr_ampa': 5.4e-5,
'input_prox_A_weight_L5Pyr_ampa': 5.4e-5,
't0_input_prox': 50})
net = Network(deepcopy(params), add_drives_from_params=True)
net = default_network(deepcopy(params), add_drives_from_params=True)
network_builder = NetworkBuilder(net) # needed to populate net.cells

# Assert that params are conserved across Network initialization
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_network():
assert nc.threshold == params['threshold']

# create a new connection between cell types
net = Network(deepcopy(params), add_drives_from_params=True)
net = default_network(deepcopy(params), add_drives_from_params=True)
nc_dict = {'A_delay': 1, 'A_weight': 1e-5, 'lamtha': 20,
'threshold': 0.5}
net._all_to_all_connect('bursty1', 'L5_basket',
Expand All @@ -166,15 +166,15 @@ def test_network():
assert len(network_builder.ncs['bursty1_L5Basket_gabaa']) == n_conn

# try unique=True
net = Network(deepcopy(params), add_drives_from_params=True)
net = default_network(deepcopy(params), add_drives_from_params=True)
net._all_to_all_connect('extgauss', 'L5_basket',
'soma', 'gabaa', nc_dict, unique=True)
network_builder = NetworkBuilder(net)
n_conn = len(net.gid_ranges['L5_basket'])
assert len(network_builder.ncs['extgauss_L5Basket_gabaa']) == n_conn

# Test inputs for connectivity API
net = Network(deepcopy(params), add_drives_from_params=True)
net = default_network(deepcopy(params), add_drives_from_params=True)
n_conn = len(network_builder.ncs['L2Basket_L2Pyr_gabaa'])
kwargs_default = dict(src_gids=[0, 1], target_gids=[35, 36],
loc='soma', receptor='gabaa',
Expand Down
6 changes: 3 additions & 3 deletions hnn_core/tests/test_parallel_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mne.utils import _fetch_file

import hnn_core
from hnn_core import MPIBackend, Network, read_params
from hnn_core import MPIBackend, default_network, read_params
from hnn_core.dipole import simulate_dipole
from hnn_core.parallel_backends import requires_mpi4py, requires_psutil

Expand Down Expand Up @@ -99,7 +99,7 @@ def test_terminate_mpibackend(self, run_hnn_core_fixture):
't_evdist_1': 10,
't_evprox_2': 20,
'N_trials': 2})
net = Network(params, add_drives_from_params=True)
net = default_network(params, add_drives_from_params=True)

with MPIBackend() as backend:
event = Event()
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_run_mpibackend_oversubscribed(self, run_hnn_core_fixture):
't_evdist_1': 10,
't_evprox_2': 20,
'N_trials': 2})
net = Network(params, add_drives_from_params=True)
net = default_network(params, add_drives_from_params=True)

oversubscribed = round(cpu_count() * 1.5)
with MPIBackend(n_procs=oversubscribed) as backend:
Expand Down
6 changes: 3 additions & 3 deletions hnn_core/tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

import hnn_core
from hnn_core import read_params, Network
from hnn_core import read_params, default_network
from hnn_core.viz import (plot_cells, plot_dipole, plot_psd, plot_tfr_morlet,
plot_cell_morphology)
from hnn_core.dipole import simulate_dipole
Expand All @@ -20,7 +20,7 @@ def test_network_visualization():
params = read_params(params_fname)
params.update({'N_pyr_x': 3,
'N_pyr_y': 3})
net = Network(params)
net = default_network(params)
plot_cells(net)
with pytest.raises(ValueError, match='Unrecognized cell type'):
plot_cell_morphology(cell_types='blah')
Expand All @@ -37,7 +37,7 @@ def test_dipole_visualization():
params.update({'N_pyr_x': 3,
'N_pyr_y': 3,
'tstop': 100.})
net = Network(params)
net = default_network(params)
weights_ampa_p = {'L2_pyramidal': 5.4e-5, 'L5_pyramidal': 5.4e-5}
syn_delays_p = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.}

Expand Down

0 comments on commit 44dc87d

Please sign in to comment.