Skip to content

Commit

Permalink
clean up docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
himkwtn committed Aug 28, 2024
1 parent 2981f77 commit c2237bf
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions pysindy/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def reorder_constraints(arr, n_features, output_order="feature"):
return arr.reshape(starting_shape).transpose([0, 2, 1]).reshape((n_constraints, -1))


def validate_prox_and_reg_inputs(func, regularization):
def _validate_prox_and_reg_inputs(func, regularization):
def wrapper(x, regularization_weight):
if regularization[:8] == "weighted":
if not isinstance(regularization_weight, np.ndarray):
Expand All @@ -176,7 +176,7 @@ def wrapper(x, regularization_weight):
def get_prox(
regularization: str,
) -> Callable[
[NDArray[np.float64], Union[np.float64, NDArray[np.float64]]], NDArray[np.float64]
[NDArray[np.float64], Union[float, NDArray[np.float64]]], NDArray[np.float64]
]:
"""
Args:
Expand All @@ -186,10 +186,10 @@ def get_prox(
Returns:
--------
proximal_operator: (x: np.array, reg_weight: float | np.array) -> np.array
A function that takes an input x of shape (n_targets, n_features)
A function that takes an input x of shape (m, n)
and regularization weight factor which can be a scalar or
an array of shape (n_targets, n_features),
and returns an array of shape (n_targets, n_features)
an array of shape (m, n),
and returns an array of shape (m, n)
"""

def prox_l0(
Expand Down Expand Up @@ -221,12 +221,12 @@ def prox_l2(
"weighted_l2": prox_l2,
}
regularization = regularization.lower()
return validate_prox_and_reg_inputs(prox[regularization], regularization)
return _validate_prox_and_reg_inputs(prox[regularization], regularization)


def get_regularization(
regularization: str,
) -> Callable[[NDArray[np.float64], Union[np.float64, NDArray[np.float64]]], float]:
) -> Callable[[NDArray[np.float64], Union[float, NDArray[np.float64]]], float]:
"""
Args:
-----
Expand All @@ -235,9 +235,9 @@ def get_regularization(
Returns:
--------
regularization_function: (x: np.array, reg_weight: float | np.array) -> np.array
A function that takes an input x of shape (n_targets, n_features)
A function that takes an input x of shape (m, n)
and regularization weight factor which can be a scalar or
an array of shape (n_targets, n_features),
an array of shape (m, n),
and returns a float
"""

Expand Down Expand Up @@ -271,7 +271,7 @@ def regularization_l2(
"weighted_l2": regularization_l2,
}
regularization = regularization.lower()
return validate_prox_and_reg_inputs(
return _validate_prox_and_reg_inputs(
regularization_fn[regularization], regularization
)

Expand Down

0 comments on commit c2237bf

Please sign in to comment.