Skip to content

Commit

Permalink
fix(ssr): Select model by error inflection point
Browse files Browse the repository at this point in the history
Closes #532

This commit implements the inflection point criteria from the SSR paper,
although still measuring training loss, rather than cross validation loss.
  • Loading branch information
Jacob-Stevens-Haas committed Oct 16, 2024
1 parent 31f909b commit 6747bf7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
8 changes: 6 additions & 2 deletions pysindy/optimizers/ssr.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,5 +227,9 @@ def _reduce(self, x, y):
if np.all(np.sum(np.asarray(inds, dtype=int), axis=1) <= 1):
# each equation has one last term
break
err_min = np.argmin(self.err_history_)
self.coef_ = np.asarray(self.history_)[err_min, :, :]

# err history is reverse of ordering in paper
self.err_history_ = self.err_history_[::-1]
err_ratio = np.array(self.err_history_[:-1]) / np.array(self.err_history_[1:])
ind_err_inflection = np.argmax(err_ratio) + 1
self.coef_ = np.asarray(self.history_)[ind_err_inflection, :, :]
15 changes: 9 additions & 6 deletions test/test_optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,13 +1197,16 @@ def test_pickle(data_lorenz, opt_cls, opt_args):


def test_ssr_history():
x = np.zeros((10, 3))
y = np.ones((10,))
x[:, 0] = y
x += np.random.normal(size=(10, 3), scale=1e-2)
x = np.zeros((10, 8))
rng = np.random.default_rng(1)
real_cols = rng.normal(scale=3, size=(10, 4))
x[:, :4] = real_cols
expected = np.array([[1, 1, 1, 1, 0, 0, 0, 0]])
y = x @ expected.T

x += np.random.normal(size=(10, 8), scale=1e-2)
opt = SSR()
result = opt.fit(x, y).coef_
expected = np.array([[1, 0, 0]])

assert len(opt.history_) == len(opt.err_history_)
np.testing.assert_allclose(result, expected)
np.testing.assert_allclose(result, expected, atol=1e-2)

0 comments on commit 6747bf7

Please sign in to comment.