Skip to content

Commit

Permalink
CLN: merge weighted and non-weighted prox/reg fn
Browse files Browse the repository at this point in the history
  • Loading branch information
himkwtn committed Aug 28, 2024
1 parent 05ee9e6 commit 2981f77
Showing 1 changed file with 28 additions and 43 deletions.
71 changes: 28 additions & 43 deletions pysindy/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,45 +192,33 @@ def get_prox(
and returns an array of shape (n_targets, n_features)
"""

def prox_l0(x: NDArray[np.float64], regularization_weight: np.float64):
"""Proximal operator for L0 regularization."""
threshold = np.sqrt(2 * regularization_weight)
return x * (np.abs(x) > threshold)

def prox_weighted_l0(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
def prox_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
"""Proximal operator for weighted l0 regularization."""
threshold = np.sqrt(2 * regularization_weight)
return x * (np.abs(x) > threshold)

def prox_l1(x: NDArray[np.float64], regularization_weight: np.float64):
"""Proximal operator for L1 regularization."""
return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0)

def prox_weighted_l1(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
def prox_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
"""Proximal operator for weighted l1 regularization."""
return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0)

def prox_l2(x: NDArray[np.float64], regularization_weight: np.float64):
"""Proximal operator for ridge regularization."""
return x / (1 + 2 * regularization_weight)
return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0)

def prox_weighted_l2(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
def prox_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
"""Proximal operator for ridge regularization."""
return x / (1 + 2 * regularization_weight)

prox = {
"l0": prox_l0,
"weighted_l0": prox_weighted_l0,
"weighted_l0": prox_l0,
"l1": prox_l1,
"weighted_l1": prox_weighted_l1,
"weighted_l1": prox_l1,
"l2": prox_l2,
"weighted_l2": prox_weighted_l2,
"weighted_l2": prox_l2,
}
regularization = regularization.lower()
return validate_prox_and_reg_inputs(prox[regularization], regularization)
Expand All @@ -253,37 +241,34 @@ def get_regularization(
and returns a float
"""

def regularization_l0(x: NDArray[np.float64], regularization_weight: np.float64):
return regularization_weight * np.count_nonzero(x)

def regualization_weighted_l0(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
def regularization_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return np.sum(regularization_weight[np.nonzero(x)])

def regularization_l1(x: NDArray[np.float64], regularization_weight: np.float64):
return np.sum(regularization_weight * np.abs(x))
return np.sum(regularization_weight * (x != 0))

def regualization_weighted_l1(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
def regularization_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return np.sum(regularization_weight * np.abs(x))

def regularization_l2(x: NDArray[np.float64], regularization_weight: np.float64):
return np.sum(regularization_weight * x**2)
return np.sum(regularization_weight * np.abs(x))

def regualization_weighted_l2(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
def regularization_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sum(regularization_weight * x**2)

regularization_fn = {
"l0": regularization_l0,
"weighted_l0": regualization_weighted_l0,
"weighted_l0": regularization_l0,
"l1": regularization_l1,
"weighted_l1": regualization_weighted_l1,
"weighted_l1": regularization_l1,
"l2": regularization_l2,
"weighted_l2": regualization_weighted_l2,
"weighted_l2": regularization_l2,
}
regularization = regularization.lower()
return validate_prox_and_reg_inputs(
Expand Down

0 comments on commit 2981f77

Please sign in to comment.