Skip to content

Commit

Permalink
fix: Apply correct constraints to enstrophy trapping
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Jul 23, 2024
1 parent 958e246 commit 263dea1
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions pysindy/optimizers/trapping_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,15 @@ def __init__(
self._include_bias = _include_bias
self._interaction_only = _interaction_only
self._n_tgts = _n_tgts
self.mod_matrix = mod_matrix
if _n_tgts is None:
warnings.warn(
"Trapping Optimizer initialized without _n_tgts. It will likely"
" be unable to fit data"
)
self._n_tgts = 1
if self.mod_matrix is None:
self.mod_matrix = np.eye(self._n_tgts)
if method == "global":
if hasattr(kwargs, "constraint_separation_index"):
constraint_separation_index = kwargs["constraint_separation_index"]
Expand All @@ -226,6 +229,7 @@ def __init__(
constraint_rhs, constraint_lhs = _make_constraints(
self._n_tgts, include_bias=_include_bias
)
constraint_lhs = np.tensordot(constraint_lhs, self.mod_matrix, axes=1)
constraint_order = kwargs.pop("constraint_order", "feature")
if constraint_order == "target":
constraint_lhs = np.transpose(constraint_lhs, [0, 2, 1])
Expand Down Expand Up @@ -256,7 +260,6 @@ def __init__(
else:
raise ValueError(f"Can either use 'global' or 'local' method, not {method}")

self.mod_matrix = mod_matrix
self.eps_solver = eps_solver
self.m0 = m0
self.A0 = A0
Expand Down Expand Up @@ -317,8 +320,6 @@ def __post_init_guard(self):
raise ValueError("tol and tol_m must be positive")
if self.inequality_constraints and self.threshold == 0.0:
raise ValueError("Inequality constraints requires threshold!=0")
if self.mod_matrix is None:
self.mod_matrix = np.eye(self._n_tgts)
if self.A0 is None:
self.A0 = np.diag(self.gamma * np.ones(self._n_tgts))
if self.m0 is None:
Expand Down

0 comments on commit 263dea1

Please sign in to comment.