diff --git a/hnn_core/cell_response.py b/hnn_core/cell_response.py index e36fe3b36..cd1500549 100644 --- a/hnn_core/cell_response.py +++ b/hnn_core/cell_response.py @@ -56,6 +56,9 @@ class CellResponse(object): isec : list (n_trials,) of dict, shape Each element of the outer list is a trial. Dictionary indexed by gids containing currents for cell sections. + dcell : list (n_trials,) of dict, shape + Each element of the outer list is a trial. + Dictionary indexed by gids containing dipoles for individual cells. times : array-like, shape (n_times,) Array of time points for samples in continuous data. This includes vsoma and isoma. @@ -115,6 +118,7 @@ def __init__(self, spike_times=None, spike_gids=None, spike_types=None, self._spike_types = spike_types self._vsec = list() self._isec = list() + self._dcell = list() if times is not None: if not isinstance(times, (list, np.ndarray)): raise TypeError("'times' is an np.ndarray of simulation times") @@ -225,6 +229,10 @@ def vsec(self): @property def isec(self): return self._isec + + @property + def dcell(self): + return self._dcell @property def times(self): @@ -423,6 +431,10 @@ def to_dict(self): # Turn `int` gid keys into string values for hdf5 format trial = dict((str(key), val) for key, val in trial.items()) cell_response_data['isec'].append(trial) + for trial in self.dcell: + # Turn `int` gid keys into string values for hdf5 format + trial = dict((str(key), val) for key, val in trial.items()) + cell_response_data['dcell'].append(trial) cell_response_data['times'] = self.times return cell_response_data diff --git a/hnn_core/dipole.py b/hnn_core/dipole.py index 28fe515a4..8e2abf3f6 100644 --- a/hnn_core/dipole.py +++ b/hnn_core/dipole.py @@ -16,7 +16,7 @@ def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False, - record_isec=False, postproc=False): + record_isec=False, record_dcell=False, postproc=False): """Simulate a dipole given the experiment parameters. Parameters @@ -37,6 +37,8 @@ def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False, record_isec : 'all' | 'soma' | False Option to record voltages from all sections ('all'), or just the soma ('soma'). Default: False. + record_dcell : bool + Option to record dipole from individual cells. Default: False. postproc : bool If True, smoothing (``dipole_smooth_win``) and scaling (``dipole_scalefctr``) values are read from the parameter file, and @@ -96,6 +98,10 @@ def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False, net._params['record_isec'] = record_isec + _check_option('record_dcell', record_dcell, [True, False]) + + net._params['record_dcell'] = record_dcell + if postproc: warnings.warn('The postproc-argument is deprecated and will be removed' ' in a future release of hnn-core. Please define ' diff --git a/hnn_core/network_builder.py b/hnn_core/network_builder.py index 0d7c33f46..438f68f8a 100644 --- a/hnn_core/network_builder.py +++ b/hnn_core/network_builder.py @@ -99,6 +99,10 @@ def simulation_time(): for sec_name, isec in isec_dict.items(): isec_py[gid][sec_name] = { key: isec.to_python() for key, isec in isec.items()} + + dcell_py = dict() + for gid, dcell in neuron_net._dcell.items(): + dcell_py[gid] = dcell.to_python() dpl_data = np.c_[ neuron_net._nrn_dipoles['L2_pyramidal'].as_numpy() + @@ -119,6 +123,7 @@ def simulation_time(): 'gid_ranges': net.gid_ranges, 'vsec': vsec_py, 'isec': isec_py, + 'dcell': dcell_py, 'rec_data': rec_arr_py, 'rec_times': rec_times_py, 'times': times.to_python()} @@ -291,6 +296,7 @@ def __init__(self, net, trial_idx=0): self._vsec = dict() self._isec = dict() + self._dcell = dict() self._nrn_rec_arrays = dict() self._nrn_rec_callbacks = list() @@ -562,6 +568,9 @@ def aggregate_data(self, n_samples): nrn_dpl = self._nrn_dipoles[_long_name(cell.name)] nrn_dpl.add(cell.dipole) + if self.net._params['record_dcell']: + self._dcell[cell.gid] = cell.dipole + self._vsec[cell.gid] = cell.vsec self._isec[cell.gid] = cell.isec @@ -574,6 +583,7 @@ def aggregate_data(self, n_samples): # aggregate the currents and voltages independently on each proc vsec_list = _PC.py_gather(self._vsec, 0) isec_list = _PC.py_gather(self._isec, 0) + dcell_list = _PC.py_gather(self._dcell, 0) # combine spiking data from each proc spike_times_list = _PC.py_gather(self._spike_times, 0) @@ -589,6 +599,8 @@ def aggregate_data(self, n_samples): self._vsec.update(vsec) for isec in isec_list: self._isec.update(isec) + for dcell in dcell_list: + self._dcell.update(dcell) _PC.barrier() # get all nodes to this place before continuing diff --git a/hnn_core/parallel_backends.py b/hnn_core/parallel_backends.py index b022e5965..c2d311e6a 100644 --- a/hnn_core/parallel_backends.py +++ b/hnn_core/parallel_backends.py @@ -53,6 +53,7 @@ def _gather_trial_data(sim_data, net, n_trials, postproc): net.cell_response.update_types(net.gid_ranges) net.cell_response._vsec.append(sim_data[idx]['vsec']) net.cell_response._isec.append(sim_data[idx]['isec']) + net.cell_response._dcell.append(sim_data[idx]['dcell']) # extracellular array for arr_name, arr in net.rec_arrays.items(): diff --git a/hnn_core/tests/test_dipole.py b/hnn_core/tests/test_dipole.py index 22af87b3c..fd763f1f3 100644 --- a/hnn_core/tests/test_dipole.py +++ b/hnn_core/tests/test_dipole.py @@ -187,6 +187,8 @@ def test_dipole_simulation(): with pytest.raises(ValueError, match="Invalid value for the"): simulate_dipole(net, tstop=25., n_trials=1, record_vsec=False, record_isec='abc') + with pytest.raises(ValueError, match="Invalid value for the"): + simulate_dipole(net, tstop=25., n_trials=1, record_dcell='abc') # test Network.copy() returns 'bare' network after simulating dpl = simulate_dipole(net, tstop=25., n_trials=1)[0] @@ -213,32 +215,39 @@ def test_cell_response_backends(run_hnn_core_fixture): # reduced simulation has n_trials=2 trial_idx, n_trials, gid = 0, 2, 7 - _, joblib_net = run_hnn_core_fixture(backend='joblib', n_jobs=1, + joblib_dpl, joblib_net = run_hnn_core_fixture(backend='joblib', n_jobs=1, reduced=True, record_vsec='all', - record_isec='soma') - _, mpi_net = run_hnn_core_fixture(backend='mpi', n_procs=2, reduced=True, - record_vsec='all', record_isec='soma') + record_isec='soma', record_dcell=True) + mpi_dpl, mpi_net = run_hnn_core_fixture(backend='mpi', n_procs=2, reduced=True, + record_vsec='all', record_isec='soma', + record_dcell=True) n_times = len(joblib_net.cell_response.times) assert len(joblib_net.cell_response.vsec) == n_trials assert len(joblib_net.cell_response.isec) == n_trials + assert len(joblib_net.cell_response.dcell) == n_trials assert len(joblib_net.cell_response.vsec[trial_idx][gid]) == 8 # num sec assert len(joblib_net.cell_response.isec[trial_idx][gid]) == 1 assert len(joblib_net.cell_response.vsec[ trial_idx][gid]['apical_1']) == n_times assert len(joblib_net.cell_response.isec[ trial_idx][gid]['soma']['soma_gabaa']) == n_times + assert len(joblib_net.cell_response.isec[trial_idx][gid]) == n_times + assert len(mpi_net.cell_response.vsec) == n_trials assert len(mpi_net.cell_response.isec) == n_trials + assert len(mpi_net.cell_response.dcell) == n_trials assert len(mpi_net.cell_response.vsec[trial_idx][gid]) == 8 # num sec assert len(mpi_net.cell_response.isec[trial_idx][gid]) == 1 assert len(mpi_net.cell_response.vsec[ trial_idx][gid]['apical_1']) == n_times assert len(mpi_net.cell_response.isec[ trial_idx][gid]['soma']['soma_gabaa']) == n_times + assert len(mpi_net.cell_response.isec[trial_idx][gid]) == n_times assert mpi_net.cell_response.vsec == joblib_net.cell_response.vsec assert mpi_net.cell_response.isec == joblib_net.cell_response.isec + assert mpi_net.cell_response.dcell == joblib_net.cell_response.dcell # Test if spike time falls within depolarization window above v_thresh v_thresh = 0.0 @@ -259,6 +268,28 @@ def test_cell_response_backends(run_hnn_core_fixture): g == gid_ran[idx_drive]] assert_allclose(np.array(event_times), np.array(net_ets)) + # test that individual cell dipoles match aggregate dipole + L5_dipole = np.array([joblib_net.cell_response.dcell[0][gid] for + gid in list(joblib_net.gid_ranges['L5_pyramidal'])]) + L2_dipole = np.array([joblib_net.cell_response.dcell[0][gid] + for gid in list(joblib_net.gid_ranges['L2_pyramidal'])]) + agg_dipole = np.concatenate([L2_dipole, L5_dipole], axis=0) + + L5_dipole_sum = np.sum(L5_dipole, axis=0) + L2_dipole_sum = np.sum(L2_dipole, axis=0) + agg_dipole_sum = np.sum(agg_dipole, axis=0) + + dipole_data = np.stack([agg_dipole_sum, L2_dipole_sum, L5_dipole_sum], axis=1) + + test_dpl = Dipole(dpl[0].times, dipole_data) + N_pyr_x = joblib_net._params['N_pyr_x'] + N_pyr_y = joblib_net._params['N_pyr_y'] + test_dpl._baseline_renormalize(N_pyr_x, N_pyr_y) + test_dpl._convert_fAm_to_nAm() + + assert np.all(test_dpl.data['agg'] == joblib_dpl[0].data['agg']) + assert np.all(test_dpl.data['agg'] == mpi_dpl[0].data['agg']) + def test_rmse(): """Test to check RMSE calculation"""