Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update test_spmd_debugging.py to avoid code test code self #6263

Merged
merged 25 commits into from
Jan 19, 2024

Conversation

ManfeiBai
Copy link
Collaborator

@ManfeiBai ManfeiBai commented Jan 5, 2024

Fix #6252

  • test/spmd/test_spmd_debugging.py was added for test spmd debugging tool, to avoid local test failure and circular reasoning, test based on given sharding string
  • add GPU tests by change @unittest.skipIf(
  • enable local test color check before draw table

Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Manfei!

Comment on lines 122 to 123
if num_devices != 8:
self.skipTest("skip num_devices!=8 env to avoid circular reasoning")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make these checks into test decorators? e.g. unittest.skipIf(xr.global_runtime_device_count() != 8)

Also for my understanding, what's the circular reasoning here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, unittest.skipIf is the way to go. What's the reason behind this check? !=8 seems too restrictive -- so if we can generalize to n devices, it would be better?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for both suggestion, has updated to more generalized way to generate expected table without restrictive limitation

and for circular reasoning, I want to describe a test situation that we need to provide test example by our function, and provide expected example by our function too,

for this test/spmd/test_spmd_debugging.py, if we want to test on given device kind, we need to generate table with our function, and generate expected table with our function too, so to avoid this(circular reasoning), we want to limit the test device kind to 8-devices

let me change to code test code self for better description

@@ -255,6 +267,8 @@ def test_single_host_partial_replication_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = self.n_devices
if num_devices != 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naive question, but is there a case where PJRT_DEVICE=CPU has n_devices > 1?

Copy link
Collaborator Author

@ManfeiBai ManfeiBai Jan 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, Jon, very good question, checked with TPU Pod-16:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} --zone  ${ZONE} --worker all --command='PJRT_DEVICE=CPU python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count());"'
Using ssh batch size of 4. Attempting to SSH into 1 nodes with a total of 4 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
1
1
1
1

looks like PJRT_DEVICE=CPU only has n_devices = 1 compared with TPU device number:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} --zone  ${ZONE} --worker all --command='PJRT_DEVICE=TPU python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count());"'
Using ssh batch size of 4. Attempting to SSH into 1 nodes with a total of 4 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
16
16
16

@ManfeiBai
Copy link
Collaborator Author

ManfeiBai commented Jan 9, 2024

we might want to move code of get local device number to doc example, and generalize test code start from given sharding string

@ManfeiBai ManfeiBai changed the title Update test_spmd_debugging.py to avoid circular reasoning Update test_spmd_debugging.py to avoid code test code self Jan 12, 2024
@ManfeiBai
Copy link
Collaborator Author

we might want to move code of get local device number to doc example, and generalize test code start from given sharding string

done

@ManfeiBai ManfeiBai requested a review from jonb377 January 19, 2024 00:58
Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ManfeiBai ManfeiBai merged commit df097d7 into master Jan 19, 2024
18 checks passed
ManfeiBai added a commit that referenced this pull request Jan 20, 2024
…elf and Promote int to float for tanh operation (#6263)(#6166) (#6329)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

test/spmd/test_spmd_debugging.py fails when run on a v4-8 TPU
3 participants