Skip to content

Commit

Permalink
fix (ssr): Re-implement l0-based model seletion option
Browse files Browse the repository at this point in the history
Also change test to use better-conditioned feature library, since l0 penalty
is based upon condition number.

Also, fix use of inflection point model selection.  Because error history is
reversed before calculating inflection index, index is from the end of the list

Also, extract calculating the index of inflection in SSR (+ tests)
  • Loading branch information
Jacob-Stevens-Haas committed Oct 16, 2024
1 parent 6747bf7 commit 9b5130a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
26 changes: 21 additions & 5 deletions pysindy/optimizers/ssr.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,24 @@ def _reduce(self, x, y):
# each equation has one last term
break

# err history is reverse of ordering in paper
self.err_history_ = self.err_history_[::-1]
err_ratio = np.array(self.err_history_[:-1]) / np.array(self.err_history_[1:])
ind_err_inflection = np.argmax(err_ratio) + 1
self.coef_ = np.asarray(self.history_)[ind_err_inflection, :, :]
if self.kappa is not None:
ind_best = np.argmin(self.err_history_)
else:
# err history is reverse of ordering in paper
ind_best = (
len(self.err_history_) - 1 - _ind_inflection(self.err_history_[::-1])
)
self.coef_ = np.asarray(self.history_)[ind_best, :, :]


def _ind_inflection(err_descending: list[float]) -> int:
"Calculate the index of the inflection point in error"
if len(err_descending) == 1:
raise ValueError("Cannot find the inflection point of a single point")
err_descending = np.array(err_descending)
if np.any(err_descending < 0):
raise ValueError("SSR inflection point method requires nonnegative losses")
if np.any(err_descending == 0):
return np.argmin(err_descending)
err_ratio = err_descending[:-1] / err_descending[1:]
return np.argmax(err_ratio) + 1
38 changes: 31 additions & 7 deletions test/test_optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pysindy.optimizers import STLSQ
from pysindy.optimizers import TrappingSR3
from pysindy.optimizers import WrappedOptimizer
from pysindy.optimizers.ssr import _ind_inflection
from pysindy.optimizers.stlsq import _remove_and_decrement
from pysindy.utils import supports_multiple_targets
from pysindy.utils.odes import enzyme
Expand Down Expand Up @@ -1196,17 +1197,40 @@ def test_pickle(data_lorenz, opt_cls, opt_args):
np.testing.assert_array_equal(result, expected)


def test_ssr_history():
x = np.zeros((10, 8))
@pytest.mark.parametrize("kappa", (None, 0.1), ids=["inflection", "L0"])
def test_ssr_history_selection(kappa):
rng = np.random.default_rng(1)
real_cols = rng.normal(scale=3, size=(10, 4))
x[:, :4] = real_cols
expected = np.array([[1, 1, 1, 1, 0, 0, 0, 0]])
x = rng.normal(size=(30, 8))
expected = np.array([[1, 1, 1, 0, 0, 0, 0, 0]])
y = x @ expected.T

x += np.random.normal(size=(10, 8), scale=1e-2)
opt = SSR()
x += np.random.normal(size=(30, 8), scale=1e-2)
opt = SSR(kappa=kappa)
result = opt.fit(x, y).coef_

assert len(opt.history_) == len(opt.err_history_)
np.testing.assert_allclose(result, expected, atol=1e-2)
np.testing.assert_array_equal(result == 0, expected == 0)


@pytest.mark.parametrize(
["errs", "expected"],
(([3, 1, 0.9], 1), ([1, 0, 0], 1)),
ids=["basic", "zero-error"],
)
def test_ssr_inflection(errs, expected):
result = _ind_inflection(errs)
assert result == expected


@pytest.mark.parametrize(
["errs", "expected", "message"],
(
([1], ValueError, "single point"),
([-1, 1, 1], ValueError, ""),
),
ids=["length-1", "negative"],
)
def test_ssr_inflection_bad_args(errs, expected, message):
with pytest.raises(expected, match=message):
_ind_inflection(errs)

0 comments on commit 9b5130a

Please sign in to comment.