diff --git a/pysindy/differentiation/finite_difference.py b/pysindy/differentiation/finite_difference.py index f29ab8ec5..17e828dc1 100644 --- a/pysindy/differentiation/finite_difference.py +++ b/pysindy/differentiation/finite_difference.py @@ -1,4 +1,7 @@ +from typing import Union + import numpy as np +from numpy.typing import NDArray from .base import BaseDifferentiation from pysindy.utils.axes import AxesArray @@ -218,7 +221,9 @@ def _accumulate(self, coeffs, x): np.roll(np.arange(x.ndim), self.axis), ) - def _differentiate(self, x, t): + def _differentiate( + self, x: NDArray, t: Union[NDArray, float, list[float]] + ) -> NDArray: """ Apply finite difference method. """ @@ -249,6 +254,7 @@ def _differentiate(self, x, t): s[self.axis] = slice(start, stop) interior = interior + x[tuple(s)] * coeffs[i] else: + t = AxesArray(np.array(t), axes={"ax_time": 0}) coeffs = self._coefficients(t) interior = self._accumulate(coeffs, x) s[self.axis] = slice((self.n_stencil - 1) // 2, -(self.n_stencil - 1) // 2)