Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPMD] Add apply_backward_optimization_barrier #6157

Merged
merged 1 commit into from
Dec 14, 2023

Conversation

alanwaketan
Copy link
Collaborator

Summary:
This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier where registers a full backward hook that apply an optimization barrier to the given module. This API will prevent the XLA compiler from fusing the module's backward pass with others. And It's useful to prevent gigantic buffers being allocated to synchronize the gradients.

Test Plan:
python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier

Summary:
This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier
where registers a full backward hook that apply an optimization barrier to the given module.
This API will prevent the XLA compiler from fusing the module's backward pass with others.
And It's useful to prevent gigantic buffers being allocated to synchronize the gradients.

Test Plan:
python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier
@JackCaoG JackCaoG merged commit e854d07 into r2.2 Dec 14, 2023
17 checks passed
@alanwaketan
Copy link
Collaborator Author

Thanks Jon and Jack.

@isaacr
Copy link

isaacr commented Jan 10, 2024

Hi team,

I have stumbled upon an issue here, where I am getting:

[2D] Sharding tensor model.layers.31.mlp.down_proj.weight torch.Size([4096, 11008])
model.layers.31.mlp.down_proj.weight {devices=[16,1]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}
[2D] Sharding tensor model.layers.31.input_layernorm.weight torch.Size([4096])
[2D] Sharding tensor model.layers.31.post_attention_layernorm.weight torch.Size([4096])
[2D] Sharding tensor model.norm.weight torch.Size([4096])
[2D] Sharding tensor lm_head.weight torch.Size([32000, 4096])
lm_head.weight {devices=[1,16]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}
Traceback (most recent call last):
File "/home/me/transformers/examples/pytorch/language-modeling/run_clm.py", line 838, in
main()
File "/home/me/transformers/examples/pytorch/language-modeling/run_clm.py", line 627, in main
xs.apply_backward_optimization_barrier(model.model.layers[i])
AttributeError: module 'torch_xla.experimental.xla_sharding' has no attribute 'apply_backward_optimization_barrier'

Command execution on worker 0 failed with exit status 1. Continuing.

Am I missing a pkg ?

@alanwaketan
Copy link
Collaborator Author

Hi team,

I have stumbled upon an issue here, where I am getting:

[2D] Sharding tensor model.layers.31.mlp.down_proj.weight torch.Size([4096, 11008])
model.layers.31.mlp.down_proj.weight {devices=[16,1]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}
[2D] Sharding tensor model.layers.31.input_layernorm.weight torch.Size([4096])
[2D] Sharding tensor model.layers.31.post_attention_layernorm.weight torch.Size([4096])
[2D] Sharding tensor model.norm.weight torch.Size([4096])
[2D] Sharding tensor lm_head.weight torch.Size([32000, 4096])
lm_head.weight {devices=[1,16]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}
Traceback (most recent call last):
File "/home/me/transformers/examples/pytorch/language-modeling/run_clm.py", line 838, in
main()
File "/home/me/transformers/examples/pytorch/language-modeling/run_clm.py", line 627, in main
xs.apply_backward_optimization_barrier(model.model.layers[i])
AttributeError: module 'torch_xla.experimental.xla_sharding' has no attribute 'apply_backward_optimization_barrier'

Command execution on worker 0 failed with exit status 1. Continuing.

Am I missing a pkg ?

You need to update your pytorch-xla to the head of the tree. Or use the r2.2 branch.

@isaacr
Copy link

isaacr commented Jan 11, 2024

Thank you for a speedy response. How do I reference r2.2 branch?

'pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html'

ERROR: Could not find a version that satisfies the requirement torch~=2.2.0 (from versions: 1.11.0, 1.12.0, 1.12.1, 1.13.0, 1.13.1, 2.0.0, 2.0.1, 2.1.0, 2.1.1, 2.1.2)

This is the latest that works for me:
pip install torch~=2.1.2 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html

@alanwaketan
Copy link
Collaborator Author

Thank you for a speedy response. How do I reference r2.2 branch?

'pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f storage.googleapis.com/libtpu-releases/index.html'

ERROR: Could not find a version that satisfies the requirement torch~=2.2.0 (from versions: 1.11.0, 1.12.0, 1.12.1, 1.13.0, 1.13.1, 2.0.0, 2.0.1, 2.1.0, 2.1.1, 2.1.2)

This is the latest that works for me: pip install torch~=2.1.2 torch_xla[tpu]~=2.1.0 -f storage.googleapis.com/libtpu-releases/index.html

@ManfeiBai @zpcore Can you share the details?

@isaacr
Copy link

isaacr commented Jan 11, 2024

a minor update: if I drop "-f ...." and, instead modify pip install to:
"pip install git+https://github.com/pytorch/xla",
then the journey took me down the rabbit hole of getting build tools, like Bazel (5.3.0) to build

ERROR: The project you're trying to build requires Bazel 5.3.0 (specified in /tmp/pip-req-build-s_3ikvec/.bazelversion), but it wasn't found in /usr/bin.
....
Building torch_xla version: 2.2.0+git8141078

however that process is failing for me for other unrelated reasons.
Of course, I appreciate the help and ideally I'd like to just obtain pre-built pre-release 2.2 bits, if at all possible, please?

Alternatively, what would be the syntax to update pytorch-xla to the head of the tree ?

FWIW: I've triple-checked here https://github.com/pytorch/xla/releases and the syntax mentioned there doesn't
work yet:
pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html

Appreciate the assist!

@zpcore
Copy link
Collaborator

zpcore commented Jan 12, 2024

Thank you for a speedy response. How do I reference r2.2 branch?

'pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html'

ERROR: Could not find a version that satisfies the requirement torch~=2.2.0 (from versions: 1.11.0, 1.12.0, 1.12.1, 1.13.0, 1.13.1, 2.0.0, 2.0.1, 2.1.0, 2.1.1, 2.1.2)

This is the latest that works for me: pip install torch~=2.1.2 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html

Sorry for the late reply, the release document is currently in shape of a draft verison. You can use the following example command to install the 2.2 pre release along with other dependencies if you need torch and libtpu:

pip install torch==2.2.0 --index-url https://download.pytorch.org/whl/test/cpu
pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0rc6-cp310-cp310-manylinux_2_28_x86_64.whl
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

Let me know if you see any issues pop up. Thanks

@isaacr
Copy link

isaacr commented Jan 16, 2024

Thank you for this! It is precisely what I was looking for. Apart from informational deprecation warnings, I am able to move beyond issues I encountered previously. TYVM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants