Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[scan] Add a test under SPMD #8419

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/scan/test_scan.py"
run_test "$CDIR/scan/test_scan_spmd.py"
run_test "$CDIR/scan/test_scan_layers.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
Expand Down
51 changes: 51 additions & 0 deletions test/scan/test_scan_spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import sys
import unittest

import torch
import torch_xla
from torch_xla.experimental.scan import scan
from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, get_1d_mesh
import torch_xla.runtime as xr


class ScanSpmdTest(unittest.TestCase):

def setUp(self):
# Activate SPMD
xr.use_spmd()

# Set up a simple SPMD mesh for these tests.
self.spmd_mesh = get_1d_mesh(axis_name="model")
set_global_mesh(self.spmd_mesh)
self.device = torch_xla.device()

@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
"Multiple devices required")
def test_scan_cumsum(self):
"""This test uses `scan` to implement `torch.cumsum`."""

def fn(carry, x):
new_carry = carry + x
y = new_carry
return new_carry, y

init = torch.zeros(1024, requires_grad=True, device=self.device)
mark_sharding(init, self.spmd_mesh, ('model',))
xs = torch.randn([8, 1024], requires_grad=True, device=self.device)
mark_sharding(xs, self.spmd_mesh, (None, 'model'))
final_carry, ys = scan(fn, init, xs)
torch_xla.sync()

# Check the input and output sharding. Note that we do this after
# `torch_xla.sync()` to ensure the output tensors are materialized and
# have taken on sharding annotations propagated by the compiler.
for tensor in [init, xs, final_carry, ys]:
self.assertIn('ShardingSpec: {devices=[',
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))
self.assertIn('OpSharding: {devices=[',
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
python3 test/test_while_loop.py
python3 test/scan/test_scan.py
python3 test/scan/test_scan_spmd.py
python3 test/scan/test_scan_layers.py
python3 test/test_pallas.py -v
python3 test/test_pallas_spmd.py
Expand Down
Loading