Skip to content

Commit

Permalink
Merge pull request #586 from dynamicslab/pickle-sr3
Browse files Browse the repository at this point in the history
Be explicit about setting the shape for SSR.
Pickle SR3 and subordinate classes
  • Loading branch information
Jacob-Stevens-Haas authored Jan 12, 2025
2 parents 2ca37cb + 8b8f0a5 commit cbb6863
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 79 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ repos:
hooks:
- id: end-of-file-fixer
exclude: (.txt|^docs/JOSS1|^docs/JOSS2|^examples/data/)
stages: [commit, merge-commit, push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite]
stages: [pre-commit, pre-merge-commit, pre-push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite]
- id: trailing-whitespace
stages: [commit, merge-commit, push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite]
stages: [pre-commit, pre-merge-commit, pre-push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite]
exclude: (.txt|^docs/JOSS1|^docs/JOSS2|^examples/data/)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
readme = "README.rst"
dependencies = [
"jax>=0.4,<0.5",
"scikit-learn>=1.1, !=1.5.0",
"scikit-learn>=1.1, !=1.5.0, !=1.6.0",
"derivative>=0.6.2",
"typing_extensions",
]
Expand Down
20 changes: 19 additions & 1 deletion pysindy/optimizers/ssr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
from typing import cast
from typing import NewType
from typing import TypeVar

import numpy as np
from numpy.typing import NBitBase
from sklearn.linear_model import ridge_regression

from .base import BaseOptimizer

Rows = TypeVar("Rows", covariant=True, bound=int)
Cols = TypeVar("Cols", covariant=True, bound=int)
Float2D = np.ndarray[tuple[Rows, Cols], np.dtype[np.floating[NBitBase]]]
Features = NewType("Features", int)
Targets = NewType("Targets", int)
Samples = NewType("Samples", int)


class SSR(BaseOptimizer):
"""Stepwise sparse regression (SSR) greedy algorithm.
Expand Down Expand Up @@ -157,17 +169,23 @@ def _model_residual(self, x, y, coef, inds):
cc[total_ind] = 0.0
return cc, total_ind

def _regress(self, x, y):
def _regress(
self, x: Float2D[Samples, Features], y: Float2D[Samples, Targets]
) -> Float2D[Targets, Features]:
"""Perform the ridge regression"""
kw = self.ridge_kw or {}
coef = ridge_regression(x, y, self.alpha, **kw)
coef = np.atleast_2d(coef) # type: ignore
self.iters += 1
return coef

def _reduce(self, x, y):
"""Performs at most ``self.max_iter`` iterations of the
SSR greedy algorithm.
"""
# Until static typing grows, use cast
x = cast(Float2D[Samples, Features], x)
y = cast(Float2D[Samples, Targets], y)
n_samples, n_features = x.shape
n_targets = y.shape[1]
cond_num = np.linalg.cond(x)
Expand Down
136 changes: 73 additions & 63 deletions pysindy/utils/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from functools import wraps
from typing import Callable
from typing import Sequence
from typing import Union
Expand Down Expand Up @@ -152,17 +153,19 @@ def reorder_constraints(arr, n_features, output_order="feature"):
return arr.reshape(starting_shape).transpose([0, 2, 1]).reshape((n_constraints, -1))


def _validate_prox_and_reg_inputs(func, regularization):
def _validate_prox_and_reg_inputs(func):
"""Add guard code to ensure weight and argument have compatible shape/type
Decorates prox and regularization functions.
"""

@wraps(func)
def wrapper(x, regularization_weight):
if regularization[:8] == "weighted":
if not isinstance(regularization_weight, np.ndarray):
raise ValueError(
f"'regularization_weight' must be an array of shape {x.shape}."
)
if isinstance(regularization_weight, np.ndarray):
weight_shape = regularization_weight.shape
if weight_shape != x.shape:
raise ValueError(
f"Invalid shape for 'regularization_weight':"
f"Invalid shape for 'regularization_weight': "
f"{weight_shape}. Must be the same shape as x: {x.shape}."
)
elif not isinstance(regularization_weight, (int, float)):
Expand Down Expand Up @@ -190,36 +193,66 @@ def get_prox(
and returns an array of the same shape
"""

def prox_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
threshold = np.sqrt(2 * regularization_weight)
return x * (np.abs(x) > threshold)
prox = {
"l0": _prox_l0,
"weighted_l0": _prox_l0,
"l1": _prox_l1,
"weighted_l1": _prox_l1,
"l2": _prox_l2,
"weighted_l2": _prox_l2,
}
regularization = regularization.lower()
return prox[regularization]

def prox_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0)
@_validate_prox_and_reg_inputs
def _prox_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
threshold = np.sqrt(2 * regularization_weight)
return x * (np.abs(x) > threshold)

def prox_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return x / (1 + 2 * regularization_weight)

prox = {
"l0": prox_l0,
"weighted_l0": prox_l0,
"l1": prox_l1,
"weighted_l1": prox_l1,
"l2": prox_l2,
"weighted_l2": prox_l2,
}
regularization = regularization.lower()
return _validate_prox_and_reg_inputs(prox[regularization], regularization)
@_validate_prox_and_reg_inputs
def _prox_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0)


@_validate_prox_and_reg_inputs
def _prox_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return x / (1 + 2 * regularization_weight)


@_validate_prox_and_reg_inputs
def _regularization_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return np.sum(regularization_weight * (x != 0))


@_validate_prox_and_reg_inputs
def _regularization_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return np.sum(regularization_weight * np.abs(x))


@_validate_prox_and_reg_inputs
def _regularization_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return np.sum(regularization_weight * x**2)


def get_regularization(
Expand All @@ -238,39 +271,16 @@ def get_regularization(
and returns a float
"""

def regularization_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sum(regularization_weight * (x != 0))

def regularization_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sum(regularization_weight * np.abs(x))

def regularization_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sum(regularization_weight * x**2)

regularization_fn = {
"l0": regularization_l0,
"weighted_l0": regularization_l0,
"l1": regularization_l1,
"weighted_l1": regularization_l1,
"l2": regularization_l2,
"weighted_l2": regularization_l2,
"l0": _regularization_l0,
"weighted_l0": _regularization_l0,
"l1": _regularization_l1,
"weighted_l1": _regularization_l1,
"l2": _regularization_l2,
"weighted_l2": _regularization_l2,
}
regularization = regularization.lower()
return _validate_prox_and_reg_inputs(
regularization_fn[regularization], regularization
)
return regularization_fn[regularization]


def capped_simplex_projection(trimming_array, trimming_fraction):
Expand Down
4 changes: 3 additions & 1 deletion test/test_optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,12 +1183,14 @@ def test_remove_and_decrement():
(
(MIOSR, {"target_sparsity": 7}),
(SBR, {"num_warmup": 10, "num_samples": 10}),
(SR3, {}),
(TrappingSR3, {"_n_tgts": 3, "_include_bias": True}),
),
)
def test_pickle(data_lorenz, opt_cls, opt_args):
x, t = data_lorenz
y = PolynomialLibrary(degree=2).fit_transform(x)
opt = opt_cls(**opt_args).fit(x, y)
opt = opt_cls(**opt_args).fit(y, x)
expected = opt.coef_
new_opt = pickle.loads(pickle.dumps(opt))
result = new_opt.coef_
Expand Down
14 changes: 3 additions & 11 deletions test/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,9 @@ def test_get_prox_and_regularization_bad_shape(regularization, lam):
prox(data, lam)


@pytest.mark.parametrize(
"regularization", ["weighted_l0", "weighted_l1", "weighted_l2"]
)
@pytest.mark.parametrize(
"lam",
[
np.array([[1, 2]]),
1,
],
)
def test_get_weighted_prox_and_regularization_bad_shape(regularization, lam):
@pytest.mark.parametrize("regularization", ["l0", "l1", "l2"])
def test_get_weighted_prox_and_regularization_bad_shape(regularization):
lam = np.array([[1, 2]])
data = np.array([[-2, 5]]).T
reg = get_regularization(regularization)
with pytest.raises(ValueError):
Expand Down

0 comments on commit cbb6863

Please sign in to comment.