Skip to content

Commit

Permalink
[scan] Add a test under SPMD
Browse files Browse the repository at this point in the history
Verifies that the GSPMD sharding annotation propagation pass can
propagate through a While op and through the Body computation just fine.
  • Loading branch information
tengyifei committed Nov 26, 2024
1 parent 39e67b5 commit c9a5907
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/scan/test_scan.py"
run_save_tensor_hlo "$CDIR/scan/test_scan_spmd.py"
run_test "$CDIR/scan/test_scan_layers.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
Expand Down
86 changes: 86 additions & 0 deletions test/scan/test_scan_spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import sys
import os
import re
import unittest
from pathlib import Path

import torch
import torch_xla
from torch_xla.experimental.scan import scan
from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, Mesh
import torch_xla.runtime as xr


class ScanSpmdTest(unittest.TestCase):

def setUp(self):
# Set up a simple SPMD mesh for these tests.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices,)
self.spmd_mesh = Mesh(list(range(num_devices)), mesh_shape, ('model',))
set_global_mesh(self.spmd_mesh)
xr.use_spmd()
self.device = torch_xla.device()

@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
"Multiple devices required")
def test_scan_cumsum(self):
"""This test uses `scan` to implement `torch.cumsum`."""

save_file = os.getenv('XLA_SAVE_TENSORS_FILE')
save_format = os.getenv('XLA_SAVE_TENSORS_FMT')
if not save_file:
assert False, "This test should be run with XLA_SAVE_TENSORS_FILE"
save_file += '.0'
assert save_format == 'hlo', "This test should be run with XLA_SAVE_TENSORS_FMT=hlo"

# Remove the save file (if exists) to start from a clean slate
try:
os.remove(save_file)
except:
pass

def fn(carry, x):
new_carry = carry + x
y = new_carry
return new_carry, y

init = torch.zeros(1024, requires_grad=True, device=self.device)
mark_sharding(init, self.spmd_mesh, ('model',))
xs = torch.randn([8, 1024], requires_grad=True, device=self.device)
mark_sharding(xs, self.spmd_mesh, (None, 'model'))
final_carry, ys = scan(fn, init, xs)
torch_xla.sync()

# Check the HLO
hlo_content = Path(save_file).read_text()
lines = hlo_content.splitlines()

# There should be only one graph.
assert len(re.findall('END_GRAPH', hlo_content)) == 1

# The graph should have output sharding.
begin_magic = '#OUTPUT_SHARDING_BEGIN'
end_magic = '#OUTPUT_SHARDING_END'
self.assertIn(end_magic, str(lines[-2]))

# Extract the output sharding descriptions.
start = hlo_content.find(begin_magic)
assert start != -1
start += len(begin_magic)
end = hlo_content.find(end_magic, start)
assert end != -1
output_sharding = hlo_content[start:end].strip().splitlines()

# There should be 4 tensors in output sharding: init, xs, final_carry, ys.
self.assertEqual(len(output_sharding), 4)
for sharding in output_sharding:
self.assertIn('devices=[', sharding)

# Remove the save file again to avoid cluttering other tests.
os.remove(save_file)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
python3 test/test_while_loop.py
python3 test/scan/test_scan.py
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" python3 test/scan/test_scan_spmd.py
python3 test/scan/test_scan_layers.py
python3 test/test_pallas.py -v
python3 test/test_pallas_spmd.py
Expand Down

0 comments on commit c9a5907

Please sign in to comment.