From 7893d43b7cca2b778be593547e1351678ef4c76e Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Tue, 17 Sep 2024 17:13:11 -0700 Subject: [PATCH] tst: Test local trapping Not necessarily the best test, but coverage should at least find any shape errors that arise. Also, remove tests for different regularizers from trapping, now that that regularization is fully abstracted to superclass, with exception of reg == 0 vs reg != 0 --- test/test_optimizers/test_optimizers.py | 54 ++++++++----------------- 1 file changed, 16 insertions(+), 38 deletions(-) diff --git a/test/test_optimizers/test_optimizers.py b/test/test_optimizers/test_optimizers.py index 5e459c0e..a2518476 100644 --- a/test/test_optimizers/test_optimizers.py +++ b/test/test_optimizers/test_optimizers.py @@ -508,60 +508,38 @@ def test_stable_linear_sr3_linear_library(): assert np.allclose(opt.coef_.flatten(), 0.0) -@pytest.mark.parametrize( - "params", - [ - dict(regularizer="l1", reg_weight_lam=0, _include_bias=True), - dict(regularizer="l1", reg_weight_lam=1e-5, _include_bias=True), - dict( - regularizer="weighted_l1", - reg_weight_lam=np.zeros((1, 2)), - eta=1e5, - alpha_m=1e4, - alpha_A=1e5, - _include_bias=False, - ), - dict( - regularizer="weighted_l1", - reg_weight_lam=1e-5 * np.ones((1, 2)), - _include_bias=False, - ), - dict(regularizer="l2", reg_weight_lam=0, _include_bias=True), - dict(regularizer="l2", reg_weight_lam=1e-5, _include_bias=True), - dict( - regularizer="weighted_l2", - reg_weight_lam=np.zeros((1, 2)), - _include_bias=False, - ), - dict( - regularizer="weighted_l2", - reg_weight_lam=1e-5 * np.ones((1, 2)), - _include_bias=False, - ), - ], -) -def test_trapping_sr3_quadratic_library(params): +@pytest.mark.parametrize("bias", (True, False)) +@pytest.mark.parametrize("method", ("global", "local")) +@pytest.mark.parametrize("reg_weight", (0.0, 1e-1)) +def test_trapping_sr3_quadratic_library(bias, method, reg_weight): t = np.arange(0, 1, 0.1) x = np.exp(-t).reshape((-1, 1)) x_dot = -x features = np.hstack([x, x**2]) - if params.get("_include_bias"): + if bias: features = np.hstack([np.ones_like(x), features]) - opt = TrappingSR3(_n_tgts=1, **params) + params = { + "_n_tgts": 1, + "_include_bias": bias, + "method": method, + "reg_weight_lam": reg_weight, + } + + opt = TrappingSR3(**params) opt.fit(features, x_dot) check_is_fitted(opt) # Rerun with identity constraints r = x.shape[1] - N = 2 + params.get("_include_bias", 0) + N = 2 + bias params["constraint_rhs"] = np.zeros(r * N) params["constraint_lhs"] = np.eye(r * N, r * N) - opt = TrappingSR3(_n_tgts=1, **params) + opt = TrappingSR3(**params) opt.fit(features, x_dot) check_is_fitted(opt) - # check is solve was infeasible first + # check if solve was infeasible first if not np.allclose(opt.m_history_[-1], opt.m_history_[0]): assert np.allclose((opt.coef_.flatten())[0], 0.0, atol=1e-5)