Skip to content

Commit

Permalink
Refactor: Removed joblib from simulate_dipole, and added parallel exe…
Browse files Browse the repository at this point in the history
…cution test.

Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls committed Nov 26, 2024
1 parent 2f79384 commit 74b16cd
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 12 deletions.
5 changes: 3 additions & 2 deletions examples/howto/plot_batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from hnn_core import jones_2009_model

# The number of cores may need modifying depending on your current machine.
n_jobs = 10
n_jobs = 4
###############################################################################
# The `add_evoked_drive` function simulates external input to the network,
# mimicking sensory stimulation or other external events.
Expand Down Expand Up @@ -116,7 +116,8 @@ def summary_func(results):
net = jones_2009_model(mesh_shape=(3, 3))
batch_simulation = BatchSimulate(net=net,
set_params=set_params,
summary_func=summary_func)
summary_func=summary_func,
n_trials=10)
simulation_results = batch_simulation.run(param_grid,
n_jobs=n_jobs,
combinations=False,
Expand Down
18 changes: 8 additions & 10 deletions hnn_core/batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import os
from joblib import Parallel, delayed, parallel_config

from .parallel_backends import JoblibBackend
from .network import Network
from .externals.mne import _validate_type, _check_option
from .dipole import simulate_dipole
Expand Down Expand Up @@ -294,15 +293,14 @@ def _run_single_sim(self, param_values, n_jobs=1):
results = {'net': net, 'param_values': param_values}

if self.save_dpl:
with JoblibBackend(n_jobs=n_jobs):
dpl = simulate_dipole(net,
tstop=self.tstop,
dt=self.dt,
n_trials=self.n_trials,
record_vsec=self.record_vsec,
record_isec=self.record_isec,
postproc=self.postproc)
results['dpl'] = dpl
dpl = simulate_dipole(net,
tstop=self.tstop,
dt=self.dt,
n_trials=self.n_trials,
record_vsec=self.record_vsec,
record_isec=self.record_isec,
postproc=self.postproc)
results['dpl'] = dpl

if self.save_spiking:
results['spiking'] = {
Expand Down
25 changes: 25 additions & 0 deletions hnn_core/tests/test_batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Mainak Jas <[email protected]>

from pathlib import Path
import time
import pytest
import numpy as np
import os
Expand Down Expand Up @@ -290,3 +291,27 @@ def test_load_results(batch_simulate_instance, param_grid, tmp_path):
# Validation Tests
with pytest.raises(TypeError, match='results must be'):
batch_simulate_instance._save("invalid_results", start_idx, end_idx)


def test_parallel_execution(batch_simulate_instance, param_grid):
"""Test parallel execution of simulations and ensure speedup."""

param_combinations = batch_simulate_instance._generate_param_combinations(
param_grid)

start_time = time.perf_counter()
results_serial = batch_simulate_instance.simulate_batch(
param_combinations, n_jobs=1, backend='loky')
end_time = time.perf_counter()
serial_time = end_time - start_time

start_time = time.perf_counter()
results_parallel = batch_simulate_instance.simulate_batch(
param_combinations,
n_jobs=4,
backend='loky')
end_time = time.perf_counter()
parallel_time = end_time - start_time

assert (serial_time > parallel_time
), "Parallel execution is not faster than serial execution!"

0 comments on commit 74b16cd

Please sign in to comment.