Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add dipole recording for individual cells #682

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions hnn_core/cell_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class CellResponse(object):
isec : list (n_trials,) of dict, shape
Each element of the outer list is a trial.
Dictionary indexed by gids containing currents for cell sections.
dcell : list (n_trials,) of dict, shape
Each element of the outer list is a trial.
Dictionary indexed by gids containing dipoles for individual cells.
times : array-like, shape (n_times,)
Array of time points for samples in continuous data.
This includes vsoma and isoma.
Expand Down Expand Up @@ -115,6 +118,7 @@ def __init__(self, spike_times=None, spike_gids=None, spike_types=None,
self._spike_types = spike_types
self._vsec = list()
self._isec = list()
self._dcell = list()
if times is not None:
if not isinstance(times, (list, np.ndarray)):
raise TypeError("'times' is an np.ndarray of simulation times")
Expand Down Expand Up @@ -225,6 +229,10 @@ def vsec(self):
@property
def isec(self):
return self._isec

@property
def dcell(self):
return self._dcell

@property
def times(self):
Expand Down Expand Up @@ -423,6 +431,10 @@ def to_dict(self):
# Turn `int` gid keys into string values for hdf5 format
trial = dict((str(key), val) for key, val in trial.items())
cell_response_data['isec'].append(trial)
for trial in self.dcell:
# Turn `int` gid keys into string values for hdf5 format
trial = dict((str(key), val) for key, val in trial.items())
cell_response_data['dcell'].append(trial)
cell_response_data['times'] = self.times
return cell_response_data

Expand Down
8 changes: 7 additions & 1 deletion hnn_core/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False,
record_isec=False, postproc=False):
record_isec=False, record_dcell=False, postproc=False):
"""Simulate a dipole given the experiment parameters.

Parameters
Expand All @@ -37,6 +37,8 @@ def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False,
record_isec : 'all' | 'soma' | False
Option to record voltages from all sections ('all'), or just
the soma ('soma'). Default: False.
record_dcell : bool
Option to record dipole from individual cells. Default: False.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a computational hit when you do this? I think users should be warned ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is, it is substantially less than record_isec and record_vsec because the individual cell dipoles have to be recorded by default. The only difference here is that we save them at the end instead of just adding them together.

Copy link
Contributor Author

@ntolley ntolley Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With that said I personally haven't noticed a performance hit (will need to followup with timed tests), the main concern would be taking up too much RAM if the recording is really long

postproc : bool
If True, smoothing (``dipole_smooth_win``) and scaling
(``dipole_scalefctr``) values are read from the parameter file, and
Expand Down Expand Up @@ -96,6 +98,10 @@ def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False,

net._params['record_isec'] = record_isec

_check_option('record_dcell', record_dcell, [True, False])

net._params['record_dcell'] = record_dcell

if postproc:
warnings.warn('The postproc-argument is deprecated and will be removed'
' in a future release of hnn-core. Please define '
Expand Down
12 changes: 12 additions & 0 deletions hnn_core/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def simulation_time():
for sec_name, isec in isec_dict.items():
isec_py[gid][sec_name] = {
key: isec.to_python() for key, isec in isec.items()}

dcell_py = dict()
for gid, dcell in neuron_net._dcell.items():
dcell_py[gid] = dcell.to_python()

dpl_data = np.c_[
neuron_net._nrn_dipoles['L2_pyramidal'].as_numpy() +
Expand All @@ -119,6 +123,7 @@ def simulation_time():
'gid_ranges': net.gid_ranges,
'vsec': vsec_py,
'isec': isec_py,
'dcell': dcell_py,
'rec_data': rec_arr_py,
'rec_times': rec_times_py,
'times': times.to_python()}
Expand Down Expand Up @@ -291,6 +296,7 @@ def __init__(self, net, trial_idx=0):

self._vsec = dict()
self._isec = dict()
self._dcell = dict()
self._nrn_rec_arrays = dict()
self._nrn_rec_callbacks = list()

Expand Down Expand Up @@ -562,6 +568,9 @@ def aggregate_data(self, n_samples):
nrn_dpl = self._nrn_dipoles[_long_name(cell.name)]
nrn_dpl.add(cell.dipole)

if self.net._params['record_dcell']:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay so record_dcell does nothing to the neuron simulation ... the dipoles are recorded no matter what. The only difference is whether they are converted from neuron into python or not ...

you may have opened a pandora's box. If you want to go down that road, we can provide cell_response.dipole ... and then in a later pull request deprecate simulate_dipole to simply simulate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about I switch this to a WIP flag and we use this as a PR to work out the simulate_dipole API? I agree this is a natural place to make this happen, but the pandora's box may need to stay half open until COSYNE abstracts are submitted 😉

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @jasmainak: cell-level dipoles should be accessible to the user via a similar API as other cell-level outputs.

@ntolley, WDYT of hashing out an improved simulated_dipole API with existing features first in a separate MAINT PR and keep this an ENH PR. I'm imagining this getting super huge and unwieldy....

self._dcell[cell.gid] = cell.dipole

self._vsec[cell.gid] = cell.vsec
self._isec[cell.gid] = cell.isec

Expand All @@ -574,6 +583,7 @@ def aggregate_data(self, n_samples):
# aggregate the currents and voltages independently on each proc
vsec_list = _PC.py_gather(self._vsec, 0)
isec_list = _PC.py_gather(self._isec, 0)
dcell_list = _PC.py_gather(self._dcell, 0)

# combine spiking data from each proc
spike_times_list = _PC.py_gather(self._spike_times, 0)
Expand All @@ -589,6 +599,8 @@ def aggregate_data(self, n_samples):
self._vsec.update(vsec)
for isec in isec_list:
self._isec.update(isec)
for dcell in dcell_list:
self._dcell.update(dcell)

_PC.barrier() # get all nodes to this place before continuing

Expand Down
1 change: 1 addition & 0 deletions hnn_core/parallel_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _gather_trial_data(sim_data, net, n_trials, postproc):
net.cell_response.update_types(net.gid_ranges)
net.cell_response._vsec.append(sim_data[idx]['vsec'])
net.cell_response._isec.append(sim_data[idx]['isec'])
net.cell_response._dcell.append(sim_data[idx]['dcell'])

# extracellular array
for arr_name, arr in net.rec_arrays.items():
Expand Down
39 changes: 35 additions & 4 deletions hnn_core/tests/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def test_dipole_simulation():
with pytest.raises(ValueError, match="Invalid value for the"):
simulate_dipole(net, tstop=25., n_trials=1, record_vsec=False,
record_isec='abc')
with pytest.raises(ValueError, match="Invalid value for the"):
simulate_dipole(net, tstop=25., n_trials=1, record_dcell='abc')

# test Network.copy() returns 'bare' network after simulating
dpl = simulate_dipole(net, tstop=25., n_trials=1)[0]
Expand All @@ -213,32 +215,39 @@ def test_cell_response_backends(run_hnn_core_fixture):

# reduced simulation has n_trials=2
trial_idx, n_trials, gid = 0, 2, 7
_, joblib_net = run_hnn_core_fixture(backend='joblib', n_jobs=1,
joblib_dpl, joblib_net = run_hnn_core_fixture(backend='joblib', n_jobs=1,
reduced=True, record_vsec='all',
record_isec='soma')
_, mpi_net = run_hnn_core_fixture(backend='mpi', n_procs=2, reduced=True,
record_vsec='all', record_isec='soma')
record_isec='soma', record_dcell=True)
mpi_dpl, mpi_net = run_hnn_core_fixture(backend='mpi', n_procs=2, reduced=True,
record_vsec='all', record_isec='soma',
record_dcell=True)
n_times = len(joblib_net.cell_response.times)

assert len(joblib_net.cell_response.vsec) == n_trials
assert len(joblib_net.cell_response.isec) == n_trials
assert len(joblib_net.cell_response.dcell) == n_trials
assert len(joblib_net.cell_response.vsec[trial_idx][gid]) == 8 # num sec
assert len(joblib_net.cell_response.isec[trial_idx][gid]) == 1
assert len(joblib_net.cell_response.vsec[
trial_idx][gid]['apical_1']) == n_times
assert len(joblib_net.cell_response.isec[
trial_idx][gid]['soma']['soma_gabaa']) == n_times
assert len(joblib_net.cell_response.isec[trial_idx][gid]) == n_times


assert len(mpi_net.cell_response.vsec) == n_trials
assert len(mpi_net.cell_response.isec) == n_trials
assert len(mpi_net.cell_response.dcell) == n_trials
assert len(mpi_net.cell_response.vsec[trial_idx][gid]) == 8 # num sec
assert len(mpi_net.cell_response.isec[trial_idx][gid]) == 1
assert len(mpi_net.cell_response.vsec[
trial_idx][gid]['apical_1']) == n_times
assert len(mpi_net.cell_response.isec[
trial_idx][gid]['soma']['soma_gabaa']) == n_times
assert len(mpi_net.cell_response.isec[trial_idx][gid]) == n_times
assert mpi_net.cell_response.vsec == joblib_net.cell_response.vsec
assert mpi_net.cell_response.isec == joblib_net.cell_response.isec
assert mpi_net.cell_response.dcell == joblib_net.cell_response.dcell

# Test if spike time falls within depolarization window above v_thresh
v_thresh = 0.0
Expand All @@ -259,6 +268,28 @@ def test_cell_response_backends(run_hnn_core_fixture):
g == gid_ran[idx_drive]]
assert_allclose(np.array(event_times), np.array(net_ets))

# test that individual cell dipoles match aggregate dipole
L5_dipole = np.array([joblib_net.cell_response.dcell[0][gid] for
gid in list(joblib_net.gid_ranges['L5_pyramidal'])])
L2_dipole = np.array([joblib_net.cell_response.dcell[0][gid]
for gid in list(joblib_net.gid_ranges['L2_pyramidal'])])
agg_dipole = np.concatenate([L2_dipole, L5_dipole], axis=0)

L5_dipole_sum = np.sum(L5_dipole, axis=0)
L2_dipole_sum = np.sum(L2_dipole, axis=0)
agg_dipole_sum = np.sum(agg_dipole, axis=0)

dipole_data = np.stack([agg_dipole_sum, L2_dipole_sum, L5_dipole_sum], axis=1)

test_dpl = Dipole(joblib_dpl[0].times, dipole_data)
N_pyr_x = joblib_net._params['N_pyr_x']
N_pyr_y = joblib_net._params['N_pyr_y']
test_dpl._baseline_renormalize(N_pyr_x, N_pyr_y)
test_dpl._convert_fAm_to_nAm()
Comment on lines +287 to +288
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little bit uncomfortable with two code paths to achieve the same objective ... the difference in the processing streams is likely to be a source of confusion for users in the future. Could they be harmonized somehow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean testing wise? the main motivation here is that there are some very specialized transformations to take dipoles of individual cells to the aggregate dipole

I figured a good test would be making sure that the data pulled out from earlier in the pipeline matches the data produced at the end of the pipeline when we apply the same transformations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also I admit the baseline renormalize code is a mystery to me right now. It gets applied inside the simulation call automatically...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the postproc flag must be respected by cell_response.dcell ... they should give consistent results without accessing private methods


assert np.all(test_dpl.data['agg'] == joblib_dpl[0].data['agg'])
assert np.all(test_dpl.data['agg'] == mpi_dpl[0].data['agg'])


def test_rmse():
"""Test to check RMSE calculation"""
Expand Down
Loading