Skip to content

Commit

Permalink
CLN: Move Trapping guards from __init__ to helper.
Browse files Browse the repository at this point in the history
This allows them to be used by set_params, part of the scikit-learn API
  • Loading branch information
Jacob-Stevens-Haas committed Nov 30, 2023
1 parent 4ff33f6 commit 5c0bf1f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 35 deletions.
19 changes: 13 additions & 6 deletions pysindy/optimizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,24 @@ def __init__(
unbias: bool = True,
):
super().__init__(fit_intercept=False, copy_X=copy_X)

if max_iter <= 0:
raise ValueError("max_iter must be positive")

self.max_iter = max_iter
self.iters = 0
if np.ndim(initial_guess) == 1:
initial_guess = initial_guess.reshape(1, -1)
self.initial_guess = initial_guess
self.normalize_columns = normalize_columns
self.unbias = unbias
self.__post_init_guard()

# See name mangling rules for double underscore rationale
def __post_init_guard(self):
"""Conduct initialization post-init, as required by scikitlearn API."""
if np.ndim(self.initial_guess) == 1:
self.initial_guess = self.initial_guess.reshape(1, -1)
if self.max_iter <= 0:
raise ValueError("max_iter must be positive")

def set_params(self, **kwargs):
super().set_params(**kwargs)
self.__post_init_guard

# Force subclasses to implement this
@abc.abstractmethod
Expand Down
62 changes: 33 additions & 29 deletions pysindy/optimizers/trapping_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,53 +174,57 @@ def __init__(
A0: NDArray | None = None,
**kwargs,
):
super().__init__(
thresholder=thresholder,
**kwargs,
)
super().__init__(thresholder=thresholder, **kwargs)
self.eps_solver = eps_solver
self.relax_optim = relax_optim
self.inequality_constraints = inequality_constraints
self.m0 = m0
self.A0 = A0
self.alpha_A = alpha_A
self.alpha_m = alpha_m
self.eta = eta
self.gamma = gamma
self.tol_m = tol_m
self.accel = accel
self.__post_init_guard()

def __post_init_guard(self):
"""Conduct initialization post-init, as required by scikitlearn API"""
if self.thresholder.lower() not in ("l1", "l2", "weighted_l1", "weighted_l2"):
raise ValueError("Regularizer must be (weighted) L1 or L2")
if eta is None:
if self.eta is None:
warnings.warn(
"eta was not set, so defaulting to eta = 1e20 "
"with alpha_m = 1e-2 * eta, alpha_A = eta. Here eta is so "
"large that the stability term in the optimization "
"will be ignored."
)
eta = 1e20
alpha_m = 1e18
alpha_A = 1e20
self.eta = 1e20
self.alpha_m = 1e18
self.alpha_A = 1e20
else:
if alpha_m is None:
alpha_m = eta * 1e-2
if alpha_A is None:
alpha_A = eta
if eta <= 0:
if self.alpha_m is None:
self.alpha_m = self.eta * 1e-2
if self.alpha_A is None:
self.alpha_A = self.eta
if self.eta <= 0:
raise ValueError("eta must be positive")
if alpha_m < 0 or alpha_m > eta:
if self.alpha_m < 0 or self.alpha_m > self.eta:
raise ValueError("0 <= alpha_m <= eta")
if alpha_A < 0 or alpha_A > eta:
if self.alpha_A < 0 or self.alpha_A > self.eta:
raise ValueError("0 <= alpha_A <= eta")
if gamma >= 0:
if self.gamma >= 0:
raise ValueError("gamma must be negative")
if self.tol <= 0 or tol_m <= 0 or eps_solver <= 0:
if self.tol <= 0 or self.tol_m <= 0 or self.eps_solver <= 0:
raise ValueError("tol and tol_m must be positive")
if inequality_constraints and relax_optim and self.threshold == 0.0:
if self.inequality_constraints and self.relax_optim and self.threshold == 0.0:
raise ValueError(
"Ineq. constr. -> threshold!=0 + relax_optim=True or relax_optim=False."
)

self.eps_solver = eps_solver
self.relax_optim = relax_optim
self.inequality_constraints = inequality_constraints
self.m0 = m0
self.A0 = A0
self.alpha_A = alpha_A
self.alpha_m = alpha_m
self.eta = eta
self.gamma = gamma
self.tol_m = tol_m
self.accel = accel
def set_params(self, **kwargs):
super().set_params(**kwargs)
self.__post_init_guard

def _set_Ptensors(
self, n_targets: int
Expand Down

0 comments on commit 5c0bf1f

Please sign in to comment.