Skip to content

Commit

Permalink
Move initial conditions into submodule
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Mar 1, 2024
1 parent 539a297 commit f8374a3
Show file tree
Hide file tree
Showing 11 changed files with 347 additions and 488 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ ks_stepper = ex.stepper.KuramotoSivashinskyConservative(
num_points=200, dt=0.1,
)

u_0 = ex.RandomTruncatedFourierSeries(
u_0 = ex.ic.RandomTruncatedFourierSeries(
num_spatial_dims=1, cutoff=5
)(num_points=200, key=jax.random.PRNGKey(0))

Expand Down
15 changes: 2 additions & 13 deletions exponax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
from . import metrics, nonlin_fun, normalized, poisson, stepper
from . import ic, metrics, nonlin_fun, normalized, poisson, stepper
from .forced_stepper import ForcedStepper
from .initial_conditions import (
DiffusedNoise,
GaussianRandomField,
MultiChannelIC,
RandomMultiChannelICGenerator,
RandomTruncatedFourierSeries,
)
from .repeated_stepper import RepeatedStepper
from .spectral import derivative
from .utils import (
Expand All @@ -22,11 +15,6 @@

__all__ = [
"ForcedStepper",
"DiffusedNoise",
"GaussianRandomField",
"MultiChannelIC",
"RandomMultiChannelICGenerator",
"RandomTruncatedFourierSeries",
"normalized",
"poisson",
"RepeatedStepper",
Expand All @@ -42,4 +30,5 @@
"wrap_bc",
"metrics",
"nonlin_fun",
"ic",
]
15 changes: 15 additions & 0 deletions exponax/ic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .base_ic import BaseIC, BaseRandomICGenerator
from .diffused_noise import DiffusedNoise
from .gaussian_random_field import GaussianRandomField
from .multi_channel import MultiChannelIC, RandomMultiChannelICGenerator
from .truncated_fourier_series import RandomTruncatedFourierSeries

__all__ = [
"BaseIC",
"BaseRandomICGenerator",
"DiffusedNoise",
"GaussianRandomField",
"MultiChannelIC",
"RandomMultiChannelICGenerator",
"RandomTruncatedFourierSeries",
]
69 changes: 69 additions & 0 deletions exponax/ic/base_ic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from abc import ABC, abstractmethod

import equinox as eqx
from jaxtyping import Array, Float, PRNGKeyArray

from ..utils import get_grid


class BaseIC(eqx.Module, ABC):
@abstractmethod
def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]:
"""
Evaluate the initial condition.
**Arguments**:
- `x`: The grid points.
**Returns**:
- `u`: The initial condition evaluated at the grid points.
"""
pass


class BaseRandomICGenerator(eqx.Module):
num_spatial_dims: int
domain_extent: float
indexing: str = "ij"

def gen_ic_fun(self, num_points: int, *, key: PRNGKeyArray) -> BaseIC:
"""
Generate an initial condition function.
**Arguments**:
- `num_points`: The number of grid points in each dimension.
- `key`: A jax random key.
**Returns**:
- `ic`: An initial condition function that can be evaluated at
degree of freedom locations.
"""
raise NotImplementedError(
"This random ic generator cannot represent its initial condition as a function. Directly evaluate it."
)

def __call__(
self,
num_points: int,
*,
key: PRNGKeyArray,
) -> Float[Array, "1 ... N"]:
"""
Generate a random initial condition.
**Arguments**:
- `num_points`: The number of grid points in each dimension.
- `key`: A jax random key.
- `indexing`: The indexing convention for the grid.
**Returns**:
- `u`: The initial condition evaluated at the grid points.
"""
ic_fun = self.gen_ic_fun(num_points, key=key)
grid = get_grid(
self.num_spatial_dims,
self.domain_extent,
num_points,
indexing=self.indexing,
)
return ic_fun(grid)
59 changes: 59 additions & 0 deletions exponax/ic/diffused_noise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Array, Float, PRNGKeyArray

from ..spectral import spatial_shape
from ..stepper import Diffusion
from .base_ic import BaseRandomICGenerator


class DiffusedNoise(BaseRandomICGenerator):
num_spatial_dims: int
domain_extent: float
intensity: float
zero_mean: bool

def __init__(
self,
num_spatial_dims: int,
domain_extent: float = 1.0,
*,
intensity=0.001,
zero_mean: bool = False,
):
"""
Randomly generated initial condition consisting of a diffused noise field.
Arguments are drawn from uniform distributions.
**Arguments**:
- `D`: The dimension of the domain.
- `L`: The length of the domain.
- `N`: The number of grid points in each dimension.
- `intensity`: The diffusivity.
- `zero_mean`: Whether to subtract the mean.
"""
self.num_spatial_dims = num_spatial_dims
self.domain_extent = domain_extent
self.intensity = intensity
self.zero_mean = zero_mean

def __call__(
self, num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]:
noise_shape = (1,) + spatial_shape(self.num_spatial_dims, num_points)
noise = jr.normal(key, shape=noise_shape)

diffusion_stepper = Diffusion(
self.num_spatial_dims,
self.domain_extent,
num_points,
1.0,
diffusivity=self.intensity,
)
ic = diffusion_stepper(noise)

if self.zero_mean:
ic = ic - jnp.mean(ic)

return ic
1 change: 1 addition & 0 deletions exponax/ic/discontinuities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# ToDo
69 changes: 69 additions & 0 deletions exponax/ic/gaussian_random_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Array, Float, PRNGKeyArray

from ..spectral import (
build_scaled_wavenumbers,
space_indices,
spatial_shape,
wavenumber_shape,
)
from .base_ic import BaseRandomICGenerator


class GaussianRandomField(BaseRandomICGenerator):
num_spatial_dims: int
domain_extent: float
powerlaw_exponent: float
normalize: bool

def __init__(
self,
num_spatial_dims: int,
domain_extent: float = 1.0,
*,
powerlaw_exponent: float = 3.0,
normalize: bool = True,
):
"""
Randomly generated initial condition consisting of a Gaussian random field.
"""
self.num_spatial_dims = num_spatial_dims
self.domain_extent = domain_extent
self.powerlaw_exponent = powerlaw_exponent
self.normalize = normalize

def __call__(
self, num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]:
wavenumber_grid = build_scaled_wavenumbers(
self.num_spatial_dims, self.domain_extent, num_points
)
wavenumer_norm_grid = jnp.linalg.norm(wavenumber_grid, axis=0, keepdims=True)
amplitude = jnp.power(wavenumer_norm_grid, -self.powerlaw_exponent / 2.0)
amplitude = (
amplitude.flatten().at[0].set(0.0).reshape(wavenumer_norm_grid.shape)
)

real_key, imag_key = jr.split(key, 2)
noise = jr.normal(
real_key,
shape=(1,) + wavenumber_shape(self.num_spatial_dims, num_points),
) + 1j * jr.normal(
imag_key,
shape=(1,) + wavenumber_shape(self.num_spatial_dims, num_points),
)

noise = noise * amplitude

ic = jnp.fft.irfftn(
noise,
s=spatial_shape(self.num_spatial_dims, num_points),
axes=space_indices(self.num_spatial_dims),
)

if self.normalize:
ic = ic - jnp.mean(ic)
ic = ic / jnp.std(ic)

return ic
39 changes: 39 additions & 0 deletions exponax/ic/multi_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import List

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, Float, PRNGKeyArray

from .base_ic import BaseIC, BaseRandomICGenerator


class MultiChannelIC(eqx.Module):
initial_conditions: List[BaseIC]

def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "C ... N"]:
"""
Evaluate the initial condition.
**Arguments**:
- `x`: The grid points.
**Returns**:
- `u`: The initial condition evaluated at the grid points.
"""
return jnp.concatenate([ic(x) for ic in self.initial_conditions], axis=0)


class RandomMultiChannelICGenerator(eqx.Module):
ic_generators: List[BaseRandomICGenerator]

def gen_ic_fun(self, num_points: int, *, key: PRNGKeyArray) -> MultiChannelIC:
ic_funs = [
ic_gen.gen_ic_fun(num_points, key=key) for ic_gen in self.ic_generators
]
return MultiChannelIC(ic_funs)

def __call__(
self, num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "C ... N"]:
u_list = [ic_gen(num_points, key=key) for ic_gen in self.ic_generators]
return jnp.concatenate(u_list, axis=0)
88 changes: 88 additions & 0 deletions exponax/ic/truncated_fourier_series.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Array, Float, PRNGKeyArray

from ..spectral import (
build_scaling_array,
low_pass_filter_mask,
space_indices,
spatial_shape,
wavenumber_shape,
)
from .base_ic import BaseRandomICGenerator


class RandomTruncatedFourierSeries(BaseRandomICGenerator):
num_spatial_dims: int
domain_extent: float
cutoff: int
amplitude_range: tuple[int, int]
angle_range: tuple[int, int]
offset_range: tuple[int, int]

def __init__(
self,
num_spatial_dims: int,
domain_extent: float = 1.0,
*,
cutoff: int = 10,
amplitude_range: tuple[int, int] = (-1.0, 1.0),
angle_range: tuple[int, int] = (0.0, 2.0 * jnp.pi),
offset_range: tuple[int, int] = (0.0, 0.0), # no offset by default
):
self.num_spatial_dims = num_spatial_dims
self.domain_extent = domain_extent

self.cutoff = cutoff
self.amplitude_range = amplitude_range
self.angle_range = angle_range
self.offset_range = offset_range

def __call__(
self, num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]:
fourier_noise_shape = (1,) + wavenumber_shape(self.num_spatial_dims, num_points)
amplitude_key, angle_key, offset_key = jr.split(key, 3)

amplitude = jr.uniform(
amplitude_key,
shape=fourier_noise_shape,
minval=self.amplitude_range[0],
maxval=self.amplitude_range[1],
)
angle = jr.uniform(
angle_key,
shape=fourier_noise_shape,
minval=self.angle_range[0],
maxval=self.angle_range[1],
)

fourier_noise = amplitude * jnp.exp(1j * angle)

low_pass_filter = low_pass_filter_mask(
self.num_spatial_dims, num_points, cutoff=self.cutoff, axis_separate=True
)

fourier_noise = fourier_noise * low_pass_filter

offset = jr.uniform(
offset_key,
shape=(1,),
minval=self.offset_range[0],
maxval=self.offset_range[1],
)[0]
fourier_noise = (
fourier_noise.flatten().at[0].set(offset).reshape(fourier_noise_shape)
)

fourier_noise = fourier_noise * build_scaling_array(
self.num_spatial_dims, num_points
)

u = jnp.fft.irfftn(
fourier_noise,
s=spatial_shape(self.num_spatial_dims, num_points),
axes=space_indices(self.num_spatial_dims),
)

return u
Loading

0 comments on commit f8374a3

Please sign in to comment.