From 44269653831dca11540a04205db48e4b062ae818 Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 2 Jun 2024 16:12:06 -0700 Subject: [PATCH] cln: Split global/local math in _solve_m_relax_and_split Also: move m0 initialization to __init__, fix _objective calc --- pysindy/optimizers/trapping_sr3.py | 90 +++++++++++++++++++----------- 1 file changed, 58 insertions(+), 32 deletions(-) diff --git a/pysindy/optimizers/trapping_sr3.py b/pysindy/optimizers/trapping_sr3.py index 0333ee841..941aeb3e3 100644 --- a/pysindy/optimizers/trapping_sr3.py +++ b/pysindy/optimizers/trapping_sr3.py @@ -21,6 +21,7 @@ AnyFloat = np.dtype[np.floating[NBitBase]] Int1D = np.ndarray[tuple[int], np.dtype[np.int_]] +Float1D = np.ndarray[tuple[int], AnyFloat] Float2D = np.ndarray[tuple[int, int], AnyFloat] Float3D = np.ndarray[tuple[int, int, int], AnyFloat] Float4D = np.ndarray[tuple[int, int, int, int], AnyFloat] @@ -333,6 +334,9 @@ def __post_init_guard(self): self.mod_matrix = np.eye(self._n_tgts) if self.A0 is None: self.A0 = np.diag(self.gamma * np.ones(self._n_tgts)) + if self.m0 is None: + np.random.seed(1) + self.m0 = (np.random.rand(self._n_tgts) - np.ones(self._n_tgts)) * 2 def set_params(self, **kwargs): super().set_params(**kwargs) @@ -533,10 +537,10 @@ def _objective(self, x, y, coef_sparse, A, PW, k): "{0:5d} ... {1:8.3e} ... {2:8.3e} ... {3:8.2e}" " ... {4:8.2e} ... {5:8.2e} ... {6:8.2e}".format(*row) ) - if self.method == "global": - return 0.5 * np.sum(R2) + 0.5 * np.sum(A2) / self.eta + L1 - else: - return R2 + stability_term + L1 + alpha_term + beta_term + obj = R2 + stability_term + L1 + if self.method == "local": + obj += alpha_term + beta_term + return obj def _update_coef_sparse_rs( self, n_tgts, n_features, var_len, x_expanded, y, Pmatrix, A, coef_prev @@ -585,35 +589,65 @@ def _update_coef_nonsparse_rs( return self._solve_nonsparse_relax_and_split(H, xTy, P_transpose_A, coef_prev) - def _solve_m_relax_and_split(self, m_prev, m, A, coef_sparse, tk_previous): - """ - If using the relaxation formulation of trapping SINDy, solves the - (m, A) algorithm update. + def _solve_m_relax_and_split( + self, + trap_ctr_prev: Float1D, + trap_ctr: Float1D, + A: Float2D, + coef_sparse: np.ndarray[tuple[NFeat, NTarget], AnyFloat], + tk_previous: float, + ) -> tuple[Float1D, Float2D, float]: + """Solves the (m, A) algorithm update. + + TODO: explain the optimization this solves, add good names to variables, + and refactor/indirect the if global/local trapping conditionals + + Returns the new trap center (m), the new A, and the new acceleration weight """ # prox-grad for (A, m) # Accelerated prox gradient descent + # Calculate projection matrix from Quad terms to As if self.accel: tk = (1 + np.sqrt(1 + 4 * tk_previous**2)) / 2.0 - m_partial = m + (tk_previous - 1.0) / tk * (m - m_prev) + m_partial = trap_ctr + (tk_previous - 1.0) / tk * (trap_ctr - trap_ctr_prev) tk_previous = tk - mPM = np.tensordot(self.PM_, m_partial, axes=([2], [0])) + if self.method == "global": + mPQ = np.tensordot(m_partial, self.PQ_, axes=([0], [0])) + else: + mPM = np.tensordot(self.PM_, m_partial, axes=([2], [0])) else: - mPM = np.tensordot(self.PM_, m, axes=([2], [0])) - p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0])) - PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1])) - PMW = np.tensordot(self.PM_, coef_sparse, axes=([4, 3], [0, 1])) - PMW = np.tensordot(self.mod_matrix, PMW, axes=([1], [0])) + if self.method == "global": + mPQ = np.tensordot(trap_ctr, self.PQ_, axes=([0], [0])) + else: + mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0])) + # Calculate As and its quad term components + if self.method == "global": + p = self.PL_ - mPQ + PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1])) + PQW = np.tensordot(self.PQ_, coef_sparse, axes=([4, 3], [0, 1])) + else: + p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0])) + PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1])) + PMW = np.tensordot(self.PM_, coef_sparse, axes=([4, 3], [0, 1])) + PMW = np.tensordot(self.mod_matrix, PMW, axes=([1], [0])) + # Calculate error in quadratic balance, and adjust trap center A_b = (A - PW) / self.eta - PMT_PW = np.tensordot(PMW, A_b, axes=([2, 1], [0, 1])) - if self.accel: - m_new = m_partial - self.alpha_m * PMT_PW + if self.method == "global": + PQWT_PW = np.tensordot(PQW, A_b, axes=([2, 1], [0, 1])) + if self.accel: + trap_new = m_partial - self.alpha_m * PQWT_PW + else: + trap_new = trap_ctr_prev - self.alpha_m * PQWT_PW else: - m_new = m_prev - self.alpha_m * PMT_PW - m_current = m_new + PMT_PW = np.tensordot(PMW, A_b, axes=([2, 1], [0, 1])) + if self.accel: + trap_new = m_partial - self.alpha_m * PMT_PW + else: + trap_new = trap_ctr_prev - self.alpha_m * PMT_PW # Update A A_new = self._update_A(A - self.alpha_A * A_b, PW) - return m_current, m_new, A_new, tk_previous + return trap_new, A_new, tk_previous def _solve_nonsparse_relax_and_split(self, H, xTy, P_transpose_A, coef_prev): """Update for the coefficients if threshold = 0.""" @@ -668,7 +702,7 @@ def _reduce(self, x, y): # Set initial coefficients if self.use_constraints and self.constraint_order.lower() == "target": self.constraint_lhs = reorder_constraints( - self.constraint_lhs, n_features, output_order="target" + self.constraint_lhs, n_features, output_order="feature" ) coef_sparse: np.ndarray[tuple[NFeat, NTarget], AnyFloat] = self.coef_.T @@ -690,13 +724,7 @@ def _reduce(self, x, y): A = self.A0 self.A_history_.append(A) - - # initial guess for m - if self.m0 is not None: - trap_ctr = self.m0 - else: - np.random.seed(1) - trap_ctr = (np.random.rand(n_tgts) - np.ones(n_tgts)) * 2 + trap_ctr = self.m0 self.m_history_.append(trap_ctr) # Precompute some objects for optimization @@ -723,6 +751,7 @@ def _reduce(self, x, y): mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0])) p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0])) Pmatrix = p.reshape(n_tgts * n_tgts, n_tgts * n_features) + self.p_history_.append(p) coef_prev = coef_sparse if (self.threshold > 0.0) or self.inequality_constraints: @@ -756,9 +785,6 @@ def _reduce(self, x, y): eigvals, eigvecs = np.linalg.eig(PW) self.PW_history_.append(PW) self.PWeigs_history_.append(np.sort(eigvals)) - mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0])) - p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0])) - self.p_history_.append(p) # update objective objective_history.append(self._objective(x, y, coef_sparse, A, PW, k))