Skip to content

Commit

Permalink
Enable interp functions for simulating with control inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
briandesilva committed May 4, 2020
1 parent 556b94c commit fe286f5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
11 changes: 9 additions & 2 deletions pysindy/pysindy.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,9 +682,16 @@ def rhs(x, t):
return self.predict(x[newaxis, :])[0]

else:
if ndim(u(1)) == 1:

def rhs(x, t):
return self.predict(x[newaxis, :], u(t))[0]
def rhs(x, t):
print(t)
return self.predict(x[newaxis, :], u(t).reshape(1, -1))[0]

else:

def rhs(x, t):
return self.predict(x[newaxis, :], u(t))[0]

return integrator(rhs, x0, t, **integrator_kws)

Expand Down
22 changes: 19 additions & 3 deletions test/test_sindyc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pytest
from scipy.integrate import odeint
from scipy.interpolate import interp1d
from sklearn.exceptions import ConvergenceWarning
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import ElasticNet
Expand Down Expand Up @@ -126,15 +127,15 @@ def test_improper_shape_input(data_1d):
check_is_fitted(model)

model = SINDy()
model.fit(x.flatten(), u=u.reshape(-1, 1), t=t)
model.fit(x.flatten(), u=u.flatten(), t=t)
check_is_fitted(model)

model = SINDy()
model.fit(x.flatten(), u=u.reshape(-1, 1), t=t, x_dot=x.flatten())
model.fit(x.flatten(), u=u.flatten(), t=t, x_dot=x.flatten())
check_is_fitted(model)

model = SINDy()
model.fit(x, u=u.reshape(-1, 1), t=t, x_dot=x.flatten())
model.fit(x, u=u.flatten(), t=t, x_dot=x.flatten())
check_is_fitted(model)

# Should fail if x and u have incompatible numbers of rows
Expand Down Expand Up @@ -243,6 +244,21 @@ def test_simulate(data):
assert len(x1) == len(t)


@pytest.mark.parametrize(
"data",
[pytest.lazy_fixture("data_lorenz_c_1d"), pytest.lazy_fixture("data_lorenz_c_2d")],
)
def test_simulate_with_interp(data):
x, t, u, _ = data
model = SINDy()
model.fit(x, u=u, t=t)

u_fun = interp1d(t, u, axis=0)
x1 = model.simulate(x[0], t=t[:-1], u=u_fun)

assert len(x1) == len(t) - 1


@pytest.mark.parametrize(
"data",
[pytest.lazy_fixture("data_lorenz_c_1d"), pytest.lazy_fixture("data_lorenz_c_2d")],
Expand Down

0 comments on commit fe286f5

Please sign in to comment.