-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update test_spmd_debugging.py to avoid code test code self #6263
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Manfei!
test/spmd/test_spmd_debugging.py
Outdated
if num_devices != 8: | ||
self.skipTest("skip num_devices!=8 env to avoid circular reasoning") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make these checks into test decorators? e.g. unittest.skipIf(xr.global_runtime_device_count() != 8)
Also for my understanding, what's the circular reasoning here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, unittest.skipIf
is the way to go. What's the reason behind this check? !=8
seems too restrictive -- so if we can generalize to n
devices, it would be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for both suggestion, has updated to more generalized way to generate expected table without restrictive limitation
and for circular reasoning
, I want to describe a test situation that we need to provide test example by our function, and provide expected example by our function too,
for this test/spmd/test_spmd_debugging.py
, if we want to test on given device kind, we need to generate table with our function, and generate expected table with our function too, so to avoid this(circular reasoning), we want to limit the test device kind to 8-devices
let me change to code test code self
for better description
test/spmd/test_spmd_debugging.py
Outdated
@@ -255,6 +267,8 @@ def test_single_host_partial_replication_cpu(self): | |||
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding | |||
device = xm.xla_device() | |||
num_devices = self.n_devices | |||
if num_devices != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naive question, but is there a case where PJRT_DEVICE=CPU
has n_devices > 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, Jon, very good question, checked with TPU Pod-16:
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} --zone ${ZONE} --worker all --command='PJRT_DEVICE=CPU python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count());"'
Using ssh batch size of 4. Attempting to SSH into 1 nodes with a total of 4 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
1
1
1
1
looks like PJRT_DEVICE=CPU
only has n_devices = 1 compared with TPU device number:
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} --zone ${ZONE} --worker all --command='PJRT_DEVICE=TPU python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count());"'
Using ssh batch size of 4. Attempting to SSH into 1 nodes with a total of 4 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
16
16
16
we might want to move code of get local device number to doc example, and generalize test code start from given sharding string |
done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Fix #6252
test/spmd/test_spmd_debugging.py
was added for test spmd debugging tool, to avoid local test failure and circular reasoning, test based on given sharding string@unittest.skipIf(