Skip to content

Commit

Permalink
Removed the if statements between global and local so they use the sa…
Browse files Browse the repository at this point in the history
…me code for computing PW now. Removed the option for accelerated prox grad descent since this seemed not to work very well anyways. Updated the trapping paper examples script. Everything now looks good except the constraint matrix is not quite right for the MHD case.
  • Loading branch information
akaptano committed Jun 28, 2024
1 parent 6444890 commit e6013cd
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@
eta = 1.0e3

alpha_m = 8e-1 * eta # default is 1e-2 * eta so this speeds up the code here
accel = False

# run trapping SINDy
sindy_opt = ps.TrappingSR3(
Expand All @@ -385,7 +384,6 @@
threshold=threshold,
eta=eta,
alpha_m=alpha_m,
accel=accel,
max_iter=max_iter,
gamma=-1,
verbose=True,
Expand Down Expand Up @@ -547,6 +545,7 @@
L = np.tensordot(PL_tensor, Xi, axes=([3, 2], [0, 1]))
Q = np.tensordot(PQ_tensor, Xi, axes=([4, 3], [0, 1]))
Q_sum = np.max(np.abs((Q + np.transpose(Q, [1, 2, 0]) + np.transpose(Q, [2, 0, 1]))))
print((Q + np.transpose(Q, [1, 2, 0]) + np.transpose(Q, [2, 0, 1])))
print("Max deviation from the constraints = ", Q_sum)

# plotting and analysis
Expand Down
5 changes: 0 additions & 5 deletions examples/8_trapping_sindy_examples/trapping_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def make_3d_plots(x_test, x_test_pred, filename):
ax.set_zticklabels([])
ax.set_axis_off()
plt.legend(fontsize=14)
plt.show()


# Plot the SINDy fits of X and Xdot against the ground truth
Expand All @@ -295,8 +294,6 @@ def make_fits(r, t, xdot_test, xdot_test_pred, x_test, x_test_pred, filename):
if i == r - 1:
plt.xlabel("t", fontsize=18)

plt.show()


# Plot errors between m_{k+1} and m_k and similarly for the model coefficients
def make_progress_plots(r, sindy_opt):
Expand Down Expand Up @@ -653,7 +650,6 @@ def trapping_region(r, x_test_pred, Xi, sindy_opt, filename):
ax.set_yticklabels([])
ax.set_zticklabels([])
ax.set_axis_off()
plt.show()


# Make Lissajou figures with ground truth and SINDy model
Expand Down Expand Up @@ -683,4 +679,3 @@ def make_lissajou(r, x_train, x_test, x_train_pred, x_test_pred, filename):
plt.ylabel(r"$x_" + str(i) + r"$", fontsize=18)
if i == r - 1:
plt.xlabel(r"$x_" + str(j) + r"$", fontsize=18)
plt.show()
75 changes: 16 additions & 59 deletions pysindy/optimizers/trapping_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,6 @@ class TrappingSR3(ConstrainedSR3):
could be straightforwardly implemented, but L0 requires
reformulation because of nonconvexity. (default 'L1')
accel :
Whether or not to use accelerated prox-gradient descent for (m, A).
(default False)
m0 :
Initial guess for trap center in the optimization. Default None
initializes vector elements randomly in [-1, 1]. shape (n_targets)
Expand Down Expand Up @@ -204,7 +200,6 @@ def __init__(
gamma: float = -0.1,
tol_m: float = 1e-5,
thresholder: str = "l1",
accel: bool = False,
m0: Union[NDArray, None] = None,
A0: Union[NDArray, None] = None,
**kwargs,
Expand Down Expand Up @@ -272,7 +267,6 @@ def __init__(
self.beta = beta
self.gamma = gamma
self.tol_m = tol_m
self.accel = accel
self.__post_init_guard()

def __post_init_guard(self):
Expand Down Expand Up @@ -584,64 +578,34 @@ def _update_coef_nonsparse_rs(

def _solve_m_relax_and_split(
self,
trap_ctr_prev: Float1D,
trap_ctr: Float1D,
A: Float2D,
coef_sparse: np.ndarray[tuple[NFeat, NTarget], AnyFloat],
tk_previous: float,
) -> tuple[Float1D, Float2D, float]:
) -> tuple[Float1D, Float2D]:
"""Solves the (m, A) algorithm update.
TODO: explain the optimization this solves, add good names to variables,
and refactor/indirect the if global/local trapping conditionals
Returns the new trap center (m), the new A, and the new acceleration weight
Returns the new trap center (m) and the new A
"""
# prox-grad for (A, m)
# Accelerated prox gradient descent
# prox-gradient descent for (A, m)
# Calculate projection matrix from Quad terms to As
if self.accel:
tk = (1 + np.sqrt(1 + 4 * tk_previous**2)) / 2.0
m_partial = trap_ctr + (tk_previous - 1.0) / tk * (trap_ctr - trap_ctr_prev)
tk_previous = tk
if self.method == "global":
mPQ = np.tensordot(m_partial, self.PQ_, axes=([0], [0]))
else:
mPM = np.tensordot(self.PM_, m_partial, axes=([2], [0]))
else:
if self.method == "global":
mPQ = np.tensordot(trap_ctr, self.PQ_, axes=([0], [0]))
else:
mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0]))
mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0]))
p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0]))
PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1]))
# Calculate As and its quad term components
if self.method == "global":
p = self.PL_ - mPQ
PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1]))
PQW = np.tensordot(self.PQ_, coef_sparse, axes=([4, 3], [0, 1]))
else:
p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0]))
PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1]))
PMW = np.tensordot(self.PM_, coef_sparse, axes=([4, 3], [0, 1]))
PMW = np.tensordot(self.mod_matrix, PMW, axes=([1], [0]))
PMW = np.tensordot(self.PM_, coef_sparse, axes=([4, 3], [0, 1]))
PMW = np.tensordot(self.mod_matrix, PMW, axes=([1], [0]))
# Calculate error in quadratic balance, and adjust trap center
A_b = (A - PW) / self.eta
if self.method == "global":
# PQWT_PW is gradient of some loss in m
PQWT_PW = np.tensordot(PQW, A_b, axes=([2, 1], [0, 1]))
if self.accel:
trap_new = m_partial - self.alpha_m * PQWT_PW
else:
trap_new = trap_ctr - self.alpha_m * PQWT_PW
else:
PMT_PW = np.tensordot(PMW, A_b, axes=([2, 1], [0, 1]))
if self.accel:
trap_new = m_partial - self.alpha_m * PMT_PW
else:
trap_new = trap_ctr - self.alpha_m * PMT_PW
# PQWT_PW is gradient of some loss in m
PMT_PW = np.tensordot(PMW, A_b, axes=([2, 1], [0, 1]))
trap_new = trap_ctr - self.alpha_m * PMT_PW

