Skip to content

Commit

Permalink
fix: some function calls
Browse files Browse the repository at this point in the history
ALso align some comments and organization with trapping-resolve
  • Loading branch information
Jacob-Stevens-Haas committed Jun 1, 2024
1 parent 02b5a8a commit 64882e7
Showing 1 changed file with 16 additions and 22 deletions.
38 changes: 16 additions & 22 deletions pysindy/optimizers/trapping_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ def __post_init_guard(self):
raise ValueError("Inequality constraints requires threshold!=0")
if self.mod_matrix is None:
self.mod_matrix = np.eye(self._n_tgts)
if self.A0 is None:
self.A0 = np.diag(self.gamma * np.ones(self._n_tgts))

def set_params(self, **kwargs):
super().set_params(**kwargs)
Expand Down Expand Up @@ -534,19 +536,23 @@ def _objective(self, x, y, coef_sparse, A, PW, k):
return R2 + stability_term + L1 + alpha_term + beta_term

def _update_coef_sparse_rs(
self, r, N, var_len, x_expanded, y, Pmatrix, A, coef_prev
self, n_tgts, n_features, var_len, x_expanded, y, Pmatrix, A, coef_prev
):
"""Solve coefficient update with CVXPY if threshold != 0"""
xi, cost = self._create_var_and_part_cost(var_len, x_expanded, y)
cost = cost + cp.sum_squares(Pmatrix @ xi - A.flatten()) / self.eta

# new terms minimizing quadratic piece ||P^Q @ xi||_2^2
if self.method == "local":
Q = np.reshape(self.PQ_, (r * r * r, N * r), "F")
Q = np.reshape(
self.PQ_, (n_tgts * n_tgts * n_tgts, n_features * n_tgts), "F"
)
cost = cost + cp.sum_squares(Q @ xi) / self.alpha
Q = np.reshape(self.PQ_, (r, r, r, N * r), "F")
Q = np.reshape(self.PQ_, (n_tgts, n_tgts, n_tgts, n_features * n_tgts), "F")
Q_ep = Q + np.transpose(Q, [1, 2, 0, 3]) + np.transpose(Q, [2, 0, 1, 3])
Q_ep = np.reshape(Q_ep, (r * r * r, N * r), "F")
Q_ep = np.reshape(
Q_ep, (n_tgts * n_tgts * n_tgts, n_features * n_tgts), "F"
)
cost = cost + cp.sum_squares(Q_ep @ xi) / self.beta

return self._update_coef_cvxpy(xi, cost, var_len, coef_prev, self.eps_solver)
Expand Down Expand Up @@ -647,8 +653,6 @@ def _reduce(self, x, y):
)
var_len = n_features * n_tgts

# Define PL, PQ, PT and PM tensors, only relevant if the stability term in
# trapping SINDy is turned on.
(
self.PC_,
self.PL_unsym_,
Expand Down Expand Up @@ -681,13 +685,7 @@ def _reduce(self, x, y):
" ... {: >8} ... {: >10} ... {: >8}".format(*row)
)

# initial A
if self.A0 is not None:
A = self.A0
elif np.any(self.PM_ != 0.0):
A = np.diag(self.gamma * np.ones(n_tgts))
else:
A = np.diag(np.zeros(n_tgts))
A = self.A0
self.A_history_.append(A)

# initial guess for m
Expand All @@ -708,7 +706,7 @@ def _reduce(self, x, y):

# if using acceleration
tk_prev = 1
m_prev = trap_ctr
trap_prev_ctr = trap_ctr

# Begin optimization loop
objective_history = []
Expand All @@ -726,26 +724,22 @@ def _reduce(self, x, y):
n_tgts, n_features, var_len, x_expanded, y, Pmatrix, A, coef_prev
)
else:
# if threshold = 0, there is analytic expression
# for the solve over the coefficients,
# which is coded up here separately
coef_sparse = self._update_coef_nonsparse_rs(
Pmatrix, A, coef_prev, xTx, xTy
n_tgts, n_features, Pmatrix, A, coef_prev, xTx, xTy
)

# If problem over xi becomes infeasible, break out of the loop
if coef_sparse is None:
coef_sparse = coef_prev
break

# Now solve optimization for m and A
m_prev, trap_ctr, A, tk_prev = self._solve_m_relax_and_split(
n_tgts, n_features, m_prev, trap_ctr, A, coef_sparse, tk_prev
trap_prev_ctr, trap_ctr, A, tk_prev = self._solve_m_relax_and_split(
n_tgts, n_features, trap_prev_ctr, trap_ctr, A, coef_sparse, tk_prev
)

# If problem over m becomes infeasible, break out of the loop
if trap_ctr is None:
trap_ctr = m_prev
trap_ctr = trap_prev_ctr
break
self.history_.append(coef_sparse.T)
PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1]))
Expand Down

0 comments on commit 64882e7

Please sign in to comment.