From 5d11f66057dfac4be0a3407de35ff7cf98bd2e0c Mon Sep 17 00:00:00 2001 From: Rui <179625410+rpsilva-aws@users.noreply.github.com> Date: Fri, 6 Dec 2024 17:20:00 -0800 Subject: [PATCH] [LoweringContext] Support an optimized parameter mapping for SPMD (#8460) --- test/run_tests.sh | 1 + test/spmd/test_spmd_lowering_context.py | 58 +++++++++++++++++++++++++ test/test_operations.py | 14 ++++-- test/tpu/run_tests.sh | 1 + torch_xla/csrc/init_python_bindings.cpp | 36 +++++++++++++-- 5 files changed, 103 insertions(+), 7 deletions(-) create mode 100644 test/spmd/test_spmd_lowering_context.py diff --git a/test/run_tests.sh b/test/run_tests.sh index a3a8c74cedd..eeb2e8ee34d 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -248,6 +248,7 @@ function run_xla_op_tests3 { run_test "$CDIR/spmd/test_xla_auto_sharding.py" run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py" run_test "$CDIR/spmd/test_mp_input_sharding.py" + run_test "$CDIR/spmd/test_spmd_lowering_context.py" run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_input_output_aliases.py" run_test "$CDIR/test_torch_distributed_xla_backend.py" diff --git a/test/spmd/test_spmd_lowering_context.py b/test/spmd/test_spmd_lowering_context.py new file mode 100644 index 00000000000..cb5018b1a4f --- /dev/null +++ b/test/spmd/test_spmd_lowering_context.py @@ -0,0 +1,58 @@ +import sys + +import unittest + +import test_xla_sharding_base + +import torch +import torch_xla +import torch_xla.debug.metrics as met +import torch_xla.distributed.spmd as xs +import torch_xla.core.xla_model as xm +import contextlib + + +class TestSPMDLoweringContext(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def test_device_parameter_id_tensor_mapping(self): + met.clear_all() + + model_axis = min(8, self.n_devices) + data_axis = self.n_devices // model_axis + mesh_shape = (data_axis, model_axis) + spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y')) + + device = xm.xla_device() + a = torch.randn([32, 2048]).to(device) + xs.mark_sharding(a, spmd_mesh, ('x', 'y')) + b = torch.ones(2048).to(device) + xs.mark_sharding(b, spmd_mesh, ('x',)) + + def fn(a, b): + return a + b + + result = fn(a, b) + ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName") + ctx.build([result]) + torch_xla.sync() + + mapping = ctx.device_parameter_id_tensor_mapping() + num_params = len(mapping) + self.assertEqual(num_params, 2) + self.assertNotEqual(ctx.tensor_parameter_id(a), -1) + self.assertNotEqual(ctx.tensor_parameter_id(b), -1) + self.assertEqual(met.counter_value("VirtualDeviceUsage"), num_params) + + # Ensure that the parameter mapping does not require transferring data + # from the device to the host when sharded. + self.assertFalse(met.metric_data("TransferFromDeviceTime")) + self.assertFalse(met.counter_value("ReplicateShardedData")) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_operations.py b/test/test_operations.py index 892a02ddb0b..20365cb4f1e 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2634,6 +2634,7 @@ def test_multi_init_xla_backend(self): class TestLoweringContext(test_utils.XlaTestCase): def test_api(self): + met.clear_all() device = xm.xla_device() a = torch.tensor([1.0, 2.0, 3.0], device=device) b = torch.tensor([4.0, 5.0, 6.0], device=device) @@ -2642,14 +2643,19 @@ def test_api(self): ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName") ctx.build([result]) - hlo = ctx.hlo() + _ = ctx.hlo() hlo_text = ctx.hlo_text() self.assertIn('MyCustomName', hlo_text) - self.assertIn('opcode: "parameter"', hlo_text) - self.assertIn('opcode: "parameter"', hlo_text) + self.assertTrue(hlo_text.count('opcode: "parameter"'), 2) self.assertIn('opcode: "add"', hlo_text) + num_expected_params = 2 mapping = ctx.parameter_id_tensor_mapping() - self.assertEqual(len(mapping), 2) + self.assertEqual(len(mapping), num_expected_params) + self.assertTrue(met.metric_data("TransferFromDeviceTime")) + met.clear_all() + device_mapping = ctx.device_parameter_id_tensor_mapping() + self.assertEqual(len(device_mapping), num_expected_params) + self.assertFalse(met.metric_data("TransferFromDeviceTime")) def test_get_parameters_scalar(self): """Scalar tensors parameters may be shared in the HLO graph if their diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index fb5cdd51c8e..0134e6730f8 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -6,6 +6,7 @@ python3 test/test_operations.py -v python3 test/pjrt/test_runtime_tpu.py python3 test/pjrt/test_collective_ops_tpu.py python3 test/spmd/test_mp_input_sharding.py +python3 test/spmd/test_spmd_lowering_context.py python3 test/spmd/test_xla_sharding.py python3 test/spmd/test_xla_virtual_device.py python3 test/spmd/test_xla_distributed_checkpoint.py diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 84a9d066cbf..b42a084e85c 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1070,7 +1070,7 @@ class PyLoweringContext { // Get a mapping from the HLO input parameters to the backing Tensor values. // This allows the caller to get all parameter information regardless of // how the parameter was allocated (inline tensor, nn.Parameter, constant, - // etc.) + // etc.). This will copy the tensor data from the device to the host. std::unordered_map GetParameterIdTensorMapping() { // Find parameters in the lowering const std::vector& device_data = @@ -1096,10 +1096,38 @@ class PyLoweringContext { return results; } + // Returns a mapping from HLO parameter IDs to their corresponding + // device-backed Tensors. This version only returns parameters that were + // explicitly allocated on device data, accessible via GetTensorParameterId(). + // Unlike GetParameterIdTensorMapping(), it avoids transferring data from + // device to host, making it more efficient especially for SPMD scenarios + // where transferring data involves costly collectives. + std::unordered_map GetDeviceParameterIdTensorMapping() { + // Find parameters in the lowering + const std::vector& device_data = + lowering_ctx.GetParametersData(); + + // Create a mapping from parameter id to the tensor data + std::unordered_map param_to_tensor; + param_to_tensor.reserve(device_data.size()); + + for (const auto& data : device_data) { + std::optional param_id = lowering_ctx.GetParameterId(data); + XLA_CHECK(param_id.has_value()) + << "Parameter ID must exist for device data"; + + at::Tensor tensor = + bridge::AtenFromXlaTensor(torch_xla::XLATensor::Create(data)); + param_to_tensor.emplace(param_id.value(), std::move(tensor)); + } + return param_to_tensor; + } + // Get the parameter identifier of a given tensor. If the tensor is not a // parameter this will always return -1. This is useful in conjunction with - // GetParameterIdTensorMapping to identify which values can be baked into - // the graph and which values must remain parameters. + // GetParameterIdTensorMapping or GetDeviceParameterIdTensorMapping, to + // identify which values can be baked into the graph and which values must + // remain parameters. int64_t GetTensorParameterId(at::Tensor tensor) { // Convert tensor into the backing lazy node XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); @@ -1201,6 +1229,8 @@ void BuildLoweringContextSubmodule(py::module* m) { .def("hlo_json", &PyLoweringContext::GetHloJsonText) .def("parameter_id_tensor_mapping", &PyLoweringContext::GetParameterIdTensorMapping) + .def("device_parameter_id_tensor_mapping", + &PyLoweringContext::GetDeviceParameterIdTensorMapping) .def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId) .def("set_name_string", &PyLoweringContext::SetNameString) .def("get_name_string", &PyLoweringContext::GetNameString);