diff --git a/hnn_core/tests/test_params.py b/hnn_core/tests/test_params.py index bc272cde03..5923cb3794 100644 --- a/hnn_core/tests/test_params.py +++ b/hnn_core/tests/test_params.py @@ -8,9 +8,11 @@ import pytest -from hnn_core import (read_params, Params, jones_2009_model, convert_to_json, +from hnn_core import (read_params, Params, convert_to_json, Network) from hnn_core.hnn_io import read_network_configuration +from hnn_core.network_models import (jones_2009_model, law_2021_model, + calcium_model) hnn_core_root = Path(__file__).parents[1] @@ -77,80 +79,140 @@ def test_base_params(): assert params == params_base -def test_convert_to_json(tmp_path): - """Tests conversion of a flat json file to hierarchical json""" - # Download params - param_url = ('https://raw.githubusercontent.com/hnn-core/' - 'hnn_core/param/default.json') - params_base_fname = Path(hnn_core_root, 'param', 'default.json') - if not op.exists(params_base_fname): - urlretrieve(param_url, params_base_fname) - net_params = Network(read_params(params_base_fname), - add_drives_from_params=True, - ) - - # Write hdf5 and check if constructed network is equal - outpath = Path(tmp_path, 'default.json') - convert_to_json(params_base_fname, outpath) - net_hjson = read_network_configuration(outpath) - assert net_hjson == net_params - - # Write hdf5 without drives - outpath_no_drives = Path(tmp_path, 'default_no_drives.json') - convert_to_json(params_base_fname, outpath_no_drives, include_drives=False) - net_hjson_no_drives = read_network_configuration(outpath_no_drives) - assert net_hjson_no_drives != net_hjson - assert bool(net_hjson_no_drives.external_drives) is False - - # Check that writing with no extension will add one - outpath_no_ext = Path(tmp_path, 'default_no_ext') - convert_to_json(params_base_fname, outpath_no_ext) - assert outpath_no_ext.with_suffix('.json').exists() - - -def test_convert_to_hdf5_legacy(tmp_path): - """Tests conversion of a param legacy file to hdf5""" - # Download params - param_url = ('https://raw.githubusercontent.com/hnnsolver/' - 'hnn-core/test_data/default.param') - params_base_fname = Path(hnn_core_root, 'param', 'default.param') - if not op.exists(params_base_fname): - urlretrieve(param_url, params_base_fname) - net_params = Network(read_params(params_base_fname), - add_drives_from_params=True, - legacy_mode=True - ) - - # Write hdf5 and check if constructed network is equal - outpath = Path(tmp_path, 'default.json') - convert_to_json(params_base_fname, outpath) - net_hdf5 = read_network_configuration(outpath) - assert net_hdf5 == net_params - - -def test_convert_to_json_bad_type(): - """Tests type validation in convert_to_json function""" - good_path = hnn_core_root - path_str = str(good_path) - bad_path = 5 - - # Valid path and string, but not actual files - with pytest.raises( - ValueError, - match="Unrecognized extension, expected one of" - ): - convert_to_json(good_path, path_str) - - # Bad params_fname - with pytest.raises( - TypeError, - match="params_fname must be an instance of str or Path" - ): - convert_to_json(bad_path, good_path) - - # Bad out_fname - with pytest.raises( - TypeError, - match="out_fname must be an instance of str or Path" - ): - convert_to_json(good_path, bad_path) +class TestConvertToJson: + """Tests convert_to_json function""" + + path_default = Path(hnn_core_root, 'param', 'default.json') + + def test_default_network_connectivity(self, tmp_path): + """Tests conversion with default parameters""" + + net_params = jones_2009_model(params=read_params(self.path_default), + add_drives_from_params=True) + + # Write json and check if constructed network is equal + outpath = Path(tmp_path, 'default.json') + convert_to_json(self.path_default, + outpath + ) + net_json = read_network_configuration(outpath) + assert net_json == net_params + + # Write json without drives + outpath_no_drives = Path(tmp_path, 'default_no_drives.json') + convert_to_json(self.path_default, + outpath_no_drives, + include_drives=False + ) + net_json_no_drives = read_network_configuration(outpath_no_drives) + assert net_json_no_drives != net_json + assert bool(net_json_no_drives.external_drives) is False + + # Check that writing with no extension will add one + outpath_no_ext = Path(tmp_path, 'default_no_ext') + convert_to_json(self.path_default, + outpath_no_ext + ) + assert outpath_no_ext.with_suffix('.json').exists() + + def test_law_network_connectivity(self, tmp_path): + """Tests conversion with Law 2021 network connectivity model""" + + net_params = law_2021_model(read_params(self.path_default), + add_drives_from_params=True, + ) + + # Write json and check if constructed network is equal + outpath = Path(tmp_path, 'default.json') + convert_to_json(self.path_default, + outpath, + network_connectivity='law_2021_model') + net_json = read_network_configuration(outpath) + assert net_json == net_params + + def test_calcium_network_connectivity(self, tmp_path): + """Tests conversion with calcium network connectivity model""" + + net_params = calcium_model(read_params(self.path_default), + add_drives_from_params=True, + ) + + # Write json and check if constructed network is equal + outpath = Path(tmp_path, 'default.json') + convert_to_json(self.path_default, + outpath, + network_connectivity='calcium_model') + net_json = read_network_configuration(outpath) + assert net_json == net_params + + def test_no_network_connectivity(self, tmp_path): + """Tests conversion with no network connectivity model""" + + net_params = Network(read_params(self.path_default), + add_drives_from_params=True, + ) + + # Write json and check if constructed network is equal + outpath = Path(tmp_path, 'default.json') + convert_to_json(self.path_default, + outpath, + network_connectivity=None) + net_json = read_network_configuration(outpath) + assert net_json == net_params + + def test_convert_to_json_legacy(self, tmp_path): + """Tests conversion of a param legacy file to json""" + + # Download params + param_url = ('https://raw.githubusercontent.com/hnnsolver/' + 'hnn-core/test_data/default.param') + params_base_fname = Path(hnn_core_root, 'param', 'default.param') + if not op.exists(params_base_fname): + urlretrieve(param_url, params_base_fname) + net_params = jones_2009_model(read_params(params_base_fname), + add_drives_from_params=True, + legacy_mode=True + ) + + # Write json and check if constructed network is equal + outpath = Path(tmp_path, 'default.json') + convert_to_json(params_base_fname, outpath) + net_json = read_network_configuration(outpath) + assert net_json == net_params + + def test_convert_to_json_bad_type(self): + """Tests type validation in convert_to_json function""" + + good_path = hnn_core_root + path_str = str(good_path) + bad_path = 5 + bad_network_conn = 'bad_model' + + # Valid path and string, but not actual files + with pytest.raises( + ValueError, + match="Unrecognized extension, expected one of" + ): + convert_to_json(good_path, path_str) + + # Bad params_fname + with pytest.raises( + TypeError, + match="params_fname must be an instance of str or Path" + ): + convert_to_json(bad_path, good_path) + + # Bad out_fname + with pytest.raises( + TypeError, + match="out_fname must be an instance of str or Path" + ): + convert_to_json(good_path, bad_path) + + # Bad network_connectivity + with pytest.raises( + KeyError, + match="Invalid network connectivity:" + ): + convert_to_json(good_path, good_path, + network_connectivity=bad_network_conn)