diff --git a/examples/howto/plot_batch_simulate.py b/examples/howto/plot_batch_simulate.py index f683a6087..77e86c278 100644 --- a/examples/howto/plot_batch_simulate.py +++ b/examples/howto/plot_batch_simulate.py @@ -119,8 +119,8 @@ def summary_func(results): summary_func=summary_func) simulation_results = batch_simulation.run(param_grid, n_jobs=n_jobs, - combinations=False, - backend='multiprocessing') + combinations=True, + backend='loky') # backend='dask' if installed print("Simulation results:", simulation_results) ############################################################################### diff --git a/hnn_core/batch_simulate.py b/hnn_core/batch_simulate.py index b096819cc..880224949 100644 --- a/hnn_core/batch_simulate.py +++ b/hnn_core/batch_simulate.py @@ -9,6 +9,7 @@ import os from joblib import Parallel, delayed, parallel_config +from .parallel_backends import JoblibBackend from .network import Network from .externals.mne import _validate_type, _check_option from .dipole import simulate_dipole @@ -196,14 +197,14 @@ def run(self, param_grid, return_output=True, param_combinations = self._generate_param_combinations( param_grid, combinations) total_sims = len(param_combinations) - num_sims_per_batch = max(total_sims // self.batch_size, 1) + num_sims_per_batch = max(total_sims // n_jobs, 1) batch_size = min(self.batch_size, total_sims) results = [] simulated_data = [] - for i in range(batch_size): - start_idx = i * num_sims_per_batch - end_idx = start_idx + num_sims_per_batch + for i in range(0, total_sims, num_sims_per_batch): + start_idx = i + end_idx = min(i + num_sims_per_batch, total_sims) if i == batch_size - 1: end_idx = len(param_combinations) batch_results = self.simulate_batch( @@ -269,10 +270,10 @@ def simulate_batch(self, param_combinations, n_jobs=1, with parallel_config(backend=backend): res = Parallel(n_jobs=n_jobs, verbose=verbose)( delayed(self._run_single_sim)( - params) for params in param_combinations) + params, n_jobs) for params in param_combinations) return res - def _run_single_sim(self, param_values): + def _run_single_sim(self, param_values, n_jobs=1): """Run a single simulation. Parameters @@ -296,14 +297,15 @@ def _run_single_sim(self, param_values): results = {'net': net, 'param_values': param_values} if self.save_dpl: - dpl = simulate_dipole(net, - tstop=self.tstop, - dt=self.dt, - n_trials=self.n_trials, - record_vsec=self.record_vsec, - record_isec=self.record_isec, - postproc=self.postproc) - results['dpl'] = dpl + with JoblibBackend(n_jobs=n_jobs): + dpl = simulate_dipole(net, + tstop=self.tstop, + dt=self.dt, + n_trials=self.n_trials, + record_vsec=self.record_vsec, + record_isec=self.record_isec, + postproc=self.postproc) + results['dpl'] = dpl if self.save_spiking: results['spiking'] = {