Skip to content

Commit

Permalink
feat(trapping): Create trapping constraints automatically
Browse files Browse the repository at this point in the history
Previously, TrappingSR3 could only use constraints passed to it, and only then a
limited set of constraints.  It also didn't apply the trapping constraints
automatically, because constraints were required at __init__, and actually
shaping them requires knowledge about the number of features, typically not
known until fit (unless the user is a developer who knows how the feature
libraries work internally 😉)

WIP, Spawned issue #452
  • Loading branch information
Jacob-Stevens-Haas committed Jan 8, 2024
1 parent 1b436ca commit 7b2622b
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 27 deletions.
4 changes: 2 additions & 2 deletions pysindy/optimizers/constrained_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class ConstrainedSR3(SR3):
constraint_lhs : numpy ndarray, optional (default None)
Shape should be (n_constraints, n_features * n_targets),
The left hand side matrix C of Cw <= d.
There should be one row per constraint.
The left hand side matrix C of Cw <= d (Or Cw = d for equality
constraints). There should be one row per constraint.
constraint_rhs : numpy ndarray, shape (n_constraints,), optional (default None)
The right hand side vector d of Cw <= d.
Expand Down
47 changes: 40 additions & 7 deletions pysindy/optimizers/trapping_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ class TrappingSR3(ConstrainedSR3):
If relax_optim = True, use the relax-and-split method. If False,
try a direct minimization on the largest eigenvalue.
inequality_constraints : bool, optional (default False)
If True, relax_optim must be false or relax_optim = True
AND threshold != 0, so that the CVXPY methods are used.
alpha_A :
Determines the step size in the prox-gradient descent over A.
For convergence, need alpha_A <= eta, so default
Expand Down Expand Up @@ -166,10 +162,11 @@ class TrappingSR3(ConstrainedSR3):
def __init__(
self,
*,
_n_tgts: int = None,
_include_bias: bool = True,
eta: Union[float, None] = None,
eps_solver: float = 1e-7,
relax_optim: bool = True,
inequality_constraints=False,
alpha_A: Union[float, None] = None,
alpha_m: Union[float, None] = None,
gamma: float = -0.1,
Expand All @@ -180,10 +177,46 @@ def __init__(
A0: Union[NDArray, None] = None,
**kwargs,
):
super().__init__(thresholder=thresholder, **kwargs)
# n_tgts, constraints, etc are data-dependent parameters and belong in
# _reduce/fit (). The following is a hack until we refactor how
# constraints are applied in ConstrainedSR3 and MIOSR
self._include_bias = _include_bias
self._n_tgts = _n_tgts
if _n_tgts is None:
warnings.warn(
"Trapping Optimizer initialized without _n_tgts. It will likely"
" be unable to fit data"
)
_n_tgts = 1
constraint_separation_index = kwargs.get("constraint_separation_index", 0)
constraint_rhs, constraint_lhs = _make_constraints(
_n_tgts, include_bias=_include_bias
)
constraint_order = kwargs.get("constraint_order", "feature")
if constraint_order == "target":
constraint_lhs = np.reshape(np.transpose(constraint_lhs, [1, 0, 2]))
constraint_lhs = np.reshape(constraint_lhs, (constraint_lhs.shape[0], -1))
try:
constraint_lhs = np.concatenate(
(kwargs.pop("constraint_lhs"), constraint_lhs), 0
)
constraint_rhs = np.concatenate(
(kwargs.pop("constraint_rhs"), constraint_rhs), 0
)
except KeyError:
pass

super().__init__(
constraint_lhs=constraint_lhs,
constraint_rhs=constraint_rhs,
constraint_separation_index=constraint_separation_index,
constraint_order=constraint_order,
equality_constraints=True,
thresholder=thresholder,
**kwargs,
)
self.eps_solver = eps_solver
self.relax_optim = relax_optim
self.inequality_constraints = inequality_constraints
self.m0 = m0
self.A0 = A0
self.alpha_A = alpha_A
Expand Down
3 changes: 2 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,8 @@ def data_linear_oscillator_corrupted():
@pytest.fixture(scope="session")
def data_linear_combination():
t = np.linspace(0, 5, 100)
x = np.stack((np.exp(t), np.sin(t), np.cos(t)), axis=-1)
lib = PolynomialLibrary(2)
x = lib.fit_transform(t)
y = np.stack((x[:, 0] + x[:, 1], x[:, 1] + x[:, 2]), axis=-1)

return x, y
Expand Down
57 changes: 40 additions & 17 deletions test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest
from numpy.linalg import norm
from numpy.typing import NDArray
from scipy.integrate import solve_ivp
from sklearn.base import BaseEstimator
from sklearn.exceptions import ConvergenceWarning
Expand All @@ -18,6 +19,7 @@
from pysindy import SINDy
from pysindy.feature_library import CustomLibrary
from pysindy.feature_library import SINDyPILibrary
from pysindy.optimizers import BaseOptimizer
from pysindy.optimizers import ConstrainedSR3
from pysindy.optimizers import EnsembleOptimizer
from pysindy.optimizers import FROLS
Expand Down Expand Up @@ -67,6 +69,18 @@ def predict(self, x):
return x


def _align_optimizer_and_1dfeatures(
opt: BaseOptimizer, features: NDArray
) -> tuple[BaseOptimizer, NDArray]:
# This is a hack until constraints are moved from init to fit
if isinstance(opt, TrappingSR3):
opt = TrappingSR3(_n_tgts=1, _include_bias=True)
features = np.hstack([features, features, features])
else:
features = features
return opt, features


@pytest.mark.parametrize(
"cls, support",
[
Expand Down Expand Up @@ -100,17 +114,19 @@ def data(request):
SR3(),
ConstrainedSR3(),
StableLinearSR3(),
TrappingSR3(),
TrappingSR3(_n_tgts=1),
Lasso(fit_intercept=False),
ElasticNet(fit_intercept=False),
DummyLinearModel(),
MIOSR(),
],
ids=lambda param: type(param),
)
def test_fit(data_derivative_1d, optimizer):
x, x_dot = data_derivative_1d
if len(x.shape) == 1:
x = x.reshape(-1, 1)
optimizer, x = _align_optimizer_and_1dfeatures(optimizer, x)
opt = WrappedOptimizer(optimizer, unbias=False)
opt.fit(x, x_dot)

Expand Down Expand Up @@ -167,12 +183,12 @@ def test_alternate_parameters(data_derivative_1d, kwargs):
],
)
def test_sample_weight_optimizers(data_1d, optimizer):
x, t = data_1d

y, t = data_1d
opt = optimizer()
opt, x = _align_optimizer_and_1dfeatures(opt, y)
sample_weight = np.ones(x[:, 0].shape)
sample_weight[::2] = 0
opt = optimizer()
opt.fit(x, x, sample_weight=sample_weight)
opt.fit(x, y, sample_weight=sample_weight)
check_is_fitted(opt)


Expand Down Expand Up @@ -222,12 +238,12 @@ def test_sr3_bad_parameters(optimizer, params):
)
def test_trapping_bad_parameters(params):
with pytest.raises(ValueError):
TrappingSR3(**params)
TrappingSR3(_n_tgts=1, **params)


def test_trapping_objective_print():
# test error in verbose print logic when max_iter < 10
opt = TrappingSR3(max_iter=2, verbose=True)
opt = TrappingSR3(_n_tgts=1, max_iter=2, verbose=True)
arr = np.ones(1)
opt._objective(arr, arr, arr, arr, arr, 1)

Expand Down Expand Up @@ -481,7 +497,7 @@ def test_trapping_sr3_quadratic_library(params, trapping_sr3_params):

params.update(trapping_sr3_params)

opt = TrappingSR3(**params)
opt = TrappingSR3(_n_tgts=1, _include_bias=False, **params)
opt.fit(features, x_dot)
assert opt.PL_.shape == (1, 1, 1, 2)
assert opt.PQ_.shape == (1, 1, 1, 1, 2)
Expand All @@ -494,7 +510,7 @@ def test_trapping_sr3_quadratic_library(params, trapping_sr3_params):
params["constraint_rhs"] = np.zeros(p)
params["constraint_lhs"] = np.eye(p, r * N)

opt = TrappingSR3(**params)
opt = TrappingSR3(_n_tgts=1, _include_bias=False, **params)
opt.fit(features, x_dot)
assert opt.PL_.shape == (1, 1, 1, 2)
assert opt.PQ_.shape == (1, 1, 1, 1, 2)
Expand Down Expand Up @@ -631,7 +647,7 @@ def test_constrained_sr3_prox_functions(data_derivative_1d, thresholder):
(SR3, {"trimming_fraction": 0.1}),
(ConstrainedSR3, {"constraint_lhs": [1], "constraint_rhs": [1]}),
(ConstrainedSR3, {"trimming_fraction": 0.1}),
(TrappingSR3, {"constraint_lhs": [1], "constraint_rhs": [1]}),
(TrappingSR3, {"_n_tgts": 1, "constraint_lhs": [1], "constraint_rhs": [1]}),
(StableLinearSR3, {"constraint_lhs": [1], "constraint_rhs": [1]}),
(StableLinearSR3, {"trimming_fraction": 0.1}),
(SINDyPI, {}),
Expand Down Expand Up @@ -741,7 +757,7 @@ def test_sr3_enable_trimming(optimizer, data_linear_oscillator_corrupted):
SR3(max_iter=1),
ConstrainedSR3(max_iter=1),
StableLinearSR3(max_iter=1),
TrappingSR3(max_iter=1),
TrappingSR3(_n_tgts=1, max_iter=1),
],
)
def test_fit_warn(data_derivative_1d, optimizer):
Expand All @@ -755,7 +771,11 @@ def test_fit_warn(data_derivative_1d, optimizer):

@pytest.mark.parametrize(
"optimizer",
[(ConstrainedSR3, {"max_iter": 80}), (TrappingSR3, {"max_iter": 100}), (MIOSR, {})],
[
(ConstrainedSR3, {"max_iter": 80}),
(TrappingSR3, {"_n_tgts": 5, "max_iter": 100}),
(MIOSR, {}),
],
)
@pytest.mark.parametrize("target_value", [0, -1, 3])
def test_row_format_constraints(data_linear_combination, optimizer, target_value):
Expand Down Expand Up @@ -966,6 +986,7 @@ def test_normalize_columns(data_derivative_1d, optimizer):
if len(x.shape) == 1:
x = x.reshape(-1, 1)
opt = optimizer(normalize_columns=True)
opt, x = _align_optimizer_and_1dfeatures(opt, x)
opt.fit(x, x_dot)
check_is_fitted(opt)
assert opt.complexity >= 0
Expand Down Expand Up @@ -1027,9 +1048,11 @@ def test_ssr_criteria(data_lorenz):
],
)
def test_optimizers_verbose(data_1d, optimizer):
x, _ = data_1d
y, _ = data_1d
opt = optimizer(verbose=True)
opt.fit(x, x)
opt, x = _align_optimizer_and_1dfeatures(opt, y)
opt.verbose = True
opt.fit(x, y)
check_is_fitted(opt)


Expand All @@ -1043,10 +1066,10 @@ def test_optimizers_verbose(data_1d, optimizer):
],
)
def test_optimizers_verbose_cvxpy(data_1d, optimizer):
x, _ = data_1d

y, _ = data_1d
opt = optimizer(verbose_cvxpy=True)
opt.fit(x, x)
opt, x = _align_optimizer_and_1dfeatures(opt, y)
opt.fit(x, y)
check_is_fitted(opt)


Expand Down

0 comments on commit 7b2622b

Please sign in to comment.