Skip to content

Commit

Permalink
[scan] Add a test under SPMD (#8419)
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei authored Dec 4, 2024
1 parent ae17298 commit 2ba14fc
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/scan/test_scan.py"
run_test "$CDIR/scan/test_scan_spmd.py"
run_test "$CDIR/scan/test_scan_layers.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
Expand Down
51 changes: 51 additions & 0 deletions test/scan/test_scan_spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import sys
import unittest

import torch
import torch_xla
from torch_xla.experimental.scan import scan
from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, get_1d_mesh
import torch_xla.runtime as xr


class ScanSpmdTest(unittest.TestCase):

def setUp(self):
# Activate SPMD
xr.use_spmd()

# Set up a simple SPMD mesh for these tests.
self.spmd_mesh = get_1d_mesh(axis_name="model")
set_global_mesh(self.spmd_mesh)
self.device = torch_xla.device()

@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
"Multiple devices required")
def test_scan_cumsum(self):
"""This test uses `scan` to implement `torch.cumsum`."""

def fn(carry, x):
new_carry = carry + x
y = new_carry
return new_carry, y

init = torch.zeros(1024, requires_grad=True, device=self.device)
mark_sharding(init, self.spmd_mesh, ('model',))
xs = torch.randn([8, 1024], requires_grad=True, device=self.device)
mark_sharding(xs, self.spmd_mesh, (None, 'model'))
final_carry, ys = scan(fn, init, xs)
torch_xla.sync()

# Check the input and output sharding. Note that we do this after
# `torch_xla.sync()` to ensure the output tensors are materialized and
# have taken on sharding annotations propagated by the compiler.
for tensor in [init, xs, final_carry, ys]:
self.assertIn('ShardingSpec: {devices=[',
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))
self.assertIn('OpSharding: {devices=[',
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
python3 test/test_while_loop.py
python3 test/scan/test_scan.py
python3 test/scan/test_scan_spmd.py
python3 test/scan/test_scan_layers.py
python3 test/test_pallas.py -v
python3 test/test_pallas_spmd.py
Expand Down

0 comments on commit 2ba14fc

Please sign in to comment.