You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Problem: I have some Jax code that does sequence parallel, so somewhat similar to this
activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('data', 'tensor', None))
activation = norm(activation)
activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('None, 'tensor', None))
# I want to remat this one ^
activation = attention(activation)
I have tried everything I can to remat the activation directly before attention, including Jax policies, explicitly using jax checkpoint on that exact tensor, but nothing to seems to make it remat. The activation directly before attention is a GSPMD inserted all-gather on the sequence dimension (dim=0).
I ended up writing an XLA pass to rematerialize large all-gathers and submitted a PR. openxla/xla#19163
Question: Is this possible to do from Jax end or is my pass really needed?
The text was updated successfully, but these errors were encountered:
As I understand it, the standard way to spell this is to us a remat policy to mark the with_sharding_constraint which induces the allgather as not-saveable. One way to do that would be to use save_only_these_names and to only name other arrays (that are either upstream of the allgather-inducing with_sharding_constraint, or downstream of the operations that use the output of attention). Following your snippet, that might look something like:
Problem: I have some Jax code that does sequence parallel, so somewhat similar to this
I have tried everything I can to remat the activation directly before attention, including Jax policies, explicitly using jax checkpoint on that exact tensor, but nothing to seems to make it remat. The activation directly before attention is a GSPMD inserted all-gather on the sequence dimension (dim=0).
I ended up writing an XLA pass to rematerialize large all-gathers and submitted a PR. openxla/xla#19163
Question: Is this possible to do from Jax end or is my pass really needed?
The text was updated successfully, but these errors were encountered: