From bd95524455b67a0352cfb18c244addbdd48a763c Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 10 Jan 2024 01:36:40 +0000 Subject: [PATCH] feat(utils): Deprecate target format constraints And simplify reorder_constraints --- pysindy/utils/base.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/pysindy/utils/base.py b/pysindy/utils/base.py index cc3a846ab..114916e7f 100644 --- a/pysindy/utils/base.py +++ b/pysindy/utils/base.py @@ -1,3 +1,4 @@ +import warnings from itertools import repeat from typing import Sequence @@ -136,24 +137,17 @@ def drop_nan_samples(x, y): return x, y -def reorder_constraints(c, n_features, output_order="row"): - """Reorder constraint matrix.""" - ret = c.copy() - - if ret.ndim == 1: - ret = ret.reshape(1, -1) - - n_targets = ret.shape[1] // n_features - shape = (n_targets, n_features) - - if output_order == "row": - for i in range(ret.shape[0]): - ret[i] = ret[i].reshape(shape).flatten(order="F") +def reorder_constraints(arr, n_features, output_order="feature"): + """Switch between 'feature' and 'target' constraint order.""" + warnings.warn("Target format constraints are deprecated.", stacklevel=2) + n_constraints = arr.shape[0] if arr.ndim > 1 else 1 + n_tgt = arr.size // n_features // n_constraints + if output_order == "feature": + starting_shape = (n_constraints, n_tgt, n_features) else: - for i in range(ret.shape[0]): - ret[i] = ret[i].reshape(shape, order="F").flatten() + starting_shape = (n_constraints, n_features, n_tgt) - return ret + return arr.reshape(starting_shape).transpose([0, 2, 1]).reshape((n_constraints, -1)) def prox_l0(x, threshold):