diff --git a/pysindy/optimizers/trapping_sr3.py b/pysindy/optimizers/trapping_sr3.py index c25803c7c..dee8937ad 100644 --- a/pysindy/optimizers/trapping_sr3.py +++ b/pysindy/optimizers/trapping_sr3.py @@ -1,4 +1,7 @@ import warnings +from itertools import combinations_with_replacement as combo_wr +from itertools import product +from typing import Tuple import cvxpy as cp import numpy as np @@ -223,41 +226,28 @@ def __init__( self.accel = accel self.objective_history = objective_history - def _set_Ptensors(self, r): + def _set_Ptensors(self, n_targets: int) -> Tuple[np.ndarray, np.ndarray]: """Make the projection tensors used for the algorithm.""" - N = int((r**2 + 3 * r) / 2.0) + N = int((n_targets**2 + 3 * n_targets) / 2.0) # delta_{il}delta_{jk} - PL_tensor = np.zeros((r, r, r, N)) - PL_tensor_unsym = np.zeros((r, r, r, N)) - for i in range(r): - for j in range(r): - for k in range(r): - for kk in range(N): - if i == k and j == kk: - PL_tensor_unsym[i, j, k, kk] = 1.0 + PL_tensor_unsym = np.zeros((n_targets, n_targets, n_targets, N)) + for i, j in combo_wr(range(n_targets), 2): + PL_tensor_unsym[i, j, i, j] = 1.0 # Now symmetrize PL - for i in range(r): - for j in range(N): - PL_tensor[:, :, i, j] = 0.5 * ( - PL_tensor_unsym[:, :, i, j] + PL_tensor_unsym[:, :, i, j].T - ) + PL_tensor = (PL_tensor_unsym + np.transpose(PL_tensor_unsym, [1, 0, 2, 3])) / 2 # if j == k, delta_{il}delta_{N-r+j,n} # if j != k, delta_{il}delta_{r+j+k-1,n} - PQ_tensor = np.zeros((r, r, r, r, N)) - for i in range(r): - for j in range(r): - for k in range(r): - for kk in range(r): - for n in range(N): - if (j == k) and (n == N - r + j) and (i == kk): - PQ_tensor[i, j, k, kk, n] = 1.0 - if (j != k) and (n == r + j + k - 1) and (i == kk): - PQ_tensor[i, j, k, kk, n] = 1 / 2 + PQ_tensor = np.zeros((n_targets, n_targets, n_targets, n_targets, N)) + for (i, j, k, kk), n in product(combo_wr(range(n_targets), 4), range(N)): + if (j == k) and (n == N - n_targets + j) and (i == kk): + PQ_tensor[i, j, k, kk, n] = 1.0 + if (j != k) and (n == n_targets + j + k - 1) and (i == kk): + PQ_tensor[i, j, k, kk, n] = 1 / 2 - return PL_tensor_unsym, PL_tensor, PQ_tensor + return PL_tensor, PQ_tensor def _bad_PL(self, PL): """Check if PL tensor is properly defined""" @@ -511,14 +501,14 @@ def _reduce(self, x, y): self.PWeigs_history_ = [] self.history_ = [] n_samples, n_features = x.shape - r = y.shape[1] - N = int((r**2 + 3 * r) / 2.0) + n_tgts = y.shape[1] + N = int((n_tgts**2 + 3 * n_tgts) / 2.0) # Define PL and PQ tensors, only relevant if the stability term in # trapping SINDy is turned on. - self.PL_unsym_, self.PL_, self.PQ_ = self._set_Ptensors(r) + self.PL_, self.PQ_ = self._set_Ptensors(n_tgts) # make sure dimensions/symmetries are correct - self._check_P_matrix(r, n_features, N) + self._check_P_matrix(n_tgts, n_features, N) # Set initial coefficients if self.use_constraints and self.constraint_order.lower() == "target": @@ -544,9 +534,9 @@ def _reduce(self, x, y): if self.A0 is not None: A = self.A0 elif np.any(self.PQ_ != 0.0): - A = np.diag(self.gamma * np.ones(r)) + A = np.diag(self.gamma * np.ones(n_tgts)) else: - A = np.diag(np.zeros(r)) + A = np.diag(np.zeros(n_tgts)) self.A_history_.append(A) # initial guess for m @@ -554,14 +544,14 @@ def _reduce(self, x, y): m = self.m0 else: np.random.seed(1) - m = (np.random.rand(r) - np.ones(r)) * 2 + m = (np.random.rand(n_tgts) - np.ones(n_tgts)) * 2 self.m_history_.append(m) # Precompute some objects for optimization - x_expanded = np.zeros((n_samples, r, n_features, r)) - for i in range(r): + x_expanded = np.zeros((n_samples, n_tgts, n_features, n_tgts)) + for i in range(n_tgts): x_expanded[:, i, :, i] = x - x_expanded = np.reshape(x_expanded, (n_samples * r, r * n_features)) + x_expanded = np.reshape(x_expanded, (n_samples * n_tgts, n_tgts * n_features)) xTx = np.dot(x_expanded.T, x_expanded) xTy = np.dot(x_expanded.T, y.flatten()) @@ -576,7 +566,7 @@ def _reduce(self, x, y): # update P tensor from the newest m mPQ = np.tensordot(m, self.PQ_, axes=([0], [0])) p = self.PL_ - mPQ - Pmatrix = p.reshape(r * r, r * n_features) + Pmatrix = p.reshape(n_tgts * n_tgts, n_tgts * n_features) # update w coef_prev = coef_sparse @@ -584,14 +574,14 @@ def _reduce(self, x, y): if self.relax_optim: if self.threshold > 0.0: xi, cost = self._create_var_and_part_cost( - n_features * r, x_expanded, y + n_features * n_tgts, x_expanded, y ) cost = ( cost + cp.sum_squares(Pmatrix @ xi - A.flatten()) / self.eta ) # sparse relax_and_split coef_sparse = self._update_coef_cvxpy( - xi, cost, r * n_features, coef_prev, self.eps_solver + xi, cost, n_tgts * n_features, coef_prev, self.eps_solver ) else: pTp = np.dot(Pmatrix.T, Pmatrix) @@ -602,7 +592,7 @@ def _reduce(self, x, y): ) else: m, coef_sparse = self._solve_direct_cvxpy( - r, n_features, x_expanded, y, Pmatrix, coef_prev + n_tgts, n_features, x_expanded, y, Pmatrix, coef_prev ) # If problem over xi becomes infeasible, break out of the loop @@ -612,7 +602,7 @@ def _reduce(self, x, y): if self.relax_optim: m_prev, m, A, tk_prev = self._solve_m_relax_and_split( - r, n_features, m_prev, m, A, coef_sparse, tk_prev + n_tgts, n_features, m_prev, m, A, coef_sparse, tk_prev ) # If problem over m becomes infeasible, break out of the loop diff --git a/test/test_optimizers.py b/test/test_optimizers.py index e4d13c4fa..f208c7a6e 100644 --- a/test/test_optimizers.py +++ b/test/test_optimizers.py @@ -483,7 +483,6 @@ def test_trapping_sr3_quadratic_library(params, trapping_sr3_params, quadratic_l opt = TrappingSR3(**params) opt.fit(features, x_dot) - assert opt.PL_unsym_.shape == (1, 1, 1, 2) assert opt.PL_.shape == (1, 1, 1, 2) assert opt.PQ_.shape == (1, 1, 1, 1, 2) check_is_fitted(opt) @@ -497,7 +496,6 @@ def test_trapping_sr3_quadratic_library(params, trapping_sr3_params, quadratic_l opt = TrappingSR3(**params) opt.fit(features, x_dot) - assert opt.PL_unsym_.shape == (1, 1, 1, 2) assert opt.PL_.shape == (1, 1, 1, 2) assert opt.PQ_.shape == (1, 1, 1, 1, 2) check_is_fitted(opt) diff --git a/test/test_optimizers_complexity.py b/test/test_optimizers_complexity.py index 8a6486d83..f027f12f5 100644 --- a/test/test_optimizers_complexity.py +++ b/test/test_optimizers_complexity.py @@ -45,7 +45,7 @@ def test_complexity_parameter( optimizers = [ WrappedOptimizer(opt_cls(**{reg_name: reg_value}), normalize_columns=True) - for reg_value in [10, 1, 0.1, 0.01] + for reg_value in [10, 1, 0.1, 0.001] ] for opt in optimizers: