Skip to content

Commit

Permalink
Merge pull request dynamicslab#476 from yb6599/dynamicslabgh-472-sind…
Browse files Browse the repository at this point in the history
…yderivative

Differentiation of Multidimensional Arrays in SINDyDerivative
  • Loading branch information
Jacob-Stevens-Haas authored Jul 15, 2024
2 parents 3503c15 + 234c3e3 commit 15e8093
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
readme = "README.rst"
dependencies = [
"scikit-learn>=1.1, !=1.5.0",
"derivative>=0.5.4",
"derivative>=0.6.2",
]

[project.optional-dependencies]
Expand Down
7 changes: 4 additions & 3 deletions pysindy/differentiation/sindy_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class SINDyDerivative(BaseDifferentiation):
for acceptable keywords.
"""

def __init__(self, save_smooth=True, **kwargs):
def __init__(self, axis=0, save_smooth=True, **kwargs):
self.axis = axis
self.kwargs = kwargs
self.save_smooth = save_smooth

Expand Down Expand Up @@ -76,9 +77,9 @@ def _differentiate(self, x, t=1):
differentiator = methods[self.kwargs["kind"]](
**{k: v for k, v in self.kwargs.items() if k != "kind"}
)
x_dot = differentiator.d(x, t, axis=0)
x_dot = differentiator.d(x, t, axis=self.axis)
if self.save_smooth:
self.smoothed_x_ = differentiator.x(x, t, axis=0)
self.smoothed_x_ = differentiator.x(x, t, axis=self.axis)
else:
self.smoothed_x_ = x
return x_dot
21 changes: 21 additions & 0 deletions test/test_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,24 @@ def test_centered_difference_noaxis_vs_axis(data_2d_resolved_pde):
slow_differences_t,
atol=atol,
)


@pytest.mark.parametrize(
"derivative, kwargs",
[
(SINDyDerivative, {"kind": "finite_difference", "k": 1, "axis": -2}),
(FiniteDifference, {"axis": -2}),
(
SmoothedFiniteDifference,
{"axis": -2, "smoother_kws": {"window_length": 2, "polyorder": 1}},
),
(SpectralDerivative, {"axis": -2}),
],
)
def test_nd_differentiation(derivative, kwargs):
t = np.arange(3)
x = np.random.random(size=(2, 3, 2))
x[1, :, 1] = 1
xdot = derivative(**kwargs)._differentiate(x, t)
expected = np.zeros(3)
np.testing.assert_array_almost_equal(xdot[1, :, 1], expected)

0 comments on commit 15e8093

Please sign in to comment.