forked from jonescompneurolab/hnn-core
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Mainak Jas <[email protected]> Signed-off-by: samadpls <[email protected]>
- Loading branch information
Showing
5 changed files
with
162 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,8 @@ env/ | |
venv/ | ||
ENV/ | ||
VENV/ | ||
|
||
hnn_core/mod/* | ||
test.* | ||
# Sphinx documentation | ||
doc/_build/** | ||
pip-log.txt | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
""" | ||
================================================= | ||
==================== | ||
06. Batch Simulation | ||
================================================= | ||
==================== | ||
This example shows how to do batch simulations in HNN-core, allowing users to | ||
efficiently run multiple simulations with different parameters | ||
|
@@ -14,7 +14,6 @@ | |
# Mainak Jas <[email protected]> | ||
|
||
############################################################################### | ||
|
||
# Let us import ``hnn_core``. | ||
|
||
import numpy as np | ||
|
@@ -32,29 +31,26 @@ def set_params(param_values, net=None): | |
Parameters | ||
---------- | ||
param_grid : dict | ||
param_values : dict | ||
Dictionary of parameter values. | ||
net : instance of Network, optional | ||
If None, a new network is created using the specified model type. | ||
""" | ||
if net is None: | ||
net = jones_2009_model() | ||
|
||
weights_ampa = {'L2_basket': param_grid['weight_basket'], | ||
'L2_pyramidal': param_grid['weight_pyr'], | ||
'L5_basket': param_grid['weight_basket'], | ||
'L5_pyramidal': param_grid['weight_pyr']} | ||
weights_ampa = {'L2_basket': param_values['weight_basket'], | ||
'L2_pyramidal': param_values['weight_pyr'], | ||
'L5_basket': param_values['weight_basket'], | ||
'L5_pyramidal': param_values['weight_pyr']} | ||
|
||
synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, | ||
'L5_basket': 1., 'L5_pyramidal': 1.} | ||
|
||
mu = param_grid['mu'] | ||
sigma = param_grid['sigma'] | ||
|
||
# Add an evoked drive to the network | ||
# Add an evoked drive to the network. | ||
net.add_evoked_drive('evprox', | ||
mu=mu, | ||
sigma=sigma, | ||
mu=param_values['mu'], | ||
sigma=param_values['sigma'], | ||
numspikes=1, | ||
location='proximal', | ||
weights_ampa=weights_ampa, | ||
|
@@ -64,11 +60,13 @@ def set_params(param_values, net=None): | |
|
||
|
||
param_grid = { | ||
'weight_basket': np.logspace(-4 -1, 5), | ||
'weight_basket': np.logspace(-4 - 1, 5), | ||
'weight_pyr': np.logspace(-4, -1, 5), | ||
'mu': np.linspace(20, 80, 5), | ||
'sigma': np.linspace(1, 20, 5) | ||
} | ||
|
||
batch_simulation = BatchSimulate(set_params=set_params) | ||
simulation_results = batch_simulation.run(param_grid, n_jobs=n_jobs) | ||
simulation_results = batch_simulation.run(param_grid, | ||
n_jobs=n_jobs, | ||
combinations=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,10 @@ | ||
"""Batch simulation.""" | ||
|
||
# Authors: Abdul Samad Siddiqui <[email protected]> | ||
# Nick Tolley <[email protected]> | ||
# Ryan Thorpe <[email protected]> | ||
# Mainak Jas <[email protected]> | ||
|
||
from joblib import Parallel, delayed | ||
from hnn_core import simulate_dipole | ||
from hnn_core.network_models import (jones_2009_model, | ||
|
@@ -16,8 +22,8 @@ def __init__(self, set_params, net_name='jones', tstop=170, | |
set_params : func | ||
User-defined function that sets parameters in network drives. | ||
`set_params(net, params) -> None` | ||
net_name : str, optional | ||
The name of the network model to use. Default is 'jones'. | ||
net_name : str | ||
The name of the network model to use. Default is `jones`. | ||
tstop : float, optional | ||
The stop time for the simulation. Default is 170 ms. | ||
dt : float, optional | ||
|
@@ -31,44 +37,36 @@ def __init__(self, set_params, net_name='jones', tstop=170, | |
self.dt = dt | ||
self.n_trials = n_trials | ||
|
||
def run(self, param_grid, max_size=None, return_output=True, | ||
save_output=False, fpath='./', n_jobs=1): | ||
def run(self, param_grid, return_output=True, combinations=True, n_jobs=1): | ||
""" | ||
Run batch simulations. | ||
Parameters | ||
---------- | ||
param_grid : dict | ||
Dictionary with parameter names and ranges. | ||
max_size : int, optional | ||
Maximum size of the batch. Default is None. | ||
return_output : bool, optional | ||
Whether to return the simulation outputs. Default is True. | ||
save_output : bool, optional | ||
Whether to save the outputs to disk. Default is False. | ||
fpath : str, optional | ||
File path for saving outputs. Default is './'. | ||
combinations : bool, optional | ||
Whether to generate the Cartesian product of the parameter ranges. | ||
If False, generate combinations based on corresponding indices. | ||
Default is True. | ||
n_jobs : int, optional | ||
Number of parallel jobs. Default is -1. | ||
Number of parallel jobs. Default is 1. | ||
Returns | ||
------- | ||
list | ||
results : list | ||
List of simulation results if return_output is True. | ||
""" | ||
param_combinations = self._generate_param_combinations( | ||
param_grid, max_size) | ||
# print("param_combinations-->",param_combinations) | ||
param_grid, combinations) | ||
results = self.simulate_batch(param_combinations, n_jobs=n_jobs) | ||
print(results) | ||
|
||
# if save_output: | ||
# self.save(results, param_combinations, fpath, max_size) | ||
|
||
if return_output: | ||
return results | ||
|
||
def simulate_batch(self, param_combinations, n_jobs=-1): | ||
def simulate_batch(self, param_combinations, n_jobs=1): | ||
""" | ||
Simulate a batch of parameter sets in parallel. | ||
|
@@ -77,12 +75,16 @@ def simulate_batch(self, param_combinations, n_jobs=-1): | |
param_combinations : list | ||
List of parameter combinations. | ||
n_jobs : int, optional | ||
Number of parallel jobs. Default is -1. | ||
Number of parallel jobs. Default is 1. | ||
Returns | ||
------- | ||
list | ||
List of simulation results. | ||
res: list | ||
List of dictionaries containing simulation results. | ||
Each dictionary contains the following keys: | ||
- `net`: The network model used for the simulation. | ||
- `dpl`: The simulated dipole. | ||
- `param_values`: The parameter values used for the simulation. | ||
""" | ||
res = Parallel(n_jobs=n_jobs, verbose=50)( | ||
delayed(self._run_single_sim)( | ||
|
@@ -100,8 +102,12 @@ def _run_single_sim(self, param_values): | |
Returns | ||
------- | ||
instance of Dipole | ||
The simulated dipole. | ||
dict | ||
Dictionary containing the simulation results. | ||
The dictionary contains the following keys: | ||
- `net`: The network model used for the simulation. | ||
- `dpl`: The simulated dipole. | ||
- `param_values`: The parameter values used for the simulation. | ||
""" | ||
if self.net_name not in ['jones', 'law', 'calcium']: | ||
raise ValueError( | ||
|
@@ -114,35 +120,38 @@ def _run_single_sim(self, param_values): | |
net = law_2021_model() | ||
elif self.net_name == 'calcium': | ||
net = calcium_model() | ||
|
||
print(param_values) | ||
self.set_params(param_values, net) | ||
dpl = simulate_dipole(net, tstop=self.tstop, dt=self.dt, | ||
n_trials=self.n_trials) | ||
return dpl | ||
|
||
def _generate_param_combinations(self, param_grid, max_size=None): | ||
return {'net': net, 'dpl': dpl, 'param_values': param_values} | ||
|
||
def _generate_param_combinations(self, param_grid, combinations=True): | ||
""" | ||
Generate combinations of parameters from the grid. | ||
Parameters | ||
---------- | ||
param_grid : dict | ||
Dictionary with parameter names and ranges. | ||
max_size : int, optional | ||
Maximum size of the batch. Default is None. | ||
combinations : bool, optional | ||
Whether to generate the Cartesian product of the parameter ranges. | ||
If False, generate combinations based on corresponding indices. | ||
Default is True. | ||
Returns | ||
------- | ||
list | ||
param_combinations: list | ||
List of parameter combinations. | ||
""" | ||
from itertools import product | ||
|
||
keys, values = zip(*param_grid.items()) | ||
combinations = [dict(zip(keys, combination)) | ||
for combination in product(*values)] | ||
|
||
if max_size is not None: | ||
combinations = combinations[:max_size] | ||
|
||
return combinations | ||
if combinations: | ||
param_combinations = [dict(zip(keys, combination)) | ||
for combination in product(*values)] | ||
else: | ||
param_combinations = [dict(zip(keys, combination)) | ||
for combination in zip(*values)] | ||
return param_combinations |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Authors: Abdul Samad Siddiqui <[email protected]> | ||
# Nick Tolley <[email protected]> | ||
# Ryan Thorpe <[email protected]> | ||
# Mainak Jas <[email protected]> | ||
|
||
import pytest | ||
import numpy as np | ||
|
||
from hnn_core import BatchSimulate | ||
|
||
|
||
@pytest.fixture | ||
def batch_simulate_instance(): | ||
"""Fixture for creating a BatchSimulate instance with custom parameters.""" | ||
def set_params(param_values, net): | ||
weights_ampa = {'L2_basket': param_values['weight_basket'], | ||
'L2_pyramidal': param_values['weight_pyr'], | ||
'L5_basket': param_values['weight_basket'], | ||
'L5_pyramidal': param_values['weight_pyr']} | ||
|
||
synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, | ||
'L5_basket': 1., 'L5_pyramidal': 1.} | ||
|
||
mu = param_values['mu'] | ||
sigma = param_values['sigma'] | ||
net.add_evoked_drive('evprox', | ||
mu=mu, | ||
sigma=sigma, | ||
numspikes=1, | ||
location='proximal', | ||
weights_ampa=weights_ampa, | ||
synaptic_delays=synaptic_delays) | ||
|
||
return BatchSimulate(set_params=set_params, tstop=10.) | ||
|
||
|
||
@pytest.fixture | ||
def param_grid(): | ||
"""Returns a dictionary representing a parameter grid for | ||
batch simulation.""" | ||
return { | ||
'weight_basket': np.logspace(-4 - 1, 2), | ||
'weight_pyr': np.logspace(-4, -1, 2), | ||
'mu': np.linspace(20, 80, 2), | ||
'sigma': np.linspace(1, 20, 2) | ||
} | ||
|
||
|
||
def test_generate_param_combinations(batch_simulate_instance, param_grid): | ||
"""Test generating parameter combinations.""" | ||
param_combinations = batch_simulate_instance._generate_param_combinations( | ||
param_grid) | ||
assert len(param_combinations) == ( | ||
len(param_grid['weight_basket']) * | ||
len(param_grid['weight_pyr']) * | ||
len(param_grid['mu']) * | ||
len(param_grid['sigma']) | ||
) | ||
|
||
|
||
def test_run_single_sim(batch_simulate_instance): | ||
"""Test running a single simulation.""" | ||
param_values = { | ||
'weight_basket': -3, | ||
'weight_pyr': -2, | ||
'mu': 40, | ||
'sigma': 20 | ||
} | ||
result = batch_simulate_instance._run_single_sim(param_values) | ||
assert 'net' in result | ||
assert 'dpl' in result | ||
assert 'param_values' in result | ||
assert result['param_values'] == param_values | ||
|
||
|
||
def test_simulate_batch(batch_simulate_instance, param_grid): | ||
"""Test simulating a batch of parameter sets.""" | ||
param_combinations = batch_simulate_instance._generate_param_combinations( | ||
param_grid)[:3] | ||
results = batch_simulate_instance.simulate_batch(param_combinations, | ||
n_jobs=2) | ||
assert len(results) == len(param_combinations) | ||
for result in results: | ||
assert 'net' in result | ||
assert 'dpl' in result | ||
assert 'param_values' in result | ||
|
||
|
||
def test_run(batch_simulate_instance, param_grid): | ||
"""Test the run method of the batch_simulate_instance.""" | ||
results = batch_simulate_instance.run(param_grid, n_jobs=2, | ||
return_output=True, | ||
combinations=False) | ||
|
||
assert results is not None | ||
assert isinstance(results, list) | ||
assert len(results) == len( | ||
batch_simulate_instance._generate_param_combinations( | ||
param_grid, combinations=False)) |