From e6013cdccdbc66d89f6260fd59be71949c9005bb Mon Sep 17 00:00:00 2001 From: Alan Kaptanoglu Date: Fri, 28 Jun 2024 10:50:34 -0400 Subject: [PATCH] Removed the if statements between global and local so they use the same code for computing PW now. Removed the option for accelerated prox grad descent since this seemed not to work very well anyways. Updated the trapping paper examples script. Everything now looks good except the constraint matrix is not quite right for the MHD case. --- .../trapping_sindy_paper_examples.py | 3 +- .../trapping_utils.py | 5 -- pysindy/optimizers/trapping_sr3.py | 75 ++++--------------- 3 files changed, 17 insertions(+), 66 deletions(-) diff --git a/examples/8_trapping_sindy_examples/trapping_sindy_paper_examples.py b/examples/8_trapping_sindy_examples/trapping_sindy_paper_examples.py index 649c3940e..54a926c13 100644 --- a/examples/8_trapping_sindy_examples/trapping_sindy_paper_examples.py +++ b/examples/8_trapping_sindy_examples/trapping_sindy_paper_examples.py @@ -376,7 +376,6 @@ eta = 1.0e3 alpha_m = 8e-1 * eta # default is 1e-2 * eta so this speeds up the code here -accel = False # run trapping SINDy sindy_opt = ps.TrappingSR3( @@ -385,7 +384,6 @@ threshold=threshold, eta=eta, alpha_m=alpha_m, - accel=accel, max_iter=max_iter, gamma=-1, verbose=True, @@ -547,6 +545,7 @@ L = np.tensordot(PL_tensor, Xi, axes=([3, 2], [0, 1])) Q = np.tensordot(PQ_tensor, Xi, axes=([4, 3], [0, 1])) Q_sum = np.max(np.abs((Q + np.transpose(Q, [1, 2, 0]) + np.transpose(Q, [2, 0, 1])))) +print((Q + np.transpose(Q, [1, 2, 0]) + np.transpose(Q, [2, 0, 1]))) print("Max deviation from the constraints = ", Q_sum) # plotting and analysis diff --git a/examples/8_trapping_sindy_examples/trapping_utils.py b/examples/8_trapping_sindy_examples/trapping_utils.py index 3ff39f238..8ad9c8d2e 100644 --- a/examples/8_trapping_sindy_examples/trapping_utils.py +++ b/examples/8_trapping_sindy_examples/trapping_utils.py @@ -268,7 +268,6 @@ def make_3d_plots(x_test, x_test_pred, filename): ax.set_zticklabels([]) ax.set_axis_off() plt.legend(fontsize=14) - plt.show() # Plot the SINDy fits of X and Xdot against the ground truth @@ -295,8 +294,6 @@ def make_fits(r, t, xdot_test, xdot_test_pred, x_test, x_test_pred, filename): if i == r - 1: plt.xlabel("t", fontsize=18) - plt.show() - # Plot errors between m_{k+1} and m_k and similarly for the model coefficients def make_progress_plots(r, sindy_opt): @@ -653,7 +650,6 @@ def trapping_region(r, x_test_pred, Xi, sindy_opt, filename): ax.set_yticklabels([]) ax.set_zticklabels([]) ax.set_axis_off() - plt.show() # Make Lissajou figures with ground truth and SINDy model @@ -683,4 +679,3 @@ def make_lissajou(r, x_train, x_test, x_train_pred, x_test_pred, filename): plt.ylabel(r"$x_" + str(i) + r"$", fontsize=18) if i == r - 1: plt.xlabel(r"$x_" + str(j) + r"$", fontsize=18) - plt.show() diff --git a/pysindy/optimizers/trapping_sr3.py b/pysindy/optimizers/trapping_sr3.py index e19c4f1e0..48c99edf7 100644 --- a/pysindy/optimizers/trapping_sr3.py +++ b/pysindy/optimizers/trapping_sr3.py @@ -116,10 +116,6 @@ class TrappingSR3(ConstrainedSR3): could be straightforwardly implemented, but L0 requires reformulation because of nonconvexity. (default 'L1') - accel : - Whether or not to use accelerated prox-gradient descent for (m, A). - (default False) - m0 : Initial guess for trap center in the optimization. Default None initializes vector elements randomly in [-1, 1]. shape (n_targets) @@ -204,7 +200,6 @@ def __init__( gamma: float = -0.1, tol_m: float = 1e-5, thresholder: str = "l1", - accel: bool = False, m0: Union[NDArray, None] = None, A0: Union[NDArray, None] = None, **kwargs, @@ -272,7 +267,6 @@ def __init__( self.beta = beta self.gamma = gamma self.tol_m = tol_m - self.accel = accel self.__post_init_guard() def __post_init_guard(self): @@ -584,64 +578,34 @@ def _update_coef_nonsparse_rs( def _solve_m_relax_and_split( self, - trap_ctr_prev: Float1D, trap_ctr: Float1D, A: Float2D, coef_sparse: np.ndarray[tuple[NFeat, NTarget], AnyFloat], - tk_previous: float, - ) -> tuple[Float1D, Float2D, float]: + ) -> tuple[Float1D, Float2D]: """Solves the (m, A) algorithm update. TODO: explain the optimization this solves, add good names to variables, and refactor/indirect the if global/local trapping conditionals - Returns the new trap center (m), the new A, and the new acceleration weight + Returns the new trap center (m) and the new A """ - # prox-grad for (A, m) - # Accelerated prox gradient descent + # prox-gradient descent for (A, m) # Calculate projection matrix from Quad terms to As - if self.accel: - tk = (1 + np.sqrt(1 + 4 * tk_previous**2)) / 2.0 - m_partial = trap_ctr + (tk_previous - 1.0) / tk * (trap_ctr - trap_ctr_prev) - tk_previous = tk - if self.method == "global": - mPQ = np.tensordot(m_partial, self.PQ_, axes=([0], [0])) - else: - mPM = np.tensordot(self.PM_, m_partial, axes=([2], [0])) - else: - if self.method == "global": - mPQ = np.tensordot(trap_ctr, self.PQ_, axes=([0], [0])) - else: - mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0])) + mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0])) + p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0])) + PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1])) # Calculate As and its quad term components - if self.method == "global": - p = self.PL_ - mPQ - PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1])) - PQW = np.tensordot(self.PQ_, coef_sparse, axes=([4, 3], [0, 1])) - else: - p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0])) - PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1])) - PMW = np.tensordot(self.PM_, coef_sparse, axes=([4, 3], [0, 1])) - PMW = np.tensordot(self.mod_matrix, PMW, axes=([1], [0])) + PMW = np.tensordot(self.PM_, coef_sparse, axes=([4, 3], [0, 1])) + PMW = np.tensordot(self.mod_matrix, PMW, axes=([1], [0])) # Calculate error in quadratic balance, and adjust trap center A_b = (A - PW) / self.eta - if self.method == "global": - # PQWT_PW is gradient of some loss in m - PQWT_PW = np.tensordot(PQW, A_b, axes=([2, 1], [0, 1])) - if self.accel: - trap_new = m_partial - self.alpha_m * PQWT_PW - else: - trap_new = trap_ctr - self.alpha_m * PQWT_PW - else: - PMT_PW = np.tensordot(PMW, A_b, axes=([2, 1], [0, 1])) - if self.accel: - trap_new = m_partial - self.alpha_m * PMT_PW - else: - trap_new = trap_ctr - self.alpha_m * PMT_PW + # PQWT_PW is gradient of some loss in m + PMT_PW = np.tensordot(PMW, A_b, axes=([2, 1], [0, 1])) + trap_new = trap_ctr - self.alpha_m * PMT_PW # Update A A_new = self._update_A(A - self.alpha_A * A_b, PW) - return trap_new, A_new, tk_previous + return trap_new, A_new def _solve_nonsparse_relax_and_split(self, H, xTy, P_transpose_A, coef_prev): """Update for the coefficients if threshold = 0.""" @@ -729,20 +693,15 @@ def _reduce(self, x, y): xTx = np.dot(x_expanded.T, x_expanded) xTy = np.dot(x_expanded.T, y.flatten()) - # if using acceleration - tk_prev = 1 + # keep track of last solution in case method fails trap_prev_ctr = trap_ctr # Begin optimization loop objective_history = [] for k in range(self.max_iter): # update P tensor from the newest trap center - if self.method == "global": - mPQ = np.tensordot(trap_ctr, self.PQ_, axes=([0], [0])) - p = self.PL_ - mPQ - else: - mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0])) - p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0])) + mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0])) + p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0])) Pmatrix = p.reshape(n_tgts * n_tgts, n_tgts * n_features) self.p_history_.append(p) @@ -761,9 +720,7 @@ def _reduce(self, x, y): coef_sparse = coef_prev break - trap_ctr, A, tk_prev = self._solve_m_relax_and_split( - trap_prev_ctr, trap_ctr, A, coef_sparse, tk_prev - ) + trap_ctr, A = self._solve_m_relax_and_split(trap_ctr, A, coef_sparse) # If problem over m becomes infeasible, break out of the loop if trap_ctr is None: