-
Notifications
You must be signed in to change notification settings - Fork 488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPMD] Add apply_backward_optimization_barrier #6157
Conversation
Summary: This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier where registers a full backward hook that apply an optimization barrier to the given module. This API will prevent the XLA compiler from fusing the module's backward pass with others. And It's useful to prevent gigantic buffers being allocated to synchronize the gradients. Test Plan: python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier
Thanks Jon and Jack. |
Hi team, I have stumbled upon an issue here, where I am getting:
Command execution on worker 0 failed with exit status 1. Continuing.Am I missing a pkg ? |
You need to update your pytorch-xla to the head of the tree. Or use the r2.2 branch. |
Thank you for a speedy response. How do I reference r2.2 branch? 'pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html' ERROR: Could not find a version that satisfies the requirement torch~=2.2.0 (from versions: 1.11.0, 1.12.0, 1.12.1, 1.13.0, 1.13.1, 2.0.0, 2.0.1, 2.1.0, 2.1.1, 2.1.2) This is the latest that works for me: |
@ManfeiBai @zpcore Can you share the details? |
a minor update: if I drop "-f ...." and, instead modify pip install to: ERROR: The project you're trying to build requires Bazel 5.3.0 (specified in /tmp/pip-req-build-s_3ikvec/.bazelversion), but it wasn't found in /usr/bin. however that process is failing for me for other unrelated reasons. Alternatively, what would be the syntax to update pytorch-xla to the head of the tree ? FWIW: I've triple-checked here https://github.com/pytorch/xla/releases and the syntax mentioned there doesn't Appreciate the assist! |
Sorry for the late reply, the release document is currently in shape of a draft verison. You can use the following example command to install the 2.2 pre release along with other dependencies if you need torch and libtpu:
Let me know if you see any issues pop up. Thanks |
Thank you for this! It is precisely what I was looking for. Apart from informational deprecation warnings, I am able to move beyond issues I encountered previously. TYVM! |
Summary:
This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier where registers a full backward hook that apply an optimization barrier to the given module. This API will prevent the XLA compiler from fusing the module's backward pass with others. And It's useful to prevent gigantic buffers being allocated to synchronize the gradients.
Test Plan:
python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier