Skip to content

Commit

Permalink
added some test
Browse files Browse the repository at this point in the history
  • Loading branch information
znicolaou committed Nov 29, 2023
1 parent 8b6e67c commit 0f2e54c
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 14 deletions.
9 changes: 5 additions & 4 deletions pysindy/differentiation/finite_difference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from scipy.special import factorial

from .base import BaseDifferentiation

Expand Down Expand Up @@ -98,7 +99,7 @@ def _coefficients(self, t):
]
)[:, np.newaxis, :] ** pows
b = np.zeros(self.n_stencil)
b[self.d] = np.math.factorial(self.d)
b[self.d] = factorial(self.d)
return np.linalg.solve(matrices, [b])

def _coefficients_boundary_forward(self, t):
Expand Down Expand Up @@ -134,7 +135,7 @@ def _coefficients_boundary_forward(self, t):
)

b = np.zeros(self.stencil_inds.shape).T
b[:, self.d] = np.math.factorial(self.d)
b[:, self.d] = factorial(self.d)
return np.linalg.solve(matrices, b)

def _coefficients_boundary_periodic(self, t):
Expand Down Expand Up @@ -187,7 +188,7 @@ def _coefficients_boundary_periodic(self, t):
)

b = np.zeros(self.stencil_inds.shape).T
b[:, self.d] = np.math.factorial(self.d)
b[:, self.d] = factorial(self.d)
return np.linalg.solve(matrices, b)

def _constant_coefficients(self, dt):
Expand All @@ -196,7 +197,7 @@ def _constant_coefficients(self, dt):
np.newaxis, :
] ** pows
b = np.zeros(self.n_stencil)
b[self.d] = np.math.factorial(self.d)
b[self.d] = factorial(self.d)
return np.linalg.solve(matrices, b)

def _accumulate(self, coeffs, x):
Expand Down
10 changes: 5 additions & 5 deletions pysindy/feature_library/weak_pde_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def __init__(
" library_functions or function_names"
)

if library_functions is None and derivative_order == 0:
if library is None and len(library_functions) == 0 and derivative_order == 0:
raise ValueError(
"No library functions were specified, and no "
"derivatives were asked for. The library is empty."
Expand Down Expand Up @@ -504,7 +504,7 @@ def _set_up_weights(self):
)

self.fulltweights = self.fulltweights + [
ret * np.product(H_xt_k[k] ** (1.0 - deriv))
ret * np.prod(H_xt_k[k] ** (1.0 - deriv))
]

# Product weights over the axes for pure derivative terms, shaped as inds_k
Expand All @@ -522,7 +522,7 @@ def _set_up_weights(self):
weights0[i][lefts[i][k] : rights[i][k] + 1], dims
)

self.fullweights0 = self.fullweights0 + [ret * np.product(H_xt_k[k])]
self.fullweights0 = self.fullweights0 + [ret * np.prod(H_xt_k[k])]

# Product weights over the axes for mixed derivative terms, shaped as inds_k
self.fullweights1 = []
Expand All @@ -546,7 +546,7 @@ def _set_up_weights(self):
dims,
)

weights2 = weights2 + [ret * np.product(H_xt_k[k] ** (1.0 - deriv))]
weights2 = weights2 + [ret * np.prod(H_xt_k[k] ** (1.0 - deriv))]
self.fullweights1 = self.fullweights1 + [weights2]

@staticmethod
Expand Down Expand Up @@ -1034,7 +1034,7 @@ def transform(self, x_full):
tuple(np.arange(self.grid_ndim)),
tuple(np.arange(self.grid_ndim)),
),
) * np.product(
) * np.prod(
binom(derivs_mixed, deriv)
)
# collect the results
Expand Down
54 changes: 49 additions & 5 deletions test/test_feature_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,33 +115,59 @@ def test_pde_library_bad_parameters(params):
@pytest.mark.parametrize(
"params",
[
dict(spatiotemporal_grid=range(10), p=-1),
dict(spatiotemporal_grid=range(10), H_xt=-1),
dict(spatiotemporal_grid=range(10), H_xt=11),
dict(spatiotemporal_grid=range(10), K=-1),
dict(
spatiotemporal_grid=range(10),
p=-1,
library_functions=[lambda x: x, lambda x: x**2, lambda x: 0 * x],
),
dict(
spatiotemporal_grid=range(10),
H_xt=-1,
library_functions=[lambda x: x, lambda x: x**2, lambda x: 0 * x],
),
dict(
spatiotemporal_grid=range(10),
H_xt=11,
library_functions=[lambda x: x, lambda x: x**2, lambda x: 0 * x],
),
dict(
spatiotemporal_grid=range(10),
K=-1,
library_functions=[lambda x: x, lambda x: x**2, lambda x: 0 * x],
),
dict(),
dict(
spatiotemporal_grid=np.asarray(np.meshgrid(range(10), range(10))).T,
H_xt=-1,
library_functions=[lambda x: x, lambda x: x**2, lambda x: 0 * x],
),
dict(
spatiotemporal_grid=np.transpose(
np.asarray(np.meshgrid(range(10), range(10), range(10), indexing="ij")),
axes=[1, 2, 3, 0],
),
H_xt=-1,
library_functions=[lambda x: x, lambda x: x**2, lambda x: 0 * x],
),
dict(
spatiotemporal_grid=np.transpose(
np.asarray(np.meshgrid(range(10), range(10), range(10), indexing="ij")),
axes=[1, 2, 3, 0],
),
H_xt=11,
library_functions=[lambda x: x, lambda x: x**2, lambda x: 0 * x],
),
dict(
library=PolynomialLibrary(degree=1, include_bias=False),
library_functions=[lambda x: x, lambda x: x**2, lambda x: 0 * x],
),
dict(
library_functions=[lambda x: x, lambda x: x**2, lambda x: 0 * x],
function_names=[lambda x: x],
),
],
)
def test_weak_pde_library_bad_parameters(params):
params["library_functions"] = [lambda x: x, lambda x: x**2, lambda x: 0 * x]
with pytest.raises(ValueError):
WeakPDELibrary(**params)

Expand Down Expand Up @@ -752,7 +778,15 @@ def test_1D_weak_pdes():
H_xt=2,
include_bias=True,
)
pde_lib2 = WeakPDELibrary(
library=PolynomialLibrary(degree=2, include_bias=False),
derivative_order=4,
spatiotemporal_grid=spatiotemporal_grid,
H_xt=2,
include_bias=True,
)
pde_library_helper(pde_lib, u)
pde_library_helper(pde_lib2, u)


def test_2D_weak_pdes():
Expand All @@ -777,6 +811,16 @@ def test_2D_weak_pdes():
)
pde_library_helper(pde_lib, u)

pde_lib2 = WeakPDELibrary(
library=PolynomialLibrary(degree=2, include_bias=False),
derivative_order=2,
spatiotemporal_grid=spatiotemporal_grid,
H_xt=4,
K=10,
include_bias=True,
)
pde_library_helper(pde_lib2, u)


def test_3D_weak_pdes():
n = 5
Expand Down

0 comments on commit 0f2e54c

Please sign in to comment.