Skip to content

Commit

Permalink
added SINDy library input option to WeakPDELibrary
Browse files Browse the repository at this point in the history
  • Loading branch information
znicolaou committed Nov 29, 2023
1 parent 2633ee0 commit 26f7b30
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 27 deletions.
69 changes: 63 additions & 6 deletions examples/12_weakform_SINDy_examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand All @@ -71,9 +71,9 @@
"(x0)' = -9.999 x0 + 9.999 x1\n",
"(x1)' = 27.992 x0 + -0.999 x1 + -1.000 x0 x2\n",
"(x2)' = -2.666 x2 + 1.000 x0 x1\n",
"(x0)' = -10.000 x0 + 10.000 x1\n",
"(x1)' = 28.000 x0 + -1.000 x1 + -1.000 x0x2\n",
"(x2)' = -2.667 x2 + 1.000 x0x1\n"
"(x0)' = -9.911 x0 + 9.930 x1\n",
"(x1)' = 27.910 x0 + -0.930 x1 + -0.949 x0x2\n",
"(x2)' = -2.612 x2 + 0.948 x0x1\n"
]
}
],
Expand All @@ -99,6 +99,7 @@
"# which allows weak form ODEs.\n",
"library_functions = [lambda x: x, lambda x: x * x, lambda x, y: x * y]\n",
"library_function_names = [lambda x: x, lambda x: x + x, lambda x, y: x + y]\n",
"\n",
"ode_lib = ps.WeakPDELibrary(\n",
" library_functions=library_functions,\n",
" function_names=library_function_names,\n",
Expand All @@ -107,15 +108,71 @@
" K=100,\n",
")\n",
"\n",
"\n",
"# Instantiate and fit the SINDy model with the integral of u_dot\n",
"optimizer = ps.SR3(\n",
" threshold=0.05, thresholder=\"l1\", max_iter=1000, normalize_columns=True, tol=1e-1\n",
" threshold=0.05, thresholder=\"l1\", max_iter=1000, normalize_columns=False, tol=1e-1\n",
")\n",
"model = ps.SINDy(feature_library=ode_lib, optimizer=optimizer)\n",
"model.fit(u_train)\n",
"model.print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also provide an existing SINDy library whose output features and names will be used in place of the library_functions and function_names."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(x0)' = -9.915 x0 + 9.931 x1\n",
"(x1)' = 27.914 x0 + -0.931 x1 + -0.949 x0 x2\n",
"(x2)' = -2.612 x2 + 0.948 x0 x1\n"
]
}
],
"source": [
"# Generate measurement data\n",
"dt = 0.002\n",
"t_train = np.arange(0, 10, dt)\n",
"t_train_span = (t_train[0], t_train[-1])\n",
"u0_train = [-8, 8, 27]\n",
"u_train = solve_ivp(\n",
" lorenz, t_train_span, u0_train, t_eval=t_train, **integrator_keywords\n",
").y.T\n",
"\n",
"# Define weak form ODE library\n",
"# defaults to derivative_order = 0 if not specified,\n",
"# and if spatial_grid is not specified, defaults to None,\n",
"# which allows weak form ODEs.\n",
"poly_lib = ps.PolynomialLibrary(\n",
" degree=2,\n",
" include_bias=False\n",
")\n",
"\n",
"ode_lib2 = ps.WeakPDELibrary(\n",
" library=poly_lib,\n",
" spatiotemporal_grid=t_train,\n",
" is_uniform=True,\n",
" K=100,\n",
")\n",
"\n",
"\n",
"# Instantiate and fit the SINDy model with the integral of u_dot\n",
"model = ps.SINDy(feature_library=ode_lib2, optimizer=optimizer)\n",
"model.fit(u_train)\n",
"model.print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -1621,7 +1678,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.11.5"
},
"toc": {
"base_numbering": 1,
Expand Down
74 changes: 53 additions & 21 deletions pysindy/feature_library/weak_pde_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class WeakPDELibrary(BaseFeatureLibrary):
Functions to include in the library. Each function will be
applied to each input variable (but not their derivatives)
library : BaseFeatureLibrary, optional (default None)
SINDy library with output features representing library_functions to include
in the library, in place of library_functions.
derivative_order : int, optional (default 0)
Order of derivative to take on each input variable,
can be arbitrary non-negative integer.
Expand Down Expand Up @@ -154,6 +158,7 @@ class WeakPDELibrary(BaseFeatureLibrary):
def __init__(
self,
library_functions=[],
library=None,
derivative_order=0,
spatiotemporal_grid=None,
function_names=None,
Expand All @@ -172,6 +177,7 @@ def __init__(
periodic=None,
):
self.functions = library_functions
self.library = library
self.derivative_order = derivative_order
self.function_names = function_names
self.interaction_only = interaction_only
Expand All @@ -185,11 +191,23 @@ def __init__(
self.differentiation_method = differentiation_method
self.diff_kwargs = diff_kwargs

if function_names and (len(library_functions) != len(function_names)):
if (
library is None
and function_names
and (len(library_functions) != len(function_names))
):
raise ValueError(
"library_functions and function_names must have the same"
" number of elements"
)
if library is not None and (
function_names is not None or len(library_functions) > 0
):
raise ValueError(

Check warning on line 206 in pysindy/feature_library/weak_pde_library.py

View check run for this annotation

Codecov / codecov/patch

pysindy/feature_library/weak_pde_library.py#L206

Added line #L206 was not covered by tests
"If providing a library, do not provide"
" library_functions or function_names"
)

if library_functions is None and derivative_order == 0:
raise ValueError(
"No library functions were specified, and no "
Expand Down Expand Up @@ -710,13 +728,16 @@ def get_feature_names(self, input_features=None):
feature_names.append("1")

# Include any non-derivative terms
for i, f in enumerate(self.functions):
for c in self._combinations(
n_features, f.__code__.co_argcount, self.interaction_only
):
feature_names.append(
self.function_names[i](*[input_features[j] for j in c])
)
if self.library is not None:
feature_names = self.library.get_feature_names()

Check warning on line 732 in pysindy/feature_library/weak_pde_library.py

View check run for this annotation

Codecov / codecov/patch

pysindy/feature_library/weak_pde_library.py#L732

Added line #L732 was not covered by tests
else:
for i, f in enumerate(self.functions):
for c in self._combinations(
n_features, f.__code__.co_argcount, self.interaction_only
):
feature_names.append(
self.function_names[i](*[input_features[j] for j in c])
)

if self.grid_ndim != 0:

Expand Down Expand Up @@ -779,11 +800,15 @@ def fit(self, x_full, y=None):
n_output_features = 0

# Count the number of non-derivative terms
for f in self.functions:
n_args = f.__code__.co_argcount
n_output_features += len(
list(self._combinations(n_features, n_args, self.interaction_only))
)
if self.library is None:
for f in self.functions:
n_args = f.__code__.co_argcount
n_output_features += len(
list(self._combinations(n_features, n_args, self.interaction_only))
)
else:
self.library.fit(x_full)
n_output_features = self.library.n_output_features_

Check warning on line 811 in pysindy/feature_library/weak_pde_library.py

View check run for this annotation

Codecov / codecov/patch

pysindy/feature_library/weak_pde_library.py#L810-L811

Added lines #L810 - L811 were not covered by tests

if self.grid_ndim != 0:
# Add the mixed derivative library_terms
Expand Down Expand Up @@ -832,6 +857,7 @@ def transform(self, x_full):
self.x_k = [x[np.ix_(*self.inds_k[k])] for k in range(self.K)]

# library function terms

n_library_terms = 0
for f in self.functions:
for c in self._combinations(
Expand All @@ -841,14 +867,20 @@ def transform(self, x_full):
library_functions = np.empty((self.K, n_library_terms), dtype=x.dtype)

# Evaluate the functions on the indices of domain cells
funcs = np.zeros((*x.shape[:-1], n_library_terms))
func_idx = 0
for f in self.functions:
for c in self._combinations(
n_features, f.__code__.co_argcount, self.interaction_only
):
funcs[..., func_idx] = f(*[x[..., j] for j in c])
func_idx += 1
if self.library is not None:
funcs = self.library.fit_transform(x)
n_library_terms = funcs.shape[-1]
library_functions = np.empty((self.K, n_library_terms), dtype=x.dtype)

Check warning on line 873 in pysindy/feature_library/weak_pde_library.py

View check run for this annotation

Codecov / codecov/patch

pysindy/feature_library/weak_pde_library.py#L871-L873

Added lines #L871 - L873 were not covered by tests
else:
funcs = np.zeros((*x.shape[:-1], n_library_terms))
func_idx = 0

for f in self.functions:
for c in self._combinations(
n_features, f.__code__.co_argcount, self.interaction_only
):
funcs[..., func_idx] = f(*[x[..., j] for j in c])
func_idx += 1

# library function terms
for k in range(self.K): # loop over domain cells
Expand Down

0 comments on commit 26f7b30

Please sign in to comment.