diff --git a/pysindy/optimizers/ssr.py b/pysindy/optimizers/ssr.py index 48f57c0f..2270b07a 100644 --- a/pysindy/optimizers/ssr.py +++ b/pysindy/optimizers/ssr.py @@ -228,8 +228,24 @@ def _reduce(self, x, y): # each equation has one last term break - # err history is reverse of ordering in paper - self.err_history_ = self.err_history_[::-1] - err_ratio = np.array(self.err_history_[:-1]) / np.array(self.err_history_[1:]) - ind_err_inflection = np.argmax(err_ratio) + 1 - self.coef_ = np.asarray(self.history_)[ind_err_inflection, :, :] + if self.kappa is not None: + ind_best = np.argmin(self.err_history_) + else: + # err history is reverse of ordering in paper + ind_best = ( + len(self.err_history_) - 1 - _ind_inflection(self.err_history_[::-1]) + ) + self.coef_ = np.asarray(self.history_)[ind_best, :, :] + + +def _ind_inflection(err_descending: list[float]) -> int: + "Calculate the index of the inflection point in error" + if len(err_descending) == 1: + raise ValueError("Cannot find the inflection point of a single point") + err_descending = np.array(err_descending) + if np.any(err_descending < 0): + raise ValueError("SSR inflection point method requires nonnegative losses") + if np.any(err_descending == 0): + return np.argmin(err_descending) + err_ratio = err_descending[:-1] / err_descending[1:] + return np.argmax(err_ratio) + 1 diff --git a/test/test_optimizers/test_optimizers.py b/test/test_optimizers/test_optimizers.py index 25bfb621..0145f988 100644 --- a/test/test_optimizers/test_optimizers.py +++ b/test/test_optimizers/test_optimizers.py @@ -35,6 +35,7 @@ from pysindy.optimizers import STLSQ from pysindy.optimizers import TrappingSR3 from pysindy.optimizers import WrappedOptimizer +from pysindy.optimizers.ssr import _ind_inflection from pysindy.optimizers.stlsq import _remove_and_decrement from pysindy.utils import supports_multiple_targets from pysindy.utils.odes import enzyme @@ -1196,17 +1197,40 @@ def test_pickle(data_lorenz, opt_cls, opt_args): np.testing.assert_array_equal(result, expected) -def test_ssr_history(): - x = np.zeros((10, 8)) +@pytest.mark.parametrize("kappa", (None, 0.1), ids=["inflection", "L0"]) +def test_ssr_history_selection(kappa): rng = np.random.default_rng(1) - real_cols = rng.normal(scale=3, size=(10, 4)) - x[:, :4] = real_cols - expected = np.array([[1, 1, 1, 1, 0, 0, 0, 0]]) + x = rng.normal(size=(30, 8)) + expected = np.array([[1, 1, 1, 0, 0, 0, 0, 0]]) y = x @ expected.T - x += np.random.normal(size=(10, 8), scale=1e-2) - opt = SSR() + x += np.random.normal(size=(30, 8), scale=1e-2) + opt = SSR(kappa=kappa) result = opt.fit(x, y).coef_ assert len(opt.history_) == len(opt.err_history_) np.testing.assert_allclose(result, expected, atol=1e-2) + np.testing.assert_array_equal(result == 0, expected == 0) + + +@pytest.mark.parametrize( + ["errs", "expected"], + (([3, 1, 0.9], 1), ([1, 0, 0], 1)), + ids=["basic", "zero-error"], +) +def test_ssr_inflection(errs, expected): + result = _ind_inflection(errs) + assert result == expected + + +@pytest.mark.parametrize( + ["errs", "expected", "message"], + ( + ([1], ValueError, "single point"), + ([-1, 1, 1], ValueError, ""), + ), + ids=["length-1", "negative"], +) +def test_ssr_inflection_bad_args(errs, expected, message): + with pytest.raises(expected, match=message): + _ind_inflection(errs)