Skip to content

Commit

Permalink
[LoweringContext] SPMD support
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws committed Nov 26, 2024
1 parent 39e67b5 commit a313cee
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 21 deletions.
77 changes: 71 additions & 6 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@
import copy
import itertools
import math
from numbers import Number
from functools import reduce
import numpy
import random
import numpy as np
import re
import torch
import torch.autograd as ad
Expand Down Expand Up @@ -2627,6 +2625,75 @@ def test_multi_init_xla_backend(self):
self.assertEqual(met.counter_value("RegisterXLAFunctions"), 1)


@unittest.skipIf(
os.environ.get('XLA_USE_EAGER_DEBUG_MODE'),
"Skipping test under XLA_USE_EAGER_DEBUG_MODE because `result` will not \
reference a graph due to eager evaluation.")
class TestLoweringContextSPMD(test_utils.XlaTestCase):
def _get_computation_hlo_txt(self, ctx):
hlo = ctx.hlo()
comp = xb.computation_from_module_proto("my_custom_comp", hlo)
return xb.get_computation_hlo(comp)

def setUp(self):
xr.use_spmd()
super().setUp()
num_devices = xr.global_runtime_device_count()
device_ids = np.arange(num_devices)
# Annotate a simple sharding for the test.
self.model_axis = min(8, num_devices)
self.data_axis = num_devices // self.model_axis
mesh_shape = (self.data_axis, self.model_axis)
self.spmd_mesh = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
xs.set_global_mesh(self.spmd_mesh)

def test_basic(self):
device = xm.xla_device()
a = torch.ones(2048, requires_grad=True).to(device)
xs.mark_sharding(a, self.spmd_mesh, ('x',))
b = torch.randn([32, 2048], requires_grad=True).to(device)
xs.mark_sharding(b, self.spmd_mesh, (None, 'y'))

def fn(x, y):
x = x + 1
return x, y * 2

result = fn(a, b)

ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName")
ctx.build(list(result))
torch_xla.sync()

# Sanity HLO check.
hlo_text = ctx.hlo_text()
self.assertIn('MyCustomName', hlo_text)
self.assertIn('opcode: "parameter"', hlo_text)
self.assertIn('opcode: "add"', hlo_text)
self.assertIn('sharding', hlo_text)

# Ensure that the corresponding input parameters contain the expected sharding.
hlo_comp_txt = self._get_computation_hlo_txt(ctx)
a_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(a)
self.assertRegex(
hlo_comp_txt,
rf'%p\d+\.\d+.*f32[2048]{{0}}.*sharding={re.escape(a_sharding_spec)}'
)
b_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(b)
self.assertRegex(
hlo_comp_txt,
rf'%p\d+\.\d+.*f32[32,2048]{{0}}.*sharding={re.escape(b_sharding_spec)}'
)

# Ensure that the results retain the same sharding specs.
result_a, result_b = result
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(result_a), a_sharding_spec
)
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(result_b), b_sharding_spec
)


@unittest.skipIf(
os.environ.get('XLA_USE_EAGER_DEBUG_MODE'),
"Skipping test under XLA_USE_EAGER_DEBUG_MODE because `result` will not \
Expand All @@ -2642,11 +2709,9 @@ def test_api(self):

ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName")
ctx.build([result])
hlo = ctx.hlo()
hlo_text = ctx.hlo_text()
self.assertIn('MyCustomName', hlo_text)
self.assertIn('opcode: "parameter"', hlo_text)
self.assertIn('opcode: "parameter"', hlo_text)
self.assertTrue(hlo_text.count('opcode: "parameter"'), 2)
self.assertIn('opcode: "add"', hlo_text)
mapping = ctx.parameter_id_tensor_mapping()
self.assertEqual(len(mapping), 2)
Expand Down
12 changes: 10 additions & 2 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,8 @@ class PyLoweringContext {
ir_values.push_back(value);
}

ShardingUtil::SetHloSharding(&lowering_ctx);

// Lower the graph using the output IR values
for (auto& ir_value : ir_values) {
xla::XlaOp root = lowering_ctx.GetOutputOp(
Expand Down Expand Up @@ -1048,12 +1050,18 @@ class PyLoweringContext {
}
}

ShardingUtil::SetHloSharding(&lowering_ctx);

computation = ConsumeValue(lowering_ctx.BuildXla());

// wrap inputs of cond/body_computation
if ((GetNameString() == "condctx") || (GetNameString() == "bodyctx")) {
std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair;
std::vector<size_t> buffer_donor_indices;
std::vector<xla::HloSharding> param_shardings;
// If sharded, then extract all input Op shardings.
if (UseVirtualDevice()) {
param_shardings = XlaHelpers::ExtractInputShardings(computation);
}
xla::ProgramShape program_shape =
ConsumeValue(computation.GetProgramShape());
// TODO(@manfei): please confirm whether we check for more than two or use
Expand All @@ -1062,7 +1070,7 @@ class PyLoweringContext {
if (should_wrap_parameter) {
// For now we assume that we for i loop input is not sharded.
computation = ConsumeValue(XlaHelpers::WrapXlaComputation(
computation, program_shape.parameters(), {}, buffer_donor_indices));
computation, program_shape.parameters(), param_shardings, {}));
}
}
}
Expand Down
30 changes: 19 additions & 11 deletions torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,23 +111,31 @@ LoweringContext::LoweringContext(
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();

xla::XlaOp LoweringContext::GetParameter(
const std::shared_ptr<torch::lazy::BackendData>& data,
const std::shared_ptr<torch::lazy::BackendData>& backend_data,
const std::unordered_set<uint32_t>& unbounded_dynamic_dims) {
torch::lazy::BackendData::Handle handle = data->GetHandle();
torch::lazy::BackendData::Handle handle = backend_data->GetHandle();
auto it = parameters_map_.find(handle);
if (it == parameters_map_.end()) {
xla::Shape shape =
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data)
->shape();
auto data = std::dynamic_pointer_cast<runtime::ComputationClient::Data>(backend_data);
XLA_CHECK(data != nullptr);
xla::Shape shape = data->shape();
for (const int dim : unbounded_dynamic_dims) {
shape.set_dynamic_dimension(dim, true);
shape.set_dimensions(dim, kUnboundedSize);
}
xla::XlaOp param = xla::Parameter(builder(), parameters_.size(), shape,
absl::StrCat("p", parameters_.size()));
it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
size_t param_index = parameters_.size();
std::string param_name = absl::StrCat("p", param_index);
xla::XlaOp param;
if (data->HasSharding()) {
xla::OpSharding sharding = data->GetSharding();
xla::XlaScopedShardingAssignment scoped_sharding(builder(), sharding);
param = xla::Parameter(builder(), param_index, shape, param_name);
} else {
param = xla::Parameter(builder(), param_index, shape, param_name);
}
it = parameters_map_.emplace(handle, Parameter{param, param_index})
.first;
parameters_.push_back(data);
parameters_.push_back(backend_data);
} else {
XLA_CHECK(unbounded_dynamic_dims.empty())
<< "The unbounded dynamic dims can only be set when Parameter is "
Expand All @@ -138,8 +146,8 @@ xla::XlaOp LoweringContext::GetParameter(
}

std::optional<size_t> LoweringContext::GetParameterId(
const std::shared_ptr<torch::lazy::BackendData>& data) const {
torch::lazy::BackendData::Handle handle = data->GetHandle();
const std::shared_ptr<torch::lazy::BackendData>& backend_data) const {
torch::lazy::BackendData::Handle handle = backend_data->GetHandle();
auto it = parameters_map_.find(handle);
if (it == parameters_map_.end()) {
return std::nullopt;
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ class LoweringContext : public torch::lazy::LoweringContext {
// returned. Otherwise a new one will be created, associated with the tensor
// held in data.
xla::XlaOp GetParameter(
const std::shared_ptr<torch::lazy::BackendData>& data,
const std::shared_ptr<torch::lazy::BackendData>& backend_data,
const std::unordered_set<uint32_t>& dynamic_dims = {});

// If a parameter associated with data has already been declared, returns its
// ID. Otherwise, returns `std::nullopt`.
std::optional<size_t> GetParameterId(
const std::shared_ptr<torch::lazy::BackendData>& data) const;
const std::shared_ptr<torch::lazy::BackendData>& backend_data) const;

// Retrieves the vector holding all the tensors associated with the parameter
// instructions which have been created.
Expand Down

0 comments on commit a313cee

Please sign in to comment.