Skip to content

Commit

Permalink
[LoweringContext] SPMD support
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws committed Dec 3, 2024
1 parent 39e67b5 commit 2291041
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 20 deletions.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$CDIR/spmd/test_mp_input_sharding.py"
run_save_tensor_hlo "$CDIR/spmd/test_spmd_lowering_context.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
run_test "$CDIR/test_torch_distributed_xla_backend.py"
Expand Down
115 changes: 115 additions & 0 deletions test/spmd/test_spmd_lowering_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os
import sys
from pathlib import Path

import torch
import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs

import unittest
import re

import test_xla_sharding_base


class TestSPMDLoweringContext(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
super().setUpClass()

def _get_computation_hlo_txt(self, ctx):
hlo = ctx.hlo()
comp = xb.computation_from_module_proto("my_custom_comp", hlo)
return xb.get_computation_hlo(comp)

def test_basic(self):
# Validate that the output sharding from XLA dump files match our expectation.
save_file = os.getenv('XLA_SAVE_TENSORS_FILE')
save_format = os.getenv('XLA_SAVE_TENSORS_FMT')
assert save_file, "This test should be run with XLA_SAVE_TENSORS_FILE"
save_file += '.0'
assert save_format == 'hlo', "This test should be run with XLA_SAVE_TENSORS_FMT=hlo"

# Remove the save file (if exists) to start from a clean slate
try:
os.remove(save_file)
except:
pass

model_axis = min(8, self.n_devices)
data_axis = self.n_devices // model_axis
mesh_shape = (data_axis, model_axis)
spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y'))

device = xm.xla_device()
a = torch.zeros(2048, device=device, requires_grad=True)
xs.mark_sharding(a, spmd_mesh, ('x',))
b = torch.randn([32, 2048], device=device, requires_grad=True)
xs.mark_sharding(b, spmd_mesh, (None, 'y'))

def fn(x, y):
x = x + 1
return x, y * 2

result = fn(a, b)
ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName")
ctx.build(list(result))
torch_xla.sync()

# Sanity HLO check.
hlo_text = ctx.hlo_text()
self.assertIn('MyCustomName', hlo_text)
self.assertIn('opcode: "parameter"', hlo_text)
self.assertIn('opcode: "add"', hlo_text)
self.assertIn('sharding', hlo_text)

# Ensure that the corresponding input parameters contain the expected sharding.
hlo_comp_txt = self._get_computation_hlo_txt(ctx)
a_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(a)
self.assertRegex(
hlo_comp_txt,
rf'%custom-call.*.*f32[2048]{{0}}.*sharding={re.escape(a_sharding_spec)}'
)
b_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(b)
self.assertRegex(
hlo_comp_txt,
rf'%custom-call.*f32[32,2048]{{0}}.*sharding={re.escape(b_sharding_spec)}'
)

# Ensure that the results retain the same sharding specs.
result_a, result_b = result
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(result_a), a_sharding_spec)
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(result_b), b_sharding_spec)

hlo_content = Path(save_file).read_text()
assert len(re.findall('END_GRAPH',
hlo_content)) == 1, "There is a single graph"

# Extract the content between OUTPUT_SHARDING_BEGIN and OUTPUT_SHARDING_END
pattern = r'#OUTPUT_SHARDING_BEGIN\n(.*?)\n#OUTPUT_SHARDING_END'
match = re.search(pattern, hlo_content, re.DOTALL)
assert match is not None, "#OUTPUT_SHARDING not found in the file"
assert len(match.groups()
) == 1, f"Expected 1 group, but found {len(match.groups())}"
expected_output = match.group(1).strip().split('\n')

# Assert that the output sharding match our expectation.
assert len(expected_output
) == 4, f"Expected 4 lines, but found {len(expected_output)}"
assert expected_output[0] == f"f32[2048] {a_sharding_spec}"
assert expected_output[1] == f"f32[32,2048] {b_sharding_spec}"
assert expected_output[2] == f"f32[2048] {a_sharding_spec}"
assert expected_output[3] == f"f32[32,2048] {b_sharding_spec}"
self.assertTrue(met.counter_value("ExecuteReplicated") == 1)
self.assertTrue(met.counter_value("ExecuteComputation") is None)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
4 changes: 1 addition & 3 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2642,11 +2642,9 @@ def test_api(self):

ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName")
ctx.build([result])
hlo = ctx.hlo()
hlo_text = ctx.hlo_text()
self.assertIn('MyCustomName', hlo_text)
self.assertIn('opcode: "parameter"', hlo_text)
self.assertIn('opcode: "parameter"', hlo_text)
self.assertTrue(hlo_text.count('opcode: "parameter"'), 2)
self.assertIn('opcode: "add"', hlo_text)
mapping = ctx.parameter_id_tensor_mapping()
self.assertEqual(len(mapping), 2)
Expand Down
14 changes: 11 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,9 @@ class PyLoweringContext {
torch::lazy::Output(ir_value.node.get(), ir_value.index));
lowering_ctx.AddResult(root);
}

ShardingUtil::SetHloSharding(&lowering_ctx);

computation = ConsumeValue(lowering_ctx.BuildXla());
}

Expand Down Expand Up @@ -1048,21 +1051,26 @@ class PyLoweringContext {
}
}

ShardingUtil::SetHloSharding(&lowering_ctx);

computation = ConsumeValue(lowering_ctx.BuildXla());

// wrap inputs of cond/body_computation
if ((GetNameString() == "condctx") || (GetNameString() == "bodyctx")) {
std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair;
std::vector<size_t> buffer_donor_indices;
std::vector<xla::HloSharding> param_shardings;
// If sharded, then extract all input Op shardings.
if (UseVirtualDevice()) {
param_shardings = XlaHelpers::ExtractInputShardings(computation);
}
xla::ProgramShape program_shape =
ConsumeValue(computation.GetProgramShape());
// TODO(@manfei): please confirm whether we check for more than two or use
// default value true
bool should_wrap_parameter = (program_shape.parameters_size() >= 2);
if (should_wrap_parameter) {
// For now we assume that we for i loop input is not sharded.
computation = ConsumeValue(XlaHelpers::WrapXlaComputation(
computation, program_shape.parameters(), {}, buffer_donor_indices));
computation, program_shape.parameters(), param_shardings, {}));
}
}
}
Expand Down
32 changes: 20 additions & 12 deletions torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,23 +111,31 @@ LoweringContext::LoweringContext(
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();

xla::XlaOp LoweringContext::GetParameter(
const std::shared_ptr<torch::lazy::BackendData>& data,
const std::shared_ptr<torch::lazy::BackendData>& backend_data,
const std::unordered_set<uint32_t>& unbounded_dynamic_dims) {
torch::lazy::BackendData::Handle handle = data->GetHandle();
torch::lazy::BackendData::Handle handle = backend_data->GetHandle();
auto it = parameters_map_.find(handle);
if (it == parameters_map_.end()) {
xla::Shape shape =
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data)
->shape();
auto data = std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
backend_data);
XLA_CHECK(data != nullptr);
xla::Shape shape = data->shape();
for (const int dim : unbounded_dynamic_dims) {
shape.set_dynamic_dimension(dim, true);
shape.set_dimensions(dim, kUnboundedSize);
}
xla::XlaOp param = xla::Parameter(builder(), parameters_.size(), shape,
absl::StrCat("p", parameters_.size()));
it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
.first;
parameters_.push_back(data);
size_t param_index = parameters_.size();
std::string param_name = absl::StrCat("p", param_index);
xla::XlaOp param;
if (data->HasSharding()) {
xla::OpSharding sharding = data->GetSharding();
xla::XlaScopedShardingAssignment scoped_sharding(builder(), sharding);
param = xla::Parameter(builder(), param_index, shape, param_name);
} else {
param = xla::Parameter(builder(), param_index, shape, param_name);
}
it = parameters_map_.emplace(handle, Parameter{param, param_index}).first;
parameters_.push_back(backend_data);
} else {
XLA_CHECK(unbounded_dynamic_dims.empty())
<< "The unbounded dynamic dims can only be set when Parameter is "
Expand All @@ -138,8 +146,8 @@ xla::XlaOp LoweringContext::GetParameter(
}

std::optional<size_t> LoweringContext::GetParameterId(
const std::shared_ptr<torch::lazy::BackendData>& data) const {
torch::lazy::BackendData::Handle handle = data->GetHandle();
const std::shared_ptr<torch::lazy::BackendData>& backend_data) const {
torch::lazy::BackendData::Handle handle = backend_data->GetHandle();
auto it = parameters_map_.find(handle);
if (it == parameters_map_.end()) {
return std::nullopt;
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ class LoweringContext : public torch::lazy::LoweringContext {
// returned. Otherwise a new one will be created, associated with the tensor
// held in data.
xla::XlaOp GetParameter(
const std::shared_ptr<torch::lazy::BackendData>& data,
const std::shared_ptr<torch::lazy::BackendData>& backend_data,
const std::unordered_set<uint32_t>& dynamic_dims = {});

// If a parameter associated with data has already been declared, returns its
// ID. Otherwise, returns `std::nullopt`.
std::optional<size_t> GetParameterId(
const std::shared_ptr<torch::lazy::BackendData>& data) const;
const std::shared_ptr<torch::lazy::BackendData>& backend_data) const;

// Retrieves the vector holding all the tensors associated with the parameter
// instructions which have been created.
Expand Down

0 comments on commit 2291041

Please sign in to comment.