From 2981f77d3ff4ce4a381c228add9a4d10727b7097 Mon Sep 17 00:00:00 2001 From: himkwtn Date: Wed, 28 Aug 2024 12:46:51 -0700 Subject: [PATCH] CLN: merge weighted and non-weighted prox/reg fn --- pysindy/utils/base.py | 71 +++++++++++++++++-------------------------- 1 file changed, 28 insertions(+), 43 deletions(-) diff --git a/pysindy/utils/base.py b/pysindy/utils/base.py index 6e80accb..bf89bef7 100644 --- a/pysindy/utils/base.py +++ b/pysindy/utils/base.py @@ -192,45 +192,33 @@ def get_prox( and returns an array of shape (n_targets, n_features) """ - def prox_l0(x: NDArray[np.float64], regularization_weight: np.float64): - """Proximal operator for L0 regularization.""" - threshold = np.sqrt(2 * regularization_weight) - return x * (np.abs(x) > threshold) - - def prox_weighted_l0( - x: NDArray[np.float64], regularization_weight: NDArray[np.float64] + def prox_l0( + x: NDArray[np.float64], + regularization_weight: Union[float, NDArray[np.float64]], ): - """Proximal operator for weighted l0 regularization.""" threshold = np.sqrt(2 * regularization_weight) return x * (np.abs(x) > threshold) - def prox_l1(x: NDArray[np.float64], regularization_weight: np.float64): - """Proximal operator for L1 regularization.""" - return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0) - - def prox_weighted_l1( - x: NDArray[np.float64], regularization_weight: NDArray[np.float64] + def prox_l1( + x: NDArray[np.float64], + regularization_weight: Union[float, NDArray[np.float64]], ): - """Proximal operator for weighted l1 regularization.""" - return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0) - def prox_l2(x: NDArray[np.float64], regularization_weight: np.float64): - """Proximal operator for ridge regularization.""" - return x / (1 + 2 * regularization_weight) + return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0) - def prox_weighted_l2( - x: NDArray[np.float64], regularization_weight: NDArray[np.float64] + def prox_l2( + x: NDArray[np.float64], + regularization_weight: Union[float, NDArray[np.float64]], ): - """Proximal operator for ridge regularization.""" return x / (1 + 2 * regularization_weight) prox = { "l0": prox_l0, - "weighted_l0": prox_weighted_l0, + "weighted_l0": prox_l0, "l1": prox_l1, - "weighted_l1": prox_weighted_l1, + "weighted_l1": prox_l1, "l2": prox_l2, - "weighted_l2": prox_weighted_l2, + "weighted_l2": prox_l2, } regularization = regularization.lower() return validate_prox_and_reg_inputs(prox[regularization], regularization) @@ -253,37 +241,34 @@ def get_regularization( and returns a float """ - def regularization_l0(x: NDArray[np.float64], regularization_weight: np.float64): - return regularization_weight * np.count_nonzero(x) - - def regualization_weighted_l0( - x: NDArray[np.float64], regularization_weight: NDArray[np.float64] + def regularization_l0( + x: NDArray[np.float64], + regularization_weight: Union[float, NDArray[np.float64]], ): - return np.sum(regularization_weight[np.nonzero(x)]) - def regularization_l1(x: NDArray[np.float64], regularization_weight: np.float64): - return np.sum(regularization_weight * np.abs(x)) + return np.sum(regularization_weight * (x != 0)) - def regualization_weighted_l1( - x: NDArray[np.float64], regularization_weight: NDArray[np.float64] + def regularization_l1( + x: NDArray[np.float64], + regularization_weight: Union[float, NDArray[np.float64]], ): - return np.sum(regularization_weight * np.abs(x)) - def regularization_l2(x: NDArray[np.float64], regularization_weight: np.float64): - return np.sum(regularization_weight * x**2) + return np.sum(regularization_weight * np.abs(x)) - def regualization_weighted_l2( - x: NDArray[np.float64], regularization_weight: NDArray[np.float64] + def regularization_l2( + x: NDArray[np.float64], + regularization_weight: Union[float, NDArray[np.float64]], ): + return np.sum(regularization_weight * x**2) regularization_fn = { "l0": regularization_l0, - "weighted_l0": regualization_weighted_l0, + "weighted_l0": regularization_l0, "l1": regularization_l1, - "weighted_l1": regualization_weighted_l1, + "weighted_l1": regularization_l1, "l2": regularization_l2, - "weighted_l2": regualization_weighted_l2, + "weighted_l2": regularization_l2, } regularization = regularization.lower() return validate_prox_and_reg_inputs(