Skip to content

Commit

Permalink
fix according to comments
Browse files Browse the repository at this point in the history
  • Loading branch information
himkwtn committed Sep 3, 2024
1 parent 870525d commit a0475aa
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 22 deletions.
12 changes: 5 additions & 7 deletions pysindy/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,9 @@ def get_prox(
Returns:
--------
proximal_operator: (x: np.array, reg_weight: float | np.array) -> np.array
A function that takes an input x of shape (m, n)
and regularization weight factor which can be a scalar or
an array of shape (m, n),
and returns an array of shape (m, n)
A function that takes an input array x and a regularization weight,
which can be either a scalar or array of the same shape,
and returns an array of the same shape
"""

def prox_l0(
Expand Down Expand Up @@ -235,9 +234,8 @@ def get_regularization(
Returns:
--------
regularization_function: (x: np.array, reg_weight: float | np.array) -> np.array
A function that takes an input x of shape (m, n)
and regularization weight factor which can be a scalar or
an array of shape (m, n),
A function that takes an input array x and a regularization weight,
which can be either a scalar or array of the same shape,
and returns a float
"""

Expand Down
21 changes: 6 additions & 15 deletions test/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,8 @@ def test_get_regularization(regularization, lam, expected):
"lam",
[
np.array([[1, 2]]),
np.array([1, 2]),
np.array([[1, 2]]).T,
np.array([1]),
np.array([[1]]),
],
)
def test_get_prox_and_regularization_bad_shape(regularization, lam):
Expand All @@ -111,11 +109,7 @@ def test_get_prox_and_regularization_bad_shape(regularization, lam):
"lam",
[
np.array([[1, 2]]),
np.array([1, 2, 3]),
np.array([[1, 2, 3]]).T,
1,
np.array([1]),
np.array([[1]]),
],
)
def test_get_weighted_prox_and_regularization_bad_shape(regularization, lam):
Expand All @@ -131,16 +125,13 @@ def test_get_weighted_prox_and_regularization_bad_shape(regularization, lam):
@pytest.mark.parametrize(
["regularization", "lam", "expected"],
[
("l0", 3, np.array([[0, 5]]).T),
("l1", 3, np.array([[0, 2]]).T),
("l2", 3, np.array([[-2 / 7, 5 / 7]]).T),
("weighted_l0", np.array([[3, 2]]).T, np.array([[0, 5]]).T),
("weighted_l1", np.array([[3, 2]]).T, np.array([[0, 3]]).T),
("weighted_l2", np.array([[3, 2]]).T, np.array([[-2 / 7, 5 / 5]]).T),
("l0", 1, np.array([[2]])),
("l1", 0.5, np.array([[1.5]])),
("l2", 0.5, np.array([[1]])),
],
)
def test_get_prox(regularization, expected, lam):
data = np.array([[-2, 5]]).T
def test_get_prox(regularization, lam, expected):
data = np.array([[2]])

prox = get_prox(regularization)
result = prox(data, lam)
Expand Down

0 comments on commit a0475aa

Please sign in to comment.