From 22329ddbeb65f464357c8f7b3793b055ee11d706 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 17 Oct 2024 18:51:33 -0700 Subject: [PATCH] feat: Accomodate jax arrays This was thought to be easy, because in many cases jax arrays were an almost drop-in replacement for numpy arrays. However, they are far less amenable to subclassing. Why does this matter? The codebase gained a lot of readability with AxesArray allowing arrays to dynamically know what their axes meant, even after indexing changed their shape. However, extending AxesArray to dynamically subclass either numpy.ndarray or jax.Array is impossible - even a static subclass of the latter is impossible. Long term, we will need our own metadata type that carries around an array, it's type package (numpy or jax.numpy or cvxpy.numpy), its bidirectional mapping between axis index and axis meaning, and maybe even something from sympy. Short term, we should expose our general expectations for axis definitions as global constants. This is still error prone, as the constants are incorrect for arrays that have changed shape due to indexing, but will be far more readable than magic numbers. --- pyproject.toml | 2 +- pysindy/feature_library/base.py | 15 ++++++++++++--- pysindy/feature_library/polynomial_library.py | 5 +++-- pysindy/utils/_axis_conventions.py | 2 ++ 4 files changed, 18 insertions(+), 6 deletions(-) create mode 100644 pysindy/utils/_axis_conventions.py diff --git a/pyproject.toml b/pyproject.toml index 26e9d4ad..cd3603a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ classifiers = [ ] readme = "README.rst" dependencies = [ + "jax>=0.4,<0.5", "scikit-learn>=1.1, !=1.5.0", "derivative>=0.6.2", "typing_extensions", @@ -64,7 +65,6 @@ cvxpy = [ ] sbr = [ "numpyro", - "jax", "arviz==0.17.1", "scipy<1.13.0" ] diff --git a/pysindy/feature_library/base.py b/pysindy/feature_library/base.py index 671386cb..8f24dc9a 100644 --- a/pysindy/feature_library/base.py +++ b/pysindy/feature_library/base.py @@ -8,6 +8,7 @@ from typing import Optional from typing import Sequence +import jax import numpy as np from scipy import sparse from sklearn.base import TransformerMixin @@ -144,20 +145,28 @@ def x_sequence_or_item(wrapped_func): @wraps(wrapped_func) def func(self, x, *args, **kwargs): if isinstance(x, Sequence): - xs = [AxesArray(xi, comprehend_axes(xi)) for xi in x] + if isinstance(x[0], jax.Array): + xs = x + else: + xs = [AxesArray(xi, comprehend_axes(xi)) for xi in x] result = wrapped_func(self, xs, *args, **kwargs) # if transform() is a normal "return x" if isinstance(result, Sequence) and isinstance(result[0], np.ndarray): return [AxesArray(xp, comprehend_axes(xp)) for xp in result] return result # e.g. fit() returns self else: - if not sparse.issparse(x): + if isinstance(x, jax.Array): + + def reconstructor(x): + return x + + elif not sparse.issparse(x) and isinstance(x, np.ndarray): x = AxesArray(x, comprehend_axes(x)) def reconstructor(x): return x - else: # sparse arrays + else: # sparse reconstructor = type(x) axes = comprehend_axes(x) wrap_axes(axes, x) diff --git a/pysindy/feature_library/polynomial_library.py b/pysindy/feature_library/polynomial_library.py index 72e357a7..496841e8 100644 --- a/pysindy/feature_library/polynomial_library.py +++ b/pysindy/feature_library/polynomial_library.py @@ -12,6 +12,7 @@ from ..utils import AxesArray from ..utils import comprehend_axes from ..utils import wrap_axes +from ..utils._axis_conventions import AX_COORD from .base import BaseFeatureLibrary from .base import x_sequence_or_item @@ -180,7 +181,7 @@ def fit(self, x_full: list[AxesArray], y=None): "Can't have include_interaction be False and interaction_only" " be True" ) - n_features = x_full[0].shape[x_full[0].ax_coord] + n_features = x_full[0].shape[AX_COORD] combinations = self._combinations( n_features, self.degree, @@ -217,7 +218,7 @@ def transform(self, x_full): axes = comprehend_axes(x) x = x.asformat("csc") wrap_axes(axes, x) - n_features = x.shape[x.ax_coord] + n_features = x.shape[AX_COORD] if n_features != self.n_features_in_: raise ValueError("x shape does not match training shape") diff --git a/pysindy/utils/_axis_conventions.py b/pysindy/utils/_axis_conventions.py new file mode 100644 index 00000000..98a7c582 --- /dev/null +++ b/pysindy/utils/_axis_conventions.py @@ -0,0 +1,2 @@ +AX_TIME = -2 +AX_COORD = -1