Skip to content

Commit

Permalink
cln: Split global/local math in _solve_m_relax_and_split
Browse files Browse the repository at this point in the history
Also: move m0 initialization to __init__, fix _objective calc
  • Loading branch information
Jacob-Stevens-Haas committed Jun 2, 2024
1 parent 70fbb67 commit a4e7c41
Showing 1 changed file with 56 additions and 28 deletions.
84 changes: 56 additions & 28 deletions pysindy/optimizers/trapping_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

AnyFloat = np.dtype[np.floating[NBitBase]]
Int1D = np.ndarray[tuple[int], np.dtype[np.int_]]
Float1D = np.ndarray[tuple[int], AnyFloat]
Float2D = np.ndarray[tuple[int, int], AnyFloat]
Float3D = np.ndarray[tuple[int, int, int], AnyFloat]
Float4D = np.ndarray[tuple[int, int, int, int], AnyFloat]
Expand Down Expand Up @@ -333,6 +334,9 @@ def __post_init_guard(self):
self.mod_matrix = np.eye(self._n_tgts)
if self.A0 is None:
self.A0 = np.diag(self.gamma * np.ones(self._n_tgts))
if self.m0 is None:
np.random.seed(1)
self.m0 = (np.random.rand(self._n_tgts) - np.ones(self._n_tgts)) * 2

def set_params(self, **kwargs):
super().set_params(**kwargs)
Expand Down Expand Up @@ -533,10 +537,10 @@ def _objective(self, x, y, coef_sparse, A, PW, k):
"{0:5d} ... {1:8.3e} ... {2:8.3e} ... {3:8.2e}"
" ... {4:8.2e} ... {5:8.2e} ... {6:8.2e}".format(*row)
)
if self.method == "global":
return 0.5 * np.sum(R2) + 0.5 * np.sum(A2) / self.eta + L1
else:
return R2 + stability_term + L1 + alpha_term + beta_term
obj = R2 + stability_term + L1
if self.method == "local":
obj += alpha_term + beta_term
return obj

def _update_coef_sparse_rs(
self, n_tgts, n_features, var_len, x_expanded, y, Pmatrix, A, coef_prev
Expand Down Expand Up @@ -585,35 +589,65 @@ def _update_coef_nonsparse_rs(

return self._solve_nonsparse_relax_and_split(H, xTy, P_transpose_A, coef_prev)

def _solve_m_relax_and_split(self, m_prev, m, A, coef_sparse, tk_previous):
"""
If using the relaxation formulation of trapping SINDy, solves the
(m, A) algorithm update.
def _solve_m_relax_and_split(
self,
trap_ctr_prev: Float1D,
trap_ctr: Float1D,
A: Float2D,
coef_sparse: np.ndarray[tuple[NFeat, NTarget], AnyFloat],
tk_previous: float,
) -> tuple[Float1D, Float2D, float]:
"""Solves the (m, A) algorithm update.
TODO: explain the optimization this solves, add good names to variables,
and refactor/indirect the if global/local trapping conditionals
Returns the new trap center (m), the new A, and the new acceleration weight
"""
# prox-grad for (A, m)
# Accelerated prox gradient descent
# Calculate projection matrix from Quad terms to As
if self.accel:
tk = (1 + np.sqrt(1 + 4 * tk_previous**2)) / 2.0
m_partial = m + (tk_previous - 1.0) / tk * (m - m_prev)
m_partial = trap_ctr + (tk_previous - 1.0) / tk * (trap_ctr - trap_ctr_prev)
tk_previous = tk
mPM = np.tensordot(self.PM_, m_partial, axes=([2], [0]))
if self.method == "global":
mPQ = np.tensordot(m_partial, self.PQ_, axes=([0], [0]))
else:
mPM = np.tensordot(self.PM_, m_partial, axes=([2], [0]))
else:
if self.method == "global":
mPQ = np.tensordot(trap_ctr, self.PQ_, axes=([0], [0]))
else:
mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0]))
# Calculate As and its quad term components
if self.method == "global":
p = self.PL_ - mPQ
PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1]))
PQW = np.tensordot(self.PQ_, coef_sparse, axes=([4, 3], [0, 1]))
else:
mPM = np.tensordot(self.PM_, m, axes=([2], [0]))
p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0]))
PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1]))
PMW = np.tensordot(self.PM_, coef_sparse, axes=([4, 3], [0, 1]))
PMW = np.tensordot(self.mod_matrix, PMW, axes=([1], [0]))
p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0]))
PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1]))
PMW = np.tensordot(self.PM_, coef_sparse, axes=([4, 3], [0, 1]))
PMW = np.tensordot(self.mod_matrix, PMW, axes=([1], [0]))
# Calculate error in quadratic balance, and adjust trap center
A_b = (A - PW) / self.eta
PMT_PW = np.tensordot(PMW, A_b, axes=([2, 1], [0, 1]))
if self.accel:
m_new = m_partial - self.alpha_m * PMT_PW
if self.method == "global":
PQWT_PW = np.tensordot(PQW, A_b, axes=([2, 1], [0, 1]))
if self.accel:
trap_new = m_partial - self.alpha_m * PQWT_PW
else:
trap_new = trap_ctr_prev - self.alpha_m * PQWT_PW
else:
m_new = m_prev - self.alpha_m * PMT_PW
m_current = m_new
PMT_PW = np.tensordot(PMW, A_b, axes=([2, 1], [0, 1]))
if self.accel:
trap_new = m_partial - self.alpha_m * PMT_PW
else:
trap_new = trap_ctr_prev - self.alpha_m * PMT_PW

# Update A
A_new = self._update_A(A - self.alpha_A * A_b, PW)
return m_current, m_new, A_new, tk_previous
return trap_new, A_new, tk_previous

def _solve_nonsparse_relax_and_split(self, H, xTy, P_transpose_A, coef_prev):
"""Update for the coefficients if threshold = 0."""
Expand Down Expand Up @@ -690,13 +724,7 @@ def _reduce(self, x, y):

A = self.A0
self.A_history_.append(A)

# initial guess for m
if self.m0 is not None:
trap_ctr = self.m0
else:
np.random.seed(1)
trap_ctr = (np.random.rand(n_tgts) - np.ones(n_tgts)) * 2
trap_ctr = self.m0
self.m_history_.append(trap_ctr)

# Precompute some objects for optimization
Expand Down

0 comments on commit a4e7c41

Please sign in to comment.