You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm working with nightly versions of torch/xla on TPU. When moving from torch==2.6.0.dev20241106+cpu to torch==2.6.0.dev20241107, I see significantly increased use of the TPU memory for SPMD training (x 2.5), and in some settings, it also crashes due to OOM. The newest nightly still hasn't solved this problem. I suspect it might be some change in torch that affects SPMD training in torch_xla.
Environment
Reproducible on XLA backend - TPU
torch_xla version:2.6.0.dev20241107
torch/torchvision versions: 2.6.0.dev20241107
The text was updated successfully, but these errors were encountered:
hmm this is a bit weird, looking at the commit history
if the breaking change is coming from 11/06 -> 11/07 the change must be merged in 11/06 but there is nothing suspicious merged that date.
@qihqi identified that regression might be coming from pytorch/pytorch@54e6801, upstream changed something about SDPA. We never tested SDPA ourselves because it does not lower into any custom kernel, it is the same using the native implementation of the attention. @dudulightricks was able to confirm that regression is from SDPA. I was trying to helping them to using our flash attention in #8425 and @dudulightricks follow up with #8427. Right now the issue seems to be XLA_DISABLE_FUNCTIONIZATION will crash on backward.
🐛 Bug
I'm working with nightly versions of torch/xla on TPU. When moving from torch==2.6.0.dev20241106+cpu to torch==2.6.0.dev20241107, I see significantly increased use of the TPU memory for SPMD training (x 2.5), and in some settings, it also crashes due to OOM. The newest nightly still hasn't solved this problem. I suspect it might be some change in torch that affects SPMD training in torch_xla.
Environment
The text was updated successfully, but these errors were encountered: