Skip to content

Commit

Permalink
ENH: unit tests for get_regularization
Browse files Browse the repository at this point in the history
  • Loading branch information
himkwtn committed Aug 14, 2024
1 parent b436eab commit 070d0c0
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions test/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pysindy.utils import AxesArray
from pysindy.utils import reorder_constraints
from pysindy.utils import validate_control_variables
from pysindy.utils import get_regularization


def test_reorder_constraints_1D():
Expand Down Expand Up @@ -59,3 +60,53 @@ def test_validate_controls():
validate_control_variables([arr], [arr[:1]])
u_mod = validate_control_variables([arr], [arr], trim_last_point=True)
assert u_mod[0].n_time == 1


@pytest.mark.parametrize(
["regularization", "expected"], [("l0", 4), ("l1", 16), ("l2", 68)]
)
def test_get_regularization_1d(regularization, expected):
data = np.array([[0, 3, 5]]).T
lam = np.array([[2]])

reg = get_regularization(regularization)
result = reg(data, lam)
assert result == expected


@pytest.mark.parametrize(
["regularization", "expected"], [("l0", 8), ("l1", 52), ("l2", 408)]
)
def test_get_regularization_2d(regularization, expected):
data = np.array([[0, 3, 5], [7, 11, 0]]).T
lam = np.array([[2]])

reg = get_regularization(regularization)
result = reg(data, lam)
assert result == expected


@pytest.mark.parametrize(
["regularization", "expected"],
[("weighted_l0", 2.5), ("weighted_l1", 8.5), ("weighted_l2", 30.5)],
)
def test_get_weighted_regularization_1d(regularization, expected):
data = np.array([[0, 3, 5]]).T
lam = np.array([[3, 2, 0.5]]).T

reg = get_regularization(regularization)
result = reg(data, lam)
assert result == expected


@pytest.mark.parametrize(
["regularization", "expected"],
[("weighted_l0", 16.5), ("weighted_l1", 158.5), ("weighted_l2", 1652.5)],
)
def test_get_weighted_regularization_2d(regularization, expected):
data = np.array([[0, 3, 5], [7, 11, 0]]).T
lam = np.array([[3, 2, 0.5], [1, 13, 17]]).T

reg = get_regularization(regularization)
result = reg(data, lam)
assert result == expected

0 comments on commit 070d0c0

Please sign in to comment.