Skip to content

Commit

Permalink
bug: Make caller more explicit to create AxesArray
Browse files Browse the repository at this point in the history
replace AxesArray.__dict__ with AxesArray.axes
Correct the axes definitions where caller just was ok with
being wrong before
  • Loading branch information
Jacob-Stevens-Haas committed Jan 13, 2024
1 parent 0bd7182 commit f13d593
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 10 deletions.
9 changes: 4 additions & 5 deletions pysindy/feature_library/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,9 @@ def correct_shape(self, x: AxesArray):
return x

def calc_trajectory(self, diff_method, x, t):
axes = x.__dict__
x_dot = diff_method(x, t=t)
x = AxesArray(diff_method.smoothed_x_, axes)
return x, AxesArray(x_dot, axes)
x = AxesArray(diff_method.smoothed_x_, x.axes)
return x, AxesArray(x_dot, x.axes)

def get_spatial_grid(self):
return None
Expand Down Expand Up @@ -337,7 +336,7 @@ def __init__(
self.libraries = libraries
self.inputs_per_library = inputs_per_library

def _combinations(self, lib_i, lib_j):
def _combinations(self, lib_i: AxesArray, lib_j: AxesArray) -> AxesArray:
"""
Compute combinations of the numerical libraries.
Expand All @@ -351,7 +350,7 @@ def _combinations(self, lib_i, lib_j):
lib_i.shape[lib_i.ax_coord] * lib_j.shape[lib_j.ax_coord]
)
lib_full = np.reshape(
lib_i[..., :, np.newaxis] * lib_j[..., np.newaxis, :],
lib_i[..., :, "coord"] * lib_j[..., "coord", :],
shape,
)

Expand Down
2 changes: 1 addition & 1 deletion pysindy/feature_library/generalized_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def transform(self, x_full):
else:
xps.append(lib.transform([x])[0])

xp = AxesArray(np.concatenate(xps, axis=xps[0].ax_coord), xps[0].__dict__)
xp = AxesArray(np.concatenate(xps, axis=xps[0].ax_coord), xps[0].axes)
xp_full = xp_full + [xp]
return xp_full

Expand Down
2 changes: 1 addition & 1 deletion pysindy/feature_library/polynomial_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def transform(self, x_full):
dtype=x.dtype,
order=self.order,
),
x.__dict__,
x.axes,
)
for i, comb in enumerate(combinations):
xp[..., i] = x[..., comb].prod(-1)
Expand Down
2 changes: 1 addition & 1 deletion pysindy/feature_library/sindy_pi_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,5 +404,5 @@ def transform(self, x_full):
*[x[:, comb] for comb in f_combs]
) * f_dot(*[x_dot[:, comb] for comb in f_dot_combs])
library_idx += 1
xp_full = xp_full + [AxesArray(xp, x.__dict__)]
xp_full = xp_full + [AxesArray(xp, x.axes)]
return xp_full
3 changes: 2 additions & 1 deletion pysindy/optimizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def fit(self, x_, y, sample_weight=None, **reduce_kws):
self : returns an instance of self
"""
x_ = AxesArray(np.asarray(x_), {"ax_sample": 0, "ax_coord": 1})
y = AxesArray(np.asarray(y), {"ax_sample": 0, "ax_coord": 1})
y_axes = {"ax_sample": 0} if y.ndim == 1 else {"ax_sample": 0, "ax_coord": 1}
y = AxesArray(np.asarray(y), y_axes)
x_, y = drop_nan_samples(x_, y)
x_, y = check_X_y(x_, y, accept_sparse=[], y_numeric=True, multi_output=True)

Expand Down
1 change: 1 addition & 0 deletions test/test_feature_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def test_sindypi_library_bad_params(params):
pytest.lazy_fixture("ode_library"),
pytest.lazy_fixture("sindypi_library"),
],
ids=type,
)
def test_fit_transform(data_lorenz, library):
x, t = data_lorenz
Expand Down
2 changes: 1 addition & 1 deletion test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def test_specific_bad_parameters(error, optimizer, params, data_lorenz):
def test_bad_optimizers(data_derivative_1d):
x, x_dot = data_derivative_1d
x = x.reshape(-1, 1)

x_dot = x_dot.reshape(-1, 1)
with pytest.raises(InvalidParameterError):
# Error: optimizer does not have a callable fit method
opt = WrappedOptimizer(DummyEmptyModel())
Expand Down

0 comments on commit f13d593

Please sign in to comment.