Skip to content

Commit

Permalink
Comply with pre-commit hook
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Feb 27, 2024
1 parent 444a145 commit 8139c22
Show file tree
Hide file tree
Showing 33 changed files with 254 additions and 247 deletions.
105 changes: 58 additions & 47 deletions exponax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,94 +1,105 @@
from .forced_stepper import ForcedStepper
from .initial_conditions import (
DiffusedNoise,
GaussianRandomField,
MultiChannelIC,
RandomMultiChannelICGenerator,
RandomTruncatedFourierSeries,
DiffusedNoise,
GaussianRandomField,
)
from .normalized_stepper import (
NormalizedConvectionStepper,
NormalizedGradientNormStepper,
NormalizedLinearStepper,
denormalize_coefficients,
denormalize_convection_scale,
denormalize_gradient_norm_scale,
normalize_coefficients,
normalize_convection_scale,
normalize_gradient_norm_scale,
)
from .poisson import Poisson
from .repeated_stepper import RepeatedStepper
from .sample_stepper import (
Advection,
Diffusion,
AdvectionDiffusion,
AllenCahn,
BelousovZhabotinsky,
Burgers,
CahnHilliard,
Diffusion,
Dispersion,
HyperDiffusion,
FisherKPP,
GeneralConvectionStepper,
GeneralGradientNormStepper,
GeneralLinearStepper,
Burgers,
GrayScott,
HyperDiffusion,
KolmogorovFlowVorticity2d,
KortevegDeVries,
KuramotoSivashinsky,
KuramotoSivashinskyConservative,
NavierStokesVorticity2d,
Nikolaevskiy,
NikolaevskiyConservative,
GeneralConvectionStepper,
GeneralGradientNormStepper,
NavierStokesVorticity2d,
KolmogorovFlowVorticity2d,
SwiftHohenberg,
GrayScott,
KortevegDeVries,
FisherKPP,
AllenCahn,
CahnHilliard,
BelousovZhabotinsky,
)
from .normalized_stepper import (
NormalizedLinearStepper,
NormalizedConvectionStepper,
NormalizedGradientNormStepper,
normalize_coefficients,
denormalize_coefficients,
normalize_convection_scale,
denormalize_convection_scale,
normalize_gradient_norm_scale,
denormalize_gradient_norm_scale,
)
from .spectral import derivative
from .utils import (
get_grid,
build_ic_set,
get_animation,
get_grid,
get_grouped_animation,
rollout,
repeat,
rollout,
stack_sub_trajectories,
build_ic_set,
)
from .spectral import (
derivative,
)

__all__ = [
"ForcedStepper",
"SineWaves",
"RandomSineWaves",
"DiffusedNoise",
"RandomDiffusedNoise",
"GaussianRandomField",
"MultiChannelIC",
"RandomMultiChannelICGenerator",
"RandomTruncatedFourierSeries",
"NormalizedConvectionStepper",
"NormalizedGradientNormStepper",
"NormalizedLinearStepper",
"Poisson",
"RepeatedStepper",
"Advection",
"Advection1d",
"Advection2d",
"Advection3d",
"Diffusion",
"Diffusion1d",
"Diffusion2d",
"Diffusion3d",
"AdvectionDiffusion",
"AllenCahn",
"BelousovZhabotinsky",
"CahnHilliard",
"Dispersion",
"FisherKPP",
"GeneralConvectionStepper",
"GeneralGradientNormStepper",
"GeneralLinearStepper",
"GrayScott",
"HyperDiffusion",
"Burgers",
"Burgers1d",
"Burgers2d",
"Burgers3d",
"SwiftHohenberg",
"KortevegDeVries",
"KolmogorovFlowVorticity2d",
"KuramotoSivashinsky",
"KuramotoSivashinsky1d",
"KuramotoSivashinsky2d",
"KuramotoSivashinsky3d",
"KuramotoSivashinskyConservative",
"Nikolaevskiy",
"NikolaevskiyConservative",
"NavierStokesVorticity2d",
"derivative",
"get_grid",
"get_animation",
"get_grouped_animation",
"rollout",
"repeat",
"stack_sub_trajectories",
"build_ic_set",
"normalize_coefficients",
"denormalize_coefficients",
"normalize_convection_scale",
"denormalize_convection_scale",
"normalize_gradient_norm_scale",
"denormalize_gradient_norm_scale",
]
10 changes: 4 additions & 6 deletions exponax/base_stepper.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import jax
import jax.numpy as jnp
import equinox as eqx
from jaxtyping import Array, Float, Complex
import jax.numpy as jnp
from jaxtyping import Array, Complex, Float

from .exponential_integrators import BaseETDRK, ETDRK0, ETDRK1, ETDRK2, ETDRK3, ETDRK4
from .exponential_integrators import ETDRK0, ETDRK1, ETDRK2, ETDRK3, ETDRK4, BaseETDRK
from .nonlinear_functions import BaseNonlinearFun
from .spectral import (
build_derivative_operator,
space_indices,
spatial_shape,
wavenumber_shape,
)

from .nonlinear_functions import BaseNonlinearFun


class BaseStepper(eqx.Module):
num_spatial_dims: int
Expand Down
12 changes: 8 additions & 4 deletions exponax/exponential_integrators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import jax.numpy as jnp
from typing import TypeVar

import equinox as eqx
from jaxtyping import Complex, Array, Float
from typing import Callable
import jax.numpy as jnp
from jaxtyping import Array, Complex

from .nonlinear_functions import BaseNonlinearFun, ZeroNonlinearFun
from .nonlinear_functions import BaseNonlinearFun

# E can either be 1 (single channel) or num_channels (multi-channel) for either
# the same linear operator for each channel or a different linear operator for
Expand Down Expand Up @@ -47,6 +48,9 @@ def step_fourier(
return self._exp_term * u_hat


M = TypeVar("M")


def roots_of_unity(M: int) -> Complex[Array, "M"]:
"""
Return (complex-valued) array with M roots of unity.
Expand Down
5 changes: 2 additions & 3 deletions exponax/forced_stepper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any
import equinox as eqx
from .base_stepper import BaseStepper
from jaxtyping import Array, Complex, Float

from jaxtyping import Array, Float, Complex
from .base_stepper import BaseStepper


class ForcedStepper(eqx.Module):
Expand Down
30 changes: 14 additions & 16 deletions exponax/initial_conditions.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,26 @@
import jax.numpy as jnp
import jax.random as jr
from abc import ABC, abstractmethod
from typing import List
import equinox as eqx
from jaxtyping import Complex, Array, Float, PRNGKeyArray

from abc import ABC, abstractmethod
from typing import Optional
import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Array, Float, PRNGKeyArray

from .sample_stepper import Diffusion
from .spectral import (
build_scaled_wavenumbers,
spatial_shape,
wavenumber_shape,
build_scaling_array,
low_pass_filter_mask,
space_indices,
build_scaling_array,
spatial_shape,
wavenumber_shape,
)
from .utils import get_grid

### --- Base classes --- ###
# --- Base classes ---


class BaseIC(eqx.Module, ABC):

@abstractmethod
def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]:
"""
Expand Down Expand Up @@ -85,7 +83,7 @@ def __call__(
return ic_fun(grid)


### Utilities to create ICs for multi-channel fields
# Utilities to create ICs for multi-channel fields


class MultiChannelIC(eqx.Module):
Expand Down Expand Up @@ -120,7 +118,7 @@ def __call__(
return jnp.concatenate(u_list, axis=0)


### New version
# New version

# class TruncatedFourierSeries(BaseIC):
# coefficient_array: Complex[Array, "1 ... (N//2)+1"]
Expand Down Expand Up @@ -220,7 +218,7 @@ def __call__(
return u


### --- Legacy Sine Waves (truncated Fourier series) --- ###
# --- Legacy Sine Waves (truncated Fourier series) ---

# class SineWaves(BaseIC):
# L: float
Expand Down Expand Up @@ -400,7 +398,7 @@ def __call__(
return ic


### Gausian Random Field ###
# Gausian Random Field


class GaussianRandomField(BaseRandomICGenerator):
Expand Down Expand Up @@ -461,7 +459,7 @@ def __call__(
return ic


### Discontinuities ###
# Discontinuities


class Discontinuities(BaseIC):
Expand Down
22 changes: 16 additions & 6 deletions exponax/nonlinear_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,22 @@
from .gradient_norm import GradientNormNonlinearFun
from .polynomial import PolynomialNonlinearFun
from .reaction import (
GrayScottNonlinearFun,
CahnHilliardNonlinearFun,
BelousovZhabotinskyNonlinearFun,
CahnHilliardNonlinearFun,
GrayScottNonlinearFun,
)
from .vorticity_convection import (
VorticityConvection2d,
VorticityConvection2dKolmogorov,
)
from .vorticity_convection import VorticityConvection2d, VorticityConvection2dKolmogorov
from .zero import ZeroNonlinearFun

__all__ = [
"BaseNonlinearFun",
"ConvectionNonlinearFun",
"GradientNormNonlinearFun",
"PolynomialNonlinearFun",
"BelousovZhabotinskyNonlinearFun",
"CahnHilliardNonlinearFun",
"GrayScottNonlinearFun",
"VorticityConvection2d",
"VorticityConvection2dKolmogorov",
"ZeroNonlinearFun",
]
13 changes: 5 additions & 8 deletions exponax/nonlinear_functions/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import jax
import jax.numpy as jnp
import equinox as eqx
from jaxtyping import Complex, Array, Float, Bool
from ..spectral import (
wavenumber_shape,
low_pass_filter_mask,
)
from abc import ABC, abstractmethod

import equinox as eqx
from jaxtyping import Array, Bool, Complex

from ..spectral import low_pass_filter_mask, wavenumber_shape


class BaseNonlinearFun(eqx.Module, ABC):
num_spatial_dims: int
Expand Down
10 changes: 3 additions & 7 deletions exponax/nonlinear_functions/convection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import jax
import jax.numpy as jnp
import equinox as eqx
from jaxtyping import Complex, Array, Float, Bool
from ..spectral import (
space_indices,
spatial_shape,
)
from jaxtyping import Array, Complex, Float

from ..spectral import space_indices, spatial_shape
from .base import BaseNonlinearFun


Expand Down Expand Up @@ -68,4 +64,4 @@ def evaluate(
axis=1,
)
# Requires minus to move term to the rhs
return - self.scale * 0.5 * u_divergence_on_outer_product_hat
return -self.scale * 0.5 * u_divergence_on_outer_product_hat
10 changes: 3 additions & 7 deletions exponax/nonlinear_functions/gradient_norm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import jax
import jax.numpy as jnp
import equinox as eqx
from jaxtyping import Complex, Array, Float, Bool
from ..spectral import (
space_indices,
spatial_shape,
)
from jaxtyping import Array, Complex, Float

from ..spectral import space_indices, spatial_shape
from .base import BaseNonlinearFun


Expand Down Expand Up @@ -73,4 +69,4 @@ def evaluate(
# )

# Requires minus to move term to the rhs
return - self.scale * 0.5 * u_gradient_norm_squared_hat
return -self.scale * 0.5 * u_gradient_norm_squared_hat
9 changes: 2 additions & 7 deletions exponax/nonlinear_functions/polynomial.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import jax
import jax.numpy as jnp
import equinox as eqx
from jaxtyping import Complex, Array, Float, Bool
from ..spectral import (
space_indices,
spatial_shape,
)
from jaxtyping import Array, Complex

from ..spectral import space_indices, spatial_shape
from .base import BaseNonlinearFun


Expand Down
Loading

0 comments on commit 8139c22

Please sign in to comment.