Skip to content

Commit

Permalink
Rewrite derivative routines to use wrapped FFT calls
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed May 16, 2024
1 parent d6bebb9 commit e0cd2a8
Showing 1 changed file with 107 additions and 107 deletions.
214 changes: 107 additions & 107 deletions exponax/_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,113 +71,6 @@ def build_scaled_wavenumbers(
return scale * wavenumbers


def derivative(
field: Float[Array, "C ... N"],
domain_extent: float,
*,
order: int = 1,
indexing: str = "ij",
) -> Union[Float[Array, "C D ... (N//2)+1"], Float[Array, "D ... (N//2)+1"]]:
"""
Perform the spectral derivative of a field. In higher dimensions, this
defaults to the gradient (the collection of all partial derivatives). In 1d,
the resulting channel dimension holds the derivative. If the function is
called with an d-dimensional field which has 1 channel, the result will be a
d-dimensional field with d channels (one per partial derivative). If the
field originally had C channels, the result will be a matrix field with C
rows and d columns.
Note that applying this operator twice will produce issues at the Nyquist if
the number of degrees of freedom N is even. For this, consider also using
the order option.
**Arguments:**
- `field`: The field to differentiate, shape `(C, ..., N,)`. `C` can be
`1` for a scalar field or `D` for a vector field.
- `L`: The domain extent.
- `order`: The order of the derivative. Default is `1`.
- `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`.
Either `"ij"` or `"xy"`. Default is `"ij"`.
**Returns:**
- `field_der`: The derivative of the field, shape `(C, D, ...,
(N//2)+1)` or `(D, ..., (N//2)+1)`.
"""
channel_shape = field.shape[0]
spatial_shape = field.shape[1:]
D = len(spatial_shape)
N = spatial_shape[0]
derivative_operator = build_derivative_operator(
D, domain_extent, N, indexing=indexing
)
# # I decided to not use this fix

# # Required for even N, no effect for odd N
# derivative_operator_fixed = (
# derivative_operator * nyquist_filter_mask(D, N)
# )
derivative_operator_fixed = derivative_operator**order

field_hat = jnp.fft.rfftn(field, axes=space_indices(D))
if channel_shape == 1:
# Do not introduce another channel axis
field_der_hat = derivative_operator_fixed * field_hat
else:
# Create a "derivative axis" right after the channel axis
field_der_hat = field_hat[:, None] * derivative_operator_fixed[None, ...]

field_der = jnp.fft.irfftn(field_der_hat, s=spatial_shape, axes=space_indices(D))

return field_der


def make_incompressible(
field: Float[Array, "D ... N"],
*,
indexing: str = "ij",
):
channel_shape = field.shape[0]
spatial_shape = field.shape[1:]
num_spatial_dims = len(spatial_shape)
if channel_shape != num_spatial_dims:
raise ValueError(
f"Expected the number of channels to be {num_spatial_dims}, got {channel_shape}."
)
num_points = spatial_shape[0]

derivative_operator = build_derivative_operator(
num_spatial_dims, 1.0, num_points, indexing=indexing
) # domain_extent does not matter because it will cancel out

incompressible_field_hat = jnp.fft.rfftn(
field, axes=space_indices(num_spatial_dims)
)

divergence = jnp.sum(
derivative_operator * incompressible_field_hat, axis=0, keepdims=True
)

laplace_operator = build_laplace_operator(derivative_operator)

inv_laplace_operator = jnp.where(
laplace_operator == 0,
1.0,
1.0 / laplace_operator,
)

pseudo_pressure = -inv_laplace_operator * divergence

pseudo_pressure_garadient = derivative_operator * pseudo_pressure

incompressible_field_hat = incompressible_field_hat - pseudo_pressure_garadient

incompressible_field = jnp.fft.irfftn(
incompressible_field_hat, s=spatial_shape, axes=space_indices(num_spatial_dims)
)

return incompressible_field


def build_derivative_operator(
num_spatial_dims: int,
domain_extent: float,
Expand Down Expand Up @@ -535,3 +428,110 @@ def ifft(
s=spatial_shape(num_spatial_dims, num_points),
axes=space_indices(num_spatial_dims),
)


def derivative(
field: Float[Array, "C ... N"],
domain_extent: float,
*,
order: int = 1,
indexing: str = "ij",
) -> Union[Float[Array, "C D ... (N//2)+1"], Float[Array, "D ... (N//2)+1"]]:
"""
Perform the spectral derivative of a field. In higher dimensions, this
defaults to the gradient (the collection of all partial derivatives). In 1d,
the resulting channel dimension holds the derivative. If the function is
called with an d-dimensional field which has 1 channel, the result will be a
d-dimensional field with d channels (one per partial derivative). If the
field originally had C channels, the result will be a matrix field with C
rows and d columns.
Note that applying this operator twice will produce issues at the Nyquist if
the number of degrees of freedom N is even. For this, consider also using
the order option.
**Arguments:**
- `field`: The field to differentiate, shape `(C, ..., N,)`. `C` can be
`1` for a scalar field or `D` for a vector field.
- `L`: The domain extent.
- `order`: The order of the derivative. Default is `1`.
- `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`.
Either `"ij"` or `"xy"`. Default is `"ij"`.
**Returns:**
- `field_der`: The derivative of the field, shape `(C, D, ...,
(N//2)+1)` or `(D, ..., (N//2)+1)`.
"""
channel_shape = field.shape[0]
spatial_shape = field.shape[1:]
D = len(spatial_shape)
N = spatial_shape[0]
derivative_operator = build_derivative_operator(
D, domain_extent, N, indexing=indexing
)
# # I decided to not use this fix

# # Required for even N, no effect for odd N
# derivative_operator_fixed = (
# derivative_operator * nyquist_filter_mask(D, N)
# )
derivative_operator_fixed = derivative_operator**order

field_hat = fft(field, num_spatial_dims=D)
if channel_shape == 1:
# Do not introduce another channel axis
field_der_hat = derivative_operator_fixed * field_hat
else:
# Create a "derivative axis" right after the channel axis
field_der_hat = field_hat[:, None] * derivative_operator_fixed[None, ...]

field_der = ifft(field_der_hat, num_spatial_dims=D, num_points=N)

return field_der


def make_incompressible(
field: Float[Array, "D ... N"],
*,
indexing: str = "ij",
):
channel_shape = field.shape[0]
spatial_shape = field.shape[1:]
num_spatial_dims = len(spatial_shape)
if channel_shape != num_spatial_dims:
raise ValueError(
f"Expected the number of channels to be {num_spatial_dims}, got {channel_shape}."
)
num_points = spatial_shape[0]

derivative_operator = build_derivative_operator(
num_spatial_dims, 1.0, num_points, indexing=indexing
) # domain_extent does not matter because it will cancel out

incompressible_field_hat = fft(field, num_spatial_dims=num_spatial_dims)

divergence = jnp.sum(
derivative_operator * incompressible_field_hat, axis=0, keepdims=True
)

laplace_operator = build_laplace_operator(derivative_operator)

inv_laplace_operator = jnp.where(
laplace_operator == 0,
1.0,
1.0 / laplace_operator,
)

pseudo_pressure = -inv_laplace_operator * divergence

pseudo_pressure_garadient = derivative_operator * pseudo_pressure

incompressible_field_hat = incompressible_field_hat - pseudo_pressure_garadient

incompressible_field = ifft(
incompressible_field_hat,
num_spatial_dims=num_spatial_dims,
num_points=num_points,
)

return incompressible_field

0 comments on commit e0cd2a8

Please sign in to comment.