Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPU memory use increased significantly in torch/xla - 2.6.0.dev20241107 #8423

Open
dudulightricks opened this issue Nov 27, 2024 · 4 comments

Comments

@dudulightricks
Copy link
Contributor

🐛 Bug

I'm working with nightly versions of torch/xla on TPU. When moving from torch==2.6.0.dev20241106+cpu to torch==2.6.0.dev20241107, I see significantly increased use of the TPU memory for SPMD training (x 2.5), and in some settings, it also crashes due to OOM. The newest nightly still hasn't solved this problem. I suspect it might be some change in torch that affects SPMD training in torch_xla.

Environment

  • Reproducible on XLA backend - TPU
  • torch_xla version:2.6.0.dev20241107
  • torch/torchvision versions: 2.6.0.dev20241107
@JackCaoG
Copy link
Collaborator

hmm this is a bit weird, looking at the commit history
image
if the breaking change is coming from 11/06 -> 11/07 the change must be merged in 11/06 but there is nothing suspicious merged that date.

@JackCaoG
Copy link
Collaborator

I will try to repo with the https://github.com/pytorch/xla/blob/master/examples/fsdp/train_decoder_only_fsdp_v2.py which is doing fsdpv2 sharding with the decoder only model and see if I can repo the memory issue.

@miladm
Copy link
Collaborator

miladm commented Dec 2, 2024

@JackCaoG do we have an insight on the repro to share?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 2, 2024

@qihqi identified that regression might be coming from pytorch/pytorch@54e6801, upstream changed something about SDPA. We never tested SDPA ourselves because it does not lower into any custom kernel, it is the same using the native implementation of the attention. @dudulightricks was able to confirm that regression is from SDPA. I was trying to helping them to using our flash attention in #8425 and @dudulightricks follow up with #8427. Right now the issue seems to be XLA_DISABLE_FUNCTIONIZATION will crash on backward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants