diff --git a/examples/howto/plot_batch_simulate.py b/examples/howto/plot_batch_simulate.py index d6c378e65..35ac2a75c 100644 --- a/examples/howto/plot_batch_simulate.py +++ b/examples/howto/plot_batch_simulate.py @@ -107,10 +107,6 @@ def summary_func(results): ############################################################################### # Run the batch simulation and collect the results. -# Uncomment this code if dask and distributed Python packages are installed. -# from dask.distributed import Client -# client = Client(threads_per_worker=1, n_workers=5, processes=False) - # Initialize the network model and run the batch simulation. net = jones_2009_model(mesh_shape=(3, 3)) @@ -121,7 +117,7 @@ def summary_func(results): n_jobs=n_jobs, combinations=False, backend='loky') -# backend='dask' if installed + print("Simulation results:", simulation_results) ############################################################################### # This plot shows an overlay of all smoothed dipole waveforms from the diff --git a/hnn_core/batch_simulate.py b/hnn_core/batch_simulate.py index 9c09a5f88..168c3ea95 100644 --- a/hnn_core/batch_simulate.py +++ b/hnn_core/batch_simulate.py @@ -16,7 +16,7 @@ class BatchSimulate(object): - def __init__(self, set_params, net=None, + def __init__(self, set_params, net=jones_2009_model(), tstop=170, dt=0.025, n_trials=1, save_folder='./sim_results', batch_size=100, overwrite=True, save_outputs=False, save_dpl=True, @@ -106,7 +106,7 @@ def __init__(self, set_params, net=None, will be overwritten. """ - _validate_type(net, (Network, None), 'net', 'Network') + _validate_type(net, Network, 'net', 'Network') _validate_type(tstop, types='numeric', item_name='tstop') _validate_type(dt, types='numeric', item_name='dt') _validate_type(n_trials, types='int', item_name='n_trials') @@ -129,7 +129,7 @@ def __init__(self, set_params, net=None, if summary_func is not None and not callable(summary_func): raise TypeError("summary_func must be a callable function") - self.net = net if net is not None else jones_2009_model() + self.net = net self.set_params = set_params self.tstop = tstop self.dt = dt