# Update A
A_new = self._update_A(A - self.alpha_A * A_b, PW)
return trap_new, A_new, tk_previous
return trap_new, A_new

def _solve_nonsparse_relax_and_split(self, H, xTy, P_transpose_A, coef_prev):
"""Update for the coefficients if threshold = 0."""
Expand Down Expand Up @@ -729,20 +693,15 @@ def _reduce(self, x, y):
xTx = np.dot(x_expanded.T, x_expanded)
xTy = np.dot(x_expanded.T, y.flatten())

# if using acceleration
tk_prev = 1
# keep track of last solution in case method fails
trap_prev_ctr = trap_ctr

# Begin optimization loop
objective_history = []
for k in range(self.max_iter):
# update P tensor from the newest trap center
if self.method == "global":
mPQ = np.tensordot(trap_ctr, self.PQ_, axes=([0], [0]))
p = self.PL_ - mPQ
else:
mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0]))
p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0]))
mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0]))
p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0]))
Pmatrix = p.reshape(n_tgts * n_tgts, n_tgts * n_features)
self.p_history_.append(p)

Expand All @@ -761,9 +720,7 @@ def _reduce(self, x, y):
coef_sparse = coef_prev
break

trap_ctr, A, tk_prev = self._solve_m_relax_and_split(
trap_prev_ctr, trap_ctr, A, coef_sparse, tk_prev
)
trap_ctr, A = self._solve_m_relax_and_split(trap_ctr, A, coef_sparse)

# If problem over m becomes infeasible, break out of the loop
if trap_ctr is None:
Expand Down

0 comments on commit e6013cd

Please sign in to comment.