Skip to content

Commit

Permalink
feat(pysindy): add keyword args to model.print
Browse files Browse the repository at this point in the history
  • Loading branch information
himkwtn committed Jul 15, 2024
1 parent b5798d1 commit dcd9083
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions pysindy/pysindy.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def equations(self, precision=3):
precision=precision,
)

def print(self, lhs=None, precision=3, flush=False):
def print(self, lhs=None, precision=3, **kwargs):
"""Print the SINDy model equations.
Parameters
Expand All @@ -363,16 +363,22 @@ def print(self, lhs=None, precision=3, flush=False):
precision: int, optional (default 3)
Precision to be used when printing out model coefficients.
flush: bool, optional (default = False)
If flush is true, the output stream is forcibly flushed.
**kwargs: Additional keyword arguments passed to the print function:
- sep: str, optional (default=' ')
string inserted between values, default a space.
- end: str, optional (default='\\n')
string appended after the last value, default a newline.
- file: str, optional (default = None)
a file-like object (stream); defaults to the current sys.stdout.
- flush: bool, optional (default = False)
whether to forcibly flush the stream.
"""
eqns = self.equations(precision)
if sindy_pi_flag and isinstance(self.optimizer, SINDyPI):
feature_names = self.get_feature_names()
else:
feature_names = self.feature_names
for i, eqn in enumerate(eqns):
names = None
if self.discrete_time:
names = f"({feature_names[i]})[k+1]"
elif lhs is None:
Expand All @@ -382,7 +388,7 @@ def print(self, lhs=None, precision=3, flush=False):
names = f"({feature_names[i]})"
else:
names = f"{lhs[i]}"
print(f"{names} = {eqn}", flush=flush)
print(f"{names} = {eqn}", **kwargs)

def score(self, x, t=None, x_dot=None, u=None, metric=r2_score, **metric_kws):
"""
Expand Down

0 comments on commit dcd9083

Please sign in to comment.