Skip to content

Commit

Permalink
CLN: Simplify checking PQ/PL tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Nov 9, 2023
1 parent 6af6e4c commit e6aa50a
Showing 1 changed file with 25 additions and 70 deletions.
95 changes: 25 additions & 70 deletions pysindy/optimizers/trapping_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,76 +251,30 @@ def _set_Ptensors(

return PL_tensor_unsym, PL_tensor, PQ_tensor

def _bad_PL(self, PL):
"""Check if PL tensor is properly defined"""
tol = 1e-10
return np.any((np.transpose(PL, [1, 0, 2, 3]) - PL) > tol)

def _bad_PQ(self, PQ):
"""Check if PQ tensor is properly defined"""
tol = 1e-10
return np.any((np.transpose(PQ, [0, 2, 1, 3, 4]) - PQ) > tol)

def _check_P_matrix(self, r, n_features, N):
@staticmethod
def _check_P_matrix(
n_tgts: int, n_feat: int, n_feat_expected: int, PL: np.ndarray, PQ: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""Check if P tensor is properly defined"""
# If these tensors are not passed, or incorrect shape, assume zeros
if self.PQ_ is None:
self.PQ_ = np.zeros((r, r, r, r, n_features))
warnings.warn(
"The PQ tensor (a requirement for the stability promotion) was"
" not set, so setting this tensor to all zeros. "
)
elif (self.PQ_).shape != (r, r, r, r, n_features) and (self.PQ_).shape != (
r,
r,
r,
r,
N,
):
self.PQ_ = np.zeros((r, r, r, r, n_features))
warnings.warn(
"The PQ tensor (a requirement for the stability promotion) was"
" initialized with incorrect dimensions, "
"so setting this tensor to all zeros "
"(with the correct dimensions). "
)
if self.PL_ is None:
self.PL_ = np.zeros((r, r, r, n_features))
warnings.warn(
"The PL tensor (a requirement for the stability promotion) was"
" not set, so setting this tensor to all zeros. "
)
elif (self.PL_).shape != (r, r, r, n_features) and (self.PL_).shape != (
r,
r,
r,
N,
if (
PQ is None
or PL is None
or PQ.shape != (n_tgts, n_tgts, n_tgts, n_tgts, n_feat)
or PL.shape != (n_tgts, n_tgts, n_tgts, n_feat)
or n_feat != n_feat_expected # library is not quadratic/incorrect shape
):
self.PL_ = np.zeros((r, r, r, n_features))
PL = np.zeros((n_tgts, n_tgts, n_tgts, n_feat))
PQ = np.zeros((n_tgts, n_tgts, n_tgts, n_tgts, n_feat))
warnings.warn(
"The PL tensor (a requirement for the stability promotion) was"
" initialized with incorrect dimensions, "
"so setting this tensor to all zeros "
"(with the correct dimensions). "
"PQ and PL tensors not defined, wrong shape, or incompatible with "
"feature library shape. Ensure feature library is quadratic. "
"Setting tensors to zero"
)

# Check if the tensor symmetries are properly defined
if self._bad_PL(self.PL_):
raise ValueError("PL tensor was passed but the symmetries are not correct")
if self._bad_PQ(self.PQ_):
raise ValueError("PQ tensor was passed but the symmetries are not correct")

# If PL/PQ finite and correct, so trapping theorem is being used,
# then make sure library is quadratic and correct shape
if (np.any(self.PL_ != 0.0) or np.any(self.PQ_ != 0.0)) and n_features != N:
warnings.warn(
"The feature library is the wrong shape or not quadratic, "
"so please correct this if you are attempting to use the "
"trapping algorithm with the stability term included. Setting "
"PL and PQ tensors to zeros for now."
)
self.PL_ = np.zeros((r, r, r, n_features))
self.PQ_ = np.zeros((r, r, r, r, n_features))
if not np.allclose(
np.transpose(PL, [1, 0, 2, 3]), PL, atol=1e-10
) or not np.allclose(np.transpose(PQ, [0, 2, 1, 3, 4]), PQ, atol=1e-10):
raise ValueError("PQ/PL tensors were passed but have the wrong symmetry")
return PL, PQ

def _update_coef_constraints(self, H, x_transpose_y, P_transpose_A, coef_sparse):
"""Solves the coefficient update analytically if threshold = 0"""
Expand Down Expand Up @@ -504,13 +458,14 @@ def _reduce(self, x, y):
self.history_ = []
n_samples, n_features = x.shape
n_tgts = y.shape[1]
N = int((n_tgts**2 + 3 * n_tgts) / 2.0)
n_feat_expected = int((n_tgts**2 + 3 * n_tgts) / 2.0)

# Define PL and PQ tensors, only relevant if the stability term in
# trapping SINDy is turned on.
# Only relevant if the stability term is turned on.
self.PL_unsym_, self.PL_, self.PQ_ = self._set_Ptensors(n_tgts)
# make sure dimensions/symmetries are correct
self._check_P_matrix(n_tgts, n_features, N)
self.PL_, self.PQ_ = self._check_P_matrix(
n_tgts, n_features, n_feat_expected, self.PL_, self.PQ_
)

# Set initial coefficients
if self.use_constraints and self.constraint_order.lower() == "target":
Expand Down

0 comments on commit e6aa50a

Please sign in to comment.