Skip to content

Commit

Permalink
bug(constrained_sr3): fix penalty term calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
himkwtn committed Aug 6, 2024
1 parent 273d984 commit 6d4b20b
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pysindy/optimizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(
self,
max_iter=20,
normalize_columns=False,
initial_guess=None,
initial_guess: np.ndarray = None,
copy_X=True,
unbias: bool = True,
):
Expand Down
27 changes: 15 additions & 12 deletions pysindy/optimizers/constrained_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def __init__(
)

self.verbose_cvxpy = verbose_cvxpy
self.reg = get_regularization(thresholder)
self.constraint_lhs = constraint_lhs
self.constraint_rhs = constraint_rhs
self.constraint_order = constraint_order
Expand Down Expand Up @@ -271,20 +270,24 @@ def _update_full_coef_constraints(self, H, x_transpose_y, coef_sparse):
rhs = rhs.reshape(g.shape)
return inv1.dot(rhs)

def _calculate_penalty(self, xi: cp.Variable) -> cp.Expression:
thresholder = self.thresholder.lower()
if thresholder == "l1":
return self.threshold * cp.norm1(xi)
elif thresholder == "weighted_l1":
return cp.norm1(cp.multiply(np.ravel(self.thresholds), xi))
elif thresholder == "l2":
return self.threshold * cp.norm2(xi) ** 2
elif thresholder == "weighted_l2":
return cp.norm2(cp.multiply(np.ravel(self.thresholds), xi)) ** 2

def _create_var_and_part_cost(
self, var_len: int, x_expanded: np.ndarray, y: np.ndarray
) -> Tuple[cp.Variable, cp.Expression]:
xi = cp.Variable(var_len)
cost = cp.sum_squares(x_expanded @ xi - y.flatten())
if self.thresholder.lower() == "l1":
cost = cost + self.threshold * cp.norm1(xi)
elif self.thresholder.lower() == "weighted_l1":
cost = cost + cp.norm1(np.ravel(self.thresholds) @ xi)
elif self.thresholder.lower() == "l2":
cost = cost + self.threshold * cp.norm2(xi) ** 2
elif self.thresholder.lower() == "weighted_l2":
cost = cost + cp.norm2(np.ravel(self.thresholds) @ xi) ** 2
return xi, cost
penalty = self._calculate_penalty(xi)
return xi, cost + penalty

def _update_coef_cvxpy(self, xi, cost, var_len, coef_prev, tol):
if self.use_constraints:
Expand Down Expand Up @@ -362,7 +365,7 @@ def _objective(self, x, y, q, coef_full, coef_sparse, trimming_array=None):
R2 *= trimming_array.reshape(x.shape[0], 1)

if self.thresholds is None:
regularization = self.reg(coef_full, self.threshold**2 / self.nu)
regularization = super().reg(coef_full, self.threshold**2 / self.nu)
if print_ind == 0 and self.verbose:
row = [
q,
Expand All @@ -377,7 +380,7 @@ def _objective(self, x, y, q, coef_full, coef_sparse, trimming_array=None):
)
return 0.5 * np.sum(R2) + 0.5 * regularization + 0.5 * np.sum(D2) / self.nu
else:
regularization = self.reg(coef_full, self.thresholds**2 / self.nu)
regularization = super().reg(coef_full, self.thresholds**2 / self.nu)
if print_ind == 0 and self.verbose:
row = [
q,
Expand Down
4 changes: 2 additions & 2 deletions pysindy/optimizers/sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,15 @@ class SR3(BaseOptimizer):
def __init__(
self,
threshold=0.1,
thresholds=None,
thresholds: np.ndarray = None,
nu=1.0,
tol=1e-5,
thresholder="L0",
trimming_fraction=0.0,
trimming_step_size=1.0,
max_iter=30,
copy_X=True,
initial_guess=None,
initial_guess: np.ndarray = None,
normalize_columns=False,
verbose=False,
unbias=False,
Expand Down
2 changes: 1 addition & 1 deletion pysindy/optimizers/trapping_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def _update_coef_sparse_rs(
self, n_tgts, n_features, var_len, x_expanded, y, Pmatrix, A, coef_prev
):
"""Solve coefficient update with CVXPY if threshold != 0"""
xi, cost = self._create_var_and_part_cost(var_len, x_expanded, y)
xi, cost = super()._create_var_and_part_cost(var_len, x_expanded, y)
cost = cost + cp.sum_squares(Pmatrix @ xi - A.flatten()) / self.eta

# new terms minimizing quadratic piece ||P^Q @ xi||_2^2
Expand Down
20 changes: 20 additions & 0 deletions test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import pickle

import cvxpy as cp
import numpy as np
import pytest
from numpy.linalg import norm
Expand Down Expand Up @@ -470,6 +471,25 @@ def test_constrained_sr3_quadratic_library(params):
assert np.allclose((model.coefficients().flatten())[:p], 0.0)


@pytest.mark.parametrize(
"params",
[
dict(thresholder="l1", threshold=1, expected=2),
dict(thresholder="weighted_l1", thresholds=np.ones((4, 1)), expected=2),
dict(thresholder="l2", threshold=1, expected=1),
dict(thresholder="weighted_l2", thresholds=np.ones((4, 1)), expected=1),
],
ids=lambda d: d["thresholder"],
)
def test_constrained_sr3_penalty_term(params):
expected = params.pop("expected")
opt = ConstrainedSR3(**params)
xi = cp.Variable(4)
cost = opt._calculate_penalty(xi)
xi.value = 0.5 * np.ones(4)
np.testing.assert_allclose(cost.value, expected)


@pytest.mark.parametrize(
"params",
[
Expand Down
18 changes: 14 additions & 4 deletions test/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def test_validate_controls():
assert u_mod[0].n_time == 1


@pytest.mark.parametrize(["thresholder", "expected"], [("l0", 4), ("l1", 6), ("l2", 10)])
@pytest.mark.parametrize(
["thresholder", "expected"], [("l0", 4), ("l1", 6), ("l2", 10)]
)
def test_get_regularization_1d(thresholder, expected):
data = np.array([[0, 1, 2]]).T
lam = 2
Expand All @@ -72,7 +74,9 @@ def test_get_regularization_1d(thresholder, expected):
assert result == expected


@pytest.mark.parametrize(["thresholder", "expected"], [("l0", 8), ("l1", 10), ("l2", 14)])
@pytest.mark.parametrize(
["thresholder", "expected"], [("l0", 8), ("l1", 10), ("l2", 14)]
)
def test_get_regularization_2d(thresholder, expected):
data = np.array([[0, 1, 2], [1, 1, 0]]).T
lam = 2
Expand All @@ -82,7 +86,10 @@ def test_get_regularization_2d(thresholder, expected):
assert result == expected


@pytest.mark.parametrize(["thresholder", "expected"], [("weighted_l0", 1.5), ("weighted_l1", 2), ("weighted_l2", 3)])
@pytest.mark.parametrize(
["thresholder", "expected"],
[("weighted_l0", 1.5), ("weighted_l1", 2), ("weighted_l2", 3)],
)
def test_get_weighted_regularization_1d(thresholder, expected):
data = np.array([[0, 1, 2]]).T
lam = np.array([[1, 1, 0.5]]).T
Expand All @@ -92,7 +99,10 @@ def test_get_weighted_regularization_1d(thresholder, expected):
assert result == expected


@pytest.mark.parametrize(["thresholder", "expected"], [("weighted_l0", 2.5), ("weighted_l1", 3.5), ("weighted_l2", 5.5)])
@pytest.mark.parametrize(
["thresholder", "expected"],
[("weighted_l0", 2.5), ("weighted_l1", 3.5), ("weighted_l2", 5.5)],
)
def test_get_weighted_regularization_2d(thresholder, expected):
data = np.array([[0, 1, 2], [2, 1, 0]]).T
lam = np.array([[1, 1, 0.5], [0.5, 0.5, 1]]).T
Expand Down

0 comments on commit 6d4b20b

Please sign in to comment.