Skip to content

Commit

Permalink
Fix crucial bug in multichannel ic generation
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Mar 8, 2024
1 parent 1a79211 commit 32e0fa1
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions exponax/ic/_multi_channel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PRNGKeyArray

Expand Down Expand Up @@ -28,12 +29,22 @@ class RandomMultiChannelICGenerator(eqx.Module):

def gen_ic_fun(self, num_points: int, *, key: PRNGKeyArray) -> MultiChannelIC:
ic_funs = [
ic_gen.gen_ic_fun(num_points, key=key) for ic_gen in self.ic_generators
ic_gen.gen_ic_fun(num_points, key=k)
for (ic_gen, k) in zip(
self.ic_generators,
jax.random.split(key, len(self.ic_generators)),
)
]
return MultiChannelIC(ic_funs)

def __call__(
self, num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "C ... N"]:
u_list = [ic_gen(num_points, key=key) for ic_gen in self.ic_generators]
u_list = [
ic_gen(num_points, key=k)
for (ic_gen, k) in zip(
self.ic_generators,
jax.random.split(key, len(self.ic_generators)),
)
]
return jnp.concatenate(u_list, axis=0)

0 comments on commit 32e0fa1

Please sign in to comment.