Skip to content

Commit

Permalink
Added SBR optimizer to a bunch of tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mikkelbue committed Feb 5, 2024
1 parent 0b25e65 commit 57f4178
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pysindy.optimizers import EnsembleOptimizer
from pysindy.optimizers import FROLS
from pysindy.optimizers import MIOSR
from pysindy.optimizers import SBR
from pysindy.optimizers import SINDyPI
from pysindy.optimizers import SR3
from pysindy.optimizers import SSR
Expand Down Expand Up @@ -77,6 +78,7 @@ def predict(self, x):
(StableLinearSR3, True),
(TrappingSR3, True),
(DummyLinearModel, False),
(SBR, True),
],
)
def test_supports_multiple_targets(cls, support):
Expand All @@ -103,6 +105,7 @@ def data(request):
ElasticNet(fit_intercept=False),
DummyLinearModel(),
MIOSR(),
SBR(),
],
)
def test_fit(data_derivative_1d, optimizer):
Expand All @@ -122,14 +125,14 @@ def test_fit(data_derivative_1d, optimizer):

@pytest.mark.parametrize(
"optimizer",
[STLSQ(), SSR(), SSR(criteria="model_residual"), FROLS(), SR3(), MIOSR()],
[STLSQ(), SSR(), SSR(criteria="model_residual"), FROLS(), SR3(), MIOSR(), SBR()],
)
def test_not_fitted(optimizer):
with pytest.raises(NotFittedError):
optimizer.predict(np.ones((1, 3)))


@pytest.mark.parametrize("optimizer", [STLSQ(), SR3()])
@pytest.mark.parametrize("optimizer", [STLSQ(), SR3(), SBR()])
def test_complexity_not_fitted(optimizer, data_derivative_2d):
with pytest.raises(NotFittedError):
optimizer.complexity
Expand Down Expand Up @@ -665,6 +668,7 @@ def test_constrained_sr3_prox_functions(data_derivative_1d, thresholder):
(StableLinearSR3, {"trimming_fraction": 0.1}),
(SINDyPI, {}),
(MIOSR, {"constraint_lhs": [1]}),
(SBR, {}),
),
)
def test_illegal_unbias(data_derivative_1d, opt_cls, opt_args):
Expand Down Expand Up @@ -986,6 +990,7 @@ def test_inequality_constraints_reqs():
StableLinearSR3,
TrappingSR3,
MIOSR,
SBR,
],
)
def test_normalize_columns(data_derivative_1d, optimizer):
Expand Down

0 comments on commit 57f4178

Please sign in to comment.