-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[LoweringContext] Support an optimized parameter mapping for SPMD (#8460
- Loading branch information
1 parent
cff26e5
commit 5d11f66
Showing
5 changed files
with
103 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import sys | ||
|
||
import unittest | ||
|
||
import test_xla_sharding_base | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.debug.metrics as met | ||
import torch_xla.distributed.spmd as xs | ||
import torch_xla.core.xla_model as xm | ||
import contextlib | ||
|
||
|
||
class TestSPMDLoweringContext(test_xla_sharding_base.XlaShardingTest): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
super().setUpClass() | ||
|
||
def test_device_parameter_id_tensor_mapping(self): | ||
met.clear_all() | ||
|
||
model_axis = min(8, self.n_devices) | ||
data_axis = self.n_devices // model_axis | ||
mesh_shape = (data_axis, model_axis) | ||
spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y')) | ||
|
||
device = xm.xla_device() | ||
a = torch.randn([32, 2048]).to(device) | ||
xs.mark_sharding(a, spmd_mesh, ('x', 'y')) | ||
b = torch.ones(2048).to(device) | ||
xs.mark_sharding(b, spmd_mesh, ('x',)) | ||
|
||
def fn(a, b): | ||
return a + b | ||
|
||
result = fn(a, b) | ||
ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName") | ||
ctx.build([result]) | ||
torch_xla.sync() | ||
|
||
mapping = ctx.device_parameter_id_tensor_mapping() | ||
num_params = len(mapping) | ||
self.assertEqual(num_params, 2) | ||
self.assertNotEqual(ctx.tensor_parameter_id(a), -1) | ||
self.assertNotEqual(ctx.tensor_parameter_id(b), -1) | ||
self.assertEqual(met.counter_value("VirtualDeviceUsage"), num_params) | ||
|
||
# Ensure that the parameter mapping does not require transferring data | ||
# from the device to the host when sharded. | ||
self.assertFalse(met.metric_data("TransferFromDeviceTime")) | ||
self.assertFalse(met.counter_value("ReplicateShardedData")) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters