diff --git a/pysindy/optimizers/constrained_sr3.py b/pysindy/optimizers/constrained_sr3.py index 59288e77..f1d02fa4 100644 --- a/pysindy/optimizers/constrained_sr3.py +++ b/pysindy/optimizers/constrained_sr3.py @@ -276,8 +276,10 @@ def _calculate_penalty( """ Args: ----- - regularization: 'l0' | 'weighted_l0' | 'l1' | 'weighted_l1' | 'l2' | 'weighted_l2' - regularization_weight: float | np.array, can be a scalar or an array of shape (n_targets, n_features) + regularization: 'l0' | 'weighted_l0' | 'l1' | 'weighted_l1' | + 'l2' | 'weighted_l2' + regularization_weight: float | np.array, can be a scalar + or an array of shape (n_targets, n_features) xi: cp.Variable Returns: diff --git a/pysindy/utils/base.py b/pysindy/utils/base.py index 2f33b63b..cc1733b3 100644 --- a/pysindy/utils/base.py +++ b/pysindy/utils/base.py @@ -191,7 +191,7 @@ def get_prox( -------- proximal_operator: (x: np.array, reg_weight: float | np.array) -> np.array A function that takes an input x of shape (n_targets, n_features) - and regularization weight factor which can be a scalar or + and regularization weight factor which can be a scalar or an array of shape (n_targets, n_features), and returns an array of shape (n_targets, n_features) """ @@ -252,7 +252,7 @@ def get_regularization( -------- regularization_function: (x: np.array, reg_weight: float | np.array) -> np.array A function that takes an input x of shape (n_targets, n_features) - and regularization weight factor which can be a scalar or + and regularization weight factor which can be a scalar or an array of shape (n_targets, n_features), and returns a float """