Skip to content

Commit

Permalink
ENH: Refactor BatchSimulate Example and Improve Documentation (jone…
Browse files Browse the repository at this point in the history
…scompneurolab#857)

* ENH BatchSimulate for JSON path handling

Signed-off-by: samadpls <[email protected]>

* Fix: Docstring of dic_to_network

Signed-off-by: samadpls <[email protected]>

* ENH: Reorder the init params

Signed-off-by: samadpls <[email protected]>

* Added record_isec vsec validation test

Signed-off-by: samadpls <[email protected]>

* Removed net_json param and update test

Signed-off-by: samadpls <[email protected]>

* Refactored  Example and initialize parameters

Co-authored-by: Nicholas Tolley <[email protected]>
Signed-off-by: samadpls <[email protected]>

* Enh: visualization of dipole responses in plot_batch_simulate

Co-authored-by: Nicholas Tolley <[email protected]>
Signed-off-by: samadpls <[email protected]>

* Refactor batch simulation parameters and backend

Signed-off-by: samadpls <[email protected]>

* [MRG] Fix indexing for batch simulations (jonescompneurolab#5)

* ENH BatchSimulate for JSON path handling

Signed-off-by: samadpls <[email protected]>

* Fix: Docstring of dic_to_network

Signed-off-by: samadpls <[email protected]>

* ENH: Reorder the init params

Signed-off-by: samadpls <[email protected]>

* Added record_isec vsec validation test

Signed-off-by: samadpls <[email protected]>

* Removed net_json param and update test

Signed-off-by: samadpls <[email protected]>

* Refactored  Example and initialize parameters

Co-authored-by: Nicholas Tolley <[email protected]>
Signed-off-by: samadpls <[email protected]>

* Enh: visualization of dipole responses in plot_batch_simulate

Co-authored-by: Nicholas Tolley <[email protected]>
Signed-off-by: samadpls <[email protected]>

* Refactor batch simulation parameters and backend

Signed-off-by: samadpls <[email protected]>

* batches run in parallel

---------

Signed-off-by: samadpls <[email protected]>
Co-authored-by: samadpls <[email protected]>

* Refactor: Removed joblib from simulate_dipole, and added parallel execution test.

Signed-off-by: samadpls <[email protected]>

* Refactor: Simplify BatchSimulate parameters by removing n_jobs

Signed-off-by: samadpls <[email protected]>

* Remove unused Dask code and simplify BatchSimulate initialization

Signed-off-by: samadpls <[email protected]>

---------

Signed-off-by: samadpls <[email protected]>
Co-authored-by: Nicholas Tolley <[email protected]>
  • Loading branch information
samadpls and ntolley authored Dec 12, 2024
1 parent fbe1981 commit c374c6a
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 66 deletions.
66 changes: 48 additions & 18 deletions examples/howto/plot_batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,25 @@
from hnn_core import jones_2009_model

# The number of cores may need modifying depending on your current machine.
n_jobs = 10
n_jobs = 4
###############################################################################
# The `add_evoked_drive` function simulates external input to the network,
# mimicking sensory stimulation or other external events.
#
# - `evprox` indicates a proximal drive, targeting dendrites near the cell
# bodies.
# - `mu=40` and `sigma=5` define the timing (mean and spread) of the input.
# - `weights_ampa` and `synaptic_delays` control the strength and
# timing of the input.
#
# This evoked drive causes the initial positive deflection in the dipole
# signal, triggering a cascade of activity through the network and
# resulting in the complex waveforms observed.


def set_params(param_values, net=None):
"""
Set parameters in the network drives.
Set parameters for the network drives.
Parameters
----------
Expand All @@ -57,16 +69,16 @@ def set_params(param_values, net=None):
synaptic_delays=synaptic_delays)

###############################################################################
# Define a parameter grid for the batch simulation.
# Next, we define a parameter grid for the batch simulation.


param_grid = {
'weight_basket': np.logspace(-4 - 1, 10),
'weight_pyr': np.logspace(-4, -1, 10)
'weight_basket': np.logspace(-4, -1, 20),
'weight_pyr': np.logspace(-4, -1, 20)
}

###############################################################################
# Define a function to calculate summary statistics
# We then define a function to calculate summary statistics.


def summary_func(results):
Expand Down Expand Up @@ -95,36 +107,54 @@ def summary_func(results):
###############################################################################
# Run the batch simulation and collect the results.

# Comment off this code, if dask and distributed Python packages are installed
# from dask.distributed import Client
# client = Client(threads_per_worker=1, n_workers=5, processes=False)


# Run the batch simulation and collect the results.
# Initialize the network model and run the batch simulation.
net = jones_2009_model(mesh_shape=(3, 3))
batch_simulation = BatchSimulate(net=net,
set_params=set_params,
summary_func=summary_func)
simulation_results = batch_simulation.run(param_grid,
n_jobs=n_jobs,
combinations=False,
backend='multiprocessing')
# backend='dask' if installed
backend='loky')

print("Simulation results:", simulation_results)
###############################################################################
# This plot shows an overlay of all smoothed dipole waveforms from the
# batch simulation. Each line represents a different set of parameters,
# allowing us to visualize the range of responses across the parameter space.
# batch simulation. Each line represents a different set of synaptic strength
# parameters (`weight_basket`), allowing us to visualize the range of responses
# across the parameter space.
# The colormap represents synaptic strengths, from weaker (purple)
# to stronger (yellow).
#
# As drive strength increases, dipole responses show progressively larger
# amplitudes and more distinct features, reflecting heightened network
# activity. Weak drives (purple lines) produce smaller amplitude signals with
# simpler waveforms, while stronger drives (yellow lines) generate
# larger responses with more pronounced oscillatory features, indicating
# more robust network activity.
#
# The y-axis represents dipole amplitude in nAm (nanoAmpere-meters), which is
# the product of current flow and distance in the neural tissue.
#
# Stronger synaptic connections (yellow lines) generally show larger
# amplitude responses and more pronounced features throughout the simulation.

dpl_waveforms = []
dpl_waveforms, param_values = [], []
for data_list in simulation_results['simulated_data']:
for data in data_list:
dpl_smooth = data['dpl'][0].copy().smooth(window_len=30)
dpl_waveforms.append(dpl_smooth.data['agg'])
param_values.append(data['param_values']['weight_basket'])

plt.figure(figsize=(10, 6))
for waveform in dpl_waveforms:
plt.plot(waveform, alpha=0.5, linewidth=3)
cmap = plt.get_cmap('viridis')
log_param_values = np.log10(param_values)
norm = plt.Normalize(log_param_values.min(), log_param_values.max())

for waveform, log_param in zip(dpl_waveforms, log_param_values):
color = cmap(norm(log_param))
plt.plot(waveform, color=color, alpha=0.7, linewidth=2)
plt.title('Overlay of Dipole Waveforms')
plt.xlabel('Time (ms)')
plt.ylabel('Dipole Amplitude (nAm)')
Expand Down
78 changes: 39 additions & 39 deletions hnn_core/batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@


class BatchSimulate(object):
def __init__(self, set_params, net=jones_2009_model(), tstop=170,
dt=0.025, n_trials=1, record_vsec=False,
record_isec=False, postproc=False, save_outputs=False,
def __init__(self, set_params, net=jones_2009_model(),
tstop=170, dt=0.025, n_trials=1,
save_folder='./sim_results', batch_size=100,
overwrite=True, summary_func=None,
save_dpl=True, save_spiking=False,
save_lfp=False, save_voltages=False,
save_currents=False, save_calcium=False,
clear_cache=False):
overwrite=True, save_outputs=False, save_dpl=True,
save_spiking=False, save_lfp=False, save_voltages=False,
save_currents=False, save_calcium=False, record_vsec=False,
record_isec=False, postproc=False, clear_cache=False,
summary_func=None):
"""Initialize the BatchSimulate class.
Parameters
Expand All @@ -37,34 +36,17 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
where ``net`` is a Network object and ``params`` is a dictionary
of the parameters that will be set inside the function.
net : Network object, optional
The network model to use for simulations. Must be an instance of
jones_2009_model, law_2021_model, or calcium_model.
Default is jones_2009_model().
The network model to use for simulations. Examples include:
- `jones_2009_model`: A network model based on Jones et al. (2009).
- `law_2021_model`: A network model based on Law et al. (2021).
- `calcium_model`: A network model incorporating calcium dynamics.
Default is `jones_2009_model()`
tstop : float, optional
The stop time for the simulation. Default is 170 ms.
dt : float, optional
The time step for the simulation. Default is 0.025 ms.
n_trials : int, optional
The number of trials for the simulation. Default is 1.
record_vsec : 'all' | 'soma' | False
Option to record voltages from all sections ('all'), or just
the soma ('soma'). Default: False.
record_isec : 'all' | 'soma' | False
Option to record voltages from all sections ('all'), or just
the soma ('soma'). Default: False.
postproc : bool
If True, smoothing (``dipole_smooth_win``) and scaling
(``dipole_scalefctr``) values are read from the parameter file, and
applied to the dipole objects before returning.
Note that this setting
only affects the dipole waveforms, and not somatic voltages,
possible extracellular recordings etc.
The preferred way is to use the
:meth:`~hnn_core.dipole.Dipole.smooth` and
:meth:`~hnn_core.dipole.Dipole.scale` methods instead.
Default: False.
save_outputs : bool, optional
Whether to save the simulation outputs to files. Default is False.
save_folder : str, optional
The path to save the simulation outputs.
Default is './sim_results'.
Expand All @@ -74,9 +56,8 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
overwrite : bool, optional
Whether to overwrite existing files and create file paths
if they do not exist. Default is True.
summary_func : callable, optional
A function to calculate summary statistics from the simulation
results. Default is None.
save_outputs : bool, optional
Whether to save the simulation outputs to files. Default is False.
save_dpl : bool
If True, save dipole results. Note, `save_outputs` must be True.
Default: True.
Expand All @@ -97,9 +78,23 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170,
If True, save calcium concentrations.
Note, `save_outputs` must be True.
Default: False.
record_vsec : 'all' | 'soma' | False
Option to record voltages from all sections ('all'), or just
the soma ('soma'). Default: False.
record_isec : 'all' | 'soma' | False
Option to record voltages from all sections ('all'), or just
the soma ('soma'). Default: False.
postproc : bool
If True, smoothing (``dipole_smooth_win``) and scaling
(``dipole_scalefctr``) values are read from the parameter file, and
applied to the dipole objects before returning.
Default: False.
clear_cache : bool, optional
Whether to clear the results cache after saving each batch.
Default is False.
summary_func : callable, optional
A function to calculate summary statistics from the simulation
results. Default is None.
Notes
-----
When `save_output=True`, the saved files will appear as
Expand Down Expand Up @@ -201,16 +196,13 @@ def run(self, param_grid, return_output=True,
param_combinations = self._generate_param_combinations(
param_grid, combinations)
total_sims = len(param_combinations)
num_sims_per_batch = max(total_sims // self.batch_size, 1)
batch_size = min(self.batch_size, total_sims)

results = []
simulated_data = []
for i in range(batch_size):
start_idx = i * num_sims_per_batch
end_idx = start_idx + num_sims_per_batch
if i == batch_size - 1:
end_idx = len(param_combinations)
for i in range(0, total_sims, batch_size):
start_idx = i
end_idx = min(i + batch_size, total_sims)
batch_results = self.simulate_batch(
param_combinations[start_idx:end_idx],
n_jobs=n_jobs,
Expand Down Expand Up @@ -388,6 +380,14 @@ def _save(self, results, start_idx, end_idx):
if getattr(self, f'save_{attr}') and attr in results[0]:
save_data[attr] = [result[attr] for result in results]

metadata = {
'batch_size': self.batch_size,
'n_trials': self.n_trials,
'tstop': self.tstop,
'dt': self.dt
}
save_data['metadata'] = metadata

file_name = os.path.join(self.save_folder,
f'sim_run_{start_idx}-{end_idx}.npz')
if os.path.exists(file_name) and not self.overwrite:
Expand Down
12 changes: 6 additions & 6 deletions hnn_core/hnn_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,12 +397,12 @@ def dict_to_network(net_data,
Parameters
----------
fname : str or Path
Path to configuration file
read_drives : bool
Read-in drives to Network object
read_external_biases
Read-in external biases to Network object
net_data : dict
Dictionary containing network configurations.
read_drives : bool, optional
Read-in drives to Network object. Default is True.
read_external_biases : bool, optional
Read-in external biases to Network object. Default is True.
Returns : Network
-------
Expand Down
42 changes: 39 additions & 3 deletions hnn_core/tests/test_batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
# Ryan Thorpe <[email protected]>
# Mainak Jas <[email protected]>

from pathlib import Path
import time
import pytest
import numpy as np
import os

from hnn_core.batch_simulate import BatchSimulate
from hnn_core import jones_2009_model

hnn_core_root = Path(__file__).parents[1]
assets_path = Path(hnn_core_root, 'tests', 'assets')


@pytest.fixture
def batch_simulate_instance(tmp_path):
Expand All @@ -33,11 +38,12 @@ def set_params(param_values, net):
weights_ampa=weights_ampa,
synaptic_delays=synaptic_delays)

net = jones_2009_model()
net = jones_2009_model(mesh_shape=(3, 3))
return BatchSimulate(net=net, set_params=set_params,
tstop=1.,
tstop=10,
save_folder=tmp_path,
batch_size=3)
batch_size=3,
n_trials=3,)


@pytest.fixture
Expand Down Expand Up @@ -75,6 +81,12 @@ def test_parameter_validation():
with pytest.raises(TypeError, match="net must be"):
BatchSimulate(net="invalid_network", set_params=lambda x: x)

with pytest.raises(ValueError, match="'record_vsec' parameter"):
BatchSimulate(set_params=lambda x: x, record_vsec="invalid")

with pytest.raises(ValueError, match="'record_isec' parameter"):
BatchSimulate(set_params=lambda x: x, record_isec="invalid")


def test_generate_param_combinations(batch_simulate_instance, param_grid):
"""Test generating parameter combinations."""
Expand Down Expand Up @@ -280,3 +292,27 @@ def test_load_results(batch_simulate_instance, param_grid, tmp_path):
# Validation Tests
with pytest.raises(TypeError, match='results must be'):
batch_simulate_instance._save("invalid_results", start_idx, end_idx)


def test_parallel_execution(batch_simulate_instance, param_grid):
"""Test parallel execution of simulations and ensure speedup."""

param_combinations = batch_simulate_instance._generate_param_combinations(
param_grid)

start_time = time.perf_counter()
_ = batch_simulate_instance.simulate_batch(
param_combinations, n_jobs=1, backend='loky')
end_time = time.perf_counter()
serial_time = end_time - start_time

start_time = time.perf_counter()
_ = batch_simulate_instance.simulate_batch(
param_combinations,
n_jobs=2,
backend='loky')
end_time = time.perf_counter()
parallel_time = end_time - start_time

assert (serial_time > parallel_time
), "Parallel execution is not faster than serial execution!"

0 comments on commit c374c6a

Please sign in to comment.