Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update doc to use PJRT_DEVICE=CUDA instead of PJRT_DEVICE=GPU #5754

Merged
merged 4 commits into from
Nov 4, 2023

Conversation

vanbasten23
Copy link
Collaborator

@vanbasten23 vanbasten23 commented Nov 1, 2023

#5675 discovered that runtime has replaced PJRT_DEVICE GPU with cuda and rocm. This PR updates our GPU documentation.

Test via:

  • GPU_NUM_DEVICES=1 PJRT_DEVICE=GPU python pytorch/xla/test/test_train_mp_imagenet.py --fake_data
  • PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node 4 pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
  • GPU CI

@zpcore
Copy link
Collaborator

zpcore commented Nov 1, 2023

Thanks for the update. Do we also want to remove the GPU from the device list here: https://github.com/pytorch/xla/blob/1f0e972a6dfc5ed3205b23d0fccac8ebb128584f/torch_xla/core/xla_model.py#L92C1-L92C1? There are several places still mention GPU choice in xla_model.py file.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Nov 1, 2023

Hey @vanbasten23 can we make sure PJRT_DEVICE=GPU continue to work and add warning message to let user to set it to CUDA.

@vanbasten23
Copy link
Collaborator Author

Thanks for the update. Do we also want to remove the GPU from the device list here: https://github.com/pytorch/xla/blob/1f0e972a6dfc5ed3205b23d0fccac8ebb128584f/torch_xla/core/xla_model.py#L92C1-L92C1? There are several places still mention GPU choice in xla_model.py file.

Yeah, that's a good point. I plan to remove it in a follow-up PR

@vanbasten23
Copy link
Collaborator Author

vanbasten23 commented Nov 2, 2023

Hey @vanbasten23 can we make sure PJRT_DEVICE=GPU continue to work and add warning message to let user to set it to CUDA.

FWIW, it's already causing some problem: if I use our nightly (11/1/23), I see:

root@xiowei-gpu-1:/# PJRT_DEVICE=GPU python
Python 3.8.18 (default, Oct 12 2023, 10:35:13)
[GCC 10.2.1 20210110] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch_xla.core.xla_model as xm
>>> import torch, torch_xla
>>> xm.get_xla_supported_devices(devkind='GPU')
# it returns nothing

On the contrary, r2.1 returns ['xla:0']. Let me see if there is an easy fix

@zpcore
Copy link
Collaborator

zpcore commented Nov 2, 2023

Hey @vanbasten23 can we make sure PJRT_DEVICE=GPU continue to work and add warning message to let user to set it to CUDA.

FWIW, it's already causing some problem: if I use our nightly (11/1/23), I see:

root@xiowei-gpu-1:/# PJRT_DEVICE=GPU python
Python 3.8.18 (default, Oct 12 2023, 10:35:13)
[GCC 10.2.1 20210110] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch_xla.core.xla_model as xm
>>> import torch, torch_xla
>>> xm.get_xla_supported_devices(devkind='GPU')
# it returns nothing

On the contrary, r2.1 returns ['xla:0']. Let me see if there is an easy fix

Yes, that's what we saw when run on GPU recently.
We added devkind = "CUDA" if devkind == "GPU" else devkind before xla_devices = _DEVICES.value line for a temporary fix in order to run on GPU.

@vanbasten23
Copy link
Collaborator Author

can we make sure PJRT_DEVICE=GPU continue to work and add warning message to let user to set it to CUDA.

hi @JackCaoG , I've added warning messages and fix the currently known bug.
Regarding "make sure PJRT_DEVICE=GPU continue to work", we're already in a mixed state where GPU CI uses CUDA but our public documentation uses GPU. IMO, the only way to make sure of that is to run GPU CI with PJRT_DEVICE=GPU but that sounds too much overhead. How about we just do our best effort?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Nov 2, 2023

@vanbasten23 just have one test run with PJRT_DEVICE=GPU(explcity set env var when running test or something) and rest run with PJRT_DEVICE=CUDA(default)

@vanbasten23
Copy link
Collaborator Author

vanbasten23 commented Nov 3, 2023

GPU CI test TestDynamicShapeModels is failing:

test_backward_pass_with_dynamic_input (__main__.TestDynamicShapeModels) ... 2023-11-03 17:22:04.313597: E external/xla/xla/status_macros.cc:54] INTERNAL: RET_CHECK failure (external/xla/xla/service/dynamic_padder.cc:1989) op_support != OpDynamismSupport::kNoSupport Dynamic input unexpectedly found for unsupported instruction: %add = f32[<=10,1]{1,0} add(f32[<=10,1]{1,0} %broadcast.2, f32[10,1]{1,0} %exponential)

and it's likely due to the recent pin update. A likely culprit is openxla/xla@33bcc66.

The test has not been run since last pin update because

not xm.get_xla_supported_devices("GPU") and
would silently fail and it's always been skipped.

Luckily, another cl is trying to revert openxla/xla@33bcc66. I'll skip this test on GPU for now and only let it run on TPU (on TPU, the test succeeds.).

@vanbasten23
Copy link
Collaborator Author

Thanks for the review!

@vanbasten23 vanbasten23 merged commit f01cdb6 into master Nov 4, 2023
17 checks passed
mbzomowski pushed a commit to mbzomowski-test-org/xla that referenced this pull request Nov 16, 2023
…h#5754)

* update doc to use PJRT_DEVICE=CUDA instead of PJRT_DEVICE=GPU

* add warning message.

* fix comment and test failure.

* skip dynamic shape model test on cuda.
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
* update doc to use PJRT_DEVICE=CUDA instead of PJRT_DEVICE=GPU

* add warning message.

* fix comment and test failure.

* skip dynamic shape model test on cuda.
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
* update doc to use PJRT_DEVICE=CUDA instead of PJRT_DEVICE=GPU

* add warning message.

* fix comment and test failure.

* skip dynamic shape model test on cuda.
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
…h#5754)

* update doc to use PJRT_DEVICE=CUDA instead of PJRT_DEVICE=GPU

* add warning message.

* fix comment and test failure.

* skip dynamic shape model test on cuda.
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
* update doc to use PJRT_DEVICE=CUDA instead of PJRT_DEVICE=GPU

* add warning message.

* fix comment and test failure.

* skip dynamic shape model test on cuda.
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
* update doc to use PJRT_DEVICE=CUDA instead of PJRT_DEVICE=GPU

* add warning message.

* fix comment and test failure.

* skip dynamic shape model test on cuda.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants