diff --git a/pysindy/pysindy.py b/pysindy/pysindy.py index e3b8de0b..35d9cda0 100644 --- a/pysindy/pysindy.py +++ b/pysindy/pysindy.py @@ -18,6 +18,7 @@ from sklearn.utils.validation import check_is_fitted from typing_extensions import Self +from .differentiation import BaseDifferentiation from .differentiation import FiniteDifference from .feature_library import BaseFeatureLibrary from .feature_library import PolynomialLibrary @@ -54,6 +55,18 @@ class _BaseSINDy(BaseEstimator, ABC): def fit(self, x, t, *args, **kwargs) -> Self: ... + def _fit_shape(self): + """Assign shape attributes for the system that are used post-fit""" + self.n_features_in_ = self.feature_library.n_features_in_ + self.n_output_features_ = self.feature_library.n_output_features_ + if self.feature_names is None: + feature_names = [] + for i in range(self.n_features_in_ - self.n_control_features_): + feature_names.append("x" + str(i)) + for i in range(self.n_control_features_): + feature_names.append("u" + str(i)) + self.feature_names = feature_names + def equations(self, precision: int = 3) -> list[str]: """ Get the right hand sides of the SINDy model equations. @@ -242,12 +255,12 @@ class SINDy(_BaseSINDy): def __init__( self, - optimizer=None, - feature_library=None, - differentiation_method=None, - feature_names=None, - t_default=1, - discrete_time=False, + optimizer: Optional[BaseOptimizer] = None, + feature_library: Optional[BaseFeatureLibrary] = None, + differentiation_method: Optional[BaseDifferentiation] = None, + feature_names: Optional[list[str]] = None, + t_default: float = 1, + discrete_time: bool = False, ): if optimizer is None: optimizer = STLSQ() @@ -349,17 +362,7 @@ def fit( x_dot = concat_sample_axis(x_dot) self.model = Pipeline(steps) self.model.fit(x, x_dot) - - self.n_features_in_ = self.feature_library.n_features_in_ - self.n_output_features_ = self.feature_library.n_output_features_ - - if self.feature_names is None: - feature_names = [] - for i in range(self.n_features_in_ - self.n_control_features_): - feature_names.append("x" + str(i)) - for i in range(self.n_control_features_): - feature_names.append("u" + str(i)) - self.feature_names = feature_names + self._fit_shape() return self