Skip to content

Commit

Permalink
add non-conservative convection term
Browse files Browse the repository at this point in the history
  • Loading branch information
qiauil committed Oct 21, 2024
1 parent 52a191d commit 2e752f9
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 15 deletions.
127 changes: 112 additions & 15 deletions exponax/nonlin_fun/_convection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class ConvectionNonlinearFun(BaseNonlinearFun):
derivative_operator: Complex[Array, "D ... (N//2)+1"]
scale: float
single_channel: bool
conservative: bool

def __init__(
self,
Expand All @@ -18,13 +19,14 @@ def __init__(
dealiasing_fraction: float = 2 / 3,
scale: float = 1.0,
single_channel: bool = False,
conservative: bool = False,
):
"""
Performs a pseudo-spectral evaluation of the nonlinear convection, e.g.,
found in the Burgers equation. In 1d and state space, this reads
```
𝒩(u) = - b₁ 1/2 (u²)ₓ
𝒩(u) = -b₁ u (u)ₓ
```
with a scale `b₁`. The minus arises because `Exponax` follows the
Expand All @@ -37,19 +39,41 @@ def __init__(
channels as spatial dimensions and then gives
```
𝒩(u) = - b₁ 1/2 ∇ ⋅ (u ⊗ u)
𝒩(u) = -b₁ u ⋅ ∇ u
```
with `∇ ⋅` the divergence operator and the outer product `u ⊗ u`.
Meanwhile, if you use a conservative form, the convection term is given by
```
𝒩(u) = -b₁ u (u)ₓ
```
for 1D and
```
𝒩(u) = -b₁ ∇ ⋅ (u ⊗ u)
```
for 2D and 3D.
Another option is a "single-channel" hack requiring only one channel no
matter the spatial dimensions. This reads
```
𝒩(u) = - b₁ 1/2 (1⃗ ⋅ ∇)(u²)
𝒩(u) = -b₁ 1/2 (1⃗ ⋅ ∇)(u²)
```
for the conservative form and
```
𝒩(u) = -b₁ 1/2 u (1⃗ ⋅ ∇)u
```
**Arguments:**
for the non-conservative form.
**Arguments:**
- `num_spatial_dims`: The number of spatial dimensions `d`.
- `num_points`: The number of points `N` used to discretize the
domain. This **includes** the left boundary point and **excludes**
Expand All @@ -63,25 +87,27 @@ def __init__(
- `scale`: The scale `b₁` of the convection term. Defaults to `1.0`.
- `single_channel`: Whether to use the single-channel hack. Defaults
to `False`.
- `conservative`: Whether to use the conservative form. Defaults to `False`.
"""
self.derivative_operator = derivative_operator
self.scale = scale
self.single_channel = single_channel
self.conservative=conservative
super().__init__(
num_spatial_dims,
num_points,
dealiasing_fraction=dealiasing_fraction,
)

def _multi_channel_eval(
def _multi_channel_conservative_eval(
self, u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]:
"""
Evaluates the convection term for a multi-channel state `u_hat` in
Evaluates the conservative convection term for a multi-channel state `u_hat` in
Fourier space. The convection term is given by
```
𝒩(u) = b₁ 1/2 ∇ ⋅ (u ⊗ u)
𝒩(u) = -b₁ 1/2 ∇ ⋅ (u ⊗ u)
```
with `∇ ⋅` the divergence operator and the outer product `u ⊗ u`.
Expand All @@ -101,24 +127,61 @@ def _multi_channel_eval(
)
u_hat_dealiased = self.dealias(u_hat)
u = self.ifft(u_hat_dealiased)
u_outer_product = u[:, None] * u[None, :]
u_outer_product = u[None, :] * u[:, None]
u_outer_product_hat = self.fft(u_outer_product)
convection = 0.5 * jnp.sum(
convection = jnp.sum(
self.derivative_operator[None, :] * u_outer_product_hat,
axis=1,
)
# Requires minus to move term to the rhs
return -self.scale * convection

def _multi_channel_nonconservative_eval(
self, u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]:
"""
Evaluates the non-conservative convection term for a multi-channel state `u_hat` in
Fourier space. The convection term is given by
```
𝒩(u) = -b₁ u ⋅ ∇ u
```
**Arguments:**
def _single_channel_eval(
- `u_hat`: The state in Fourier space.
**Returns:**
- `convection`: The evaluation of the convection term in Fourier space.
"""
num_channels = u_hat.shape[0]
if num_channels != self.num_spatial_dims:
raise ValueError(
"Number of channels in u_hat should match number of spatial dimensions"
)
u_hat_dealiased = self.dealias(u_hat)
u = self.ifft(u_hat_dealiased)
nabla_u = self.ifft(self.derivative_operator[None, :] * u_hat[:, None])
conv_u = jnp.sum(
u[None, :] * nabla_u,
axis=1,
)
#conv_u=sum(
# [u[i:i+1]*self.ifft(self.derivative_operator[i:i+1]*u_hat) for i in range(num_channels)]
#)
# Requires minus to move term to the rhs
return -self.scale * self.fft(conv_u)

def _single_channel_conservative_eval(
self, u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]:
"""
Evaluates the convection term for a single-channel state `u_hat` in
Evaluates the conservative convection term for a single-channel state `u_hat` in
Fourier space. The convection term is given by
```
𝒩(u) = b₁ 1/2 (1⃗ ⋅ ∇)(u²)
𝒩(u) = -b₁ 1/2 (1⃗ ⋅ ∇)(u²)
```
with `∇ ⋅` the divergence operator and `1⃗` a vector of ones.
Expand All @@ -142,10 +205,44 @@ def _single_channel_eval(
# Requires minus to bring convection to the right-hand side
return -self.scale * convection

def _single_channel_nonconservative_eval(
self, u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]:
"""
Evaluates the non-conservative convection term for a single-channel state `u_hat` in
Fourier space. The convection term is given by
```
𝒩(u) = -b₁ 1/2 u (1⃗ ⋅ ∇)u
```
**Arguments:**
- `u_hat`: The state in Fourier space.
**Returns:**
- `convection`: The evaluation of the convection term in Fourier space.
"""
u_hat_dealiased = self.dealias(u_hat)
u = self.ifft(u_hat_dealiased)
nabla_u = self.ifft(self.derivative_operator * u_hat)
conv_u = jnp.sum(
u * nabla_u,axis=0,keepdims=True,
)
# Requires minus to bring convection to the right-hand side
return -self.scale * self.fft(conv_u)

def __call__(
self, u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]:
if self.single_channel:
return self._single_channel_eval(u_hat)
if self.conservative:
return self._single_channel_conservative_eval(u_hat)
else:
return self._single_channel_nonconservative_eval(u_hat)
else:
return self._multi_channel_eval(u_hat)
if self.conservative:
return self._multi_channel_conservative_eval(u_hat)
else:
return self._multi_channel_nonconservative_eval(u_hat)
6 changes: 6 additions & 0 deletions exponax/stepper/_burgers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Burgers(BaseStepper):
convection_scale: float
dealiasing_fraction: float
single_channel: bool
conservative: bool

def __init__(
self,
Expand All @@ -21,6 +22,7 @@ def __init__(
diffusivity: float = 0.1,
convection_scale: float = 1.0,
single_channel: bool = False,
conservative: bool = False,
order=2,
dealiasing_fraction: float = 2 / 3,
num_circle_points: int = 16,
Expand Down Expand Up @@ -75,6 +77,8 @@ def __init__(
dimensions. In this case the the convection is `b₁ (∇ ⋅ 1)(u²)`. In
this case, the state always has a single channel, no matter the
spatial dimension. Default: False.
- `conservative`: Whether to use the conservative form of the convection
term. Default: False.
- `order`: The order of the Exponential Time Differencing Runge
Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only
solves the linear part of the equation. Use higher values for higher
Expand Down Expand Up @@ -112,6 +116,7 @@ def __init__(
self.diffusivity = diffusivity
self.convection_scale = convection_scale
self.single_channel = single_channel
self.conservative = conservative
self.dealiasing_fraction = dealiasing_fraction

if single_channel:
Expand Down Expand Up @@ -149,4 +154,5 @@ def _build_nonlinear_fun(
dealiasing_fraction=self.dealiasing_fraction,
scale=self.convection_scale,
single_channel=self.single_channel,
conservative=False,
)
6 changes: 6 additions & 0 deletions exponax/stepper/_korteweg_de_vries.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class KortewegDeVries(BaseStepper):
advect_over_diffuse: bool
diffuse_over_diffuse: bool
single_channel: bool
conservaitve: bool

def __init__(
self,
Expand All @@ -34,6 +35,7 @@ def __init__(
advect_over_diffuse: bool = False,
diffuse_over_diffuse: bool = False,
single_channel: bool = False,
conservative: bool = False,
order: int = 2,
dealiasing_fraction: float = 2 / 3,
num_circle_points: int = 16,
Expand Down Expand Up @@ -108,6 +110,8 @@ def __init__(
dimensions. In this case the the convection is `b₁ (∇ ⋅ 1)(u²)`. In
this case, the state always has a single channel, no matter the
spatial dimension. Default: False.
- `conservative`: Whether to use the conservative form of the convection
term. Default: False.
- `order`: The order of the Exponential Time Differencing Runge
Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only
solves the linear part of the equation. Use higher values for higher
Expand Down Expand Up @@ -145,6 +149,7 @@ def __init__(
self.advect_over_diffuse = advect_over_diffuse
self.diffuse_over_diffuse = diffuse_over_diffuse
self.single_channel = single_channel
self.conservative = conservative
self.dealiasing_fraction = dealiasing_fraction

if single_channel:
Expand Down Expand Up @@ -210,4 +215,5 @@ def _build_nonlinear_fun(
dealiasing_fraction=self.dealiasing_fraction,
scale=self.convection_scale,
single_channel=self.single_channel,
conservative=self.conservative,
)
4 changes: 4 additions & 0 deletions exponax/stepper/_kuramoto_sivashinsky.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ class KuramotoSivashinskyConservative(BaseStepper):
second_order_diffusivity: float
fourth_order_diffusivity: float
single_channel: bool
conservative: bool
dealiasing_fraction: float

def __init__(
Expand All @@ -189,6 +190,7 @@ def __init__(
second_order_diffusivity: float = 1.0,
fourth_order_diffusivity: float = 1.0,
single_channel: bool = False,
conservative: bool = False,
dealiasing_fraction: float = 2 / 3,
order: int = 2,
num_circle_points: int = 16,
Expand All @@ -203,6 +205,7 @@ def __init__(
self.second_order_diffusivity = second_order_diffusivity
self.fourth_order_diffusivity = fourth_order_diffusivity
self.single_channel = single_channel
self.conservative = conservative
self.dealiasing_fraction = dealiasing_fraction

if num_spatial_dims > 1:
Expand Down Expand Up @@ -252,4 +255,5 @@ def _build_nonlinear_fun(
dealiasing_fraction=self.dealiasing_fraction,
scale=self.convection_scale,
single_channel=self.single_channel,
conservative=self.conservative,
)
Loading

0 comments on commit 2e752f9

Please sign in to comment.