diff --git a/test/run_tests.sh b/test/run_tests.sh index eeb2e8ee34d..fe857133792 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -248,7 +248,7 @@ function run_xla_op_tests3 { run_test "$CDIR/spmd/test_xla_auto_sharding.py" run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py" run_test "$CDIR/spmd/test_mp_input_sharding.py" - run_test "$CDIR/spmd/test_spmd_lowering_context.py" + run_save_tensor_hlo "$CDIR/spmd/test_spmd_lowering_context.py" run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_input_output_aliases.py" run_test "$CDIR/test_torch_distributed_xla_backend.py" diff --git a/test/spmd/test_spmd_lowering_context.py b/test/spmd/test_spmd_lowering_context.py index cb5018b1a4f..ac509fd0735 100644 --- a/test/spmd/test_spmd_lowering_context.py +++ b/test/spmd/test_spmd_lowering_context.py @@ -1,4 +1,7 @@ +import os +import re import sys +from pathlib import Path import unittest @@ -6,10 +9,10 @@ import torch import torch_xla +import torch_xla.core.xla_builder as xb import torch_xla.debug.metrics as met import torch_xla.distributed.spmd as xs import torch_xla.core.xla_model as xm -import contextlib class TestSPMDLoweringContext(test_xla_sharding_base.XlaShardingTest): @@ -18,6 +21,97 @@ class TestSPMDLoweringContext(test_xla_sharding_base.XlaShardingTest): def setUpClass(cls): super().setUpClass() + def _get_computation_hlo_txt(self, ctx): + hlo = ctx.hlo() + comp = xb.computation_from_module_proto("my_custom_comp", hlo) + return xb.get_computation_hlo(comp) + + def test_basic(self): + save_file = os.getenv('XLA_SAVE_TENSORS_FILE') + save_format = os.getenv('XLA_SAVE_TENSORS_FMT') + assert save_file, "This test should be run with XLA_SAVE_TENSORS_FILE" + save_file += '.0' # Identify a single device + assert save_format == 'hlo', "This test should be run with XLA_SAVE_TENSORS_FMT=hlo" + + # Ensure that there is no existing file to begin with. + try: + os.remove(save_file) + except: + pass + + model_axis = min(8, self.n_devices) + data_axis = self.n_devices // model_axis + mesh_shape = (data_axis, model_axis) + spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y')) + + device = xm.xla_device() + a = torch.zeros(2048, device=device, requires_grad=True) + xs.mark_sharding(a, spmd_mesh, ('x',)) + b = torch.randn([32, 2048], device=device, requires_grad=True) + xs.mark_sharding(b, spmd_mesh, (None, 'y')) + + def fn(x, y): + x = x + 1 + return x, y * 2 + + result = fn(a, b) + ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName") + ctx.build(list(result)) + torch_xla.sync() + + # Sanity HLO check. + hlo_text = ctx.hlo_text() + self.assertIn('MyCustomName', hlo_text) + self.assertIn('opcode: "parameter"', hlo_text) + self.assertIn('opcode: "add"', hlo_text) + self.assertIn('sharding', hlo_text) + + # Ensure that the corresponding input parameters contain the expected sharding. + hlo_comp_txt = self._get_computation_hlo_txt(ctx) + a_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(a) + self.assertRegex( + hlo_comp_txt, + rf'%custom-call.*.*f32[2048]{{0}}.*sharding={re.escape(a_sharding_spec)}' + ) + b_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(b) + self.assertRegex( + hlo_comp_txt, + rf'%custom-call.*f32[32,2048]{{0}}.*sharding={re.escape(b_sharding_spec)}' + ) + + # Ensure that the results retain the same sharding specs. + result_a, result_b = result + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(result_a), a_sharding_spec) + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(result_b), b_sharding_spec) + + hlo_content = Path(save_file).read_text() + assert len(re.findall('END_GRAPH', + hlo_content)) == 1, "There is a single graph" + + # Extract the content between OUTPUT_SHARDING_BEGIN and OUTPUT_SHARDING_END + pattern = r'#OUTPUT_SHARDING_BEGIN\n(.*?)\n#OUTPUT_SHARDING_END' + match = re.search(pattern, hlo_content, re.DOTALL) + assert match is not None, "#OUTPUT_SHARDING not found in the file" + assert len(match.groups() + ) == 1, f"Expected 1 group, but found {len(match.groups())}" + expected_output = match.group(1).strip().split('\n') + + # Assert that the output sharding match our expectation. + assert len(expected_output + ) == 4, f"Expected 4 lines, but found {len(expected_output)}" + assert expected_output[0] == f"f32[2048] {a_sharding_spec}" + assert expected_output[1] == f"f32[32,2048] {b_sharding_spec}" + assert expected_output[2] == f"f32[2048] {a_sharding_spec}" + assert expected_output[3] == f"f32[32,2048] {b_sharding_spec}" + self.assertTrue(met.counter_value("ExecuteReplicated") == 1) + self.assertTrue(met.counter_value("ExecuteComputation") is None) + + # Remove the file once the test is complete. + # TODO(rpsilva-aws): Add a proper cleanup wrapper to avoid lingering files. + os.remove(save_file) + def test_device_parameter_id_tensor_mapping(self): met.clear_all() diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index b42a084e85c..9aa14fd51c6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1006,6 +1006,9 @@ class PyLoweringContext { torch::lazy::Output(ir_value.node.get(), ir_value.index)); lowering_ctx.AddResult(root); } + + ShardingUtil::SetHloSharding(&lowering_ctx); + computation = ConsumeValue(lowering_ctx.BuildXla()); } @@ -1048,21 +1051,26 @@ class PyLoweringContext { } } + ShardingUtil::SetHloSharding(&lowering_ctx); + computation = ConsumeValue(lowering_ctx.BuildXla()); // wrap inputs of cond/body_computation if ((GetNameString() == "condctx") || (GetNameString() == "bodyctx")) { std::vector> input_output_alias_pair; - std::vector buffer_donor_indices; + std::vector param_shardings; + // If sharded, then extract all input Op shardings. + if (UseVirtualDevice()) { + param_shardings = XlaHelpers::ExtractInputShardings(computation); + } xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); // TODO(@manfei): please confirm whether we check for more than two or use // default value true bool should_wrap_parameter = (program_shape.parameters_size() >= 2); if (should_wrap_parameter) { - // For now we assume that we for i loop input is not sharded. computation = ConsumeValue(XlaHelpers::WrapXlaComputation( - computation, program_shape.parameters(), {}, buffer_donor_indices)); + computation, program_shape.parameters(), param_shardings, {})); } } } diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index c2db9b36309..6c2906dc724 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -111,23 +111,31 @@ LoweringContext::LoweringContext( static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); xla::XlaOp LoweringContext::GetParameter( - const std::shared_ptr& data, + const std::shared_ptr& backend_data, const std::unordered_set& unbounded_dynamic_dims) { - torch::lazy::BackendData::Handle handle = data->GetHandle(); + torch::lazy::BackendData::Handle handle = backend_data->GetHandle(); auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { - xla::Shape shape = - std::dynamic_pointer_cast(data) - ->shape(); + auto data = std::dynamic_pointer_cast( + backend_data); + XLA_CHECK(data != nullptr); + xla::Shape shape = data->shape(); for (const int dim : unbounded_dynamic_dims) { shape.set_dynamic_dimension(dim, true); shape.set_dimensions(dim, kUnboundedSize); } - xla::XlaOp param = xla::Parameter(builder(), parameters_.size(), shape, - absl::StrCat("p", parameters_.size())); - it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()}) - .first; - parameters_.push_back(data); + size_t param_index = parameters_.size(); + std::string param_name = absl::StrCat("p", param_index); + xla::XlaOp param; + if (data->HasSharding()) { + xla::OpSharding sharding = data->GetSharding(); + xla::XlaScopedShardingAssignment scoped_sharding(builder(), sharding); + param = xla::Parameter(builder(), param_index, shape, param_name); + } else { + param = xla::Parameter(builder(), param_index, shape, param_name); + } + it = parameters_map_.emplace(handle, Parameter{param, param_index}).first; + parameters_.push_back(backend_data); } else { XLA_CHECK(unbounded_dynamic_dims.empty()) << "The unbounded dynamic dims can only be set when Parameter is " @@ -138,8 +146,8 @@ xla::XlaOp LoweringContext::GetParameter( } std::optional LoweringContext::GetParameterId( - const std::shared_ptr& data) const { - torch::lazy::BackendData::Handle handle = data->GetHandle(); + const std::shared_ptr& backend_data) const { + torch::lazy::BackendData::Handle handle = backend_data->GetHandle(); auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { return std::nullopt; diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index 3a36695e1c0..cb4f0bc2d2f 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -50,13 +50,13 @@ class LoweringContext : public torch::lazy::LoweringContext { // returned. Otherwise a new one will be created, associated with the tensor // held in data. xla::XlaOp GetParameter( - const std::shared_ptr& data, + const std::shared_ptr& backend_data, const std::unordered_set& dynamic_dims = {}); // If a parameter associated with data has already been declared, returns its // ID. Otherwise, returns `std::nullopt`. std::optional GetParameterId( - const std::shared_ptr& data) const; + const std::shared_ptr& backend_data) const; // Retrieves the vector holding all the tensors associated with the parameter // instructions which have been created.