From e88c7d39b9a76b417012d4e4cc9f723ea7091fa2 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Thu, 26 Oct 2023 22:47:29 +0000 Subject: [PATCH] Add new API for custom mark sharding op and update tests --- test/spmd/test_dynamo_spmd.py | 18 ++++++++++++++++- torch_xla/csrc/init_python_bindings.cpp | 2 +- torch_xla/csrc/xla_lower_util.cpp | 2 +- torch_xla/experimental/xla_sharding.py | 27 +++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 3 deletions(-) diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 7dc839979eee..7b835a267549 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -53,6 +53,21 @@ def test_dynamo_spmd_basic(self): # TODO(JackCaoG): add counter checks after ExecuteReplicated also creates # a ExecuteMetric. + def test_dynamo_spmd_basic_with_custom_mark_sharding_op(self): + device = xm.xla_device() + linear = SimpleLinear().to(device) + linear.eval() + xla_x = torch.randn(1, 128, device=device) + xs.mark_sharding_dynamo_custom_op(linear.fc2.weight, + self._get_mesh((1, self.n_devices)), + (1, 0)) + xla_res = linear(xla_x) + xm.mark_step() + + dynamo_linear = torch.compile(linear, backend="openxla") + dynamo_res = dynamo_linear(xla_x) + torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + def test_dynamo_spmd_output_sharding_spec(self): device = xm.xla_device() linear = SimpleLinear().to(device) @@ -177,7 +192,8 @@ def fn_simple(x): y = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, device=xm.xla_device()) - ys = xs.mark_sharding(y, self._get_mesh((1, self.n_devices)), (0, 1)) + ys = xs.mark_sharding_dynamo_custom_op( + y, self._get_mesh((1, self.n_devices)), (0, 1)) return x + ys diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 6d3a8335566c..507aabd83ab2 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1626,7 +1626,7 @@ void InitXlaModuleBindings(py::module m) { // Register sharded tensor data. XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); }); - m.def("_xla_mark_sharding_custom_op", + m.def("_xla_mark_sharding_dynamo_custom_op", [](const at::Tensor& input, xla::OpSharding sharding) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); tensor_methods::custom_mark_sharding(xtensor, sharding); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index c573adeb31e6..567bd12ad5de 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1230,7 +1230,7 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) { xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device, const xla::XlaOp& input, - const xla::XlaOp sharding) { + const xla::XlaOp& sharding) { return xla::CustomCall(input.builder(), /*call_target_name=*/"MarkSharding", {input}, ShapeHelper::ShapeOfXlaOp(input)); } diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 21d0e2e570ac..31c44e3c9d31 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -506,6 +506,33 @@ def mark_sharding( return XLAShardedTensor(t) +@xr.requires_pjrt +def mark_sharding_dynamo_custom_op( + t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, + partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor: + """ + Same functionality as `mark_sharding` above, except this variant uses the custom mark_sharding op in torch_xla._XLAC to allow dynamo to recognize and trace it. + """ + num_devices = xr.global_runtime_device_count() + assert num_devices > 0, "This requires XLA supported device(s)." + assert mesh.size() == num_devices, \ + f"{mesh.mesh_shape} is not mappable over {num_devices} devices." + # We only allow fully specified `partition_spec` to be applicable, as opposed + # to filling in the unspecified replicated dims. Fully specified `partiion_spec` + # should be of the same rank as `t`. This is to support partial replication + # where the group assignment may vary with different input ranks. + assert len(t.shape) == len(partition_spec), \ + f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." + + op_sharding = mesh.get_op_sharding(partition_spec) + + if isinstance(t, XLAShardedTensor): + torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) + return t + torch_xla._XLAC._xla_mark_sharding(t, op_sharding) + return XLAShardedTensor(t) + + def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor: """Clear sharding annotation from the input tensor and return a `cpu` casted tensor.""" torch_xla._XLAC._xla_clear_sharding(t)