diff --git a/test/test_operations.py b/test/test_operations.py index cc3a73c4580..846dbd97f03 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -20,10 +20,8 @@ import copy import itertools import math -from numbers import Number from functools import reduce -import numpy -import random +import numpy as np import re import torch import torch.autograd as ad @@ -2627,6 +2625,75 @@ def test_multi_init_xla_backend(self): self.assertEqual(met.counter_value("RegisterXLAFunctions"), 1) +@unittest.skipIf( + os.environ.get('XLA_USE_EAGER_DEBUG_MODE'), + "Skipping test under XLA_USE_EAGER_DEBUG_MODE because `result` will not \ + reference a graph due to eager evaluation.") +class TestLoweringContextSPMD(test_utils.XlaTestCase): + def _get_computation_hlo_txt(self, ctx): + hlo = ctx.hlo() + comp = xb.computation_from_module_proto("my_custom_comp", hlo) + return xb.get_computation_hlo(comp) + + def setUp(self): + xr.use_spmd() + super().setUp() + num_devices = xr.global_runtime_device_count() + device_ids = np.arange(num_devices) + # Annotate a simple sharding for the test. + self.model_axis = min(8, num_devices) + self.data_axis = num_devices // self.model_axis + mesh_shape = (self.data_axis, self.model_axis) + self.spmd_mesh = xs.Mesh(device_ids, mesh_shape, ('x', 'y')) + xs.set_global_mesh(self.spmd_mesh) + + def test_basic(self): + device = xm.xla_device() + a = torch.ones(2048, requires_grad=True).to(device) + xs.mark_sharding(a, self.spmd_mesh, ('x',)) + b = torch.randn([32, 2048], requires_grad=True).to(device) + xs.mark_sharding(b, self.spmd_mesh, (None, 'y')) + + def fn(x, y): + x = x + 1 + return x, y * 2 + + result = fn(a, b) + + ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName") + ctx.build(list(result)) + torch_xla.sync() + + # Sanity HLO check. + hlo_text = ctx.hlo_text() + self.assertIn('MyCustomName', hlo_text) + self.assertIn('opcode: "parameter"', hlo_text) + self.assertIn('opcode: "add"', hlo_text) + self.assertIn('sharding', hlo_text) + + # Ensure that the corresponding input parameters contain the expected sharding. + hlo_comp_txt = self._get_computation_hlo_txt(ctx) + a_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(a) + self.assertRegex( + hlo_comp_txt, + rf'%p\d+\.\d+.*f32[2048]{{0}}.*sharding={re.escape(a_sharding_spec)}' + ) + b_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(b) + self.assertRegex( + hlo_comp_txt, + rf'%p\d+\.\d+.*f32[32,2048]{{0}}.*sharding={re.escape(b_sharding_spec)}' + ) + + # Ensure that the results retain the same sharding specs. + result_a, result_b = result + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(result_a), a_sharding_spec + ) + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(result_b), b_sharding_spec + ) + + @unittest.skipIf( os.environ.get('XLA_USE_EAGER_DEBUG_MODE'), "Skipping test under XLA_USE_EAGER_DEBUG_MODE because `result` will not \ @@ -2642,11 +2709,9 @@ def test_api(self): ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName") ctx.build([result]) - hlo = ctx.hlo() hlo_text = ctx.hlo_text() self.assertIn('MyCustomName', hlo_text) - self.assertIn('opcode: "parameter"', hlo_text) - self.assertIn('opcode: "parameter"', hlo_text) + self.assertTrue(hlo_text.count('opcode: "parameter"'), 2) self.assertIn('opcode: "add"', hlo_text) mapping = ctx.parameter_id_tensor_mapping() self.assertEqual(len(mapping), 2) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 84a9d066cbf..9a8c9312a85 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1000,6 +1000,8 @@ class PyLoweringContext { ir_values.push_back(value); } + ShardingUtil::SetHloSharding(&lowering_ctx); + // Lower the graph using the output IR values for (auto& ir_value : ir_values) { xla::XlaOp root = lowering_ctx.GetOutputOp( @@ -1048,12 +1050,18 @@ class PyLoweringContext { } } + ShardingUtil::SetHloSharding(&lowering_ctx); + computation = ConsumeValue(lowering_ctx.BuildXla()); // wrap inputs of cond/body_computation if ((GetNameString() == "condctx") || (GetNameString() == "bodyctx")) { std::vector> input_output_alias_pair; - std::vector buffer_donor_indices; + std::vector param_shardings; + // If sharded, then extract all input Op shardings. + if (UseVirtualDevice()) { + param_shardings = XlaHelpers::ExtractInputShardings(computation); + } xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); // TODO(@manfei): please confirm whether we check for more than two or use @@ -1062,7 +1070,7 @@ class PyLoweringContext { if (should_wrap_parameter) { // For now we assume that we for i loop input is not sharded. computation = ConsumeValue(XlaHelpers::WrapXlaComputation( - computation, program_shape.parameters(), {}, buffer_donor_indices)); + computation, program_shape.parameters(), param_shardings, {})); } } } diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index c2db9b36309..304f480360b 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -111,23 +111,31 @@ LoweringContext::LoweringContext( static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); xla::XlaOp LoweringContext::GetParameter( - const std::shared_ptr& data, + const std::shared_ptr& backend_data, const std::unordered_set& unbounded_dynamic_dims) { - torch::lazy::BackendData::Handle handle = data->GetHandle(); + torch::lazy::BackendData::Handle handle = backend_data->GetHandle(); auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { - xla::Shape shape = - std::dynamic_pointer_cast(data) - ->shape(); + auto data = std::dynamic_pointer_cast(backend_data); + XLA_CHECK(data != nullptr); + xla::Shape shape = data->shape(); for (const int dim : unbounded_dynamic_dims) { shape.set_dynamic_dimension(dim, true); shape.set_dimensions(dim, kUnboundedSize); } - xla::XlaOp param = xla::Parameter(builder(), parameters_.size(), shape, - absl::StrCat("p", parameters_.size())); - it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()}) + size_t param_index = parameters_.size(); + std::string param_name = absl::StrCat("p", param_index); + xla::XlaOp param; + if (data->HasSharding()) { + xla::OpSharding sharding = data->GetSharding(); + xla::XlaScopedShardingAssignment scoped_sharding(builder(), sharding); + param = xla::Parameter(builder(), param_index, shape, param_name); + } else { + param = xla::Parameter(builder(), param_index, shape, param_name); + } + it = parameters_map_.emplace(handle, Parameter{param, param_index}) .first; - parameters_.push_back(data); + parameters_.push_back(backend_data); } else { XLA_CHECK(unbounded_dynamic_dims.empty()) << "The unbounded dynamic dims can only be set when Parameter is " @@ -138,8 +146,8 @@ xla::XlaOp LoweringContext::GetParameter( } std::optional LoweringContext::GetParameterId( - const std::shared_ptr& data) const { - torch::lazy::BackendData::Handle handle = data->GetHandle(); + const std::shared_ptr& backend_data) const { + torch::lazy::BackendData::Handle handle = backend_data->GetHandle(); auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { return std::nullopt; diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index 3a36695e1c0..cb4f0bc2d2f 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -50,13 +50,13 @@ class LoweringContext : public torch::lazy::LoweringContext { // returned. Otherwise a new one will be created, associated with the tensor // held in data. xla::XlaOp GetParameter( - const std::shared_ptr& data, + const std::shared_ptr& backend_data, const std::unordered_set& dynamic_dims = {}); // If a parameter associated with data has already been declared, returns its // ID. Otherwise, returns `std::nullopt`. std::optional GetParameterId( - const std::shared_ptr& data) const; + const std::shared_ptr& backend_data) const; // Retrieves the vector holding all the tensors associated with the parameter // instructions which have been created.