Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent docs #26

Merged
merged 26 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 38 additions & 14 deletions exponax/_base_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,14 @@ def _build_linear_operator(
Assemble the L operator in Fourier space.

**Arguments:**
- `derivative_operator`: The derivative operator, shape `( D, ...,
N//2+1 )`. The ellipsis are (D-1) axis of size N (**not** of size
N//2+1).

- `derivative_operator`: The derivative operator, shape `( D, ...,
N//2+1 )`. The ellipsis are (D-1) axis of size N (**not** of size
N//2+1).

**Returns:**
- `L`: The linear operator, shape `( C, ..., N//2+1 )`.

- `L`: The linear operator, shape `( C, ..., N//2+1 )`.
"""
pass

Expand All @@ -183,12 +185,15 @@ def _build_nonlinear_fun(
transforms to Fourier space, and evaluates derivatives there.

**Arguments:**
- `derivative_operator`: The derivative operator, shape `( D, ..., N//2+1 )`.

- `derivative_operator`: The derivative operator, shape `( D, ...,
N//2+1 )`.

**Returns:**
- `nonlinear_fun`: A function that evaluates the nonlinearities in
time space, transforms to Fourier space, and evaluates the
derivatives there. Should be a subclass of `BaseNonlinearFun`.

- `nonlinear_fun`: A function that evaluates the nonlinearities in
time space, transforms to Fourier space, and evaluates the
derivatives there. Should be a subclass of `BaseNonlinearFun`.
"""
pass

Expand All @@ -197,10 +202,12 @@ def step(self, u: Float[Array, "C ... N"]) -> Float[Array, "C ... N"]:
Perform one step of the time integration.

**Arguments:**
- `u`: The state vector, shape `(C, ..., N,)`.

- `u`: The state vector, shape `(C, ..., N,)`.

**Returns:**
- `u_next`: The state vector after one step, shape `(C, ..., N,)`.

- `u_next`: The state vector after one step, shape `(C, ..., N,)`.
"""
u_hat = fft(u, num_spatial_dims=self.num_spatial_dims)
u_next_hat = self.step_fourier(u_hat)
Expand All @@ -220,11 +227,13 @@ def step_fourier(
transforms.

**Arguments:**
- `u_hat`: The (real) Fourier transform of the state vector

- `u_hat`: The (real) Fourier transform of the state vector

**Returns:**
- `u_next_hat`: The (real) Fourier transform of the state vector
after one step

- `u_next_hat`: The (real) Fourier transform of the state vector
after one step
"""
return self._integrator.step_fourier(u_hat)

Expand All @@ -233,7 +242,22 @@ def __call__(
u: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
"""
Performs a check
Perform one step of the time integration for a single state.

**Arguments:**

- `u`: The state vector, shape `(C, ..., N,)`.

**Returns:**

- `u_next`: The state vector after one step, shape `(C, ..., N,)`.

!!! tip
Use this call method together with `exponax.rollout` to efficiently
produce temporal trajectories.

!!! info
For batched operation, use `jax.vmap` on this function.
"""
expected_shape = (self.num_channels,) + spatial_shape(
self.num_spatial_dims, self.num_points
Expand Down
32 changes: 19 additions & 13 deletions exponax/_forced_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(
transient integrators to forced problems.

**Arguments**:
- `stepper`: The stepper to be transformed.

- `stepper`: The stepper to be transformed.
"""
self.stepper = stepper

Expand All @@ -49,11 +50,13 @@ def step(
The forcing term `f` is assumed to be evaluated on the same grid as `u`.

**Arguments**:
- `u`: The current state.
- `f`: The forcing term.

- `u`: The current state.
- `f`: The forcing term.

**Returns**:
- `u_next`: The state after one time step.

- `u_next`: The state after one time step.
"""
u_with_force = u + self.stepper.dt * f
return self.stepper.step(u_with_force)
Expand All @@ -71,11 +74,13 @@ def step_fourier(
`u_hat`.

**Arguments**:
- `u_hat`: The current state in Fourier space.
- `f_hat`: The forcing term in Fourier space.

- `u_hat`: The current state in Fourier space.
- `f_hat`: The forcing term in Fourier space.

**Returns**:
- `u_next_hat`: The state after one time step in Fourier space.

- `u_next_hat`: The state after one time step in Fourier space.
"""
u_hat_with_force = u_hat + self.stepper.dt * f_hat
return self.stepper.step_fourier(u_hat_with_force)
Expand All @@ -91,12 +96,13 @@ def __call__(

The forcing term `f` is assumed to be evaluated on the same grid as `u`.

**Arguments**:
- `u`: The current state.
- `f`: The forcing term.
**Arguments:**

**Returns**:
- `u_next`: The state after one time step.
"""
- `u`: The current state.
- `f`: The forcing term.

**Returns:**

- `u_next`: The state after one time step.
"""
return self.step(u, f)
36 changes: 26 additions & 10 deletions exponax/_poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ def __init__(
It is included for completion.

**Arguments:**
- `num_spatial_dims`: The number of spatial dimensions.
- `domain_extent`: The extent of the domain.
- `num_points`: The number of points in each spatial dimension.
- `order`: The order of the Poisson equation. Defaults to 2. You can
also set `order=4` for the biharmonic equation.

- `num_spatial_dims`: The number of spatial dimensions.
- `domain_extent`: The extent of the domain.
- `num_points`: The number of points in each spatial dimension.
- `order`: The order of the Poisson equation. Defaults to 2. You can
also set `order=4` for the biharmonic equation.
"""
self.num_spatial_dims = num_spatial_dims
self.domain_extent = domain_extent
Expand All @@ -71,10 +72,12 @@ def step_fourier(
Solve the Poisson equation in Fourier space.

**Arguments:**
- `f_hat`: The Fourier transform of the right hand side.

- `f_hat`: The Fourier transform of the right hand side.

**Returns:**
- `u_hat`: The Fourier transform of the solution.

- `u_hat`: The Fourier transform of the solution.
"""
return -self._inv_operator * f_hat

Expand All @@ -83,13 +86,15 @@ def step(
f: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
"""
Solve the Poisson equation in real space.
Solve the Poisson equation in state space.

**Arguments:**
- `f`: The right hand side.

- `f`: The right hand side.

**Returns:**
- `u`: The solution.

- `u`: The solution.
"""
f_hat = fft(f, num_spatial_dims=self.num_spatial_dims)
u_hat = self.step_fourier(f_hat)
Expand All @@ -104,6 +109,17 @@ def __call__(
self,
f: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
"""
Solve the Poisson equation in state space.

**Arguments:**

- `f`: The right hand side.

**Returns:**

- `u`: The solution.
"""
if f.shape[1:] != spatial_shape(self.num_spatial_dims, self.num_points):
raise ValueError(
f"Shape of f[1:] is {f.shape[1:]} but should be {spatial_shape(self.num_spatial_dims, self.num_points)}"
Expand Down
32 changes: 29 additions & 3 deletions exponax/_repeated_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ def __init__(
time step of X/Y and then wrap it in a RepeatedStepper with num_sub_steps=Y.

**Arguments:**
- `stepper`: The stepper to repeat.
- `num_sub_steps`: The number of substeps to perform.

- `stepper`: The stepper to repeat.
- `num_sub_steps`: The number of substeps to perform.
"""
self.stepper = stepper
self.num_sub_steps = num_sub_steps
Expand All @@ -52,8 +53,16 @@ def step(
u: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
"""
Step the PDE forward in time by self.num_sub_steps time steps given the
Step the PDE forward in time by `self.num_sub_steps` time steps given the
current state `u`.

**Arguments:**

- `u`: The current state.

**Returns:**

- `u_next`: The state after `self.num_sub_steps` time steps.
"""
return repeat(self.stepper.step, self.num_sub_steps)(u)

Expand All @@ -64,6 +73,15 @@ def step_fourier(
"""
Step the PDE forward in time by self.num_sub_steps time steps given the
current state `u_hat` in real-valued Fourier space.

**Arguments:**

- `u_hat`: The current state in Fourier space.

**Returns:**

- `u_next_hat`: The state after `self.num_sub_steps` time steps in Fourier
space.
"""
return repeat(self.stepper.step_fourier, self.num_sub_steps)(u_hat)

Expand All @@ -74,5 +92,13 @@ def __call__(
"""
Step the PDE forward in time by self.num_sub_steps time steps given the
current state `u`.

**Arguments:**

- `u`: The current state.

**Returns:**

- `u_next`: The state after `self.num_sub_steps` time steps.
"""
return repeat(self.stepper, self.num_sub_steps)(u)
Loading
Loading