Skip to content

Commit

Permalink
Implement mark_sharding as a custom op to support dynamo spmd activat…
Browse files Browse the repository at this point in the history
…ion sharding (pytorch#5712)
  • Loading branch information
wonjoolee95 authored and chunnienc committed Dec 14, 2023
1 parent 5c481db commit c7a04ac
Show file tree
Hide file tree
Showing 7 changed files with 276 additions and 89 deletions.
52 changes: 51 additions & 1 deletion test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@

class SimpleLinear(nn.Module):

def __init__(self):
def __init__(self, mesh=None):
super(SimpleLinear, self).__init__()
self.fc1 = nn.Linear(128, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 1)
# Add an additional 1x1 layer at the end to ensure the final layer
# is not sharded.
self.fc3 = nn.Linear(1, 1)
# If mesh is not none, we'll do a mark sharding inside the forward function
# to ensure dynamo can recognize and trace it in a torch compile.
self.mesh = mesh

def forward(self, x):
if self.mesh and 'xla' in str(self.fc2.weight.device):
xs.mark_sharding(
self.fc2.weight, self.mesh, (1, 0), use_dynamo_custom_op=True)
y = self.relu(self.fc1(x))
z = self.fc2(y)
return self.fc3(z)
Expand Down Expand Up @@ -171,6 +177,50 @@ def test_dynamo_input_sharding_threashold(self):
else:
del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD']

def test_dynamo_spmd_mark_sharding_outside_of_compile(self):
device = xm.xla_device()
linear = SimpleLinear().to(device)
linear.eval()
xla_x = torch.randn(1, 128, device=device)
xs.mark_sharding(
linear.fc2.weight,
self._get_mesh((1, self.n_devices)), (1, 0),
use_dynamo_custom_op=True)
xla_res = linear(xla_x)
xm.mark_step()

dynamo_linear = torch.compile(linear, backend="openxla")
dynamo_res = dynamo_linear(xla_x)
torch.allclose(xla_res.cpu(), dynamo_res.cpu())

# Ensure that another run with same input does not trigger additional compilation
compile_count = met.metric_data('CompileTime')[0]
dynamo_res = dynamo_linear(xla_x)
self.assertEqual(met.metric_data('CompileTime')[0], compile_count)

def test_mark_sharding_inside_compile(self):
met.clear_counters()
device = xm.xla_device()
mesh = self._get_mesh((1, self.n_devices))

# Passing this `mesh` as a parameter to `SimpleLinear` will call the dynamo custom op
# variant of mark_sharding inside the forward function.
linear = SimpleLinear(mesh=mesh).to(device)
linear.eval()

xla_x = torch.randn(1, 128, device=device)
xla_res = linear(xla_x)
xm.mark_step()

dynamo_linear = torch.compile(linear, backend="openxla")
dynamo_res = dynamo_linear(xla_x)
torch.allclose(xla_res.cpu(), dynamo_res.cpu())

# Ensure that another run with same input does not trigger additional compilation
compile_count = met.metric_data('CompileTime')[0]
dynamo_res = dynamo_linear(xla_x)
self.assertEqual(met.metric_data('CompileTime')[0], compile_count)


if __name__ == '__main__':
test = unittest.main()
Expand Down
12 changes: 0 additions & 12 deletions torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,5 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
return grad;
}

TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward)));
}
} // namespace aten_autograd_ops
} // namespace torch_xla
11 changes: 11 additions & 0 deletions torch_xla/csrc/aten_autograd_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ struct MaxPool3dAutogradFunction
torch::autograd::variable_list grad_output);
};

torch::Tensor max_pool2d_forward(torch::Tensor self,
torch::IntArrayRef kernel_size,
torch::IntArrayRef stride,
torch::IntArrayRef padding,
torch::IntArrayRef dilation, bool ceil_mode);

torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
torch::IntArrayRef kernel_size,
torch::IntArrayRef stride,
torch::IntArrayRef padding, bool ceil_mode);

} // namespace aten_autograd_ops
} // namespace torch_xla

