From e6aa50a35aff0e30ca3211aaec128c97c2364b58 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 9 Nov 2023 12:58:19 +0000 Subject: [PATCH] CLN: Simplify checking PQ/PL tensors --- pysindy/optimizers/trapping_sr3.py | 95 ++++++++---------------------- 1 file changed, 25 insertions(+), 70 deletions(-) diff --git a/pysindy/optimizers/trapping_sr3.py b/pysindy/optimizers/trapping_sr3.py index 6024ba3d7..33eed6461 100644 --- a/pysindy/optimizers/trapping_sr3.py +++ b/pysindy/optimizers/trapping_sr3.py @@ -251,76 +251,30 @@ def _set_Ptensors( return PL_tensor_unsym, PL_tensor, PQ_tensor - def _bad_PL(self, PL): - """Check if PL tensor is properly defined""" - tol = 1e-10 - return np.any((np.transpose(PL, [1, 0, 2, 3]) - PL) > tol) - - def _bad_PQ(self, PQ): - """Check if PQ tensor is properly defined""" - tol = 1e-10 - return np.any((np.transpose(PQ, [0, 2, 1, 3, 4]) - PQ) > tol) - - def _check_P_matrix(self, r, n_features, N): + @staticmethod + def _check_P_matrix( + n_tgts: int, n_feat: int, n_feat_expected: int, PL: np.ndarray, PQ: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: """Check if P tensor is properly defined""" - # If these tensors are not passed, or incorrect shape, assume zeros - if self.PQ_ is None: - self.PQ_ = np.zeros((r, r, r, r, n_features)) - warnings.warn( - "The PQ tensor (a requirement for the stability promotion) was" - " not set, so setting this tensor to all zeros. " - ) - elif (self.PQ_).shape != (r, r, r, r, n_features) and (self.PQ_).shape != ( - r, - r, - r, - r, - N, - ): - self.PQ_ = np.zeros((r, r, r, r, n_features)) - warnings.warn( - "The PQ tensor (a requirement for the stability promotion) was" - " initialized with incorrect dimensions, " - "so setting this tensor to all zeros " - "(with the correct dimensions). " - ) - if self.PL_ is None: - self.PL_ = np.zeros((r, r, r, n_features)) - warnings.warn( - "The PL tensor (a requirement for the stability promotion) was" - " not set, so setting this tensor to all zeros. " - ) - elif (self.PL_).shape != (r, r, r, n_features) and (self.PL_).shape != ( - r, - r, - r, - N, + if ( + PQ is None + or PL is None + or PQ.shape != (n_tgts, n_tgts, n_tgts, n_tgts, n_feat) + or PL.shape != (n_tgts, n_tgts, n_tgts, n_feat) + or n_feat != n_feat_expected # library is not quadratic/incorrect shape ): - self.PL_ = np.zeros((r, r, r, n_features)) + PL = np.zeros((n_tgts, n_tgts, n_tgts, n_feat)) + PQ = np.zeros((n_tgts, n_tgts, n_tgts, n_tgts, n_feat)) warnings.warn( - "The PL tensor (a requirement for the stability promotion) was" - " initialized with incorrect dimensions, " - "so setting this tensor to all zeros " - "(with the correct dimensions). " + "PQ and PL tensors not defined, wrong shape, or incompatible with " + "feature library shape. Ensure feature library is quadratic. " + "Setting tensors to zero" ) - - # Check if the tensor symmetries are properly defined - if self._bad_PL(self.PL_): - raise ValueError("PL tensor was passed but the symmetries are not correct") - if self._bad_PQ(self.PQ_): - raise ValueError("PQ tensor was passed but the symmetries are not correct") - - # If PL/PQ finite and correct, so trapping theorem is being used, - # then make sure library is quadratic and correct shape - if (np.any(self.PL_ != 0.0) or np.any(self.PQ_ != 0.0)) and n_features != N: - warnings.warn( - "The feature library is the wrong shape or not quadratic, " - "so please correct this if you are attempting to use the " - "trapping algorithm with the stability term included. Setting " - "PL and PQ tensors to zeros for now." - ) - self.PL_ = np.zeros((r, r, r, n_features)) - self.PQ_ = np.zeros((r, r, r, r, n_features)) + if not np.allclose( + np.transpose(PL, [1, 0, 2, 3]), PL, atol=1e-10 + ) or not np.allclose(np.transpose(PQ, [0, 2, 1, 3, 4]), PQ, atol=1e-10): + raise ValueError("PQ/PL tensors were passed but have the wrong symmetry") + return PL, PQ def _update_coef_constraints(self, H, x_transpose_y, P_transpose_A, coef_sparse): """Solves the coefficient update analytically if threshold = 0""" @@ -504,13 +458,14 @@ def _reduce(self, x, y): self.history_ = [] n_samples, n_features = x.shape n_tgts = y.shape[1] - N = int((n_tgts**2 + 3 * n_tgts) / 2.0) + n_feat_expected = int((n_tgts**2 + 3 * n_tgts) / 2.0) - # Define PL and PQ tensors, only relevant if the stability term in - # trapping SINDy is turned on. + # Only relevant if the stability term is turned on. self.PL_unsym_, self.PL_, self.PQ_ = self._set_Ptensors(n_tgts) # make sure dimensions/symmetries are correct - self._check_P_matrix(n_tgts, n_features, N) + self.PL_, self.PQ_ = self._check_P_matrix( + n_tgts, n_features, n_feat_expected, self.PL_, self.PQ_ + ) # Set initial coefficients if self.use_constraints and self.constraint_order.lower() == "target":