diff --git a/pysindy/utils/base.py b/pysindy/utils/base.py index cc1733b3..6e80accb 100644 --- a/pysindy/utils/base.py +++ b/pysindy/utils/base.py @@ -163,15 +163,11 @@ def wrapper(x, regularization_weight): weight_shape = regularization_weight.shape if weight_shape != x.shape: raise ValueError( - f"Invalid shape for 'regularization_weight': \ - {weight_shape}. Must be the same shape as x: {x.shape}." + f"Invalid shape for 'regularization_weight':" + f"{weight_shape}. Must be the same shape as x: {x.shape}." ) - else: - if not isinstance(regularization_weight, (int, float)) and ( - isinstance(regularization_weight, np.ndarray) - and regularization_weight.shape not in [(1, 1), (1,)] - ): - raise ValueError("'regularization_weight' must be a scalar") + elif not isinstance(regularization_weight, (int, float)): + raise ValueError("'regularization_weight' must be a scalar") return func(x, regularization_weight) return wrapper diff --git a/test/utils/test_utils.py b/test/utils/test_utils.py index 3ae24ed7..311fe949 100644 --- a/test/utils/test_utils.py +++ b/test/utils/test_utils.py @@ -83,35 +83,16 @@ def test_get_regularization(regularization, lam, expected): assert result == expected -@pytest.mark.parametrize("regularization", ["l0", "l1", "l2"]) -@pytest.mark.parametrize("lam", [1, np.array([1]), np.array([[1]])]) -def test_get_prox_and_regularization_shape(regularization, lam): - data = np.array([[-2, 5]]).T - reg = get_regularization(regularization) - reg_result = reg(data, lam) - prox = get_prox(regularization) - prox_result = prox(data, lam) - assert reg_result is not None - assert prox_result is not None - - -@pytest.mark.parametrize( - "regularization", ["weighted_l0", "weighted_l1", "weighted_l2"] -) -@pytest.mark.parametrize("lam", [np.array([[1, 2]]).T]) -def test_get_weighted_prox_and_regularization_shape(regularization, lam): - data = np.array([[-2, 5]]).T - reg = get_regularization(regularization) - reg_result = reg(data, lam) - prox = get_prox(regularization) - prox_result = prox(data, lam) - assert reg_result is not None - assert prox_result is not None - - @pytest.mark.parametrize("regularization", ["l0", "l1", "l2"]) @pytest.mark.parametrize( - "lam", [np.array([[1, 2]]), np.array([1, 2]), np.array([[1, 2]]).T] + "lam", + [ + np.array([[1, 2]]), + np.array([1, 2]), + np.array([[1, 2]]).T, + np.array([1]), + np.array([[1]]), + ], ) def test_get_prox_and_regularization_bad_shape(regularization, lam): data = np.array([[-2, 5]]).T @@ -127,7 +108,15 @@ def test_get_prox_and_regularization_bad_shape(regularization, lam): "regularization", ["weighted_l0", "weighted_l1", "weighted_l2"] ) @pytest.mark.parametrize( - "lam", [np.array([[1, 2]]), np.array([1, 2, 3]), np.array([[1, 2, 3]]).T, 1] + "lam", + [ + np.array([[1, 2]]), + np.array([1, 2, 3]), + np.array([[1, 2, 3]]).T, + 1, + np.array([1]), + np.array([[1]]), + ], ) def test_get_weighted_prox_and_regularization_bad_shape(regularization, lam): data = np.array([[-2, 5]]).T