diff --git a/pysindy/optimizers/trapping_sr3.py b/pysindy/optimizers/trapping_sr3.py index def9d71ee..1d8410bcb 100644 --- a/pysindy/optimizers/trapping_sr3.py +++ b/pysindy/optimizers/trapping_sr3.py @@ -822,13 +822,15 @@ def _reduce(self, x, y): self.objective_history = objective_history -def _make_constraints(n_tgts: int): +def _make_constraints(n_tgts: int, **kwargs): """Create constraints for the Quadratic terms in TrappingSR3. These are the constraints from equation 5 of the Trapping SINDy paper. Args: n_tgts: number of coordinates or modes for which you're fitting an ODE. + kwargs: Keyword arguments to PolynomialLibrary such as + ``include_bias``. Returns: A tuple of the constraint zeros, and a constraint matrix to multiply @@ -841,7 +843,7 @@ def _make_constraints(n_tgts: int): reshaping. """ n_terms = n_poly_features(n_tgts, degree=2, include_bias=False) - lib = PolynomialLibrary(2, include_bias=False).fit(np.zeros((1, n_tgts))) + lib = PolynomialLibrary(2, **kwargs).fit(np.zeros((1, n_tgts))) terms = [(t_ind, exps) for t_ind, exps in enumerate(lib.powers_)] # index of tgt -> index of its pure quadratic term diff --git a/test/test_optimizers.py b/test/test_optimizers.py index 385d809c7..ac65049e2 100644 --- a/test/test_optimizers.py +++ b/test/test_optimizers.py @@ -1135,7 +1135,7 @@ def test_remove_and_decrement(): def test_trapping_constraints(): # x, y, x^2, xy, y^2 - constraint_rhs, constraint_lhs = _make_constraints(2) + constraint_rhs, constraint_lhs = _make_constraints(2, include_bias=False) stable_coefs = np.array([[0, 0, 0, 1, -1], [0, 0, -1, 1, 0]]) result = np.tensordot(constraint_lhs, stable_coefs, ((1, 2), (1, 0))) np.testing.assert_array_equal(constraint_rhs, result)