From 724d0014ece412779d50932792628cfd359b3cec Mon Sep 17 00:00:00 2001 From: Abdul Samad Siddiqui Date: Mon, 10 Jun 2024 12:06:52 +0500 Subject: [PATCH] Refactor BatchSimulation and added unit test Signed-off-by: samadpls --- examples/howto/batch_simulate.py | 26 +++--- hnn_core/__init__.py | 2 +- .../{batchsimulate.py => batch_simulate.py} | 67 ++++++++----- hnn_core/tests/test_batch_simulate.py | 93 +++++++++++++++++++ 4 files changed, 152 insertions(+), 36 deletions(-) rename hnn_core/{batchsimulate.py => batch_simulate.py} (64%) create mode 100644 hnn_core/tests/test_batch_simulate.py diff --git a/examples/howto/batch_simulate.py b/examples/howto/batch_simulate.py index 4323bcbb54..33d876e561 100644 --- a/examples/howto/batch_simulate.py +++ b/examples/howto/batch_simulate.py @@ -14,7 +14,6 @@ # Mainak Jas ############################################################################### - # Let us import ``hnn_core``. import numpy as np @@ -32,7 +31,7 @@ def set_params(param_values, net=None): Parameters ---------- - param_grid : dict + param_values : dict Dictionary of parameter values. net : instance of Network, optional If None, a new network is created using the specified model type. @@ -40,21 +39,18 @@ def set_params(param_values, net=None): if net is None: net = jones_2009_model() - weights_ampa = {'L2_basket': param_grid['weight_basket'], - 'L2_pyramidal': param_grid['weight_pyr'], - 'L5_basket': param_grid['weight_basket'], - 'L5_pyramidal': param_grid['weight_pyr']} + weights_ampa = {'L2_basket': param_values['weight_basket'], + 'L2_pyramidal': param_values['weight_pyr'], + 'L5_basket': param_values['weight_basket'], + 'L5_pyramidal': param_values['weight_pyr']} synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, 'L5_basket': 1., 'L5_pyramidal': 1.} - mu = param_grid['mu'] - sigma = param_grid['sigma'] - - # Add an evoked drive to the network + # Add an evoked drive to the network. net.add_evoked_drive('evprox', - mu=mu, - sigma=sigma, + mu=param_values['mu'], + sigma=param_values['sigma'], numspikes=1, location='proximal', weights_ampa=weights_ampa, @@ -64,11 +60,13 @@ def set_params(param_values, net=None): param_grid = { - 'weight_basket': np.logspace(-4 -1, 5), + 'weight_basket': np.logspace(-4 - 1, 5), 'weight_pyr': np.logspace(-4, -1, 5), 'mu': np.linspace(20, 80, 5), 'sigma': np.linspace(1, 20, 5) } batch_simulation = BatchSimulate(set_params=set_params) -simulation_results = batch_simulation.run(param_grid, n_jobs=n_jobs) +simulation_results = batch_simulation.run(param_grid, + n_jobs=n_jobs, + combinations=False) diff --git a/hnn_core/__init__.py b/hnn_core/__init__.py index beea9be1f3..ab6c961e3c 100644 --- a/hnn_core/__init__.py +++ b/hnn_core/__init__.py @@ -6,6 +6,6 @@ from .cell_response import CellResponse, read_spikes from .cells_default import pyramidal, basket from .parallel_backends import MPIBackend, JoblibBackend -from .batchsimulate import BatchSimulate +from .batch_simulate import BatchSimulate __version__ = '0.4.dev0' diff --git a/hnn_core/batchsimulate.py b/hnn_core/batch_simulate.py similarity index 64% rename from hnn_core/batchsimulate.py rename to hnn_core/batch_simulate.py index 4cc0edddf7..2b35064f92 100644 --- a/hnn_core/batchsimulate.py +++ b/hnn_core/batch_simulate.py @@ -1,4 +1,10 @@ """Batch simulation.""" + +# Authors: Abdul Samad Siddiqui +# Nick Tolley +# Ryan Thorpe +# Mainak Jas + from joblib import Parallel, delayed from hnn_core import simulate_dipole from hnn_core.network_models import (jones_2009_model, @@ -32,7 +38,7 @@ def __init__(self, set_params, net_name='jones', tstop=170, self.n_trials = n_trials def run(self, param_grid, max_size=None, return_output=True, - save_output=False, fpath='./', n_jobs=1): + save_output=False, combinations=True, fpath='./', n_jobs=1): """ Run batch simulations. @@ -46,21 +52,23 @@ def run(self, param_grid, max_size=None, return_output=True, Whether to return the simulation outputs. Default is True. save_output : bool, optional Whether to save the outputs to disk. Default is False. + combinations : bool, optional + Whether to generate the Cartesian product of the parameter ranges. + If False, generate combinations based on corresponding indices. + Default is True. fpath : str, optional File path for saving outputs. Default is './'. n_jobs : int, optional - Number of parallel jobs. Default is -1. + Number of parallel jobs. Default is 1. Returns ------- - list + results : list List of simulation results if return_output is True. """ param_combinations = self._generate_param_combinations( - param_grid, max_size) - # print("param_combinations-->",param_combinations) + param_grid, max_size, combinations) results = self.simulate_batch(param_combinations, n_jobs=n_jobs) - print(results) # if save_output: # self.save(results, param_combinations, fpath, max_size) @@ -68,7 +76,7 @@ def run(self, param_grid, max_size=None, return_output=True, if return_output: return results - def simulate_batch(self, param_combinations, n_jobs=-1): + def simulate_batch(self, param_combinations, n_jobs=1): """ Simulate a batch of parameter sets in parallel. @@ -77,12 +85,16 @@ def simulate_batch(self, param_combinations, n_jobs=-1): param_combinations : list List of parameter combinations. n_jobs : int, optional - Number of parallel jobs. Default is -1. + Number of parallel jobs. Default is 1. Returns ------- - list - List of simulation results. + res: list + List of dictionaries containing simulation results. + Each dictionary contains the following keys: + - `net`: The network model used for the simulation. + - `dpl`: The simulated dipole. + - `param_values`: The parameter values used for the simulation. """ res = Parallel(n_jobs=n_jobs, verbose=50)( delayed(self._run_single_sim)( @@ -100,8 +112,12 @@ def _run_single_sim(self, param_values): Returns ------- - instance of Dipole - The simulated dipole. + dict + Dictionary containing the simulation results. + The dictionary contains the following keys: + - `net`: The network model used for the simulation. + - `dpl`: The simulated dipole. + - `param_values`: The parameter values used for the simulation. """ if self.net_name not in ['jones', 'law', 'calcium']: raise ValueError( @@ -114,13 +130,14 @@ def _run_single_sim(self, param_values): net = law_2021_model() elif self.net_name == 'calcium': net = calcium_model() - + print(param_values) self.set_params(param_values, net) dpl = simulate_dipole(net, tstop=self.tstop, dt=self.dt, n_trials=self.n_trials) - return dpl + return {'net': net, 'dpl': dpl, 'param_values': param_values} - def _generate_param_combinations(self, param_grid, max_size=None): + def _generate_param_combinations(self, param_grid, max_size=None, + combinations=True): """ Generate combinations of parameters from the grid. @@ -130,19 +147,27 @@ def _generate_param_combinations(self, param_grid, max_size=None): Dictionary with parameter names and ranges. max_size : int, optional Maximum size of the batch. Default is None. + combinations : bool, optional + Whether to generate the Cartesian product of the parameter ranges. + If False, generate combinations based on corresponding indices. + Default is True. Returns ------- - list + param_combinations: list List of parameter combinations. """ from itertools import product keys, values = zip(*param_grid.items()) - combinations = [dict(zip(keys, combination)) - for combination in product(*values)] + if combinations: + param_combinations = [dict(zip(keys, combination)) + for combination in product(*values)] + else: + param_combinations = [dict(zip(keys, combination)) + for combination in zip(*values)] - if max_size is not None: - combinations = combinations[:max_size] + # if max_size is not None: + # param_combinations = param_combinations[:max_size] - return combinations + return param_combinations diff --git a/hnn_core/tests/test_batch_simulate.py b/hnn_core/tests/test_batch_simulate.py new file mode 100644 index 0000000000..01b502a6d4 --- /dev/null +++ b/hnn_core/tests/test_batch_simulate.py @@ -0,0 +1,93 @@ +# Authors: Abdul Samad Siddiqui +# Nick Tolley +# Ryan Thorpe +# Mainak Jas + +import pytest +import numpy as np +from hnn_core import BatchSimulate + + +@pytest.fixture +def batch_simulate_instance(): + def set_params(param_values, net): + weights_ampa = {'L2_basket': param_values['weight_basket'], + 'L2_pyramidal': param_values['weight_pyr'], + 'L5_basket': param_values['weight_basket'], + 'L5_pyramidal': param_values['weight_pyr']} + + synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1, + 'L5_basket': 1., 'L5_pyramidal': 1.} + + mu = param_values['mu'] + sigma = param_values['sigma'] + net.add_evoked_drive('evprox', + mu=mu, + sigma=sigma, + numspikes=1, + location='proximal', + weights_ampa=weights_ampa, + synaptic_delays=synaptic_delays) + + return BatchSimulate(set_params=set_params, tstop=10.) + + +@pytest.fixture +def param_grid(): + return { + 'weight_basket': np.logspace(-4 - 1, 2), + 'weight_pyr': np.logspace(-4, -1, 2), + 'mu': np.linspace(20, 80, 2), + 'sigma': np.linspace(1, 20, 2) + } + + +def test_generate_param_combinations(batch_simulate_instance, param_grid): + """Test generating parameter combinations.""" + param_combinations = batch_simulate_instance._generate_param_combinations( + param_grid) + assert len(param_combinations) == ( + len(param_grid['weight_basket']) * + len(param_grid['weight_pyr']) * + len(param_grid['mu']) * + len(param_grid['sigma']) + ) + + +def test_run_single_sim(batch_simulate_instance): + """Test running a single simulation.""" + param_values = { + 'weight_basket': -3, + 'weight_pyr': -2, + 'mu': 40, + 'sigma': 20 + } + result = batch_simulate_instance._run_single_sim(param_values) + assert 'net' in result + assert 'dpl' in result + assert 'param_values' in result + assert result['param_values'] == param_values + + +def test_simulate_batch(batch_simulate_instance, param_grid): + """Test simulating a batch of parameter sets.""" + param_combinations = batch_simulate_instance._generate_param_combinations( + param_grid)[:3] + results = batch_simulate_instance.simulate_batch(param_combinations, + n_jobs=2) + assert len(results) == len(param_combinations) + for result in results: + assert 'net' in result + assert 'dpl' in result + assert 'param_values' in result + + +def test_run(batch_simulate_instance, param_grid): + results = batch_simulate_instance.run(param_grid, n_jobs=2, + return_output=True, + combinations=False) + assert results is not None + assert isinstance(results, list) + assert len(results) == len( + batch_simulate_instance._generate_param_combinations( + param_grid, combinations=False))