diff --git a/examples/8_trapping_sindy_examples/trapping_sindy_paper_examples.py b/examples/8_trapping_sindy_examples/trapping_sindy_paper_examples.py index 649c3940..54a926c1 100644 --- a/examples/8_trapping_sindy_examples/trapping_sindy_paper_examples.py +++ b/examples/8_trapping_sindy_examples/trapping_sindy_paper_examples.py @@ -376,7 +376,6 @@ eta = 1.0e3 alpha_m = 8e-1 * eta # default is 1e-2 * eta so this speeds up the code here -accel = False # run trapping SINDy sindy_opt = ps.TrappingSR3( @@ -385,7 +384,6 @@ threshold=threshold, eta=eta, alpha_m=alpha_m, - accel=accel, max_iter=max_iter, gamma=-1, verbose=True, @@ -547,6 +545,7 @@ L = np.tensordot(PL_tensor, Xi, axes=([3, 2], [0, 1])) Q = np.tensordot(PQ_tensor, Xi, axes=([4, 3], [0, 1])) Q_sum = np.max(np.abs((Q + np.transpose(Q, [1, 2, 0]) + np.transpose(Q, [2, 0, 1])))) +print((Q + np.transpose(Q, [1, 2, 0]) + np.transpose(Q, [2, 0, 1]))) print("Max deviation from the constraints = ", Q_sum) # plotting and analysis diff --git a/examples/8_trapping_sindy_examples/trapping_utils.py b/examples/8_trapping_sindy_examples/trapping_utils.py index 3ff39f23..8ad9c8d2 100644 --- a/examples/8_trapping_sindy_examples/trapping_utils.py +++ b/examples/8_trapping_sindy_examples/trapping_utils.py @@ -268,7 +268,6 @@ def make_3d_plots(x_test, x_test_pred, filename): ax.set_zticklabels([]) ax.set_axis_off() plt.legend(fontsize=14) - plt.show() # Plot the SINDy fits of X and Xdot against the ground truth @@ -295,8 +294,6 @@ def make_fits(r, t, xdot_test, xdot_test_pred, x_test, x_test_pred, filename): if i == r - 1: plt.xlabel("t", fontsize=18) - plt.show() - # Plot errors between m_{k+1} and m_k and similarly for the model coefficients def make_progress_plots(r, sindy_opt): @@ -653,7 +650,6 @@ def trapping_region(r, x_test_pred, Xi, sindy_opt, filename): ax.set_yticklabels([]) ax.set_zticklabels([]) ax.set_axis_off() - plt.show() # Make Lissajou figures with ground truth and SINDy model @@ -683,4 +679,3 @@ def make_lissajou(r, x_train, x_test, x_train_pred, x_test_pred, filename): plt.ylabel(r"$x_" + str(i) + r"$", fontsize=18) if i == r - 1: plt.xlabel(r"$x_" + str(j) + r"$", fontsize=18) - plt.show() diff --git a/pysindy/optimizers/trapping_sr3.py b/pysindy/optimizers/trapping_sr3.py index e19c4f1e..48c99edf 100644 --- a/pysindy/optimizers/trapping_sr3.py +++ b/pysindy/optimizers/trapping_sr3.py @@ -116,10 +116,6 @@ class TrappingSR3(ConstrainedSR3): could be straightforwardly implemented, but L0 requires reformulation because of nonconvexity. (default 'L1') - accel : - Whether or not to use accelerated prox-gradient descent for (m, A). - (default False) - m0 : Initial guess for trap center in the optimization. Default None initializes vector elements randomly in [-1, 1]. shape (n_targets) @@ -204,7 +200,6 @@ def __init__( gamma: float = -0.1, tol_m: float = 1e-5, thresholder: str = "l1", - accel: bool = False, m0: Union[NDArray, None] = None, A0: Union[NDArray, None] = None, **kwargs, @@ -272,7 +267,6 @@ def __init__( self.beta = beta self.gamma = gamma self.tol_m = tol_m - self.accel = accel self.__post_init_guard() def __post_init_guard(self): @@ -584,64 +578,34 @@ def _update_coef_nonsparse_rs( def _solve_m_relax_and_split( self, - trap_ctr_prev: Float1D, trap_ctr: Float1D, A: Float2D, coef_sparse: np.ndarray[tuple[NFeat, NTarget], AnyFloat], - tk_previous: float, - ) -> tuple[Float1D, Float2D, float]: + ) -> tuple[Float1D, Float2D]: """Solves the (m, A) algorithm update. TODO: explain the optimization this solves, add good names to variables, and refactor/indirect the if global/local trapping conditionals - Returns the new trap center (m), the new A, and the new acceleration weight + Returns the new trap center (m) and the new A """ - # prox-grad for (A, m) - # Accelerated prox gradient descent + # prox-gradient descent for (A, m) # Calculate projection matrix from Quad terms to As - if self.accel: - tk = (1 + np.sqrt(1 + 4 * tk_previous**2)) / 2.0 - m_partial = trap_ctr + (tk_previous - 1.0) / tk * (trap_ctr - trap_ctr_prev) - tk_previous = tk - if self.method == "global": - mPQ = np.tensordot(m_partial, self.PQ_, axes=([0], [0])) - else: - mPM = np.tensordot(self.PM_, m_partial, axes=([2], [0])) - else: - if self.method == "global": - mPQ = np.tensordot(trap_ctr, self.PQ_, axes=([0], [0])) - else: - mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0])) + mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0])) + p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0])) + PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1])) # Calculate As and its quad term components - if self.method == "global": - p = self.PL_ - mPQ - PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1])) - PQW = np.tensordot(self.PQ_, coef_sparse, axes=([4, 3], [0, 1])) - else: - p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0])) - PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1])) - PMW = np.tensordot(self.PM_, coef_sparse, axes=([4, 3], [0, 1])) - PMW = np.tensordot(self.mod_matrix, PMW, axes=([1], [0])) + PMW = np.tensordot(self.PM_, coef_sparse, axes=([4, 3], [0, 1])) + PMW = np.tensordot(self.mod_matrix, PMW, axes=([1], [0])) # Calculate error in quadratic balance, and adjust trap center A_b = (A - PW) / self.eta - if self.method == "global": - # PQWT_PW is gradient of some loss in m - PQWT_PW = np.tensordot(PQW, A_b, axes=([2, 1], [0, 1])) - if self.accel: - trap_new = m_partial - self.alpha_m * PQWT_PW - else: - trap_new = trap_ctr - self.alpha_m * PQWT_PW - else: - PMT_PW = np.tensordot(PMW, A_b, axes=([2, 1], [0, 1])) - if self.accel: - trap_new = m_partial - self.alpha_m * PMT_PW - else: - trap_new = trap_ctr - self.alpha_m * PMT_PW + # PQWT_PW is gradient of some loss in m + PMT_PW = np.tensordot(PMW, A_b, axes=([2, 1], [0, 1])) + trap_new = trap_ctr - self.alpha_m * PMT_PW # Update A A_new = self._update_A(A - self.alpha_A * A_b, PW) - return trap_new, A_new, tk_previous + return trap_new, A_new def _solve_nonsparse_relax_and_split(self, H, xTy, P_transpose_A, coef_prev): """Update for the coefficients if threshold = 0.""" @@ -729,20 +693,15 @@ def _reduce(self, x, y): xTx = np.dot(x_expanded.T, x_expanded) xTy = np.dot(x_expanded.T, y.flatten()) - # if using acceleration - tk_prev = 1 + # keep track of last solution in case method fails trap_prev_ctr = trap_ctr # Begin optimization loop objective_history = [] for k in range(self.max_iter): # update P tensor from the newest trap center - if self.method == "global": - mPQ = np.tensordot(trap_ctr, self.PQ_, axes=([0], [0])) - p = self.PL_ - mPQ - else: - mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0])) - p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0])) + mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0])) + p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0])) Pmatrix = p.reshape(n_tgts * n_tgts, n_tgts * n_features) self.p_history_.append(p) @@ -761,9 +720,7 @@ def _reduce(self, x, y): coef_sparse = coef_prev break - trap_ctr, A, tk_prev = self._solve_m_relax_and_split( - trap_prev_ctr, trap_ctr, A, coef_sparse, tk_prev - ) + trap_ctr, A = self._solve_m_relax_and_split(trap_ctr, A, coef_sparse) # If problem over m becomes infeasible, break out of the loop if trap_ctr is None: