From f89862992daa8605c984a7e3238284fda7e8c75c Mon Sep 17 00:00:00 2001 From: rpsilva-aws Date: Wed, 4 Dec 2024 21:38:18 +0000 Subject: [PATCH] [LoweringContext] Support an optimized parameter mapping for SPMD --- test/run_tests.sh | 1 + test/spmd/test_spmd_lowering_context.py | 58 +++++++++++++++++++++++++ test/test_operations.py | 8 ++-- torch_xla/csrc/init_python_bindings.cpp | 33 +++++++++++++- 4 files changed, 95 insertions(+), 5 deletions(-) create mode 100644 test/spmd/test_spmd_lowering_context.py diff --git a/test/run_tests.sh b/test/run_tests.sh index 543bc5f8403..365cd9595de 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -247,6 +247,7 @@ function run_xla_op_tests3 { run_test "$CDIR/spmd/test_xla_auto_sharding.py" run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py" run_test "$CDIR/spmd/test_mp_input_sharding.py" + run_test "$CDIR/spmd/test_spmd_lowering_context.py" run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_input_output_aliases.py" run_test "$CDIR/test_torch_distributed_xla_backend.py" diff --git a/test/spmd/test_spmd_lowering_context.py b/test/spmd/test_spmd_lowering_context.py new file mode 100644 index 00000000000..cb5018b1a4f --- /dev/null +++ b/test/spmd/test_spmd_lowering_context.py @@ -0,0 +1,58 @@ +import sys + +import unittest + +import test_xla_sharding_base + +import torch +import torch_xla +import torch_xla.debug.metrics as met +import torch_xla.distributed.spmd as xs +import torch_xla.core.xla_model as xm +import contextlib + + +class TestSPMDLoweringContext(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def test_device_parameter_id_tensor_mapping(self): + met.clear_all() + + model_axis = min(8, self.n_devices) + data_axis = self.n_devices // model_axis + mesh_shape = (data_axis, model_axis) + spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y')) + + device = xm.xla_device() + a = torch.randn([32, 2048]).to(device) + xs.mark_sharding(a, spmd_mesh, ('x', 'y')) + b = torch.ones(2048).to(device) + xs.mark_sharding(b, spmd_mesh, ('x',)) + + def fn(a, b): + return a + b + + result = fn(a, b) + ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName") + ctx.build([result]) + torch_xla.sync() + + mapping = ctx.device_parameter_id_tensor_mapping() + num_params = len(mapping) + self.assertEqual(num_params, 2) + self.assertNotEqual(ctx.tensor_parameter_id(a), -1) + self.assertNotEqual(ctx.tensor_parameter_id(b), -1) + self.assertEqual(met.counter_value("VirtualDeviceUsage"), num_params) + + # Ensure that the parameter mapping does not require transferring data + # from the device to the host when sharded. + self.assertFalse(met.metric_data("TransferFromDeviceTime")) + self.assertFalse(met.counter_value("ReplicateShardedData")) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_operations.py b/test/test_operations.py index cc3a73c4580..2e0260fd3ae 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2642,14 +2642,14 @@ def test_api(self): ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName") ctx.build([result]) - hlo = ctx.hlo() hlo_text = ctx.hlo_text() self.assertIn('MyCustomName', hlo_text) - self.assertIn('opcode: "parameter"', hlo_text) - self.assertIn('opcode: "parameter"', hlo_text) + self.assertTrue(hlo_text.count('opcode: "parameter"'), 2) self.assertIn('opcode: "add"', hlo_text) mapping = ctx.parameter_id_tensor_mapping() - self.assertEqual(len(mapping), 2) + num_params = len(mapping) + self.assertEqual(num_params, 2) + self.assertTrue(met.metric_data("TransferFromDeviceTime")) def test_get_parameters_scalar(self): """Scalar tensors parameters may be shared in the HLO graph if their diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 84a9d066cbf..24767bc7fe8 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1096,10 +1096,39 @@ class PyLoweringContext { return results; } + // Returns a mapping from HLO parameter IDs to their corresponding + // device-backed Tensors. This version only returns parameters that were + // explicitly allocated on device data, accessible via GetTensorParameterId(). + // Unlike GetParameterIdTensorMapping(), it avoids transferring data from + // device to host, making it more efficient especially for SPMD scenarios + // where data may be sharded. + std::unordered_map GetDeviceParameterIdTensorMapping() { + // Find parameters in the lowering + const std::vector& device_data = + lowering_ctx.GetParametersData(); + + // Create a mapping from parameter id to the tensor data + std::unordered_map param_to_tensor; + param_to_tensor.reserve(device_data.size()); + + for (const auto& data : device_data) { + std::optional param_id = lowering_ctx.GetParameterId(data); + XLA_CHECK(param_id.has_value()) + << "Parameter ID must exist for device data"; + + at::Tensor tensor = + bridge::AtenFromXlaTensor(torch_xla::XLATensor::Create(data)); + param_to_tensor.emplace(param_id.value(), std::move(tensor)); + } + return param_to_tensor; + } + // Get the parameter identifier of a given tensor. If the tensor is not a // parameter this will always return -1. This is useful in conjunction with // GetParameterIdTensorMapping to identify which values can be baked into - // the graph and which values must remain parameters. + // the graph and which values must remain parameters. Note that in + // conjunction with GetDeviceParameterIdTensorMapping, all tensors are + // parameters with a valid parameter id. int64_t GetTensorParameterId(at::Tensor tensor) { // Convert tensor into the backing lazy node XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); @@ -1201,6 +1230,8 @@ void BuildLoweringContextSubmodule(py::module* m) { .def("hlo_json", &PyLoweringContext::GetHloJsonText) .def("parameter_id_tensor_mapping", &PyLoweringContext::GetParameterIdTensorMapping) + .def("device_parameter_id_tensor_mapping", + &PyLoweringContext::GetDeviceParameterIdTensorMapping) .def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId) .def("set_name_string", &PyLoweringContext::SetNameString) .def("get_name_string", &PyLoweringContext::GetNameString);