Skip to content

Commit

Permalink
add fome
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Feb 28, 2024
1 parent 7f2199f commit 7da474a
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions sbijax/_src/fmpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ def _ut(theta_t, theta, times, sigma_min):
return num / denom


# pylint: disable=too-many-locals
def _cfm_loss(
params, rng_key, apply_fn, sigma_min=0.001, is_training=True, **batch
):
theta = batch["theta"]
n, p = theta.shape
n, _ = theta.shape

t_key, rng_keyt = jr.split(rng_key)
t_key, rng_key = jr.split(rng_key)
times = jr.uniform(t_key, shape=(n, 1))

theta_key, rng_key = jr.split(rng_key)
Expand Down

0 comments on commit 7da474a

Please sign in to comment.