Replace deprecated jax.tree_*
functions with jax.tree.*
#1546
Job | Run time |
---|---|
3m 23s | |
12m 3s | |
12m 7s | |
12m 20s | |
39m 53s |
jax.tree_*
functions with jax.tree.*
#1546
Job | Run time |
---|---|
3m 23s | |
12m 3s | |
12m 7s | |
12m 20s | |
39m 53s |