From f5b201594d90d1c47fd0e8d38e59a4dd78c2055c Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 14 Jan 2024 18:30:28 +0000 Subject: [PATCH] bug: Make axes explicit in PDEs --- pysindy/feature_library/pde_library.py | 8 +--- pysindy/feature_library/weak_pde_library.py | 41 +++++++++++++-------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/pysindy/feature_library/pde_library.py b/pysindy/feature_library/pde_library.py index d6c8666c9..8042a221b 100644 --- a/pysindy/feature_library/pde_library.py +++ b/pysindy/feature_library/pde_library.py @@ -276,13 +276,7 @@ def get_feature_names(self, input_features=None): def derivative_string(multiindex): ret = "" for axis in range(self.ind_range): - if self.implicit_terms and ( - axis - in [ - self.spatiotemporal_grid.ax_time, - self.spatiotemporal_grid.ax_sample, - ] - ): + if self.implicit_terms and (axis == self.spatiotemporal_grid.ax_time,): str_deriv = "t" else: str_deriv = str(axis + 1) diff --git a/pysindy/feature_library/weak_pde_library.py b/pysindy/feature_library/weak_pde_library.py index 5aa3cbbb6..02ed2851f 100644 --- a/pysindy/feature_library/weak_pde_library.py +++ b/pysindy/feature_library/weak_pde_library.py @@ -9,6 +9,7 @@ from sklearn.utils.validation import check_is_fitted from ..utils import AxesArray +from ..utils import comprehend_axes from .base import BaseFeatureLibrary from .base import x_sequence_or_item from pysindy.differentiation import FiniteDifference @@ -245,7 +246,10 @@ def __init__( self.num_derivatives = num_derivatives self.multiindices = multiindices - self.spatiotemporal_grid = spatiotemporal_grid + + self.spatiotemporal_grid = AxesArray( + spatiotemporal_grid, axes=comprehend_axes(spatiotemporal_grid) + ) # Weak form checks and setup self._weak_form_setup() @@ -255,12 +259,14 @@ def _weak_form_setup(self): L_xt = xt2 - xt1 if self.H_xt is not None: if np.isscalar(self.H_xt): - self.H_xt = np.array(self.grid_ndim * [self.H_xt]) + self.H_xt = AxesArray( + np.array(self.grid_ndim * [self.H_xt]), {"ax_coord": 0} + ) if self.grid_ndim != len(self.H_xt): raise ValueError( "The user-defined grid (spatiotemporal_grid) and " "the user-defined sizes of the subdomains for the " - "weak form, do not have the same # of spatiotemporal " + "weak form do not have the same # of spatiotemporal " "dimensions. For instance, if spatiotemporal_grid is 4D, " "then H_xt should be a 4D list of the subdomain lengths." ) @@ -285,8 +291,8 @@ def _weak_form_setup(self): self._set_up_weights() def _get_spatial_endpoints(self): - x1 = np.zeros(self.grid_ndim) - x2 = np.zeros(self.grid_ndim) + x1 = AxesArray(np.zeros(self.grid_ndim), {"ax_coord": 0}) + x2 = AxesArray(np.zeros(self.grid_ndim), {"ax_coord": 0}) for i in range(self.grid_ndim): inds = [slice(None)] * (self.grid_ndim + 1) for j in range(self.grid_ndim): @@ -306,7 +312,9 @@ def _set_up_weights(self): # Sample the random domain centers xt1, xt2 = self._get_spatial_endpoints() - domain_centers = np.zeros((self.K, self.grid_ndim)) + domain_centers = AxesArray( + np.zeros((self.K, self.grid_ndim)), {"ax_sample": 0, "ax_coord": 1} + ) for i in range(self.grid_ndim): domain_centers[:, i] = np.random.uniform( xt1[i] + self.H_xt[i], xt2[i] - self.H_xt[i], size=self.K @@ -321,15 +329,12 @@ def _set_up_weights(self): s = [0] * (self.grid_ndim + 1) s[i] = slice(None) s[-1] = i - newinds = np.intersect1d( - np.where( - self.spatiotemporal_grid[tuple(s)] - >= domain_centers[k][i] - self.H_xt[i] - ), - np.where( - self.spatiotemporal_grid[tuple(s)] - <= domain_centers[k][i] + self.H_xt[i] - ), + ax_vals = self.spatiotemporal_grid[tuple(s)] + cell_left = domain_centers[k][i] - self.H_xt[i] + cell_right = domain_centers[k][i] + self.H_xt[i] + newinds = AxesArray( + ((ax_vals > cell_left) & (ax_vals < cell_right)).nonzero()[0], + ax_vals.axes, ) # If less than two indices along any axis, resample if len(newinds) < 2: @@ -346,6 +351,7 @@ def _set_up_weights(self): self.inds_k = self.inds_k + [inds] k = k + 1 + # TODO: fix meaning of axes in XT_k # Values of the spatiotemporal grid on the domain cells XT_k = [ self.spatiotemporal_grid[np.ix_(*self.inds_k[k])] for k in range(self.K) @@ -468,6 +474,11 @@ def _set_up_weights(self): ) weights1 = weights1 + [weights2] + # TODO: get rest of code to work with AxesArray + deaxify = lambda arr_list: [np.asarray(arr) for arr in arr_list] + tweights = deaxify(tweights) + weights0 = deaxify(weights0) + weights1 = deaxify(weights1) # Product weights over the axes for time derivatives, shaped as inds_k self.fulltweights = [] deriv = np.zeros(self.grid_ndim)