forked from jonescompneurolab/hnn-core
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: samadpls <[email protected]>
- Loading branch information
Showing
2 changed files
with
41 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
# Ryan Thorpe <[email protected]> | ||
# Mainak Jas <[email protected]> | ||
|
||
import json | ||
import numpy as np | ||
import os | ||
from joblib import Parallel, delayed, parallel_config | ||
|
@@ -13,6 +14,7 @@ | |
from .externals.mne import _validate_type, _check_option | ||
from .dipole import simulate_dipole | ||
from .network_models import jones_2009_model | ||
from .hnn_io import dict_to_network | ||
|
||
|
||
class BatchSimulate(object): | ||
|
@@ -24,7 +26,7 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170, | |
save_dpl=True, save_spiking=False, | ||
save_lfp=False, save_voltages=False, | ||
save_currents=False, save_calcium=False, | ||
clear_cache=False): | ||
clear_cache=False, net_json=None): | ||
"""Initialize the BatchSimulate class. | ||
Parameters | ||
|
@@ -100,6 +102,9 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170, | |
clear_cache : bool, optional | ||
Whether to clear the results cache after saving each batch. | ||
Default is False. | ||
net_json : str, optional | ||
The path to a JSON file to create the network model. If provided, | ||
this will override the `net` parameter. Default is None. | ||
Notes | ||
----- | ||
When `save_output=True`, the saved files will appear as | ||
|
@@ -127,6 +132,8 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170, | |
_validate_type(save_currents, types=(bool,), item_name='save_currents') | ||
_validate_type(save_calcium, types=(bool,), item_name='save_calcium') | ||
_validate_type(clear_cache, types=(bool,), item_name='clear_cache') | ||
_validate_type(net_json, types=('path-like', None), | ||
item_name='net_json') | ||
|
||
if set_params is not None and not callable(set_params): | ||
raise TypeError("set_params must be a callable function") | ||
|
@@ -154,6 +161,7 @@ def __init__(self, set_params, net=jones_2009_model(), tstop=170, | |
self.save_currents = save_currents | ||
self.save_calcium = save_calcium | ||
self.clear_cache = clear_cache | ||
self.net_json = net_json | ||
|
||
def run(self, param_grid, return_output=True, | ||
combinations=True, n_jobs=1, backend='loky', | ||
|
@@ -295,7 +303,14 @@ def _run_single_sim(self, param_values): | |
- `param_values`: The parameter values used for the simulation. | ||
""" | ||
|
||
net = self.net.copy() | ||
if isinstance(self.net_json, str): | ||
with open(self.net_json, 'r') as file: | ||
net_data = json.load(file) | ||
net = dict_to_network(net_data) | ||
else: | ||
net = self.net | ||
net = net.copy() | ||
|
||
self.set_params(param_values, net) | ||
|
||
results = {'net': net, 'param_values': param_values} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,13 +3,17 @@ | |
# Ryan Thorpe <[email protected]> | ||
# Mainak Jas <[email protected]> | ||
|
||
from pathlib import Path | ||
import pytest | ||
import numpy as np | ||
import os | ||
|
||
from hnn_core.batch_simulate import BatchSimulate | ||
from hnn_core import jones_2009_model | ||
|
||
hnn_core_root = Path(__file__).parents[1] | ||
assets_path = Path(hnn_core_root, 'tests', 'assets') | ||
|
||
|
||
@pytest.fixture | ||
def batch_simulate_instance(tmp_path): | ||
|
@@ -33,9 +37,9 @@ def set_params(param_values, net): | |
weights_ampa=weights_ampa, | ||
synaptic_delays=synaptic_delays) | ||
|
||
net = jones_2009_model() | ||
net = jones_2009_model(mesh_shape=(3, 3)) | ||
return BatchSimulate(net=net, set_params=set_params, | ||
tstop=1., | ||
tstop=10, | ||
save_folder=tmp_path, | ||
batch_size=3) | ||
|
||
|
@@ -75,6 +79,9 @@ def test_parameter_validation(): | |
with pytest.raises(TypeError, match="net must be"): | ||
BatchSimulate(net="invalid_network", set_params=lambda x: x) | ||
|
||
with pytest.raises(TypeError, match="net_json must be"): | ||
BatchSimulate(net_json=123, set_params=lambda x: x) | ||
|
||
|
||
def test_generate_param_combinations(batch_simulate_instance, param_grid): | ||
"""Test generating parameter combinations.""" | ||
|
@@ -104,6 +111,21 @@ def test_run_single_sim(batch_simulate_instance): | |
assert isinstance(result['net'], type(batch_simulate_instance.net)) | ||
|
||
|
||
def test_net_json_loading(param_grid): | ||
"""Test loading the network from a JSON file.""" | ||
json_path = assets_path / 'jones2009_3x3_drives.json' | ||
|
||
batch_simulate = BatchSimulate(net_json=str(json_path), | ||
set_params=lambda x, y: x, | ||
tstop=70) | ||
|
||
result = batch_simulate._run_single_sim(param_grid) | ||
assert isinstance(result, dict) | ||
assert 'net' in result | ||
assert 'param_values' in result | ||
assert 'dpl' in result | ||
|
||
|
||
def test_simulate_batch(batch_simulate_instance, param_grid): | ||
"""Test simulating a batch of parameter sets.""" | ||
param_combinations = batch_simulate_instance._generate_param_combinations( | ||
|