diff --git a/test/test_optimizers.py b/test/test_optimizers.py index c69ce9823..35d0bf1c6 100644 --- a/test/test_optimizers.py +++ b/test/test_optimizers.py @@ -22,6 +22,7 @@ from pysindy.optimizers import EnsembleOptimizer from pysindy.optimizers import FROLS from pysindy.optimizers import MIOSR +from pysindy.optimizers import SBR from pysindy.optimizers import SINDyPI from pysindy.optimizers import SR3 from pysindy.optimizers import SSR @@ -77,6 +78,7 @@ def predict(self, x): (StableLinearSR3, True), (TrappingSR3, True), (DummyLinearModel, False), + (SBR, True), ], ) def test_supports_multiple_targets(cls, support): @@ -103,6 +105,7 @@ def data(request): ElasticNet(fit_intercept=False), DummyLinearModel(), MIOSR(), + SBR(), ], ) def test_fit(data_derivative_1d, optimizer): @@ -122,14 +125,14 @@ def test_fit(data_derivative_1d, optimizer): @pytest.mark.parametrize( "optimizer", - [STLSQ(), SSR(), SSR(criteria="model_residual"), FROLS(), SR3(), MIOSR()], + [STLSQ(), SSR(), SSR(criteria="model_residual"), FROLS(), SR3(), MIOSR(), SBR()], ) def test_not_fitted(optimizer): with pytest.raises(NotFittedError): optimizer.predict(np.ones((1, 3))) -@pytest.mark.parametrize("optimizer", [STLSQ(), SR3()]) +@pytest.mark.parametrize("optimizer", [STLSQ(), SR3(), SBR()]) def test_complexity_not_fitted(optimizer, data_derivative_2d): with pytest.raises(NotFittedError): optimizer.complexity @@ -665,6 +668,7 @@ def test_constrained_sr3_prox_functions(data_derivative_1d, thresholder): (StableLinearSR3, {"trimming_fraction": 0.1}), (SINDyPI, {}), (MIOSR, {"constraint_lhs": [1]}), + (SBR, {}), ), ) def test_illegal_unbias(data_derivative_1d, opt_cls, opt_args): @@ -986,6 +990,7 @@ def test_inequality_constraints_reqs(): StableLinearSR3, TrappingSR3, MIOSR, + SBR, ], ) def test_normalize_columns(data_derivative_1d, optimizer):