Skip to content

Commit

Permalink
MAINT: Moved _add_cell_type_bias out of the Network class. Added more…
Browse files Browse the repository at this point in the history
… test chechecking data types and warnings
  • Loading branch information
kmilo9999 committed May 14, 2024
1 parent f770827 commit 6246cdf
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 50 deletions.
6 changes: 2 additions & 4 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,8 @@ API
convergence. User can constrain parameter ranges and specify solver,
by `Carolina Fernandez Pujol`_ in :gh:`652`

- :func:`network.add_tonic_bias` argument `amplitue` now accepts a
cell_type(str)-amplitude(float) dictionary in case `cell_type` argument is None,
otherwise `amplitue` is a float indicating the amplitude of the tonic input for the specific
`cell_type`
- :func:`network.add_tonic_bias` cell-specific tonic bias can now be
provided using the argument amplitude in network.add_tonic_bias`,
by `Camilo Diaz`_ in :gh:`766`

.. _0.3:
Expand Down
84 changes: 48 additions & 36 deletions hnn_core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .check import _check_gids, _gid_to_type, _string_input_to_list
from .hnn_io import write_network
from .externals.mne import copy_doc
from typing import Union


def _create_cell_coords(n_pyr_x, n_pyr_y, zdiff, inplane_distance):
Expand Down Expand Up @@ -298,7 +299,7 @@ def pick_connection(net, src_gids=None, target_gids=None,
return sorted(conn_set)


class Network(object):
class Network:
"""The Network class.
Parameters
Expand Down Expand Up @@ -1127,28 +1128,6 @@ def _instantiate_drives(self, tstop, n_trials=1):
self.external_drives[
drive['name']]['events'].append(event_times)

def _add_cell_type_bias(self, *, cell_type=None, amplitude,
t_0=0, t_stop=None):
if cell_type is not None:
# Validate cell_type value
if cell_type not in self.cell_types:
raise ValueError(f'cell_type must be one of '
f'{list(self.cell_types.keys())}. '
f'Got {cell_type}')

if 'tonic' not in self.external_biases:
self.external_biases['tonic'] = dict()

if cell_type in self.external_biases['tonic']:
raise ValueError(f'Tonic bias already defined for {cell_type}')

cell_type_bias = {
'amplitude': amplitude,
't0': t_0,
'tstop': t_stop
}
self.external_biases['tonic'][cell_type] = cell_type_bias

def add_tonic_bias(self, *, cell_type=None, amplitude, t0=0, tstop=None):
"""Attaches parameters of tonic bias input for given cell types
Expand All @@ -1175,28 +1154,29 @@ def add_tonic_bias(self, *, cell_type=None, amplitude, t0=0, tstop=None):
"""

# old functionality single cell type - amplitude
if (cell_type is not None):
if cell_type is not None:
warnings.warn('cell_type argument will be deprecated and '
'removed in future releases', DeprecationWarning,
'removed in future releases. Use amplitude as a '
'cell_type:str,amplitude:float dictionary.'
'Read the function docustring for more information',
DeprecationWarning,
stacklevel=1)
if (not isinstance(amplitude, float)):
raise ValueError('amplitude parameter is not float')
_validate_type(amplitude, float, 'amplitude')

self._add_cell_type_bias(cell_type=cell_type, amplitude=amplitude,
t_0=t0, t_stop=tstop)
_add_cell_type_bias(network=self, cell_type=cell_type,
amplitude=amplitude,
t_0=t0, t_stop=tstop)
else:
if (not isinstance(amplitude, dict)):
raise ValueError('amplitude parameter is not a dictionary')

if (len(amplitude) == 0):
_validate_type(amplitude, dict, 'amplitude')
if len(amplitude) == 0:
warnings.warn('No bias have been defined, no action taken',
UserWarning, stacklevel=1)
return

for _cell_type, _amplitude in amplitude.items():
self._add_cell_type_bias(cell_type=_cell_type,
amplitude=_amplitude,
t_0=t0, t_stop=tstop)
_add_cell_type_bias(network=self, cell_type=_cell_type,
amplitude=_amplitude,
t_0=t0, t_stop=tstop)

def _add_cell_type(self, cell_name, pos, cell_template=None):
"""Add cell type by updating pos_dict and gid_ranges."""
Expand Down Expand Up @@ -1578,3 +1558,35 @@ def __repr__(self):
f"{len(self['events'])} trial{plurl}")
entr += '>'
return entr


def _add_cell_type_bias(network: Network, amplitude: Union[float, dict],
cell_type=None,
t_0=0, t_stop=None):

if network is None:
raise ValueError('The "network" parameter is required '
'but was not provided')
if amplitude is None:
raise ValueError('The "amplitude" parameter is required '
'but was not provided')

if cell_type is not None:
# Validate cell_type value
if cell_type not in network.cell_types:
raise ValueError(f'cell_type must be one of '
f'{list(network.cell_types.keys())}. '
f'Got {cell_type}')

if 'tonic' not in network.external_biases:
network.external_biases['tonic'] = dict()

if cell_type in network.external_biases['tonic']:
raise ValueError(f'Tonic bias already defined for {cell_type}')

cell_type_bias = {
'amplitude': amplitude,
't0': t_0,
'tstop': t_stop
}
network.external_biases['tonic'][cell_type] = cell_type_bias
43 changes: 33 additions & 10 deletions hnn_core/tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,15 +758,23 @@ def test_tonic_biases():
delay=1.0, lamtha=3.0)

tonic_bias_1 = {
'L2_pyramidal': 1.0,
'name_nonexistent': 1.0
}

with pytest.raises(ValueError, match=r'cell_type must be one of .*$'):
net.add_tonic_bias(amplitude=tonic_bias_1, t0=0.0,
tstop=4.0)
net.external_biases = dict()

with pytest.raises(TypeError,
match='amplitude must be an instance of dict'):
net.add_tonic_bias(amplitude=0.1,
t0=5.0, tstop=-1.0)

tonic_bias_2 = {
'L2_pyramidal': 1.0
'L2_pyramidal': 1.0,
'L5_basket': 0.5
}

with pytest.raises(ValueError, match='Duration of tonic input cannot be'
Expand All @@ -781,29 +789,44 @@ def test_tonic_biases():
net.add_tonic_bias(amplitude=tonic_bias_2,
t0=5.0, tstop=-1.0)
simulate_dipole(net, tstop=5.)
net.external_biases = dict()

with pytest.raises(ValueError, match='parameter may be missing'):
params['Itonic_T_L2Pyr_soma'] = 5.0
net = Network(params, add_drives_from_params=True)
net.external_biases = dict()

# test adding single cell_type - amplitude (old API)
net.external_biases = dict()
with pytest.raises(ValueError, match=r'cell_type must be one of .*$'):
net.add_tonic_bias(cell_type='name_nonexistent', amplitude=1.0,
t0=0.0, tstop=4.0)
with pytest.warns(DeprecationWarning,
match=r'cell_type argument will be deprecated'):
net.add_tonic_bias(cell_type='name_nonexistent', amplitude=1.0,
t0=0.0, tstop=4.0)

with pytest.raises(TypeError,
match='amplitude must be an instance of float'):
with pytest.warns(DeprecationWarning,
match=r'cell_type argument will be deprecated'):
net.add_tonic_bias(cell_type='L5_pyramidal',
amplitude={'L2_pyramidal': 0.1},
t0=5.0, tstop=-1.0)

with pytest.raises(ValueError, match='Duration of tonic input cannot be'
' negative'):
net.add_tonic_bias(cell_type='L2_pyramidal', amplitude=1.0,
t0=5.0, tstop=4.0)
simulate_dipole(net, tstop=20.)
with pytest.warns(DeprecationWarning,
match=r'cell_type argument will be deprecated'):
net.add_tonic_bias(cell_type='L2_pyramidal', amplitude=1.0,
t0=5.0, tstop=4.0)
simulate_dipole(net, tstop=20.)
net.external_biases = dict()

with pytest.raises(ValueError, match='End time of tonic input cannot be'
' negative'):
net.add_tonic_bias(cell_type='L2_pyramidal', amplitude=1.0,
t0=5.0, tstop=-1.0)
simulate_dipole(net, tstop=5.)
with pytest.warns(DeprecationWarning,
match=r'cell_type argument will be deprecated'):
net.add_tonic_bias(cell_type='L2_pyramidal', amplitude=1.0,
t0=5.0, tstop=-1.0)
simulate_dipole(net, tstop=5.)

params.update({
'N_pyr_x': 3, 'N_pyr_y': 3,
Expand Down

0 comments on commit 6246cdf

Please sign in to comment.