Skip to content

Commit

Permalink
Merge pull request #153 from dynamicslab/weak_optimization
Browse files Browse the repository at this point in the history
Enhanced subdomain integration for the weak form library
  • Loading branch information
akaptano authored Apr 30, 2022
2 parents a63ffcf + 0748ec5 commit d4e64c4
Show file tree
Hide file tree
Showing 23 changed files with 1,172 additions and 697 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
hooks:
- id: reorder-python-imports
- repo: https://github.com/ambv/black
rev: 21.5b1
rev: 22.3.0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
Expand Down
441 changes: 296 additions & 145 deletions examples/12_weakform_SINDy_examples.ipynb

Large diffs are not rendered by default.

46 changes: 27 additions & 19 deletions examples/1_feature_overview.ipynb

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion pysindy/feature_library/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,11 @@ def __init__(
)
self.libraries_ = libraries
self.inputs_per_library_ = inputs_per_library
for lib in self.libraries_:
if hasattr(lib, "spatiotemporal_grid"):
if lib.spatiotemporal_grid is not None:
self.n_samples = lib.K
self.spatiotemporal_grid = lib.spatiotemporal_grid

def _combinations(self, lib_i, lib_j):
"""
Expand Down Expand Up @@ -422,9 +427,12 @@ def transform(self, x):
generated from applying the custom functions to the inputs.
"""
n_samples = x.shape[0]
for lib in self.libraries_:
check_is_fitted(lib)
n_samples = x.shape[0]
if hasattr(lib, "spatiotemporal_grid"):
if lib.spatiotemporal_grid is not None: # check if weak form
n_samples = self.n_samples

# preallocate matrix
xp = np.zeros((n_samples, self.n_output_features_))
Expand Down
3 changes: 3 additions & 0 deletions pysindy/feature_library/generalized_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ def transform(self, x):

n_samples, n_features = x.shape

if isinstance(self.libraries_[0], WeakPDELibrary):
n_samples = self.libraries_[0].K * self.libraries_[0].num_trajectories

if float(__version__[:3]) >= 1.0:
n_input_features = self.n_features_in_
else:
Expand Down
Loading

0 comments on commit d4e64c4

Please sign in to comment.