Skip to content

Commit

Permalink
Squash me (make Trapping match Constrained interface)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Nov 8, 2023
1 parent 44a1c59 commit 9907583
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
1 change: 1 addition & 0 deletions pysindy/optimizers/trapping_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ def _solve_sparse_relax_and_split(self, xi, cost, var_len, coef_prev):
# default solver is OSQP here but switches to ECOS for L2
try:
prob.solve(
max_iter=self.max_iter,
eps_abs=self.eps_solver,
eps_rel=self.eps_solver,
verbose=self.verbose_cvxpy,
Expand Down
20 changes: 16 additions & 4 deletions test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,10 @@ def test_fit_warn(data_derivative_1d, optimizer):
optimizer.fit(x, x_dot)


@pytest.mark.parametrize("optimizer", [ConstrainedSR3, TrappingSR3, MIOSR])
@pytest.mark.parametrize(
"optimizer",
[(ConstrainedSR3, {"max_iter": 80}), (TrappingSR3, {"max_iter": 100}), (MIOSR, {})],
)
@pytest.mark.parametrize("target_value", [0, -1, 3])
def test_row_format_constraints(data_linear_combination, optimizer, target_value):
# Solution is x_dot = x.dot(np.array([[1, 1, 0], [0, 1, 1]]))
Expand All @@ -805,10 +808,11 @@ def test_row_format_constraints(data_linear_combination, optimizer, target_value
constraint_lhs[0, 0] = 1
constraint_lhs[1, 3] = 1

model = optimizer(
model = optimizer[0](
constraint_lhs=constraint_lhs,
constraint_rhs=constraint_rhs,
constraint_order="feature",
**optimizer[1],
)
model.fit(x, x_dot)

Expand All @@ -818,7 +822,13 @@ def test_row_format_constraints(data_linear_combination, optimizer, target_value


@pytest.mark.parametrize(
"optimizer", [ConstrainedSR3, StableLinearSR3, TrappingSR3, MIOSR]
"optimizer",
[
(ConstrainedSR3, {"max_iter": 80}),
(StableLinearSR3, {}),
(TrappingSR3, {"max_iter": 100}),
(MIOSR, {}),
],
)
@pytest.mark.parametrize("target_value", [0, -1, 3])
def test_target_format_constraints(data_linear_combination, optimizer, target_value):
Expand All @@ -831,7 +841,9 @@ def test_target_format_constraints(data_linear_combination, optimizer, target_va
constraint_lhs[0, 1] = 1
constraint_lhs[1, 4] = 1

model = optimizer(constraint_lhs=constraint_lhs, constraint_rhs=constraint_rhs)
model = optimizer[0](
constraint_lhs=constraint_lhs, constraint_rhs=constraint_rhs, **optimizer[1]
)
model.fit(x, x_dot)
np.testing.assert_allclose(model.coef_[:, 1], target_value, atol=1e-8)

Expand Down

0 comments on commit 9907583

Please sign in to comment.