Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test/spmd/test_spmd_debugging.py fails when run on a v4-8 TPU #6252

Closed
mbzomowski opened this issue Jan 3, 2024 · 3 comments · Fixed by #6263
Closed

test/spmd/test_spmd_debugging.py fails when run on a v4-8 TPU #6252

mbzomowski opened this issue Jan 3, 2024 · 3 comments · Fixed by #6263
Assignees

Comments

@mbzomowski
Copy link
Collaborator

🐛 Bug

test/spmd/test_spmd_debugging.py fails when run on a v4-8 TPU, with the following output:

root@8a626de6d6ae:/ansible/pytorch/xla# python3 -u test/spmd/test_spmd_debugging.py
s                                                           
 TPU 0  TPU 4  TPU 8  TPU 12  TPU 2  TPU 6  TPU 10  TPU 14 
                                                           
                                                           
 TPU 1  TPU 5  TPU 9  TPU 13  TPU 3  TPU 7  TPU 11  TPU 15 
                                                           
Fs              
 TPU 0  TPU 1 
              
              
 TPU 2  TPU 3 
              
Fs              
  TPU [0, 1]  
              
              
  TPU [4, 5]  
              
              
  TPU [8, 9]  
              
              
 TPU [12, 13] 
              
              
  TPU [2, 3]  
              
              
  TPU [6, 7]  
              
              
 TPU [10, 11] 
              
              
 TPU [14, 15] 
              
Fs                  
 TPU [0, 1, 2, 3] 
                  
Fs            
 TPU [0, 1] 
            
            
 TPU [2, 3] 
            
Fs                  
 TPU [0, 1, 2, 3] 
                  
F
======================================================================
FAIL: test_debugging_spmd_multi_host_tiled_tpu (__main__.DebuggingSpmdTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/spmd/test_spmd_debugging.py", line 454, in test_debugging_spmd_multi_host_tiled_tpu
    assert output == fake_output
AssertionError

======================================================================
FAIL: test_debugging_spmd_single_host_tiled_tpu (__main__.DebuggingSpmdTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/spmd/test_spmd_debugging.py", line 109, in test_debugging_spmd_single_host_tiled_tpu
    assert output == fake_output
AssertionError

======================================================================
FAIL: test_multi_host_partial_replication_tpu (__main__.DebuggingSpmdTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/spmd/test_spmd_debugging.py", line 538, in test_multi_host_partial_replication_tpu
    assert output == fake_output
AssertionError

======================================================================
FAIL: test_multi_host_replicated_tpu (__main__.DebuggingSpmdTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/spmd/test_spmd_debugging.py", line 574, in test_multi_host_replicated_tpu
    assert output == fake_output
AssertionError

======================================================================
FAIL: test_single_host_partial_replication_tpu (__main__.DebuggingSpmdTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/spmd/test_spmd_debugging.py", line 160, in test_single_host_partial_replication_tpu
    assert output == fake_output
AssertionError

======================================================================
FAIL: test_single_host_replicated_tpu (__main__.DebuggingSpmdTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/spmd/test_spmd_debugging.py", line 205, in test_single_host_replicated_tpu
    assert output == fake_output
AssertionError

----------------------------------------------------------------------
Ran 12 tests in 3.288s

FAILED (failures=6, skipped=6)

I ensured my setup was correct by running test/test_operations.py beforehand, which passed, and skipped individual tests which were not applicable to TPUs.

To Reproduce

Steps to reproduce the behavior:

  1. Created a v4-8 TPU & ssh'd into it
  2. Followed setup steps here, and ran export BUNDLE_LIBTPU=1; export TPUVM_MODE=1 before running the pytorch/xla setup script.
  3. Ran export PJRT_DEVICE=TPU; python3 -u test/test_operations.py -v to ensure my setup was working.
  4. Ran python3 -u test/spmd/test_spmd_debugging.py, which resulted in the above failure.

Expected behavior

Expected the test to pass.

Environment

  • Reproducible on XLA backend [CPU/TPU]: TPU
  • torch_xla version: cloned master & ran setup

Additional context

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jan 4, 2024

@ManfeiBai can you take a look since you added this test?

@ManfeiBai
Copy link
Collaborator

ManfeiBai commented Jan 4, 2024

@ManfeiBai can you take a look since you added this test?

Thanks, will do

@ManfeiBai ManfeiBai self-assigned this Jan 4, 2024
@ManfeiBai
Copy link
Collaborator

thanks, reproduced locally on v4-8 too,

#6263 will fixed this issue, please pull the newest commit and this test has been verified locally too:

# PJRT_DEVICE=TPU python3 test/spmd/test_spmd_debugging.py
s                                                           
 TPU 0  TPU 4  TPU 8  TPU 12  TPU 2  TPU 6  TPU 10  TPU 14 
                                                           
                                                           
 TPU 1  TPU 5  TPU 9  TPU 13  TPU 3  TPU 7  TPU 11  TPU 15 
                                                           
.sss              
  TPU [0, 1]  
              
              
  TPU [4, 5]  
              
              
  TPU [8, 9]  
              
              
 TPU [12, 13] 
              
              
  TPU [2, 3]  
              
              
  TPU [6, 7]  
              
              
 TPU [10, 11] 
              
              
 TPU [14, 15] 
              
.ssssss
----------------------------------------------------------------------
Ran 12 tests in 3.305s

OK (skipped=10)
# PJRT_DEVICE=CPU python3 test/spmd/test_spmd_debugging.py
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1704494185.579667   92651 cpu_client.cc:370] TfrtCpuClient created.
                                                           
 CPU 0  CPU 4  CPU 8  CPU 12  CPU 2  CPU 6  CPU 10  CPU 14 
                                                           
                                                           
 CPU 1  CPU 5  CPU 9  CPU 13  CPU 3  CPU 7  CPU 11  CPU 15 
                                                           
.s         
 CPU [0] 
         
.s              
  CPU [0, 1]  
              
              
  CPU [4, 5]  
              
              
  CPU [8, 9]  
              
              
 CPU [12, 13] 
              
              
  CPU [2, 3]  
              
              
  CPU [6, 7]  
              
              
 CPU [10, 11] 
              
              
 CPU [14, 15] 
              
.s         
 CPU [0] 
         
.s         
 CPU [0] 
         
.s         
 CPU [0] 
         
.s
----------------------------------------------------------------------
Ran 12 tests in 0.168s

OK (skipped=6)
I0000 00:00:1704494186.033125   92651 cpu_client.cc:373] TfrtCpuClient destroyed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants