Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement mark_sharding as a custom op to support dynamo spmd activation sharding #5712

Merged
merged 14 commits into from
Nov 11, 2023
Merged
52 changes: 51 additions & 1 deletion test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@

class SimpleLinear(nn.Module):

def __init__(self):
def __init__(self, mesh=None):
super(SimpleLinear, self).__init__()
self.fc1 = nn.Linear(128, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 1)
# Add an additional 1x1 layer at the end to ensure the final layer
# is not sharded.
self.fc3 = nn.Linear(1, 1)
# If mesh is not none, we'll do a mark sharding inside the forward function
# to ensure dynamo can recognize and trace it in a torch compile.
self.mesh = mesh

def forward(self, x):
if self.mesh and 'xla' in str(self.fc2.weight.device):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yeounoh, I updated the unit test for doing mark_sharding inside torch compile to be part of the existing SimpleLinear. Here, I just do a mark_sharding call inside the forward function. Please let me know if you think this will suffice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG

xs.mark_sharding(
self.fc2.weight, self.mesh, (1, 0), use_dynamo_custom_op=True)
y = self.relu(self.fc1(x))
z = self.fc2(y)
return self.fc3(z)
Expand Down Expand Up @@ -171,6 +177,50 @@ def test_dynamo_input_sharding_threashold(self):
else:
del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD']

def test_dynamo_spmd_mark_sharding_outside_of_compile(self):
device = xm.xla_device()
linear = SimpleLinear().to(device)
linear.eval()
xla_x = torch.randn(1, 128, device=device)
xs.mark_sharding(
linear.fc2.weight,
self._get_mesh((1, self.n_devices)), (1, 0),
use_dynamo_custom_op=True)
xla_res = linear(xla_x)
xm.mark_step()

dynamo_linear = torch.compile(linear, backend="openxla")
dynamo_res = dynamo_linear(xla_x)
torch.allclose(xla_res.cpu(), dynamo_res.cpu())

# Ensure that another run with same input does not trigger additional compilation
compile_count = met.metric_data('CompileTime')[0]
dynamo_res = dynamo_linear(xla_x)
self.assertEqual(met.metric_data('CompileTime')[0], compile_count)

def test_mark_sharding_inside_compile(self):
met.clear_counters()
device = xm.xla_device()
mesh = self._get_mesh((1, self.n_devices))

# Passing this `mesh` as a parameter to `SimpleLinear` will call the dynamo custom op
# variant of mark_sharding inside the forward function.
linear = SimpleLinear(mesh=mesh).to(device)
linear.eval()

xla_x = torch.randn(1, 128, device=device)
xla_res = linear(xla_x)
xm.mark_step()

dynamo_linear = torch.compile(linear, backend="openxla")
dynamo_res = dynamo_linear(xla_x)
torch.allclose(xla_res.cpu(), dynamo_res.cpu())

# Ensure that another run with same input does not trigger additional compilation
compile_count = met.metric_data('CompileTime')[0]
dynamo_res = dynamo_linear(xla_x)
self.assertEqual(met.metric_data('CompileTime')[0], compile_count)


if __name__ == '__main__':
test = unittest.main()
Expand Down
12 changes: 0 additions & 12 deletions torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,5 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
return grad;
}

TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward)));
}
} // namespace aten_autograd_ops
} // namespace torch_xla
11 changes: 11 additions & 0 deletions torch_xla/csrc/aten_autograd_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ struct MaxPool3dAutogradFunction
torch::autograd::variable_list grad_output);
};

torch::Tensor max_pool2d_forward(torch::Tensor self,
torch::IntArrayRef kernel_size,
torch::IntArrayRef stride,
torch::IntArrayRef padding,
torch::IntArrayRef dilation, bool ceil_mode);

torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
torch::IntArrayRef kernel_size,
torch::IntArrayRef stride,
torch::IntArrayRef padding, bool ceil_mode);

} // namespace aten_autograd_ops
} // namespace torch_xla

Expand Down
94 changes: 31 additions & 63 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "pybind11/pytypes.h"
#include "pybind11/stl_bind.h"
#include "torch_xla/csrc/XLANativeFunctions.h"
#include "torch_xla/csrc/aten_autograd_ops.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/helpers.h"
Expand Down Expand Up @@ -1559,72 +1560,39 @@ void InitXlaModuleBindings(py::module m) {
tile_assignment, group_assignment, replication_groups,
ShardingUtil::ShardingType(sharding_type));
}));
m.def("_xla_mark_sharding", [](const at::Tensor& input,
xla::OpSharding sharding) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
const DeviceData* device_data_node = nullptr;
if (xtensor->CurrentIrValue()) {
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
return;
}
}
m.def("_xla_mark_sharding",
[](const at::Tensor& input, xla::OpSharding sharding) {
ShardingUtil::xla_mark_sharding(input, sharding);
});
m.def("_xla_mark_sharding_dynamo_custom_op",
[](const at::Tensor& input, const py::list& tile_assignment,
const py::list& group_assignment, const py::list& replication_groups,
int sharding_type) {
c10::List<at::IntArrayRef> tile_assignment_list =
Copy link
Contributor

@yeounoh yeounoh Nov 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move the following data processing logic into xla_mark_sharding_dynamo_custom_op and just call

xla_mark_sharding_dynamo_custom_op(
              input, tile_assignment_list, group_assignment_list,
              replication_groups_list, sharding_type);
        });

similar to what you've done for mark_sharing.

c10::List<at::IntArrayRef>();
for (auto t : tile_assignment) {
tile_assignment_list.push_back(
at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
if (xtensor->CurrentTensorData().has_value()) {
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);
// When virtual device is enabled for SPMD, we defer the initial
// data transfer to the device and retain the original data on the
// host, until the sharded data transfer.
cpu_tensor = xtensor->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes,
// the same input is called for sharding annotation over multiple steps,
// in which case we can skip if it's the same sharding; however, if it's
// the same input with a different sharding then we block & ask the user
// to clear the existing sharding first.
auto current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec && (current_sharding_spec->sharding.type() !=
xla::OpSharding::REPLICATED)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
return;
}
c10::List<at::IntArrayRef> group_assignment_list =
c10::List<at::IntArrayRef>();
for (auto t : group_assignment) {
group_assignment_list.push_back(
at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((xtensor->CurrentDataHandle() &&
xtensor->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);
c10::List<at::IntArrayRef> replication_groups_list =
c10::List<at::IntArrayRef>();
for (auto t : replication_groups) {
replication_groups_list.push_back(
at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
});
xla_mark_sharding_dynamo_custom_op(
input, tile_assignment_list, group_assignment_list,
replication_groups_list, sharding_type);
});
m.def("_xla_clear_sharding", [](const at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xtensor->ClearShardingSpec();
Expand Down
135 changes: 135 additions & 0 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <unordered_map>

#include "torch/csrc/lazy/core/ir_util.h"
#include "torch_xla/csrc/aten_autograd_ops.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ops/device_data.h"
Expand All @@ -14,7 +16,9 @@
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/tensor.h"
#include "torch_xla/csrc/tensor_methods.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/xla_graph_executor.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/execution_options_util.h"
#include "xla/hlo/ir/hlo_module.h"
Expand Down Expand Up @@ -742,4 +746,135 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData(
source_tensors, GetVirtualDevice().toString(), global_shape, sharding);
}

void ShardingUtil::xla_mark_sharding(const at::Tensor& input,
xla::OpSharding sharding) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
const DeviceData* device_data_node = nullptr;
if (xtensor->CurrentIrValue()) {
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
return;
}
}

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
if (xtensor->CurrentTensorData().has_value()) {
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);
// When virtual device is enabled for SPMD, we defer the initial
// data transfer to the device and retain the original data on the
// host, until the sharded data transfer.
cpu_tensor = xtensor->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes,
// the same input is called for sharding annotation over multiple steps,
// in which case we can skip if it's the same sharding; however, if it's
// the same input with a different sharding then we block & ask the user
// to clear the existing sharding first.
auto current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec && (current_sharding_spec->sharding.type() !=
xla::OpSharding::REPLICATED)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
return;
}

// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((xtensor->CurrentDataHandle() &&
xtensor->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);

// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
}

void xla_mark_sharding_dynamo_custom_op(
const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment,
c10::List<at::IntArrayRef> group_assignment,
c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type) {
py::list tile_assignment_py = py::list();
for (int i = 0; i < tile_assignment.size(); i++) {
py::list pylist = py::list();
for (int64_t t : tile_assignment[i].get().toIntList()) {
pylist.append(t);
}
tile_assignment_py.append(pylist);
}

py::list group_assignment_py = py::list();
for (int i = 0; i < group_assignment.size(); i++) {
py::list pylist = py::list();
for (int64_t t : group_assignment[i].get().toIntList()) {
pylist.append(t);
}
group_assignment_py.append(pylist);
}

py::list replication_groups_py = py::list();
for (int i = 0; i < replication_groups.size(); i++) {
py::list pylist = py::list();
for (int64_t t : replication_groups[i].get().toIntList()) {
pylist.append(t);
}
replication_groups_py.append(pylist);
}

xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding(
tile_assignment_py, group_assignment_py, replication_groups_py,
ShardingUtil::ShardingType(sharding_type));

ShardingUtil::xla_mark_sharding(input, op_sharding);
}

// Macro for defining a function that will be run at static initialization time
// to define a library of operators in the namespace. Used to define a new set
// of custom operators that do not already exist in PyTorch.
TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward)));
m.def(
"xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] "
"tile_assignment, int[][] group_assignment, int[][] replication_groups, "
"int sharding_type) -> ()",
torch::dispatch(c10::DispatchKey::XLA,
TORCH_FN(torch_xla::xla_mark_sharding_dynamo_custom_op)));
}

} // namespace torch_xla
8 changes: 8 additions & 0 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,16 @@ class ShardingUtil {
const std::vector<at::Tensor>& shards,
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec);

static void xla_mark_sharding(const at::Tensor& input,
xla::OpSharding sharding);
};

void xla_mark_sharding_dynamo_custom_op(
const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment,
c10::List<at::IntArrayRef> group_assignment,
c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_XLA_SHARDING_UTIL_H_
Loading