-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Major revision to the nonlinear function interface
- Loading branch information
Showing
31 changed files
with
301 additions
and
320 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,66 +1,84 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Optional | ||
|
||
import equinox as eqx | ||
from jaxtyping import Array, Bool, Complex | ||
import jax.numpy as jnp | ||
from jaxtyping import Array, Bool, Complex, Float | ||
|
||
from .._spectral import low_pass_filter_mask, wavenumber_shape | ||
from .._spectral import low_pass_filter_mask, space_indices, spatial_shape | ||
|
||
|
||
class BaseNonlinearFun(eqx.Module, ABC): | ||
num_spatial_dims: int | ||
num_points: int | ||
num_channels: int | ||
derivative_operator: Complex[Array, "D ... (N//2)+1"] | ||
dealiasing_mask: Bool[Array, "1 ... (N//2)+1"] | ||
dealiasing_mask: Optional[Bool[Array, "1 ... (N//2)+1"]] | ||
|
||
def __init__( | ||
self, | ||
num_spatial_dims: int, | ||
num_points: int, | ||
num_channels: int, | ||
*, | ||
derivative_operator: Complex[Array, "D ... (N//2)+1"], | ||
dealiasing_fraction: float, | ||
dealiasing_fraction: Optional[float] = None, | ||
): | ||
self.num_spatial_dims = num_spatial_dims | ||
self.num_points = num_points | ||
self.num_channels = num_channels | ||
self.derivative_operator = derivative_operator | ||
|
||
# Can be done because num_points is identical in all spatial dimensions | ||
nyquist_mode = (num_points // 2) + 1 | ||
highest_resolved_mode = nyquist_mode - 1 | ||
start_of_aliased_modes = dealiasing_fraction * highest_resolved_mode | ||
if dealiasing_fraction is None: | ||
self.dealiasing_mask = None | ||
else: | ||
# Can be done because num_points is identical in all spatial dimensions | ||
nyquist_mode = (num_points // 2) + 1 | ||
highest_resolved_mode = nyquist_mode - 1 | ||
start_of_aliased_modes = dealiasing_fraction * highest_resolved_mode | ||
|
||
self.dealiasing_mask = low_pass_filter_mask( | ||
num_spatial_dims, | ||
num_points, | ||
cutoff=start_of_aliased_modes - 1, | ||
) | ||
self.dealiasing_mask = low_pass_filter_mask( | ||
num_spatial_dims, | ||
num_points, | ||
cutoff=start_of_aliased_modes - 1, | ||
) | ||
|
||
@abstractmethod | ||
def evaluate( | ||
self, | ||
u_hat: Complex[Array, "C ... (N//2)+1"], | ||
def dealias( | ||
self, u_hat: Complex[Array, "C ... (N//2)+1"] | ||
) -> Complex[Array, "C ... (N//2)+1"]: | ||
""" | ||
Evaluate all potential nonlinearities "pseudo-spectrally", account for dealiasing. | ||
""" | ||
raise NotImplementedError("Must be implemented by subclass") | ||
if self.dealiasing_mask is None: | ||
raise ValueError("Nonlinear function was set up without dealiasing") | ||
return self.dealiasing_mask * u_hat | ||
|
||
def fft(self, u: Float[Array, "C ... N"]) -> Complex[Array, "C ... (N//2)+1"]: | ||
return jnp.fft.rfftn(u, axes=space_indices(self.num_spatial_dims)) | ||
|
||
def ifft(self, u_hat: Complex[Array, "C ... (N//2)+1"]) -> Float[Array, "C ... N"]: | ||
return jnp.fft.irfftn( | ||
u_hat, | ||
s=spatial_shape(self.num_spatial_dims, self.num_points), | ||
axes=space_indices(self.num_spatial_dims), | ||
) | ||
|
||
# @abstractmethod | ||
# def evaluate( | ||
# self, | ||
# u_hat: Complex[Array, "C ... (N//2)+1"], | ||
# ) -> Complex[Array, "C ... (N//2)+1"]: | ||
# """ | ||
# Evaluate all potential nonlinearities "pseudo-spectrally", account for dealiasing. | ||
# """ | ||
# raise NotImplementedError("Must be implemented by subclass") | ||
|
||
@abstractmethod | ||
def __call__( | ||
self, | ||
u_hat: Complex[Array, "C ... (N//2)+1"], | ||
) -> Complex[Array, "C ... (N//2)+1"]: | ||
""" | ||
Perform check | ||
Evaluate all potential nonlinearities "pseudo-spectrally", account for dealiasing. | ||
""" | ||
expected_shape = (self.num_channels,) + wavenumber_shape( | ||
self.num_spatial_dims, self.num_points | ||
) | ||
if u_hat.shape != expected_shape: | ||
raise ValueError( | ||
f"Expected shape {expected_shape}, got {u_hat.shape}. For batched operation use `jax.vmap` on this function." | ||
) | ||
# expected_shape = (self.num_channels,) + wavenumber_shape( | ||
# self.num_spatial_dims, self.num_points | ||
# ) | ||
# if u_hat.shape != expected_shape: | ||
# raise ValueError( | ||
# f"Expected shape {expected_shape}, got {u_hat.shape}. For batched operation use `jax.vmap` on this function." | ||
# ) | ||
|
||
return self.evaluate(u_hat) | ||
# return self.evaluate(u_hat) | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,79 +1,71 @@ | ||
from jaxtyping import Array, Complex | ||
|
||
from ._base import BaseNonlinearFun | ||
from ._convection import ConvectionNonlinearFun | ||
from ._gradient_norm import GradientNormNonlinearFun | ||
from ._polynomial import PolynomialNonlinearFun | ||
from ._single_channel_convection import SingleChannelConvectionNonlinearFun | ||
|
||
|
||
class GeneralNonlinearFun1d(BaseNonlinearFun): | ||
class GeneralNonlinearFun(BaseNonlinearFun): | ||
square_nonlinear_fun: PolynomialNonlinearFun | ||
convection_nonlinear_fun: ConvectionNonlinearFun | ||
convection_nonlinear_fun: SingleChannelConvectionNonlinearFun | ||
gradient_norm_nonlinear_fun: GradientNormNonlinearFun | ||
|
||
def __init__( | ||
self, | ||
num_spatial_dims: int, | ||
num_points: int, | ||
num_channels: int, | ||
*, | ||
derivative_operator: Complex[Array, "D ... (N//2)+1"], | ||
dealiasing_fraction: float, | ||
scale_list: list[float] = [0.0, -1.0, 0.0], | ||
zero_mode_fix: bool = False, | ||
zero_mode_fix: bool = True, | ||
): | ||
""" | ||
Uses an additional scaling of 0.5 on the latter two components only | ||
By default: Burgers equation | ||
""" | ||
if num_spatial_dims != 1: | ||
raise ValueError("The number of spatial dimensions must be 1") | ||
if len(scale_list) != 3: | ||
raise ValueError("The scale list must have exactly 3 elements") | ||
|
||
self.square_nonlinear_fun = PolynomialNonlinearFun( | ||
num_spatial_dims=num_spatial_dims, | ||
num_points=num_points, | ||
num_channels=num_channels, | ||
num_spatial_dims, | ||
num_points, | ||
derivative_operator=derivative_operator, | ||
dealiasing_fraction=dealiasing_fraction, | ||
coefficients=[0.0, 0.0, scale_list[0]], | ||
) | ||
self.convection_nonlinear_fun = ConvectionNonlinearFun( | ||
num_spatial_dims=num_spatial_dims, | ||
num_points=num_points, | ||
num_channels=num_channels, | ||
self.convection_nonlinear_fun = SingleChannelConvectionNonlinearFun( | ||
num_spatial_dims, | ||
num_points, | ||
derivative_operator=derivative_operator, | ||
dealiasing_fraction=dealiasing_fraction, | ||
# Minus required because it internally has another minus | ||
scale=-scale_list[1], | ||
zero_mode_fix=zero_mode_fix, | ||
) | ||
self.gradient_norm_nonlinear_fun = GradientNormNonlinearFun( | ||
num_spatial_dims=num_spatial_dims, | ||
num_points=num_points, | ||
num_channels=num_channels, | ||
num_spatial_dims, | ||
num_points, | ||
derivative_operator=derivative_operator, | ||
dealiasing_fraction=dealiasing_fraction, | ||
# Minus required because it internally has another minus | ||
scale=-scale_list[2], | ||
zero_mode_fix=zero_mode_fix, | ||
) | ||
|
||
super().__init__( | ||
num_spatial_dims, | ||
num_points, | ||
num_channels, | ||
derivative_operator=derivative_operator, | ||
dealiasing_fraction=dealiasing_fraction, | ||
) | ||
|
||
def evaluate( | ||
def __call__( | ||
self, | ||
u_hat: Complex[Array, "C ... (N//2)+1"], | ||
) -> Complex[Array, "C ... (N//2)+1"]: | ||
return ( | ||
self.square_nonlinear_fun.evaluate(u_hat) | ||
+ self.convection_nonlinear_fun.evaluate(u_hat) | ||
+ self.gradient_norm_nonlinear_fun.evaluate(u_hat) | ||
self.square_nonlinear_fun(u_hat) | ||
+ self.convection_nonlinear_fun(u_hat) | ||
+ self.gradient_norm_nonlinear_fun(u_hat) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.