diff --git a/pysindy/utils/base.py b/pysindy/utils/base.py index bf89bef7..c62a38db 100644 --- a/pysindy/utils/base.py +++ b/pysindy/utils/base.py @@ -153,7 +153,7 @@ def reorder_constraints(arr, n_features, output_order="feature"): return arr.reshape(starting_shape).transpose([0, 2, 1]).reshape((n_constraints, -1)) -def validate_prox_and_reg_inputs(func, regularization): +def _validate_prox_and_reg_inputs(func, regularization): def wrapper(x, regularization_weight): if regularization[:8] == "weighted": if not isinstance(regularization_weight, np.ndarray): @@ -176,7 +176,7 @@ def wrapper(x, regularization_weight): def get_prox( regularization: str, ) -> Callable[ - [NDArray[np.float64], Union[np.float64, NDArray[np.float64]]], NDArray[np.float64] + [NDArray[np.float64], Union[float, NDArray[np.float64]]], NDArray[np.float64] ]: """ Args: @@ -186,10 +186,10 @@ def get_prox( Returns: -------- proximal_operator: (x: np.array, reg_weight: float | np.array) -> np.array - A function that takes an input x of shape (n_targets, n_features) + A function that takes an input x of shape (m, n) and regularization weight factor which can be a scalar or - an array of shape (n_targets, n_features), - and returns an array of shape (n_targets, n_features) + an array of shape (m, n), + and returns an array of shape (m, n) """ def prox_l0( @@ -221,12 +221,12 @@ def prox_l2( "weighted_l2": prox_l2, } regularization = regularization.lower() - return validate_prox_and_reg_inputs(prox[regularization], regularization) + return _validate_prox_and_reg_inputs(prox[regularization], regularization) def get_regularization( regularization: str, -) -> Callable[[NDArray[np.float64], Union[np.float64, NDArray[np.float64]]], float]: +) -> Callable[[NDArray[np.float64], Union[float, NDArray[np.float64]]], float]: """ Args: ----- @@ -235,9 +235,9 @@ def get_regularization( Returns: -------- regularization_function: (x: np.array, reg_weight: float | np.array) -> np.array - A function that takes an input x of shape (n_targets, n_features) + A function that takes an input x of shape (m, n) and regularization weight factor which can be a scalar or - an array of shape (n_targets, n_features), + an array of shape (m, n), and returns a float """ @@ -271,7 +271,7 @@ def regularization_l2( "weighted_l2": regularization_l2, } regularization = regularization.lower() - return validate_prox_and_reg_inputs( + return _validate_prox_and_reg_inputs( regularization_fn[regularization], regularization )