diff --git a/exponax/stepper/_kuramoto_sivashinsky.py b/exponax/stepper/_kuramoto_sivashinsky.py index b4f84c4..fcc39ec 100644 --- a/exponax/stepper/_kuramoto_sivashinsky.py +++ b/exponax/stepper/_kuramoto_sivashinsky.py @@ -7,8 +7,8 @@ class KuramotoSivashinsky(BaseStepper): gradient_norm_scale: float - second_order_diffusivity: float - fourth_order_diffusivity: float + second_order_scale: float + fourth_order_scale: float dealiasing_fraction: float def __init__( @@ -19,8 +19,8 @@ def __init__( dt: float, *, gradient_norm_scale: float = 1.0, - second_order_diffusivity: float = 1.0, - fourth_order_diffusivity: float = 1.0, + second_order_scale: float = 1.0, + fourth_order_scale: float = 1.0, dealiasing_fraction: float = 2 / 3, order: int = 2, num_circle_points: int = 16, @@ -36,24 +36,25 @@ def __init__( In 1d, the KS equation is given by ``` - uₜ + b₂ 1/2 (uₓ)² + ν uₓₓ + μ uₓₓₓₓ = 0 + uₜ + b₂ 1/2 (uₓ)² + ψ₁ uₓₓ + ψ₂ uₓₓₓₓ = 0 ``` - with `b₂` the gradient-norm coefficient, `ν` the diffusivity and `μ` the - hyper viscosity. Note that both viscosity terms are on the left-hand - side. As such for `ν, μ > 0`, the second-order term acts destabilizing - (increases the energy of the system) and the fourth-order term acts - stabilizing (decreases the energy of the system). A common configuration - is `b₂ = ν = μ = 1` and the dynamics are only adapted using the - `domain_extent`. For this, we espect the KS equation to experience - spatio-temporal chaos roughly once `L > 60`. + with `b₂` the gradient-norm coefficient, `ψ₁` the second-order scale and + `ψ₂` the fourth-order. If the latter two terms were on the right-hand + side, they could be interpreted as diffusivity and hyper-diffusivity, + respectively. Here, the second-order term acts destabilizing (increases + the energy of the system) and the fourth-order term acts stabilizing + (decreases the energy of the system). A common configuration is `b₂ = ψ₁ + = ψ₂ = 1` and the dynamics are only adapted using the `domain_extent`. + For this, we espect the KS equation to experience spatio-temporal chaos + roughly once `L > 60`. In this combustion (=non-conservative) format, the number of channels does **not** grow with the spatial dimension. A 2d KS still only has a single channel. In higher dimensions, the equation reads ``` - uₜ + b₂ 1/2 ‖ ∇u ‖₂² + ν (∇ ⋅ ∇) u + μ ((∇ ⊗ ∇) ⋅ (∇ ⊗ ∇))u = 0 + uₜ + b₂ 1/2 ‖ ∇u ‖₂² + ψ₁ν (∇ ⋅ ∇) u + ψ₂ ((∇ ⊗ ∇) ⋅ (∇ ⊗ ∇))u = 0 ``` with `‖ ∇u ‖₂` the gradient norm, `∇ ⋅ ∇` effectively is the Laplace @@ -75,14 +76,9 @@ def __init__( - `gradient_norm_scale`: The gradient-norm coefficient `b₂`. Note that the gradient norm is already scaled by 1/2. This factor allows for further modification. Default: 1.0. - - `second_order_diffusivity`: The diffusivity `ν` in the KS - equation. The sign of this coefficient is interpreted as if the term - was on the left-hand side. Hence it should have a positive value to - act destabilizing. Default: 1.0. - - `fourth_order_diffusivity`: The hyper viscosity `μ` in the KS - equation. The sign of this coefficient is interpreted as if the term - was on the left-hand side. Hence it should have a positive value to - act stabilizing. Default: 1.0. + - `second_order_scale`: The "diffusivity" `ψ₁` in the KS equation. + - `fourth_order_diffusivity`: The "hyper-diffusivity" `ψ₂` in the KS + equation. - `order`: The order of the Exponential Time Differencing Runge Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only solves the linear part of the equation. Use higher values for higher @@ -132,8 +128,8 @@ def __init__( the transitional phase, after that the chaotic attractor is reached. """ self.gradient_norm_scale = gradient_norm_scale - self.second_order_diffusivity = second_order_diffusivity - self.fourth_order_diffusivity = fourth_order_diffusivity + self.second_order_scale = second_order_scale + self.fourth_order_scale = fourth_order_scale self.dealiasing_fraction = dealiasing_fraction super().__init__( num_spatial_dims=num_spatial_dims, @@ -150,9 +146,10 @@ def _build_linear_operator( self, derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> Complex[Array, "1 ... (N//2)+1"]: - linear_operator = -self.second_order_diffusivity * build_laplace_operator( + # Minuses are required to move the terms to the right-hand side + linear_operator = -self.second_order_scale * build_laplace_operator( derivative_operator, order=2 - ) - self.fourth_order_diffusivity * build_laplace_operator( + ) - self.fourth_order_scale * build_laplace_operator( derivative_operator, order=4 ) return linear_operator @@ -173,8 +170,8 @@ def _build_nonlinear_fun( class KuramotoSivashinskyConservative(BaseStepper): convection_scale: float - second_order_diffusivity: float - fourth_order_diffusivity: float + second_order_scale: float + fourth_order_scale: float single_channel: bool conservative: bool dealiasing_fraction: float @@ -187,8 +184,8 @@ def __init__( dt: float, *, convection_scale: float = 1.0, - second_order_diffusivity: float = 1.0, - fourth_order_diffusivity: float = 1.0, + second_order_scale: float = 1.0, + fourth_order_scale: float = 1.0, single_channel: bool = False, conservative: bool = True, dealiasing_fraction: float = 2 / 3, @@ -202,8 +199,8 @@ def __init__( the number of spatial dimensions. """ self.convection_scale = convection_scale - self.second_order_diffusivity = second_order_diffusivity - self.fourth_order_diffusivity = fourth_order_diffusivity + self.second_order_scale = second_order_scale + self.fourth_order_scale = fourth_order_scale self.single_channel = single_channel self.conservative = conservative self.dealiasing_fraction = dealiasing_fraction @@ -237,9 +234,10 @@ def _build_linear_operator( self, derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> Complex[Array, "1 ... (N//2)+1"]: - linear_operator = -self.second_order_diffusivity * build_laplace_operator( + # Minuses are required to move the terms to the right-hand side + linear_operator = -self.second_order_scale * build_laplace_operator( derivative_operator, order=2 - ) - self.fourth_order_diffusivity * build_laplace_operator( + ) - self.fourth_order_scale * build_laplace_operator( derivative_operator, order=4 ) return linear_operator diff --git a/tests/test_builtin_solvers.py b/tests/test_builtin_solvers.py index 9960920..9725efd 100644 --- a/tests/test_builtin_solvers.py +++ b/tests/test_builtin_solvers.py @@ -174,8 +174,8 @@ def test_specific_stepper_to_general_linear_stepper( 50, 0.1, convection_scale=1.0, - second_order_diffusivity=1.0, - fourth_order_diffusivity=1.0, + second_order_scale=1.0, + fourth_order_scale=1.0, ), 1.0, [0.0, 0.0, -1.0, 0.0, -1.0], @@ -254,8 +254,8 @@ def test_specific_stepper_to_general_convection_stepper( 50, 0.1, gradient_norm_scale=1.0, - second_order_diffusivity=1.0, - fourth_order_diffusivity=1.0, + second_order_scale=1.0, + fourth_order_scale=1.0, ), 1.0, [0.0, 0.0, -1.0, 0.0, -1.0],