Skip to content

Commit

Permalink
fix: Cleanup relax_optim
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed May 29, 2024
1 parent 652514f commit 3ddb3ba
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -520,9 +520,6 @@
alpha_m = 5.0e-1 * eta


# run trapping SINDy with "relax_optim = False" here, uses CVXPY now
# so this tends to be much slower but often need far fewer algorithm iterations.
# For this problem, a single (very slow) update is all that is needed!
sindy_opt = ps.TrappingSR3(
_n_tgts=6,
_include_bias=True,
Expand Down Expand Up @@ -594,7 +591,6 @@
#
# $$
# \begin{align}
# \label{eq:burgers}
# \dot{u} &= -(U + u)\partial_x u + \nu \partial_{xx}^2u + g(x,t),
# \end{align}
# $$
Expand All @@ -605,7 +601,6 @@
#
# $$
# \begin{equation}
# \label{eq:burgers_galerkin}
# \dot{a}_k = \left( \delta_{|k|1} \sigma - \nu k^2 - ikU \right) a_k - \sum_{\ell=-r}^{r} i \ell a_{\ell} a_{k - \ell}.
# \end{equation}
# $$
Expand Down
9 changes: 5 additions & 4 deletions pysindy/optimizers/trapping_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ def __init__(
A0: Union[NDArray, None] = None,
**kwargs,
):
self.alpha = alpha
self.beta = beta
self.mod_matrix = mod_matrix
# n_tgts, constraints, etc are data-dependent parameters and belong in
# _reduce/fit (). The following is a hack until we refactor how
# constraints are applied in ConstrainedSR3 and MIOSR
Expand Down Expand Up @@ -285,10 +288,8 @@ def __post_init_guard(self):
raise ValueError("gamma must be negative")
if self.tol <= 0 or self.tol_m <= 0 or self.eps_solver <= 0:
raise ValueError("tol and tol_m must be positive")
if self.inequality_constraints and self.relax_optim and self.threshold == 0.0:
raise ValueError(
"Ineq. constr. -> threshold!=0 + relax_optim=True or relax_optim=False."
)
if self.inequality_constraints and self.threshold == 0.0:
raise ValueError("Inequality constraints requires threshold!=0")

def set_params(self, **kwargs):
super().set_params(**kwargs)
Expand Down
15 changes: 1 addition & 14 deletions test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,6 @@ def test_stable_linear_sr3_linear_library():
[
dict(),
dict(accel=True),
dict(relax_optim=False),
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -938,7 +937,7 @@ def test_constrained_inequality_constraints(data_lorenz, params):
)
def test_trapping_cost_function(params):
expected = params.pop("expected")
opt = TrappingSR3(relax_optim=True, **params)
opt = TrappingSR3(**params)
x = np.eye(2)
y = np.ones(2)
xi, cost = opt._create_var_and_part_cost(2, x, y)
Expand All @@ -960,17 +959,6 @@ def test_trapping_inequality_constraints():
constraint_rhs=constraint_rhs,
constraint_order="feature",
inequality_constraints=True,
relax_optim=True,
)
opt.fit(x, y)
assert np.all(np.dot(constraint_matrix, (opt.coef_).flatten()) <= constraint_rhs)
# Run Trapping SR3 with CVXPY for the m solve
opt = TrappingSR3(
constraint_lhs=constraint_matrix,
constraint_rhs=constraint_rhs,
constraint_order="feature",
inequality_constraints=True,
relax_optim=False,
)
opt.fit(x, y)
assert np.all(np.dot(constraint_matrix, (opt.coef_).flatten()) <= constraint_rhs)
Expand Down Expand Up @@ -1023,7 +1011,6 @@ def test_inequality_constraints_reqs():
constraint_rhs=constraint_rhs,
constraint_order="feature",
inequality_constraints=True,
relax_optim=True,
)


Expand Down

0 comments on commit 3ddb3ba

Please sign in to comment.