Skip to content

Commit

Permalink
allow user to define obj_fun as callable
Browse files Browse the repository at this point in the history
  • Loading branch information
carolinafernandezp committed Jul 26, 2024
1 parent fb6885c commit 10e4e7f
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 4 deletions.
2 changes: 1 addition & 1 deletion hnn_core/optimization/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .optimize_evoked import optimize_evokedfrom .general_optimization import Optimizer
from .optimize_evoked import optimize_evokedfrom .general_optimization import Optimizer, _update_params
Expand Down
13 changes: 10 additions & 3 deletions hnn_core/optimization/general_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ def __init__(self, initial_net, tstop, constraints, set_params,
of the parameters that will be set inside the function.
solver : str
The optimizer, 'bayesian' or 'cobyla'.
obj_fun : str
The objective function to be minimized.
obj_fun : str | func
The objective function to be minimized. Can be 'dipole_rmse',
'maximize_psd', or a user-defined function. The default is
'dipole_rmse'.
max_iter : int, optional
The max number of calls to the objective function. The default is
200.
Expand Down Expand Up @@ -87,7 +89,8 @@ def __init__(self, initial_net, tstop, constraints, set_params,
self.obj_fun = _maximize_psd
self.obj_fun_name = 'maximize_psd'
else:
raise ValueError("obj_fun must be 'dipole_rmse' or 'maximize_psd'")
self.obj_fun = obj_fun # user-defined function
self.obj_fun_name = None
self.tstop = tstop
self.net_ = None
self.obj_ = list()
Expand All @@ -112,6 +115,10 @@ def fit(self, **obj_fun_kwargs):
Lower and higher limit for each frequency band.
relative_bandpower : tuple (if obj_fun='maximize_psd')
Weight for each frequency band.
scale_factor : float, optional
The dipole scale factor.
smooth_window_len : float, optional
The smooth window length.
"""
if (self.obj_fun_name == 'dipole_rmse' and
'target' not in obj_fun_kwargs):
Expand Down
119 changes: 119 additions & 0 deletions hnn_core/tests/test_general_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,122 @@ def set_params(net_offset, params):
# the number of returned rmse values should be the same as max_iter
assert (len(obj) <= max_iter), (
"Number of rmse values should be the same as max_iter")


@pytest.mark.parametrize("solver", ['bayesian', 'cobyla'])
def test_user_obj_fun(solver):
"""Test optimization routines with a user-defined optimization function."""

max_iter = 11
tstop = 200.

# simulate a dipole to establish ground-truth drive parameters
hnn_core_root = op.dirname(hnn_core.__file__)
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)

params.update({'N_pyr_x': 3,
'N_pyr_y': 3})
net_offset = jones_2009_model(params)

def maximize_csd(initial_net, initial_params, set_params, predicted_params,
update_params, obj_values, tstop, obj_fun_kwargs):

import numpy as np
from hnn_core.optimization import _update_params
from hnn_core.extracellular import (calculate_csd2d,
_get_laminar_z_coords)

params = _update_params(initial_params, predicted_params)

# simulate dpl with predicted params
new_net = initial_net.copy()
set_params(new_net, params)

# set electrode array
depths = list(range(-325, 2150, 100))
electrode_pos = [(135, 135, dep) for dep in depths]
new_net.add_electrode_array('shank1', electrode_pos)

dpl = simulate_dipole(new_net, tstop=tstop, n_trials=1)[0]

potentials = new_net.rec_arrays['shank1'][0]

# smooth
if 'smooth_window_len' in obj_fun_kwargs:
potentials.smooth(window_len=obj_fun_kwargs['smooth_window_len'])

# get csd of simulated potentials
lfp = potentials.voltages[0] # n_contacts, n_times
contact_labels, delta = _get_laminar_z_coords(potentials.positions)
csd = calculate_csd2d(lfp_data=lfp, delta=delta) # n_contacts, n_times

# for each tuple
csd_subsets = list() # band, n_contacts, n_times
for idx, t_band in enumerate(obj_fun_kwargs['t_bands']):
t_min = np.argmax(potentials.times >= t_band[0])
t_max = np.argmax(potentials.times >= t_band[1])
depth_min = np.argmax(contact_labels >=
obj_fun_kwargs['electrode_depths'][idx][0])
depth_max = np.argmax(contact_labels >=
obj_fun_kwargs['electrode_depths'][idx][1])

csd_subsets.append(sum(sum(csd[depth_min:depth_max+1,
t_min:t_max+1])))

obj = sum(csd_subsets) / sum(sum(csd))
obj_values.append(obj)

return obj

def set_params(net_offset, params):
weights_ampa = {'L2_basket': 0.5,
'L2_pyramidal': 0.5,
'L5_basket': 0.5,
'L5_pyramidal': 0.5}
synaptic_delays = {'L2_basket': 0.1, 'L2_pyramidal': 0.1,
'L5_basket': 1., 'L5_pyramidal': 1.}
net_offset.add_evoked_drive('evprox',
mu=params['mu'],
sigma=params['sigma'],
numspikes=1,
location='proximal',
weights_ampa=weights_ampa,
synaptic_delays=synaptic_delays)

# define constraints
constraints = dict()
constraints.update({'mu': (1, 200),
'sigma': (1, 15)})

optim = Optimizer(net_offset, tstop=tstop, constraints=constraints,
set_params=set_params, solver=solver,
obj_fun=maximize_csd, max_iter=max_iter)

# test exception raised
with pytest.raises(ValueError, match='The current Network instance has '
'external drives, provide a Network object with no '
'external drives.'):
net_with_drives = jones_2009_model(params, add_drives_from_params=True)
optim = Optimizer(net_with_drives,
tstop=tstop,
constraints=constraints,
set_params=set_params,
solver=solver,
obj_fun=maximize_csd,
max_iter=max_iter)

# test repr before fitting
assert 'fit=False' in repr(optim), "optimizer is already fit"

# increase power in infragranular layers (100-150 ms)
optim.fit(t_bands=[(100, 150),], electrode_depths=[(0, 200),])

# test repr after fitting
assert 'fit=True' in repr(optim), "optimizer was not fit"

# the optimized parameter is in the range
for param_idx, param in enumerate(optim.opt_params_):
assert (list(constraints.values())[param_idx][0] <= param <=
list(constraints.values())[param_idx][1]), (
"Optimized parameter is not in user-defined range")

0 comments on commit 10e4e7f

Please sign in to comment.