Skip to content

Commit

Permalink
bug: Make axes explicit in PDEs
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Jan 14, 2024
1 parent cc6025e commit f5b2015
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 22 deletions.
8 changes: 1 addition & 7 deletions pysindy/feature_library/pde_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,7 @@ def get_feature_names(self, input_features=None):
def derivative_string(multiindex):
ret = ""
for axis in range(self.ind_range):
if self.implicit_terms and (
axis
in [
self.spatiotemporal_grid.ax_time,
self.spatiotemporal_grid.ax_sample,
]
):
if self.implicit_terms and (axis == self.spatiotemporal_grid.ax_time,):
str_deriv = "t"
else:
str_deriv = str(axis + 1)
Expand Down
41 changes: 26 additions & 15 deletions pysindy/feature_library/weak_pde_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sklearn.utils.validation import check_is_fitted

from ..utils import AxesArray
from ..utils import comprehend_axes
from .base import BaseFeatureLibrary
from .base import x_sequence_or_item
from pysindy.differentiation import FiniteDifference
Expand Down Expand Up @@ -245,7 +246,10 @@ def __init__(

self.num_derivatives = num_derivatives
self.multiindices = multiindices
self.spatiotemporal_grid = spatiotemporal_grid

self.spatiotemporal_grid = AxesArray(
spatiotemporal_grid, axes=comprehend_axes(spatiotemporal_grid)
)

# Weak form checks and setup
self._weak_form_setup()
Expand All @@ -255,12 +259,14 @@ def _weak_form_setup(self):
L_xt = xt2 - xt1
if self.H_xt is not None:
if np.isscalar(self.H_xt):
self.H_xt = np.array(self.grid_ndim * [self.H_xt])
self.H_xt = AxesArray(
np.array(self.grid_ndim * [self.H_xt]), {"ax_coord": 0}
)
if self.grid_ndim != len(self.H_xt):
raise ValueError(
"The user-defined grid (spatiotemporal_grid) and "
"the user-defined sizes of the subdomains for the "
"weak form, do not have the same # of spatiotemporal "
"weak form do not have the same # of spatiotemporal "
"dimensions. For instance, if spatiotemporal_grid is 4D, "
"then H_xt should be a 4D list of the subdomain lengths."
)
Expand All @@ -285,8 +291,8 @@ def _weak_form_setup(self):
self._set_up_weights()

def _get_spatial_endpoints(self):
x1 = np.zeros(self.grid_ndim)
x2 = np.zeros(self.grid_ndim)
x1 = AxesArray(np.zeros(self.grid_ndim), {"ax_coord": 0})
x2 = AxesArray(np.zeros(self.grid_ndim), {"ax_coord": 0})
for i in range(self.grid_ndim):
inds = [slice(None)] * (self.grid_ndim + 1)
for j in range(self.grid_ndim):
Expand All @@ -306,7 +312,9 @@ def _set_up_weights(self):

# Sample the random domain centers
xt1, xt2 = self._get_spatial_endpoints()
domain_centers = np.zeros((self.K, self.grid_ndim))
domain_centers = AxesArray(
np.zeros((self.K, self.grid_ndim)), {"ax_sample": 0, "ax_coord": 1}
)
for i in range(self.grid_ndim):
domain_centers[:, i] = np.random.uniform(
xt1[i] + self.H_xt[i], xt2[i] - self.H_xt[i], size=self.K
Expand All @@ -321,15 +329,12 @@ def _set_up_weights(self):
s = [0] * (self.grid_ndim + 1)
s[i] = slice(None)
s[-1] = i
newinds = np.intersect1d(
np.where(
self.spatiotemporal_grid[tuple(s)]
>= domain_centers[k][i] - self.H_xt[i]
),
np.where(
self.spatiotemporal_grid[tuple(s)]
<= domain_centers[k][i] + self.H_xt[i]
),
ax_vals = self.spatiotemporal_grid[tuple(s)]
cell_left = domain_centers[k][i] - self.H_xt[i]
cell_right = domain_centers[k][i] + self.H_xt[i]
newinds = AxesArray(
((ax_vals > cell_left) & (ax_vals < cell_right)).nonzero()[0],
ax_vals.axes,
)
# If less than two indices along any axis, resample
if len(newinds) < 2:
Expand All @@ -346,6 +351,7 @@ def _set_up_weights(self):
self.inds_k = self.inds_k + [inds]
k = k + 1

# TODO: fix meaning of axes in XT_k
# Values of the spatiotemporal grid on the domain cells
XT_k = [
self.spatiotemporal_grid[np.ix_(*self.inds_k[k])] for k in range(self.K)
Expand Down Expand Up @@ -468,6 +474,11 @@ def _set_up_weights(self):
)
weights1 = weights1 + [weights2]

# TODO: get rest of code to work with AxesArray
deaxify = lambda arr_list: [np.asarray(arr) for arr in arr_list]
tweights = deaxify(tweights)
weights0 = deaxify(weights0)
weights1 = deaxify(weights1)
# Product weights over the axes for time derivatives, shaped as inds_k
self.fulltweights = []
deriv = np.zeros(self.grid_ndim)
Expand Down

0 comments on commit f5b2015

Please sign in to comment.