diff --git a/pysindy/feature_library/base.py b/pysindy/feature_library/base.py index 16149b27c..54697da45 100644 --- a/pysindy/feature_library/base.py +++ b/pysindy/feature_library/base.py @@ -63,10 +63,9 @@ def correct_shape(self, x: AxesArray): return x def calc_trajectory(self, diff_method, x, t): - axes = x.__dict__ x_dot = diff_method(x, t=t) - x = AxesArray(diff_method.smoothed_x_, axes) - return x, AxesArray(x_dot, axes) + x = AxesArray(diff_method.smoothed_x_, x.axes) + return x, AxesArray(x_dot, x.axes) def get_spatial_grid(self): return None @@ -337,7 +336,7 @@ def __init__( self.libraries = libraries self.inputs_per_library = inputs_per_library - def _combinations(self, lib_i, lib_j): + def _combinations(self, lib_i: AxesArray, lib_j: AxesArray) -> AxesArray: """ Compute combinations of the numerical libraries. @@ -351,7 +350,7 @@ def _combinations(self, lib_i, lib_j): lib_i.shape[lib_i.ax_coord] * lib_j.shape[lib_j.ax_coord] ) lib_full = np.reshape( - lib_i[..., :, np.newaxis] * lib_j[..., np.newaxis, :], + lib_i[..., :, "coord"] * lib_j[..., "coord", :], shape, ) diff --git a/pysindy/feature_library/generalized_library.py b/pysindy/feature_library/generalized_library.py index 3e5e24055..29834c2a8 100644 --- a/pysindy/feature_library/generalized_library.py +++ b/pysindy/feature_library/generalized_library.py @@ -237,7 +237,7 @@ def transform(self, x_full): else: xps.append(lib.transform([x])[0]) - xp = AxesArray(np.concatenate(xps, axis=xps[0].ax_coord), xps[0].__dict__) + xp = AxesArray(np.concatenate(xps, axis=xps[0].ax_coord), xps[0].axes) xp_full = xp_full + [xp] return xp_full diff --git a/pysindy/feature_library/polynomial_library.py b/pysindy/feature_library/polynomial_library.py index 75dbf5637..e62af38bd 100644 --- a/pysindy/feature_library/polynomial_library.py +++ b/pysindy/feature_library/polynomial_library.py @@ -225,7 +225,7 @@ def transform(self, x_full): dtype=x.dtype, order=self.order, ), - x.__dict__, + x.axes, ) for i, comb in enumerate(combinations): xp[..., i] = x[..., comb].prod(-1) diff --git a/pysindy/feature_library/sindy_pi_library.py b/pysindy/feature_library/sindy_pi_library.py index 8d5f054a7..f45cf567f 100644 --- a/pysindy/feature_library/sindy_pi_library.py +++ b/pysindy/feature_library/sindy_pi_library.py @@ -404,5 +404,5 @@ def transform(self, x_full): *[x[:, comb] for comb in f_combs] ) * f_dot(*[x_dot[:, comb] for comb in f_dot_combs]) library_idx += 1 - xp_full = xp_full + [AxesArray(xp, x.__dict__)] + xp_full = xp_full + [AxesArray(xp, x.axes)] return xp_full diff --git a/pysindy/optimizers/base.py b/pysindy/optimizers/base.py index 45d4842b2..614341b54 100644 --- a/pysindy/optimizers/base.py +++ b/pysindy/optimizers/base.py @@ -144,7 +144,8 @@ def fit(self, x_, y, sample_weight=None, **reduce_kws): self : returns an instance of self """ x_ = AxesArray(np.asarray(x_), {"ax_sample": 0, "ax_coord": 1}) - y = AxesArray(np.asarray(y), {"ax_sample": 0, "ax_coord": 1}) + y_axes = {"ax_sample": 0} if y.ndim == 1 else {"ax_sample": 0, "ax_coord": 1} + y = AxesArray(np.asarray(y), y_axes) x_, y = drop_nan_samples(x_, y) x_, y = check_X_y(x_, y, accept_sparse=[], y_numeric=True, multi_output=True) diff --git a/test/test_feature_library.py b/test/test_feature_library.py index 8e98b1a0d..6fba611a5 100644 --- a/test/test_feature_library.py +++ b/test/test_feature_library.py @@ -247,6 +247,7 @@ def test_sindypi_library_bad_params(params): pytest.lazy_fixture("ode_library"), pytest.lazy_fixture("sindypi_library"), ], + ids=type, ) def test_fit_transform(data_lorenz, library): x, t = data_lorenz diff --git a/test/test_optimizers.py b/test/test_optimizers.py index c69ce9823..7bd657aa1 100644 --- a/test/test_optimizers.py +++ b/test/test_optimizers.py @@ -587,7 +587,7 @@ def test_specific_bad_parameters(error, optimizer, params, data_lorenz): def test_bad_optimizers(data_derivative_1d): x, x_dot = data_derivative_1d x = x.reshape(-1, 1) - + x_dot = x_dot.reshape(-1, 1) with pytest.raises(InvalidParameterError): # Error: optimizer does not have a callable fit method opt = WrappedOptimizer(DummyEmptyModel())