diff --git a/test/test_optimizers/test_optimizers.py b/test/test_optimizers/test_optimizers.py index 5e459c0e..a2518476 100644 --- a/test/test_optimizers/test_optimizers.py +++ b/test/test_optimizers/test_optimizers.py @@ -508,60 +508,38 @@ def test_stable_linear_sr3_linear_library(): assert np.allclose(opt.coef_.flatten(), 0.0) -@pytest.mark.parametrize( - "params", - [ - dict(regularizer="l1", reg_weight_lam=0, _include_bias=True), - dict(regularizer="l1", reg_weight_lam=1e-5, _include_bias=True), - dict( - regularizer="weighted_l1", - reg_weight_lam=np.zeros((1, 2)), - eta=1e5, - alpha_m=1e4, - alpha_A=1e5, - _include_bias=False, - ), - dict( - regularizer="weighted_l1", - reg_weight_lam=1e-5 * np.ones((1, 2)), - _include_bias=False, - ), - dict(regularizer="l2", reg_weight_lam=0, _include_bias=True), - dict(regularizer="l2", reg_weight_lam=1e-5, _include_bias=True), - dict( - regularizer="weighted_l2", - reg_weight_lam=np.zeros((1, 2)), - _include_bias=False, - ), - dict( - regularizer="weighted_l2", - reg_weight_lam=1e-5 * np.ones((1, 2)), - _include_bias=False, - ), - ], -) -def test_trapping_sr3_quadratic_library(params): +@pytest.mark.parametrize("bias", (True, False)) +@pytest.mark.parametrize("method", ("global", "local")) +@pytest.mark.parametrize("reg_weight", (0.0, 1e-1)) +def test_trapping_sr3_quadratic_library(bias, method, reg_weight): t = np.arange(0, 1, 0.1) x = np.exp(-t).reshape((-1, 1)) x_dot = -x features = np.hstack([x, x**2]) - if params.get("_include_bias"): + if bias: features = np.hstack([np.ones_like(x), features]) - opt = TrappingSR3(_n_tgts=1, **params) + params = { + "_n_tgts": 1, + "_include_bias": bias, + "method": method, + "reg_weight_lam": reg_weight, + } + + opt = TrappingSR3(**params) opt.fit(features, x_dot) check_is_fitted(opt) # Rerun with identity constraints r = x.shape[1] - N = 2 + params.get("_include_bias", 0) + N = 2 + bias params["constraint_rhs"] = np.zeros(r * N) params["constraint_lhs"] = np.eye(r * N, r * N) - opt = TrappingSR3(_n_tgts=1, **params) + opt = TrappingSR3(**params) opt.fit(features, x_dot) check_is_fitted(opt) - # check is solve was infeasible first + # check if solve was infeasible first if not np.allclose(opt.m_history_[-1], opt.m_history_[0]): assert np.allclose((opt.coef_.flatten())[0], 0.0, atol=1e-5)