Skip to content

Commit

Permalink
Change to a significantly more efficient substacker
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Apr 5, 2024
1 parent 3015427 commit 853a3c7
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions exponax/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,19 @@ def stack_sub_trajectories(

n_sub_trjs = n_time_steps - sub_len + 1

sub_trjs = jtu.tree_map(
lambda trj: jnp.stack(
[trj[i : i + sub_len] for i in range(n_sub_trjs)], axis=0
),
trj,
)
def scan_fn(_, i):
sliced = jtu.tree_map(
lambda leaf: jax.lax.dynamic_slice_in_dim(
leaf,
start_index=i,
slice_size=sub_len,
axis=0,
),
trj,
)
return _, sliced

_, sub_trjs = jax.lax.scan(scan_fn, None, jnp.arange(n_sub_trjs))

return sub_trjs

Expand Down

0 comments on commit 853a3c7

Please sign in to comment.