diff --git a/exponax/ic/_scaled.py b/exponax/ic/_scaled.py index 2e035a0..33e4e8e 100644 --- a/exponax/ic/_scaled.py +++ b/exponax/ic/_scaled.py @@ -12,14 +12,23 @@ def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]: class ScaledICGenerator(BaseRandomICGenerator): - """ - Works best in combination with initial conditions that have `max_one=True` - or `std_one=True`. - """ - ic_gen: BaseRandomICGenerator scale: float + def __init__(self, ic_gen: BaseRandomICGenerator, scale: float): + """ + A scaled initial condition generator. + + Works best in combination with initial conditions that have + `max_one=True` or `std_one=True`. + + **Arguments**: + - `ic_gen`: The initial condition generator. + - `scale`: The scaling factor. + """ + self.ic_gen = ic_gen + self.scale = scale + def gen_ic_fun(self, *, key: PRNGKeyArray) -> BaseIC: return ScaledIC(self.ic_gen.gen_ic_fun(key=key), scale=self.scale)