From 3e8c2d9b3b4ab2de9df2dde4b934b609638dcc06 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 12:56:20 +0200
Subject: [PATCH 01/26] Add documentation to base

---
 exponax/etdrk/_base_etdrk.py | 45 ++++++++++++++++++++++++++++++++++++
 1 file changed, 45 insertions(+)

diff --git a/exponax/etdrk/_base_etdrk.py b/exponax/etdrk/_base_etdrk.py
index 3e9f28b..63d1fee 100644
--- a/exponax/etdrk/_base_etdrk.py
+++ b/exponax/etdrk/_base_etdrk.py
@@ -21,6 +21,43 @@ def __init__(
         dt: float,
         linear_operator: Complex[Array, "E ... (N//2)+1"],
     ):
+        """
+        Base class for exponential time differencing Runge-Kutta methods.
+
+        **Arguments:**
+
+        - `dt`: The time step size.
+        - `linear_operator`: The linear operator of the PDE. Must have a leading
+            channel axis, followed by one, two or three spatial axes whereas the
+            last axis must be of size `(N//2)+1` where `N` is the number of
+            dimensions in the former spatial axes.
+
+        !!! Example
+            Below is an example how to get the linear operator for
+            the heat equation.
+
+            ```python
+            import jax.numpy as jnp
+            import exponax as ex
+
+            # Define the linear operator
+            N = 256
+            L = 5.0  # The domain size
+            D = 1  # Being in 1D
+
+            derivative_operator = 1j * ex.spectral.build_derivative_operator(
+                D,
+                L,
+                N,
+            )
+
+            print(derivative_operator.shape)  # (1, (N//2)+1)
+
+            nu = 0.01 # The diffusion coefficient
+
+            linear_operator = nu * derivative_operator**2
+            ```
+        """
         self.dt = dt
         self._exp_term = jnp.exp(self.dt * linear_operator)
 
@@ -31,5 +68,13 @@ def step_fourier(
     ) -> Complex[Array, "C ... (N//2)+1"]:
         """
         Advance the state in Fourier space.
+
+        **Arguments:**
+
+        - `u_hat`: The previous state in Fourier space.
+
+        **Returns:**
+
+        - The next state in Fourier space, i.e., `self.dt` time units later.
         """
         pass

From bda4de5ccc409d375c22b7360863be99e0056b32 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 13:14:17 +0200
Subject: [PATCH 02/26] Add documentation to etdrk core

---
 exponax/etdrk/_etdrk_0.py | 25 ++++++++++++++++---
 exponax/etdrk/_etdrk_1.py | 38 +++++++++++++++++++++++++++++
 exponax/etdrk/_etdrk_2.py | 46 +++++++++++++++++++++++++++++++++++
 exponax/etdrk/_etdrk_3.py | 49 +++++++++++++++++++++++++++++++++++++
 exponax/etdrk/_etdrk_4.py | 51 +++++++++++++++++++++++++++++++++++++++
 exponax/etdrk/_utils.py   | 11 ++++++++-
 6 files changed, 216 insertions(+), 4 deletions(-)

diff --git a/exponax/etdrk/_etdrk_0.py b/exponax/etdrk/_etdrk_0.py
index 06a0dea..30e94cd 100644
--- a/exponax/etdrk/_etdrk_0.py
+++ b/exponax/etdrk/_etdrk_0.py
@@ -4,9 +4,28 @@
 
 
 class ETDRK0(BaseETDRK):
-    """
-    Exactly solve a linear PDE in Fourier space
-    """
+    def __init__(
+        self,
+        dt: float,
+        linear_operator: Complex[Array, "E ... (N//2)+1"],
+    ):
+        r"""
+        Exactly solve a linear PDE in Fourier space.
+
+        $$
+            \hat{u}_h^{[t+1]} = \exp(\hat{\mathcal{L}}_h \Delta t) \odot
+            \hat{u}_h^{[t]}
+        $$
+
+        **Arguments:**
+
+        - `dt`: The time step size.
+        - `linear_operator`: The linear operator of the PDE. Must have a leading
+            channel axis, followed by one, two or three spatial axes whereas the
+            last axis must be of size `(N//2)+1` where `N` is the number of
+            dimensions in the former spatial axes.
+        """
+        super().__init__(dt, linear_operator)
 
     def step_fourier(
         self,
diff --git a/exponax/etdrk/_etdrk_1.py b/exponax/etdrk/_etdrk_1.py
index de8215e..33d445d 100644
--- a/exponax/etdrk/_etdrk_1.py
+++ b/exponax/etdrk/_etdrk_1.py
@@ -19,6 +19,44 @@ def __init__(
         num_circle_points: int = 16,
         circle_radius: float = 1.0,
     ):
+        r"""
+        Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta
+        with a **first order approximation**.
+
+        Adapted from Eq. (4) of [Cox and Matthews
+        (2002)](https://doi.org/10.1006/jcph.2002.6995):
+
+        $$
+            \hat{u}_h^{[t+1]} = \exp(\hat{\mathcal{L}}_h \Delta t) \odot
+            \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) -
+            1}{\hat{\mathcal{L}}_h} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]})
+        $$
+
+        where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of
+        the nonlinear differential operator.
+
+        **Arguments:**
+
+        - `dt`: The time step size.
+        - `linear_operator`: The linear operator of the PDE. Must have a leading
+            channel axis, followed by one, two or three spatial axes whereas the
+            last axis must be of size `(N//2)+1` where `N` is the number of
+            dimensions in the former spatial axes.
+        - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the
+            nonlinear differential operator.
+        - `num_circle_points`: The number of points on the unit circle used to
+            approximate the numerically challenging coefficients.
+        - `circle_radius`: The radius of the circle used to approximate the
+            numerically challenging coefficients.
+
+        !!! warning
+            The nonlinear function must take care of proper dealiasing.
+
+        !!! note
+            The numerically stable evaluation of the coefficients follows
+            [Kassam and Trefethen
+            (2005)](https://doi.org/10.1137/S1064827502410633).
+        """
         super().__init__(dt, linear_operator)
         self._nonlinear_fun = nonlinear_fun
 
diff --git a/exponax/etdrk/_etdrk_2.py b/exponax/etdrk/_etdrk_2.py
index 45a2123..f7b9d62 100644
--- a/exponax/etdrk/_etdrk_2.py
+++ b/exponax/etdrk/_etdrk_2.py
@@ -20,6 +20,52 @@ def __init__(
         num_circle_points: int = 16,
         circle_radius: float = 1.0,
     ):
+        r"""
+        Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta
+        with a **second order approximation**.
+
+        Adopted from Eq. (22) of [Cox and Matthews
+        (2002)](https://doi.org/10.1006/jcph.2002.6995):
+
+        $$
+            \begin{aligned}
+                \hat{u}_h^* &= \exp(\hat{\mathcal{L}}_h \Delta t) \odot
+                \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) -
+                1}{\hat{\mathcal{L}}_h} \odot
+                \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}). \\ \hat{u}_h^{[t+1]} &=
+                \hat{u}_h^* + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) - 1 -
+                \hat{\mathcal{L}}_h \Delta t}{\hat{\mathcal{L}}_h^2 \Delta t}
+                \left( \hat{\mathcal{N}}_h(\hat{u}_h^*) -
+                \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}) \right)
+            \end{aligned}
+        $$
+
+        where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of
+        the nonlinear differential operator.
+
+        **Arguments:**
+
+        - `dt`: The time step size.
+        - `linear_operator`: The linear operator of the PDE. Must have a leading
+            channel axis, followed by one, two or three spatial axes whereas the
+            last axis must be of size `(N//2)+1` where `N` is the number of
+            dimensions in the former spatial axes.
+        - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the
+            nonlinear differential operator. ! The operator must take care of
+            proper dealiasing.
+        - `num_circle_points`: The number of points on the unit circle used to
+            approximate the numerically challenging coefficients.
+        - `circle_radius`: The radius of the circle used to approximate the
+            numerically challenging coefficients.
+
+        !!! warning
+            The nonlinear function must take care of proper dealiasing.
+
+        !!! note
+            The numerically stable evaluation of the coefficients follows
+            [Kassam and Trefethen
+            (2005)](https://doi.org/10.1137/S1064827502410633).
+        """
         super().__init__(dt, linear_operator)
         self._nonlinear_fun = nonlinear_fun
 
diff --git a/exponax/etdrk/_etdrk_3.py b/exponax/etdrk/_etdrk_3.py
index 8d1cade..5a8ab31 100644
--- a/exponax/etdrk/_etdrk_3.py
+++ b/exponax/etdrk/_etdrk_3.py
@@ -24,6 +24,55 @@ def __init__(
         num_circle_points: int = 16,
         circle_radius: float = 1.0,
     ):
+        r"""
+        Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta
+        with a **third order approximation**.
+
+        Adapted from Eq. (23-25) of [Cox and Matthews
+        (2002)](https://doi.org/10.1006/jcph.2002.6995):
+
+        $$
+            \begin{aligned}
+                \hat{u}_h^* &= \exp(\hat{\mathcal{L}}_h \Delta t / 2) \odot \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t/2) - 1}{\hat{\mathcal{L}}_h} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}).
+                \\
+                \hat{u}_h^{**} &= \exp(\hat{\mathcal{L}}_h \Delta t / 2) \odot \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) - 1}{\hat{\mathcal{L}}_h} \odot \left( 2 \hat{\mathcal{N}}_h(\hat{u}_h^*) - \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}) \right).
+                \\
+                \hat{u}_h^{[t+1]} &= \exp(\hat{\mathcal{L}}_h \Delta t) \odot \hat{u}_h^{[t]}
+                \\
+                &+ \frac{-4 - \exp(\hat{\mathcal{L}}_h \Delta t) + \exp(\hat{\mathcal{L}}_h \Delta) \left( 4 - 3 \hat{\mathcal{L}}_h \Delta t + \left(\hat{\mathcal{L}}_h \Delta t\right)^2 \right)}{\hat{\mathcal{L}}_h^3 (\Delta t)^2} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}).
+                \\
+                &+ 4 \frac{2 + \hat{\mathcal{L}}_h \Delta t + \exp(\hat{\mathcal{L}}_h \Delta t) \left( -2 + \hat{\mathcal{L}}_h \Delta t \right)}{\hat{\mathcal{L}}_h^3 (\Delta t)^2} \odot \hat{\mathcal{N}}_h(\hat{u}_h^*)
+                \\
+                &+ \frac{-4 - 3 \hat{\mathcal{L}}_h \Delta t - \left( \hat{\mathcal{L}}_h \Delta t \right)^2 + \exp(\hat{\mathcal{L}}_h \Delta t) \left( 4 - \hat{\mathcal{L}}_h \Delta t \right)}{\hat{\mathcal{L}}_h^3 (\Delta t)^2} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{**})
+            \end{aligned}
+        $$
+
+        where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of
+        the nonlinear differential operator.
+
+        **Arguments:**
+
+        - `dt`: The time step size.
+        - `linear_operator`: The linear operator of the PDE. Must have a leading
+            channel axis, followed by one, two or three spatial axes whereas the
+            last axis must be of size `(N//2)+1` where `N` is the number of
+            dimensions in the former spatial axes.
+        - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the
+            nonlinear differential operator. ! The operator must take care of
+            proper dealiasing.
+        - `num_circle_points`: The number of points on the unit circle used to
+            approximate the numerically challenging coefficients.
+        - `circle_radius`: The radius of the circle used to approximate the
+            numerically challenging coefficients.
+
+        !!! warning
+            The nonlinear function must take care of proper dealiasing.
+
+        !!! note
+            The numerically stable evaluation of the coefficients follows
+            [Kassam and Trefethen
+            (2005)](https://doi.org/10.1137/S1064827502410633).
+        """
         super().__init__(dt, linear_operator)
         self._nonlinear_fun = nonlinear_fun
         self._half_exp_term = jnp.exp(0.5 * dt * linear_operator)
diff --git a/exponax/etdrk/_etdrk_4.py b/exponax/etdrk/_etdrk_4.py
index a5c7336..04b16ae 100644
--- a/exponax/etdrk/_etdrk_4.py
+++ b/exponax/etdrk/_etdrk_4.py
@@ -25,6 +25,57 @@ def __init__(
         num_circle_points: int = 16,
         circle_radius: float = 1.0,
     ):
+        r"""
+        Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta
+        with a **fourth order approximation**.
+
+        Adapted from Eq. (26-29) of [Cox and Matthews
+        (2002)](https://doi.org/10.1006/jcph.2002.6995):
+
+        $$
+            \begin{aligned}
+                \hat{u}_h^* &= \exp(\hat{\mathcal{L}}_h \Delta t / 2) \odot \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t/2) - 1}{\hat{\mathcal{L}}_h} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}).
+                \\
+                \hat{u}_h^{**} &= \exp(\hat{\mathcal{L}}_h \Delta t / 2) \odot \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t / 2) - 1}{\hat{\mathcal{L}}_h} \odot \hat{\mathcal{N}}_h(\hat{u}_h^*).
+                \\
+                \hat{u}_h^{***} &= \exp(\hat{\mathcal{L}}_h \Delta t) \odot \hat{u}_h^{*} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t/2) - 1}{\hat{\mathcal{L}}_h} \odot \left( 2 \hat{\mathcal{N}}_h(\hat{u}_h^{**}) - \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}) \right).
+                \\
+                \hat{u}_h^{[t+1]} &= \exp(\hat{\mathcal{L}}_h \Delta t) \odot \hat{u}_h^{[t]}
+                \\
+                &+ \frac{-4 - \hat{\mathcal{L}}_h \Delta t + \exp(\hat{\mathcal{L}}_h \Delta t) \left( 4 - 3 \hat{\mathcal{L}}_h \Delta t + \left(\hat{\mathcal{L}}_h \Delta t\right)^2 \right)}{\hat{\mathcal{L}}_h^3 (\Delta t)^2} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]})
+                \\
+                &+ 2 \frac{2 + \hat{\mathcal{L}}_h \Delta t + \exp(\hat{\mathcal{L}}_h \Delta t) \left( -2 + \hat{\mathcal{L}}_h \Delta t \right)}{\hat{\mathcal{L}}_h^3 (\Delta t)^2} \odot \left( \hat{\mathcal{N}}_h(\hat{u}_h^*) + \hat{\mathcal{N}}_h(\hat{u}_h^{**}) \right)
+                \\
+                &+ \frac{-4 - 3 \hat{\mathcal{L}}_h \Delta t - \left( \hat{\mathcal{L}}_h \Delta t \right)^2 + \exp(\hat{\mathcal{L}}_h \Delta t) \left( 4 - \hat{\mathcal{L}}_h \Delta t \right)}{\hat{\mathcal{L}}_h^3 (\Delta t)^2} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{***})
+            \end{aligned}
+        $$
+
+        where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of
+        the nonlinear differential operator.
+
+        **Arguments:**
+
+        - `dt`: The time step size.
+        - `linear_operator`: The linear operator of the PDE. Must have a leading
+            channel axis, followed by one, two or three spatial axes whereas the
+            last axis must be of size `(N//2)+1` where `N` is the number of
+            dimensions in the former spatial axes.
+        - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the
+            nonlinear differential operator. ! The operator must take care of
+            proper dealiasing.
+        - `num_circle_points`: The number of points on the unit circle used to
+            approximate the numerically challenging coefficients.
+        - `circle_radius`: The radius of the circle used to approximate the
+            numerically challenging coefficients.
+
+        !!! warning
+            The nonlinear function must take care of proper dealiasing.
+
+        !!! note
+            The numerically stable evaluation of the coefficients follows
+            [Kassam and Trefethen
+            (2005)](https://doi.org/10.1137/S1064827502410633).
+        """
         super().__init__(dt, linear_operator)
         self._nonlinear_fun = nonlinear_fun
         self._half_exp_term = jnp.exp(0.5 * dt * linear_operator)
diff --git a/exponax/etdrk/_utils.py b/exponax/etdrk/_utils.py
index 909bf6a..ed0abb8 100644
--- a/exponax/etdrk/_utils.py
+++ b/exponax/etdrk/_utils.py
@@ -8,7 +8,16 @@
 
 def roots_of_unity(M: int) -> Complex[Array, "M"]:
     """
-    Return (complex-valued) array with M roots of unity.
+    Return (complex-valued) array with M roots of unity. Useful to perform
+    contour integrals in the complex plane.
+
+    **Arguments:**
+
+    - `M`: The number of roots of unity.
+
+    **Returns:**
+
+    - `roots`: The M roots of unity in an array of shape `(M,)`.
     """
     # return jnp.exp(1j * jnp.pi * (jnp.arange(1, M+1) - 0.5) / M)
     return jnp.exp(2j * jnp.pi * (jnp.arange(1, M + 1) - 0.5) / M)

From 9c92ca88581a6c1384b83a8832ec4829e53e8fab Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 13:22:27 +0200
Subject: [PATCH 03/26] Adapt to Equinox style

---
 exponax/_utils.py | 180 ++++++++++++++++++++++++----------------------
 1 file changed, 96 insertions(+), 84 deletions(-)

diff --git a/exponax/_utils.py b/exponax/_utils.py
index 3de27c5..1d587cd 100644
--- a/exponax/_utils.py
+++ b/exponax/_utils.py
@@ -17,28 +17,31 @@ def make_grid(
     indexing: str = "ij",
 ) -> Float[Array, "D ... N"]:
     """
-    Return a grid in the spatial domain. A grid in d dimensions is an array of
-    shape (d,) + (num_points,)*d with the first axis representing all coordiate
-    inidices.
+    Return a grid in the spatial domain. A grid in D dimensions is an array of
+    shape (D,) + (num_points,)*D with the leading axis representing all
+    coordiate inidices.
 
     Notice, that if `num_spatial_dims = 1`, the returned array has a singleton
     dimension in the first axis, i.e., the shape is `(1, num_points)`.
 
     **Arguments:**
-        - `num_spatial_dims`: The number of spatial dimensions.
-        - `domain_extent`: The extent of the domain in each spatial dimension.
-        - `num_points`: The number of points in each spatial dimension.
-        - `full`: Whether to include the right boundary point in the grid.
-            Default: `False`. The right point is redundant for periodic boundary
-            conditions and is not considered a degree of freedom. Use this
-            option, for example, if you need a full grid for plotting.
-        - `zero_centered`: Whether to center the grid around zero. Default:
-            `False`. By default the grid considers a domain of (0,
-            domain_extent)^(num_spatial_dims).
-        - `indexing`: The indexing convention to use. Default: `'ij'`.
+
+    - `num_spatial_dims`: The number of spatial dimensions.
+    - `domain_extent`: The extent of the domain in each spatial dimension.
+    - `num_points`: The number of points in each spatial dimension.
+    - `full`: Whether to include the right boundary point in the grid.
+        Default: `False`. The right point is redundant for periodic boundary
+        conditions and is not considered a degree of freedom. Use this option,
+        for example, if you need a full grid for plotting.
+    - `zero_centered`: Whether to center the grid around zero. Default:
+        `False`. By default the grid considers a domain of (0,
+        domain_extent)^(num_spatial_dims).
+    - `indexing`: The indexing convention to use. Default: `'ij'`.
 
     **Returns:**
-        - `grid`: The grid in the spatial domain. Shape: `(num_spatial_dims, ..., num_points)`.
+
+    - `grid`: The grid in the spatial domain. Shape: `(num_spatial_dims,
+        ..., num_points)`.
     """
     if full:
         grid_1d = jnp.linspace(0, domain_extent, num_points + 1, endpoint=True)
@@ -59,18 +62,23 @@ def make_grid(
     return grid
 
 
-def wrap_bc(u):
+def wrap_bc(u: Float[Array, "C N"]) -> Float[Array, "C N+1"]:
     """
     Wraps the periodic boundary conditions around the array `u`.
 
     This can be used to plot the solution of a periodic problem on the full
-    interval [0, L] by plotting `wrap_bc(u)` instead of `u`.
+    interval [0, L] by plotting `wrap_bc(u)` instead of `u`. Consider using
+    `exponax.make_grid` with the `full=True` option to create a full grid. Note
+    that all routines in `exponax.viz` already correctly wrap the boundary
+    conditions.
 
-    **Parameters:**
-        - `u`: The array to wrap, shape `(N,)`.
+    **Arguments:**
+
+    - `u`: The array to wrap, shape `(C, N,)`.
 
     **Returns:**
-        - `u_wrapped`: The wrapped array, shape `(N + 1,)`.
+
+    - `u_wrapped`: The wrapped array, shape `(C, N + 1,)`.
     """
     _, *spatial_shape = u.shape
     num_spatial_dims = len(spatial_shape)
@@ -98,33 +106,32 @@ def rollout(
     a force/control or additional metadata (like physical parameters, or time
     for non-autonomous systems).
 
-    Args:
-        - `stepper_fn`: The time stepper to transform. If `takes_aux = False`
-            (default), expected signature is `u_next = stepper_fn(u)`, else
-            `u_next = stepper_fn(u, aux)`. `u` and `u_next` need to be PyTrees
-            of identical structure, in the easiest case just arrays of same
-            shape.
-        - `n`: The number of time steps to rollout the trajectory into the
-            future. If `include_init = False` (default) produces the `n` steps
-            into the future.
-        - `include_init`: Whether to include the initial condition in the
-            trajectory. If `True`, the arrays in the returning PyTree have shape
-            `(n + 1, ...)`, else `(n, ...)`. Default: `False`.
-        - `takes_aux`: Whether the stepper function takes an additional PyTree
-            as second argument.
-        - `constant_aux`: Whether the auxiliary input is constant over the
-            trajectory. If `True`, the auxiliary input is repeated `n` times,
-            otherwise the leading axis in the PyTree arrays has to be of length
-            `n`.
-
-    Returns:
-        - `rollout_stepper_fn`: A function that takes an initial condition `u_0`
-            and an auxiliary input `aux` (if `takes_aux = True`) and produces
-            the trajectory by autoregressively applying the stepper `n` times.
-            If `include_init = True`, the trajectory has shape `(n + 1, ...)`,
-            else `(n, ...)`. Returns a PyTree of the same structure as the
-            initial condition, but with an additional leading axis of length
-            `n`.
+    **Arguments:**
+
+    - `stepper_fn`: The time stepper to transform. If `takes_aux = False`
+        (default), expected signature is `u_next = stepper_fn(u)`, else `u_next
+        = stepper_fn(u, aux)`. `u` and `u_next` need to be PyTrees of identical
+        structure, in the easiest case just arrays of same shape.
+    - `n`: The number of time steps to rollout the trajectory into the
+        future. If `include_init = False` (default) produces the `n` steps into
+        the future.
+    - `include_init`: Whether to include the initial condition in the
+        trajectory. If `True`, the arrays in the returning PyTree have shape `(n
+        + 1, ...)`, else `(n, ...)`. Default: `False`.
+    - `takes_aux`: Whether the stepper function takes an additional PyTree
+        as second argument.
+    - `constant_aux`: Whether the auxiliary input is constant over the
+        trajectory. If `True`, the auxiliary input is repeated `n` times,
+        otherwise the leading axis in the PyTree arrays has to be of length `n`.
+
+    **Returns:**
+
+    - `rollout_stepper_fn`: A function that takes an initial condition `u_0`
+        and an auxiliary input `aux` (if `takes_aux = True`) and produces the
+        trajectory by autoregressively applying the stepper `n` times. If
+        `include_init = True`, the trajectory has shape `(n + 1, ...)`, else
+        `(n, ...)`. Returns a PyTree of the same structure as the initial
+        condition, but with an additional leading axis of length `n`.
     """
 
     if takes_aux:
@@ -196,26 +203,25 @@ def repeat(
     a force/control or additional metadata (like physical parameters, or time
     for non-autonomous systems).
 
-    Args:
-        - `stepper_fn`: The time stepper to transform. If `takes_aux = False`
-            (default), expected signature is `u_next = stepper_fn(u)`, else
-            `u_next = stepper_fn(u, aux)`. `u` and `u_next` need to be PyTrees
-            of identical structure, in the easiest case just arrays of same
-            shape.
-        - `n`: The number of times to apply the stepper.
-        - `takes_aux`: Whether the stepper function takes an additional PyTree
-            as second argument.
-        - `constant_aux`: Whether the auxiliary input is constant over the
-            trajectory. If `True`, the auxiliary input is repeated `n` times,
-            otherwise the leading axis in the PyTree arrays has to be of length
-            `n`.
-
-    Returns:
-        - `repeated_stepper_fn`: A function that takes an initial condition
-            `u_0` and an auxiliary input `aux` (if `takes_aux = True`) and
-            produces the final state by autoregressively applying the stepper
-            `n` times. Returns a PyTree of the same structure as the initial
-            condition.
+    **Arguments:**
+
+    - `stepper_fn`: The time stepper to transform. If `takes_aux = False`
+        (default), expected signature is `u_next = stepper_fn(u)`, else `u_next
+        = stepper_fn(u, aux)`. `u` and `u_next` need to be PyTrees of identical
+        structure, in the easiest case just arrays of same shape.
+    - `n`: The number of times to apply the stepper.
+    - `takes_aux`: Whether the stepper function takes an additional PyTree
+        as second argument.
+    - `constant_aux`: Whether the auxiliary input is constant over the
+        trajectory. If `True`, the auxiliary input is repeated `n` times,
+        otherwise the leading axis in the PyTree arrays has to be of length `n`.
+
+    **Returns:**
+
+    - `repeated_stepper_fn`: A function that takes an initial condition
+        `u_0` and an auxiliary input `aux` (if `takes_aux = True`) and produces
+        the final state by autoregressively applying the stepper `n` times.
+        Returns a PyTree of the same structure as the initial condition.
     """
 
     if takes_aux:
@@ -256,18 +262,22 @@ def stack_sub_trajectories(
     Slice a trajectory into subtrajectories of length `n` and stack them
     together. Useful for rollout training neural operators with temporal mixing.
 
-    !!! Note that this function can produce very large arrays.
+    !!! warning
+        This function can produce very large arrays, especially if `sub_le >>
+        1`.
 
     **Arguments:**
-        - `trj`: The trajectory to slice. Expected shape: `(n_timesteps, ...)`.
-        - `sub_len`: The length of the subtrajectories. If you want to perform rollout
-            training with k steps, note that `n=k+1` to also have an initial
-            condition in the subtrajectories.
+
+    - `trj`: The trajectory to slice. Expected shape: `(n_timesteps, ...)`.
+    - `sub_len`: The length of the subtrajectories. If you want to perform
+        rollout training with k steps, note that `n=k+1` to also have an initial
+        condition in the subtrajectories.
 
     **Returns:**
-        - `sub_trjs`: The stacked subtrajectories. Expected shape: `(n_stacks, n, ...)`.
-           `n_stacks` is the number of subtrajectories stacked together, i.e.,
-           `n_timesteps - n + 1`.
+
+    - `sub_trjs`: The stacked subtrajectories. Expected shape: `(n_stacks,
+        n, ...)`. `n_stacks` is the number of subtrajectories stacked together,
+        i.e., `n_timesteps - n + 1`.
     """
     n_time_steps = [leaf.shape[0] for leaf in jtu.tree_leaves(trj)]
 
@@ -303,26 +313,28 @@ def scan_fn(_, i):
 
 
 def build_ic_set(
-    ic_generator,
+    ic_generator: Callable[[int, PRNGKeyArray], Float[Array, "C ... N"]],
     *,
     num_points: int,
     num_samples: int,
     key: PRNGKeyArray,
-) -> Float[Array, "S 1 ... N"]:
+) -> Float[Array, "S C ... N"]:
     """
     Generate a set of initial conditions by sampling from a given initial
     condition distribution and evaluating the function on the given grid.
 
     **Arguments:**
-        - `ic_generator`: A function that takes a PRNGKey and returns a
-            function that takes a grid and returns a sample from the initial
-            condition distribution.
-        - `num_samples`: The number of initial conditions to sample.
-        - `key`: The PRNGKey to use for sampling.
+
+    - `ic_generator`: A function that takes a number of points and a PRNGKey
+        and returns an array representing the discrete state of an initial
+        condition. The shape of the returned array is `(C, ..., N)`.
+    - `num_samples`: The number of initial conditions to sample.
+    - `key`: The PRNGKey to use for sampling.
 
     **Returns:**
-        - `ic_set`: The set of initial conditions. Shape: `(S, 1, ..., N)`.
-            `S = num_samples`.
+
+    - `ic_set`: The set of initial conditions. Shape: `(S, C, ..., N)`.
+        `S = num_samples`.
     """
 
     def scan_fn(k, _):

From 4c86177dff47360ac2e90599966e8fb08a01473a Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 13:27:24 +0200
Subject: [PATCH 04/26] Change to Equinox style

---
 exponax/_poisson.py | 36 ++++++++++++++++++++++++++----------
 1 file changed, 26 insertions(+), 10 deletions(-)

diff --git a/exponax/_poisson.py b/exponax/_poisson.py
index 19168ff..2c9b9d4 100644
--- a/exponax/_poisson.py
+++ b/exponax/_poisson.py
@@ -41,11 +41,12 @@ def __init__(
         It is included for completion.
 
         **Arguments:**
-            - `num_spatial_dims`: The number of spatial dimensions.
-            - `domain_extent`: The extent of the domain.
-            - `num_points`: The number of points in each spatial dimension.
-            - `order`: The order of the Poisson equation. Defaults to 2. You can
-              also set `order=4` for the biharmonic equation.
+
+        - `num_spatial_dims`: The number of spatial dimensions.
+        - `domain_extent`: The extent of the domain.
+        - `num_points`: The number of points in each spatial dimension.
+        - `order`: The order of the Poisson equation. Defaults to 2. You can
+            also set `order=4` for the biharmonic equation.
         """
         self.num_spatial_dims = num_spatial_dims
         self.domain_extent = domain_extent
@@ -71,10 +72,12 @@ def step_fourier(
         Solve the Poisson equation in Fourier space.
 
         **Arguments:**
-            - `f_hat`: The Fourier transform of the right hand side.
+
+        - `f_hat`: The Fourier transform of the right hand side.
 
         **Returns:**
-            - `u_hat`: The Fourier transform of the solution.
+
+        - `u_hat`: The Fourier transform of the solution.
         """
         return -self._inv_operator * f_hat
 
@@ -83,13 +86,15 @@ def step(
         f: Float[Array, "C ... N"],
     ) -> Float[Array, "C ... N"]:
         """
-        Solve the Poisson equation in real space.
+        Solve the Poisson equation in state space.
 
         **Arguments:**
-            - `f`: The right hand side.
+
+        - `f`: The right hand side.
 
         **Returns:**
-            - `u`: The solution.
+
+        - `u`: The solution.
         """
         f_hat = fft(f, num_spatial_dims=self.num_spatial_dims)
         u_hat = self.step_fourier(f_hat)
@@ -104,6 +109,17 @@ def __call__(
         self,
         f: Float[Array, "C ... N"],
     ) -> Float[Array, "C ... N"]:
+        """
+        Solve the Poisson equation in state space.
+
+        **Arguments:**
+
+        - `f`: The right hand side.
+
+        **Returns:**
+
+        - `u`: The solution.
+        """
         if f.shape[1:] != spatial_shape(self.num_spatial_dims, self.num_points):
             raise ValueError(
                 f"Shape of f[1:] is {f.shape[1:]} but should be {spatial_shape(self.num_spatial_dims, self.num_points)}"

From 88858a4492f16197b4f199102c37a815a152bfbb Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 13:29:26 +0200
Subject: [PATCH 05/26] Change to Equinox style

---
 exponax/_forced_stepper.py | 32 +++++++++++++++++++-------------
 1 file changed, 19 insertions(+), 13 deletions(-)

diff --git a/exponax/_forced_stepper.py b/exponax/_forced_stepper.py
index 149b339..81d684b 100644
--- a/exponax/_forced_stepper.py
+++ b/exponax/_forced_stepper.py
@@ -33,7 +33,8 @@ def __init__(
         transient integrators to forced problems.
 
         **Arguments**:
-            - `stepper`: The stepper to be transformed.
+
+        - `stepper`: The stepper to be transformed.
         """
         self.stepper = stepper
 
@@ -49,11 +50,13 @@ def step(
         The forcing term `f` is assumed to be evaluated on the same grid as `u`.
 
         **Arguments**:
-            - `u`: The current state.
-            - `f`: The forcing term.
+
+        - `u`: The current state.
+        - `f`: The forcing term.
 
         **Returns**:
-            - `u_next`: The state after one time step.
+
+        - `u_next`: The state after one time step.
         """
         u_with_force = u + self.stepper.dt * f
         return self.stepper.step(u_with_force)
@@ -71,11 +74,13 @@ def step_fourier(
         `u_hat`.
 
         **Arguments**:
-            - `u_hat`: The current state in Fourier space.
-            - `f_hat`: The forcing term in Fourier space.
+
+        - `u_hat`: The current state in Fourier space.
+        - `f_hat`: The forcing term in Fourier space.
 
         **Returns**:
-            - `u_next_hat`: The state after one time step in Fourier space.
+
+        - `u_next_hat`: The state after one time step in Fourier space.
         """
         u_hat_with_force = u_hat + self.stepper.dt * f_hat
         return self.stepper.step_fourier(u_hat_with_force)
@@ -91,12 +96,13 @@ def __call__(
 
         The forcing term `f` is assumed to be evaluated on the same grid as `u`.
 
-        **Arguments**:
-            - `u`: The current state.
-            - `f`: The forcing term.
+        **Arguments:**
 
-        **Returns**:
-            - `u_next`: The state after one time step.
-        """
+        - `u`: The current state.
+        - `f`: The forcing term.
 
+        **Returns:**
+
+        - `u_next`: The state after one time step.
+        """
         return self.step(u, f)

From 386867f48b36a198e585e1acd1b0e53863435a46 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 13:30:43 +0200
Subject: [PATCH 06/26] Change to Equinox style

---
 exponax/_repeated_stepper.py | 32 +++++++++++++++++++++++++++++---
 1 file changed, 29 insertions(+), 3 deletions(-)

diff --git a/exponax/_repeated_stepper.py b/exponax/_repeated_stepper.py
index eb7045f..4914622 100644
--- a/exponax/_repeated_stepper.py
+++ b/exponax/_repeated_stepper.py
@@ -33,8 +33,9 @@ def __init__(
         time step of X/Y and then wrap it in a RepeatedStepper with num_sub_steps=Y.
 
         **Arguments:**
-            - `stepper`: The stepper to repeat.
-            - `num_sub_steps`: The number of substeps to perform.
+
+        - `stepper`: The stepper to repeat.
+        - `num_sub_steps`: The number of substeps to perform.
         """
         self.stepper = stepper
         self.num_sub_steps = num_sub_steps
@@ -52,8 +53,16 @@ def step(
         u: Float[Array, "C ... N"],
     ) -> Float[Array, "C ... N"]:
         """
-        Step the PDE forward in time by self.num_sub_steps time steps given the
+        Step the PDE forward in time by `self.num_sub_steps` time steps given the
         current state `u`.
+
+        **Arguments:**
+
+        - `u`: The current state.
+
+        **Returns:**
+
+        - `u_next`: The state after `self.num_sub_steps` time steps.
         """
         return repeat(self.stepper.step, self.num_sub_steps)(u)
 
@@ -64,6 +73,15 @@ def step_fourier(
         """
         Step the PDE forward in time by self.num_sub_steps time steps given the
         current state `u_hat` in real-valued Fourier space.
+
+        **Arguments:**
+
+        - `u_hat`: The current state in Fourier space.
+
+        **Returns:**
+
+        - `u_next_hat`: The state after `self.num_sub_steps` time steps in Fourier
+            space.
         """
         return repeat(self.stepper.step_fourier, self.num_sub_steps)(u_hat)
 
@@ -74,5 +92,13 @@ def __call__(
         """
         Step the PDE forward in time by self.num_sub_steps time steps given the
         current state `u`.
+
+        **Arguments:**
+
+        - `u`: The current state.
+
+        **Returns:**
+
+        - `u_next`: The state after `self.num_sub_steps` time steps.
         """
         return repeat(self.stepper, self.num_sub_steps)(u)

From 6fb0a5f9c109493b3d4da3e4cca7514c5fd531ec Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 13:36:14 +0200
Subject: [PATCH 07/26] Fix typo

---
 exponax/stepper/_advection.py           | 2 +-
 exponax/stepper/_advection_diffusion.py | 2 +-
 exponax/stepper/_diffusion.py           | 2 +-
 exponax/stepper/_dispersion.py          | 2 +-
 exponax/stepper/_hyper_diffusion.py     | 2 +-
 5 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/exponax/stepper/_advection.py b/exponax/stepper/_advection.py
index 34fc3c8..cbabf0a 100644
--- a/exponax/stepper/_advection.py
+++ b/exponax/stepper/_advection.py
@@ -61,7 +61,7 @@ def __init__(
 
         **Notes:**
 
-        - The stepper is unconditionally stable, not matter the choice of
+        - The stepper is unconditionally stable, no matter the choice of
             any argument because the equation is solved analytically in Fourier
             space. **However**, note that initial conditions with modes higher
             than the Nyquist freuency (`(N//2)+1` with `N` being the
diff --git a/exponax/stepper/_advection_diffusion.py b/exponax/stepper/_advection_diffusion.py
index 4a00b75..a965176 100644
--- a/exponax/stepper/_advection_diffusion.py
+++ b/exponax/stepper/_advection_diffusion.py
@@ -78,7 +78,7 @@ def __init__(
 
         **Notes:**
 
-        - The stepper is unconditionally stable, not matter the choice of
+        - The stepper is unconditionally stable, no matter the choice of
             any argument because the equation is solved analytically in Fourier
             space. **However**, note that initial conditions with modes higher
             than the Nyquist freuency (`(N//2)+1` with `N` being the
diff --git a/exponax/stepper/_diffusion.py b/exponax/stepper/_diffusion.py
index ccc1f85..403a66c 100644
--- a/exponax/stepper/_diffusion.py
+++ b/exponax/stepper/_diffusion.py
@@ -73,7 +73,7 @@ def __init__(
 
         **Notes:**
 
-        - The stepper is unconditionally stable, not matter the choice of
+        - The stepper is unconditionally stable, no matter the choice of
             any argument because the equation is solved analytically in Fourier
             space.
         - A `ν > 0` leads to stable and decaying solutions (i.e., energy is
diff --git a/exponax/stepper/_dispersion.py b/exponax/stepper/_dispersion.py
index b7d8347..a330af1 100644
--- a/exponax/stepper/_dispersion.py
+++ b/exponax/stepper/_dispersion.py
@@ -75,7 +75,7 @@ def __init__(
 
         **Notes:**
 
-        - The stepper is unconditionally stable, not matter the choice of
+        - The stepper is unconditionally stable, no matter the choice of
             any argument because the equation is solved analytically in Fourier
             space. **However**, note that initial conditions with modes higher
             than the Nyquist freuency (`(N//2)+1` with `N` being the
diff --git a/exponax/stepper/_hyper_diffusion.py b/exponax/stepper/_hyper_diffusion.py
index 73bac08..9c0623f 100644
--- a/exponax/stepper/_hyper_diffusion.py
+++ b/exponax/stepper/_hyper_diffusion.py
@@ -72,7 +72,7 @@ def __init__(
 
         **Notes:**
 
-        - The stepper is unconditionally stable, not matter the choice of
+        - The stepper is unconditionally stable, no matter the choice of
             any argument because the equation is solved analytically in Fourier
             space.
         - Ultimately, only the factor `μ Δt / L⁴` affects the characteristic

From 98ed94b2bcaaa044aba20ccac8ddfbd56dd7c2d8 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 13:36:49 +0200
Subject: [PATCH 08/26] Consistent documentation

---
 exponax/_base_stepper.py | 52 +++++++++++++++++++++++++++++-----------
 1 file changed, 38 insertions(+), 14 deletions(-)

diff --git a/exponax/_base_stepper.py b/exponax/_base_stepper.py
index c41a8a2..80a8875 100644
--- a/exponax/_base_stepper.py
+++ b/exponax/_base_stepper.py
@@ -164,12 +164,14 @@ def _build_linear_operator(
         Assemble the L operator in Fourier space.
 
         **Arguments:**
-            - `derivative_operator`: The derivative operator, shape `( D, ...,
-              N//2+1 )`. The ellipsis are (D-1) axis of size N (**not** of size
-              N//2+1).
+
+        - `derivative_operator`: The derivative operator, shape `( D, ...,
+            N//2+1 )`. The ellipsis are (D-1) axis of size N (**not** of size
+            N//2+1).
 
         **Returns:**
-            - `L`: The linear operator, shape `( C, ..., N//2+1 )`.
+
+        - `L`: The linear operator, shape `( C, ..., N//2+1 )`.
         """
         pass
 
@@ -183,12 +185,15 @@ def _build_nonlinear_fun(
         transforms to Fourier space, and evaluates derivatives there.
 
         **Arguments:**
-            - `derivative_operator`: The derivative operator, shape `( D, ..., N//2+1 )`.
+
+        - `derivative_operator`: The derivative operator, shape `( D, ...,
+            N//2+1 )`.
 
         **Returns:**
-            - `nonlinear_fun`: A function that evaluates the nonlinearities in
-                time space, transforms to Fourier space, and evaluates the
-                derivatives there. Should be a subclass of `BaseNonlinearFun`.
+
+        - `nonlinear_fun`: A function that evaluates the nonlinearities in
+            time space, transforms to Fourier space, and evaluates the
+            derivatives there. Should be a subclass of `BaseNonlinearFun`.
         """
         pass
 
@@ -197,10 +202,12 @@ def step(self, u: Float[Array, "C ... N"]) -> Float[Array, "C ... N"]:
         Perform one step of the time integration.
 
         **Arguments:**
-            - `u`: The state vector, shape `(C, ..., N,)`.
+
+        - `u`: The state vector, shape `(C, ..., N,)`.
 
         **Returns:**
-            - `u_next`: The state vector after one step, shape `(C, ..., N,)`.
+
+        - `u_next`: The state vector after one step, shape `(C, ..., N,)`.
         """
         u_hat = fft(u, num_spatial_dims=self.num_spatial_dims)
         u_next_hat = self.step_fourier(u_hat)
@@ -220,11 +227,13 @@ def step_fourier(
         transforms.
 
         **Arguments:**
-            - `u_hat`: The (real) Fourier transform of the state vector
+
+        - `u_hat`: The (real) Fourier transform of the state vector
 
         **Returns:**
-            - `u_next_hat`: The (real) Fourier transform of the state vector
-                after one step
+
+        - `u_next_hat`: The (real) Fourier transform of the state vector
+            after one step
         """
         return self._integrator.step_fourier(u_hat)
 
@@ -233,7 +242,22 @@ def __call__(
         u: Float[Array, "C ... N"],
     ) -> Float[Array, "C ... N"]:
         """
-        Performs a check
+        Perform one step of the time integration for a single state.
+
+        **Arguments:**
+
+        - `u`: The state vector, shape `(C, ..., N,)`.
+
+        **Returns:**
+
+        - `u_next`: The state vector after one step, shape `(C, ..., N,)`.
+
+        !!! tip
+            Use this call method together with `exponax.rollout` to produce
+            temporal trajectories by efficiently autogressive rollout.
+
+        !!! info
+            For batched operation, use `jax.vmap` on this function.
         """
         expected_shape = (self.num_channels,) + spatial_shape(
             self.num_spatial_dims, self.num_points

From 3375fb994a9348db74e285b337315113a7aa706a Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 13:46:12 +0200
Subject: [PATCH 09/26] Improve nonlin fun base documentation

---
 exponax/nonlin_fun/_base.py | 77 ++++++++++++++++++++++++++++++++++++-
 1 file changed, 76 insertions(+), 1 deletion(-)

diff --git a/exponax/nonlin_fun/_base.py b/exponax/nonlin_fun/_base.py
index d04d992..8770b9b 100644
--- a/exponax/nonlin_fun/_base.py
+++ b/exponax/nonlin_fun/_base.py
@@ -19,6 +19,31 @@ def __init__(
         *,
         dealiasing_fraction: Optional[float] = None,
     ):
+        """
+        Base class for all nonlinear functions. This class provides the basic
+        functionality to dealias the nonlinear terms and perform forward and
+        inverse Fourier transforms.
+
+        **Arguments:**
+
+        - `num_spatial_dims`: The number of spatial dimensions `D`.
+        - `num_points`: The number of points `N` used to discretize the domain.
+            This **includes** the left boundary point and **excludes** the right
+            boundary point. In higher dimensions; the number of points in each
+            dimension is the same.
+        - `dealiasing_fraction`: The fraction of the highest resolved mode to
+            keep for dealiasing. For example, `2/3` corresponds to Orszag's 2/3
+            rule typically used for quadratic nonlinearities. If `None`, no
+            dealiasing is performed.
+
+        !!! info
+            Some dealiasing strategies (like Orszag's 2/3 rule) are designed to
+            not fully remove aliasing (which would require 1/2 in the case of
+            quadratic nonlinearities), rather to only have aliases being created
+            in those modes that will be zeroed out anyway in the next
+            dealiasing step. See also [Orszag
+            (1971)](https://doi.org/10.1175/1520-0469(1971)028%3C1074:OTEOAI%3E2.0.CO;2)
+        """
         self.num_spatial_dims = num_spatial_dims
         self.num_points = num_points
 
@@ -39,14 +64,52 @@ def __init__(
     def dealias(
         self, u_hat: Complex[Array, "C ... (N//2)+1"]
     ) -> Complex[Array, "C ... (N//2)+1"]:
+        """
+        Dealias the Fourier representation of a state `u_hat` by zeroing out all
+        the coefficients associated with modes beyond `dealiasing_fraction` set
+        in the constructor.
+
+        **Arguments:**
+
+        - `u_hat`: The Fourier representation of the state `u`.
+
+        **Returns:**
+
+        - `u_hat_dealiased`: The dealiased Fourier representation of the state
+            `u`.
+        """
         if self.dealiasing_mask is None:
             raise ValueError("Nonlinear function was set up without dealiasing")
         return self.dealiasing_mask * u_hat
 
     def fft(self, u: Float[Array, "C ... N"]) -> Complex[Array, "C ... (N//2)+1"]:
+        """
+        Correctly wrapped **real-valued** Fourier transform for the shape of the
+        state vector associated with this nonlinear function.
+
+        **Arguments:**
+
+        - `u`: The state vector in real space.
+
+        **Returns:**
+
+        - `u_hat`: The (real-valued) Fourier transform of the state vector.
+        """
         return fft(u, num_spatial_dims=self.num_spatial_dims)
 
     def ifft(self, u_hat: Complex[Array, "C ... (N//2)+1"]) -> Float[Array, "C ... N"]:
+        """
+        Correctly wrapped **real-valued** inverse Fourier transform for the shape
+        of the state vector associated with this nonlinear function.
+
+        **Arguments:**
+
+        - `u_hat`: The (real-valued) Fourier transform of the state vector.
+
+        **Returns:**
+
+        - `u`: The state vector in real space.
+        """
         return ifft(
             u_hat, num_spatial_dims=self.num_spatial_dims, num_points=self.num_points
         )
@@ -57,6 +120,18 @@ def __call__(
         u_hat: Complex[Array, "C ... (N//2)+1"],
     ) -> Complex[Array, "C ... (N//2)+1"]:
         """
-        Evaluate all potential nonlinearities "pseudo-spectrally", account for dealiasing.
+        Evaluates the nonlinear function with a pseudo-spectral treatment and
+        accounts for dealiasing.
+
+        Use this in combination with `exponax.etdrk` routines to solve
+        semi-linear PDEs in Fourier space.
+
+        **Arguments:**
+
+        - `u_hat`: The Fourier representation of the state `u`.
+
+        **Returns:**
+
+        - `𝒩(u_hat)`: The Fourier representation of the nonlinear term.
         """
         pass

From 6471b789c6044778ae6cf4d7a1ac27fc91dbed6f Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 14:07:22 +0200
Subject: [PATCH 10/26] Change to Equinox style

---
 exponax/nonlin_fun/_convection.py | 63 ++++++++++++++++++++++++-------
 1 file changed, 50 insertions(+), 13 deletions(-)

diff --git a/exponax/nonlin_fun/_convection.py b/exponax/nonlin_fun/_convection.py
index 0e30f9a..9f4d151 100644
--- a/exponax/nonlin_fun/_convection.py
+++ b/exponax/nonlin_fun/_convection.py
@@ -43,19 +43,20 @@ def __init__(
         ```
 
         **Arguments:**
-            - `num_spatial_dims`: The number of spatial dimensions `d`.
-            - `num_points`: The number of points `N` used to discretize the
-                domain. This **includes** the left boundary point and
-                **excludes** the right boundary point. In higher dimensions; the
-                number of points in each dimension is the same.
-            - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)`
-                that represents the derivative operator in Fourier space.
-            - `dealiasing_fraction`: The fraction of the highest resolved modes
-                that are not aliased. Defaults to `2/3` which corresponds to
-                Orszag's 2/3 rule.
-            - `scale`: The scale `b₁` of the convection term. Defaults to `1.0`.
-            - `single_channel`: Whether to use the single-channel hack. Defaults
-                to `False`.
+
+        - `num_spatial_dims`: The number of spatial dimensions `d`.
+        - `num_points`: The number of points `N` used to discretize the
+            domain. This **includes** the left boundary point and **excludes**
+            the right boundary point. In higher dimensions; the number of points
+            in each dimension is the same.
+        - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)`
+            that represents the derivative operator in Fourier space.
+        - `dealiasing_fraction`: The fraction of the highest resolved modes
+            that are not aliased. Defaults to `2/3` which corresponds to
+            Orszag's 2/3 rule.
+        - `scale`: The scale `b₁` of the convection term. Defaults to `1.0`.
+        - `single_channel`: Whether to use the single-channel hack. Defaults
+            to `False`.
         """
         self.derivative_operator = derivative_operator
         self.scale = scale
@@ -69,6 +70,24 @@ def __init__(
     def _multi_channel_eval(
         self, u_hat: Complex[Array, "C ... (N//2)+1"]
     ) -> Complex[Array, "C ... (N//2)+1"]:
+        """
+        Evaluates the convection term for a multi-channel state `u_hat` in
+        Fourier space. The convection term is given by
+
+        ```
+            𝒩(u) = b₁ 1/2 ∇ ⋅ (u ⊗ u)
+        ```
+
+        with `∇ ⋅` the divergence operator and the outer product `u ⊗ u`.
+
+        **Arguments:**
+
+        - `u_hat`: The state in Fourier space.
+
+        **Returns:**
+
+        - `convection`: The evaluation of the convection term in Fourier space.
+        """
         num_channels = u_hat.shape[0]
         if num_channels != self.num_spatial_dims:
             raise ValueError(
@@ -88,6 +107,24 @@ def _multi_channel_eval(
     def _single_channel_eval(
         self, u_hat: Complex[Array, "C ... (N//2)+1"]
     ) -> Complex[Array, "C ... (N//2)+1"]:
+        """
+        Evaluates the convection term for a single-channel state `u_hat` in
+        Fourier space. The convection term is given by
+
+        ```
+            𝒩(u) = b₁ 1/2 (1⃗ ⋅ ∇)(u²)
+        ```
+
+        with `∇ ⋅` the divergence operator and `1⃗` a vector of ones.
+
+        **Arguments:**
+
+        - `u_hat`: The state in Fourier space.
+
+        **Returns:**
+
+        - `convection`: The evaluation of the convection term in Fourier space.
+        """
         u_hat_dealiased = self.dealias(u_hat)
         u = self.ifft(u_hat_dealiased)
         u_square = u**2

From 8a94c0ef42ddf927a30e8db3433b4a80c5ebaad8 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 14:19:16 +0200
Subject: [PATCH 11/26] Improve docs for convection

---
 exponax/nonlin_fun/_convection.py | 16 +++++++++++-----
 1 file changed, 11 insertions(+), 5 deletions(-)

diff --git a/exponax/nonlin_fun/_convection.py b/exponax/nonlin_fun/_convection.py
index 9f4d151..ea594ce 100644
--- a/exponax/nonlin_fun/_convection.py
+++ b/exponax/nonlin_fun/_convection.py
@@ -24,14 +24,20 @@ def __init__(
         found in the Burgers equation. In 1d and state space, this reads
 
         ```
-            𝒩(u) = b₁ 1/2 (u²)ₓ
+            𝒩(u) = - b₁ 1/2 (u²)ₓ
         ```
 
-        with a scale `b₁`. The typical extension to higher dimensions requires u
-        to have as many channels as spatial dimensions and then gives
+        with a scale `b₁`. The minus arises because `Exponax` follows the
+        convention that all nonlinear and linear differential operators are on
+        the right-hand side of the equation. Typically, the convection term is
+        on the left-hand side. Hence, the minus is required to move the term to
+        the right-hand side.
+
+        The typical extension to higher dimensions requires u to have as many
+        channels as spatial dimensions and then gives
 
         ```
-            𝒩(u) = b₁ 1/2 ∇ ⋅ (u ⊗ u)
+            𝒩(u) = - b₁ 1/2 ∇ ⋅ (u ⊗ u)
         ```
 
         with `∇ ⋅` the divergence operator and the outer product `u ⊗ u`.
@@ -39,7 +45,7 @@ def __init__(
         matter the spatial dimensions. This reads
 
         ```
-            𝒩(u) = b₁ 1/2 (1⃗ ⋅ ∇)(u²)
+            𝒩(u) = - b₁ 1/2 (1⃗ ⋅ ∇)(u²)
         ```
 
         **Arguments:**

From 4e510b197a876409384e37d930b83016710828f9 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 14:19:39 +0200
Subject: [PATCH 12/26] Improve docs for gradient norm

---
 exponax/nonlin_fun/_gradient_norm.py | 53 ++++++++++++++++------------
 1 file changed, 30 insertions(+), 23 deletions(-)

diff --git a/exponax/nonlin_fun/_gradient_norm.py b/exponax/nonlin_fun/_gradient_norm.py
index df946c0..53f1fdd 100644
--- a/exponax/nonlin_fun/_gradient_norm.py
+++ b/exponax/nonlin_fun/_gradient_norm.py
@@ -26,38 +26,45 @@ def __init__(
         In 1d and state space, this reads
 
         ```
-            𝒩(u) = b₂ 1/2 (u²)ₓ
+            𝒩(u) = - b₂ 1/2 (uₓ)²
         ```
 
-        with a scale `b₂`. In higher dimensions, u has to be single channel and
-        the nonlinear function reads
+        with a scale `b₂`. The minus arises because `Exponax` follows the
+        convention that all nonlinear and linear differential operators are on
+        the right-hand side of the equation. Typically, the gradient norm term
+        is on the left-hand side. Hence, the minus is required to move the term
+        to the right-hand side.
+
+        In higher dimensions, u has to be single channel and the nonlinear
+        function reads
 
         ```
-            𝒩(u) = b₂ 1/2 ‖∇u‖₂²
+            𝒩(u) = - b₂ 1/2 ‖∇u‖₂²
         ```
 
         with `‖∇u‖₂²` the squared L2 norm of the gradient of `u`.
 
         **Arguments:**
-            - `num_spatial_dims`: The number of spatial dimensions `d`.
-            - `num_points`: The number of points `N` used to discretize the
-                domain. This **includes** the left boundary point and
-                **excludes** the right boundary point. In higher dimensions; the
-                number of points in each dimension is the same.
-            - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)`
-                that represents the derivative operator in Fourier space.
-            - `dealiasing_fraction`: The fraction of the highest resolved modes
-                that are not aliased. Defaults to `2/3` which corresponds to
-                Orszag's 2/3 rule.
-            - `zero_mode_fix`: Whether to set the zero mode to zero. In other
-                words, whether to have mean zero energy after nonlinear function
-                activation. This exists because the nonlinear operation happens
-                after the derivative operator is applied. Naturally, the
-                derivative sets any constant offset to zero. However, the square
-                nonlinearity introduces again a new constant offset. Setting
-                this argument to `True` removes this offset. Defaults to `True`.
-            - `scale`: The scale `b₂` of the gradient norm term. Defaults to
-              `1.0`.
+
+        - `num_spatial_dims`: The number of spatial dimensions `d`.
+        - `num_points`: The number of points `N` used to discretize the
+            domain. This **includes** the left boundary point and **excludes**
+            the right boundary point. In higher dimensions; the number of points
+            in each dimension is the same.
+        - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)`
+            that represents the derivative operator in Fourier space.
+        - `dealiasing_fraction`: The fraction of the highest resolved modes
+            that are not aliased. Defaults to `2/3` which corresponds to
+            Orszag's 2/3 rule.
+        - `zero_mode_fix`: Whether to set the zero mode to zero. In other
+            words, whether to have mean zero energy after nonlinear function
+            activation. This exists because the nonlinear operation happens
+            after the derivative operator is applied. Naturally, the derivative
+            sets any constant offset to zero. However, the square nonlinearity
+            introduces again a new constant offset. Setting this argument to
+            `True` removes this offset. Defaults to `True`.
+        - `scale`: The scale `b₂` of the gradient norm term. Defaults to
+            `1.0`.
         """
         super().__init__(
             num_spatial_dims,

From bd2a4285e130f16cbcc86b9c7a0d503dda10f437 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 14:19:59 +0200
Subject: [PATCH 13/26] Improve docs for combined general nonlinearity

---
 exponax/nonlin_fun/_general_nonlinear.py | 50 +++++++++++++++++++++++-
 1 file changed, 48 insertions(+), 2 deletions(-)

diff --git a/exponax/nonlin_fun/_general_nonlinear.py b/exponax/nonlin_fun/_general_nonlinear.py
index 00c848d..c266572 100644
--- a/exponax/nonlin_fun/_general_nonlinear.py
+++ b/exponax/nonlin_fun/_general_nonlinear.py
@@ -22,9 +22,55 @@ def __init__(
         zero_mode_fix: bool = True,
     ):
         """
-        Uses an additional scaling of 0.5 on the latter two components only
+        Fourier pseudo-spectral evaluation of a nonlinear differential operator
+        that has a square, convection (with single-channel hack), and gradient
+        norm term. In 1D and state space, this reads
 
-        By default: Burgers equation
+        ```
+            𝒩(u) = b₀ u² + b₁ 1/2 (u²)ₓ + b₂ 1/2 (uₓ)²
+        ```
+
+        The higher-dimensional extension is designed for a single-channel state
+        `u` (i.e., the number of channels do not grow with the number of spatial
+        dimensions, see also the description of
+        `exponax.nonlin_fun.ConvectionNonlinearFun`). The extension reads
+
+        ```
+            𝒩(u) = b₀ u² + b₁ 1/2 (1⃗ ⋅ ∇)(u²) + b₂ 1/2 ‖∇u‖₂²
+        ```
+
+        !!! warning
+            In contrast to the individual nonlinear functions
+            `exponax.nonlin_fun.ConvectionNonlinearFun` and
+            `exponax.nonlin_fun.GradientNormNonlinearFun`, there is no minus.
+            Hence, to have a "propoper" convection term, consider supplying a
+            negative scale for the convection term, etc.
+
+        **Arguments**:
+
+        - `num_spatial_dims`: The number of spatial dimensions `D`.
+        - `num_points`: The number of points `N` used to discretize the domain.
+            This **includes** the left boundary point and **excludes** the right
+            boundary point. In higher dimensions; the number of points in each
+            dimension is the same.
+        - `derivative_operator`: A complex array of shape `(D, ..., N//2+1)`
+            that represents the derivative operator in Fourier space.
+        - `dealiasing_fraction`: The fraction of the highest resolved modes that
+            are not aliased. Defaults to `2/3` which corresponds to Orszag's 2/3
+            rule.
+        - `scale_list`: A tuple of three floats `[b₀, b₁, b₂]` that represent
+            the scales of the square, (single-channel) convection, and gradient
+            norm term, respectively. Defaults to `[0.0, -1.0, 0.0]` which
+            corresponds to a pure convection term (i.e, in 1D together with a
+            diffusion linear term, this would be the Burgers equation). !!!
+            important: note that negation has to be manually provided!
+        - `zero_mode_fix`: Whether to set the zero mode to zero. In other words,
+            whether to have mean zero energy after nonlinear function activation.
+            This exists because the nonlinear operation happens after the
+            derivative operator is applied. Naturally, the derivative sets any
+            constant offset to zero. However, the square nonlinearity introduces
+            again a new constant offset. Setting this argument to `True` removes
+            this offset. Defaults to `True`.
         """
         if len(scale_list) != 3:
             raise ValueError("The scale list must have exactly 3 elements")

From 1f5342e1c7e8b9c48105ef83459de3af074b72ad Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 14:20:42 +0200
Subject: [PATCH 14/26] Fix docstring

---
 exponax/nonlin_fun/_zero.py | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/exponax/nonlin_fun/_zero.py b/exponax/nonlin_fun/_zero.py
index f365d11..322146a 100644
--- a/exponax/nonlin_fun/_zero.py
+++ b/exponax/nonlin_fun/_zero.py
@@ -19,11 +19,12 @@ def __init__(
         ```
 
         **Arguments:**
-            - `num_spatial_dims`: The number of spatial dimensions `d`.
-            - `num_points`: The number of points `N` used to discretize the
-                domain. This **includes** the left boundary point and
-                **excludes** the right boundary point. In higher dimensions; the
-                number of points in each dimension is the same.
+
+        - `num_spatial_dims`: The number of spatial dimensions `d`.
+        - `num_points`: The number of points `N` used to discretize the domain.
+            This **includes** the left boundary point and **excludes** the right
+            boundary point. In higher dimensions; the number of points in each
+            dimension is the same.
         """
         super().__init__(
             num_spatial_dims,

From 01a155e5d18738dcdb1beebeba8a5520c5b033af Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 14:25:47 +0200
Subject: [PATCH 15/26] Improve and fix documentation

---
 exponax/nonlin_fun/_vorticity_convection.py | 76 ++++++++++++---------
 1 file changed, 44 insertions(+), 32 deletions(-)

diff --git a/exponax/nonlin_fun/_vorticity_convection.py b/exponax/nonlin_fun/_vorticity_convection.py
index 46231ab..558f56f 100644
--- a/exponax/nonlin_fun/_vorticity_convection.py
+++ b/exponax/nonlin_fun/_vorticity_convection.py
@@ -25,25 +25,35 @@ def __init__(
         streamfunction-vorticity formulation. In state space, it reads
 
         ```
-            𝒩(ω) = b ([1, -1]ᵀ ⊙ ∇(Δ⁻¹u)) ⋅ ∇u
+            𝒩(u) = - b ([1, -1]ᵀ ⊙ ∇(Δ⁻¹u)) ⋅ ∇u
         ```
 
         with `b` the convection scale, `⊙` the Hadamard product, `∇` the
         derivative operator, `Δ⁻¹` the inverse Laplacian, and `u` the vorticity.
 
+        The minus arises because `Exponax` follows the convention that all
+        nonlinear and linear differential operators are on the right-hand side
+        of the equation. Typically, the vorticity convection term is on the
+        left-hand side. Hence, the minus is required to move the term to the
+        right-hand side.
+
+        Since the inverse Laplacian is required, it internally performs a
+        Poisson solve which is straightforward in Fourier space.
+
         **Arguments:**
-            - `num_spatial_dims`: The number of spatial dimensions `d`.
-            - `num_points`: The number of points `N` used to discretize the
-                domain. This **includes** the left boundary point and **excludes**
-                the right boundary point. In higher dimensions; the number of
-                points in each dimension is the same.
-            - `convection_scale`: The scale `b` of the convection term. Defaults to
-                `1.0`.
-            - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)` that
-                represents the derivative operator in Fourier space.
-            - `dealiasing_fraction`: The fraction of the highest resolved modes that
-                are not aliased. Defaults to `2/3` which corresponds to Orszag's 2/3
-                rule.
+
+        - `num_spatial_dims`: The number of spatial dimensions `D`.
+        - `num_points`: The number of points `N` used to discretize the domain.
+            This **includes** the left boundary point and **excludes** the right
+            boundary point. In higher dimensions; the number of points in each
+            dimension is the same.
+        - `convection_scale`: The scale `b` of the convection term. Defaults
+            to `1.0`.
+        - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)`
+            that represents the derivative operator in Fourier space.
+        - `dealiasing_fraction`: The fraction of the highest resolved modes
+            that are not aliased. Defaults to `2/3` which corresponds to
+            Orszag's 2/3 rule.
         """
         if num_spatial_dims != 2:
             raise ValueError(f"Expected num_spatial_dims = 2, got {num_spatial_dims}.")
@@ -60,7 +70,9 @@ def __init__(
         laplacian = build_laplace_operator(derivative_operator, order=2)
 
         # Uses the UNCHANGED mean solution to the Poisson equation (hence, the
-        # mean of the "right-hand side" will be the mean of the solution)
+        # mean of the "right-hand side" will be the mean of the solution).
+        # However, this does not matter because we subsequently take the
+        # gradient which would annihilate any mean energy anyway.
         self.inv_laplacian = jnp.where(laplacian == 0, 1.0, 1 / laplacian)
 
     def __call__(
@@ -110,11 +122,12 @@ def __init__(
         In state space, it reads
 
         ```
-            𝒩(ω) = b ([1, -1]ᵀ ⊙ ∇(Δ⁻¹u)) ⋅ ∇u - f
+            𝒩(u) = - b ([1, -1]ᵀ ⊙ ∇(Δ⁻¹u)) ⋅ ∇u - f
         ```
 
         For details on the vorticity convective term, see
-        `VorticityConvection2d`. The forcing term has the form
+        `exponax.nonlin_fun.VorticityConvection2d`. The forcing term has the
+        form
 
         ```
             f = -k (2π/L) γ cos(k (2π/L) x₁)
@@ -126,22 +139,21 @@ def __init__(
         the vorticity is derived via the curl).
 
         **Arguments:**
-            - `num_spatial_dims`: The number of spatial dimensions `d`.
-            - `num_points`: The number of points `N` used to discretize the
-                domain. This **includes** the left boundary point and
-                **excludes** the right boundary point. In higher dimensions; the
-                number of points in each dimension is the same.
-            - `convection_scale`: The scale `b` of the convection term. Defaults
-                to `1.0`.
-            - `injection_mode`: The wavenumber `k` at which energy is injected.
-                Defaults to `4`.
-            - `injection_scale`: The intensity `γ` of the injection term.
-                Defaults to `1.0`.
-            - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)`
-                that represents the derivative operator in Fourier space.
-            - `dealiasing_fraction`: The fraction of the highest resolved modes
-                that are not aliased. Defaults to `2/3` which corresponds to
-                Orszag's 2/3 rule.
+
+        - `num_spatial_dims`: The number of spatial dimensions `d`.
+        - `num_points`: The number of points `N` used to discretize the
+            domain. This **includes** the left boundary point and **excludes**
+            the right boundary point. In higher dimensions; the number of points
+            in each dimension is the same.
+        - `convection_scale`: The scale `b` of the convection term. Defaults
+            to `1.0`.
+        - `injection_mode`: The wavenumber `k` at which energy is injected.
+        - `injection_scale`: The intensity `γ` of the injection term.
+        - `derivative_operator`: A complex array of shape `(d, ..., N//2+1)`
+            that represents the derivative operator in Fourier space.
+        - `dealiasing_fraction`: The fraction of the highest resolved modes
+            that are not aliased. Defaults to `2/3` which corresponds to
+            Orszag's 2/3 rule.
         """
         super().__init__(
             num_spatial_dims,

From 872426a6a8f2e80cb2fae57c239abeddefd411c1 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 14:29:47 +0200
Subject: [PATCH 16/26] Fix broken link

---
 exponax/stepper/_kuramoto_sivashinsky.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/exponax/stepper/_kuramoto_sivashinsky.py b/exponax/stepper/_kuramoto_sivashinsky.py
index 88d7d4c..9799f29 100644
--- a/exponax/stepper/_kuramoto_sivashinsky.py
+++ b/exponax/stepper/_kuramoto_sivashinsky.py
@@ -31,7 +31,7 @@ def __init__(
         equation on periodic boundary conditions. Uses the **combustion format**
         (or non-conservative format). Most deep learning papers in 1d considered
         the conservative format available as
-        [`KuramotoSivashinskyConservative`](exponax/stepper/KuramotoSivashinskyConservative).
+        [`exponax.stepper.KuramotoSivashinskyConservative`][].
 
         In 1d, the KS equation is given by
 

From 79250f8dd55a7af2cdc21177a8395e238cba9b16 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Mon, 2 Sep 2024 14:47:22 +0200
Subject: [PATCH 17/26] Adapt to Equinox style

---
 exponax/ic/_base_ic.py                  | 25 ++++++----
 exponax/ic/_clamping.py                 |  8 +++-
 exponax/ic/_diffused_noise.py           | 24 +++++-----
 exponax/ic/_discontinuities.py          | 43 +++++++++--------
 exponax/ic/_gaussian_blob.py            | 38 +++++++++------
 exponax/ic/_gaussian_random_field.py    | 23 +++++-----
 exponax/ic/_multi_channel.py            | 37 +++++++++++++--
 exponax/ic/_scaled.py                   |  5 +-
 exponax/ic/_sine_waves_1d.py            | 61 ++++++++++++++-----------
 exponax/ic/_truncated_fourier_series.py | 35 ++++++++------
 10 files changed, 181 insertions(+), 118 deletions(-)

diff --git a/exponax/ic/_base_ic.py b/exponax/ic/_base_ic.py
index e0461a9..926f38d 100644
--- a/exponax/ic/_base_ic.py
+++ b/exponax/ic/_base_ic.py
@@ -13,10 +13,12 @@ def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]:
         Evaluate the initial condition.
 
         **Arguments**:
-            - `x`: The grid points.
+
+        - `x`: The grid points.
 
         **Returns**:
-            - `u`: The initial condition evaluated at the grid points.
+
+        - `u`: The initial condition evaluated at the grid points.
         """
         pass
 
@@ -30,11 +32,13 @@ def gen_ic_fun(self, *, key: PRNGKeyArray) -> BaseIC:
         Generate an initial condition function.
 
         **Arguments**:
-            - `key`: A jax random key.
+
+        - `key`: A jax random key.
 
         **Returns**:
-            - `ic`: An initial condition function that can be evaluated at
-                degree of freedom locations.
+
+        - `ic`: An initial condition function that can be evaluated at
+            degree of freedom locations.
         """
         raise NotImplementedError(
             "This random ic generator cannot represent its initial condition as a function. Directly evaluate it."
@@ -47,15 +51,16 @@ def __call__(
         key: PRNGKeyArray,
     ) -> Float[Array, "1 ... N"]:
         """
-        Generate a random initial condition.
+        Generate a random initial condition on a grid with `num_points` points.
 
         **Arguments**:
-            - `num_points`: The number of grid points in each dimension.
-            - `key`: A jax random key.
-            - `indexing`: The indexing convention for the grid.
+
+        - `num_points`: The number of grid points in each dimension.
+        - `key`: A jax random key.
 
         **Returns**:
-            - `u`: The initial condition evaluated at the grid points.
+
+        - `u`: The initial condition evaluated at the grid points.
         """
         ic_fun = self.gen_ic_fun(key=key)
         grid = make_grid(
diff --git a/exponax/ic/_clamping.py b/exponax/ic/_clamping.py
index 236ec6e..41beb9e 100644
--- a/exponax/ic/_clamping.py
+++ b/exponax/ic/_clamping.py
@@ -15,9 +15,13 @@ def __init__(
         A generator based on another generator that clamps the output to a given
         range.
 
+        Some dynamics (like the Fisher-KPP equation) require such initial
+        conditions.
+
         **Arguments**:
-            - `ic_gen`: The initial condition generator to clamp.
-            - `limits`: The lower and upper limits of the clamping range.
+
+        - `ic_gen`: The initial condition generator to clamp.
+        - `limits`: The lower and upper limits of the clamping range.
         """
         self.ic_gen = ic_gen
         self.limits = limits
diff --git a/exponax/ic/_diffused_noise.py b/exponax/ic/_diffused_noise.py
index 7e1c466..710261b 100644
--- a/exponax/ic/_diffused_noise.py
+++ b/exponax/ic/_diffused_noise.py
@@ -31,20 +31,20 @@ def __init__(
 
         The original noise is drawn in state space with a uniform normal
         distribution. After the application of the diffusion operator, the
-        spectrum decays exponentially with a rate of `intensity`.
+        spectrum decays exponentially quadratic with a rate of `intensity`.
 
         **Arguments**:
-            - `num_spatial_dims`: The number of spatial dimensions `d`.
-            - `domain_extent`: The extent of the domain. Defaults to `1.0`. This
-                indirectly affects the intensity of the noise. It is best to
-                keep it at `1.0` and just adjust the `intensity` instead.
-            - `intensity`: The intensity of the noise. Defaults to `0.001`.
-            - `zero_mean`: Whether to zero the mean of the noise. Defaults to
-                `True`.
-            - `std_one`: Whether to normalize the noise to have a standard
-                deviation of one. Defaults to `False`.
-            - `max_one`: Whether to normalize the noise to the maximum absolute
-                value of one. Defaults to `False`.
+
+        - `num_spatial_dims`: The number of spatial dimensions `d`.
+        - `domain_extent`: The extent of the domain. Defaults to `1.0`. This
+            indirectly affects the intensity of the noise. It is best to keep it
+            at `1.0` and just adjust the `intensity` instead.
+        - `intensity`: The intensity of the noise. Defaults to `0.001`.
+        - `zero_mean`: Whether to zero the mean of the noise.
+        - `std_one`: Whether to normalize the noise to have a standard
+            deviation of one. Defaults to `False`.
+        - `max_one`: Whether to normalize the noise to the maximum absolute
+            value of one. Defaults to `False`.
         """
         if not zero_mean and std_one:
             raise ValueError("Cannot have `zero_mean=False` and `std_one=True`.")
diff --git a/exponax/ic/_discontinuities.py b/exponax/ic/_discontinuities.py
index 9af30e6..cc75bd1 100644
--- a/exponax/ic/_discontinuities.py
+++ b/exponax/ic/_discontinuities.py
@@ -42,14 +42,15 @@ def __init__(
         A state described by a collection of discontinuities.
 
         **Arguments**:
-            - `discontinuity_list`: A tuple of discontinuities.
-            - `zero_mean`: Whether the state should have zero mean.
-            - `std_one`: Whether to normalize the state to have a standard
-                deviation of one. Defaults to `False`. Only works if the offset
-                is zero.
-            - `max_one`: Whether to normalize the state to have the maximum
-                absolute value of one. Defaults to `False`. Only one of
-                `std_one` and `max_one` can be `True`.
+
+        - `discontinuity_list`: A tuple of discontinuities.
+        - `zero_mean`: Whether the state should have zero mean.
+        - `std_one`: Whether to normalize the state to have a standard
+            deviation of one. Defaults to `False`. Only works if the offset is
+            zero.
+        - `max_one`: Whether to normalize the state to have the maximum
+            absolute value of one. Defaults to `False`. Only one of `std_one`
+            and `max_one` can be `True`.
         """
         if not zero_mean and std_one:
             raise ValueError("Cannot have `zero_mean=False` and `std_one=True`.")
@@ -102,17 +103,18 @@ def __init__(
         discontinuities.
 
         **Arguments**:
-            - `num_spatial_dims`: The number of spatial dimensions.
-            - `domain_extent`: The extent of the domain in each spatial direction.
-            - `num_discontinuities`: The number of discontinuities.
-            - `value_range`: The range of values for the discontinuities.
-            - `zero_mean`: Whether the state should have zero mean.
-            - `std_one`: Whether to normalize the state to have a standard
-                deviation of one. Defaults to `False`. Only works if the offset
-                is zero.
-            - `max_one`: Whether to normalize the state to have the maximum
-                absolute value of one. Defaults to `False`. Only one of
-                `std_one` and `max_one` can be `True`.
+
+        - `num_spatial_dims`: The number of spatial dimensions.
+        - `domain_extent`: The extent of the domain in each spatial direction.
+        - `num_discontinuities`: The number of discontinuities.
+        - `value_range`: The range of values for the discontinuities.
+        - `zero_mean`: Whether the state should have zero mean.
+        - `std_one`: Whether to normalize the state to have a standard
+            deviation of one. Defaults to `False`. Only works if the offset is
+            zero.
+        - `max_one`: Whether to normalize the state to have the maximum
+            absolute value of one. Defaults to `False`. Only one of `std_one`
+            and `max_one` can be `True`.
         """
         if not zero_mean and std_one:
             raise ValueError("Cannot have `zero_mean=False` and `std_one=True`.")
@@ -129,6 +131,9 @@ def __init__(
         self.max_one = max_one
 
     def gen_one_ic_fn(self, *, key: PRNGKeyArray) -> Discontinuity:
+        """
+        Generates a single discontinuity.
+        """
         lower_limits = []
         upper_limits = []
         for i in range(self.num_spatial_dims):
diff --git a/exponax/ic/_gaussian_blob.py b/exponax/ic/_gaussian_blob.py
index e82b703..65911c4 100644
--- a/exponax/ic/_gaussian_blob.py
+++ b/exponax/ic/_gaussian_blob.py
@@ -24,12 +24,15 @@ def __init__(
         one_complement: bool = False,
     ):
         """
-        A state described by a Gaussian blob.
+        A state described by a Gaussian blob. Note that the produced function is
+        not perfectly periodic, especially if the blobs are close to the domain
+        boundaries.
 
         **Arguments**:
-            - `position`: The position of the blob.
-            - `covariance`: The covariance matrix of the blob.
-            - `one_complement`: Whether to return one minus the Gaussian blob.
+
+        - `position`: The position of the blob.
+        - `covariance`: The covariance matrix of the blob.
+        - `one_complement`: Whether to return one minus the Gaussian blob.
         """
         self.position = position
         self.covariance = covariance
@@ -78,7 +81,8 @@ def __init__(
         A state described by a collection of Gaussian blobs.
 
         **Arguments**:
-            - `blob_list`: A tuple of Gaussian blobs.
+
+        - `blob_list`: A tuple of Gaussian blobs.
         """
         self.blob_list = blob_list
 
@@ -111,16 +115,17 @@ def __init__(
         A random Gaussian blob initial condition generator.
 
         **Arguments**:
-            - `num_spatial_dims`: The number of spatial dimensions.
-            - `domain_extent`: The extent of the domain.
-            - `num_blobs`: The number of blobs.
-            - `position_range`: The range of the position of the blobs. This
-                will be scaled by the domain extent. Hence, this acts as if the
-                domain_extent was 1
-            - `variance_range`: The range of the variance of the blobs. This will
-                be scaled by the domain extent. Hence, this acts as if the
-                domain_extent was 1
-            - `one_complement`: Whether to return one minus the Gaussian blob.
+
+        - `num_spatial_dims`: The number of spatial dimensions.
+        - `domain_extent`: The extent of the domain.
+        - `num_blobs`: The number of blobs.
+        - `position_range`: The range of the position of the blobs. This
+            will be scaled by the domain extent. Hence, this acts as if the
+            domain_extent was 1
+        - `variance_range`: The range of the variance of the blobs. This will
+            be scaled by the domain extent. Hence, this acts as if the
+            domain_extent was 1
+        - `one_complement`: Whether to return one minus the Gaussian blob.
         """
         self.num_spatial_dims = num_spatial_dims
         self.domain_extent = domain_extent
@@ -130,6 +135,9 @@ def __init__(
         self.one_complement = one_complement
 
     def gen_blob(self, *, key) -> GaussianBlob:
+        """
+        Generates a single Gaussian blob.
+        """
         position_key, variance_key = jr.split(key)
 
         position = jr.uniform(
diff --git a/exponax/ic/_gaussian_random_field.py b/exponax/ic/_gaussian_random_field.py
index 6fd1435..f3f8ba6 100644
--- a/exponax/ic/_gaussian_random_field.py
+++ b/exponax/ic/_gaussian_random_field.py
@@ -31,19 +31,20 @@ def __init__(
     ):
         """
         Random generator for initial states following a power-law spectrum in
-        Fourier space.
+        Fourier space, i.e., it decays polynomially with the wavenumber.
 
         **Arguments:**
-            - `num_spatial_dims`: The number of spatial dimensions.
-            - `domain_extent`: The extent of the domain in each spatial direction.
-            - `powerlaw_exponent`: The exponent of the power-law spectrum.
-            - `zero_mean`: Whether the field should have zero mean.
-            - `std_one`: Whether to normalize the state to have a standard
-                deviation of one. Defaults to `False`. Only works if the offset
-                is zero.
-            - `max_one`: Whether to normalize the state to have the maximum
-                absolute value of one. Defaults to `False`. Only one of
-                `std_one` and `max_one` can be `True`.
+
+        - `num_spatial_dims`: The number of spatial dimensions.
+        - `domain_extent`: The extent of the domain in each spatial direction.
+        - `powerlaw_exponent`: The exponent of the power-law spectrum.
+        - `zero_mean`: Whether the field should have zero mean.
+        - `std_one`: Whether to normalize the state to have a standard
+            deviation of one. Defaults to `False`. Only works if the offset is
+            zero.
+        - `max_one`: Whether to normalize the state to have the maximum
+            absolute value of one. Defaults to `False`. Only one of `std_one`
+            and `max_one` can be `True`.
         """
         if not zero_mean and std_one:
             raise ValueError("Cannot have `zero_mean=False` and `std_one=True`.")
diff --git a/exponax/ic/_multi_channel.py b/exponax/ic/_multi_channel.py
index 09c4a97..8084dea 100644
--- a/exponax/ic/_multi_channel.py
+++ b/exponax/ic/_multi_channel.py
@@ -14,7 +14,8 @@ def __init__(self, initial_conditions: tuple[BaseIC, ...]):
         A multi-channel initial condition.
 
         **Arguments**:
-            - `initial_conditions`: A tuple of initial conditions.
+
+        - `initial_conditions`: A tuple of initial conditions.
         """
         self.initial_conditions = initial_conditions
 
@@ -23,10 +24,12 @@ def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "C ... N"]:
         Evaluate the initial condition.
 
         **Arguments**:
-            - `x`: The grid points.
+
+        - `x`: The grid points.
 
         **Returns**:
-            - `u`: The initial condition evaluated at the grid points.
+
+        - `u`: The initial condition evaluated at the grid points.
         """
         return jnp.concatenate([ic(x) for ic in self.initial_conditions], axis=0)
 
@@ -36,10 +39,34 @@ class RandomMultiChannelICGenerator(eqx.Module):
 
     def __init__(self, ic_generators: tuple[BaseRandomICGenerator, ...]):
         """
-        A multi-channel random initial condition generator.
+        A multi-channel random initial condition generator. Use this for
+        problems with multiple channels, like Burgers in higher dimensions or
+        the Gray-Scott dynamics.
 
         **Arguments**:
-            - `ic_generators`: A tuple of initial condition generators.
+
+        - `ic_generators`: A tuple of initial condition generators.
+
+        !!! example
+            Below is an example for generating a random multi-channel initial
+            condition for the three-dimensional Burgers equation which has three
+            channels. For simplicity, we will use the same IC generator for each
+            channel.
+
+            ```python
+            import jax
+            import exponax as ex
+
+            single_channel_ic_gen = ex.ic.RandomTruncatedFourierSeries(
+                3,
+                max_one=True,
+            )
+            multi_channel_ic_gen = ex.ic.RandomMultiChannelICGenerator(
+                [single_channel_ic_gen,] * 3
+            )
+
+            ic = multi_channel_ic_gen(100, key=jax.random.PRNGKey(0))
+            ```
         """
         self.ic_generators = ic_generators
 
diff --git a/exponax/ic/_scaled.py b/exponax/ic/_scaled.py
index 6ecdbec..fde2adf 100644
--- a/exponax/ic/_scaled.py
+++ b/exponax/ic/_scaled.py
@@ -23,8 +23,9 @@ def __init__(self, ic_gen: BaseRandomICGenerator, scale: float):
         `max_one=True` or `std_one=True`.
 
         **Arguments**:
-            - `ic_gen`: The initial condition generator.
-            - `scale`: The scaling factor.
+
+        - `ic_gen`: The initial condition generator.
+        - `scale`: The scaling factor.
         """
         self.ic_gen = ic_gen
         self.scale = scale
diff --git a/exponax/ic/_sine_waves_1d.py b/exponax/ic/_sine_waves_1d.py
index 6cd8f14..de96370 100644
--- a/exponax/ic/_sine_waves_1d.py
+++ b/exponax/ic/_sine_waves_1d.py
@@ -29,17 +29,18 @@ def __init__(
         A state described by a collection of sine waves. Only works in 1d.
 
         **Arguments**:
-            - `domain_extent`: The extent of the domain.
-            - `amplitudes`: A tuple of amplitudes.
-            - `wavenumbers`: A tuple of wavenumbers.
-            - `phases`: A tuple of phases.
-            - `offset`: A constant offset.
-            - `std_one`: Whether to normalize the state to have a standard
-                deviation of one. Defaults to `False`. Only works if the offset
-                is zero.
-            - `max_one`: Whether to normalize the state to have the maximum
-                absolute value of one. Defaults to `False`. Only one of
-                `std_one` and `max_one` can be `True`.
+
+        - `domain_extent`: The extent of the domain.
+        - `amplitudes`: A tuple of amplitudes.
+        - `wavenumbers`: A tuple of wavenumbers.
+        - `phases`: A tuple of phases.
+        - `offset`: A constant offset.
+        - `std_one`: Whether to normalize the state to have a standard
+            deviation of one. Defaults to `False`. Only works if the offset
+            is zero.
+        - `max_one`: Whether to normalize the state to have the maximum
+            absolute value of one. Defaults to `False`. Only one of
+            `std_one` and `max_one` can be `True`.
         """
         if offset != 0.0 and std_one:
             raise ValueError("Cannot have non-zero offset and `std_one=True`.")
@@ -103,23 +104,29 @@ def __init__(
         Random generator for initial states described by a collection of sine
         waves. Only works in 1d.
 
+        This is a simplified version of the `RandomTruncatedFourierSeries`
+        generator that works in arbitrary dimensions. However, only this
+        generator can produce a functional representation of the initial
+        condition.
+
         **Arguments**:
-            - `num_spatial_dims`: The number of spatial dimensions.
-            - `domain_extent`: The extent of the domain.
-            - `cutoff`: The cutoff of the wavenumbers. This limits the
-                "complexity" of the initial state. Note that some dynamics are
-                very sensitive to high-frequency information.
-            - `amplitude_range`: The range of the amplitudes. Defaults to
-              `(-1.0, 1.0)`.
-            - `phase_range`: The range of the phases. Defaults to `(0.0, 2π)`.
-            - `offset_range`: The range of the offsets. Defaults to `(0.0,
-                0.0)`, meaning **zero-mean** by default.
-            - `std_one`: Whether to normalize the state to have a standard
-                deviation of one. Defaults to `False`. Only works if the offset
-                is zero.
-            - `max_one`: Whether to normalize the state to have the maximum
-                absolute value of one. Defaults to `False`. Only one of
-                `std_one` and `max_one` can be `True`.
+
+        - `num_spatial_dims`: The number of spatial dimensions.
+        - `domain_extent`: The extent of the domain.
+        - `cutoff`: The cutoff of the wavenumbers. This limits the
+            "complexity" of the initial state. Note that some dynamics are very
+            sensitive to high-frequency information.
+        - `amplitude_range`: The range of the amplitudes. Defaults to
+            `(-1.0, 1.0)`.
+        - `phase_range`: The range of the phases. Defaults to `(0.0, 2π)`.
+        - `offset_range`: The range of the offsets. Defaults to `(0.0,
+            0.0)`, meaning **zero-mean** by default.
+        - `std_one`: Whether to normalize the state to have a standard
+            deviation of one. Defaults to `False`. Only works if the offset is
+            zero.
+        - `max_one`: Whether to normalize the state to have the maximum
+            absolute value of one. Defaults to `False`. Only one of `std_one`
+            and `max_one` can be `True`.
         """
         if num_spatial_dims != 1:
             raise ValueError("RandomSineWaves1d only works in 1d.")
diff --git a/exponax/ic/_truncated_fourier_series.py b/exponax/ic/_truncated_fourier_series.py
index c200a2f..ee4aad0 100644
--- a/exponax/ic/_truncated_fourier_series.py
+++ b/exponax/ic/_truncated_fourier_series.py
@@ -59,22 +59,27 @@ def __init__(
         in the range `amplitude_range`. Angles (=angular offsets) are drawn
         according to a uniform distribution in the range `angle_range`.
 
+        See also `exponax.ic.RandomSineWaves1d` for a simplified version that
+        only works in 1d but can also produce a functional representation of the
+        initial state.
+
         **Arguments**:
-            - `num_spatial_dims`: The number of spatial dimensions `d`.
-            - `cutoff`: The cutoff of the wavenumbers. This limits the
-                "complexity" of the initial state. Note that some dynamics are
-                very sensitive to high-frequency information.
-            - `amplitude_range`: The range of the amplitudes. Defaults to
-              `(-1.0, 1.0)`.
-            - `angle_range`: The range of the angles. Defaults to `(0.0, 2π)`.
-            - `offset_range`: The range of the offsets. Defaults to `(0.0,
-                0.0)`, meaning **zero-mean** by default.
-            - `std_one`: Whether to normalize the state to have a standard
-                deviation of one. Defaults to `False`. Only works if the offset
-                is zero.
-            - `max_one`: Whether to normalize the state to have the maximum
-                absolute value of one. Defaults to `False`. Only one of
-                `std_one` and `max_one` can be `True`.
+
+        - `num_spatial_dims`: The number of spatial dimensions `d`.
+        - `cutoff`: The cutoff of the wavenumbers. This limits the
+            "complexity" of the initial state. Note that some dynamics are very
+            sensitive to high-frequency information.
+        - `amplitude_range`: The range of the amplitudes. Defaults to
+            `(-1.0, 1.0)`.
+        - `angle_range`: The range of the angles. Defaults to `(0.0, 2π)`.
+        - `offset_range`: The range of the offsets. Defaults to `(0.0,
+            0.0)`, meaning **zero-mean** by default.
+        - `std_one`: Whether to normalize the state to have a standard
+            deviation of one. Defaults to `False`. Only works if the offset is
+            zero.
+        - `max_one`: Whether to normalize the state to have the maximum
+            absolute value of one. Defaults to `False`. Only one of `std_one`
+            and `max_one` can be `True`.
         """
         if offset_range == (0.0, 0.0) and std_one:
             raise ValueError("Cannot have non-zero offset and `std_one=True`.")

From bfaefa85d6a56cc895d5f71f3d15dc01a216836a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Felix=20K=C3=B6hler?=
 <27728103+Ceyron@users.noreply.github.com>
Date: Mon, 2 Sep 2024 15:30:41 +0200
Subject: [PATCH 18/26] Improve documentation

---
 exponax/stepper/generic/_convection.py | 196 ++++++++++++++++++++-----
 1 file changed, 156 insertions(+), 40 deletions(-)

diff --git a/exponax/stepper/generic/_convection.py b/exponax/stepper/generic/_convection.py
index c5b5aca..9754b45 100644
--- a/exponax/stepper/generic/_convection.py
+++ b/exponax/stepper/generic/_convection.py
@@ -56,47 +56,48 @@ def __init__(
         Alternatively, with `single_channel=True`, the number of channels can be
         kept to constant 1 no matter the number of spatial dimensions.
 
-        Depending on the collection of linear coefficients can be represented,
-        for example:
+        Depending on the collection of linear coefficients a range of dynamics
+        can be represented, for example:
             - Burgers equation with `a = (0, 0, 0.01)` with `len(a) = 3`
             - KdV equation with `a = (0, 0, 0, 0.01)` with `len(a) = 4`
 
         **Arguments:**
-            - `num_spatial_dims`: The number of spatial dimensions `d`.
-            - `domain_extent`: The size of the domain `L`; in higher dimensions
-                the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`.
-            - `num_points`: The number of points `N` used to discretize the
-                domain. This **includes** the left boundary point and
-                **excludes** the right boundary point. In higher dimensions; the
-                number of points in each dimension is the same. Hence, the total
-                number of degrees of freedom is `Nᵈ`.
-            - `dt`: The timestep size `Δt` between two consecutive states.
-            - `coefficients` (keyword-only): The list of coefficients `a_j`
-                corresponding to the derivatives. The length of this tuple
-                represents the highest occuring derivative. The default value
-                `(0.0, 0.0, 0.01)` corresponds to the Burgers equation (because
-                of the diffusion)
-            - `convection_scale` (keyword-only): The scale `b₁` of the
-                convection term. Default is `1.0`.
-            - `single_channel`: Whether to use the single channel mode in higher
-                dimensions. In this case the the convection is `b₁ (∇ ⋅ 1)(u²)`.
-                In this case, the state always has a single channel, no matter
-                the spatial dimension. Default: False.
-            - `order`: The order of the Exponential Time Differencing Runge
-                Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0`
-                only solves the linear part of the equation. Use higher values
-                for higher accuracy and stability. The default choice of `2` is
-                a good compromise for single precision floats.
-            - `dealiasing_fraction`: The fraction of the wavenumbers to keep
-                before evaluating the nonlinearity. The default 2/3 corresponds
-                to Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2.
-                Default: 2/3.
-            - `num_circle_points`: How many points to use in the complex contour
-                integral method to compute the coefficients of the exponential
-                time differencing Runge Kutta method. Default: 16.
-            - `circle_radius`: The radius of the contour used to compute the
-                coefficients of the exponential time differencing Runge Kutta
-                method. Default: 1.0.
+
+        - `num_spatial_dims`: The number of spatial dimensions `D`.
+        - `domain_extent`: The size of the domain `L`; in higher dimensions
+            the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`.
+        - `num_points`: The number of points `N` used to discretize the
+            domain. This **includes** the left boundary point and **excludes**
+            the right boundary point. In higher dimensions; the number of points
+            in each dimension is the same. Hence, the total number of degrees of
+            freedom is `Nᵈ`.
+        - `dt`: The timestep size `Δt` between two consecutive states.
+        - `coefficients` (keyword-only): The list of coefficients `a_j`
+            corresponding to the derivatives. The length of this tuple
+            represents the highest occuring derivative. The default value `(0.0,
+            0.0, 0.01)` corresponds to the Burgers equation (because of the
+            diffusion)
+        - `convection_scale` (keyword-only): The scale `b₁` of the
+            convection term. Default is `1.0`.
+        - `single_channel`: Whether to use the single channel mode in higher
+            dimensions. In this case the the convection is `b₁ (∇ ⋅ 1)(u²)`. In
+            this case, the state always has a single channel, no matter the
+            spatial dimension. Default: False.
+        - `order`: The order of the Exponential Time Differencing Runge
+            Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only
+            solves the linear part of the equation. Use higher values for higher
+            accuracy and stability. The default choice of `2` is a good
+            compromise for single precision floats.
+        - `dealiasing_fraction`: The fraction of the wavenumbers to keep
+            before evaluating the nonlinearity. The default 2/3 corresponds to
+            Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. Default:
+            2/3.
+        - `num_circle_points`: How many points to use in the complex contour
+            integral method to compute the coefficients of the exponential time
+            differencing Runge Kutta method. Default: 16.
+        - `circle_radius`: The radius of the contour used to compute the
+            coefficients of the exponential time differencing Runge Kutta
+            method. Default: 1.0.
         """
         self.coefficients = coefficients
         self.convection_scale = convection_scale
@@ -166,12 +167,60 @@ def __init__(
         circle_radius: float = 1.0,
     ):
         """
-        By default: Behaves like a Burgers with
+        Time stepper for the **normalized** d-dimensional (`d ∈ {1, 2, 3}`)
+        semi-linear PDEs consisting of a convection nonlinearity and an
+        arbitrary combination of (isotropic) linear derivatives. Uses a
+        normalized interface, i.e., the domain is scaled to `Ω = (0, 1)ᵈ` and
+        time step size is `Δt = 1.0`.
+
+        See `exponax.stepper.generic.GeneralConvectionStepper` for more details
+        on the functional form of the PDE.
 
-        ``` Burgers(
+        In the default configuration, the number of channel grows with the
+        number of spatial dimensions. Setting the flag `single_channel=True`
+        activates a single-channel hack.
+
+        Under the default settings, it behaves like the Burgers equation under
+        the following settings ```python exponax.stepper.Burgers(
             D=D, L=1, N=N, dt=0.1, diffusivity=0.01,
         )
         ```
+
+        **Arguments:**
+
+        - `num_spatial_dims`: The number of spatial dimensions `D`.
+        - `num_points`: The number of points `N` used to discretize the domain.
+            This **includes** the left boundary point and **excludes** the right
+            boundary point. In higher dimensions; the number of points in each
+            dimension is the same. Hence, the total number of degrees of freedom
+            is `Nᵈ`.
+        - `normalized_coefficients`: The list of coefficients
+            `α_j` corresponding to the derivatives. The length of this tuple
+            represents the highest occuring derivative. The default value `(0.0,
+            0.0, 0.01)` corresponds to the Burgers equation (because of the
+            diffusion contribution). Note that these coefficients are normalized
+            on the unit domain and unit time step size.
+        - `normalized_convection_scale`: The scale `β` of the convection term.
+            Default is `1.0`.
+        - `single_channel`: Whether to use the single channel mode in higher
+            dimensions. In this case the the convection is `β (∇ ⋅ 1)(u²)`. In
+            this case, the state always has a single channel, no matter the
+            spatial dimension. Default: False.
+        - `order`: The order of the Exponential Time Differencing Runge
+            Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only
+            solves the linear part of the equation. Use higher values for higher
+            accuracy and stability. The default choice of `2` is a good
+            compromise for single precision floats.
+        - `dealiasing_fraction`: The fraction of the wavenumbers to keep
+            before evaluating the nonlinearity. The default 2/3 corresponds to
+            Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. Default:
+            2/3.
+        - `num_circle_points`: How many points to use in the complex contour
+            integral method to compute the coefficients of the exponential time
+            differencing Runge Kutta method. Default: 16.
+        - `circle_radius`: The radius of the contour used to compute the
+            coefficients of the exponential time differencing Runge Kutta
+            method. Default: 1.0.
         """
         self.normalized_coefficients = normalized_coefficients
         self.normalized_convection_scale = normalized_convection_scale
@@ -209,8 +258,75 @@ def __init__(
         circle_radius: float = 1.0,
     ):
         """
-        By default: Behaves like a Burgers
+        Timestepper for the **difficulty-based** d-dimensional (`d ∈ {1, 2, 3}`)
+        semi-linear PDEs consisting of a convection nonlinearity and an
+        arbitrary combination of (isotropic) linear derivatives. Uses a
+        difficulty-based interface where the "intensity" of the dynamics reduces
+        with increasing resolution. This is intended such that emulator learning
+        problems on two resolutions are comparibly difficult.
+
+        Different to `exponax.stepper.generic.NormalizedConvectionStepper`, the
+        dynamics are defined by difficulties. The difficulties are a different
+        combination of normalized dynamics, `num_spatial_dims`, and
+        `num_points`.
+
+            γᵢ = αᵢ Nⁱ 2ⁱ⁻¹ d
+
+        with `d` the number of spatial dimensions, `N` the number of points, and
+        `αᵢ` the normalized coefficient.
+
+        For the nonlinear convection scale it is defined as
+
+            δ = β * M * N² * D
+
+        with `M` the maximum absolute value of the input state.
+
+        This interface is more natural because the difficulties for all orders
+        (given by `i`) are around 1.0. Additionally, they relate to stability
+        condition of explicit Finite Difference schemes for the particular
+        equations. For example, for advection (`i=1`), the absolute of the
+        difficulty is the Courant-Friedrichs-Lewy (CFL) number.
+
+        Under the default settings, this timestepper represents the Burgers
+        equation.
+
+        **Arguments:**
 
+        - `num_spatial_dims`: The number of spatial dimensions `D`.
+        - `num_points`: The number of points `N` used to discretize the domain.
+            This **includes** the left boundary point and **excludes** the right
+            boundary point. In higher dimensions; the number of points in each
+            dimension is the same. Hence, the total number of degrees of freedom
+            is `Nᵈ`.
+        - `linear_difficulties`: The list of difficulties `γᵢ` corresponding to
+            the derivatives. The length of this tuple represents the highest
+            occuring derivative. The default value `(0.0, 0.0, 4.5)` corresponds
+            to the Burgers equation. Note that these coefficients are normalized
+            on the unit domain and unit time step size.
+        - `convection_difficulty`: The difficulty `δ` of the convection term.
+            Default is `5.0`.
+        - `single_channel`: Whether to use the single channel mode in higher
+            dimensions. In this case the the convection is `δ (∇ ⋅ 1)(u²)`. In
+            this case, the state always has a single channel, no matter the
+            spatial dimension. Default: False.
+        - `maximum_absolute`: The maximum absolute value of the state. This is
+            used to extract the normalized dynamics from the convection
+            difficulty.
+        - `order`: The order of the Exponential Time Differencing Runge
+            Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only
+            solves the linear part of the equation. Use higher values for higher
+            accuracy and stability. The default choice of `2` is a good
+            compromise for single precision floats.
+        - `dealiasing_fraction`: The fraction of the wavenumbers to keep
+            before evaluating the nonlinearity. The default 2/3 corresponds to
+            Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. Default:
+            2/3.
+        - `num_circle_points`: How many points to use in the complex contour
+            integral method to compute the coefficients of the exponential time
+            differencing Runge Kutta method. Default: 16.
+        - `circle_radius`: The radius of the contour used to compute the
+            coefficients of the exponential time differencing Runge Kutta
+            method. Default: 1.0.
         """
         self.linear_difficulties = linear_difficulties
         self.convection_difficulty = convection_difficulty

From b36adb21a4dd58fa6b6790434770ff7d5a03bd16 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Tue, 3 Sep 2024 08:45:12 +0200
Subject: [PATCH 19/26] Add clarification

---
 exponax/stepper/generic/_convection.py | 16 +++++++++-------
 1 file changed, 9 insertions(+), 7 deletions(-)

diff --git a/exponax/stepper/generic/_convection.py b/exponax/stepper/generic/_convection.py
index 9754b45..4540246 100644
--- a/exponax/stepper/generic/_convection.py
+++ b/exponax/stepper/generic/_convection.py
@@ -275,17 +275,19 @@ def __init__(
         with `d` the number of spatial dimensions, `N` the number of points, and
         `αᵢ` the normalized coefficient.
 
-        For the nonlinear convection scale it is defined as
+        The difficulty of the nonlinear convection scale is defined by
 
             δ = β * M * N² * D
 
-        with `M` the maximum absolute value of the input state.
+        with `M` the maximum absolute value of the input state (typically `1.0`
+        if one uses the `exponax.ic` random generators with the `max_one=True`
+        argument).
 
-        This interface is more natural because the difficulties for all orders
-        (given by `i`) are around 1.0. Additionally, they relate to stability
-        condition of explicit Finite Difference schemes for the particular
-        equations. For example, for advection (`i=1`), the absolute of the
-        difficulty is the Courant-Friedrichs-Lewy (CFL) number.
+        This interface is more natural than the normalized interface because the
+        difficulties for all orders (given by `i`) are around 1.0. Additionally,
+        they relate to stability condition of explicit Finite Difference schemes
+        for the particular equations. For example, for advection (`i=1`), the
+        absolute of the difficulty is the Courant-Friedrichs-Lewy (CFL) number.
 
         Under the default settings, this timestepper represents the Burgers
         equation.

From bc25c685e0fe63cb2998d97427a06eff62d56b82 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Tue, 3 Sep 2024 08:45:58 +0200
Subject: [PATCH 20/26] better wording

---
 exponax/_base_stepper.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/exponax/_base_stepper.py b/exponax/_base_stepper.py
index 80a8875..60d1de2 100644
--- a/exponax/_base_stepper.py
+++ b/exponax/_base_stepper.py
@@ -253,8 +253,8 @@ def __call__(
         - `u_next`: The state vector after one step, shape `(C, ..., N,)`.
 
         !!! tip
-            Use this call method together with `exponax.rollout` to produce
-            temporal trajectories by efficiently autogressive rollout.
+            Use this call method together with `exponax.rollout` to efficiently
+            produce temporal trajectories.
 
         !!! info
             For batched operation, use `jax.vmap` on this function.

From a8de7f1fcb1fe7f155972d74dcae1be868ed2d90 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Tue, 3 Sep 2024 08:59:18 +0200
Subject: [PATCH 21/26] Fix convection difficulty computation

---
 exponax/stepper/generic/_convection.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/exponax/stepper/generic/_convection.py b/exponax/stepper/generic/_convection.py
index 4540246..00df9c2 100644
--- a/exponax/stepper/generic/_convection.py
+++ b/exponax/stepper/generic/_convection.py
@@ -181,7 +181,11 @@ def __init__(
         activates a single-channel hack.
 
         Under the default settings, it behaves like the Burgers equation under
-        the following settings ```python exponax.stepper.Burgers(
+        the following settings
+
+        ```python
+
+        exponax.stepper.Burgers(
             D=D, L=1, N=N, dt=0.1, diffusivity=0.01,
         )
         ```
@@ -277,7 +281,7 @@ def __init__(
 
         The difficulty of the nonlinear convection scale is defined by
 
-            δ = β * M * N² * D
+            δ₁ = β₁ * M * N * D
 
         with `M` the maximum absolute value of the input state (typically `1.0`
         if one uses the `exponax.ic` random generators with the `max_one=True`

From aa09c3286a491ea917b54b500e723f50b360d7ee Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Tue, 3 Sep 2024 09:01:09 +0200
Subject: [PATCH 22/26] Improve doc of generic gradient norm stepper

---
 exponax/stepper/generic/_gradient_norm.py | 198 ++++++++++++++++++----
 1 file changed, 161 insertions(+), 37 deletions(-)

diff --git a/exponax/stepper/generic/_gradient_norm.py b/exponax/stepper/generic/_gradient_norm.py
index fe03f7d..cd6e299 100644
--- a/exponax/stepper/generic/_gradient_norm.py
+++ b/exponax/stepper/generic/_gradient_norm.py
@@ -51,42 +51,43 @@ def __init__(
         ```
 
         The default configuration coincides with a Kuramoto-Sivashinsky equation
-        in combustion format. Note that this requires negative values (because
-        the KS usually defines their linear operators on the left hand side of
-        the equation)
+        in combustion format (see `exponax.stepper.KuramotoSivashinsky`). Note
+        that this requires negative values (because the KS usually defines their
+        linear operators on the left hand side of the equation)
 
         **Arguments:**
-            - `num_spatial_dims`: The number of spatial dimensions `d`.
-            - `domain_extent`: The size of the domain `L`; in higher dimensions
-                the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`.
-            - `num_points`: The number of points `N` used to discretize the
-                domain. This **includes** the left boundary point and
-                **excludes** the right boundary point. In higher dimensions; the
-                number of points in each dimension is the same. Hence, the total
-                number of degrees of freedom is `Nᵈ`.
-            - `dt`: The timestep size `Δt` between two consecutive states.
-            - `coefficients` (keyword-only): The list of coefficients `a_j`
-                corresponding to the derivatives. The length of this tuple
-                represents the highest occuring derivative. The default value
-                `(0.0, 0.0, -1.0, 0.0, -1.0)` corresponds to the Kuramoto-
-                Sivashinsky equation in combustion format.
-            - `gradient_norm_scale` (keyword-only): The scale of the gradient
-                norm term `b₂`. Default: 1.0.
-            - `order`: The order of the Exponential Time Differencing Runge
-                Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0`
-                only solves the linear part of the equation. Use higher values
-                for higher accuracy and stability. The default choice of `2` is
-                a good compromise for single precision floats.
-            - `dealiasing_fraction`: The fraction of the wavenumbers to keep
-                before evaluating the nonlinearity. The default 2/3 corresponds
-                to Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2.
-                Default: 2/3.
-            - `num_circle_points`: How many points to use in the complex contour
-                integral method to compute the coefficients of the exponential
-                time differencing Runge Kutta method. Default: 16.
-            - `circle_radius`: The radius of the contour used to compute the
-                coefficients of the exponential time differencing Runge Kutta
-                method. Default: 1.0.
+
+        - `num_spatial_dims`: The number of spatial dimensions `d`.
+        - `domain_extent`: The size of the domain `L`; in higher dimensions
+            the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`.
+        - `num_points`: The number of points `N` used to discretize the
+            domain. This **includes** the left boundary point and **excludes**
+            the right boundary point. In higher dimensions; the number of points
+            in each dimension is the same. Hence, the total number of degrees of
+            freedom is `Nᵈ`.
+        - `dt`: The timestep size `Δt` between two consecutive states.
+        - `coefficients` (keyword-only): The list of coefficients `a_j`
+            corresponding to the derivatives. The length of this tuple
+            represents the highest occuring derivative. The default value `(0.0,
+            0.0, -1.0, 0.0, -1.0)` corresponds to the Kuramoto- Sivashinsky
+            equation in combustion format.
+        - `gradient_norm_scale` (keyword-only): The scale of the gradient
+            norm term `b₂`. Default: 1.0.
+        - `order`: The order of the Exponential Time Differencing Runge
+            Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only
+            solves the linear part of the equation. Use higher values for higher
+            accuracy and stability. The default choice of `2` is a good
+            compromise for single precision floats.
+        - `dealiasing_fraction`: The fraction of the wavenumbers to keep
+            before evaluating the nonlinearity. The default 2/3 corresponds to
+            Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. Default:
+            2/3.
+        - `num_circle_points`: How many points to use in the complex contour
+            integral method to compute the coefficients of the exponential time
+            differencing Runge Kutta method. Default: 16.
+        - `circle_radius`: The radius of the contour used to compute the
+            coefficients of the exponential time differencing Runge Kutta
+            method. Default: 1.0.
         """
         self.coefficients = coefficients
         self.gradient_norm_scale = gradient_norm_scale
@@ -153,11 +154,68 @@ def __init__(
         circle_radius: float = 1.0,
     ):
         """
-        the number of channels do **not** grow with the number of spatial
-        dimensions. They are always 1.
+        Timestepper for the **normalized** d-dimensional (`d ∈ {1, 2, 3}`)
+        semi-linear PDEs consisting of a gradient norm nonlinearity and an
+        arbitrary combination of (isotropic) linear operators. Uses a normalized
+        interface, i.e., the domain is scaled to `Ω = (0, 1)ᵈ` and time step
+        size is `Δt = 1.0`.
+
+        See `exponax.stepper.generic.GeneralGradientNormStepper` for more
+        details on the functional form of the PDE.
+
+        The number of channels do **not** grow with the number of spatial
+        dimensions. They are always one.
+
+        Under the default settings, it behaves like the Kuramoto-Sivashinsky
+        equation in combustion format under the following settings.
 
         By default: the KS equation on L=60.0
 
+        ```python
+
+        exponax.stepper.KuramotoSivashinsky(
+            num_spatial_dims=D, domain_extent=60.0, num_points=N, dt=0.1,
+            gradient_norm_scale=1.0, second_order_diffusivity=1.0,
+            fourth_order_diffusivity=1.0,
+        )
+        ```
+
+        Note that the coefficient list requires a negative sign because the
+        linear derivatives are moved to the right-hand side in this generic
+        interface.
+
+        **Arguments:**
+
+        - `num_spatial_dims`: The number of spatial dimensions `d`.
+        - `num_points`: The number of points `N` used to discretize the
+            domain. This **includes** the left boundary point and **excludes**
+            the right boundary point. In higher dimensions; the number of points
+            in each dimension is the same. Hence, the total number of degrees of
+            freedom is `Nᵈ`.
+        - `normalized_coefficients`: The list of coefficients `a_j`
+            corresponding to the derivatives. The length of this tuple
+            represents the highest occuring derivative. The default value `(0.0,
+            0.0, -1.0 * 0.1 / (60.0**2), 0.0, -1.0 * 0.1 / (60.0**4))`
+            corresponds to the Kuramoto-Sivashinsky equation in combustion
+            format on a domain of size `L=60.0` with a time step size of
+            `Δt=0.1`.
+        - `normalized_gradient_norm_scale`: The scale of the gradient
+            norm term `b₂`. Default: `1.0 * 0.1 / (60.0**2)`.
+        - `order`: The order of the Exponential Time Differencing Runge
+            Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only
+            solves the linear part of the equation. Use higher values for higher
+            accuracy and stability. The default choice of `2` is a good
+            compromise for single precision floats.
+        - `dealiasing_fraction`: The fraction of the wavenumbers to keep
+            before evaluating the nonlinearity. The default 2/3 corresponds to
+            Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. Default:
+            2/3.
+        - `num_circle_points`: How many points to use in the complex contour
+            integral method to compute the coefficients of the exponential time
+            differencing Runge Kutta method. Default: 16.
+        - `circle_radius`: The radius of the contour used to compute the
+            coefficients of the exponential time differencing Runge Kutta
+            method. Default: 1.0.
         """
         self.normalized_coefficients = normalized_coefficients
         self.normalized_gradient_norm_scale = normalized_gradient_norm_scale
@@ -193,7 +251,73 @@ def __init__(
         circle_radius: float = 1.0,
     ):
         """
-        By default: KS equation
+        Timestepper for the **difficulty-based** d-dimensional (`d ∈ {1, 2, 3}`)
+        semi-linear PDEs consisting of a gradient norm nonlinearity and an
+        arbitrary combination of (isotropic) linear operators. Uses a
+        difficulty-based interface where the "intensity" of the dynamics reduces
+        with increasing resolution. This is intended such that emulator learning
+        problems on two resolutions are comparibly difficult.
+
+        Different to `exponax.stepper.generic.NormalizedGradientNormStepper`,
+        the dynamics are defined by difficulties. The difficulties are a
+        different combination of normalized dynamics, `num_spatial_dims`, and
+        `num_points`.
+
+            γᵢ = αᵢ Nⁱ 2ⁱ⁻¹ d
+
+        with `d` the number of spatial dimensions, `N` the number of points, and
+        `αᵢ` the normalized coefficient.
+
+        The difficulty of the nonlinear convection scale is defined by
+
+            δ₂ = β₂ * M * N² * D
+
+        with `M` the maximum absolute value of the input state (typically `1.0`
+        if one uses the `exponax.ic` random generators with the `max_one=True`
+        argument).
+
+        This interface is more natural than the normalized interface because the
+        difficulties for all orders (given by `i`) are around 1.0. Additionally,
+        they relate to stability condition of explicit Finite Difference schemes
+        for the particular equations. For example, for advection (`i=1`), the
+        absolute of the difficulty is the Courant-Friedrichs-Lewy (CFL) number.
+
+        Under the default settings, this timestepper represents the
+        Kuramoto-Sivashinsky equation (in combustion format).
+
+        **Arguments:**
+
+        - `num_spatial_dims`: The number of spatial dimensions `d`.
+        - `num_points`: The number of points `N` used to discretize the
+            domain. This **includes** the left boundary point and **excludes**
+            the right boundary point. In higher dimensions; the number of points
+            in each dimension is the same. Hence, the total number of degrees of
+            freedom is `Nᵈ`.
+        - `linear_difficulties`: The list of difficulties `γᵢ` corresponding to
+            the derivatives. The length of this tuple represents the highest
+            occuring derivative. The default value `(0.0, 0.0, -0.128, 0.0,
+            -0.32768)` corresponds to the Kuramoto-Sivashinsky equation in
+            combustion format (because it contains both a negative diffusion and
+            a negative hyperdiffusion term).
+        - `gradient_norm_difficulty`: The difficulty of the gradient norm term
+            `δ₂`.
+        - `maximum_absolute`: The maximum absolute value of the input state. This
+            is used to scale the gradient norm term.
+        - `order`: The order of the Exponential Time Differencing Runge
+            Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only
+            solves the linear part of the equation. Use higher values for higher
+            accuracy and stability. The default choice of `2` is a good
+            compromise for single precision floats.
+        - `dealiasing_fraction`: The fraction of the wavenumbers to keep
+            before evaluating the nonlinearity. The default 2/3 corresponds to
+            Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. Default:
+            2/3.
+        - `num_circle_points`: How many points to use in the complex contour
+            integral method to compute the coefficients of the exponential time
+            differencing Runge Kutta method. Default: 16.
+        - `circle_radius`: The radius of the contour used to compute the
+            coefficients of the exponential time differencing Runge Kutta
+            method. Default: 1.0.
         """
         self.linear_difficulties = linear_difficulties
         self.gradient_norm_difficulty = gradient_norm_difficulty

From cb562fd3f436b8dd769d8aff449b021b7036abb1 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Tue, 3 Sep 2024 09:17:41 +0200
Subject: [PATCH 23/26] Add docstring

---
 .../stepper/generic/_vorticity_convection.py  | 58 +++++++++++++++++++
 1 file changed, 58 insertions(+)

diff --git a/exponax/stepper/generic/_vorticity_convection.py b/exponax/stepper/generic/_vorticity_convection.py
index 8f64e81..f75b2c4 100644
--- a/exponax/stepper/generic/_vorticity_convection.py
+++ b/exponax/stepper/generic/_vorticity_convection.py
@@ -28,6 +28,64 @@ def __init__(
         num_circle_points: int = 16,
         circle_radius: float = 1.0,
     ):
+        """
+        Timestepper for 2D PDEs consisting of vorticity convection term and an
+        arbitrary combination of (isotropic) linear derivatives.
+
+        ```
+            uₜ + b ([1, -1]ᵀ ⊙ ∇(Δ⁻¹u)) ⋅ ∇u = sum_j a_j (1⋅∇ʲ)u
+        ```
+
+        where `b` is the vorticity convection scale, `a_j` are the coefficients
+        of the linear derivatives, and `∇ʲ` is the `j`-th derivative operator.
+
+        In the default configuration, this corresponds to the 2D Navier-Stokes
+        simulation with a viscosity of `ν = 0.001` (the resulting Reynols number
+        depends on the `domain_extent`. In the case of a unit square domain,
+        i.e., `domain_extent = 1`, the Reynols number is `Re = 1/ν = 1000`).
+        Depending on the initial state, this corresponds to a decaying 2D
+        turbulence.
+
+        Additionally, one can set an `injection_mode` and `injection_scale` to
+        inject energy into the system. For example, this allows for the
+        simulation of forced turbulence (=Kolmogorov flow).
+
+        **Arguments:**
+
+        - `num_spatial_dims`: number of spatial dimensions `D`.
+        - `domain_extent`: The size of the domain `L`; in higher dimensions
+            the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`.
+        - `num_points`: The number of points `N` used to discretize the
+            domain. This **includes** the left boundary point and **excludes**
+            the right boundary point. In higher dimensions; the number of points
+            in each dimension is the same. Hence, the total number of degrees of
+            freedom is `Nᵈ`.
+        - `dt`: The timestep size `Δt` between two consecutive states.
+        - `coefficients`: The list of coefficients `a_j`
+            corresponding to the derivatives. The length of this tuple
+            represents the highest occuring derivative. The default value `(0.0,
+            0.0, 0.001)` corresponds to pure regular diffusion.
+        - `vorticity_convection_scale`: The scale `b` of the vorticity
+            convection term.
+        - `injection_mode`: The mode of the injection.
+        - `injection_scale`: The scale of the injection. Defaults to `0.0` which
+            means no injection. Hence, the flow will decay over time.
+        - `dealiasing_fraction`: The fraction of the modes that are kept after
+            dealiasing. The default value `2/3` corresponds to the 2/3 rule.
+        - `order`: The order of the ETDRK method to use. Must be one of {0, 1,
+            2, 3, 4}. The option `0` only solves the linear part of the
+            equation. Hence, only use this for linear PDEs. For nonlinear PDEs,
+            a higher order method tends to be more stable and accurate. `2` is
+            often a good compromis in single-precision. Use `4` together with
+            double precision (`jax.config.update("jax_enable_x64", True)`) for
+            highest accuracy.
+        - `num_circle_points`: How many points to use in the complex contour
+            integral method to compute the coefficients of the exponential time
+            differencing Runge Kutta method.
+        - `circle_radius`: The radius of the contour used to compute the
+            coefficients of the exponential time differencing Runge Kutta
+            method.
+        """
         if num_spatial_dims != 2:
             raise ValueError(f"Expected num_spatial_dims = 2, got {num_spatial_dims}.")
         self.vorticity_convection_scale = vorticity_convection_scale

From 7ccef6f47667649c144ed844753735f7ce09f182 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Tue, 3 Sep 2024 09:42:39 +0200
Subject: [PATCH 24/26] Add documentation

---
 exponax/stepper/generic/_polynomial.py | 180 ++++++++++++++++++++++++-
 1 file changed, 175 insertions(+), 5 deletions(-)

diff --git a/exponax/stepper/generic/_polynomial.py b/exponax/stepper/generic/_polynomial.py
index bb95451..372611b 100644
--- a/exponax/stepper/generic/_polynomial.py
+++ b/exponax/stepper/generic/_polynomial.py
@@ -26,11 +26,84 @@ def __init__(
         circle_radius: float = 1.0,
     ):
         """
-        By default: Fisher-KPP with a small diffusion and 10.0 reactivity
+        Timestepper for the d-dimensional (`d ∈ {1, 2, 3}`) semi-linear PDEs
+        consisting of an arbitrary combination of polynomial nonlinearities and
+        (isotropic) linear derivatives. This can be used to represent a wide
+        array of reaction-diffusion equations.
 
-        Note that the first two entries in the polynomial_scales list are often zero.
+        In 1d, the PDE is of the form
 
-        The effect of polynomial_scale[1] is similar to the effect of coefficients[0]
+        ```
+            uₜ = ∑ₖ pₖ uᵏ + ∑ⱼ aⱼ uₓʲ
+        ```
+
+        where `pₖ` are the polynomial coefficients and `aⱼ` are the linear
+        coefficients. `uᵏ` denotes `u` pointwise raised to the power of `k`
+        (hence the polynomial contribution) and `uₓʲ` denotes the `j`-th
+        derivative of `u`.
+
+        The higher-dimensional generalization reads
+
+        ```
+            uₜ = ∑ₖ pₖ uᵏ + ∑ⱼ a_j (1⋅∇ʲ)u
+
+        ```
+
+        where `∇ʲ` is the `j`-th derivative operator.
+
+        The default configuration corresponds to the Fisher-KPP equation with
+        the following settings
+
+        ```python
+
+        exponax.stepper.reaction.FisherKPP(
+            num_spatial_dims=num_spatial_dims, domain_extent=domain_extent,
+            num_points=num_points, dt=dt, diffusivity=0.01, reactivity=-10.0,
+            #TODO: Check this
+        )
+        ```
+
+        Note that the effect of polynomial_scale[1] is similar to the effect of
+        coefficients[0] with the difference that in ETDRK integration the latter
+        is treated anlytically and should be preferred.
+
+        **Arguments:**
+
+        - `num_spatial_dims`: The number of spatial dimensions `d`.
+        - `domain_extent`: The size of the domain `L`; in higher dimensions
+            the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`.
+        - `num_points`: The number of points `N` used to discretize the
+            domain. This **includes** the left boundary point and **excludes**
+            the right boundary point. In higher dimensions; the number of points
+            in each dimension is the same. Hence, the total number of degrees of
+            freedom is `Nᵈ`.
+        - `dt`: The timestep size `Δt` between two consecutive states.
+        - `coefficients`: The list of coefficients `a_j` corresponding to the
+            derivatives. The length of this tuple represents the highest
+            occuring derivative. The default value `(10.0, 0.0, 0.01)` in
+            combination with the default `polynomial_scales` corresponds to the
+            Fisher-KPP equation.
+        - `polynomial_scales`: The list of scales `pₖ` corresponding to the
+            polynomial contributions. The length of this tuple represents the
+            highest occuring polynomial. The default value `(0.0, 0.0, 10.0)` in
+            combination with the default `coefficients` corresponds to the
+            Fisher-KPP equation.
+        - `order`: The order of the Exponential Time Differencing Runge
+            Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only
+            solves the linear part of the equation. Use higher values for higher
+            accuracy and stability. The default choice of `2` is a good
+            compromise for single precision floats.
+        - `dealiasing_fraction`: The fraction of the wavenumbers to keep
+            before evaluating the nonlinearity. The default 2/3 corresponds to
+            Orszag's 2/3 rule which is sufficient if the highest occuring
+            polynomial is quadratic (i.e., there are at maximum three entries in
+            the `polynomial_scales` tuple).
+        - `num_circle_points`: How many points to use in the complex contour
+            integral method to compute the coefficients of the exponential time
+            differencing Runge Kutta method.
+        - `circle_radius`: The radius of the contour used to compute the
+            coefficients of the exponential time differencing Runge Kutta
+            method.
         """
         self.coefficients = coefficients
         self.polynomial_scales = polynomial_scales
@@ -98,7 +171,48 @@ def __init__(
         circle_radius: float = 1.0,
     ):
         """
-        By default: Fisher-KPP
+        Timestepper for the **normalized** d-dimensional (`d ∈ {1, 2, 3}`)
+        semi-linear PDEs consisting of an arbitrary combination of polynomial
+        nonlinearities and (isotropic) linear derivatives. Uses a normalized
+        interface, i.e., the domain is scaled to `Ω = (0, 1)ᵈ` and time step
+        size is `Δt = 1.0`.
+
+        See `exponax.stepper.generic.GeneralPolynomialStepper` for more details
+        on the functional form of the PDE.
+
+        The default settings correspond to the Fisher-KPP equation.
+
+        **Arguments:**
+
+        - `num_spatial_dims`: The number of spatial dimensions `d`.
+        - `num_points`: The number of points `N` used to discretize the domain.
+            This **includes** the left boundary point and **excludes** the right
+            boundary point. In higher dimensions; the number of points in each
+            dimension is the same. Hence, the total number of degrees of freedom
+            is `Nᵈ`.
+        - `normalized_coefficients`: The list of coefficients `α_j` corresponding
+            to the derivatives. The length of this tuple represents the highest
+            occuring derivative. The default value corresponds to the Fisher-KPP
+            equation.
+        - `normalized_polynomial_scales`: The list of scales `βₖ` corresponding
+            to the polynomial contributions. The length of this tuple represents
+            the highest occuring polynomial. The default value corresponds to the
+            Fisher-KPP equation.
+        - `order`: The order of the Exponential Time Differencing Runge Kutta
+            method. Must be one of {0, 1, 2, 3, 4}. The option `0` only solves
+            the linear part of the equation. Use higher values for higher accuracy
+            and stability. The default choice of `2` is a good compromise for
+            single precision floats.
+        - `dealiasing_fraction`: The fraction of the wavenumbers to keep before
+            evaluating the nonlinearity. The default 2/3 corresponds to Orszag's
+            2/3 rule which is sufficient if the highest occuring polynomial is
+            quadratic (i.e., there are at maximum three entries in the
+            `normalized_polynomial_scales` tuple).
+        - `num_circle_points`: How many points to use in the complex contour
+            integral method to compute the coefficients of the exponential time
+            differencing Runge Kutta method.
+        - `circle_radius`: The radius of the contour used to compute the
+            coefficients of the exponential time differencing Runge Kutta method.
         """
         self.normalized_coefficients = normalized_coefficients
         self.normalized_polynomial_scales = normalized_polynomial_scales
@@ -142,7 +256,63 @@ def __init__(
         circle_radius: float = 1.0,
     ):
         """
-        By default: Fisher-KPP
+        Timestepper for **difficulty-based** d-dimensional (`d ∈ {1, 2, 3}`)
+        semi-linear PDEs consisting of an arbitrary combination of polynomial
+        nonlinearities and (isotropic) linear derivatives. Uses a
+        difficulty-based interface where the "intensity" of the dynamics reduces
+        with increasing resolution. This is intended such that emulator learning
+        problems on two resolutions are comparibly difficult.
+
+        Different to `exponax.stepper.generic.NormalizedPolynomialStepper`, the
+        dynamics are defined by difficulties. The difficulties are a different
+        combination of normalized dynamics, `num_spatial_dims`, and
+        `num_points`.
+
+            γᵢ = αᵢ Nⁱ 2ⁱ⁻¹ d
+
+        with `d` the number of spatial dimensions, `N` the number of points, and
+        `αᵢ` the normalized coefficient.
+
+        Since the polynomial nonlinearity does not contain any derivatives, we
+        have that
+
+        ```
+            normalized_polynomial_scales = polynomial_difficulties
+        ```
+
+        The default settings correspond to the Fisher-KPP equation.
+
+        **Arguments:**
+
+        - `num_spatial_dims`: The number of spatial dimensions `d`.
+        - `num_points`: The number of points `N` used to discretize the domain.
+            This **includes** the left boundary point and **excludes** the right
+            boundary point. In higher dimensions; the number of points in each
+            dimension is the same. Hence, the total number of degrees of freedom
+            is `Nᵈ`.
+        - `linear_difficulties`: The list of difficulties `γ_j` corresponding to
+            the derivatives. The length of this tuple represents the highest
+            occuring derivative. The default value corresponds to the Fisher-KPP
+            equation.
+        - `polynomial_difficulties`: The list of difficulties `δₖ` corresponding
+            to the polynomial contributions. The length of this tuple represents
+            the highest occuring polynomial. The default value corresponds to the
+            Fisher-KPP equation.
+        - `order`: The order of the Exponential Time Differencing Runge Kutta
+            method. Must be one of {0, 1, 2, 3, 4}. The option `0` only solves
+            the linear part of the equation. Use higher values for higher accuracy
+            and stability. The default choice of `2` is a good compromise for
+            single precision floats.
+        - `dealiasing_fraction`: The fraction of the wavenumbers to keep before
+            evaluating the nonlinearity. The default 2/3 corresponds to Orszag's
+            2/3 rule which is sufficient if the highest occuring polynomial is
+            quadratic (i.e., there are at maximum three entries in the
+            `polynomial_difficulties` tuple).
+        - `num_circle_points`: How many points to use in the complex contour
+            integral method to compute the coefficients of the exponential time
+            differencing Runge Kutta method.
+        - `circle_radius`: The radius of the contour used to compute the
+            coefficients of the exponential time differencing Runge Kutta method.
         """
         self.linear_difficulties = linear_difficulties
         self.polynomial_difficulties = polynomial_difficulties

From 257c1dd7fe36e4505dfa4d5e597b63a3321b3a18 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Tue, 3 Sep 2024 10:40:17 +0200
Subject: [PATCH 25/26] Add documentation

---
 exponax/stepper/generic/_nonlinear.py | 229 +++++++++++++++++++++++++-
 1 file changed, 227 insertions(+), 2 deletions(-)

diff --git a/exponax/stepper/generic/_nonlinear.py b/exponax/stepper/generic/_nonlinear.py
index 67c43ec..60315be 100644
--- a/exponax/stepper/generic/_nonlinear.py
+++ b/exponax/stepper/generic/_nonlinear.py
@@ -29,7 +29,83 @@ def __init__(
         circle_radius: float = 1.0,
     ):
         """
-        By default Burgers equation
+        Timestepper for d-dimensional (`d ∈ {1, 2, 3}`) semi-linear PDEs
+        consisting of a quadratic, a single-channel convection, and a gradient
+        norm nonlinearity together with an arbitrary combination of (isotropic)
+        linear derivatives.
+
+        In 1d, the PDE is of the form
+
+        ```
+            uₜ = b₀ u² + b₁ 1/2 (u²)ₓ + b₂ 1/2 (uₓ)² + sum_j a_j uₓʲ
+        ```
+
+        where `b₀`, `b₁`, `b₂` are the coefficients of the quadratic,
+        convection, and gradient norm nonlinearity, respectively, and `a_j` are
+        the coefficients of the linear derivatives. Effectively, this
+        timestepper is a combination of the
+        `exponax.stepper.generic.GeneralPolynomialStepper` (with only the
+        coefficient to the quadratic polynomial being set with `b₀`), the
+        `exponax.stepper.generic.GeneralConvectionStepper` (with the
+        single-channel hack activated via `single_channel=True` and the
+        convection scale set with `- b₁`), and the
+        `exponax.stepper.generic.GeneralGradientNormStepper` (with the gradient
+        norm scale set with `- b₂`).
+
+        !!! warning
+            In contrast to the
+            `exponax.stepper.generic.GeneralConvectionStepper` and the
+            `exponax.stepper.generic.GeneralGradientNormStepper`, the nonlinear
+            terms are no considered to be on right-hand side of the PDE. Hence,
+            in order to get the same dynamics as in the other steppers, the
+            coefficients must be negated. (This is not relevant for the
+            coefficient of the quadratic polynomial because in the
+            `exponax.stepper.generic.GeneralPolynomialStepper` the polynomial
+            nonlinearity is already on the right-hand side.)
+
+        The higher-dimensional generalization is
+
+        ```
+            uₜ = b₀ u² + b₁ 1/2 (1⃗ ⋅ ∇)(u²) + b₂ 1/2 ‖ ∇u ‖₂² + sum_j a_j uₓˢ
+        ```
+
+        Under the default configuration, this correspons to a Burgers equation
+        in single-channel mode.
+
+        **Arguments:**
+
+        - `num_spatial_dims`: The number of spatial dimensions `d`.
+        - `domain_extent`: The size of the domain `L`; in higher dimensions the
+            domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`.
+        - `num_points`: The number of points `N` used to discretize the domain.
+            This **includes** the left boundary point and **excludes** the right
+            boundary point. In higher dimensions; the number of points in each
+            dimension is the same. Hence, the total number of degrees of freedom
+            is `Nᵈ`.
+        - `dt`: The timestep size `Δt` between two consecutive states.
+        - `coefficients_linear`: The list of coefficients `a_j` corresponding to
+            the derivatives. The length of this tuple represents the highest
+            occuring derivative. The default value `(0.0, 0.0, 0.01)` together
+            with the default `coefficients_nonlinear` corresponds to the Burgers
+            equation.
+        - `coefficients_nonlinear`: The list of coefficients `b₀`, `b₁`, `b₂`
+            (in this order). The default value `(0.0, -1.0, 0.0)` corresponds to
+            a (single-channel) convection nonlinearity scaled with `1.0`. Note
+            that all nonlinear contributions are considered to be on the
+            right-hand side of the PDE. Hence, in order to get the "correct
+            convection" dynamics, the coefficients must be negated.
+        - `order`: The order of the ETDRK method to use. Must be one of {0, 1, 2,
+            3, 4}. The option `0` only solves the linear part of the equation.
+            Use higher values for higher accuracy and stability. The default
+            choice of `2` is a good compromise for single precision floats.
+        - `dealiasing_fraction`: The fraction of the wavenumbers to keep before
+            evaluating the nonlinearity. The default value `2/3` corresponds to
+            Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2.
+        - `num_circle_points`: How many points to use in the complex contour
+            integral method to compute the coefficients of the exponential time
+            differencing Runge Kutta method.
+        - `circle_radius`: The radius of the contour used to compute the
+            coefficients of the exponential time differencing Runge Kutta method.
         """
         if len(coefficients_nonlinear) != 3:
             raise ValueError(
@@ -96,6 +172,84 @@ def __init__(
     ):
         """
         By default Burgers.
+
+        Timesteppr for **normalized** d-dimensional (`d ∈ {1, 2, 3}`)
+        semi-linear PDEs consisting of a quadratic, a single-channel convection,
+        and a gradient norm nonlinearity together with an arbitrary combination
+        of (isotropic) linear derivatives. Uses a normalized interface, i.e.,
+        the domain is scaled to `Ω = (0, 1)ᵈ` and time step size is `Δt = 1.0`.
+
+        See `exponax.stepper.generic.GeneralNonlinearStepper` for more details
+        on the functional form of the PDE.
+
+        Behaves like a single-channel Burgers equation by default under the
+        following settings
+
+        ```python
+
+        exponax.stepper.Burgers(
+            num_spatial_dims=num_spatial_dims, domain_extent=1.0,
+            num_points=num_points, dt=1.0, convection_scale=1.0,
+            diffusivity=0.1, single_channel=True,
+        )
+        ```
+
+        Effectively, this timestepper is a combination of the
+        `exponax.stepper.generic.NormalizedPolynomialStepper` (with only the
+        coefficient to the quadratic polynomial being set with `b₀`), the
+        `exponax.stepper.generic.NormalizedConvectionStepper` (with the
+        single-channel hack activated via `single_channel=True` and the
+        convection scale set with `- b₁`), and the
+        `exponax.stepper.generic.NormalizedGradientNormStepper` (with the
+        gradient norm scale set with `- b₂`).
+
+        !!! warning
+            In contrast to the
+            `exponax.stepper.generic.NormalizedConvectionStepper` and the
+            `exponax.stepper.generic.NormalizedGradientNormStepper`, the
+            nonlinear terms are no considered to be on right-hand side of the
+            PDE. Hence, in order to get the same dynamics as in the other
+            steppers, the coefficients must be negated. (This is not relevant
+            for the coefficient of the quadratic polynomial because in the
+            `exponax.stepper.generic.NormalizedPolynomialStepper` the polynomial
+            nonlinearity is already on the right-hand side.)
+
+
+        **Arguments:**
+
+        - `num_spatial_dims`: The number of spatial dimensions `d`.
+        - `num_points`: The number of points `N` used to discretize the domain.
+            This **includes** the left boundary point and **excludes** the right
+            boundary point. In higher dimensions; the number of points in each
+            dimension is the same. Hence, the total number of degrees of freedom
+            is `Nᵈ`.
+        - `normalized_coefficients_linear`: The list of coefficients `αⱼ`
+            corresponding to the linear derivatives. The length of this tuple
+            represents the highest occuring derivative. The default value `(0.0,
+            0.0, 0.1 * 0.1)` together with the default
+            `normalized_coefficients_nonlinear` corresponds to the Burgers
+            equation (in single-channel mode).
+        - `normalized_coefficients_nonlinear`: The list of coefficients `β₀`,
+            `β₁`, and `β₂` (in this order) corresponding to the quadratic,
+            (single-channel) convection, and gradient norm nonlinearity,
+            respectively. The default value `(0.0, -1.0 * 0.1, 0.0)` corresponds
+            to a (single-channel) convection nonlinearity scaled with `0.1`.
+            Note that all nonlinear contributions are considered to be on the
+            right-hand side of the PDE. Hence, in order to get the "correct
+            convection" dynamics, the coefficients must be negated. (Also
+            relevant for the gradient norm nonlinearity)
+        - `order`: The order of the ETDRK method to use. Must be one of {0, 1, 2,
+            3, 4}. The option `0` only solves the linear part of the equation.
+            Use higher values for higher accuracy and stability. The default
+            choice of `2` is a good compromise for single precision floats.
+        - `dealiasing_fraction`: The fraction of the wavenumbers to keep before
+            evaluating the nonlinearity. The default value `2/3` corresponds to
+            Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2.
+        - `num_circle_points`: How many points to use in the complex contour
+            integral method to compute the coefficients of the exponential time
+            differencing Runge Kutta method.
+        - `circle_radius`: The radius of the contour used to compute the
+            coefficients of the exponential time differencing Runge Kutta method.
         """
 
         self.normalized_coefficients_linear = normalized_coefficients_linear
@@ -141,7 +295,78 @@ def __init__(
         circle_radius: float = 1.0,
     ):
         """
-        By default Burgers.
+        Timestepper for **difficulty-based** d-dimensional (`d ∈ {1, 2, 3}`)
+        semi-linear PDEs consisting of a quadratic, a single-channel convection,
+        and a gradient norm nonlinearity together with an arbitrary combination
+        of (isotropic) linear derivatives. Uses a difficulty-based interface
+        where the "intensity" of the dynamics reduces with increasing
+        resolution. This is intended such that emulator learning problems on two
+        resolutions are comparibly difficult.
+
+        Different to `exponax.stepper.generic.NormalizedNonlinearStepper`, the
+        dynamics are defined by difficulties. The difficulties are a different
+        combination of normalized dynamics, `num_spatial_dims`, and
+        `num_points`.
+
+            γᵢ = αᵢ Nⁱ 2ⁱ⁻¹ d
+
+        with `d` the number of spatial dimensions, `N` the number of points, and
+        `αᵢ` the normalized coefficient.
+
+        The difficulties of the nonlinear terms are
+
+            δ₀ = β₀
+
+            δ₁ = β₁ * M * N * D
+
+            δ₂ = β₂ * M * N² * D
+
+        with `βᵢ` the normalized coefficient and `M` the maximum absolute value
+        of the input state (typically `1.0` if one uses the `exponax.ic` random
+        generators with the `max_one=True` argument).
+
+        This interface is more natural than the normalized interface because the
+        difficulties for all orders (given by `i`) are around 1.0. Additionally,
+        they relate to stability condition of explicit Finite Difference schemes
+        for the particular equations. For example, for advection (`i=1`), the
+        absolute of the difficulty is the Courant-Friedrichs-Lewy (CFL) number.
+
+        Under the default settings, this timestep corresponds to a Burgers
+        equation in single-channel mode.
+
+        **Arguments:**
+
+        - `num_spatial_dims`: The number of spatial dimensions `d`.
+        - `num_points`: The number of points `N` used to discretize the domain.
+            This **includes** the left boundary point and **excludes** the right
+            boundary point. In higher dimensions; the number of points in each
+            dimension is the same. Hence, the total number of degrees of freedom
+            is `Nᵈ`.
+        - `linear_difficulties`: The list of difficulties `γᵢ` corresponding to
+            the linear derivatives. The length of this tuple represents the
+            highest occuring derivative. The default value `(0.0, 0.0, 0.1 * 0.1
+            / 1.0 * 48**2 * 2)` together with the default `nonlinear_difficulties`
+            corresponds to the Burgers equation.
+        - `nonlinear_difficulties`: The list of difficulties `δ₀`, `δ₁`, and `δ₂`
+            (in this order) corresponding to the quadratic, (single-channel)
+            convection, and gradient norm nonlinearity, respectively. The default
+            value `(0.0, -1.0 * 0.1 / 1.0 * 48, 0.0)` corresponds to a
+            (single-channel) convection nonlinearity. Note that all nonlinear
+            contributions are considered to be on the right-hand side of the PDE.
+        - `maximum_absolute`: The maximum absolute value of the input state. This
+            is used to scale the nonlinear difficulties.
+        - `order`: The order of the ETDRK method to use. Must be one of {0, 1, 2,
+            3, 4}. The option `0` only solves the linear part of the equation.
+            Use higher values for higher accuracy and stability. The default
+            choice of `2` is a good compromise for single precision floats.
+        - `dealiasing_fraction`: The fraction of the wavenumbers to keep before
+            evaluating the nonlinearity. The default value `2/3` corresponds to
+            Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2.
+        - `num_circle_points`: How many points to use in the complex contour
+            integral method to compute the coefficients of the exponential time
+            differencing Runge Kutta method.
+        - `circle_radius`: The radius of the contour used to compute the
+            coefficients of the exponential time differencing Runge Kutta method.
         """
         self.linear_difficulties = linear_difficulties
         self.nonlinear_difficulties = nonlinear_difficulties

From 356691b3e982aaf522b1a00acf0ffb7d786f7ba9 Mon Sep 17 00:00:00 2001
From: Felix Koehler <f.koehler@tum.de>
Date: Tue, 3 Sep 2024 12:38:48 +0200
Subject: [PATCH 26/26] Add documentation

---
 exponax/stepper/generic/_utils.py | 307 ++++++++++++++++++++++++++++--
 1 file changed, 295 insertions(+), 12 deletions(-)

diff --git a/exponax/stepper/generic/_utils.py b/exponax/stepper/generic/_utils.py
index 6049e34..4263913 100644
--- a/exponax/stepper/generic/_utils.py
+++ b/exponax/stepper/generic/_utils.py
@@ -11,11 +11,24 @@ def normalize_coefficients(
     Normalize the coefficients to a linear time stepper to be used with the
     normalized linear stepper.
 
+        αᵢ = aᵢ Δt / Lⁱ
+
+    !!! warning
+        A consequence of this normalization is that the normalized coefficients
+        for high order derivatives will be very small.
+
     **Arguments:**
+
     - `coefficients`: coefficients for the linear operator, `coefficients[i]` is
         the coefficient for the `i`-th derivative
     - `domain_extent`: extent of the domain
     - `dt`: time step
+
+    **Returns:**
+
+    - `normalized_coefficients`: normalized coefficients for the linear
+        operator, `normalized_coefficients[i]` is the coefficient for the `i`-th
+        derivative
     """
     normalized_coefficients = tuple(
         c * dt / (domain_extent**i) for i, c in enumerate(coefficients)
@@ -31,14 +44,22 @@ def denormalize_coefficients(
 ) -> tuple[float, ...]:
     """
     Denormalize the coefficients as they were used in the normalized linear to
-    then be used again in a regular linear stepper.
+    then be used again in a genric linear stepper with a physical interface.
+
+        aᵢ = αᵢ Lⁱ / Δt
 
     **Arguments:**
+
     - `normalized_coefficients`: coefficients for the linear operator,
         `normalized_coefficients[i]` is the coefficient for the `i`-th
         derivative
     - `domain_extent`: extent of the domain
     - `dt`: time step
+
+    **Returns:**
+
+    - `coefficients`: coefficients for the linear operator, `coefficients[i]` is
+        the coefficient for the `i`-th derivative
     """
     coefficients = tuple(
         c_n / dt * domain_extent**i for i, c_n in enumerate(normalized_coefficients)
@@ -52,6 +73,23 @@ def normalize_convection_scale(
     domain_extent: float,
     dt: float,
 ) -> float:
+    """
+    Normalize the scale (=coefficient) in front of the convection term to be
+    used with the normalized generic steppers.
+
+        β₁ = b₁ Δt / L
+
+    **Arguments:**
+
+    - `convection_scale`: scale in front of the convection term, i.e., the `b_1`
+        in `𝒩(u) = - b₁ 1/2 (u²)ₓ`
+    - `domain_extent`: extent of the domain
+    - `dt`: time step
+
+    **Returns:**
+
+    - `normalized_convection_scale`: normalized scale in front of the convection
+    """
     normalized_convection_scale = convection_scale * dt / domain_extent
     return normalized_convection_scale
 
@@ -62,6 +100,24 @@ def denormalize_convection_scale(
     domain_extent: float,
     dt: float,
 ) -> float:
+    """
+    Denormalize the scale in front of the convection term as it was used in the
+    normalized generic steppers to then be used again in a generic stepper with
+    a physical interface.
+
+        b₁ = β₁ L / Δt
+
+    **Arguments:**
+
+    - `normalized_convection_scale`: normalized scale in front of the convection
+    - `domain_extent`: extent of the domain
+    - `dt`: time step
+
+    **Returns:**
+
+    - `convection_scale`: scale in front of the convection term, i.e., the `b_1`
+        in `𝒩(u) = - b₁ 1/2 (u²)ₓ`
+    """
     convection_scale = normalized_convection_scale / dt * domain_extent
     return convection_scale
 
@@ -72,6 +128,24 @@ def normalize_gradient_norm_scale(
     domain_extent: float,
     dt: float,
 ):
+    """
+    Normalize the scale in front of the gradient norm term to be used with the
+    normalized generic steppers.
+
+        β₂ = b₂ Δt / L²
+
+    **Arguments:**
+
+    - `gradient_norm_scale`: scale in front of the gradient norm term, i.e., the
+        `b_2` in `𝒩(u) = - b₂ 1/2 ‖∇u‖₂²`
+    - `domain_extent`: extent of the domain
+    - `dt`: time step
+
+    **Returns:**
+
+    - `normalized_gradient_norm_scale`: normalized scale in front of the
+        gradient norm term
+    """
     normalized_gradient_norm_scale = (
         gradient_norm_scale * dt / jnp.square(domain_extent)
     )
@@ -84,6 +158,25 @@ def denormalize_gradient_norm_scale(
     domain_extent: float,
     dt: float,
 ):
+    """
+    Denormalize the scale in front of the gradient norm term as it was used in
+    the normalized generic steppers to then be used again in a generic stepper
+    with a physical interface.
+
+        b₂ = β₂ L² / Δt
+
+    **Arguments:**
+
+    - `normalized_gradient_norm_scale`: normalized scale in front of the gradient
+        norm term
+    - `domain_extent`: extent of the domain
+    - `dt`: time step
+
+    **Returns:**
+
+    - `gradient_norm_scale`: scale in front of the gradient norm term, i.e., the
+        `b_2` in `𝒩(u) = - b₂ 1/2 ‖∇u‖₂²`
+    """
     gradient_norm_scale = (
         normalized_gradient_norm_scale / dt * jnp.square(domain_extent)
     )
@@ -101,11 +194,18 @@ def normalize_polynomial_scales(
     stepper.
 
     **Arguments:**
-        - `polynomial_scales`: scales for the polynomial operator,
-            `polynomial_scales[i]` is the scale for the `i`-th derivative
-        - `domain_extent`: extent of the domain (not needed, kept for
-            compatibility with other normalization APIs)
-        - `dt`: time step
+
+    - `polynomial_scales`: scales for the polynomial operator,
+        `polynomial_scales[i]` is the scale for the `i`-th degree polynomial
+    - `domain_extent`: extent of the domain (not needed, kept for
+        compatibility with other normalization APIs)
+    - `dt`: time step
+
+    **Returns:**
+
+    - `normalized_polynomial_scales`: normalized scales for the polynomial
+        operator, `normalized_polynomial_scales[i]` is the scale for the `i`-th
+        degree polynomial
     """
     normalized_polynomial_scales = tuple(c * dt for c in polynomial_scales)
     return normalized_polynomial_scales
@@ -122,12 +222,17 @@ def denormalize_polynomial_scales(
     polynomial to then be used again in a regular polynomial stepper.
 
     **Arguments:**
-        - `normalized_polynomial_scales`: scales for the polynomial operator,
-            `normalized_polynomial_scales[i]` is the scale for the `i`-th
-            derivative
-        - `domain_extent`: extent of the domain (not needed, kept for
-            compatibility with other normalization APIs)
-        - `dt`: time step
+
+    - `normalized_polynomial_scales`: scales for the polynomial operator,
+        `normalized_polynomial_scales[i]` is the scale for the `i`-th degree
+        polynomial
+    - `domain_extent`: extent of the domain (not needed, kept for
+        compatibility with other normalization APIs)
+    - `dt`: time step
+
+    **Returns:**
+
+    - `polynomial_scales`: scales for the polynomial operator,
     """
     polynomial_scales = tuple(c_n / dt for c_n in normalized_polynomial_scales)
     return polynomial_scales
@@ -139,6 +244,31 @@ def reduce_normalized_coefficients_to_difficulty(
     num_spatial_dims: int,
     num_points: int,
 ):
+    """
+    Reduce the normalized coefficients for a linear operator to a difficulty
+    based interface. This interface is designed to "reduce the intensity of the
+    dynamics" at higher resolutions to make emulator learning across resolutions
+    comparible. Thereby, it resembles the stability numbers of the most compact
+    finite difference scheme of the respective PDE.
+
+        γ₀ = α₀
+
+        γⱼ = αⱼ Nʲ 2ʲ⁻¹ D
+
+    **Arguments:**
+
+    - `normalized_coefficients`: normalized coefficients for the linear
+        operator, `normalized_coefficients[i]` is the coefficient for the `i`-th
+        derivative
+    - `num_spatial_dims`: number of spatial dimensions `d`
+    - `num_points`: number of points `N` used to discretize the domain per
+        dimension
+
+    **Returns:**
+
+    - `difficulty_coefficients`: difficulty coefficients for the linear operator,
+        `difficulty_coefficients[i]` is the coefficient for the `i`-th derivative
+    """
     difficulty_coefficients = list(
         alpha * num_points**j * 2 ** (j - 1) * num_spatial_dims
         for j, alpha in enumerate(normalized_coefficients)
@@ -155,6 +285,27 @@ def extract_normalized_coefficients_from_difficulty(
     num_spatial_dims: int,
     num_points: int,
 ):
+    """
+    Extract the normalized coefficients for a linear operator from a difficulty
+    based interface.
+
+        α₀ = γ₀
+
+        αⱼ = γⱼ / (Nʲ 2ʲ⁻¹ D)
+
+    **Arguments:**
+
+    - `difficulty_coefficients`: difficulty coefficients for the linear operator,
+        `difficulty_coefficients[i]` is the coefficient for the `i`-th derivative
+    - `num_spatial_dims`: number of spatial dimensions `d`
+    - `num_points`: number of points `N` used to discretize the domain per
+        dimension
+
+    **Returns:**
+
+    - `normalized_coefficients`: normalized coefficients for the linear operator,
+        `normalized_coefficients[i]` is the coefficient for the `i`-th derivative
+    """
     normalized_coefficients = list(
         gamma / (num_points**j * 2 ** (j - 1) * num_spatial_dims)
         for j, gamma in enumerate(difficulty_coefficients)
@@ -172,6 +323,25 @@ def reduce_normalized_convection_scale_to_difficulty(
     num_points: int,
     maximum_absolute: float,
 ):
+    """
+    Reduce the normalized convection scale to a difficulty based interface.
+
+        δ₁ = β₁ * M * N * D
+
+    **Arguments:**
+
+    - `normalized_convection_scale`: normalized convection scale, see also
+        `exponax.stepper.generic.normalize_convection_scale`
+    - `num_spatial_dims`: number of spatial dimensions `d`
+    - `num_points`: number of points `N` used to discretize the domain per
+        dimension
+    - `maximum_absolute`: maximum absolute value of the input state the
+        resulting stepper is applied to
+
+    **Returns:**
+
+    - `difficulty_convection_scale`: difficulty convection scale
+    """
     difficulty_convection_scale = (
         normalized_convection_scale * maximum_absolute * num_points * num_spatial_dims
     )
@@ -185,6 +355,25 @@ def extract_normalized_convection_scale_from_difficulty(
     num_points: int,
     maximum_absolute: float,
 ):
+    """
+    Extract the normalized convection scale from a difficulty based interface.
+
+        β₁ = δ₁ / (M * N * D)
+
+    **Arguments:**
+
+    - `difficulty_convection_scale`: difficulty convection scale
+    - `num_spatial_dims`: number of spatial dimensions `d`
+    - `num_points`: number of points `N` used to discretize the domain per
+        dimension
+    - `maximum_absolute`: maximum absolute value of the input state the
+        resulting stepper is applied to
+
+    **Returns:**
+
+    - `normalized_convection_scale`: normalized convection scale, see also
+        `exponax.stepper.generic.normalize_convection_scale`
+    """
     normalized_convection_scale = difficulty_convection_scale / (
         maximum_absolute * num_points * num_spatial_dims
     )
@@ -198,6 +387,25 @@ def reduce_normalized_gradient_norm_scale_to_difficulty(
     num_points: int,
     maximum_absolute: float,
 ):
+    """
+    Reduce the normalized gradient norm scale to a difficulty based interface.
+
+        δ₂ = β₂ * M * N² * D
+
+    **Arguments:**
+
+    - `normalized_gradient_norm_scale`: normalized gradient norm scale, see also
+        `exponax.stepper.generic.normalize_gradient_norm_scale`
+    - `num_spatial_dims`: number of spatial dimensions `d`
+    - `num_points`: number of points `N` used to discretize the domain per
+        dimension
+    - `maximum_absolute`: maximum absolute value of the input state the
+        resulting stepper is applied to
+
+    **Returns:**
+
+    - `difficulty_gradient_norm_scale`: difficulty gradient norm scale
+    """
     difficulty_gradient_norm_scale = (
         normalized_gradient_norm_scale
         * maximum_absolute
@@ -214,6 +422,25 @@ def extract_normalized_gradient_norm_scale_from_difficulty(
     num_points: int,
     maximum_absolute: float,
 ):
+    """
+    Extract the normalized gradient norm scale from a difficulty based interface.
+
+        β₂ = δ₂ / (M * N² * D)
+
+    **Arguments:**
+
+    - `difficulty_gradient_norm_scale`: difficulty gradient norm scale
+    - `num_spatial_dims`: number of spatial dimensions `d`
+    - `num_points`: number of points `N` used to discretize the domain per
+        dimension
+    - `maximum_absolute`: maximum absolute value of the input state the
+        resulting stepper is applied to
+
+    **Returns:**
+
+    - `normalized_gradient_norm_scale`: normalized gradient norm scale, see also
+        `exponax.stepper.generic.normalize_gradient_norm_scale`
+    """
     normalized_gradient_norm_scale = difficulty_gradient_norm_scale / (
         maximum_absolute * jnp.square(num_points) * num_spatial_dims
     )
@@ -227,6 +454,34 @@ def reduce_normalized_nonlinear_scales_to_difficulty(
     num_points: int,
     maximum_absolute: float,
 ):
+    """
+    Reduce the normalized nonlinear scales associated with a quadratic, a
+    (single-channel) convection term, and a gradient norm term to a difficulty
+    based interface.
+
+        δ₀ = β₀
+
+        δ₁ = β₁ * M * N * D
+
+        δ₂ = β₂ * M * N² * D
+
+    **Arguments:**
+
+    - `normalized_nonlinear_scales`: normalized nonlinear scales associated with
+        a quadratic, a (single-channel) convection term, and a gradient norm
+        term (in this order)
+    - `num_spatial_dims`: number of spatial dimensions `d`
+    - `num_points`: number of points `N` used to discretize the domain per
+        dimension
+    - `maximum_absolute`: maximum absolute value of the input state the
+        resulting stepper is applied to
+
+    **Returns:**
+
+    - `nonlinear_difficulties`: difficulty nonlinear scales associated with a
+        quadratic, a (single-channel) convection term, and a gradient norm term
+        (in this order)
+    """
     nonlinear_difficulties = (
         normalized_nonlinear_scales[0],  # Polynomial: normalized == difficulty
         reduce_normalized_convection_scale_to_difficulty(
@@ -252,6 +507,34 @@ def extract_normalized_nonlinear_scales_from_difficulty(
     num_points: int,
     maximum_absolute: float,
 ):
+    """
+    Extract the normalized nonlinear scales associated with a quadratic, a
+    (single-channel) convection term, and a gradient norm term from a difficulty
+    based interface.
+
+        β₀ = δ₀
+
+        β₁ = δ₁ / (M * N * D)
+
+        β₂ = δ₂ / (M * N² * D)
+
+    **Arguments:**
+
+    - `nonlinear_difficulties`: difficulty nonlinear scales associated with a
+        quadratic, a (single-channel) convection term, and a gradient norm term
+        (in this order)
+    - `num_spatial_dims`: number of spatial dimensions `d`
+    - `num_points`: number of points `N` used to discretize the domain per
+        dimension
+    - `maximum_absolute`: maximum absolute value of the input state the
+        resulting stepper is applied to
+
+    **Returns:**
+
+    - `normalized_nonlinear_scales`: normalized nonlinear scales associated with
+        a quadratic, a (single-channel) convection term, and a gradient norm term
+        (in this order)
+    """
     normalized_nonlinear_scales = (
         nonlinear_difficulties[0],  # Polynomial: normalized == difficulty
         extract_normalized_convection_scale_from_difficulty(