From 273d984463043664f1705f00464bf0fb0b7c00ae Mon Sep 17 00:00:00 2001 From: himkwtn Date: Mon, 5 Aug 2024 11:47:54 -0700 Subject: [PATCH] bug(util): fix get_regularization calculation --- pysindy/utils/base.py | 10 +++++++--- test/utils/test_utils.py | 41 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/pysindy/utils/base.py b/pysindy/utils/base.py index 114916e7f..761fe0289 100644 --- a/pysindy/utils/base.py +++ b/pysindy/utils/base.py @@ -1,6 +1,8 @@ import warnings from itertools import repeat +from typing import Callable from typing import Sequence +from typing import Union import numpy as np from scipy.optimize import bisect @@ -225,7 +227,9 @@ def get_prox(regularization): raise NotImplementedError("{} has not been implemented".format(regularization)) -def get_regularization(regularization): +def get_regularization( + regularization: str, +) -> Callable[[np.ndarray, Union[float, np.ndarray]], float]: if regularization.lower() == "l0": return lambda x, lam: lam * np.count_nonzero(x) elif regularization.lower() == "weighted_l0": @@ -233,11 +237,11 @@ def get_regularization(regularization): elif regularization.lower() == "l1": return lambda x, lam: lam * np.sum(np.abs(x)) elif regularization.lower() == "weighted_l1": - return lambda x, lam: np.sum(np.abs(lam @ x)) + return lambda x, lam: np.sum(np.abs(lam * x)) elif regularization.lower() == "l2": return lambda x, lam: lam * np.sum(x**2) elif regularization.lower() == "weighted_l2": - return lambda x, lam: np.sum(lam @ x**2) + return lambda x, lam: np.sum(lam * x**2) elif regularization.lower() == "cad": # dummy function return lambda x, lam: 0 else: diff --git a/test/utils/test_utils.py b/test/utils/test_utils.py index 4c4d2daf2..850ef3cb3 100644 --- a/test/utils/test_utils.py +++ b/test/utils/test_utils.py @@ -2,6 +2,7 @@ import pytest from pysindy.utils import AxesArray +from pysindy.utils import get_regularization from pysindy.utils import reorder_constraints from pysindy.utils import validate_control_variables @@ -59,3 +60,43 @@ def test_validate_controls(): validate_control_variables([arr], [arr[:1]]) u_mod = validate_control_variables([arr], [arr], trim_last_point=True) assert u_mod[0].n_time == 1 + + +@pytest.mark.parametrize(["thresholder", "expected"], [("l0", 4), ("l1", 6), ("l2", 10)]) +def test_get_regularization_1d(thresholder, expected): + data = np.array([[0, 1, 2]]).T + lam = 2 + + reg = get_regularization(thresholder) + result = reg(data, lam) + assert result == expected + + +@pytest.mark.parametrize(["thresholder", "expected"], [("l0", 8), ("l1", 10), ("l2", 14)]) +def test_get_regularization_2d(thresholder, expected): + data = np.array([[0, 1, 2], [1, 1, 0]]).T + lam = 2 + + reg = get_regularization(thresholder) + result = reg(data, lam) + assert result == expected + + +@pytest.mark.parametrize(["thresholder", "expected"], [("weighted_l0", 1.5), ("weighted_l1", 2), ("weighted_l2", 3)]) +def test_get_weighted_regularization_1d(thresholder, expected): + data = np.array([[0, 1, 2]]).T + lam = np.array([[1, 1, 0.5]]).T + + reg = get_regularization(thresholder) + result = reg(data, lam) + assert result == expected + + +@pytest.mark.parametrize(["thresholder", "expected"], [("weighted_l0", 2.5), ("weighted_l1", 3.5), ("weighted_l2", 5.5)]) +def test_get_weighted_regularization_2d(thresholder, expected): + data = np.array([[0, 1, 2], [2, 1, 0]]).T + lam = np.array([[1, 1, 0.5], [0.5, 0.5, 1]]).T + + reg = get_regularization(thresholder) + result = reg(data, lam) + assert result == expected