diff --git a/test/test_optimizers.py b/test/test_optimizers.py index 263eb3bf..b19044be 100644 --- a/test/test_optimizers.py +++ b/test/test_optimizers.py @@ -221,10 +221,10 @@ def test_sr3_bad_parameters(optimizer, params): @pytest.mark.parametrize( - ["thresholder", "expected"], [("l0", 4), ("l1", 6), ("l2", 10)] + ["thresholder", "expected"], [("l0", 4), ("l1", 16), ("l2", 68)] ) def test_get_regularization_1d(thresholder, expected): - data = np.array([[0, 1, 2]]).T + data = np.array([[0, 3, 5]]).T lam = 2 reg = SR3.get_regularization(thresholder) @@ -233,10 +233,10 @@ def test_get_regularization_1d(thresholder, expected): @pytest.mark.parametrize( - ["thresholder", "expected"], [("l0", 8), ("l1", 10), ("l2", 14)] + ["thresholder", "expected"], [("l0", 8), ("l1", 52), ("l2", 408)] ) def test_get_regularization_2d(thresholder, expected): - data = np.array([[0, 1, 2], [1, 1, 0]]).T + data = np.array([[0, 3, 5], [7, 11, 0]]).T lam = 2 reg = SR3.get_regularization(thresholder) @@ -246,11 +246,11 @@ def test_get_regularization_2d(thresholder, expected): @pytest.mark.parametrize( ["thresholder", "expected"], - [("weighted_l0", 1.5), ("weighted_l1", 2), ("weighted_l2", 3)], + [("weighted_l0", 2.5), ("weighted_l1", 8.5), ("weighted_l2", 30.5)], ) def test_get_weighted_regularization_1d(thresholder, expected): - data = np.array([[0, 1, 2]]).T - lam = np.array([[1, 1, 0.5]]).T + data = np.array([[0, 3, 5]]).T + lam = np.array([[3, 2, 0.5]]).T reg = SR3.get_regularization(thresholder) result = reg(data, lam) @@ -259,11 +259,11 @@ def test_get_weighted_regularization_1d(thresholder, expected): @pytest.mark.parametrize( ["thresholder", "expected"], - [("weighted_l0", 2.5), ("weighted_l1", 3.5), ("weighted_l2", 5.5)], + [("weighted_l0", 16.5), ("weighted_l1", 158.5), ("weighted_l2", 1652.5)], ) def test_get_weighted_regularization_2d(thresholder, expected): - data = np.array([[0, 1, 2], [2, 1, 0]]).T - lam = np.array([[1, 1, 0.5], [0.5, 0.5, 1]]).T + data = np.array([[0, 3, 5], [7, 11, 0]]).T + lam = np.array([[3, 2, 0.5], [1, 13, 17]]).T reg = SR3.get_regularization(thresholder) result = reg(data, lam)