From 60ab8793439cfc0cf3c7c1068889fd9e45c439dc Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 13 May 2024 15:18:28 -0700 Subject: [PATCH] CLN (trapping): No need to check matrix we just created --- pysindy/optimizers/trapping_sr3.py | 110 +---------------------------- 1 file changed, 3 insertions(+), 107 deletions(-) diff --git a/pysindy/optimizers/trapping_sr3.py b/pysindy/optimizers/trapping_sr3.py index 650112f8b..1afb4fd1e 100644 --- a/pysindy/optimizers/trapping_sr3.py +++ b/pysindy/optimizers/trapping_sr3.py @@ -453,8 +453,7 @@ def _set_Ptensors( N = self.n_features # If bias term is included, need to shift the tensor index PC_tensor = np.zeros((n_targets, n_targets, N)) - if N > int((n_targets**2 + 3 * n_targets) / 2.0): - self._include_bias = True + if self._include_bias: PC_tensor[range(n_targets), range(n_targets), 0] = 1.0 lib = PolynomialLibrary(2, include_bias=self._include_bias).fit( @@ -470,109 +469,6 @@ def _set_Ptensors( return PC_tensor, PL_tensor_unsym, PL_tensor, PQ_tensor, PT_tensor, PM_tensor - def _bad_PL(self, PL): - """Check if PL tensor is properly defined""" - tol = 1e-10 - return np.any((np.transpose(PL, [1, 0, 2, 3]) - PL) > tol) - - def _bad_PQ(self, PQ): - """Check if PQ tensor is properly defined""" - tol = 1e-10 - return np.any((np.transpose(PQ, [0, 2, 1, 3, 4]) - PQ) > tol) - - def _bad_PT(self, PT): - """Check if PT tensor is properly defined""" - tol = 1e-10 - return np.any((np.transpose(PT, [2, 1, 0, 3, 4]) - PT) > tol) - - def _check_P_matrix(self, r, n_features, N): - """Check if P tensor is properly defined""" - # If these tensors are not passed, or incorrect shape, assume zeros - if self.PQ_ is None: - self.PQ_ = np.zeros((r, r, r, r, n_features)) - warnings.warn( - "The PQ tensor (a requirement for the stability promotion) was" - " not set, so setting this tensor to all zeros. " - ) - elif (self.PQ_).shape != (r, r, r, r, n_features) and (self.PQ_).shape != ( - r, - r, - r, - r, - N, - ): - self.PQ_ = np.zeros((r, r, r, r, n_features)) - warnings.warn( - "The PQ tensor (a requirement for the stability promotion) was" - " initialized with incorrect dimensions, " - "so setting this tensor to all zeros " - "(with the correct dimensions). " - ) - if self.PT_ is None: - self.PT_ = np.zeros((r, r, r, r, n_features)) - warnings.warn( - "The PT tensor (a requirement for the stability promotion) was" - " not set, so setting this tensor to all zeros. " - ) - elif (self.PT_).shape != (r, r, r, r, n_features) and (self.PT_).shape != ( - r, - r, - r, - r, - N, - ): - self.PT_ = np.zeros((r, r, r, r, n_features)) - warnings.warn( - "The PT tensor (a requirement for the stability promotion) was" - " initialized with incorrect dimensions, " - "so setting this tensor to all zeros " - "(with the correct dimensions). " - ) - if self.PL_ is None: - self.PL_ = np.zeros((r, r, r, n_features)) - warnings.warn( - "The PL tensor (a requirement for the stability promotion) was" - " not set, so setting this tensor to all zeros. " - ) - elif (self.PL_).shape != (r, r, r, n_features) and (self.PL_).shape != ( - r, - r, - r, - N, - ): - self.PL_ = np.zeros((r, r, r, n_features)) - warnings.warn( - "The PL tensor (a requirement for the stability promotion) was" - " initialized with incorrect dimensions, " - "so setting this tensor to all zeros " - "(with the correct dimensions). " - ) - - # Check if the tensor symmetries are properly defined - if self._bad_PL(self.PL_): - raise ValueError("PL tensor was passed but the symmetries are not correct") - if self._bad_PQ(self.PQ_): - raise ValueError("PQ tensor was passed but the symmetries are not correct") - if self._bad_PT(self.PT_): - raise ValueError("PT tensor was passed but the symmetries are not correct") - - # If PL/PQ/PT finite and correct, so trapping theorem is being used, - # then make sure library is quadratic and correct shape - if ( - np.any(self.PL_ != 0.0) - or np.any(self.PQ_ != 0.0) - or np.any(self.PT_ != 0.0) - ) and n_features != N: - print( - "The feature library is the wrong shape or not quadratic, " - "so please correct this if you are attempting to use the " - "trapping algorithm with the stability term included. Setting " - "PL and PQ tensors to zeros for now." - ) - self.PL_ = np.zeros((r, r, r, n_features)) - self.PQ_ = np.zeros((r, r, r, r, n_features)) - self.PT_ = np.zeros((r, r, r, r, n_features)) - def _update_coef_constraints(self, H, x_transpose_y, P_transpose_A, coef_sparse): """Solves the coefficient update analytically if threshold = 0""" g = x_transpose_y + P_transpose_A / self.eta @@ -786,6 +682,8 @@ def _reduce(self, x, y): self.n_features = n_features r = y.shape[1] N = n_features # int((r ** 2 + 3 * r) / 2.0) + if N > int((r**2 + 3 * r) / 2.0): + self._include_bias = True if self.mod_matrix is None: self.mod_matrix = np.eye(r) @@ -800,8 +698,6 @@ def _reduce(self, x, y): self.PT_, self.PM_, ) = self._set_Ptensors(r) - # make sure dimensions/symmetries are correct - self._check_P_matrix(r, n_features, N) # Set initial coefficients if self.use_constraints and self.constraint_order.lower() == "target":