diff --git a/exponax/nonlin_fun/__init__.py b/exponax/nonlin_fun/__init__.py index 72b2f10..73c3cc9 100644 --- a/exponax/nonlin_fun/__init__.py +++ b/exponax/nonlin_fun/__init__.py @@ -1,6 +1,6 @@ from ._base import BaseNonlinearFun from ._convection import ConvectionNonlinearFun -from ._general_nonlinear import GeneralNonlinearFun1d +from ._general_nonlinear import GeneralNonlinearFun from ._gradient_norm import GradientNormNonlinearFun from ._polynomial import PolynomialNonlinearFun from ._vorticity_convection import ( @@ -12,7 +12,7 @@ __all__ = [ "BaseNonlinearFun", "ConvectionNonlinearFun", - "GeneralNonlinearFun1d", + "GeneralNonlinearFun", "GradientNormNonlinearFun", "PolynomialNonlinearFun", "VorticityConvection2d", diff --git a/exponax/nonlin_fun/_base.py b/exponax/nonlin_fun/_base.py index cac2eed..78f517d 100644 --- a/exponax/nonlin_fun/_base.py +++ b/exponax/nonlin_fun/_base.py @@ -1,66 +1,84 @@ from abc import ABC, abstractmethod +from typing import Optional import equinox as eqx -from jaxtyping import Array, Bool, Complex +import jax.numpy as jnp +from jaxtyping import Array, Bool, Complex, Float -from .._spectral import low_pass_filter_mask, wavenumber_shape +from .._spectral import low_pass_filter_mask, space_indices, spatial_shape class BaseNonlinearFun(eqx.Module, ABC): num_spatial_dims: int num_points: int - num_channels: int - derivative_operator: Complex[Array, "D ... (N//2)+1"] - dealiasing_mask: Bool[Array, "1 ... (N//2)+1"] + dealiasing_mask: Optional[Bool[Array, "1 ... (N//2)+1"]] def __init__( self, num_spatial_dims: int, num_points: int, - num_channels: int, *, - derivative_operator: Complex[Array, "D ... (N//2)+1"], - dealiasing_fraction: float, + dealiasing_fraction: Optional[float] = None, ): self.num_spatial_dims = num_spatial_dims self.num_points = num_points - self.num_channels = num_channels - self.derivative_operator = derivative_operator - # Can be done because num_points is identical in all spatial dimensions - nyquist_mode = (num_points // 2) + 1 - highest_resolved_mode = nyquist_mode - 1 - start_of_aliased_modes = dealiasing_fraction * highest_resolved_mode + if dealiasing_fraction is None: + self.dealiasing_mask = None + else: + # Can be done because num_points is identical in all spatial dimensions + nyquist_mode = (num_points // 2) + 1 + highest_resolved_mode = nyquist_mode - 1 + start_of_aliased_modes = dealiasing_fraction * highest_resolved_mode - self.dealiasing_mask = low_pass_filter_mask( - num_spatial_dims, - num_points, - cutoff=start_of_aliased_modes - 1, - ) + self.dealiasing_mask = low_pass_filter_mask( + num_spatial_dims, + num_points, + cutoff=start_of_aliased_modes - 1, + ) - @abstractmethod - def evaluate( - self, - u_hat: Complex[Array, "C ... (N//2)+1"], + def dealias( + self, u_hat: Complex[Array, "C ... (N//2)+1"] ) -> Complex[Array, "C ... (N//2)+1"]: - """ - Evaluate all potential nonlinearities "pseudo-spectrally", account for dealiasing. - """ - raise NotImplementedError("Must be implemented by subclass") + if self.dealiasing_mask is None: + raise ValueError("Nonlinear function was set up without dealiasing") + return self.dealiasing_mask * u_hat + def fft(self, u: Float[Array, "C ... N"]) -> Complex[Array, "C ... (N//2)+1"]: + return jnp.fft.rfftn(u, axes=space_indices(self.num_spatial_dims)) + + def ifft(self, u_hat: Complex[Array, "C ... (N//2)+1"]) -> Float[Array, "C ... N"]: + return jnp.fft.irfftn( + u_hat, + s=spatial_shape(self.num_spatial_dims, self.num_points), + axes=space_indices(self.num_spatial_dims), + ) + + # @abstractmethod + # def evaluate( + # self, + # u_hat: Complex[Array, "C ... (N//2)+1"], + # ) -> Complex[Array, "C ... (N//2)+1"]: + # """ + # Evaluate all potential nonlinearities "pseudo-spectrally", account for dealiasing. + # """ + # raise NotImplementedError("Must be implemented by subclass") + + @abstractmethod def __call__( self, u_hat: Complex[Array, "C ... (N//2)+1"], ) -> Complex[Array, "C ... (N//2)+1"]: """ - Perform check + Evaluate all potential nonlinearities "pseudo-spectrally", account for dealiasing. """ - expected_shape = (self.num_channels,) + wavenumber_shape( - self.num_spatial_dims, self.num_points - ) - if u_hat.shape != expected_shape: - raise ValueError( - f"Expected shape {expected_shape}, got {u_hat.shape}. For batched operation use `jax.vmap` on this function." - ) + # expected_shape = (self.num_channels,) + wavenumber_shape( + # self.num_spatial_dims, self.num_points + # ) + # if u_hat.shape != expected_shape: + # raise ValueError( + # f"Expected shape {expected_shape}, got {u_hat.shape}. For batched operation use `jax.vmap` on this function." + # ) - return self.evaluate(u_hat) + # return self.evaluate(u_hat) + pass diff --git a/exponax/nonlin_fun/_convection.py b/exponax/nonlin_fun/_convection.py index e19f208..ad002b3 100644 --- a/exponax/nonlin_fun/_convection.py +++ b/exponax/nonlin_fun/_convection.py @@ -1,18 +1,17 @@ import jax.numpy as jnp from jaxtyping import Array, Complex -from .._spectral import space_indices, spatial_shape from ._base import BaseNonlinearFun class ConvectionNonlinearFun(BaseNonlinearFun): + derivative_operator: Complex[Array, "D ... (N//2)+1"] scale: float def __init__( self, num_spatial_dims: int, num_points: int, - num_channels: int, *, derivative_operator: Complex[Array, "D ... (N//2)+1"], dealiasing_fraction: float, @@ -21,33 +20,51 @@ def __init__( """ Uses by default a scaling of 0.5 to take into account the conservative evaluation """ + self.derivative_operator = derivative_operator self.scale = scale super().__init__( num_spatial_dims, num_points, - num_channels, - derivative_operator=derivative_operator, dealiasing_fraction=dealiasing_fraction, ) - def evaluate( - self, - u_hat: Complex[Array, "C ... (N//2)+1"], + def __call__( + self, u_hat: Complex[Array, "C ... (N//2)+1"] ) -> Complex[Array, "C ... (N//2)+1"]: - u_hat_dealiased = self.dealiasing_mask * u_hat - u = jnp.fft.irfftn( - u_hat_dealiased, - s=spatial_shape(self.num_spatial_dims, self.num_points), - axes=space_indices(self.num_spatial_dims), - ) + num_channels = u_hat.shape[0] + if num_channels != self.num_spatial_dims: + raise ValueError( + "Number of channels in u_hat should match number of spatial dimensions" + ) + u_hat_dealiased = self.dealias(u_hat) + u = self.ifft(u_hat_dealiased) u_outer_product = u[:, None] * u[None, :] - - u_outer_product_hat = jnp.fft.rfftn( - u_outer_product, axes=space_indices(self.num_spatial_dims) - ) - u_divergence_on_outer_product_hat = jnp.sum( + u_outer_product_hat = self.fft(u_outer_product) + convection = 0.5 * jnp.sum( self.derivative_operator[None, :] * u_outer_product_hat, axis=1, ) # Requires minus to move term to the rhs - return -self.scale * 0.5 * u_divergence_on_outer_product_hat + return -self.scale * convection + + # def evaluate( + # self, + # u_hat: Complex[Array, "C ... (N//2)+1"], + # ) -> Complex[Array, "C ... (N//2)+1"]: + # u_hat_dealiased = self.dealiasing_mask * u_hat + # u = jnp.fft.irfftn( + # u_hat_dealiased, + # s=spatial_shape(self.num_spatial_dims, self.num_points), + # axes=space_indices(self.num_spatial_dims), + # ) + # u_outer_product = u[:, None] * u[None, :] + + # u_outer_product_hat = jnp.fft.rfftn( + # u_outer_product, axes=space_indices(self.num_spatial_dims) + # ) + # u_divergence_on_outer_product_hat = jnp.sum( + # self.derivative_operator[None, :] * u_outer_product_hat, + # axis=1, + # ) + # # Requires minus to move term to the rhs + # return -self.scale * 0.5 * u_divergence_on_outer_product_hat diff --git a/exponax/nonlin_fun/_general_nonlinear.py b/exponax/nonlin_fun/_general_nonlinear.py index 7b80e1c..d01dcda 100644 --- a/exponax/nonlin_fun/_general_nonlinear.py +++ b/exponax/nonlin_fun/_general_nonlinear.py @@ -1,79 +1,71 @@ from jaxtyping import Array, Complex from ._base import BaseNonlinearFun -from ._convection import ConvectionNonlinearFun from ._gradient_norm import GradientNormNonlinearFun from ._polynomial import PolynomialNonlinearFun +from ._single_channel_convection import SingleChannelConvectionNonlinearFun -class GeneralNonlinearFun1d(BaseNonlinearFun): +class GeneralNonlinearFun(BaseNonlinearFun): square_nonlinear_fun: PolynomialNonlinearFun - convection_nonlinear_fun: ConvectionNonlinearFun + convection_nonlinear_fun: SingleChannelConvectionNonlinearFun gradient_norm_nonlinear_fun: GradientNormNonlinearFun def __init__( self, num_spatial_dims: int, num_points: int, - num_channels: int, *, derivative_operator: Complex[Array, "D ... (N//2)+1"], dealiasing_fraction: float, scale_list: list[float] = [0.0, -1.0, 0.0], - zero_mode_fix: bool = False, + zero_mode_fix: bool = True, ): """ Uses an additional scaling of 0.5 on the latter two components only By default: Burgers equation """ - if num_spatial_dims != 1: - raise ValueError("The number of spatial dimensions must be 1") if len(scale_list) != 3: raise ValueError("The scale list must have exactly 3 elements") self.square_nonlinear_fun = PolynomialNonlinearFun( - num_spatial_dims=num_spatial_dims, - num_points=num_points, - num_channels=num_channels, + num_spatial_dims, + num_points, derivative_operator=derivative_operator, dealiasing_fraction=dealiasing_fraction, coefficients=[0.0, 0.0, scale_list[0]], ) - self.convection_nonlinear_fun = ConvectionNonlinearFun( - num_spatial_dims=num_spatial_dims, - num_points=num_points, - num_channels=num_channels, + self.convection_nonlinear_fun = SingleChannelConvectionNonlinearFun( + num_spatial_dims, + num_points, derivative_operator=derivative_operator, dealiasing_fraction=dealiasing_fraction, # Minus required because it internally has another minus scale=-scale_list[1], - zero_mode_fix=zero_mode_fix, ) self.gradient_norm_nonlinear_fun = GradientNormNonlinearFun( - num_spatial_dims=num_spatial_dims, - num_points=num_points, - num_channels=num_channels, + num_spatial_dims, + num_points, derivative_operator=derivative_operator, dealiasing_fraction=dealiasing_fraction, # Minus required because it internally has another minus scale=-scale_list[2], + zero_mode_fix=zero_mode_fix, ) super().__init__( num_spatial_dims, num_points, - num_channels, - derivative_operator=derivative_operator, dealiasing_fraction=dealiasing_fraction, ) - def evaluate( + def __call__( self, u_hat: Complex[Array, "C ... (N//2)+1"], ) -> Complex[Array, "C ... (N//2)+1"]: return ( - self.square_nonlinear_fun.evaluate(u_hat) - + self.convection_nonlinear_fun.evaluate(u_hat) - + self.gradient_norm_nonlinear_fun.evaluate(u_hat) + self.square_nonlinear_fun(u_hat) + + self.convection_nonlinear_fun(u_hat) + + self.gradient_norm_nonlinear_fun(u_hat) ) diff --git a/exponax/nonlin_fun/_gradient_norm.py b/exponax/nonlin_fun/_gradient_norm.py index 7dd2373..14d24fb 100644 --- a/exponax/nonlin_fun/_gradient_norm.py +++ b/exponax/nonlin_fun/_gradient_norm.py @@ -2,19 +2,18 @@ import jax.numpy as jnp from jaxtyping import Array, Complex, Float -from .._spectral import space_indices, spatial_shape from ._base import BaseNonlinearFun class GradientNormNonlinearFun(BaseNonlinearFun): scale: float zero_mode_fix: bool + derivative_operator: Complex[Array, "D ... (N//2)+1"] def __init__( self, num_spatial_dims: int, num_points: int, - num_channels: int, *, derivative_operator: Complex[Array, "D ... (N//2)+1"], dealiasing_fraction: float, @@ -27,10 +26,9 @@ def __init__( super().__init__( num_spatial_dims, num_points, - num_channels, - derivative_operator=derivative_operator, dealiasing_fraction=dealiasing_fraction, ) + self.derivative_operator = derivative_operator self.zero_mode_fix = zero_mode_fix self.scale = scale @@ -40,17 +38,12 @@ def zero_fix( ): return f - jnp.mean(f) - def evaluate( + def __call__( self, u_hat: Complex[Array, "C ... (N//2)+1"], ) -> Complex[Array, "C ... (N//2)+1"]: u_gradient_hat = self.derivative_operator[None, :] * u_hat[:, None] - u_gradient_dealiased_hat = self.dealiasing_mask * u_gradient_hat - u_gradient = jnp.fft.irfftn( - u_gradient_dealiased_hat, - s=spatial_shape(self.num_spatial_dims, self.num_points), - axes=space_indices(self.num_spatial_dims), - ) + u_gradient = self.ifft(self.dealias(u_gradient_hat)) # Reduces the axis introduced by the gradient u_gradient_norm_squared = jnp.sum(u_gradient**2, axis=1) @@ -59,9 +52,7 @@ def evaluate( # Maybe there is more efficient way u_gradient_norm_squared = jax.vmap(self.zero_fix)(u_gradient_norm_squared) - u_gradient_norm_squared_hat = jnp.fft.rfftn( - u_gradient_norm_squared, axes=space_indices(self.num_spatial_dims) - ) + u_gradient_norm_squared_hat = 0.5 * self.fft(u_gradient_norm_squared) # if self.zero_mode_fix: # # Fix the mean mode # u_gradient_norm_squared_hat = u_gradient_norm_squared_hat.at[..., 0].set( @@ -69,4 +60,4 @@ def evaluate( # ) # Requires minus to move term to the rhs - return -self.scale * 0.5 * u_gradient_norm_squared_hat + return -self.scale * u_gradient_norm_squared_hat diff --git a/exponax/nonlin_fun/_polynomial.py b/exponax/nonlin_fun/_polynomial.py index ba28365..fbe8a03 100644 --- a/exponax/nonlin_fun/_polynomial.py +++ b/exponax/nonlin_fun/_polynomial.py @@ -1,7 +1,5 @@ -import jax.numpy as jnp from jaxtyping import Array, Complex -from .._spectral import space_indices, spatial_shape from ._base import BaseNonlinearFun @@ -16,9 +14,7 @@ def __init__( self, num_spatial_dims: int, num_points: int, - num_channels: int, *, - derivative_operator: Complex[Array, "D ... (N//2)+1"], dealiasing_fraction: float, coefficients: list[float], ): @@ -28,29 +24,20 @@ def __init__( super().__init__( num_spatial_dims, num_points, - num_channels, - derivative_operator=derivative_operator, dealiasing_fraction=dealiasing_fraction, ) self.coefficients = coefficients - def evaluate( + def __call__( self, u_hat: Complex[Array, "C ... (N//2)+1"], ) -> Complex[Array, "C ... (N//2)+1"]: - u_hat_dealiased = self.dealiasing_mask * u_hat - u = jnp.fft.irfftn( - u_hat_dealiased, - s=spatial_shape(self.num_spatial_dims, self.num_points), - axes=space_indices(self.num_spatial_dims), - ) + u = self.ifft(self.dealias(u_hat)) u_power = 1.0 u_nonlin = 0.0 for coeff in self.coefficients: u_nonlin += coeff * u_power u_power = u_power * u - u_nonlin_hat = jnp.fft.rfftn( - u_nonlin, axes=space_indices(self.num_spatial_dims) - ) + u_nonlin_hat = self.fft(u_nonlin) return u_nonlin_hat diff --git a/exponax/nonlin_fun/_single_channel_convection.py b/exponax/nonlin_fun/_single_channel_convection.py new file mode 100644 index 0000000..778d567 --- /dev/null +++ b/exponax/nonlin_fun/_single_channel_convection.py @@ -0,0 +1,50 @@ +import jax.numpy as jnp +from jaxtyping import Array, Complex + +from ._base import BaseNonlinearFun + + +class SingleChannelConvectionNonlinearFun(BaseNonlinearFun): + sum_of_derivatives_operator: Complex[Array, "1 ... (N//2)+1"] + scale: float + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float = 2 / 3, + scale: float = 1.0, + ): + """ + Use additional default scaling of 0.5 to account for conservative eval. + + In contrast to the classical convection function, this one does not grow + in channels as the number of spatial dimensions grow. + """ + self.scale = scale + self.sum_of_derivatives_operator = jnp.sum( + derivative_operator, axis=0, keepdims=True + ) + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + + def evaluate( + self, u_hat: Complex[Array, "C ... (N//2)+1"] + ) -> Complex[Array, "C ... (N//2)+1"]: + u_hat_dealiased = self.dealias(u_hat) + u = self.ifft(u_hat_dealiased) + u_square = u**2 + u_square_hat = self.fft(u_square) + single_channel_convection = ( + 0.5 * self.sum_of_derivatives_operator * u_square_hat + ) + # Requires minus to bring convection to the right-hand side + return -self.scale * single_channel_convection diff --git a/exponax/nonlin_fun/_vorticity_convection.py b/exponax/nonlin_fun/_vorticity_convection.py index 7acc69c..ad95022 100644 --- a/exponax/nonlin_fun/_vorticity_convection.py +++ b/exponax/nonlin_fun/_vorticity_convection.py @@ -13,7 +13,6 @@ def __init__( self, num_spatial_dims: int, num_points: int, - num_channels: int, *, convection_scale: float = 1.0, derivative_operator: Complex[Array, "D ... (N//2)+1"], @@ -21,14 +20,10 @@ def __init__( ): if num_spatial_dims != 2: raise ValueError(f"Expected num_spatial_dims = 2, got {num_spatial_dims}.") - if num_channels != 1: - raise ValueError(f"Expected num_channels = 1, got {num_channels}.") super().__init__( num_spatial_dims, num_points, - num_channels, - derivative_operator=derivative_operator, dealiasing_fraction=dealiasing_fraction, ) @@ -40,7 +35,7 @@ def __init__( # mean of the "right-hand side" will be the mean of the solution) self.inv_laplacian = jnp.where(laplacian == 0, 1.0, 1 / laplacian) - def evaluate( + def __call__( self, u_hat: Complex[Array, "1 ... (N//2)+1"] ) -> Complex[Array, "1 ... (N//2)+1"]: vorticity_hat = u_hat @@ -51,31 +46,53 @@ def evaluate( del_vorticity_del_x_hat = self.derivative_operator[0:1] * vorticity_hat del_vorticity_del_y_hat = self.derivative_operator[1:2] * vorticity_hat - u = jnp.fft.irfft2( - u_hat * self.dealiasing_mask, s=(self.num_points, self.num_points) - ) - v = jnp.fft.irfft2( - v_hat * self.dealiasing_mask, s=(self.num_points, self.num_points) - ) - del_vorticity_del_x = jnp.fft.irfft2( - del_vorticity_del_x_hat * self.dealiasing_mask, - s=(self.num_points, self.num_points), - ) - del_vorticity_del_y = jnp.fft.irfft2( - del_vorticity_del_y_hat * self.dealiasing_mask, - s=(self.num_points, self.num_points), - ) + u = self.ifft(self.dealias(u_hat)) + v = self.ifft(self.dealias(v_hat)) + del_vorticity_del_x = self.ifft(self.dealias(del_vorticity_del_x_hat)) + del_vorticity_del_y = self.ifft(self.dealias(del_vorticity_del_y_hat)) convection = u * del_vorticity_del_x + v * del_vorticity_del_y - convection_hat = jnp.fft.rfft2(convection) - - # Do we need another dealiasing mask here? - # convection_hat = self.dealiasing_mask * convection_hat + convection_hat = self.fft(convection) - # Requires minus to move term to the rhs return -self.convection_scale * convection_hat + # def evaluate( + # self, u_hat: Complex[Array, "1 ... (N//2)+1"] + # ) -> Complex[Array, "1 ... (N//2)+1"]: + # vorticity_hat = u_hat + # stream_function_hat = self.inv_laplacian * vorticity_hat + + # u_hat = +self.derivative_operator[1:2] * stream_function_hat + # v_hat = -self.derivative_operator[0:1] * stream_function_hat + # del_vorticity_del_x_hat = self.derivative_operator[0:1] * vorticity_hat + # del_vorticity_del_y_hat = self.derivative_operator[1:2] * vorticity_hat + + # u = jnp.fft.irfft2( + # u_hat * self.dealiasing_mask, s=(self.num_points, self.num_points) + # ) + # v = jnp.fft.irfft2( + # v_hat * self.dealiasing_mask, s=(self.num_points, self.num_points) + # ) + # del_vorticity_del_x = jnp.fft.irfft2( + # del_vorticity_del_x_hat * self.dealiasing_mask, + # s=(self.num_points, self.num_points), + # ) + # del_vorticity_del_y = jnp.fft.irfft2( + # del_vorticity_del_y_hat * self.dealiasing_mask, + # s=(self.num_points, self.num_points), + # ) + + # convection = u * del_vorticity_del_x + v * del_vorticity_del_y + + # convection_hat = jnp.fft.rfft2(convection) + + # # Do we need another dealiasing mask here? + # # convection_hat = self.dealiasing_mask * convection_hat + + # # Requires minus to move term to the rhs + # return -self.convection_scale * convection_hat + class VorticityConvection2dKolmogorov(VorticityConvection2d): injection: Complex[Array, "1 ... (N//2)+1"] @@ -84,7 +101,6 @@ def __init__( self, num_spatial_dims: int, num_points: int, - num_channels: int, *, convection_scale: float = 1.0, injection_mode: int = 4, @@ -95,7 +111,6 @@ def __init__( super().__init__( num_spatial_dims, num_points, - num_channels, convection_scale=convection_scale, derivative_operator=derivative_operator, dealiasing_fraction=dealiasing_fraction, @@ -109,8 +124,8 @@ def __init__( 0.0, ) - def evaluate( + def __call__( self, u_hat: Complex[Array, "1 ... (N//2)+1"] ) -> Complex[Array, "1 ... (N//2)+1"]: - neg_convection_hat = super().evaluate(u_hat) + neg_convection_hat = super()(u_hat) return neg_convection_hat + self.injection diff --git a/exponax/nonlin_fun/_zero.py b/exponax/nonlin_fun/_zero.py index aea491a..322d096 100644 --- a/exponax/nonlin_fun/_zero.py +++ b/exponax/nonlin_fun/_zero.py @@ -9,20 +9,13 @@ def __init__( self, num_spatial_dims: int, num_points: int, - num_channels: int, - *, - derivative_operator: Complex[Array, "D ... (N//2)+1"], - dealiasing_fraction: float = 1.0, ): super().__init__( num_spatial_dims, num_points, - num_channels, - derivative_operator=derivative_operator, - dealiasing_fraction=dealiasing_fraction, ) - def evaluate( + def __call__( self, u_hat: Complex[Array, "C ... (N//2)+1"], ) -> Complex[Array, "C ... (N//2)+1"]: diff --git a/exponax/normalized/_convection.py b/exponax/normalized/_convection.py index 3c511fd..f867e3f 100644 --- a/exponax/normalized/_convection.py +++ b/exponax/normalized/_convection.py @@ -58,9 +58,8 @@ def _build_linear_operator(self, derivative_operator: Array) -> Array: def _build_nonlinear_fun(self, derivative_operator: Array): return ConvectionNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + self.num_spatial_dims, + self.num_points, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, scale=self.normalized_convection_scale, diff --git a/exponax/normalized/_general_nonlinear.py b/exponax/normalized/_general_nonlinear.py index d5233ec..55cd45f 100644 --- a/exponax/normalized/_general_nonlinear.py +++ b/exponax/normalized/_general_nonlinear.py @@ -2,7 +2,7 @@ from jaxtyping import Array, Complex from .._base_stepper import BaseStepper -from ..nonlin_fun import GeneralNonlinearFun1d +from ..nonlin_fun import GeneralNonlinearFun class NormlizedGeneralNonlinearStepper1d(BaseStepper): @@ -26,10 +26,6 @@ def __init__( By default Burgers. """ - if num_spatial_dims != 1: - raise ValueError( - "The number of spatial dimensions must be 1 because of ambiguity in channel growth" - ) if len(normalized_coefficients_nonlinear) != 3: raise ValueError( "The nonlinear coefficients list must have exactly 3 elements" @@ -63,11 +59,10 @@ def _build_linear_operator(self, derivative_operator: Array) -> Array: def _build_nonlinear_fun( self, derivative_operator: Complex[Array, "D ... (N//2)+1"], - ) -> GeneralNonlinearFun1d: - return GeneralNonlinearFun1d( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + ) -> GeneralNonlinearFun: + return GeneralNonlinearFun( + self.num_spatial_dims, + self.num_points, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, scale_list=self.normalized_coefficients_nonlinear, diff --git a/exponax/normalized/_gradient_norm.py b/exponax/normalized/_gradient_norm.py index d414102..2fc4461 100644 --- a/exponax/normalized/_gradient_norm.py +++ b/exponax/normalized/_gradient_norm.py @@ -64,7 +64,6 @@ def _build_nonlinear_fun(self, derivative_operator: Array): return GradientNormNonlinearFun( num_spatial_dims=self.num_spatial_dims, num_points=self.num_points, - num_channels=self.num_channels, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, scale=self.normalized_gradient_norm_scale, diff --git a/exponax/normalized/_linear.py b/exponax/normalized/_linear.py index f33d285..7ccebe9 100644 --- a/exponax/normalized/_linear.py +++ b/exponax/normalized/_linear.py @@ -49,6 +49,4 @@ def _build_nonlinear_fun(self, derivative_operator: Array): return ZeroNonlinearFun( num_spatial_dims=self.num_spatial_dims, num_points=self.num_points, - num_channels=self.num_channels, - derivative_operator=derivative_operator, ) diff --git a/exponax/normalized/_polynomial.py b/exponax/normalized/_polynomial.py index 3da08dd..97934da 100644 --- a/exponax/normalized/_polynomial.py +++ b/exponax/normalized/_polynomial.py @@ -67,10 +67,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> PolynomialNonlinearFun: return PolynomialNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, - derivative_operator=derivative_operator, + self.num_spatial_dims, + self.num_points, coefficients=self.normalized_polynomial_scales, dealiasing_fraction=self.dealiasing_fraction, ) diff --git a/exponax/normalized/_vorticity_convection.py b/exponax/normalized/_vorticity_convection.py index 962551f..43a99cb 100644 --- a/exponax/normalized/_vorticity_convection.py +++ b/exponax/normalized/_vorticity_convection.py @@ -70,18 +70,16 @@ def _build_nonlinear_fun( ) -> VorticityConvection2d: if self.normalized_injection_scale == 0.0: return VorticityConvection2d( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + self.num_spatial_dims, + self.num_points, convection_scale=self.normalized_vorticity_convection_scale, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, ) else: return VorticityConvection2dKolmogorov( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + self.num_spatial_dims, + self.num_points, convection_scale=self.normalized_vorticity_convection_scale, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, diff --git a/exponax/reaction/_allen_cahn.py b/exponax/reaction/_allen_cahn.py index 8cf5ad2..b8922c1 100644 --- a/exponax/reaction/_allen_cahn.py +++ b/exponax/reaction/_allen_cahn.py @@ -46,10 +46,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> PolynomialNonlinearFun: return PolynomialNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, - derivative_operator=derivative_operator, + self.num_spatial_dims, + self.num_points, dealiasing_fraction=self.dealiasing_fraction, coefficients=[0.0, 0.0, 0.0, -1.0], ) diff --git a/exponax/reaction/_belousov_zhabotinsky.py b/exponax/reaction/_belousov_zhabotinsky.py index dc4ab79..835722f 100644 --- a/exponax/reaction/_belousov_zhabotinsky.py +++ b/exponax/reaction/_belousov_zhabotinsky.py @@ -2,7 +2,7 @@ from jaxtyping import Array, Complex from .._base_stepper import BaseStepper -from .._spectral import build_laplace_operator, space_indices, spatial_shape +from .._spectral import build_laplace_operator from ..nonlin_fun import BaseNonlinearFun @@ -15,31 +15,23 @@ def __init__( self, num_spatial_dims: int, num_points: int, - num_channels: int, *, - derivative_operator: Complex[Array, "D ... (N//2)+1"], dealiasing_fraction: float, ): - if num_channels != 3: - raise ValueError(f"Expected num_channels = 3, got {num_channels}.") super().__init__( num_spatial_dims, num_points, - num_channels, - derivative_operator=derivative_operator, dealiasing_fraction=dealiasing_fraction, ) - def evaluate( + def __call__( self, u_hat: Complex[Array, "C ... (N//2)+1"], ) -> Complex[Array, "C ... (N//2)+1"]: - u_hat_dealiased = self.dealiasing_mask * u_hat - u = jnp.fft.irfftn( - u_hat_dealiased, - s=spatial_shape(self.num_spatial_dims, self.num_points), - axes=space_indices(self.num_spatial_dims), - ) + num_channels = u_hat.shape[0] + if num_channels != 3: + raise ValueError("num_channels must be 3") + u = self.ifft(self.dealias(u_hat)) u_power = jnp.stack( [ u[0] + u[1] - u[0] * u[1] - u[0] ** 2, @@ -47,7 +39,7 @@ def evaluate( u[0] - u[2], ] ) - u_power_hat = jnp.fft.rfftn(u_power, axes=space_indices(self.num_spatial_dims)) + u_power_hat = self.fft(u_power) return u_power_hat @@ -101,9 +93,7 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> BelousovZhabotinskyNonlinearFun: return BelousovZhabotinskyNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, - derivative_operator=derivative_operator, + self.num_spatial_dims, + self.num_points, dealiasing_fraction=self.dealiasing_fraction, ) diff --git a/exponax/reaction/_cahn_hilliard.py b/exponax/reaction/_cahn_hilliard.py index 3104f6a..544f3d2 100644 --- a/exponax/reaction/_cahn_hilliard.py +++ b/exponax/reaction/_cahn_hilliard.py @@ -1,46 +1,39 @@ -import jax.numpy as jnp from jaxtyping import Array, Complex from .._base_stepper import BaseStepper -from .._spectral import build_laplace_operator, space_indices, spatial_shape +from .._spectral import build_laplace_operator from ..nonlin_fun import BaseNonlinearFun class CahnHilliardNonlinearFun(BaseNonlinearFun): + laplace_operator: Complex[Array, "1 ... (N//2)+1"] + def __init__( self, num_spatial_dims: int, num_points: int, - num_channels: int, *, derivative_operator: Complex[Array, "D ... (N//2)+1"], dealiasing_fraction: float, ): - if num_channels != 1: - raise ValueError(f"Expected num_channels = 1, got {num_channels}.") super().__init__( num_spatial_dims, num_points, - num_channels, - derivative_operator=derivative_operator, dealiasing_fraction=dealiasing_fraction, ) + self.laplace_operator = build_laplace_operator(derivative_operator) - def evaluate( + def __call__( self, u_hat: Complex[Array, "C ... (N//2)+1"], ) -> Complex[Array, "C ... (N//2)+1"]: - u_hat_dealiased = self.dealiasing_mask * u_hat - u = jnp.fft.irfftn( - u_hat_dealiased, - s=spatial_shape(self.num_spatial_dims, self.num_points), - axes=space_indices(self.num_spatial_dims), - ) + u = self.ifft(self.dealias(u_hat)) u_power = u[0] ** 3 - u_power_hat = jnp.fft.rfftn(u_power, axes=space_indices(self.num_spatial_dims)) + u_power_hat = self.fft(u_power) u_power_laplace_hat = ( build_laplace_operator(self.derivative_operator, order=2) * u_power_hat ) + u_power_laplace_hat = self.laplace_operator * u_power_hat return u_power_laplace_hat @@ -89,9 +82,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> CahnHilliardNonlinearFun: return CahnHilliardNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + self.num_spatial_dims, + self.num_points, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, ) diff --git a/exponax/reaction/_fisher_kpp.py b/exponax/reaction/_fisher_kpp.py index 76b65fb..9c5d136 100644 --- a/exponax/reaction/_fisher_kpp.py +++ b/exponax/reaction/_fisher_kpp.py @@ -48,10 +48,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> PolynomialNonlinearFun: return PolynomialNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, - derivative_operator=derivative_operator, + self.num_spatial_dims, + self.num_points, dealiasing_fraction=self.dealiasing_fraction, coefficients=[0.0, 0.0, -self.r], ) diff --git a/exponax/reaction/_gray_scott.py b/exponax/reaction/_gray_scott.py index 1f2ac74..547328d 100644 --- a/exponax/reaction/_gray_scott.py +++ b/exponax/reaction/_gray_scott.py @@ -2,7 +2,7 @@ from jaxtyping import Array, Complex from .._base_stepper import BaseStepper -from .._spectral import build_laplace_operator, space_indices, spatial_shape +from .._spectral import build_laplace_operator from ..nonlin_fun import BaseNonlinearFun @@ -14,42 +14,34 @@ def __init__( self, num_spatial_dims: int, num_points: int, - num_channels: int, *, - derivative_operator: Complex[Array, "D ... (N//2)+1"], dealiasing_fraction: float, b: float, d: float, ): - if num_channels != 2: - raise ValueError(f"Expected num_channels = 2, got {num_channels}.") super().__init__( num_spatial_dims, num_points, - num_channels, - derivative_operator=derivative_operator, dealiasing_fraction=dealiasing_fraction, ) self.b = b self.d = d - def evaluate( + def __call__( self, u_hat: Complex[Array, "C ... (N//2)+1"], ) -> Complex[Array, "C ... (N//2)+1"]: - u_hat_dealiased = self.dealiasing_mask * u_hat - u = jnp.fft.irfftn( - u_hat_dealiased, - s=spatial_shape(self.num_spatial_dims, self.num_points), - axes=space_indices(self.num_spatial_dims), - ) + num_channels = u_hat.shape[0] + if num_channels != 2: + raise ValueError("num_channels must be 2") + u = self.ifft(self.dealias(u_hat)) u_power = jnp.stack( [ self.b * (1 - u[0]) - u[0] * u[1] ** 2, -self.d * u[1] + u[0] * u[1] ** 2, ] ) - u_power_hat = jnp.fft.rfftn(u_power, axes=space_indices(self.num_spatial_dims)) + u_power_hat = self.fft(u_power) return u_power_hat @@ -111,10 +103,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> GrayScottNonlinearFun: return GrayScottNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, - derivative_operator=derivative_operator, + self.num_spatial_dims, + self.num_points, b=self.b, d=self.d, dealiasing_fraction=self.dealiasing_fraction, diff --git a/exponax/reaction/_swift_hohenberg.py b/exponax/reaction/_swift_hohenberg.py index 9a1d009..e36c9ed 100644 --- a/exponax/reaction/_swift_hohenberg.py +++ b/exponax/reaction/_swift_hohenberg.py @@ -52,10 +52,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> PolynomialNonlinearFun: return PolynomialNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, - derivative_operator=derivative_operator, + self.num_spatial_dims, + self.num_points, dealiasing_fraction=self.dealiasing_fraction, coefficients=[0.0, 0.0, self.g, -1.0], ) diff --git a/exponax/stepper/_burgers.py b/exponax/stepper/_burgers.py index 2341feb..76ff4e6 100644 --- a/exponax/stepper/_burgers.py +++ b/exponax/stepper/_burgers.py @@ -88,9 +88,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> ConvectionNonlinearFun: return ConvectionNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + self.num_spatial_dims, + self.num_points, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, scale=self.convection_scale, diff --git a/exponax/stepper/_convection.py b/exponax/stepper/_convection.py index a7fad7b..222fb6b 100644 --- a/exponax/stepper/_convection.py +++ b/exponax/stepper/_convection.py @@ -63,11 +63,9 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> ConvectionNonlinearFun: return ConvectionNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + self.num_spatial_dims, + self.num_points, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, scale=self.convection_scale, - zero_mode_fix=False, # Todo: check this ) diff --git a/exponax/stepper/_general_nonlinear.py b/exponax/stepper/_general_nonlinear.py index 1c24211..0c3a3e3 100644 --- a/exponax/stepper/_general_nonlinear.py +++ b/exponax/stepper/_general_nonlinear.py @@ -2,7 +2,7 @@ from jaxtyping import Array, Complex from .._base_stepper import BaseStepper -from ..nonlin_fun import GeneralNonlinearFun1d +from ..nonlin_fun import GeneralNonlinearFun class GeneralNonlinearStepper1d(BaseStepper): @@ -27,10 +27,6 @@ def __init__( """ By default Burgers equation """ - if num_spatial_dims != 1: - raise ValueError( - "The number of spatial dimensions must be 1 because of ambiguity in channel growth" - ) if len(coefficients_nonlinear) != 3: raise ValueError( "The nonlinear coefficients list must have exactly 3 elements" @@ -67,11 +63,10 @@ def _build_linear_operator( def _build_nonlinear_fun( self, derivative_operator: Complex[Array, "D ... (N//2)+1"], - ) -> GeneralNonlinearFun1d: - return GeneralNonlinearFun1d( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + ) -> GeneralNonlinearFun: + return GeneralNonlinearFun( + self.num_spatial_dims, + self.num_points, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, scale_list=self.coefficients_nonlinear, diff --git a/exponax/stepper/_gradient_norm.py b/exponax/stepper/_gradient_norm.py index 1afe83b..22b2d35 100644 --- a/exponax/stepper/_gradient_norm.py +++ b/exponax/stepper/_gradient_norm.py @@ -63,9 +63,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> GradientNormNonlinearFun: return GradientNormNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + self.num_spatial_dims, + self.num_points, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, scale=self.gradient_norm_scale, diff --git a/exponax/stepper/_korteweg_de_vries.py b/exponax/stepper/_korteweg_de_vries.py index 6b78979..2cdb0af 100644 --- a/exponax/stepper/_korteweg_de_vries.py +++ b/exponax/stepper/_korteweg_de_vries.py @@ -77,9 +77,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> ConvectionNonlinearFun: return ConvectionNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + self.num_spatial_dims, + self.num_points, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, scale=self.convection_scale, diff --git a/exponax/stepper/_kuramoto_sivashinsky.py b/exponax/stepper/_kuramoto_sivashinsky.py index 2207701..8a39b05 100644 --- a/exponax/stepper/_kuramoto_sivashinsky.py +++ b/exponax/stepper/_kuramoto_sivashinsky.py @@ -63,9 +63,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> GradientNormNonlinearFun: return GradientNormNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + self.num_spatial_dims, + self.num_points, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, zero_mode_fix=True, @@ -130,11 +129,9 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> ConvectionNonlinearFun: return ConvectionNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + self.num_spatial_dims, + self.num_points, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, - zero_mode_fix=True, scale=self.convection_scale, ) diff --git a/exponax/stepper/_linear.py b/exponax/stepper/_linear.py index 48f86de..ef870e1 100644 --- a/exponax/stepper/_linear.py +++ b/exponax/stepper/_linear.py @@ -97,11 +97,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> ZeroNonlinearFun: return ZeroNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, - derivative_operator=derivative_operator, - dealiasing_fraction=1.0, + self.num_spatial_dims, + self.num_points, ) @@ -215,11 +212,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> ZeroNonlinearFun: return ZeroNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, - derivative_operator=derivative_operator, - dealiasing_fraction=1.0, + self.num_spatial_dims, + self.num_points, ) @@ -347,11 +341,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> ZeroNonlinearFun: return ZeroNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, - derivative_operator=derivative_operator, - dealiasing_fraction=1.0, + self.num_spatial_dims, + self.num_points, ) @@ -464,11 +455,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> ZeroNonlinearFun: return ZeroNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, - derivative_operator=derivative_operator, - dealiasing_fraction=1.0, + self.num_spatial_dims, + self.num_points, ) @@ -579,11 +567,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> ZeroNonlinearFun: return ZeroNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, - derivative_operator=derivative_operator, - dealiasing_fraction=1.0, + self.num_spatial_dims, + self.num_points, ) @@ -743,9 +728,6 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> ZeroNonlinearFun: return ZeroNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, - derivative_operator=derivative_operator, - dealiasing_fraction=1.0, + self.num_spatial_dims, + self.num_points, ) diff --git a/exponax/stepper/_navier_stokes.py b/exponax/stepper/_navier_stokes.py index 2deb0ac..2ad58d0 100644 --- a/exponax/stepper/_navier_stokes.py +++ b/exponax/stepper/_navier_stokes.py @@ -26,6 +26,9 @@ def __init__( num_circle_points: int = 16, circle_radius: float = 1.0, ): + if num_spatial_dims != 2: + raise ValueError(f"Expected num_spatial_dims = 2, got {num_spatial_dims}.") + self.diffusivity = diffusivity self.vorticity_convection_scale = vorticity_convection_scale self.drag = drag @@ -54,9 +57,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> VorticityConvection2d: return VorticityConvection2d( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + self.num_spatial_dims, + self.num_points, convection_scale=self.vorticity_convection_scale, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, @@ -120,9 +122,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> VorticityConvection2dKolmogorov: return VorticityConvection2dKolmogorov( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + self.num_spatial_dims, + self.num_points, convection_scale=self.convection_scale, injection_mode=self.injection_mode, injection_scale=self.injection_scale, diff --git a/exponax/stepper/_nikolaevskiy.py b/exponax/stepper/_nikolaevskiy.py index 1b02371..3b4dd49 100644 --- a/exponax/stepper/_nikolaevskiy.py +++ b/exponax/stepper/_nikolaevskiy.py @@ -63,9 +63,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> GradientNormNonlinearFun: return GradientNormNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + self.num_spatial_dims, + self.num_points, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, zero_mode_fix=True, @@ -131,11 +130,9 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> ConvectionNonlinearFun: return ConvectionNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, + self.num_spatial_dims, + self.num_points, derivative_operator=derivative_operator, dealiasing_fraction=self.dealiasing_fraction, - zero_mode_fix=True, scale=self.convection_scale, ) diff --git a/exponax/stepper/_polynomial.py b/exponax/stepper/_polynomial.py index d65230e..48df604 100644 --- a/exponax/stepper/_polynomial.py +++ b/exponax/stepper/_polynomial.py @@ -65,10 +65,8 @@ def _build_nonlinear_fun( derivative_operator: Complex[Array, "D ... (N//2)+1"], ) -> PolynomialNonlinearFun: return PolynomialNonlinearFun( - num_spatial_dims=self.num_spatial_dims, - num_points=self.num_points, - num_channels=self.num_channels, - derivative_operator=derivative_operator, + self.num_spatial_dims, + self.num_points, dealiasing_fraction=self.dealiasing_fraction, coefficients=self.polynomial_scales, )