forked from jonescompneurolab/hnn-core
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Mainak Jas <[email protected]>
- Loading branch information
Showing
6 changed files
with
304 additions
and
168 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
""" | ||
================================================= | ||
==================== | ||
06. Batch Simulation | ||
================================================= | ||
==================== | ||
This example shows how to do batch simulations in HNN-core, allowing users to | ||
efficiently run multiple simulations with different parameters | ||
|
@@ -13,17 +13,17 @@ | |
# Ryan Thorpe <[email protected]> | ||
# Mainak Jas <[email protected]> | ||
|
||
############################################################################### | ||
|
||
############################################################################# | ||
# Let us import ``hnn_core``. | ||
|
||
import hnn_core | ||
import numpy as np | ||
from hnn_core import BatchSimulate | ||
from hnn_core.network_models import jones_2009_model | ||
|
||
# The number of cores may need modifying depending on your current machine. | ||
n_jobs = 10 | ||
############################################################################### | ||
########################################################################### | ||
|
||
|
||
def set_params(param_values, net=None): | ||
|
@@ -32,43 +32,42 @@ def set_params(param_values, net=None): | |
Parameters | ||
---------- | ||
param_grid : dict | ||
param_values : dict | ||
Dictionary of parameter values. | ||
net : instance of Network, optional | ||
If None, a new network is created using the specified model type. | ||
""" | ||
if net is None: | ||
net = jones_2009_model() | ||
|
||
weights_ampa = {'L2_basket': param_grid['weight_basket'], | ||
'L2_pyramidal': param_grid['weight_pyr'], | ||
'L5_basket': param_grid['weight_basket'], | ||
'L5_pyramidal': param_grid['weight_pyr']} | ||
weights_ampa = {'L2_basket': param_values['weight_basket'], | ||
'L2_pyramidal': param_values['weight_pyr'], | ||
'L5_basket': param_values['weight_basket'], | ||
'L5_pyramidal': param_values['weight_pyr']} | ||
|
||
synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, | ||
'L5_basket': 1., 'L5_pyramidal': 1.} | ||
|
||
mu = param_grid['mu'] | ||
sigma = param_grid['sigma'] | ||
|
||
# Add an evoked drive to the network | ||
# Add an evoked drive to the network. | ||
net.add_evoked_drive('evprox', | ||
mu=mu, | ||
sigma=sigma, | ||
mu=param_values['mu'], | ||
sigma=param_values['sigma'], | ||
numspikes=1, | ||
location='proximal', | ||
weights_ampa=weights_ampa, | ||
synaptic_delays=synaptic_delays) | ||
|
||
############################################################################### | ||
########################################################################### | ||
|
||
|
||
param_grid = { | ||
'weight_basket': np.logspace(-4 -1, 5), | ||
'weight_basket': np.logspace(-4 - 1, 5), | ||
'weight_pyr': np.logspace(-4, -1, 5), | ||
'mu': np.linspace(20, 80, 5), | ||
'sigma': np.linspace(1, 20, 5) | ||
} | ||
|
||
batch_simulation = BatchSimulate(set_params=set_params) | ||
simulation_results = batch_simulation.run(param_grid, n_jobs=n_jobs) | ||
simulation_results = batch_simulation.run(param_grid, | ||
n_jobs=n_jobs, | ||
combinations=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
"""Batch simulation.""" | ||
|
||
# Authors: Abdul Samad Siddiqui <[email protected]> | ||
# Nick Tolley <[email protected]> | ||
# Ryan Thorpe <[email protected]> | ||
# Mainak Jas <[email protected]> | ||
|
||
from joblib import Parallel, delayed | ||
from hnn_core import simulate_dipole | ||
from hnn_core.network_models import (jones_2009_model, | ||
calcium_model, law_2021_model) | ||
|
||
|
||
class BatchSimulate: | ||
def __init__(self, set_params, net_name='jones', tstop=170, | ||
dt=0.025, n_trials=1, record_vsec=False, | ||
record_isec=False, postproc=False): | ||
""" | ||
Initialize the BatchSimulate class. | ||
Parameters | ||
---------- | ||
set_params : func | ||
User-defined function that sets parameters in network drives. | ||
`set_params(net, params) -> None` | ||
net_name : str | ||
The name of the network model to use. Default is `jones`. | ||
tstop : float, optional | ||
The stop time for the simulation. Default is 170 ms. | ||
dt : float, optional | ||
The time step for the simulation. Default is 0.025 ms. | ||
n_trials : int, optional | ||
The number of trials for the simulation. Default is 1. | ||
record_vsec : 'all' | 'soma' | False | ||
Option to record voltages from all sections ('all'), or just | ||
the soma ('soma'). Default: False. | ||
record_isec : 'all' | 'soma' | False | ||
Option to record voltages from all sections ('all'), or just | ||
the soma ('soma'). Default: False. | ||
postproc : bool | ||
If True, smoothing (``dipole_smooth_win``) and scaling | ||
(``dipole_scalefctr``) values are read from the parameter file, and | ||
applied to the dipole objects before returning. | ||
Note that this setting | ||
only affects the dipole waveforms, and not somatic voltages, | ||
possible extracellular recordings etc. | ||
The preferred way is to use the | ||
:meth:`~hnn_core.dipole.Dipole.smooth` and | ||
:meth:`~hnn_core.dipole.Dipole.scale` methods instead. | ||
Default: False. | ||
""" | ||
self.set_params = set_params | ||
self.net_name = net_name | ||
self.tstop = tstop | ||
self.dt = dt | ||
self.n_trials = n_trials | ||
self.record_vsec = record_vsec | ||
self.record_isec = record_isec | ||
self.postproc = postproc | ||
|
||
def run(self, param_grid, return_output=True, combinations=True, n_jobs=1): | ||
""" | ||
Run batch simulations. | ||
Parameters | ||
---------- | ||
param_grid : dict | ||
Dictionary with parameter names and ranges. | ||
return_output : bool, optional | ||
Whether to return the simulation outputs. Default is True. | ||
combinations : bool, optional | ||
Whether to generate the Cartesian product of the parameter ranges. | ||
If False, generate combinations based on corresponding indices. | ||
Default is True. | ||
n_jobs : int, optional | ||
Number of parallel jobs. Default is 1. | ||
Returns | ||
------- | ||
results : list | ||
List of simulation results if return_output is True. | ||
""" | ||
param_combinations = self._generate_param_combinations( | ||
param_grid, combinations) | ||
results = self.simulate_batch(param_combinations, n_jobs=n_jobs) | ||
|
||
if return_output: | ||
return results | ||
|
||
def simulate_batch(self, param_combinations, n_jobs=1): | ||
""" | ||
Simulate a batch of parameter sets in parallel. | ||
Parameters | ||
---------- | ||
param_combinations : list | ||
List of parameter combinations. | ||
n_jobs : int, optional | ||
Number of parallel jobs. Default is 1. | ||
Returns | ||
------- | ||
res: list | ||
List of dictionaries containing simulation results. | ||
Each dictionary contains the following keys: | ||
- `net`: The network model used for the simulation. | ||
- `dpl`: The simulated dipole. | ||
- `param_values`: The parameter values used for the simulation. | ||
""" | ||
res = Parallel(n_jobs=n_jobs, verbose=50)( | ||
delayed(self._run_single_sim)( | ||
params) for params in param_combinations) | ||
return res | ||
|
||
def _run_single_sim(self, param_values): | ||
""" | ||
Run a single simulation. | ||
Parameters | ||
---------- | ||
param_values : dict | ||
Dictionary of parameter values. | ||
Returns | ||
------- | ||
dict | ||
Dictionary containing the simulation results. | ||
The dictionary contains the following keys: | ||
- `net`: The network model used for the simulation. | ||
- `dpl`: The simulated dipole. | ||
- `param_values`: The parameter values used for the simulation. | ||
""" | ||
if self.net_name not in ['jones', 'law', 'calcium']: | ||
raise ValueError( | ||
f"Unknown network model: {self.net_name}. " | ||
"Valid options are 'jones', 'law', and 'calcium'." | ||
) | ||
elif self.net_name == 'jones': | ||
net = jones_2009_model() | ||
elif self.net_name == 'law': | ||
net = law_2021_model() | ||
elif self.net_name == 'calcium': | ||
net = calcium_model() | ||
print(param_values) | ||
self.set_params(param_values, net) | ||
dpl = simulate_dipole(net, | ||
tstop=self.tstop, | ||
dt=self.dt, | ||
n_trials=self.n_trials, | ||
record_vsec=self.record_vsec, | ||
record_isec=self.record_isec, | ||
postproc=self.postproc) | ||
|
||
return {'net': net, 'dpl': dpl, 'param_values': param_values} | ||
|
||
def _generate_param_combinations(self, param_grid, combinations=True): | ||
""" | ||
Generate combinations of parameters from the grid. | ||
Parameters | ||
---------- | ||
param_grid : dict | ||
Dictionary with parameter names and ranges. | ||
combinations : bool, optional | ||
Whether to generate the Cartesian product of the parameter ranges. | ||
If False, generate combinations based on corresponding indices. | ||
Default is True. | ||
Returns | ||
------- | ||
param_combinations: list | ||
List of parameter combinations. | ||
""" | ||
from itertools import product | ||
|
||
keys, values = zip(*param_grid.items()) | ||
if combinations: | ||
param_combinations = [dict(zip(keys, combination)) | ||
for combination in product(*values)] | ||
else: | ||
param_combinations = [dict(zip(keys, combination)) | ||
for combination in zip(*values)] | ||
return param_combinations |
Oops, something went wrong.