Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do you remat GSPMD inserted all-gathers? #25010

Open
ptoulme-aws opened this issue Nov 20, 2024 · 1 comment
Open

How do you remat GSPMD inserted all-gathers? #25010

ptoulme-aws opened this issue Nov 20, 2024 · 1 comment
Assignees
Labels
question Questions for the JAX team

Comments

@ptoulme-aws
Copy link

ptoulme-aws commented Nov 20, 2024

Problem: I have some Jax code that does sequence parallel, so somewhat similar to this

activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('data', 'tensor', None))
activation = norm(activation)
activation =  jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('None, 'tensor', None))
# I want to remat this one ^
activation = attention(activation)

I have tried everything I can to remat the activation directly before attention, including Jax policies, explicitly using jax checkpoint on that exact tensor, but nothing to seems to make it remat. The activation directly before attention is a GSPMD inserted all-gather on the sequence dimension (dim=0).

I ended up writing an XLA pass to rematerialize large all-gathers and submitted a PR. openxla/xla#19163

Question: Is this possible to do from Jax end or is my pass really needed?

@mattjj
Copy link
Collaborator

mattjj commented Nov 20, 2024

Thanks for the question.

No, I don't think a new pass is needed.

As I understand it, the standard way to spell this is to us a remat policy to mark the with_sharding_constraint which induces the allgather as not-saveable. One way to do that would be to use save_only_these_names and to only name other arrays (that are either upstream of the allgather-inducing with_sharding_constraint, or downstream of the operations that use the output of attention). Following your snippet, that might look something like:

activation = jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('data', 'tensor', None))
activation = checkpoint_name(norm(activation), 'scattered_activations')
activation =  jax.lax.with_sharding_constraint(activation, NamedSharding(mesh, PartitionSpec('None, 'tensor', None))
activation = attention(activation)

together with a save_only_these_names policy that mentions 'scattered_activations' or something upstream of it.

Did you try something like that? If you already tried it, we should put together a minimal example to debug what's going on.

@mattjj mattjj added question Questions for the JAX team and removed enhancement New feature or request labels Nov 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

2 participants