Expand Down
94 changes: 31 additions & 63 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "pybind11/pytypes.h"
#include "pybind11/stl_bind.h"
#include "torch_xla/csrc/XLANativeFunctions.h"
#include "torch_xla/csrc/aten_autograd_ops.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/dtype.h"
Expand Down Expand Up @@ -1561,72 +1562,39 @@ void InitXlaModuleBindings(py::module m) {
tile_assignment, group_assignment, replication_groups,
ShardingUtil::ShardingType(sharding_type));
}));
m.def("_xla_mark_sharding", [](const at::Tensor& input,
xla::OpSharding sharding) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
const DeviceData* device_data_node = nullptr;
if (xtensor->CurrentIrValue()) {
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
return;
}
}
m.def("_xla_mark_sharding",
[](const at::Tensor& input, xla::OpSharding sharding) {
ShardingUtil::xla_mark_sharding(input, sharding);
});
m.def("_xla_mark_sharding_dynamo_custom_op",
[](const at::Tensor& input, const py::list& tile_assignment,
const py::list& group_assignment, const py::list& replication_groups,
int sharding_type) {
c10::List<at::IntArrayRef> tile_assignment_list =
c10::List<at::IntArrayRef>();
for (auto t : tile_assignment) {
tile_assignment_list.push_back(
at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
if (xtensor->CurrentTensorData().has_value()) {
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);
// When virtual device is enabled for SPMD, we defer the initial
// data transfer to the device and retain the original data on the
// host, until the sharded data transfer.
cpu_tensor = xtensor->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes,
// the same input is called for sharding annotation over multiple steps,
// in which case we can skip if it's the same sharding; however, if it's
// the same input with a different sharding then we block & ask the user
// to clear the existing sharding first.
auto current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec && (current_sharding_spec->sharding.type() !=
xla::OpSharding::REPLICATED)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
return;
}
c10::List<at::IntArrayRef> group_assignment_list =
c10::List<at::IntArrayRef>();
for (auto t : group_assignment) {
group_assignment_list.push_back(
at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((xtensor->CurrentDataHandle() &&
xtensor->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);
c10::List<at::IntArrayRef> replication_groups_list =
c10::List<at::IntArrayRef>();
for (auto t : replication_groups) {
replication_groups_list.push_back(
at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
});
xla_mark_sharding_dynamo_custom_op(
input, tile_assignment_list, group_assignment_list,
replication_groups_list, sharding_type);
});
m.def("_xla_clear_sharding", [](const at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xtensor->ClearShardingSpec();
Expand Down
135 changes: 135 additions & 0 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <unordered_map>

#include "torch/csrc/lazy/core/ir_util.h"
#include "torch_xla/csrc/aten_autograd_ops.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/dtype.h"
#include "torch_xla/csrc/helpers.h"
Expand All @@ -15,7 +17,9 @@
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/tensor.h"
#include "torch_xla/csrc/tensor_methods.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/xla_graph_executor.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/execution_options_util.h"
#include "xla/hlo/ir/hlo_module.h"
Expand Down Expand Up @@ -743,4 +747,135 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData(
source_tensors, GetVirtualDevice().toString(), global_shape, sharding);
}

void ShardingUtil::xla_mark_sharding(const at::Tensor& input,
xla::OpSharding sharding) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
const DeviceData* device_data_node = nullptr;
if (xtensor->CurrentIrValue()) {
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
return;
}
}

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
if (xtensor->CurrentTensorData().has_value()) {
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);
// When virtual device is enabled for SPMD, we defer the initial
// data transfer to the device and retain the original data on the
// host, until the sharded data transfer.
cpu_tensor = xtensor->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes,
// the same input is called for sharding annotation over multiple steps,
// in which case we can skip if it's the same sharding; however, if it's
// the same input with a different sharding then we block & ask the user
// to clear the existing sharding first.
auto current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec && (current_sharding_spec->sharding.type() !=
xla::OpSharding::REPLICATED)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
return;
}

// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((xtensor->CurrentDataHandle() &&
xtensor->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);

// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
}

void xla_mark_sharding_dynamo_custom_op(
const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment,
c10::List<at::IntArrayRef> group_assignment,
c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type) {
py::list tile_assignment_py = py::list();
for (int i = 0; i < tile_assignment.size(); i++) {
py::list pylist = py::list();
for (int64_t t : tile_assignment[i].get().toIntList()) {
pylist.append(t);
}
tile_assignment_py.append(pylist);
}

py::list group_assignment_py = py::list();
for (int i = 0; i < group_assignment.size(); i++) {
py::list pylist = py::list();
for (int64_t t : group_assignment[i].get().toIntList()) {
pylist.append(t);
}
group_assignment_py.append(pylist);
}

py::list replication_groups_py = py::list();
for (int i = 0; i < replication_groups.size(); i++) {
py::list pylist = py::list();
for (int64_t t : replication_groups[i].get().toIntList()) {
pylist.append(t);
}
replication_groups_py.append(pylist);
}

xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding(
tile_assignment_py, group_assignment_py, replication_groups_py,
ShardingUtil::ShardingType(sharding_type));

ShardingUtil::xla_mark_sharding(input, op_sharding);
}

// Macro for defining a function that will be run at static initialization time
// to define a library of operators in the namespace. Used to define a new set
// of custom operators that do not already exist in PyTorch.
TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward)));
m.def(
"xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] "
"tile_assignment, int[][] group_assignment, int[][] replication_groups, "
"int sharding_type) -> ()",
torch::dispatch(c10::DispatchKey::XLA,
TORCH_FN(torch_xla::xla_mark_sharding_dynamo_custom_op)));
}

} // namespace torch_xla
8 changes: 8 additions & 0 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,16 @@ class ShardingUtil {
const std::vector<at::Tensor>& shards,
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec);

static void xla_mark_sharding(const at::Tensor& input,
xla::OpSharding sharding);
};

void xla_mark_sharding_dynamo_custom_op(
const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment,
c10::List<at::IntArrayRef> group_assignment,
c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_XLA_SHARDING_UTIL_H_
Loading

0 comments on commit c7a04ac

Please sign in to comment.