From 6246cdf9176e34f49754d6b82f45ff797f704ce1 Mon Sep 17 00:00:00 2001 From: Camilo Diaz Date: Tue, 14 May 2024 17:06:07 -0400 Subject: [PATCH] MAINT: Moved _add_cell_type_bias out of the Network class. Added more test chechecking data types and warnings --- doc/whats_new.rst | 6 +-- hnn_core/network.py | 84 +++++++++++++++++++--------------- hnn_core/tests/test_network.py | 43 +++++++++++++---- 3 files changed, 83 insertions(+), 50 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index cb4bb31c0..7143d3260 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -86,10 +86,8 @@ API convergence. User can constrain parameter ranges and specify solver, by `Carolina Fernandez Pujol`_ in :gh:`652` -- :func:`network.add_tonic_bias` argument `amplitue` now accepts a - cell_type(str)-amplitude(float) dictionary in case `cell_type` argument is None, - otherwise `amplitue` is a float indicating the amplitude of the tonic input for the specific - `cell_type` +- :func:`network.add_tonic_bias` cell-specific tonic bias can now be + provided using the argument amplitude in network.add_tonic_bias`, by `Camilo Diaz`_ in :gh:`766` .. _0.3: diff --git a/hnn_core/network.py b/hnn_core/network.py index 7d2da081a..93b6c02c1 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -25,6 +25,7 @@ from .check import _check_gids, _gid_to_type, _string_input_to_list from .hnn_io import write_network from .externals.mne import copy_doc +from typing import Union def _create_cell_coords(n_pyr_x, n_pyr_y, zdiff, inplane_distance): @@ -298,7 +299,7 @@ def pick_connection(net, src_gids=None, target_gids=None, return sorted(conn_set) -class Network(object): +class Network: """The Network class. Parameters @@ -1127,28 +1128,6 @@ def _instantiate_drives(self, tstop, n_trials=1): self.external_drives[ drive['name']]['events'].append(event_times) - def _add_cell_type_bias(self, *, cell_type=None, amplitude, - t_0=0, t_stop=None): - if cell_type is not None: - # Validate cell_type value - if cell_type not in self.cell_types: - raise ValueError(f'cell_type must be one of ' - f'{list(self.cell_types.keys())}. ' - f'Got {cell_type}') - - if 'tonic' not in self.external_biases: - self.external_biases['tonic'] = dict() - - if cell_type in self.external_biases['tonic']: - raise ValueError(f'Tonic bias already defined for {cell_type}') - - cell_type_bias = { - 'amplitude': amplitude, - 't0': t_0, - 'tstop': t_stop - } - self.external_biases['tonic'][cell_type] = cell_type_bias - def add_tonic_bias(self, *, cell_type=None, amplitude, t0=0, tstop=None): """Attaches parameters of tonic bias input for given cell types @@ -1175,28 +1154,29 @@ def add_tonic_bias(self, *, cell_type=None, amplitude, t0=0, tstop=None): """ # old functionality single cell type - amplitude - if (cell_type is not None): + if cell_type is not None: warnings.warn('cell_type argument will be deprecated and ' - 'removed in future releases', DeprecationWarning, + 'removed in future releases. Use amplitude as a ' + 'cell_type:str,amplitude:float dictionary.' + 'Read the function docustring for more information', + DeprecationWarning, stacklevel=1) - if (not isinstance(amplitude, float)): - raise ValueError('amplitude parameter is not float') + _validate_type(amplitude, float, 'amplitude') - self._add_cell_type_bias(cell_type=cell_type, amplitude=amplitude, - t_0=t0, t_stop=tstop) + _add_cell_type_bias(network=self, cell_type=cell_type, + amplitude=amplitude, + t_0=t0, t_stop=tstop) else: - if (not isinstance(amplitude, dict)): - raise ValueError('amplitude parameter is not a dictionary') - - if (len(amplitude) == 0): + _validate_type(amplitude, dict, 'amplitude') + if len(amplitude) == 0: warnings.warn('No bias have been defined, no action taken', UserWarning, stacklevel=1) return for _cell_type, _amplitude in amplitude.items(): - self._add_cell_type_bias(cell_type=_cell_type, - amplitude=_amplitude, - t_0=t0, t_stop=tstop) + _add_cell_type_bias(network=self, cell_type=_cell_type, + amplitude=_amplitude, + t_0=t0, t_stop=tstop) def _add_cell_type(self, cell_name, pos, cell_template=None): """Add cell type by updating pos_dict and gid_ranges.""" @@ -1578,3 +1558,35 @@ def __repr__(self): f"{len(self['events'])} trial{plurl}") entr += '>' return entr + + +def _add_cell_type_bias(network: Network, amplitude: Union[float, dict], + cell_type=None, + t_0=0, t_stop=None): + + if network is None: + raise ValueError('The "network" parameter is required ' + 'but was not provided') + if amplitude is None: + raise ValueError('The "amplitude" parameter is required ' + 'but was not provided') + + if cell_type is not None: + # Validate cell_type value + if cell_type not in network.cell_types: + raise ValueError(f'cell_type must be one of ' + f'{list(network.cell_types.keys())}. ' + f'Got {cell_type}') + + if 'tonic' not in network.external_biases: + network.external_biases['tonic'] = dict() + + if cell_type in network.external_biases['tonic']: + raise ValueError(f'Tonic bias already defined for {cell_type}') + + cell_type_bias = { + 'amplitude': amplitude, + 't0': t_0, + 'tstop': t_stop + } + network.external_biases['tonic'][cell_type] = cell_type_bias diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index b2634ce00..206a67e94 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -758,15 +758,23 @@ def test_tonic_biases(): delay=1.0, lamtha=3.0) tonic_bias_1 = { + 'L2_pyramidal': 1.0, 'name_nonexistent': 1.0 } with pytest.raises(ValueError, match=r'cell_type must be one of .*$'): net.add_tonic_bias(amplitude=tonic_bias_1, t0=0.0, tstop=4.0) + net.external_biases = dict() + + with pytest.raises(TypeError, + match='amplitude must be an instance of dict'): + net.add_tonic_bias(amplitude=0.1, + t0=5.0, tstop=-1.0) tonic_bias_2 = { - 'L2_pyramidal': 1.0 + 'L2_pyramidal': 1.0, + 'L5_basket': 0.5 } with pytest.raises(ValueError, match='Duration of tonic input cannot be' @@ -781,29 +789,44 @@ def test_tonic_biases(): net.add_tonic_bias(amplitude=tonic_bias_2, t0=5.0, tstop=-1.0) simulate_dipole(net, tstop=5.) + net.external_biases = dict() with pytest.raises(ValueError, match='parameter may be missing'): params['Itonic_T_L2Pyr_soma'] = 5.0 net = Network(params, add_drives_from_params=True) + net.external_biases = dict() # test adding single cell_type - amplitude (old API) - net.external_biases = dict() with pytest.raises(ValueError, match=r'cell_type must be one of .*$'): - net.add_tonic_bias(cell_type='name_nonexistent', amplitude=1.0, - t0=0.0, tstop=4.0) + with pytest.warns(DeprecationWarning, + match=r'cell_type argument will be deprecated'): + net.add_tonic_bias(cell_type='name_nonexistent', amplitude=1.0, + t0=0.0, tstop=4.0) + + with pytest.raises(TypeError, + match='amplitude must be an instance of float'): + with pytest.warns(DeprecationWarning, + match=r'cell_type argument will be deprecated'): + net.add_tonic_bias(cell_type='L5_pyramidal', + amplitude={'L2_pyramidal': 0.1}, + t0=5.0, tstop=-1.0) with pytest.raises(ValueError, match='Duration of tonic input cannot be' ' negative'): - net.add_tonic_bias(cell_type='L2_pyramidal', amplitude=1.0, - t0=5.0, tstop=4.0) - simulate_dipole(net, tstop=20.) + with pytest.warns(DeprecationWarning, + match=r'cell_type argument will be deprecated'): + net.add_tonic_bias(cell_type='L2_pyramidal', amplitude=1.0, + t0=5.0, tstop=4.0) + simulate_dipole(net, tstop=20.) net.external_biases = dict() with pytest.raises(ValueError, match='End time of tonic input cannot be' ' negative'): - net.add_tonic_bias(cell_type='L2_pyramidal', amplitude=1.0, - t0=5.0, tstop=-1.0) - simulate_dipole(net, tstop=5.) + with pytest.warns(DeprecationWarning, + match=r'cell_type argument will be deprecated'): + net.add_tonic_bias(cell_type='L2_pyramidal', amplitude=1.0, + t0=5.0, tstop=-1.0) + simulate_dipole(net, tstop=5.) params.update({ 'N_pyr_x': 3, 'N_pyr_y': 3,