-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Add dipole recording for individual cells #682
base: master
Are you sure you want to change the base?
Changes from 2 commits
adffb77
0b120fc
4345cad
276e802
da91881
05ab75d
95ff6ce
2de40b6
39dc47b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,6 +99,10 @@ def simulation_time(): | |
for sec_name, isec in isec_dict.items(): | ||
isec_py[gid][sec_name] = { | ||
key: isec.to_python() for key, isec in isec.items()} | ||
|
||
dcell_py = dict() | ||
for gid, dcell in neuron_net._dcell.items(): | ||
dcell_py[gid] = dcell.to_python() | ||
|
||
dpl_data = np.c_[ | ||
neuron_net._nrn_dipoles['L2_pyramidal'].as_numpy() + | ||
|
@@ -119,6 +123,7 @@ def simulation_time(): | |
'gid_ranges': net.gid_ranges, | ||
'vsec': vsec_py, | ||
'isec': isec_py, | ||
'dcell': dcell_py, | ||
'rec_data': rec_arr_py, | ||
'rec_times': rec_times_py, | ||
'times': times.to_python()} | ||
|
@@ -291,6 +296,7 @@ def __init__(self, net, trial_idx=0): | |
|
||
self._vsec = dict() | ||
self._isec = dict() | ||
self._dcell = dict() | ||
self._nrn_rec_arrays = dict() | ||
self._nrn_rec_callbacks = list() | ||
|
||
|
@@ -562,6 +568,9 @@ def aggregate_data(self, n_samples): | |
nrn_dpl = self._nrn_dipoles[_long_name(cell.name)] | ||
nrn_dpl.add(cell.dipole) | ||
|
||
if self.net._params['record_dcell']: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay so you may have opened a pandora's box. If you want to go down that road, we can provide There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about I switch this to a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with @jasmainak: cell-level dipoles should be accessible to the user via a similar API as other cell-level outputs. @ntolley, WDYT of hashing out an improved simulated_dipole API with existing features first in a separate MAINT PR and keep this an ENH PR. I'm imagining this getting super huge and unwieldy.... |
||
self._dcell[cell.gid] = cell.dipole | ||
|
||
self._vsec[cell.gid] = cell.vsec | ||
self._isec[cell.gid] = cell.isec | ||
|
||
|
@@ -574,6 +583,7 @@ def aggregate_data(self, n_samples): | |
# aggregate the currents and voltages independently on each proc | ||
vsec_list = _PC.py_gather(self._vsec, 0) | ||
isec_list = _PC.py_gather(self._isec, 0) | ||
dcell_list = _PC.py_gather(self._dcell, 0) | ||
|
||
# combine spiking data from each proc | ||
spike_times_list = _PC.py_gather(self._spike_times, 0) | ||
|
@@ -589,6 +599,8 @@ def aggregate_data(self, n_samples): | |
self._vsec.update(vsec) | ||
for isec in isec_list: | ||
self._isec.update(isec) | ||
for dcell in dcell_list: | ||
self._dcell.update(dcell) | ||
|
||
_PC.barrier() # get all nodes to this place before continuing | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -187,6 +187,8 @@ def test_dipole_simulation(): | |
with pytest.raises(ValueError, match="Invalid value for the"): | ||
simulate_dipole(net, tstop=25., n_trials=1, record_vsec=False, | ||
record_isec='abc') | ||
with pytest.raises(ValueError, match="Invalid value for the"): | ||
simulate_dipole(net, tstop=25., n_trials=1, record_dcell='abc') | ||
|
||
# test Network.copy() returns 'bare' network after simulating | ||
dpl = simulate_dipole(net, tstop=25., n_trials=1)[0] | ||
|
@@ -213,32 +215,39 @@ def test_cell_response_backends(run_hnn_core_fixture): | |
|
||
# reduced simulation has n_trials=2 | ||
trial_idx, n_trials, gid = 0, 2, 7 | ||
_, joblib_net = run_hnn_core_fixture(backend='joblib', n_jobs=1, | ||
joblib_dpl, joblib_net = run_hnn_core_fixture(backend='joblib', n_jobs=1, | ||
reduced=True, record_vsec='all', | ||
record_isec='soma') | ||
_, mpi_net = run_hnn_core_fixture(backend='mpi', n_procs=2, reduced=True, | ||
record_vsec='all', record_isec='soma') | ||
record_isec='soma', record_dcell=True) | ||
mpi_dpl, mpi_net = run_hnn_core_fixture(backend='mpi', n_procs=2, reduced=True, | ||
record_vsec='all', record_isec='soma', | ||
record_dcell=True) | ||
n_times = len(joblib_net.cell_response.times) | ||
|
||
assert len(joblib_net.cell_response.vsec) == n_trials | ||
assert len(joblib_net.cell_response.isec) == n_trials | ||
assert len(joblib_net.cell_response.dcell) == n_trials | ||
assert len(joblib_net.cell_response.vsec[trial_idx][gid]) == 8 # num sec | ||
assert len(joblib_net.cell_response.isec[trial_idx][gid]) == 1 | ||
assert len(joblib_net.cell_response.vsec[ | ||
trial_idx][gid]['apical_1']) == n_times | ||
assert len(joblib_net.cell_response.isec[ | ||
trial_idx][gid]['soma']['soma_gabaa']) == n_times | ||
assert len(joblib_net.cell_response.isec[trial_idx][gid]) == n_times | ||
|
||
|
||
assert len(mpi_net.cell_response.vsec) == n_trials | ||
assert len(mpi_net.cell_response.isec) == n_trials | ||
assert len(mpi_net.cell_response.dcell) == n_trials | ||
assert len(mpi_net.cell_response.vsec[trial_idx][gid]) == 8 # num sec | ||
assert len(mpi_net.cell_response.isec[trial_idx][gid]) == 1 | ||
assert len(mpi_net.cell_response.vsec[ | ||
trial_idx][gid]['apical_1']) == n_times | ||
assert len(mpi_net.cell_response.isec[ | ||
trial_idx][gid]['soma']['soma_gabaa']) == n_times | ||
assert len(mpi_net.cell_response.isec[trial_idx][gid]) == n_times | ||
assert mpi_net.cell_response.vsec == joblib_net.cell_response.vsec | ||
assert mpi_net.cell_response.isec == joblib_net.cell_response.isec | ||
assert mpi_net.cell_response.dcell == joblib_net.cell_response.dcell | ||
|
||
# Test if spike time falls within depolarization window above v_thresh | ||
v_thresh = 0.0 | ||
|
@@ -259,6 +268,28 @@ def test_cell_response_backends(run_hnn_core_fixture): | |
g == gid_ran[idx_drive]] | ||
assert_allclose(np.array(event_times), np.array(net_ets)) | ||
|
||
# test that individual cell dipoles match aggregate dipole | ||
L5_dipole = np.array([joblib_net.cell_response.dcell[0][gid] for | ||
gid in list(joblib_net.gid_ranges['L5_pyramidal'])]) | ||
L2_dipole = np.array([joblib_net.cell_response.dcell[0][gid] | ||
for gid in list(joblib_net.gid_ranges['L2_pyramidal'])]) | ||
agg_dipole = np.concatenate([L2_dipole, L5_dipole], axis=0) | ||
|
||
L5_dipole_sum = np.sum(L5_dipole, axis=0) | ||
L2_dipole_sum = np.sum(L2_dipole, axis=0) | ||
agg_dipole_sum = np.sum(agg_dipole, axis=0) | ||
|
||
dipole_data = np.stack([agg_dipole_sum, L2_dipole_sum, L5_dipole_sum], axis=1) | ||
|
||
test_dpl = Dipole(joblib_dpl[0].times, dipole_data) | ||
N_pyr_x = joblib_net._params['N_pyr_x'] | ||
N_pyr_y = joblib_net._params['N_pyr_y'] | ||
test_dpl._baseline_renormalize(N_pyr_x, N_pyr_y) | ||
test_dpl._convert_fAm_to_nAm() | ||
Comment on lines
+287
to
+288
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a little bit uncomfortable with two code paths to achieve the same objective ... the difference in the processing streams is likely to be a source of confusion for users in the future. Could they be harmonized somehow? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you mean testing wise? the main motivation here is that there are some very specialized transformations to take dipoles of individual cells to the aggregate dipole I figured a good test would be making sure that the data pulled out from earlier in the pipeline matches the data produced at the end of the pipeline when we apply the same transformations There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also I admit the baseline renormalize code is a mystery to me right now. It gets applied inside the simulation call automatically... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the |
||
|
||
assert np.all(test_dpl.data['agg'] == joblib_dpl[0].data['agg']) | ||
assert np.all(test_dpl.data['agg'] == mpi_dpl[0].data['agg']) | ||
|
||
|
||
def test_rmse(): | ||
"""Test to check RMSE calculation""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a computational hit when you do this? I think users should be warned ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is, it is substantially less than
record_isec
andrecord_vsec
because the individual cell dipoles have to be recorded by default. The only difference here is that we save them at the end instead of just adding them together.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With that said I personally haven't noticed a performance hit (will need to followup with timed tests), the main concern would be taking up too much RAM if the recording is really long