Skip to content

Commit

Permalink
Regular convection does not need zero mode fix because the derivative…
Browse files Browse the repository at this point in the history
… operators already set it to zero
  • Loading branch information
Ceyron committed Mar 13, 2024
1 parent 2986e22 commit 965e896
Showing 1 changed file with 1 addition and 15 deletions.
16 changes: 1 addition & 15 deletions exponax/nonlin_fun/_convection.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import jax
import jax.numpy as jnp
from jaxtyping import Array, Complex, Float
from jaxtyping import Array, Complex

from .._spectral import space_indices, spatial_shape
from ._base import BaseNonlinearFun


class ConvectionNonlinearFun(BaseNonlinearFun):
scale: float
zero_mode_fix: bool

def __init__(
self,
Expand All @@ -19,13 +17,11 @@ def __init__(
derivative_operator: Complex[Array, "D ... (N//2)+1"],
dealiasing_fraction: float,
scale: float = 1.0,
zero_mode_fix: bool = False,
):
"""
Uses by default a scaling of 0.5 to take into account the conservative evaluation
"""
self.scale = scale
self.zero_mode_fix = zero_mode_fix
super().__init__(
num_spatial_dims,
num_points,
Expand All @@ -34,12 +30,6 @@ def __init__(
dealiasing_fraction=dealiasing_fraction,
)

def zero_fix(
self,
f: Float[Array, "... N"],
):
return f - jnp.mean(f)

def evaluate(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
Expand All @@ -52,10 +42,6 @@ def evaluate(
)
u_outer_product = u[:, None] * u[None, :]

if self.zero_mode_fix:
# Maybe there is more efficient way
u_outer_product = jax.vmap(self.zero_fix)(u_outer_product)

u_outer_product_hat = jnp.fft.rfftn(
u_outer_product, axes=space_indices(self.num_spatial_dims)
)
Expand Down

0 comments on commit 965e896

Please sign in to comment.