Skip to content

Commit

Permalink
doc: Improve variable names and docstrings for m, A update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Sep 4, 2024
1 parent d039913 commit 9cac0f9
Showing 1 changed file with 38 additions and 15 deletions.
53 changes: 38 additions & 15 deletions pysindy/optimizers/trapping_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,12 +463,19 @@ def _set_Ptensors(
PQ_tensor = self._build_PQ(polyterms)
PT_tensor = PQ_tensor.transpose([1, 0, 2, 3, 4])
# PM is the sum of PQ and PQ which projects out the sum of Qijk and Qjik
# These are the quadtratic terms of the energy growth
PM_tensor = cast(Float5D, PQ_tensor + PT_tensor)

return PC_tensor, PL_tensor_unsym, PL_tensor, PQ_tensor, PT_tensor, PM_tensor

def _update_A(self, A_old, PW):
"""Update the symmetrized A matrix"""
"""Update the proxy enstrophy quadratic form, :math:`A`?
Currently, this function projects a proxy of the quadratic form onto the
negative definite cone (w/tol gamma) and then "projects" the exitisting
quadratic form onto those same eigenvalues
"""
eigvals, eigvecs = np.linalg.eigh(A_old)
eigPW, eigvecsPW = np.linalg.eigh(PW)
r = A_old.shape[0]
Expand Down Expand Up @@ -596,32 +603,48 @@ def _solve_m_relax_and_split(
prev_A: Float2D,
coef_sparse: np.ndarray[tuple[NFeat, NTarget], AnyFloat],
) -> tuple[Float1D, Float2D]:
"""Solves the (m, A) algorithm update.
r"""Updates the trap center
Ideally, the step would find a trap center that reduces the enstrophy
quadratic form as close as possible to the negative semidefinite cone.
.. math::
\underset{m, A\in \mathcal S^{--}}{\arg\min}||(L-Qm)^S - A||^2
TODO: explain the optimization this solves, add good names to variables,
and refactor/indirect the if global/local trapping conditionals
where the trap center is :math:`m`. However, the algorithm simply
performs one step of gradient update on the trap center and a
gradient-like step of the proxy enstrophy quadratic form.
Returns the new trap center (m) and the new A
TODO: improve variable names, test out variants such as completely
optimizing over trap center, limiting A update to projection onto
negative definite cone, or using updated trap center in A update.
See eqn 31-35 in Kaptanoglu et al 2021 and Algorithm 1
Returns:
new trap center (:math:`m`) and proxy enstrophy quadratic terms
(:math:`A`)
"""
# prox-gradient descent for (A, m)
# Calculate As
p_AS = _create_A_symm(self.PL_unsym_, self.PM_, trap_ctr, self.enstrophy)
PW = np.tensordot(p_AS, coef_sparse, axes=([3, 2], [0, 1]))
AS_coeff = np.tensordot(p_AS, coef_sparse, axes=([3, 2], [0, 1]))

# Calculate error in quadratic balance, and adjust trap center
relax_err_wrt_proxy = (prev_A - AS_coeff) / self.eta
# Calculate quadratic terms of As as a function of m
PMW = np.tensordot(self.PM_, coef_sparse, axes=([4, 3], [0, 1]))
PMW = np.einsum(
"ya,abc,bz->yzc", self.enstrophy.P_root, PMW, self.enstrophy.P_root_inv
A_wrt_m = np.tensordot(self.PM_, coef_sparse, axes=([4, 3], [0, 1]))
A_wrt_m = np.einsum(
"ya,abc,bz->yzc", self.enstrophy.P_root, A_wrt_m, self.enstrophy.P_root_inv
)
PMW = (PMW + np.transpose(PMW, [1, 0, 2])) / 2
# Calculate error in quadratic balance, and adjust trap center
A_b = (prev_A - PW) / self.eta
# PQWT_PW is gradient of some loss in m
PMT_PW = np.tensordot(PMW, A_b, axes=([2, 1], [0, 1]))
A_wrt_m = (A_wrt_m + np.transpose(A_wrt_m, [1, 0, 2])) / 2
# PMT_PW is gradient of relaxation wrt trap center (eqn 35)
PMT_PW = np.tensordot(A_wrt_m, relax_err_wrt_proxy, axes=([2, 1], [0, 1]))
trap_new = trap_ctr - self.alpha_m * PMT_PW

# Update A
A_new = self._update_A(prev_A - self.alpha_A * A_b, PW)
A_new = self._update_A(prev_A - self.alpha_A * relax_err_wrt_proxy, AS_coeff)
return trap_new, A_new

def _solve_nonsparse_relax_and_split(self, hess, gradient_constant):
Expand Down

0 comments on commit 9cac0f9

Please sign in to comment.