From 5c0bf1f4132cf04d948a6d2410c7b27d713e2b8d Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 30 Nov 2023 01:01:08 +0000 Subject: [PATCH] CLN: Move Trapping guards from __init__ to helper. This allows them to be used by set_params, part of the scikit-learn API --- pysindy/optimizers/base.py | 19 ++++++--- pysindy/optimizers/trapping_sr3.py | 62 ++++++++++++++++-------------- 2 files changed, 46 insertions(+), 35 deletions(-) diff --git a/pysindy/optimizers/base.py b/pysindy/optimizers/base.py index 45d4842b2..e44a23022 100644 --- a/pysindy/optimizers/base.py +++ b/pysindy/optimizers/base.py @@ -97,17 +97,24 @@ def __init__( unbias: bool = True, ): super().__init__(fit_intercept=False, copy_X=copy_X) - - if max_iter <= 0: - raise ValueError("max_iter must be positive") - self.max_iter = max_iter self.iters = 0 - if np.ndim(initial_guess) == 1: - initial_guess = initial_guess.reshape(1, -1) self.initial_guess = initial_guess self.normalize_columns = normalize_columns self.unbias = unbias + self.__post_init_guard() + + # See name mangling rules for double underscore rationale + def __post_init_guard(self): + """Conduct initialization post-init, as required by scikitlearn API.""" + if np.ndim(self.initial_guess) == 1: + self.initial_guess = self.initial_guess.reshape(1, -1) + if self.max_iter <= 0: + raise ValueError("max_iter must be positive") + + def set_params(self, **kwargs): + super().set_params(**kwargs) + self.__post_init_guard # Force subclasses to implement this @abc.abstractmethod diff --git a/pysindy/optimizers/trapping_sr3.py b/pysindy/optimizers/trapping_sr3.py index 3d2a29441..eea4aef36 100644 --- a/pysindy/optimizers/trapping_sr3.py +++ b/pysindy/optimizers/trapping_sr3.py @@ -174,53 +174,57 @@ def __init__( A0: NDArray | None = None, **kwargs, ): - super().__init__( - thresholder=thresholder, - **kwargs, - ) + super().__init__(thresholder=thresholder, **kwargs) + self.eps_solver = eps_solver + self.relax_optim = relax_optim + self.inequality_constraints = inequality_constraints + self.m0 = m0 + self.A0 = A0 + self.alpha_A = alpha_A + self.alpha_m = alpha_m + self.eta = eta + self.gamma = gamma + self.tol_m = tol_m + self.accel = accel + self.__post_init_guard() + + def __post_init_guard(self): + """Conduct initialization post-init, as required by scikitlearn API""" if self.thresholder.lower() not in ("l1", "l2", "weighted_l1", "weighted_l2"): raise ValueError("Regularizer must be (weighted) L1 or L2") - if eta is None: + if self.eta is None: warnings.warn( "eta was not set, so defaulting to eta = 1e20 " "with alpha_m = 1e-2 * eta, alpha_A = eta. Here eta is so " "large that the stability term in the optimization " "will be ignored." ) - eta = 1e20 - alpha_m = 1e18 - alpha_A = 1e20 + self.eta = 1e20 + self.alpha_m = 1e18 + self.alpha_A = 1e20 else: - if alpha_m is None: - alpha_m = eta * 1e-2 - if alpha_A is None: - alpha_A = eta - if eta <= 0: + if self.alpha_m is None: + self.alpha_m = self.eta * 1e-2 + if self.alpha_A is None: + self.alpha_A = self.eta + if self.eta <= 0: raise ValueError("eta must be positive") - if alpha_m < 0 or alpha_m > eta: + if self.alpha_m < 0 or self.alpha_m > self.eta: raise ValueError("0 <= alpha_m <= eta") - if alpha_A < 0 or alpha_A > eta: + if self.alpha_A < 0 or self.alpha_A > self.eta: raise ValueError("0 <= alpha_A <= eta") - if gamma >= 0: + if self.gamma >= 0: raise ValueError("gamma must be negative") - if self.tol <= 0 or tol_m <= 0 or eps_solver <= 0: + if self.tol <= 0 or self.tol_m <= 0 or self.eps_solver <= 0: raise ValueError("tol and tol_m must be positive") - if inequality_constraints and relax_optim and self.threshold == 0.0: + if self.inequality_constraints and self.relax_optim and self.threshold == 0.0: raise ValueError( "Ineq. constr. -> threshold!=0 + relax_optim=True or relax_optim=False." ) - self.eps_solver = eps_solver - self.relax_optim = relax_optim - self.inequality_constraints = inequality_constraints - self.m0 = m0 - self.A0 = A0 - self.alpha_A = alpha_A - self.alpha_m = alpha_m - self.eta = eta - self.gamma = gamma - self.tol_m = tol_m - self.accel = accel + def set_params(self, **kwargs): + super().set_params(**kwargs) + self.__post_init_guard def _set_Ptensors( self, n_targets: int