Skip to content

Commit

Permalink
CLN(sr3): improve get_regularization test case
Browse files Browse the repository at this point in the history
  • Loading branch information
himkwtn committed Aug 7, 2024
1 parent 1cb0f03 commit f5cd33e
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,10 @@ def test_sr3_bad_parameters(optimizer, params):


@pytest.mark.parametrize(
["thresholder", "expected"], [("l0", 4), ("l1", 6), ("l2", 10)]
["thresholder", "expected"], [("l0", 4), ("l1", 16), ("l2", 68)]
)
def test_get_regularization_1d(thresholder, expected):
data = np.array([[0, 1, 2]]).T
data = np.array([[0, 3, 5]]).T
lam = 2

reg = SR3.get_regularization(thresholder)
Expand All @@ -233,10 +233,10 @@ def test_get_regularization_1d(thresholder, expected):


@pytest.mark.parametrize(
["thresholder", "expected"], [("l0", 8), ("l1", 10), ("l2", 14)]
["thresholder", "expected"], [("l0", 8), ("l1", 52), ("l2", 408)]
)
def test_get_regularization_2d(thresholder, expected):
data = np.array([[0, 1, 2], [1, 1, 0]]).T
data = np.array([[0, 3, 5], [7, 11, 0]]).T
lam = 2

reg = SR3.get_regularization(thresholder)
Expand All @@ -246,11 +246,11 @@ def test_get_regularization_2d(thresholder, expected):

@pytest.mark.parametrize(
["thresholder", "expected"],
[("weighted_l0", 1.5), ("weighted_l1", 2), ("weighted_l2", 3)],
[("weighted_l0", 2.5), ("weighted_l1", 8.5), ("weighted_l2", 30.5)],
)
def test_get_weighted_regularization_1d(thresholder, expected):
data = np.array([[0, 1, 2]]).T
lam = np.array([[1, 1, 0.5]]).T
data = np.array([[0, 3, 5]]).T
lam = np.array([[3, 2, 0.5]]).T

reg = SR3.get_regularization(thresholder)
result = reg(data, lam)
Expand All @@ -259,11 +259,11 @@ def test_get_weighted_regularization_1d(thresholder, expected):

@pytest.mark.parametrize(
["thresholder", "expected"],
[("weighted_l0", 2.5), ("weighted_l1", 3.5), ("weighted_l2", 5.5)],
[("weighted_l0", 16.5), ("weighted_l1", 158.5), ("weighted_l2", 1652.5)],
)
def test_get_weighted_regularization_2d(thresholder, expected):
data = np.array([[0, 1, 2], [2, 1, 0]]).T
lam = np.array([[1, 1, 0.5], [0.5, 0.5, 1]]).T
data = np.array([[0, 3, 5], [7, 11, 0]]).T
lam = np.array([[3, 2, 0.5], [1, 13, 17]]).T

reg = SR3.get_regularization(thresholder)
result = reg(data, lam)
Expand Down

0 comments on commit f5cd33e

Please sign in to comment.