Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make AxesArray handle slicing correctly #451

Merged
merged 64 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
e92360f
TST: Add tests for proper AxesArray warnings/slices
Jacob-Stevens-Haas Apr 25, 2023
47e8793
WIP making array slicing consistent
Jacob-Stevens-Haas Apr 26, 2023
647e6ec
WIP: Offload AxesArray construction logic to _AxisMapping
Jacob-Stevens-Haas Apr 30, 2023
06c5b9a
WIP Allow __array_function__ to let ufuncs pass through without change.
Jacob-Stevens-Haas Apr 30, 2023
218e1f4
WIP: begin __getitem__ work to id axes
Jacob-Stevens-Haas May 1, 2023
0d358de
ENH: add function to standardize basic indexing keys
Jacob-Stevens-Haas May 18, 2023
3393599
ENH: rename AxisMapping.reduce and apply to multiple axes
Jacob-Stevens-Haas May 21, 2023
6f6aec5
TST: Add test for inserting axes & mis-ordered axes
Jacob-Stevens-Haas May 21, 2023
a138e05
BUG: build insert_axis
Jacob-Stevens-Haas May 22, 2023
d5e046b
BUG: Allow reduce_axis to handle twisted axes.
Jacob-Stevens-Haas May 22, 2023
86c2e3d
ENH: Enable basic indexing on AxesArrays
Jacob-Stevens-Haas May 22, 2023
478cf52
TST: Build test for standardizing fancy indexers
Jacob-Stevens-Haas May 22, 2023
beef4a7
ENH: Allow _standardize_indexer to handle fancy indexes
Jacob-Stevens-Haas May 22, 2023
5397edc
BUG: make _standardize_indexer handle lists with numpy arrays
Jacob-Stevens-Haas May 23, 2023
353c9d3
TST: enhance advanced indexing test
Jacob-Stevens-Haas May 23, 2023
fb9e01a
TST: Enhance basic slicing test
Jacob-Stevens-Haas Jun 3, 2023
0acda3b
WIP: rearrange __getitem__ for advanced
Jacob-Stevens-Haas May 23, 2023
a54e684
WIP but not at a stable point
Jacob-Stevens-Haas Jun 3, 2023
6ccba03
CLN: Make AxesArray syntax a little clearer
Jacob-Stevens-Haas Jan 4, 2024
b8c8739
BUG: Sort axis argument when inserting or removing axes
Jacob-Stevens-Haas Jan 4, 2024
17842da
BUG: Fix everything about _squeeze_to_sublist with tests
Jacob-Stevens-Haas Jan 4, 2024
efbfac2
CLN: explain _standardize_indexer
Jacob-Stevens-Haas Jan 4, 2024
c101a5a
TST: Test _determine_adv_broadcasting
Jacob-Stevens-Haas Jan 4, 2024
4ac044f
CLN: Remove name mangling from AxesArray _ax_map
Jacob-Stevens-Haas Jan 4, 2024
48634df
CLN: Simplify advanced indexing broadcast calculation
Jacob-Stevens-Haas Jan 4, 2024
91064dd
CLN: Extract function on the basic indexing
Jacob-Stevens-Haas Jan 4, 2024
700521a
CLN: Type the return of determine_adv_broadcasting
Jacob-Stevens-Haas Jan 4, 2024
204223f
ENH: Enable fancy indexing in AxesArray
Jacob-Stevens-Haas Jan 4, 2024
49ac5a0
TST: Update tests for new helper function values
Jacob-Stevens-Haas Jan 4, 2024
42088a1
ENH: Enable boolean advanced indexing in AxesArray
Jacob-Stevens-Haas Jan 5, 2024
62d12ea
ENH: Allow inserting axes by adding strings to index
Jacob-Stevens-Haas Jan 5, 2024
341f15c
CLN: Remove default name for a new axis
Jacob-Stevens-Haas Jan 5, 2024
0ae9f6a
CLN: Remove default name for a new axis
Jacob-Stevens-Haas Jan 6, 2024
80452a6
CLN: Remove unused type expressions
Jacob-Stevens-Haas Jan 7, 2024
51bae67
bug(axes): Only pre-standardize tuple indexers
Jacob-Stevens-Haas Jan 12, 2024
23817f0
feat(_AxisMapping): create an ndim property
Jacob-Stevens-Haas Jan 12, 2024
c11c0d6
bug(axes): enable 0-degree arrays
Jacob-Stevens-Haas Jan 12, 2024
bb1c73d
feat(axes): Enable np.reshape on AxesArrays
Jacob-Stevens-Haas Jan 13, 2024
0bd7182
Merge branch 'master' into axesarray-indexing
Jacob-Stevens-Haas Jan 13, 2024
f13d593
bug: Make caller more explicit to create AxesArray
Jacob-Stevens-Haas Jan 13, 2024
996d555
feat(axes): Make np.transpose work on AxesArray
Jacob-Stevens-Haas Jan 13, 2024
3298f5f
feat(axes): Make np.einsum work on AxesArray
Jacob-Stevens-Haas Jan 14, 2024
18d449e
bug: clean up callers of AxesArray
Jacob-Stevens-Haas Jan 14, 2024
a30220a
Merge remote-tracking branch 'origin/master' into axesarray-indexing
Jacob-Stevens-Haas Jan 14, 2024
cc6025e
feat(axes): Support numpy.ix_
Jacob-Stevens-Haas Jan 14, 2024
f5b2015
bug: Make axes explicit in PDEs
Jacob-Stevens-Haas Jan 14, 2024
e0eb87e
fix(axes): Prevent inf recursive einsum
Jacob-Stevens-Haas Jan 14, 2024
c93ca16
feat(axes): Add np.linalg.solve
Jacob-Stevens-Haas Jan 14, 2024
690aa06
bug: Enable AxesArray in FiniteDifference internals
Jacob-Stevens-Haas Jan 15, 2024
264f8b3
bug(axes): Fix transpose and einsum bugs
Jacob-Stevens-Haas Jan 15, 2024
f0fc6b3
fix(axes): Pass correct ndim to _AxisMapping()
Jacob-Stevens-Haas Jan 15, 2024
8084fd4
feat(axes): Add tensordot function for AxesArrays
Jacob-Stevens-Haas Jan 15, 2024
8f1e4bc
test(axes): Add linalg.solve() tests for AxesArray
Jacob-Stevens-Haas Jan 15, 2024
3dacb89
bug(axes) Change axis alignment linalg_solve + test
Jacob-Stevens-Haas Jan 15, 2024
21c3b01
test(axes): Add tensordot tests
Jacob-Stevens-Haas Jan 15, 2024
36ae58c
fix(axes): pass ts-to-einsum tests
Jacob-Stevens-Haas Jan 16, 2024
730a582
fix(einsum): Replace lstrip with removeprefix in renaming axes
Jacob-Stevens-Haas Jan 16, 2024
93f14a7
bug(finite_difference): Wrap internal arrays as AxesArrays
Jacob-Stevens-Haas Jan 16, 2024
323d115
bug(AxesArray.tensordot): Adapt int index to list of lists
Jacob-Stevens-Haas Jan 16, 2024
cef7167
clean: downgrade typing syntax and stdlibrary use to python 3.8
Jacob-Stevens-Haas Jan 16, 2024
2c56053
doc: Fix doc build errors. Upgrade sphinx
Jacob-Stevens-Haas Jan 17, 2024
3ede6d0
feat/doc(axes): Make helpers public so docs pick them up
Jacob-Stevens-Haas Jan 17, 2024
9c76e79
tst(axes): Cover more lines!
Jacob-Stevens-Haas Jan 17, 2024
bc3fbd2
Merge branch 'master' into axesarray-indexing
Jacob-Stevens-Haas Jan 31, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ venv/
ENV/
env.bak/
venv.bak/
env8

