diff --git a/exponax/_utils.py b/exponax/_utils.py index 760be63..4370501 100644 --- a/exponax/_utils.py +++ b/exponax/_utils.py @@ -286,12 +286,19 @@ def stack_sub_trajectories( n_sub_trjs = n_time_steps - sub_len + 1 - sub_trjs = jtu.tree_map( - lambda trj: jnp.stack( - [trj[i : i + sub_len] for i in range(n_sub_trjs)], axis=0 - ), - trj, - ) + def scan_fn(_, i): + sliced = jtu.tree_map( + lambda leaf: jax.lax.dynamic_slice_in_dim( + leaf, + start_index=i, + slice_size=sub_len, + axis=0, + ), + trj, + ) + return _, sliced + + _, sub_trjs = jax.lax.scan(scan_fn, None, jnp.arange(n_sub_trjs)) return sub_trjs