From 070d0c0ccf7e99c3fb109ad64a0946b53ab33e3a Mon Sep 17 00:00:00 2001 From: himkwtn Date: Tue, 13 Aug 2024 17:48:45 -0700 Subject: [PATCH] ENH: unit tests for get_regularization --- test/utils/test_utils.py | 51 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/test/utils/test_utils.py b/test/utils/test_utils.py index 4c4d2daf2..8b68ca0ca 100644 --- a/test/utils/test_utils.py +++ b/test/utils/test_utils.py @@ -4,6 +4,7 @@ from pysindy.utils import AxesArray from pysindy.utils import reorder_constraints from pysindy.utils import validate_control_variables +from pysindy.utils import get_regularization def test_reorder_constraints_1D(): @@ -59,3 +60,53 @@ def test_validate_controls(): validate_control_variables([arr], [arr[:1]]) u_mod = validate_control_variables([arr], [arr], trim_last_point=True) assert u_mod[0].n_time == 1 + + +@pytest.mark.parametrize( + ["regularization", "expected"], [("l0", 4), ("l1", 16), ("l2", 68)] +) +def test_get_regularization_1d(regularization, expected): + data = np.array([[0, 3, 5]]).T + lam = np.array([[2]]) + + reg = get_regularization(regularization) + result = reg(data, lam) + assert result == expected + + +@pytest.mark.parametrize( + ["regularization", "expected"], [("l0", 8), ("l1", 52), ("l2", 408)] +) +def test_get_regularization_2d(regularization, expected): + data = np.array([[0, 3, 5], [7, 11, 0]]).T + lam = np.array([[2]]) + + reg = get_regularization(regularization) + result = reg(data, lam) + assert result == expected + + +@pytest.mark.parametrize( + ["regularization", "expected"], + [("weighted_l0", 2.5), ("weighted_l1", 8.5), ("weighted_l2", 30.5)], +) +def test_get_weighted_regularization_1d(regularization, expected): + data = np.array([[0, 3, 5]]).T + lam = np.array([[3, 2, 0.5]]).T + + reg = get_regularization(regularization) + result = reg(data, lam) + assert result == expected + + +@pytest.mark.parametrize( + ["regularization", "expected"], + [("weighted_l0", 16.5), ("weighted_l1", 158.5), ("weighted_l2", 1652.5)], +) +def test_get_weighted_regularization_2d(regularization, expected): + data = np.array([[0, 3, 5], [7, 11, 0]]).T + lam = np.array([[3, 2, 0.5], [1, 13, 17]]).T + + reg = get_regularization(regularization) + result = reg(data, lam) + assert result == expected