Skip to content

Commit

Permalink
Refactor BatchSimulation and added unit test
Browse files Browse the repository at this point in the history
Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls committed Jun 12, 2024
1 parent cbaeffa commit 724d001
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 36 deletions.
26 changes: 12 additions & 14 deletions examples/howto/batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# Mainak Jas <[email protected]>

###############################################################################

# Let us import ``hnn_core``.

import numpy as np
Expand All @@ -32,29 +31,26 @@ def set_params(param_values, net=None):
Parameters
----------
param_grid : dict
param_values : dict
Dictionary of parameter values.
net : instance of Network, optional
If None, a new network is created using the specified model type.
"""
if net is None:
net = jones_2009_model()

weights_ampa = {'L2_basket': param_grid['weight_basket'],
'L2_pyramidal': param_grid['weight_pyr'],
'L5_basket': param_grid['weight_basket'],
'L5_pyramidal': param_grid['weight_pyr']}
weights_ampa = {'L2_basket': param_values['weight_basket'],
'L2_pyramidal': param_values['weight_pyr'],
'L5_basket': param_values['weight_basket'],
'L5_pyramidal': param_values['weight_pyr']}

synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1,
'L5_basket': 1., 'L5_pyramidal': 1.}

mu = param_grid['mu']
sigma = param_grid['sigma']

# Add an evoked drive to the network
# Add an evoked drive to the network.
net.add_evoked_drive('evprox',
mu=mu,
sigma=sigma,
mu=param_values['mu'],
sigma=param_values['sigma'],
numspikes=1,
location='proximal',
weights_ampa=weights_ampa,
Expand All @@ -64,11 +60,13 @@ def set_params(param_values, net=None):


param_grid = {
'weight_basket': np.logspace(-4 -1, 5),
'weight_basket': np.logspace(-4 - 1, 5),
'weight_pyr': np.logspace(-4, -1, 5),
'mu': np.linspace(20, 80, 5),
'sigma': np.linspace(1, 20, 5)
}

batch_simulation = BatchSimulate(set_params=set_params)
simulation_results = batch_simulation.run(param_grid, n_jobs=n_jobs)
simulation_results = batch_simulation.run(param_grid,
n_jobs=n_jobs,
combinations=False)
2 changes: 1 addition & 1 deletion hnn_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
from .cell_response import CellResponse, read_spikes
from .cells_default import pyramidal, basket
from .parallel_backends import MPIBackend, JoblibBackend
from .batchsimulate import BatchSimulate
from .batch_simulate import BatchSimulate

__version__ = '0.4.dev0'
67 changes: 46 additions & 21 deletions hnn_core/batchsimulate.py → hnn_core/batch_simulate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
"""Batch simulation."""

# Authors: Abdul Samad Siddiqui <[email protected]>
# Nick Tolley <[email protected]>
# Ryan Thorpe <[email protected]>
# Mainak Jas <[email protected]>

from joblib import Parallel, delayed
from hnn_core import simulate_dipole
from hnn_core.network_models import (jones_2009_model,
Expand Down Expand Up @@ -32,7 +38,7 @@ def __init__(self, set_params, net_name='jones', tstop=170,
self.n_trials = n_trials

def run(self, param_grid, max_size=None, return_output=True,
save_output=False, fpath='./', n_jobs=1):
save_output=False, combinations=True, fpath='./', n_jobs=1):
"""
Run batch simulations.
Expand All @@ -46,29 +52,31 @@ def run(self, param_grid, max_size=None, return_output=True,
Whether to return the simulation outputs. Default is True.
save_output : bool, optional
Whether to save the outputs to disk. Default is False.
combinations : bool, optional
Whether to generate the Cartesian product of the parameter ranges.
If False, generate combinations based on corresponding indices.
Default is True.
fpath : str, optional
File path for saving outputs. Default is './'.
n_jobs : int, optional
Number of parallel jobs. Default is -1.
Number of parallel jobs. Default is 1.
Returns
-------
list
results : list
List of simulation results if return_output is True.
"""
param_combinations = self._generate_param_combinations(
param_grid, max_size)
# print("param_combinations-->",param_combinations)
param_grid, max_size, combinations)
results = self.simulate_batch(param_combinations, n_jobs=n_jobs)
print(results)

# if save_output:
# self.save(results, param_combinations, fpath, max_size)

if return_output:
return results

def simulate_batch(self, param_combinations, n_jobs=-1):
def simulate_batch(self, param_combinations, n_jobs=1):
"""
Simulate a batch of parameter sets in parallel.
Expand All @@ -77,12 +85,16 @@ def simulate_batch(self, param_combinations, n_jobs=-1):
param_combinations : list
List of parameter combinations.
n_jobs : int, optional
Number of parallel jobs. Default is -1.
Number of parallel jobs. Default is 1.
Returns
-------
list
List of simulation results.
res: list
List of dictionaries containing simulation results.
Each dictionary contains the following keys:
- `net`: The network model used for the simulation.
- `dpl`: The simulated dipole.
- `param_values`: The parameter values used for the simulation.
"""
res = Parallel(n_jobs=n_jobs, verbose=50)(
delayed(self._run_single_sim)(
Expand All @@ -100,8 +112,12 @@ def _run_single_sim(self, param_values):
Returns
-------
instance of Dipole
The simulated dipole.
dict
Dictionary containing the simulation results.
The dictionary contains the following keys:
- `net`: The network model used for the simulation.
- `dpl`: The simulated dipole.
- `param_values`: The parameter values used for the simulation.
"""
if self.net_name not in ['jones', 'law', 'calcium']:
raise ValueError(
Expand All @@ -114,13 +130,14 @@ def _run_single_sim(self, param_values):
net = law_2021_model()
elif self.net_name == 'calcium':
net = calcium_model()

print(param_values)
self.set_params(param_values, net)
dpl = simulate_dipole(net, tstop=self.tstop, dt=self.dt,
n_trials=self.n_trials)
return dpl
return {'net': net, 'dpl': dpl, 'param_values': param_values}

def _generate_param_combinations(self, param_grid, max_size=None):
def _generate_param_combinations(self, param_grid, max_size=None,
combinations=True):
"""
Generate combinations of parameters from the grid.
Expand All @@ -130,19 +147,27 @@ def _generate_param_combinations(self, param_grid, max_size=None):
Dictionary with parameter names and ranges.
max_size : int, optional
Maximum size of the batch. Default is None.
combinations : bool, optional
Whether to generate the Cartesian product of the parameter ranges.
If False, generate combinations based on corresponding indices.
Default is True.
Returns
-------
list
param_combinations: list
List of parameter combinations.
"""
from itertools import product

keys, values = zip(*param_grid.items())
combinations = [dict(zip(keys, combination))
for combination in product(*values)]
if combinations:
param_combinations = [dict(zip(keys, combination))
for combination in product(*values)]
else:
param_combinations = [dict(zip(keys, combination))
for combination in zip(*values)]

if max_size is not None:
combinations = combinations[:max_size]
# if max_size is not None:
# param_combinations = param_combinations[:max_size]

return combinations
return param_combinations
93 changes: 93 additions & 0 deletions hnn_core/tests/test_batch_simulate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Authors: Abdul Samad Siddiqui <[email protected]>
# Nick Tolley <[email protected]>
# Ryan Thorpe <[email protected]>
# Mainak Jas <[email protected]>

import pytest
import numpy as np
from hnn_core import BatchSimulate


@pytest.fixture
def batch_simulate_instance():
def set_params(param_values, net):
weights_ampa = {'L2_basket': param_values['weight_basket'],
'L2_pyramidal': param_values['weight_pyr'],
'L5_basket': param_values['weight_basket'],
'L5_pyramidal': param_values['weight_pyr']}

synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1,
'L5_basket': 1., 'L5_pyramidal': 1.}

mu = param_values['mu']
sigma = param_values['sigma']
net.add_evoked_drive('evprox',
mu=mu,
sigma=sigma,
numspikes=1,
location='proximal',
weights_ampa=weights_ampa,
synaptic_delays=synaptic_delays)

return BatchSimulate(set_params=set_params, tstop=10.)


@pytest.fixture
def param_grid():
return {
'weight_basket': np.logspace(-4 - 1, 2),
'weight_pyr': np.logspace(-4, -1, 2),
'mu': np.linspace(20, 80, 2),
'sigma': np.linspace(1, 20, 2)
}


def test_generate_param_combinations(batch_simulate_instance, param_grid):
"""Test generating parameter combinations."""
param_combinations = batch_simulate_instance._generate_param_combinations(
param_grid)
assert len(param_combinations) == (
len(param_grid['weight_basket']) *
len(param_grid['weight_pyr']) *
len(param_grid['mu']) *
len(param_grid['sigma'])
)


def test_run_single_sim(batch_simulate_instance):
"""Test running a single simulation."""
param_values = {
'weight_basket': -3,
'weight_pyr': -2,
'mu': 40,
'sigma': 20
}
result = batch_simulate_instance._run_single_sim(param_values)
assert 'net' in result
assert 'dpl' in result
assert 'param_values' in result
assert result['param_values'] == param_values


def test_simulate_batch(batch_simulate_instance, param_grid):
"""Test simulating a batch of parameter sets."""
param_combinations = batch_simulate_instance._generate_param_combinations(
param_grid)[:3]
results = batch_simulate_instance.simulate_batch(param_combinations,
n_jobs=2)
assert len(results) == len(param_combinations)
for result in results:
assert 'net' in result
assert 'dpl' in result
assert 'param_values' in result


def test_run(batch_simulate_instance, param_grid):
results = batch_simulate_instance.run(param_grid, n_jobs=2,
return_output=True,
combinations=False)
assert results is not None
assert isinstance(results, list)
assert len(results) == len(
batch_simulate_instance._generate_param_combinations(
param_grid, combinations=False))

0 comments on commit 724d001

Please sign in to comment.