Replace deprecated jax.tree_*
functions with jax.tree.*
#1553
Job | Run time |
---|---|
11m 45s | |
3m 40s | |
12m 1s | |
12m 0s | |
39m 26s |
jax.tree_*
functions with jax.tree.*
#1553
Job | Run time |
---|---|
11m 45s | |
3m 40s | |
12m 1s | |
12m 0s | |
39m 26s |