Skip to content

Commit

Permalink
feat: Accomodate jax arrays
Browse files Browse the repository at this point in the history
This was thought to be easy, because in many cases jax arrays were an
almost drop-in replacement for numpy arrays.  However, they are far less
amenable to subclassing.  Why does this matter?

The codebase gained a lot of readability with AxesArray allowing arrays
to dynamically know what their axes meant, even after indexing changed
their shape.  However, extending AxesArray to dynamically subclass either
numpy.ndarray or jax.Array is impossible - even a static subclass of the
latter is impossible.

Long term, we will need our own metadata type that carries around an array,
it's type package (numpy or jax.numpy or cvxpy.numpy), its bidirectional
mapping between axis index and axis meaning, and maybe even something from
sympy.

Short term, we should expose our general expectations for axis definitions
as global constants.  This is still error prone, as the constants are
incorrect for arrays that have changed shape due to indexing, but will
be far more readable than magic numbers.
  • Loading branch information
Jacob-Stevens-Haas committed Oct 18, 2024
1 parent f84df31 commit 22329dd
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ classifiers = [
]
readme = "README.rst"
dependencies = [
"jax>=0.4,<0.5",
"scikit-learn>=1.1, !=1.5.0",
"derivative>=0.6.2",
"typing_extensions",
Expand Down Expand Up @@ -64,7 +65,6 @@ cvxpy = [
]
sbr = [
"numpyro",
"jax",
"arviz==0.17.1",
"scipy<1.13.0"
]
Expand Down
15 changes: 12 additions & 3 deletions pysindy/feature_library/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional
from typing import Sequence

import jax
import numpy as np
from scipy import sparse
from sklearn.base import TransformerMixin
Expand Down Expand Up @@ -144,20 +145,28 @@ def x_sequence_or_item(wrapped_func):
@wraps(wrapped_func)
def func(self, x, *args, **kwargs):
if isinstance(x, Sequence):
xs = [AxesArray(xi, comprehend_axes(xi)) for xi in x]
if isinstance(x[0], jax.Array):
xs = x
else:
xs = [AxesArray(xi, comprehend_axes(xi)) for xi in x]
result = wrapped_func(self, xs, *args, **kwargs)
# if transform() is a normal "return x"
if isinstance(result, Sequence) and isinstance(result[0], np.ndarray):
return [AxesArray(xp, comprehend_axes(xp)) for xp in result]
return result # e.g. fit() returns self
else:
if not sparse.issparse(x):
if isinstance(x, jax.Array):

def reconstructor(x):
return x

elif not sparse.issparse(x) and isinstance(x, np.ndarray):
x = AxesArray(x, comprehend_axes(x))

def reconstructor(x):
return x

else: # sparse arrays
else: # sparse
reconstructor = type(x)
axes = comprehend_axes(x)
wrap_axes(axes, x)
Expand Down
5 changes: 3 additions & 2 deletions pysindy/feature_library/polynomial_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..utils import AxesArray
from ..utils import comprehend_axes
from ..utils import wrap_axes
from ..utils._axis_conventions import AX_COORD
from .base import BaseFeatureLibrary
from .base import x_sequence_or_item

Expand Down Expand Up @@ -180,7 +181,7 @@ def fit(self, x_full: list[AxesArray], y=None):
"Can't have include_interaction be False and interaction_only"
" be True"
)
n_features = x_full[0].shape[x_full[0].ax_coord]
n_features = x_full[0].shape[AX_COORD]
combinations = self._combinations(
n_features,
self.degree,
Expand Down Expand Up @@ -217,7 +218,7 @@ def transform(self, x_full):
axes = comprehend_axes(x)
x = x.asformat("csc")
wrap_axes(axes, x)
n_features = x.shape[x.ax_coord]
n_features = x.shape[AX_COORD]
if n_features != self.n_features_in_:
raise ValueError("x shape does not match training shape")

Expand Down
2 changes: 2 additions & 0 deletions pysindy/utils/_axis_conventions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
AX_TIME = -2
AX_COORD = -1

0 comments on commit 22329dd

Please sign in to comment.