Skip to content

Commit

Permalink
Major revision to the nonlinear function interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Mar 13, 2024
1 parent 965e896 commit a1fcbfb
Show file tree
Hide file tree
Showing 31 changed files with 301 additions and 320 deletions.
4 changes: 2 additions & 2 deletions exponax/nonlin_fun/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ._base import BaseNonlinearFun
from ._convection import ConvectionNonlinearFun
from ._general_nonlinear import GeneralNonlinearFun1d
from ._general_nonlinear import GeneralNonlinearFun
from ._gradient_norm import GradientNormNonlinearFun
from ._polynomial import PolynomialNonlinearFun
from ._vorticity_convection import (
Expand All @@ -12,7 +12,7 @@
__all__ = [
"BaseNonlinearFun",
"ConvectionNonlinearFun",
"GeneralNonlinearFun1d",
"GeneralNonlinearFun",
"GradientNormNonlinearFun",
"PolynomialNonlinearFun",
"VorticityConvection2d",
Expand Down
90 changes: 54 additions & 36 deletions exponax/nonlin_fun/_base.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,84 @@
from abc import ABC, abstractmethod
from typing import Optional

import equinox as eqx
from jaxtyping import Array, Bool, Complex
import jax.numpy as jnp
from jaxtyping import Array, Bool, Complex, Float

from .._spectral import low_pass_filter_mask, wavenumber_shape
from .._spectral import low_pass_filter_mask, space_indices, spatial_shape


class BaseNonlinearFun(eqx.Module, ABC):
num_spatial_dims: int
num_points: int
num_channels: int
derivative_operator: Complex[Array, "D ... (N//2)+1"]
dealiasing_mask: Bool[Array, "1 ... (N//2)+1"]
dealiasing_mask: Optional[Bool[Array, "1 ... (N//2)+1"]]

def __init__(
self,
num_spatial_dims: int,
num_points: int,
num_channels: int,
*,
derivative_operator: Complex[Array, "D ... (N//2)+1"],
dealiasing_fraction: float,
dealiasing_fraction: Optional[float] = None,
):
self.num_spatial_dims = num_spatial_dims
self.num_points = num_points
self.num_channels = num_channels
self.derivative_operator = derivative_operator

# Can be done because num_points is identical in all spatial dimensions
nyquist_mode = (num_points // 2) + 1
highest_resolved_mode = nyquist_mode - 1
start_of_aliased_modes = dealiasing_fraction * highest_resolved_mode
if dealiasing_fraction is None:
self.dealiasing_mask = None
else:
# Can be done because num_points is identical in all spatial dimensions
nyquist_mode = (num_points // 2) + 1
highest_resolved_mode = nyquist_mode - 1
start_of_aliased_modes = dealiasing_fraction * highest_resolved_mode

self.dealiasing_mask = low_pass_filter_mask(
num_spatial_dims,
num_points,
cutoff=start_of_aliased_modes - 1,
)
self.dealiasing_mask = low_pass_filter_mask(
num_spatial_dims,
num_points,
cutoff=start_of_aliased_modes - 1,
)

@abstractmethod
def evaluate(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
def dealias(
self, u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]:
"""
Evaluate all potential nonlinearities "pseudo-spectrally", account for dealiasing.
"""
raise NotImplementedError("Must be implemented by subclass")
if self.dealiasing_mask is None:
raise ValueError("Nonlinear function was set up without dealiasing")
return self.dealiasing_mask * u_hat

def fft(self, u: Float[Array, "C ... N"]) -> Complex[Array, "C ... (N//2)+1"]:
return jnp.fft.rfftn(u, axes=space_indices(self.num_spatial_dims))

def ifft(self, u_hat: Complex[Array, "C ... (N//2)+1"]) -> Float[Array, "C ... N"]:
return jnp.fft.irfftn(
u_hat,
s=spatial_shape(self.num_spatial_dims, self.num_points),
axes=space_indices(self.num_spatial_dims),
)

# @abstractmethod
# def evaluate(
# self,
# u_hat: Complex[Array, "C ... (N//2)+1"],
# ) -> Complex[Array, "C ... (N//2)+1"]:
# """
# Evaluate all potential nonlinearities "pseudo-spectrally", account for dealiasing.
# """
# raise NotImplementedError("Must be implemented by subclass")

@abstractmethod
def __call__(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
"""
Perform check
Evaluate all potential nonlinearities "pseudo-spectrally", account for dealiasing.
"""
expected_shape = (self.num_channels,) + wavenumber_shape(
self.num_spatial_dims, self.num_points
)
if u_hat.shape != expected_shape:
raise ValueError(
f"Expected shape {expected_shape}, got {u_hat.shape}. For batched operation use `jax.vmap` on this function."
)
# expected_shape = (self.num_channels,) + wavenumber_shape(
# self.num_spatial_dims, self.num_points
# )
# if u_hat.shape != expected_shape:
# raise ValueError(
# f"Expected shape {expected_shape}, got {u_hat.shape}. For batched operation use `jax.vmap` on this function."
# )

return self.evaluate(u_hat)
# return self.evaluate(u_hat)
pass
55 changes: 36 additions & 19 deletions exponax/nonlin_fun/_convection.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import jax.numpy as jnp
from jaxtyping import Array, Complex

from .._spectral import space_indices, spatial_shape
from ._base import BaseNonlinearFun


class ConvectionNonlinearFun(BaseNonlinearFun):
derivative_operator: Complex[Array, "D ... (N//2)+1"]
scale: float

def __init__(
self,
num_spatial_dims: int,
num_points: int,
num_channels: int,
*,
derivative_operator: Complex[Array, "D ... (N//2)+1"],
dealiasing_fraction: float,
Expand All @@ -21,33 +20,51 @@ def __init__(
"""
Uses by default a scaling of 0.5 to take into account the conservative evaluation
"""
self.derivative_operator = derivative_operator
self.scale = scale
super().__init__(
num_spatial_dims,
num_points,
num_channels,
derivative_operator=derivative_operator,
dealiasing_fraction=dealiasing_fraction,
)

def evaluate(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
def __call__(
self, u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]:
u_hat_dealiased = self.dealiasing_mask * u_hat
u = jnp.fft.irfftn(
u_hat_dealiased,
s=spatial_shape(self.num_spatial_dims, self.num_points),
axes=space_indices(self.num_spatial_dims),
)
num_channels = u_hat.shape[0]
if num_channels != self.num_spatial_dims:
raise ValueError(
"Number of channels in u_hat should match number of spatial dimensions"
)
u_hat_dealiased = self.dealias(u_hat)
u = self.ifft(u_hat_dealiased)
u_outer_product = u[:, None] * u[None, :]

u_outer_product_hat = jnp.fft.rfftn(
u_outer_product, axes=space_indices(self.num_spatial_dims)
)
u_divergence_on_outer_product_hat = jnp.sum(
u_outer_product_hat = self.fft(u_outer_product)
convection = 0.5 * jnp.sum(
self.derivative_operator[None, :] * u_outer_product_hat,
axis=1,
)
# Requires minus to move term to the rhs
return -self.scale * 0.5 * u_divergence_on_outer_product_hat
return -self.scale * convection

# def evaluate(
# self,
# u_hat: Complex[Array, "C ... (N//2)+1"],
# ) -> Complex[Array, "C ... (N//2)+1"]:
# u_hat_dealiased = self.dealiasing_mask * u_hat
# u = jnp.fft.irfftn(
# u_hat_dealiased,
# s=spatial_shape(self.num_spatial_dims, self.num_points),
# axes=space_indices(self.num_spatial_dims),
# )
# u_outer_product = u[:, None] * u[None, :]

# u_outer_product_hat = jnp.fft.rfftn(
# u_outer_product, axes=space_indices(self.num_spatial_dims)
# )
# u_divergence_on_outer_product_hat = jnp.sum(
# self.derivative_operator[None, :] * u_outer_product_hat,
# axis=1,
# )
# # Requires minus to move term to the rhs
# return -self.scale * 0.5 * u_divergence_on_outer_product_hat
40 changes: 16 additions & 24 deletions exponax/nonlin_fun/_general_nonlinear.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,71 @@
from jaxtyping import Array, Complex

from ._base import BaseNonlinearFun
from ._convection import ConvectionNonlinearFun
from ._gradient_norm import GradientNormNonlinearFun
from ._polynomial import PolynomialNonlinearFun
from ._single_channel_convection import SingleChannelConvectionNonlinearFun


class GeneralNonlinearFun1d(BaseNonlinearFun):
class GeneralNonlinearFun(BaseNonlinearFun):
square_nonlinear_fun: PolynomialNonlinearFun
convection_nonlinear_fun: ConvectionNonlinearFun
convection_nonlinear_fun: SingleChannelConvectionNonlinearFun
gradient_norm_nonlinear_fun: GradientNormNonlinearFun

def __init__(
self,
num_spatial_dims: int,
num_points: int,
num_channels: int,
*,
derivative_operator: Complex[Array, "D ... (N//2)+1"],
dealiasing_fraction: float,
scale_list: list[float] = [0.0, -1.0, 0.0],
zero_mode_fix: bool = False,
zero_mode_fix: bool = True,
):
"""
Uses an additional scaling of 0.5 on the latter two components only
By default: Burgers equation
"""
if num_spatial_dims != 1:
raise ValueError("The number of spatial dimensions must be 1")
if len(scale_list) != 3:
raise ValueError("The scale list must have exactly 3 elements")

self.square_nonlinear_fun = PolynomialNonlinearFun(
num_spatial_dims=num_spatial_dims,
num_points=num_points,
num_channels=num_channels,
num_spatial_dims,
num_points,
derivative_operator=derivative_operator,
dealiasing_fraction=dealiasing_fraction,
coefficients=[0.0, 0.0, scale_list[0]],
)
self.convection_nonlinear_fun = ConvectionNonlinearFun(
num_spatial_dims=num_spatial_dims,
num_points=num_points,
num_channels=num_channels,
self.convection_nonlinear_fun = SingleChannelConvectionNonlinearFun(
num_spatial_dims,
num_points,
derivative_operator=derivative_operator,
dealiasing_fraction=dealiasing_fraction,
# Minus required because it internally has another minus
scale=-scale_list[1],
zero_mode_fix=zero_mode_fix,
)
self.gradient_norm_nonlinear_fun = GradientNormNonlinearFun(
num_spatial_dims=num_spatial_dims,
num_points=num_points,
num_channels=num_channels,
num_spatial_dims,
num_points,
derivative_operator=derivative_operator,
dealiasing_fraction=dealiasing_fraction,
# Minus required because it internally has another minus
scale=-scale_list[2],
zero_mode_fix=zero_mode_fix,
)

super().__init__(
num_spatial_dims,
num_points,
num_channels,
derivative_operator=derivative_operator,
dealiasing_fraction=dealiasing_fraction,
)

def evaluate(
def __call__(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
return (
self.square_nonlinear_fun.evaluate(u_hat)
+ self.convection_nonlinear_fun.evaluate(u_hat)
+ self.gradient_norm_nonlinear_fun.evaluate(u_hat)
self.square_nonlinear_fun(u_hat)
+ self.convection_nonlinear_fun(u_hat)
+ self.gradient_norm_nonlinear_fun(u_hat)
)
21 changes: 6 additions & 15 deletions exponax/nonlin_fun/_gradient_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@
import jax.numpy as jnp
from jaxtyping import Array, Complex, Float

from .._spectral import space_indices, spatial_shape
from ._base import BaseNonlinearFun


class GradientNormNonlinearFun(BaseNonlinearFun):
scale: float
zero_mode_fix: bool
derivative_operator: Complex[Array, "D ... (N//2)+1"]

def __init__(
self,
num_spatial_dims: int,
num_points: int,
num_channels: int,
*,
derivative_operator: Complex[Array, "D ... (N//2)+1"],
dealiasing_fraction: float,
Expand All @@ -27,10 +26,9 @@ def __init__(
super().__init__(
num_spatial_dims,
num_points,
num_channels,
derivative_operator=derivative_operator,
dealiasing_fraction=dealiasing_fraction,
)
self.derivative_operator = derivative_operator
self.zero_mode_fix = zero_mode_fix
self.scale = scale

Expand All @@ -40,17 +38,12 @@ def zero_fix(
):
return f - jnp.mean(f)

def evaluate(
def __call__(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
u_gradient_hat = self.derivative_operator[None, :] * u_hat[:, None]
u_gradient_dealiased_hat = self.dealiasing_mask * u_gradient_hat
u_gradient = jnp.fft.irfftn(
u_gradient_dealiased_hat,
s=spatial_shape(self.num_spatial_dims, self.num_points),
axes=space_indices(self.num_spatial_dims),
)
u_gradient = self.ifft(self.dealias(u_gradient_hat))

# Reduces the axis introduced by the gradient
u_gradient_norm_squared = jnp.sum(u_gradient**2, axis=1)
Expand All @@ -59,14 +52,12 @@ def evaluate(
# Maybe there is more efficient way
u_gradient_norm_squared = jax.vmap(self.zero_fix)(u_gradient_norm_squared)

u_gradient_norm_squared_hat = jnp.fft.rfftn(
u_gradient_norm_squared, axes=space_indices(self.num_spatial_dims)
)
u_gradient_norm_squared_hat = 0.5 * self.fft(u_gradient_norm_squared)
# if self.zero_mode_fix:
# # Fix the mean mode
# u_gradient_norm_squared_hat = u_gradient_norm_squared_hat.at[..., 0].set(
# u_hat[..., 0]
# )

# Requires minus to move term to the rhs
return -self.scale * 0.5 * u_gradient_norm_squared_hat
return -self.scale * u_gradient_norm_squared_hat
Loading

0 comments on commit a1fcbfb

Please sign in to comment.