Skip to content

Commit

Permalink
First pass at adding dipoles for individual cells
Browse files Browse the repository at this point in the history
  • Loading branch information
ntolley committed Oct 25, 2023
1 parent 7953101 commit adffb77
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 5 deletions.
12 changes: 12 additions & 0 deletions hnn_core/cell_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class CellResponse(object):
isec : list (n_trials,) of dict, shape
Each element of the outer list is a trial.
Dictionary indexed by gids containing currents for cell sections.
dcell : list (n_trials,) of dict, shape
Each element of the outer list is a trial.
Dictionary indexed by gids containing dipoles for individual cells.
times : array-like, shape (n_times,)
Array of time points for samples in continuous data.
This includes vsoma and isoma.
Expand Down Expand Up @@ -115,6 +118,7 @@ def __init__(self, spike_times=None, spike_gids=None, spike_types=None,
self._spike_types = spike_types
self._vsec = list()
self._isec = list()
self._dcell = list()
if times is not None:
if not isinstance(times, (list, np.ndarray)):
raise TypeError("'times' is an np.ndarray of simulation times")
Expand Down Expand Up @@ -225,6 +229,10 @@ def vsec(self):
@property
def isec(self):
return self._isec

@property
def dcell(self):
return self._dcell

@property
def times(self):
Expand Down Expand Up @@ -423,6 +431,10 @@ def to_dict(self):
# Turn `int` gid keys into string values for hdf5 format
trial = dict((str(key), val) for key, val in trial.items())
cell_response_data['isec'].append(trial)
for trial in self.dcell:
# Turn `int` gid keys into string values for hdf5 format
trial = dict((str(key), val) for key, val in trial.items())
cell_response_data['dcell'].append(trial)
cell_response_data['times'] = self.times
return cell_response_data

Expand Down
8 changes: 7 additions & 1 deletion hnn_core/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False,
record_isec=False, postproc=False):
record_isec=False, record_dcell=False, postproc=False):
"""Simulate a dipole given the experiment parameters.
Parameters
Expand All @@ -37,6 +37,8 @@ def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False,
record_isec : 'all' | 'soma' | False
Option to record voltages from all sections ('all'), or just
the soma ('soma'). Default: False.
record_dcell : bool
Option to record dipole from individual cells. Default: False.
postproc : bool
If True, smoothing (``dipole_smooth_win``) and scaling
(``dipole_scalefctr``) values are read from the parameter file, and
Expand Down Expand Up @@ -96,6 +98,10 @@ def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False,

net._params['record_isec'] = record_isec

_check_option('record_dcell', record_dcell, [True, False])

net._params['record_dcell'] = record_dcell

if postproc:
warnings.warn('The postproc-argument is deprecated and will be removed'
' in a future release of hnn-core. Please define '
Expand Down
12 changes: 12 additions & 0 deletions hnn_core/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def simulation_time():
for sec_name, isec in isec_dict.items():
isec_py[gid][sec_name] = {
key: isec.to_python() for key, isec in isec.items()}

dcell_py = dict()
for gid, dcell in neuron_net._dcell.items():
dcell_py[gid] = dcell.to_python()

dpl_data = np.c_[
neuron_net._nrn_dipoles['L2_pyramidal'].as_numpy() +
Expand All @@ -119,6 +123,7 @@ def simulation_time():
'gid_ranges': net.gid_ranges,
'vsec': vsec_py,
'isec': isec_py,
'dcell': dcell_py,
'rec_data': rec_arr_py,
'rec_times': rec_times_py,
'times': times.to_python()}
Expand Down Expand Up @@ -291,6 +296,7 @@ def __init__(self, net, trial_idx=0):

self._vsec = dict()
self._isec = dict()
self._dcell = dict()
self._nrn_rec_arrays = dict()
self._nrn_rec_callbacks = list()

Expand Down Expand Up @@ -562,6 +568,9 @@ def aggregate_data(self, n_samples):
nrn_dpl = self._nrn_dipoles[_long_name(cell.name)]
nrn_dpl.add(cell.dipole)

if self.net._params['record_dcell']:
self._dcell[cell.gid] = cell.dipole

self._vsec[cell.gid] = cell.vsec
self._isec[cell.gid] = cell.isec

Expand All @@ -574,6 +583,7 @@ def aggregate_data(self, n_samples):
# aggregate the currents and voltages independently on each proc
vsec_list = _PC.py_gather(self._vsec, 0)
isec_list = _PC.py_gather(self._isec, 0)
dcell_list = _PC.py_gather(self._dcell, 0)

# combine spiking data from each proc
spike_times_list = _PC.py_gather(self._spike_times, 0)
Expand All @@ -589,6 +599,8 @@ def aggregate_data(self, n_samples):
self._vsec.update(vsec)
for isec in isec_list:
self._isec.update(isec)
for dcell in dcell_list:
self._dcell.update(dcell)

_PC.barrier() # get all nodes to this place before continuing

Expand Down
1 change: 1 addition & 0 deletions hnn_core/parallel_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _gather_trial_data(sim_data, net, n_trials, postproc):
net.cell_response.update_types(net.gid_ranges)
net.cell_response._vsec.append(sim_data[idx]['vsec'])
net.cell_response._isec.append(sim_data[idx]['isec'])
net.cell_response._dcell.append(sim_data[idx]['dcell'])

# extracellular array
for arr_name, arr in net.rec_arrays.items():
Expand Down
39 changes: 35 additions & 4 deletions hnn_core/tests/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def test_dipole_simulation():
with pytest.raises(ValueError, match="Invalid value for the"):
simulate_dipole(net, tstop=25., n_trials=1, record_vsec=False,
record_isec='abc')
with pytest.raises(ValueError, match="Invalid value for the"):
simulate_dipole(net, tstop=25., n_trials=1, record_dcell='abc')

# test Network.copy() returns 'bare' network after simulating
dpl = simulate_dipole(net, tstop=25., n_trials=1)[0]
Expand All @@ -213,32 +215,39 @@ def test_cell_response_backends(run_hnn_core_fixture):

# reduced simulation has n_trials=2
trial_idx, n_trials, gid = 0, 2, 7
_, joblib_net = run_hnn_core_fixture(backend='joblib', n_jobs=1,
joblib_dpl, joblib_net = run_hnn_core_fixture(backend='joblib', n_jobs=1,
reduced=True, record_vsec='all',
record_isec='soma')
_, mpi_net = run_hnn_core_fixture(backend='mpi', n_procs=2, reduced=True,
record_vsec='all', record_isec='soma')
record_isec='soma', record_dcell=True)
mpi_dpl, mpi_net = run_hnn_core_fixture(backend='mpi', n_procs=2, reduced=True,
record_vsec='all', record_isec='soma',
record_dcell=True)
n_times = len(joblib_net.cell_response.times)

assert len(joblib_net.cell_response.vsec) == n_trials
assert len(joblib_net.cell_response.isec) == n_trials
assert len(joblib_net.cell_response.dcell) == n_trials
assert len(joblib_net.cell_response.vsec[trial_idx][gid]) == 8 # num sec
assert len(joblib_net.cell_response.isec[trial_idx][gid]) == 1
assert len(joblib_net.cell_response.vsec[
trial_idx][gid]['apical_1']) == n_times
assert len(joblib_net.cell_response.isec[
trial_idx][gid]['soma']['soma_gabaa']) == n_times
assert len(joblib_net.cell_response.isec[trial_idx][gid]) == n_times


assert len(mpi_net.cell_response.vsec) == n_trials
assert len(mpi_net.cell_response.isec) == n_trials
assert len(mpi_net.cell_response.dcell) == n_trials
assert len(mpi_net.cell_response.vsec[trial_idx][gid]) == 8 # num sec
assert len(mpi_net.cell_response.isec[trial_idx][gid]) == 1
assert len(mpi_net.cell_response.vsec[
trial_idx][gid]['apical_1']) == n_times
assert len(mpi_net.cell_response.isec[
trial_idx][gid]['soma']['soma_gabaa']) == n_times
assert len(mpi_net.cell_response.isec[trial_idx][gid]) == n_times
assert mpi_net.cell_response.vsec == joblib_net.cell_response.vsec
assert mpi_net.cell_response.isec == joblib_net.cell_response.isec
assert mpi_net.cell_response.dcell == joblib_net.cell_response.dcell

# Test if spike time falls within depolarization window above v_thresh
v_thresh = 0.0
Expand All @@ -259,6 +268,28 @@ def test_cell_response_backends(run_hnn_core_fixture):
g == gid_ran[idx_drive]]
assert_allclose(np.array(event_times), np.array(net_ets))

# test that individual cell dipoles match aggregate dipole
L5_dipole = np.array([joblib_net.cell_response.dcell[0][gid] for
gid in list(joblib_net.gid_ranges['L5_pyramidal'])])
L2_dipole = np.array([joblib_net.cell_response.dcell[0][gid]
for gid in list(joblib_net.gid_ranges['L2_pyramidal'])])
agg_dipole = np.concatenate([L2_dipole, L5_dipole], axis=0)

L5_dipole_sum = np.sum(L5_dipole, axis=0)
L2_dipole_sum = np.sum(L2_dipole, axis=0)
agg_dipole_sum = np.sum(agg_dipole, axis=0)

dipole_data = np.stack([agg_dipole_sum, L2_dipole_sum, L5_dipole_sum], axis=1)

test_dpl = Dipole(dpl[0].times, dipole_data)
N_pyr_x = joblib_net._params['N_pyr_x']
N_pyr_y = joblib_net._params['N_pyr_y']
test_dpl._baseline_renormalize(N_pyr_x, N_pyr_y)
test_dpl._convert_fAm_to_nAm()

assert np.all(test_dpl.data['agg'] == joblib_dpl[0].data['agg'])
assert np.all(test_dpl.data['agg'] == mpi_dpl[0].data['agg'])


def test_rmse():
"""Test to check RMSE calculation"""
Expand Down

0 comments on commit adffb77

Please sign in to comment.