From a0475aa1e478255d21be23958e834413eeeeb7f4 Mon Sep 17 00:00:00 2001 From: himkwtn Date: Tue, 3 Sep 2024 16:33:27 -0700 Subject: [PATCH] fix according to comments --- pysindy/utils/base.py | 12 +++++------- test/utils/test_utils.py | 21 ++++++--------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/pysindy/utils/base.py b/pysindy/utils/base.py index c62a38db..a1be22b6 100644 --- a/pysindy/utils/base.py +++ b/pysindy/utils/base.py @@ -186,10 +186,9 @@ def get_prox( Returns: -------- proximal_operator: (x: np.array, reg_weight: float | np.array) -> np.array - A function that takes an input x of shape (m, n) - and regularization weight factor which can be a scalar or - an array of shape (m, n), - and returns an array of shape (m, n) + A function that takes an input array x and a regularization weight, + which can be either a scalar or array of the same shape, + and returns an array of the same shape """ def prox_l0( @@ -235,9 +234,8 @@ def get_regularization( Returns: -------- regularization_function: (x: np.array, reg_weight: float | np.array) -> np.array - A function that takes an input x of shape (m, n) - and regularization weight factor which can be a scalar or - an array of shape (m, n), + A function that takes an input array x and a regularization weight, + which can be either a scalar or array of the same shape, and returns a float """ diff --git a/test/utils/test_utils.py b/test/utils/test_utils.py index 311fe949..5c3e32a7 100644 --- a/test/utils/test_utils.py +++ b/test/utils/test_utils.py @@ -88,10 +88,8 @@ def test_get_regularization(regularization, lam, expected): "lam", [ np.array([[1, 2]]), - np.array([1, 2]), - np.array([[1, 2]]).T, - np.array([1]), np.array([[1]]), + ], ) def test_get_prox_and_regularization_bad_shape(regularization, lam): @@ -111,11 +109,7 @@ def test_get_prox_and_regularization_bad_shape(regularization, lam): "lam", [ np.array([[1, 2]]), - np.array([1, 2, 3]), - np.array([[1, 2, 3]]).T, 1, - np.array([1]), - np.array([[1]]), ], ) def test_get_weighted_prox_and_regularization_bad_shape(regularization, lam): @@ -131,16 +125,13 @@ def test_get_weighted_prox_and_regularization_bad_shape(regularization, lam): @pytest.mark.parametrize( ["regularization", "lam", "expected"], [ - ("l0", 3, np.array([[0, 5]]).T), - ("l1", 3, np.array([[0, 2]]).T), - ("l2", 3, np.array([[-2 / 7, 5 / 7]]).T), - ("weighted_l0", np.array([[3, 2]]).T, np.array([[0, 5]]).T), - ("weighted_l1", np.array([[3, 2]]).T, np.array([[0, 3]]).T), - ("weighted_l2", np.array([[3, 2]]).T, np.array([[-2 / 7, 5 / 5]]).T), + ("l0", 1, np.array([[2]])), + ("l1", 0.5, np.array([[1.5]])), + ("l2", 0.5, np.array([[1]])), ], ) -def test_get_prox(regularization, expected, lam): - data = np.array([[-2, 5]]).T +def test_get_prox(regularization, lam, expected): + data = np.array([[2]]) prox = get_prox(regularization) result = prox(data, lam)