-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor BatchSimulation and added unit test
Signed-off-by: samadpls <[email protected]>
- Loading branch information
Showing
4 changed files
with
154 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
""" | ||
================================================= | ||
==================== | ||
06. Batch Simulation | ||
================================================= | ||
==================== | ||
This example shows how to do batch simulations in HNN-core, allowing users to | ||
efficiently run multiple simulations with different parameters | ||
|
@@ -14,7 +14,6 @@ | |
# Mainak Jas <[email protected]> | ||
|
||
############################################################################### | ||
|
||
# Let us import ``hnn_core``. | ||
|
||
import numpy as np | ||
|
@@ -32,29 +31,26 @@ def set_params(param_values, net=None): | |
Parameters | ||
---------- | ||
param_grid : dict | ||
param_values : dict | ||
Dictionary of parameter values. | ||
net : instance of Network, optional | ||
If None, a new network is created using the specified model type. | ||
""" | ||
if net is None: | ||
net = jones_2009_model() | ||
|
||
weights_ampa = {'L2_basket': param_grid['weight_basket'], | ||
'L2_pyramidal': param_grid['weight_pyr'], | ||
'L5_basket': param_grid['weight_basket'], | ||
'L5_pyramidal': param_grid['weight_pyr']} | ||
weights_ampa = {'L2_basket': param_values['weight_basket'], | ||
'L2_pyramidal': param_values['weight_pyr'], | ||
'L5_basket': param_values['weight_basket'], | ||
'L5_pyramidal': param_values['weight_pyr']} | ||
|
||
synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, | ||
'L5_basket': 1., 'L5_pyramidal': 1.} | ||
|
||
mu = param_grid['mu'] | ||
sigma = param_grid['sigma'] | ||
|
||
# Add an evoked drive to the network | ||
# Add an evoked drive to the network. | ||
net.add_evoked_drive('evprox', | ||
mu=mu, | ||
sigma=sigma, | ||
mu=param_values['mu'], | ||
sigma=param_values['sigma'], | ||
numspikes=1, | ||
location='proximal', | ||
weights_ampa=weights_ampa, | ||
|
@@ -64,11 +60,13 @@ def set_params(param_values, net=None): | |
|
||
|
||
param_grid = { | ||
'weight_basket': np.logspace(-4 -1, 5), | ||
'weight_basket': np.logspace(-4 - 1, 5), | ||
'weight_pyr': np.logspace(-4, -1, 5), | ||
'mu': np.linspace(20, 80, 5), | ||
'sigma': np.linspace(1, 20, 5) | ||
} | ||
|
||
batch_simulation = BatchSimulate(set_params=set_params) | ||
simulation_results = batch_simulation.run(param_grid, n_jobs=n_jobs) | ||
simulation_results = batch_simulation.run(param_grid, | ||
n_jobs=n_jobs, | ||
combinations=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,10 @@ | ||
"""Batch simulation.""" | ||
|
||
# Authors: Abdul Samad Siddiqui <[email protected]> | ||
# Nick Tolley <[email protected]> | ||
# Ryan Thorpe <[email protected]> | ||
# Mainak Jas <[email protected]> | ||
|
||
from joblib import Parallel, delayed | ||
from hnn_core import simulate_dipole | ||
from hnn_core.network_models import (jones_2009_model, | ||
|
@@ -32,7 +38,7 @@ def __init__(self, set_params, net_name='jones', tstop=170, | |
self.n_trials = n_trials | ||
|
||
def run(self, param_grid, max_size=None, return_output=True, | ||
save_output=False, fpath='./', n_jobs=1): | ||
save_output=False, combinations=True, fpath='./', n_jobs=1): | ||
""" | ||
Run batch simulations. | ||
|
@@ -46,29 +52,31 @@ def run(self, param_grid, max_size=None, return_output=True, | |
Whether to return the simulation outputs. Default is True. | ||
save_output : bool, optional | ||
Whether to save the outputs to disk. Default is False. | ||
combinations : bool, optional | ||
Whether to generate the Cartesian product of the parameter ranges. | ||
If False, generate combinations based on corresponding indices. | ||
Default is True. | ||
fpath : str, optional | ||
File path for saving outputs. Default is './'. | ||
n_jobs : int, optional | ||
Number of parallel jobs. Default is -1. | ||
Number of parallel jobs. Default is 1. | ||
Returns | ||
------- | ||
list | ||
results : list | ||
List of simulation results if return_output is True. | ||
""" | ||
param_combinations = self._generate_param_combinations( | ||
param_grid, max_size) | ||
# print("param_combinations-->",param_combinations) | ||
param_grid, max_size, combinations) | ||
results = self.simulate_batch(param_combinations, n_jobs=n_jobs) | ||
print(results) | ||
|
||
# if save_output: | ||
# self.save(results, param_combinations, fpath, max_size) | ||
|
||
if return_output: | ||
return results | ||
|
||
def simulate_batch(self, param_combinations, n_jobs=-1): | ||
def simulate_batch(self, param_combinations, n_jobs=1): | ||
""" | ||
Simulate a batch of parameter sets in parallel. | ||
|
@@ -77,12 +85,16 @@ def simulate_batch(self, param_combinations, n_jobs=-1): | |
param_combinations : list | ||
List of parameter combinations. | ||
n_jobs : int, optional | ||
Number of parallel jobs. Default is -1. | ||
Number of parallel jobs. Default is 1. | ||
Returns | ||
------- | ||
list | ||
List of simulation results. | ||
res: list | ||
List of dictionaries containing simulation results. | ||
Each dictionary contains the following keys: | ||
- `net`: The network model used for the simulation. | ||
- `dpl`: The simulated dipole. | ||
- `param_values`: The parameter values used for the simulation. | ||
""" | ||
res = Parallel(n_jobs=n_jobs, verbose=50)( | ||
delayed(self._run_single_sim)( | ||
|
@@ -100,8 +112,12 @@ def _run_single_sim(self, param_values): | |
Returns | ||
------- | ||
instance of Dipole | ||
The simulated dipole. | ||
dict | ||
Dictionary containing the simulation results. | ||
The dictionary contains the following keys: | ||
- `net`: The network model used for the simulation. | ||
- `dpl`: The simulated dipole. | ||
- `param_values`: The parameter values used for the simulation. | ||
""" | ||
if self.net_name not in ['jones', 'law', 'calcium']: | ||
raise ValueError( | ||
|
@@ -114,13 +130,14 @@ def _run_single_sim(self, param_values): | |
net = law_2021_model() | ||
elif self.net_name == 'calcium': | ||
net = calcium_model() | ||
|
||
print(param_values) | ||
self.set_params(param_values, net) | ||
dpl = simulate_dipole(net, tstop=self.tstop, dt=self.dt, | ||
n_trials=self.n_trials) | ||
return dpl | ||
return {'net': net, 'dpl': dpl, 'param_values': param_values} | ||
|
||
def _generate_param_combinations(self, param_grid, max_size=None): | ||
def _generate_param_combinations(self, param_grid, max_size=None, | ||
combinations=True): | ||
""" | ||
Generate combinations of parameters from the grid. | ||
|
@@ -130,19 +147,27 @@ def _generate_param_combinations(self, param_grid, max_size=None): | |
Dictionary with parameter names and ranges. | ||
max_size : int, optional | ||
Maximum size of the batch. Default is None. | ||
combinations : bool, optional | ||
Whether to generate the Cartesian product of the parameter ranges. | ||
If False, generate combinations based on corresponding indices. | ||
Default is True. | ||
Returns | ||
------- | ||
list | ||
param_combinations: list | ||
List of parameter combinations. | ||
""" | ||
from itertools import product | ||
|
||
keys, values = zip(*param_grid.items()) | ||
combinations = [dict(zip(keys, combination)) | ||
for combination in product(*values)] | ||
if combinations: | ||
param_combinations = [dict(zip(keys, combination)) | ||
for combination in product(*values)] | ||
else: | ||
param_combinations = [dict(zip(keys, combination)) | ||
for combination in zip(*values)] | ||
|
||
if max_size is not None: | ||
combinations = combinations[:max_size] | ||
# if max_size is not None: | ||
# param_combinations = param_combinations[:max_size] | ||
|
||
return combinations | ||
return param_combinations |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Authors: Abdul Samad Siddiqui <[email protected]> | ||
# Nick Tolley <[email protected]> | ||
# Ryan Thorpe <[email protected]> | ||
# Mainak Jas <[email protected]> | ||
|
||
import pytest | ||
import numpy as np | ||
from hnn_core import BatchSimulate | ||
|
||
|
||
@pytest.fixture | ||
def batch_simulate_instance(): | ||
def set_params(param_values, net): | ||
weights_ampa = {'L2_basket': param_values['weight_basket'], | ||
'L2_pyramidal': param_values['weight_pyr'], | ||
'L5_basket': param_values['weight_basket'], | ||
'L5_pyramidal': param_values['weight_pyr']} | ||
|
||
synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, | ||
'L5_basket': 1., 'L5_pyramidal': 1.} | ||
|
||
mu = param_values['mu'] | ||
sigma = param_values['sigma'] | ||
net.add_evoked_drive('evprox', | ||
mu=mu, | ||
sigma=sigma, | ||
numspikes=1, | ||
location='proximal', | ||
weights_ampa=weights_ampa, | ||
synaptic_delays=synaptic_delays) | ||
|
||
return BatchSimulate(set_params=set_params, tstop=10.) | ||
|
||
|
||
@pytest.fixture | ||
def param_grid(): | ||
return { | ||
'weight_basket': np.logspace(-4 - 1, 2), | ||
'weight_pyr': np.logspace(-4, -1, 2), | ||
'mu': np.linspace(20, 80, 2), | ||
'sigma': np.linspace(1, 20, 2) | ||
} | ||
|
||
|
||
def test_generate_param_combinations(batch_simulate_instance, param_grid): | ||
"""Test generating parameter combinations.""" | ||
param_combinations = batch_simulate_instance._generate_param_combinations( | ||
param_grid) | ||
assert len(param_combinations) == ( | ||
len(param_grid['weight_basket']) * | ||
len(param_grid['weight_pyr']) * | ||
len(param_grid['mu']) * | ||
len(param_grid['sigma']) | ||
) | ||
|
||
|
||
def test_run_single_sim(batch_simulate_instance): | ||
"""Test running a single simulation.""" | ||
param_values = { | ||
'weight_basket': -3, | ||
'weight_pyr': -2, | ||
'mu': 40, | ||
'sigma': 20 | ||
} | ||
result = batch_simulate_instance._run_single_sim(param_values) | ||
assert 'net' in result | ||
assert 'dpl' in result | ||
assert 'param_values' in result | ||
assert result['param_values'] == param_values | ||
|
||
|
||
def test_simulate_batch(batch_simulate_instance, param_grid): | ||
"""Test simulating a batch of parameter sets.""" | ||
param_combinations = batch_simulate_instance._generate_param_combinations( | ||
param_grid)[:3] | ||
results = batch_simulate_instance.simulate_batch(param_combinations, | ||
n_jobs=2) | ||
assert len(results) == len(param_combinations) | ||
for result in results: | ||
assert 'net' in result | ||
assert 'dpl' in result | ||
assert 'param_values' in result | ||
|
||
|
||
def test_run(batch_simulate_instance, param_grid): | ||
results = batch_simulate_instance.run(param_grid, n_jobs=2, | ||
return_output=True, | ||
combinations=False) | ||
assert results is not None | ||
assert isinstance(results, list) | ||
assert len(results) == len( | ||
batch_simulate_instance._generate_param_combinations( | ||
param_grid, combinations=False)) |