-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Verifies that the GSPMD sharding annotation propagation pass can propagate through a While op and through the Body computation just fine.
- Loading branch information
Showing
3 changed files
with
89 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import sys | ||
import os | ||
import re | ||
import unittest | ||
from pathlib import Path | ||
|
||
import torch | ||
import torch_xla | ||
from torch_xla.experimental.scan import scan | ||
from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, Mesh | ||
import torch_xla.runtime as xr | ||
|
||
|
||
class ScanSpmdTest(unittest.TestCase): | ||
|
||
def setUp(self): | ||
# Set up a simple SPMD mesh for these tests. | ||
num_devices = xr.global_runtime_device_count() | ||
mesh_shape = (num_devices,) | ||
self.spmd_mesh = Mesh(list(range(num_devices)), mesh_shape, ('model',)) | ||
set_global_mesh(self.spmd_mesh) | ||
xr.use_spmd() | ||
self.device = torch_xla.device() | ||
|
||
@unittest.skipUnless(xr.global_runtime_device_count() >= 4, | ||
"Multiple devices required") | ||
def test_scan_cumsum(self): | ||
"""This test uses `scan` to implement `torch.cumsum`.""" | ||
|
||
save_file = os.getenv('XLA_SAVE_TENSORS_FILE') | ||
save_format = os.getenv('XLA_SAVE_TENSORS_FMT') | ||
if not save_file: | ||
assert False, "This test should be run with XLA_SAVE_TENSORS_FILE" | ||
save_file += '.0' | ||
assert save_format == 'hlo', "This test should be run with XLA_SAVE_TENSORS_FMT=hlo" | ||
|
||
# Remove the save file (if exists) to start from a clean slate | ||
try: | ||
os.remove(save_file) | ||
except: | ||
pass | ||
|
||
def fn(carry, x): | ||
new_carry = carry + x | ||
y = new_carry | ||
return new_carry, y | ||
|
||
init = torch.zeros(1024, requires_grad=True, device=self.device) | ||
mark_sharding(init, self.spmd_mesh, ('model',)) | ||
xs = torch.randn([8, 1024], requires_grad=True, device=self.device) | ||
mark_sharding(xs, self.spmd_mesh, (None, 'model')) | ||
final_carry, ys = scan(fn, init, xs) | ||
torch_xla.sync() | ||
|
||
# Check the HLO | ||
hlo_content = Path(save_file).read_text() | ||
lines = hlo_content.splitlines() | ||
|
||
# There should be only one graph. | ||
assert len(re.findall('END_GRAPH', hlo_content)) == 1 | ||
|
||
# The graph should have output sharding. | ||
begin_magic = '#OUTPUT_SHARDING_BEGIN' | ||
end_magic = '#OUTPUT_SHARDING_END' | ||
self.assertIn(end_magic, str(lines[-2])) | ||
|
||
# Extract the output sharding descriptions. | ||
start = hlo_content.find(begin_magic) | ||
assert start != -1 | ||
start += len(begin_magic) | ||
end = hlo_content.find(end_magic, start) | ||
assert end != -1 | ||
end -= len(end_magic) | ||
output_sharding = hlo_content[start:end].strip().splitlines() | ||
|
||
# There should be 4 tensors in output sharding: init, xs, final_carry, ys. | ||
self.assertEqual(len(output_sharding), 4) | ||
for sharding in output_sharding: | ||
self.assertIn('devices=[', sharding) | ||
|
||
# Remove the save file again to avoid cluttering other tests. | ||
os.remove(save_file) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters