From f737ae13d4236e439effb96daf1bdc677f3d6215 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Wed, 3 Jan 2024 17:07:13 +0000 Subject: [PATCH] BUG: Set mutable default function_library in __init__() body --- pysindy/feature_library/pde_library.py | 8 ++++---- pysindy/feature_library/weak_pde_library.py | 7 ++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pysindy/feature_library/pde_library.py b/pysindy/feature_library/pde_library.py index 89b53a35c..44ce3a5c6 100644 --- a/pysindy/feature_library/pde_library.py +++ b/pysindy/feature_library/pde_library.py @@ -1,5 +1,6 @@ import warnings from itertools import product as iproduct +from typing import Optional import numpy as np from sklearn.utils.validation import check_is_fitted @@ -85,9 +86,7 @@ class PDELibrary(BaseFeatureLibrary): def __init__( self, - function_library: BaseFeatureLibrary = PolynomialLibrary( - degree=3, include_bias=False - ), + function_library: Optional[BaseFeatureLibrary] = None, derivative_order=0, spatial_grid=None, temporal_grid=None, @@ -110,7 +109,8 @@ def __init__( self.num_trajectories = 1 self.differentiation_method = differentiation_method self.diff_kwargs = diff_kwargs - + if function_library is None: + self.function_library = PolynomialLibrary(degree=3, include_bias=False) if derivative_order < 0: raise ValueError("The derivative order must be >0") diff --git a/pysindy/feature_library/weak_pde_library.py b/pysindy/feature_library/weak_pde_library.py index fe82518ff..16e2792a0 100644 --- a/pysindy/feature_library/weak_pde_library.py +++ b/pysindy/feature_library/weak_pde_library.py @@ -1,5 +1,6 @@ import warnings from itertools import product as iproduct +from typing import Optional import numpy as np from scipy.special import binom @@ -135,9 +136,7 @@ class WeakPDELibrary(BaseFeatureLibrary): def __init__( self, - function_library: BaseFeatureLibrary = PolynomialLibrary( - degree=3, include_bias=False - ), + function_library: Optional[BaseFeatureLibrary] = None, derivative_order=0, spatiotemporal_grid=None, interaction_only=True, @@ -166,6 +165,8 @@ def __init__( self.num_trajectories = 1 self.differentiation_method = differentiation_method self.diff_kwargs = diff_kwargs + if function_library is None: + self.function_library = PolynomialLibrary(degree=3, include_bias=False) if spatiotemporal_grid is None: raise ValueError(