From cbdeeb28f92180fe7bdfa97242b28031ec9061f6 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Wed, 20 Mar 2024 15:30:13 +0100 Subject: [PATCH] Add scaled IC generation --- exponax/ic/__init__.py | 3 +++ exponax/ic/_scaled.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) create mode 100644 exponax/ic/_scaled.py diff --git a/exponax/ic/__init__.py b/exponax/ic/__init__.py index 6a4961d..d2a0111 100644 --- a/exponax/ic/__init__.py +++ b/exponax/ic/__init__.py @@ -14,6 +14,7 @@ from ._discontinuities import Discontinuities, RandomDiscontinuities from ._gaussian_random_field import GaussianRandomField from ._multi_channel import MultiChannelIC, RandomMultiChannelICGenerator +from ._scaled import ScaledIC, ScaledICGenerator from ._truncated_fourier_series import RandomTruncatedFourierSeries __all__ = [ @@ -27,4 +28,6 @@ "RandomDiscontinuities", "RandomMultiChannelICGenerator", "RandomTruncatedFourierSeries", + "ScaledIC", + "ScaledICGenerator", ] diff --git a/exponax/ic/_scaled.py b/exponax/ic/_scaled.py new file mode 100644 index 0000000..2e035a0 --- /dev/null +++ b/exponax/ic/_scaled.py @@ -0,0 +1,30 @@ +from jaxtyping import Array, Float, PRNGKeyArray + +from ._base_ic import BaseIC, BaseRandomICGenerator + + +class ScaledIC(BaseIC): + ic: BaseIC + scale: float + + def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]: + return self.ic(x) * self.scale + + +class ScaledICGenerator(BaseRandomICGenerator): + """ + Works best in combination with initial conditions that have `max_one=True` + or `std_one=True`. + """ + + ic_gen: BaseRandomICGenerator + scale: float + + def gen_ic_fun(self, *, key: PRNGKeyArray) -> BaseIC: + return ScaledIC(self.ic_gen.gen_ic_fun(key=key), scale=self.scale) + + def __call__( + self, num_points: int, *, key: PRNGKeyArray + ) -> Float[Array, "1 ... N"]: + ic = self.ic_gen(num_points=num_points, key=key) + return ic * self.scale