Skip to content

Commit

Permalink
bug(util): fix get_regularization calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
himkwtn committed Aug 5, 2024
1 parent 476b59f commit 273d984
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
10 changes: 7 additions & 3 deletions pysindy/utils/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
from itertools import repeat
from typing import Callable
from typing import Sequence
from typing import Union

import numpy as np
from scipy.optimize import bisect
Expand Down Expand Up @@ -225,19 +227,21 @@ def get_prox(regularization):
raise NotImplementedError("{} has not been implemented".format(regularization))


def get_regularization(regularization):
def get_regularization(
regularization: str,
) -> Callable[[np.ndarray, Union[float, np.ndarray]], float]:
if regularization.lower() == "l0":
return lambda x, lam: lam * np.count_nonzero(x)
elif regularization.lower() == "weighted_l0":
return lambda x, lam: np.sum(lam[np.nonzero(x)])
elif regularization.lower() == "l1":
return lambda x, lam: lam * np.sum(np.abs(x))
elif regularization.lower() == "weighted_l1":
return lambda x, lam: np.sum(np.abs(lam @ x))
return lambda x, lam: np.sum(np.abs(lam * x))
elif regularization.lower() == "l2":
return lambda x, lam: lam * np.sum(x**2)
elif regularization.lower() == "weighted_l2":
return lambda x, lam: np.sum(lam @ x**2)
return lambda x, lam: np.sum(lam * x**2)
elif regularization.lower() == "cad": # dummy function
return lambda x, lam: 0
else:
Expand Down
41 changes: 41 additions & 0 deletions test/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest

from pysindy.utils import AxesArray
from pysindy.utils import get_regularization
from pysindy.utils import reorder_constraints
from pysindy.utils import validate_control_variables

Expand Down Expand Up @@ -59,3 +60,43 @@ def test_validate_controls():
validate_control_variables([arr], [arr[:1]])
u_mod = validate_control_variables([arr], [arr], trim_last_point=True)
assert u_mod[0].n_time == 1


@pytest.mark.parametrize(["thresholder", "expected"], [("l0", 4), ("l1", 6), ("l2", 10)])
def test_get_regularization_1d(thresholder, expected):
data = np.array([[0, 1, 2]]).T
lam = 2

reg = get_regularization(thresholder)
result = reg(data, lam)
assert result == expected


@pytest.mark.parametrize(["thresholder", "expected"], [("l0", 8), ("l1", 10), ("l2", 14)])
def test_get_regularization_2d(thresholder, expected):
data = np.array([[0, 1, 2], [1, 1, 0]]).T
lam = 2

reg = get_regularization(thresholder)
result = reg(data, lam)
assert result == expected


@pytest.mark.parametrize(["thresholder", "expected"], [("weighted_l0", 1.5), ("weighted_l1", 2), ("weighted_l2", 3)])
def test_get_weighted_regularization_1d(thresholder, expected):
data = np.array([[0, 1, 2]]).T
lam = np.array([[1, 1, 0.5]]).T

reg = get_regularization(thresholder)
result = reg(data, lam)
assert result == expected


@pytest.mark.parametrize(["thresholder", "expected"], [("weighted_l0", 2.5), ("weighted_l1", 3.5), ("weighted_l2", 5.5)])
def test_get_weighted_regularization_2d(thresholder, expected):
data = np.array([[0, 1, 2], [2, 1, 0]]).T
lam = np.array([[1, 1, 0.5], [0.5, 0.5, 1]]).T

reg = get_regularization(thresholder)
result = reg(data, lam)
assert result == expected

0 comments on commit 273d984

Please sign in to comment.