From 2e752f999870afd4ec9ef6dd8f264ceba02049d4 Mon Sep 17 00:00:00 2001 From: "qiangliu.7@outlook.com" Date: Mon, 21 Oct 2024 11:19:11 +0200 Subject: [PATCH 1/4] add non-conservative convection term --- exponax/nonlin_fun/_convection.py | 127 ++++++++++++++++++++--- exponax/stepper/_burgers.py | 6 ++ exponax/stepper/_korteweg_de_vries.py | 6 ++ exponax/stepper/_kuramoto_sivashinsky.py | 4 + exponax/stepper/generic/_convection.py | 13 +++ 5 files changed, 141 insertions(+), 15 deletions(-) diff --git a/exponax/nonlin_fun/_convection.py b/exponax/nonlin_fun/_convection.py index ea594ce..3923869 100644 --- a/exponax/nonlin_fun/_convection.py +++ b/exponax/nonlin_fun/_convection.py @@ -8,6 +8,7 @@ class ConvectionNonlinearFun(BaseNonlinearFun): derivative_operator: Complex[Array, "D ... (N//2)+1"] scale: float single_channel: bool + conservative: bool def __init__( self, @@ -18,13 +19,14 @@ def __init__( dealiasing_fraction: float = 2 / 3, scale: float = 1.0, single_channel: bool = False, + conservative: bool = False, ): """ Performs a pseudo-spectral evaluation of the nonlinear convection, e.g., found in the Burgers equation. In 1d and state space, this reads ``` - 𝒩(u) = - b₁ 1/2 (u²)ₓ + 𝒩(u) = -b₁ u (u)ₓ ``` with a scale `b₁`. The minus arises because `Exponax` follows the @@ -37,19 +39,41 @@ def __init__( channels as spatial dimensions and then gives ``` - 𝒩(u) = - b₁ 1/2 ∇ ⋅ (u ⊗ u) + 𝒩(u) = -b₁ u ⋅ ∇ u ``` - + with `∇ ⋅` the divergence operator and the outer product `u ⊗ u`. + + Meanwhile, if you use a conservative form, the convection term is given by + + ``` + 𝒩(u) = -b₁ u (u)ₓ + ``` + + for 1D and + + ``` + 𝒩(u) = -b₁ ∇ ⋅ (u ⊗ u) + ``` + + for 2D and 3D. + Another option is a "single-channel" hack requiring only one channel no matter the spatial dimensions. This reads ``` - 𝒩(u) = - b₁ 1/2 (1⃗ ⋅ ∇)(u²) + 𝒩(u) = -b₁ 1/2 (1⃗ ⋅ ∇)(u²) + ``` + + for the conservative form and + + ``` + 𝒩(u) = -b₁ 1/2 u (1⃗ ⋅ ∇)u ``` - **Arguments:** + for the non-conservative form. + **Arguments:** - `num_spatial_dims`: The number of spatial dimensions `d`. - `num_points`: The number of points `N` used to discretize the domain. This **includes** the left boundary point and **excludes** @@ -63,25 +87,27 @@ def __init__( - `scale`: The scale `b₁` of the convection term. Defaults to `1.0`. - `single_channel`: Whether to use the single-channel hack. Defaults to `False`. + - `conservative`: Whether to use the conservative form. Defaults to `False`. """ self.derivative_operator = derivative_operator self.scale = scale self.single_channel = single_channel + self.conservative=conservative super().__init__( num_spatial_dims, num_points, dealiasing_fraction=dealiasing_fraction, ) - def _multi_channel_eval( + def _multi_channel_conservative_eval( self, u_hat: Complex[Array, "C ... (N//2)+1"] ) -> Complex[Array, "C ... (N//2)+1"]: """ - Evaluates the convection term for a multi-channel state `u_hat` in + Evaluates the conservative convection term for a multi-channel state `u_hat` in Fourier space. The convection term is given by ``` - 𝒩(u) = b₁ 1/2 ∇ ⋅ (u ⊗ u) + 𝒩(u) = -b₁ 1/2 ∇ ⋅ (u ⊗ u) ``` with `∇ ⋅` the divergence operator and the outer product `u ⊗ u`. @@ -101,24 +127,61 @@ def _multi_channel_eval( ) u_hat_dealiased = self.dealias(u_hat) u = self.ifft(u_hat_dealiased) - u_outer_product = u[:, None] * u[None, :] + u_outer_product = u[None, :] * u[:, None] u_outer_product_hat = self.fft(u_outer_product) - convection = 0.5 * jnp.sum( + convection = jnp.sum( self.derivative_operator[None, :] * u_outer_product_hat, axis=1, ) # Requires minus to move term to the rhs return -self.scale * convection + + def _multi_channel_nonconservative_eval( + self, u_hat: Complex[Array, "C ... (N//2)+1"] + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Evaluates the non-conservative convection term for a multi-channel state `u_hat` in + Fourier space. The convection term is given by + + ``` + 𝒩(u) = -b₁ u ⋅ ∇ u + ``` + + **Arguments:** - def _single_channel_eval( + - `u_hat`: The state in Fourier space. + + **Returns:** + + - `convection`: The evaluation of the convection term in Fourier space. + """ + num_channels = u_hat.shape[0] + if num_channels != self.num_spatial_dims: + raise ValueError( + "Number of channels in u_hat should match number of spatial dimensions" + ) + u_hat_dealiased = self.dealias(u_hat) + u = self.ifft(u_hat_dealiased) + nabla_u = self.ifft(self.derivative_operator[None, :] * u_hat[:, None]) + conv_u = jnp.sum( + u[None, :] * nabla_u, + axis=1, + ) + #conv_u=sum( + # [u[i:i+1]*self.ifft(self.derivative_operator[i:i+1]*u_hat) for i in range(num_channels)] + #) + # Requires minus to move term to the rhs + return -self.scale * self.fft(conv_u) + + def _single_channel_conservative_eval( self, u_hat: Complex[Array, "C ... (N//2)+1"] ) -> Complex[Array, "C ... (N//2)+1"]: """ - Evaluates the convection term for a single-channel state `u_hat` in + Evaluates the conservative convection term for a single-channel state `u_hat` in Fourier space. The convection term is given by ``` - 𝒩(u) = b₁ 1/2 (1⃗ ⋅ ∇)(u²) + 𝒩(u) = -b₁ 1/2 (1⃗ ⋅ ∇)(u²) ``` with `∇ ⋅` the divergence operator and `1⃗` a vector of ones. @@ -142,10 +205,44 @@ def _single_channel_eval( # Requires minus to bring convection to the right-hand side return -self.scale * convection + def _single_channel_nonconservative_eval( + self, u_hat: Complex[Array, "C ... (N//2)+1"] + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Evaluates the non-conservative convection term for a single-channel state `u_hat` in + Fourier space. The convection term is given by + + ``` + 𝒩(u) = -b₁ 1/2 u (1⃗ ⋅ ∇)u + ``` + + **Arguments:** + + - `u_hat`: The state in Fourier space. + + **Returns:** + + - `convection`: The evaluation of the convection term in Fourier space. + """ + u_hat_dealiased = self.dealias(u_hat) + u = self.ifft(u_hat_dealiased) + nabla_u = self.ifft(self.derivative_operator * u_hat) + conv_u = jnp.sum( + u * nabla_u,axis=0,keepdims=True, + ) + # Requires minus to bring convection to the right-hand side + return -self.scale * self.fft(conv_u) + def __call__( self, u_hat: Complex[Array, "C ... (N//2)+1"] ) -> Complex[Array, "C ... (N//2)+1"]: if self.single_channel: - return self._single_channel_eval(u_hat) + if self.conservative: + return self._single_channel_conservative_eval(u_hat) + else: + return self._single_channel_nonconservative_eval(u_hat) else: - return self._multi_channel_eval(u_hat) + if self.conservative: + return self._multi_channel_conservative_eval(u_hat) + else: + return self._multi_channel_nonconservative_eval(u_hat) diff --git a/exponax/stepper/_burgers.py b/exponax/stepper/_burgers.py index 866e983..ec4dfeb 100644 --- a/exponax/stepper/_burgers.py +++ b/exponax/stepper/_burgers.py @@ -10,6 +10,7 @@ class Burgers(BaseStepper): convection_scale: float dealiasing_fraction: float single_channel: bool + conservative: bool def __init__( self, @@ -21,6 +22,7 @@ def __init__( diffusivity: float = 0.1, convection_scale: float = 1.0, single_channel: bool = False, + conservative: bool = False, order=2, dealiasing_fraction: float = 2 / 3, num_circle_points: int = 16, @@ -75,6 +77,8 @@ def __init__( dimensions. In this case the the convection is `b₁ (∇ ⋅ 1)(u²)`. In this case, the state always has a single channel, no matter the spatial dimension. Default: False. + - `conservative`: Whether to use the conservative form of the convection + term. Default: False. - `order`: The order of the Exponential Time Differencing Runge Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only solves the linear part of the equation. Use higher values for higher @@ -112,6 +116,7 @@ def __init__( self.diffusivity = diffusivity self.convection_scale = convection_scale self.single_channel = single_channel + self.conservative = conservative self.dealiasing_fraction = dealiasing_fraction if single_channel: @@ -149,4 +154,5 @@ def _build_nonlinear_fun( dealiasing_fraction=self.dealiasing_fraction, scale=self.convection_scale, single_channel=self.single_channel, + conservative=False, ) diff --git a/exponax/stepper/_korteweg_de_vries.py b/exponax/stepper/_korteweg_de_vries.py index 0a54008..bbfe5c6 100644 --- a/exponax/stepper/_korteweg_de_vries.py +++ b/exponax/stepper/_korteweg_de_vries.py @@ -19,6 +19,7 @@ class KortewegDeVries(BaseStepper): advect_over_diffuse: bool diffuse_over_diffuse: bool single_channel: bool + conservaitve: bool def __init__( self, @@ -34,6 +35,7 @@ def __init__( advect_over_diffuse: bool = False, diffuse_over_diffuse: bool = False, single_channel: bool = False, + conservative: bool = False, order: int = 2, dealiasing_fraction: float = 2 / 3, num_circle_points: int = 16, @@ -108,6 +110,8 @@ def __init__( dimensions. In this case the the convection is `b₁ (∇ ⋅ 1)(u²)`. In this case, the state always has a single channel, no matter the spatial dimension. Default: False. + - `conservative`: Whether to use the conservative form of the convection + term. Default: False. - `order`: The order of the Exponential Time Differencing Runge Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only solves the linear part of the equation. Use higher values for higher @@ -145,6 +149,7 @@ def __init__( self.advect_over_diffuse = advect_over_diffuse self.diffuse_over_diffuse = diffuse_over_diffuse self.single_channel = single_channel + self.conservative = conservative self.dealiasing_fraction = dealiasing_fraction if single_channel: @@ -210,4 +215,5 @@ def _build_nonlinear_fun( dealiasing_fraction=self.dealiasing_fraction, scale=self.convection_scale, single_channel=self.single_channel, + conservative=self.conservative, ) diff --git a/exponax/stepper/_kuramoto_sivashinsky.py b/exponax/stepper/_kuramoto_sivashinsky.py index 8596bfa..7acabbf 100644 --- a/exponax/stepper/_kuramoto_sivashinsky.py +++ b/exponax/stepper/_kuramoto_sivashinsky.py @@ -176,6 +176,7 @@ class KuramotoSivashinskyConservative(BaseStepper): second_order_diffusivity: float fourth_order_diffusivity: float single_channel: bool + conservative: bool dealiasing_fraction: float def __init__( @@ -189,6 +190,7 @@ def __init__( second_order_diffusivity: float = 1.0, fourth_order_diffusivity: float = 1.0, single_channel: bool = False, + conservative: bool = False, dealiasing_fraction: float = 2 / 3, order: int = 2, num_circle_points: int = 16, @@ -203,6 +205,7 @@ def __init__( self.second_order_diffusivity = second_order_diffusivity self.fourth_order_diffusivity = fourth_order_diffusivity self.single_channel = single_channel + self.conservative = conservative self.dealiasing_fraction = dealiasing_fraction if num_spatial_dims > 1: @@ -252,4 +255,5 @@ def _build_nonlinear_fun( dealiasing_fraction=self.dealiasing_fraction, scale=self.convection_scale, single_channel=self.single_channel, + conservative=self.conservative, ) diff --git a/exponax/stepper/generic/_convection.py b/exponax/stepper/generic/_convection.py index 00df9c2..b4241f4 100644 --- a/exponax/stepper/generic/_convection.py +++ b/exponax/stepper/generic/_convection.py @@ -14,6 +14,7 @@ class GeneralConvectionStepper(BaseStepper): convection_scale: float dealiasing_fraction: float single_channel: bool + conservative: bool def __init__( self, @@ -25,6 +26,7 @@ def __init__( coefficients: tuple[float, ...] = (0.0, 0.0, 0.01), convection_scale: float = 1.0, single_channel: bool = False, + conservative: bool = False, order=2, dealiasing_fraction: float = 2 / 3, num_circle_points: int = 16, @@ -83,6 +85,8 @@ def __init__( dimensions. In this case the the convection is `b₁ (∇ ⋅ 1)(u²)`. In this case, the state always has a single channel, no matter the spatial dimension. Default: False. + - `conservative`: Whether to use the conservative form of the convection + term. Default: False. - `order`: The order of the Exponential Time Differencing Runge Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only solves the linear part of the equation. Use higher values for higher @@ -103,6 +107,7 @@ def __init__( self.convection_scale = convection_scale self.single_channel = single_channel self.dealiasing_fraction = dealiasing_fraction + self.conservative = conservative if single_channel: num_channels = 1 @@ -146,6 +151,7 @@ def _build_nonlinear_fun( dealiasing_fraction=self.dealiasing_fraction, scale=self.convection_scale, single_channel=self.single_channel, + conservative=self.conservative, ) @@ -161,6 +167,7 @@ def __init__( normalized_coefficients: tuple[float, ...] = (0.0, 0.0, 0.01 * 0.1), normalized_convection_scale: float = 1.0 * 0.1, single_channel: bool = False, + conservative: bool = False, order: int = 2, dealiasing_fraction: float = 2 / 3, num_circle_points: int = 16, @@ -210,6 +217,8 @@ def __init__( dimensions. In this case the the convection is `β (∇ ⋅ 1)(u²)`. In this case, the state always has a single channel, no matter the spatial dimension. Default: False. + - `conservative`: Whether to use the conservative form of the convection + term. Default: False. - `order`: The order of the Exponential Time Differencing Runge Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only solves the linear part of the equation. Use higher values for higher @@ -240,6 +249,7 @@ def __init__( num_circle_points=num_circle_points, circle_radius=circle_radius, single_channel=single_channel, + conservative=conservative, ) @@ -255,6 +265,7 @@ def __init__( linear_difficulties: tuple[float, ...] = (0.0, 0.0, 4.5), convection_difficulty: float = 5.0, single_channel: bool = False, + conservative: bool = False, maximum_absolute: float = 1.0, order: int = 2, dealiasing_fraction: float = 2 / 3, @@ -315,6 +326,7 @@ def __init__( dimensions. In this case the the convection is `δ (∇ ⋅ 1)(u²)`. In this case, the state always has a single channel, no matter the spatial dimension. Default: False. + - `conservative`: Whether to use the conservative form of the convection - `maximum_absolute`: The maximum absolute value of the state. This is used to extract the normalized dynamics from the convection difficulty. @@ -359,4 +371,5 @@ def __init__( dealiasing_fraction=dealiasing_fraction, num_circle_points=num_circle_points, circle_radius=circle_radius, + conservative=conservative, ) From 34d1b37899c99e16629c83e64099dc8cec8295ed Mon Sep 17 00:00:00 2001 From: "qiangliu.7@outlook.com" Date: Mon, 21 Oct 2024 12:51:37 +0200 Subject: [PATCH 2/4] fix errors in docstring --- exponax/nonlin_fun/_convection.py | 29 ++++++++++++--------------- exponax/stepper/_korteweg_de_vries.py | 2 +- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/exponax/nonlin_fun/_convection.py b/exponax/nonlin_fun/_convection.py index 3923869..4d221cb 100644 --- a/exponax/nonlin_fun/_convection.py +++ b/exponax/nonlin_fun/_convection.py @@ -41,22 +41,20 @@ def __init__( ``` 𝒩(u) = -b₁ u ⋅ ∇ u ``` - - with `∇ ⋅` the divergence operator and the outer product `u ⊗ u`. Meanwhile, if you use a conservative form, the convection term is given by ``` - 𝒩(u) = -b₁ u (u)ₓ + 𝒩(u) = -1/2 b₁ (u²)ₓ ``` - + for 1D and ``` - 𝒩(u) = -b₁ ∇ ⋅ (u ⊗ u) + 𝒩(u) = -1/2 b₁ ∇ ⋅ (u ⊗ u) ``` - for 2D and 3D. + for 2D and 3D with `∇ ⋅` the divergence operator and the outer product `u ⊗ u`. Another option is a "single-channel" hack requiring only one channel no matter the spatial dimensions. This reads @@ -64,11 +62,11 @@ def __init__( ``` 𝒩(u) = -b₁ 1/2 (1⃗ ⋅ ∇)(u²) ``` - + for the conservative form and - + ``` - 𝒩(u) = -b₁ 1/2 u (1⃗ ⋅ ∇)u + 𝒩(u) = -b₁ u (1⃗ ⋅ ∇)u ``` for the non-conservative form. @@ -92,7 +90,7 @@ def __init__( self.derivative_operator = derivative_operator self.scale = scale self.single_channel = single_channel - self.conservative=conservative + self.conservative = conservative super().__init__( num_spatial_dims, num_points, @@ -135,7 +133,7 @@ def _multi_channel_conservative_eval( ) # Requires minus to move term to the rhs return -self.scale * convection - + def _multi_channel_nonconservative_eval( self, u_hat: Complex[Array, "C ... (N//2)+1"] ) -> Complex[Array, "C ... (N//2)+1"]: @@ -167,9 +165,6 @@ def _multi_channel_nonconservative_eval( u[None, :] * nabla_u, axis=1, ) - #conv_u=sum( - # [u[i:i+1]*self.ifft(self.derivative_operator[i:i+1]*u_hat) for i in range(num_channels)] - #) # Requires minus to move term to the rhs return -self.scale * self.fft(conv_u) @@ -213,7 +208,7 @@ def _single_channel_nonconservative_eval( Fourier space. The convection term is given by ``` - 𝒩(u) = -b₁ 1/2 u (1⃗ ⋅ ∇)u + 𝒩(u) = -b₁ u (1⃗ ⋅ ∇)u ``` **Arguments:** @@ -228,7 +223,9 @@ def _single_channel_nonconservative_eval( u = self.ifft(u_hat_dealiased) nabla_u = self.ifft(self.derivative_operator * u_hat) conv_u = jnp.sum( - u * nabla_u,axis=0,keepdims=True, + u * nabla_u, + axis=0, + keepdims=True, ) # Requires minus to bring convection to the right-hand side return -self.scale * self.fft(conv_u) diff --git a/exponax/stepper/_korteweg_de_vries.py b/exponax/stepper/_korteweg_de_vries.py index bbfe5c6..9ad5b83 100644 --- a/exponax/stepper/_korteweg_de_vries.py +++ b/exponax/stepper/_korteweg_de_vries.py @@ -19,7 +19,7 @@ class KortewegDeVries(BaseStepper): advect_over_diffuse: bool diffuse_over_diffuse: bool single_channel: bool - conservaitve: bool + conservative: bool def __init__( self, From 6fc5697bbad7f0836805d15b7d898e6cd548ed2c Mon Sep 17 00:00:00 2001 From: "qiangliu.7@outlook.com" Date: Mon, 21 Oct 2024 15:06:07 +0200 Subject: [PATCH 3/4] update docstring --- exponax/nonlin_fun/_convection.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/exponax/nonlin_fun/_convection.py b/exponax/nonlin_fun/_convection.py index 4d221cb..da1c3f5 100644 --- a/exponax/nonlin_fun/_convection.py +++ b/exponax/nonlin_fun/_convection.py @@ -42,7 +42,8 @@ def __init__( 𝒩(u) = -b₁ u ⋅ ∇ u ``` - Meanwhile, if you use a conservative form, the convection term is given by + Meanwhile, if you use a conservative form, the convection term is given + by ``` 𝒩(u) = -1/2 b₁ (u²)ₓ @@ -54,7 +55,8 @@ def __init__( 𝒩(u) = -1/2 b₁ ∇ ⋅ (u ⊗ u) ``` - for 2D and 3D with `∇ ⋅` the divergence operator and the outer product `u ⊗ u`. + for 2D and 3D with `∇ ⋅` the divergence operator and the outer product + `u ⊗ u`. Another option is a "single-channel" hack requiring only one channel no matter the spatial dimensions. This reads @@ -72,6 +74,7 @@ def __init__( for the non-conservative form. **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions `d`. - `num_points`: The number of points `N` used to discretize the domain. This **includes** the left boundary point and **excludes** @@ -85,7 +88,8 @@ def __init__( - `scale`: The scale `b₁` of the convection term. Defaults to `1.0`. - `single_channel`: Whether to use the single-channel hack. Defaults to `False`. - - `conservative`: Whether to use the conservative form. Defaults to `False`. + - `conservative`: Whether to use the conservative form. Defaults to + `False`. """ self.derivative_operator = derivative_operator self.scale = scale @@ -101,8 +105,8 @@ def _multi_channel_conservative_eval( self, u_hat: Complex[Array, "C ... (N//2)+1"] ) -> Complex[Array, "C ... (N//2)+1"]: """ - Evaluates the conservative convection term for a multi-channel state `u_hat` in - Fourier space. The convection term is given by + Evaluates the conservative convection term for a multi-channel state + `u_hat` in Fourier space. The convection term is given by ``` 𝒩(u) = -b₁ 1/2 ∇ ⋅ (u ⊗ u) @@ -138,8 +142,8 @@ def _multi_channel_nonconservative_eval( self, u_hat: Complex[Array, "C ... (N//2)+1"] ) -> Complex[Array, "C ... (N//2)+1"]: """ - Evaluates the non-conservative convection term for a multi-channel state `u_hat` in - Fourier space. The convection term is given by + Evaluates the non-conservative convection term for a multi-channel state + `u_hat` in Fourier space. The convection term is given by ``` 𝒩(u) = -b₁ u ⋅ ∇ u @@ -172,8 +176,8 @@ def _single_channel_conservative_eval( self, u_hat: Complex[Array, "C ... (N//2)+1"] ) -> Complex[Array, "C ... (N//2)+1"]: """ - Evaluates the conservative convection term for a single-channel state `u_hat` in - Fourier space. The convection term is given by + Evaluates the conservative convection term for a single-channel state + `u_hat` in Fourier space. The convection term is given by ``` 𝒩(u) = -b₁ 1/2 (1⃗ ⋅ ∇)(u²) @@ -204,8 +208,8 @@ def _single_channel_nonconservative_eval( self, u_hat: Complex[Array, "C ... (N//2)+1"] ) -> Complex[Array, "C ... (N//2)+1"]: """ - Evaluates the non-conservative convection term for a single-channel state `u_hat` in - Fourier space. The convection term is given by + Evaluates the non-conservative convection term for a single-channel + state `u_hat` in Fourier space. The convection term is given by ``` 𝒩(u) = -b₁ u (1⃗ ⋅ ∇)u From c6daa20bed18a3d55d5d284cd5bf8adbce22cfa0 Mon Sep 17 00:00:00 2001 From: "qiangliu.7@outlook.com" Date: Tue, 22 Oct 2024 11:06:48 +0200 Subject: [PATCH 4/4] use the dealiased value for non-conservative convection --- exponax/nonlin_fun/_convection.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/exponax/nonlin_fun/_convection.py b/exponax/nonlin_fun/_convection.py index da1c3f5..39f4aca 100644 --- a/exponax/nonlin_fun/_convection.py +++ b/exponax/nonlin_fun/_convection.py @@ -164,7 +164,9 @@ def _multi_channel_nonconservative_eval( ) u_hat_dealiased = self.dealias(u_hat) u = self.ifft(u_hat_dealiased) - nabla_u = self.ifft(self.derivative_operator[None, :] * u_hat[:, None]) + nabla_u = self.ifft( + self.derivative_operator[None, :] * u_hat_dealiased[:, None] + ) conv_u = jnp.sum( u[None, :] * nabla_u, axis=1, @@ -225,7 +227,7 @@ def _single_channel_nonconservative_eval( """ u_hat_dealiased = self.dealias(u_hat) u = self.ifft(u_hat_dealiased) - nabla_u = self.ifft(self.derivative_operator * u_hat) + nabla_u = self.ifft(self.derivative_operator * u_hat_dealiased) conv_u = jnp.sum( u * nabla_u, axis=0,