Skip to content

Commit

Permalink
change shape validation
Browse files Browse the repository at this point in the history
  • Loading branch information
himkwtn committed Aug 28, 2024
1 parent 6928455 commit 05ee9e6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 36 deletions.
12 changes: 4 additions & 8 deletions pysindy/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,11 @@ def wrapper(x, regularization_weight):
weight_shape = regularization_weight.shape
if weight_shape != x.shape:
raise ValueError(
f"Invalid shape for 'regularization_weight': \
{weight_shape}. Must be the same shape as x: {x.shape}."
f"Invalid shape for 'regularization_weight':"
f"{weight_shape}. Must be the same shape as x: {x.shape}."
)
else:
if not isinstance(regularization_weight, (int, float)) and (
isinstance(regularization_weight, np.ndarray)
and regularization_weight.shape not in [(1, 1), (1,)]
):
raise ValueError("'regularization_weight' must be a scalar")
elif not isinstance(regularization_weight, (int, float)):
raise ValueError("'regularization_weight' must be a scalar")
return func(x, regularization_weight)

return wrapper
Expand Down
45 changes: 17 additions & 28 deletions test/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,35 +83,16 @@ def test_get_regularization(regularization, lam, expected):
assert result == expected


@pytest.mark.parametrize("regularization", ["l0", "l1", "l2"])
@pytest.mark.parametrize("lam", [1, np.array([1]), np.array([[1]])])
def test_get_prox_and_regularization_shape(regularization, lam):
data = np.array([[-2, 5]]).T
reg = get_regularization(regularization)
reg_result = reg(data, lam)
prox = get_prox(regularization)
prox_result = prox(data, lam)
assert reg_result is not None
assert prox_result is not None


@pytest.mark.parametrize(
"regularization", ["weighted_l0", "weighted_l1", "weighted_l2"]
)
@pytest.mark.parametrize("lam", [np.array([[1, 2]]).T])
def test_get_weighted_prox_and_regularization_shape(regularization, lam):
data = np.array([[-2, 5]]).T
reg = get_regularization(regularization)
reg_result = reg(data, lam)
prox = get_prox(regularization)
prox_result = prox(data, lam)
assert reg_result is not None
assert prox_result is not None


@pytest.mark.parametrize("regularization", ["l0", "l1", "l2"])
@pytest.mark.parametrize(
"lam", [np.array([[1, 2]]), np.array([1, 2]), np.array([[1, 2]]).T]
"lam",
[
np.array([[1, 2]]),
np.array([1, 2]),
np.array([[1, 2]]).T,
np.array([1]),
np.array([[1]]),
],
)
def test_get_prox_and_regularization_bad_shape(regularization, lam):
data = np.array([[-2, 5]]).T
Expand All @@ -127,7 +108,15 @@ def test_get_prox_and_regularization_bad_shape(regularization, lam):
"regularization", ["weighted_l0", "weighted_l1", "weighted_l2"]
)
@pytest.mark.parametrize(
"lam", [np.array([[1, 2]]), np.array([1, 2, 3]), np.array([[1, 2, 3]]).T, 1]
"lam",
[
np.array([[1, 2]]),
np.array([1, 2, 3]),
np.array([[1, 2, 3]]).T,
1,
np.array([1]),
np.array([[1]]),
],
)
def test_get_weighted_prox_and_regularization_bad_shape(regularization, lam):
data = np.array([[-2, 5]]).T
Expand Down

0 comments on commit 05ee9e6

Please sign in to comment.