# automatically generated by setuptools-scm
pysindy/version.py
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ docs = [
"ipython",
"pandoc",
"sphinx-rtd-theme",
"sphinx==5.3.0",
"sphinx==7.1.2",
"sphinxcontrib-apidoc",
"nbsphinx"
]
Expand Down
58 changes: 38 additions & 20 deletions pysindy/differentiation/finite_difference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import List
from typing import Union

import numpy as np
from numpy.typing import NDArray
from scipy.special import factorial

from .base import BaseDifferentiation
from pysindy.utils.axes import AxesArray


class FiniteDifference(BaseDifferentiation):
Expand Down Expand Up @@ -55,7 +60,7 @@ class FiniteDifference(BaseDifferentiation):
def __init__(
self,
order=2,
d=1,
d: int = 1,
axis=0,
is_uniform=False,
drop_endpoints=False,
Expand Down Expand Up @@ -85,22 +90,27 @@ def __init__(

def _coefficients(self, t):
nt = len(t)
self.stencil_inds = np.array(
[np.arange(i, nt - self.n_stencil + i + 1) for i in range(self.n_stencil)]
self.stencil_inds = AxesArray(
np.array(
[
np.arange(i, nt - self.n_stencil + i + 1)
for i in range(self.n_stencil)
]
),
{"ax_offset": 0, "ax_ti": 1},
)
self.stencil = AxesArray(
np.transpose(t[self.stencil_inds]), {"ax_time": 0, "ax_offset": 1}
)
self.stencil = np.transpose(t[self.stencil_inds])

pows = np.arange(self.n_stencil)[np.newaxis, :, np.newaxis]
matrices = (
dt_endpoints = (
self.stencil
- t[
(self.n_stencil - 1) // 2 : -(self.n_stencil - 1) // 2,
np.newaxis,
]
)[:, np.newaxis, :] ** pows
b = np.zeros(self.n_stencil)
b[self.d] = factorial(self.d)
return np.linalg.solve(matrices, [b])
- t[(self.n_stencil - 1) // 2 : -(self.n_stencil - 1) // 2, "offset"]
)
matrices = dt_endpoints[:, "power", :] ** pows
b = AxesArray(np.zeros((1, self.n_stencil)), {"ax_time": 0, "ax_power": 1})
b[0, self.d] = factorial(self.d)
return np.linalg.solve(matrices, b)

def _coefficients_boundary_forward(self, t):
# use the same stencil for each boundary point,
Expand Down Expand Up @@ -202,23 +212,30 @@ def _constant_coefficients(self, dt):

def _accumulate(self, coeffs, x):
# slice to select the stencil indices
s = [slice(None)] * len(x.shape)
s = [slice(None)] * x.ndim
s[self.axis] = self.stencil_inds

# a new axis is introduced after self.axis for the stencil indices
# a new axis is introduced before self.axis for the stencil indices
# To contract with the coefficients, roll by -self.axis to put it first
# Then roll back by self.axis to return the order
trans = np.roll(np.arange(len(x.shape) + 1), -self.axis)
trans = np.roll(np.arange(x.ndim + 1), -self.axis)
# TODO: assign x's axes much earlier in the call stack
x = AxesArray(x, {"ax_unk": list(range(x.ndim))})
x_expanded = AxesArray(
np.transpose(x[tuple(s)], axes=trans), x.insert_axis(0, "ax_offset")
)
return np.transpose(
np.einsum(
"ij...,ij->j...",
np.transpose(x[tuple(s)], axes=trans),
x_expanded,
np.transpose(coeffs),
),
np.roll(np.arange(len(x.shape)), self.axis),
np.roll(np.arange(x.ndim), self.axis),
)

def _differentiate(self, x, t):
def _differentiate(
self, x: NDArray, t: Union[NDArray, float, List[float]]
) -> NDArray:
"""
Apply finite difference method.
"""
Expand Down Expand Up @@ -249,6 +266,7 @@ def _differentiate(self, x, t):
s[self.axis] = slice(start, stop)
interior = interior + x[tuple(s)] * coeffs[i]
else:
t = AxesArray(np.array(t), axes={"ax_time": 0})
coeffs = self._coefficients(t)
interior = self._accumulate(coeffs, x)
s[self.axis] = slice((self.n_stencil - 1) // 2, -(self.n_stencil - 1) // 2)
Expand Down
9 changes: 4 additions & 5 deletions pysindy/feature_library/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,9 @@ def correct_shape(self, x: AxesArray):
return x

def calc_trajectory(self, diff_method, x, t):
axes = x.__dict__
x_dot = diff_method(x, t=t)
x = AxesArray(diff_method.smoothed_x_, axes)
return x, AxesArray(x_dot, axes)
x = AxesArray(diff_method.smoothed_x_, x.axes)
return x, AxesArray(x_dot, x.axes)

def get_spatial_grid(self):
return None
Expand Down Expand Up @@ -337,7 +336,7 @@ def __init__(
self.libraries = libraries
self.inputs_per_library = inputs_per_library

def _combinations(self, lib_i, lib_j):
def _combinations(self, lib_i: AxesArray, lib_j: AxesArray) -> AxesArray:
"""
Compute combinations of the numerical libraries.

Expand All @@ -351,7 +350,7 @@ def _combinations(self, lib_i, lib_j):
lib_i.shape[lib_i.ax_coord] * lib_j.shape[lib_j.ax_coord]
)
lib_full = np.reshape(
lib_i[..., :, np.newaxis] * lib_j[..., np.newaxis, :],
lib_i[..., :, "coord"] * lib_j[..., "coord", :],
shape,
)

Expand Down
2 changes: 1 addition & 1 deletion pysindy/feature_library/generalized_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def transform(self, x_full):
else:
xps.append(lib.transform([x])[0])

xp = AxesArray(np.concatenate(xps, axis=xps[0].ax_coord), xps[0].__dict__)
xp = AxesArray(np.concatenate(xps, axis=xps[0].ax_coord), xps[0].axes)
xp_full = xp_full + [xp]
return xp_full

Expand Down
14 changes: 4 additions & 10 deletions pysindy/feature_library/pde_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,7 @@ def get_feature_names(self, input_features=None):
def derivative_string(multiindex):
ret = ""
for axis in range(self.ind_range):
if self.implicit_terms and (
axis
in [
self.spatiotemporal_grid.ax_time,
self.spatiotemporal_grid.ax_sample,
]
):
if self.implicit_terms and (axis == self.spatiotemporal_grid.ax_time,):
str_deriv = "t"
else:
str_deriv = str(axis + 1)
Expand Down Expand Up @@ -345,7 +339,7 @@ def transform(self, x_full):

# derivative terms
shape[-1] = n_features * self.num_derivatives
library_derivatives = np.empty(shape, dtype=x.dtype)
library_derivatives = AxesArray(np.empty(shape, dtype=x.dtype), x.axes)
library_idx = 0
for multiindex in self.multiindices:
derivs = x
Expand Down Expand Up @@ -395,8 +389,8 @@ def transform(self, x_full):
library_idx : library_idx
+ n_library_terms * self.num_derivatives * n_features,
] = np.reshape(
library_functions[..., np.newaxis, :]
* library_derivatives[..., :, np.newaxis],
library_functions[..., "coord", :]
* library_derivatives[..., :, "coord"],
shape,
)
library_idx += n_library_terms * self.num_derivatives * n_features
Expand Down
2 changes: 1 addition & 1 deletion pysindy/feature_library/polynomial_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def transform(self, x_full):
dtype=x.dtype,
order=self.order,
),
x.__dict__,
x.axes,
)
for i, comb in enumerate(combinations):
xp[..., i] = x[..., comb].prod(-1)
Expand Down
2 changes: 1 addition & 1 deletion pysindy/feature_library/sindy_pi_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,5 +404,5 @@ def transform(self, x_full):
*[x[:, comb] for comb in f_combs]
) * f_dot(*[x_dot[:, comb] for comb in f_dot_combs])
library_idx += 1
xp_full = xp_full + [AxesArray(xp, x.__dict__)]
xp_full = xp_full + [AxesArray(xp, x.axes)]
return xp_full
41 changes: 26 additions & 15 deletions pysindy/feature_library/weak_pde_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sklearn.utils.validation import check_is_fitted

from ..utils import AxesArray
from ..utils import comprehend_axes
from .base import BaseFeatureLibrary
from .base import x_sequence_or_item
from .polynomial_library import PolynomialLibrary
Expand Down Expand Up @@ -218,7 +219,10 @@ def __init__(

self.num_derivatives = num_derivatives
self.multiindices = multiindices
self.spatiotemporal_grid = spatiotemporal_grid

self.spatiotemporal_grid = AxesArray(
spatiotemporal_grid, axes=comprehend_axes(spatiotemporal_grid)
)

# Weak form checks and setup
self._weak_form_setup()
Expand All @@ -228,12 +232,14 @@ def _weak_form_setup(self):
L_xt = xt2 - xt1
if self.H_xt is not None:
if np.isscalar(self.H_xt):
self.H_xt = np.array(self.grid_ndim * [self.H_xt])
self.H_xt = AxesArray(
np.array(self.grid_ndim * [self.H_xt]), {"ax_coord": 0}
)
if self.grid_ndim != len(self.H_xt):
raise ValueError(
"The user-defined grid (spatiotemporal_grid) and "
"the user-defined sizes of the subdomains for the "
"weak form, do not have the same # of spatiotemporal "
"weak form do not have the same # of spatiotemporal "
"dimensions. For instance, if spatiotemporal_grid is 4D, "
"then H_xt should be a 4D list of the subdomain lengths."
)
Expand All @@ -258,8 +264,8 @@ def _weak_form_setup(self):
self._set_up_weights()

def _get_spatial_endpoints(self):
x1 = np.zeros(self.grid_ndim)
x2 = np.zeros(self.grid_ndim)
x1 = AxesArray(np.zeros(self.grid_ndim), {"ax_coord": 0})
x2 = AxesArray(np.zeros(self.grid_ndim), {"ax_coord": 0})
for i in range(self.grid_ndim):
inds = [slice(None)] * (self.grid_ndim + 1)
for j in range(self.grid_ndim):
Expand All @@ -279,7 +285,9 @@ def _set_up_weights(self):

# Sample the random domain centers
xt1, xt2 = self._get_spatial_endpoints()
domain_centers = np.zeros((self.K, self.grid_ndim))
domain_centers = AxesArray(
np.zeros((self.K, self.grid_ndim)), {"ax_sample": 0, "ax_coord": 1}
)
for i in range(self.grid_ndim):
domain_centers[:, i] = np.random.uniform(
xt1[i] + self.H_xt[i], xt2[i] - self.H_xt[i], size=self.K
Expand All @@ -294,15 +302,12 @@ def _set_up_weights(self):
s = [0] * (self.grid_ndim + 1)
s[i] = slice(None)
s[-1] = i
newinds = np.intersect1d(
np.where(
self.spatiotemporal_grid[tuple(s)]
>= domain_centers[k][i] - self.H_xt[i]
),
np.where(
self.spatiotemporal_grid[tuple(s)]
<= domain_centers[k][i] + self.H_xt[i]
),
ax_vals = self.spatiotemporal_grid[tuple(s)]
cell_left = domain_centers[k][i] - self.H_xt[i]
cell_right = domain_centers[k][i] + self.H_xt[i]
newinds = AxesArray(
((ax_vals > cell_left) & (ax_vals < cell_right)).nonzero()[0],
ax_vals.axes,
)
# If less than two indices along any axis, resample
if len(newinds) < 2:
Expand All @@ -319,6 +324,7 @@ def _set_up_weights(self):
self.inds_k = self.inds_k + [inds]
k = k + 1

# TODO: fix meaning of axes in XT_k
# Values of the spatiotemporal grid on the domain cells
XT_k = [
self.spatiotemporal_grid[np.ix_(*self.inds_k[k])] for k in range(self.K)
Expand Down Expand Up @@ -441,6 +447,11 @@ def _set_up_weights(self):
)
weights1 = weights1 + [weights2]

# TODO: get rest of code to work with AxesArray. Too unsure of
# which axis labels to use at this point to continue
tweights = [np.asarray(arr) for arr in tweights]
weights0 = [np.asarray(arr) for arr in weights0]
weights1 = [[np.asarray(arr) for arr in sublist] for sublist in weights1]
# Product weights over the axes for time derivatives, shaped as inds_k
self.fulltweights = []
deriv = np.zeros(self.grid_ndim)
Expand Down
3 changes: 2 additions & 1 deletion pysindy/optimizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def fit(self, x_, y, sample_weight=None, **reduce_kws):
self : returns an instance of self
"""
x_ = AxesArray(np.asarray(x_), {"ax_sample": 0, "ax_coord": 1})
y = AxesArray(np.asarray(y), {"ax_sample": 0, "ax_coord": 1})
y_axes = {"ax_sample": 0} if y.ndim == 1 else {"ax_sample": 0, "ax_coord": 1}
y = AxesArray(np.asarray(y), y_axes)
x_, y = drop_nan_samples(x_, y)
x_, y = check_X_y(x_, y, accept_sparse=[], y_numeric=True, multi_output=True)

Expand Down
Loading
Loading