From 53e7ec0be1fcabbbd2907684889a0aacd7a4cc4b Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Thu, 12 Oct 2023 23:42:55 +0000 Subject: [PATCH 01/14] Implement mark_sharding as a custom op to support dynamo spmd activation sharding --- test/spmd/test_dynamo_spmd.py | 14 ++++++++++ torch_xla/csrc/ops/custom_mark_sharding.cpp | 29 +++++++++++++++++++++ torch_xla/csrc/ops/custom_mark_sharding.h | 22 ++++++++++++++++ torch_xla/csrc/tensor_methods.cpp | 9 +++++++ torch_xla/csrc/tensor_methods.h | 3 +++ torch_xla/csrc/xla_lower_util.cpp | 5 ++++ torch_xla/csrc/xla_lower_util.h | 2 ++ 7 files changed, 84 insertions(+) create mode 100644 torch_xla/csrc/ops/custom_mark_sharding.cpp create mode 100644 torch_xla/csrc/ops/custom_mark_sharding.h diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 22cd2980413..cc3edc4f79b 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -171,6 +171,20 @@ def test_dynamo_input_sharding_threashold(self): else: del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] + def test_mark_sharding_after_compile(self): + device = xm.xla_device() + linear = SimpleLinear().to(device) + linear.eval() + xla_x = torch.randn(1, 128, device=device) + xs.mark_sharding(linear.fc2.weight, self._get_mesh((1, self.n_devices)), + (1, 0)) + xla_res = linear(xla_x) + xm.mark_step() + + dynamo_linear = torch.compile(linear, backend="openxla") + dynamo_res = dynamo_linear(xla_x) + torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/ops/custom_mark_sharding.cpp b/torch_xla/csrc/ops/custom_mark_sharding.cpp new file mode 100644 index 00000000000..53d3480855d --- /dev/null +++ b/torch_xla/csrc/ops/custom_mark_sharding.cpp @@ -0,0 +1,29 @@ +#include "torch_xla/csrc/ops/custom_mark_sharding.h" + +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/xla_lower_util.h" + +namespace torch_xla { + +CustomMarkSharding::CustomMarkSharding(const torch::lazy::Value& input) + : XlaNode(xla_custom_mark_sharding, {input}, GetXlaShape(input), + /*num_outputs=*/1, torch::lazy::MHash(std::string("MarkSharding"))) {} + +torch::lazy::NodePtr CustomMarkSharding::Clone(torch::lazy::OpList operands) const { + return torch::lazy::MakeNode(operands.at(0)); +} + +XlaOpVector CustomMarkSharding::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp output = BuildCustomMarkSharding(input); + return ReturnOp(output, loctx); +} + +std::string CustomMarkSharding::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", MarkSharding"; + return ss.str(); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/custom_mark_sharding.h b/torch_xla/csrc/ops/custom_mark_sharding.h new file mode 100644 index 00000000000..15667cb94f7 --- /dev/null +++ b/torch_xla/csrc/ops/custom_mark_sharding.h @@ -0,0 +1,22 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_CUSTOM_MARK_SHARDING_H_ +#define XLA_TORCH_XLA_CSRC_OPS_CUSTOM_MARK_SHARDING_H_ + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class CustomMarkSharding : public XlaNode { + public: + // Make a custom call to Sharding. + CustomMarkSharding(const torch::lazy::Value& input); + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_CUSTOM_MARK_SHARDING_H_ diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index fa54741190d..935e987afe6 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -442,6 +442,15 @@ void custom_sharding_( input->SetShardingSpec(*sharding_spec); } +void custom_mark_sharding( + const XLATensorPtr& input, + const std::shared_ptr& sharding_spec) { + torch::lazy::NodePtr node = torch::lazy::MakeNode( + torch::lazy::MakeNode(input->GetIrValue())); + return {input->CreateFrom(torch::lazy::Value(node, 0)), + torch::lazy::Value(node, 1)}; +} + XLATensorPtr get_dimensions_size(const XLATensorPtr& input, std::vector dimensions) { return input->CreateFrom(torch::lazy::MakeNode( diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 5a714170300..1ac71c9b4f8 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -57,6 +57,9 @@ std::pair collective_permute( void custom_sharding_(const XLATensorPtr& input, const std::shared_ptr& spec); +void custom_mark_sharding(const XLATensorPtr& input, + const std::shared_ptr& spec); + XLATensorPtr get_dimensions_size(const XLATensorPtr& input, std::vector dimensions); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 374e7569ca0..774f65a91e2 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1228,4 +1228,9 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) { {input}, ShapeHelper::ShapeOfXlaOp(input)); } +xla::XlaOp BuildCustomMarkSharding(const xla::XlaOp& input) { + return xla::CustomCall(input.builder(), /*call_target_name=*/"MarkSharding", + {input}, ShapeHelper::ShapeOfXlaOp(input)); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 252bbe5e31c..71ac1ef1a02 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -150,6 +150,8 @@ xla::XlaOp BuildCdistForward(xla::XlaOp x1, xla::XlaOp x2, xla::XlaOp p, xla::XlaOp BuildCustomSharding(const xla::XlaOp& input); +xla::XlaOp BuildCustomMarkSharding(const xla::XlaOp& input); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_XLA_LOWER_UTIL_H_ From f0e8a94b05366c2760859093e5b07ed98e66cf1e Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Thu, 19 Oct 2023 20:19:31 +0000 Subject: [PATCH 02/14] Update to include OpSharding as an input --- torch_xla/csrc/init_python_bindings.cpp | 8 ++++++++ torch_xla/csrc/ops/custom_mark_sharding.cpp | 11 ++++++----- torch_xla/csrc/ops/custom_mark_sharding.h | 5 ++++- torch_xla/csrc/tensor_methods.cpp | 8 +++----- torch_xla/csrc/tensor_methods.h | 2 +- torch_xla/csrc/xla_lower_util.cpp | 21 ++++++++++++++++++--- torch_xla/csrc/xla_lower_util.h | 2 +- 7 files changed, 41 insertions(+), 16 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 1e6bb020fe5..a1aea978c9d 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1625,6 +1625,14 @@ void InitXlaModuleBindings(py::module m) { // Register sharded tensor data. XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); }); + m.def("_xla_mark_sharding_custom_op", [](const at::Tensor& input, + xla::OpSharding sharding) { + TORCH_LAZY_COUNTER("XlaMarkSharding", 1); + XLA_CHECK(UseVirtualDevice()) + << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + tensor_methods::custom_mark_sharding(input, sharding); + } m.def("_xla_clear_sharding", [](const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); xtensor->ClearShardingSpec(); diff --git a/torch_xla/csrc/ops/custom_mark_sharding.cpp b/torch_xla/csrc/ops/custom_mark_sharding.cpp index 53d3480855d..5e2e2e77a0c 100644 --- a/torch_xla/csrc/ops/custom_mark_sharding.cpp +++ b/torch_xla/csrc/ops/custom_mark_sharding.cpp @@ -6,18 +6,19 @@ namespace torch_xla { -CustomMarkSharding::CustomMarkSharding(const torch::lazy::Value& input) +CustomMarkSharding::CustomMarkSharding(const torch::lazy::Value& input, xla::OpSharding sharding) : XlaNode(xla_custom_mark_sharding, {input}, GetXlaShape(input), - /*num_outputs=*/1, torch::lazy::MHash(std::string("MarkSharding"))) {} + /*num_outputs=*/1, torch::lazy::MHash(std::string("MarkSharding"))), + sharding_(sharding) {} torch::lazy::NodePtr CustomMarkSharding::Clone(torch::lazy::OpList operands) const { - return torch::lazy::MakeNode(operands.at(0)); + return torch::lazy::MakeNode(operands.at(0), operands.at(1)); } XlaOpVector CustomMarkSharding::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); - xla::XlaOp output = BuildCustomMarkSharding(input); - return ReturnOp(output, loctx); + xla::XlaOp sharding = loctx->GetOutputOp(operand(1)); + return ReturnOp(BuildCustomMarkSharding(loctx->device(), input, sharding), loctx); } std::string CustomMarkSharding::ToString() const { diff --git a/torch_xla/csrc/ops/custom_mark_sharding.h b/torch_xla/csrc/ops/custom_mark_sharding.h index 15667cb94f7..7da19bc15fe 100644 --- a/torch_xla/csrc/ops/custom_mark_sharding.h +++ b/torch_xla/csrc/ops/custom_mark_sharding.h @@ -8,13 +8,16 @@ namespace torch_xla { class CustomMarkSharding : public XlaNode { public: // Make a custom call to Sharding. - CustomMarkSharding(const torch::lazy::Value& input); + CustomMarkSharding(const torch::lazy::Value& input, xla::OpSharding sharding); torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; XlaOpVector Lower(LoweringContext* loctx) const override; std::string ToString() const override; + + private: + xla::OpSharding sharding_; }; } // namespace torch_xla diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 935e987afe6..bf9fc6a7a3f 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -443,12 +443,10 @@ void custom_sharding_( } void custom_mark_sharding( - const XLATensorPtr& input, - const std::shared_ptr& sharding_spec) { + const XLATensorPtr& input, xla::OpSharding sharding) { torch::lazy::NodePtr node = torch::lazy::MakeNode( - torch::lazy::MakeNode(input->GetIrValue())); - return {input->CreateFrom(torch::lazy::Value(node, 0)), - torch::lazy::Value(node, 1)}; + torch::lazy::MakeNode(input->GetIrValue(), sharding)); + // TODO (@wonjoo) what do I return here? } XLATensorPtr get_dimensions_size(const XLATensorPtr& input, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 1ac71c9b4f8..0c28ee2b9f4 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -58,7 +58,7 @@ void custom_sharding_(const XLATensorPtr& input, const std::shared_ptr& spec); void custom_mark_sharding(const XLATensorPtr& input, - const std::shared_ptr& spec); + xla::OpSharding sharding); XLATensorPtr get_dimensions_size(const XLATensorPtr& input, std::vector dimensions); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 774f65a91e2..150bdf8be9a 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1228,9 +1228,24 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) { {input}, ShapeHelper::ShapeOfXlaOp(input)); } -xla::XlaOp BuildCustomMarkSharding(const xla::XlaOp& input) { - return xla::CustomCall(input.builder(), /*call_target_name=*/"MarkSharding", - {input}, ShapeHelper::ShapeOfXlaOp(input)); +xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device, const xla::XlaOp& input, xla::OpSharding sharding) { + auto new_sharding_spec = std::make_shared( + sharding, MakeShapeWithDeviceLayout( + ShapeHelper::ShapeOfXlaOp(input), + static_cast(device.type()))); + + // For Non DeviceData IR values, we directly attach the sharding spec + // to the xtensor. + const DeviceData* device_data_node = nullptr; + if (xtensor->CurrentIrValue()) { + device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (!device_data_node) { + tensor_methods::custom_sharding_(xtensor, new_sharding_spec); + return; + } + + // TODO move rest of `xla/torch_xla/csrc/init_python_bindings.cpp::_xla_mark_sharding`. + // Note to self: `_xla_mark_sharding` works with XLATensorPtr directly, as opposed to XlaOp here. } } // namespace torch_xla diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 71ac1ef1a02..4ddaf74d1a3 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -150,7 +150,7 @@ xla::XlaOp BuildCdistForward(xla::XlaOp x1, xla::XlaOp x2, xla::XlaOp p, xla::XlaOp BuildCustomSharding(const xla::XlaOp& input); -xla::XlaOp BuildCustomMarkSharding(const xla::XlaOp& input); +xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device, const xla::XlaOp& input, xla::OpSharding sharding); } // namespace torch_xla From 7891b42c66ab5936e70200684b4810da1e138c74 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Fri, 20 Oct 2023 19:35:28 +0000 Subject: [PATCH 03/14] Rebase with master and run linter --- torch_xla/csrc/init_python_bindings.cpp | 575 ++++++++++---------- torch_xla/csrc/ops/custom_mark_sharding.cpp | 15 +- torch_xla/csrc/ops/custom_mark_sharding.h | 4 +- torch_xla/csrc/tensor_methods.cpp | 3 +- torch_xla/csrc/tensor_methods.h | 3 +- torch_xla/csrc/xla_lower_util.cpp | 18 +- torch_xla/csrc/xla_lower_util.h | 4 +- 7 files changed, 319 insertions(+), 303 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a1aea978c9d..6861b3812aa 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -968,27 +968,28 @@ void InitXlaModuleBindings(py::module m) { ShardingUtil::ShardingType(sharding_type)), global_shape, minibatch); })); - m.def("_xla_tensors_from_aten", - [](const std::vector& tensors, - const std::vector& devices, - const std::optional>& - shardings) { - std::vector result; - { - NoGilSection nogil; - std::vector xla_tensors = - GetXlaTensorsFromAten(tensors, devices, shardings); - result.reserve(xla_tensors.size()); - for (size_t i = 0; i < xla_tensors.size(); ++i) { - result.push_back(torch::autograd::make_variable( - xla_tensors[i], - /*requires_grad=*/tensors.at(i).requires_grad())); - } + m.def( + "_xla_tensors_from_aten", + [](const std::vector& tensors, + const std::vector& devices, + const std::optional>& + shardings) { + std::vector result; + { + NoGilSection nogil; + std::vector xla_tensors = + GetXlaTensorsFromAten(tensors, devices, shardings); + result.reserve(xla_tensors.size()); + for (size_t i = 0; i < xla_tensors.size(); ++i) { + result.push_back(torch::autograd::make_variable( + xla_tensors[i], + /*requires_grad=*/tensors.at(i).requires_grad())); } - return result; - }, - py::arg("tensors"), py::arg("devices"), - py::arg("shardings") = py::none()); + } + return result; + }, + py::arg("tensors"), py::arg("devices"), + py::arg("shardings") = py::none()); m.def("_xla_get_cpu_tensors", [](const std::vector& tensors) { std::vector result; { @@ -1288,45 +1289,51 @@ void InitXlaModuleBindings(py::module m) { } return list; }); - m.def("_xla_set_rng_seed", - [](uint64_t seed, const std::string& device) { - SetRngSeed(seed, device); - }, - py::arg("seed") = 101, py::arg("device") = ""); - m.def("_xla_get_rng_seed", - [](const std::string& device) { return GetRngSeed(device); }, - py::arg("device") = ""); - m.def("_xla_sync_multi", - [](const std::vector& tensors, - const std::vector& devices, bool wait, - bool sync_xla_data) { - NoGilSection nogil; - SyncTensors(tensors, devices, wait, sync_xla_data); - }, - py::arg("tensors"), py::arg("devices"), py::arg("wait") = true, - py::arg("sync_xla_data") = true); - m.def("_xla_warm_up_cache", - [](const std::vector& tensors, - const std::vector& devices) { - NoGilSection nogil; - SyncTensors(tensors, devices, /*wait=*/false, /*sync_xla_data=*/false, - /*warm_up_cache_only=*/true); - }, - py::arg("tensors"), py::arg("devices")); - m.def("_xla_sync_live_tensors", - [](const std::string& device, const std::vector& devices, - bool wait) { - NoGilSection nogil; - SyncLiveTensors(device, devices, wait); - }, - py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); - m.def("_xla_step_marker", - [](const std::string& device, const std::vector& devices, - bool wait) { - NoGilSection nogil; - StepMarker(device, devices, wait); - }, - py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); + m.def( + "_xla_set_rng_seed", + [](uint64_t seed, const std::string& device) { + SetRngSeed(seed, device); + }, + py::arg("seed") = 101, py::arg("device") = ""); + m.def( + "_xla_get_rng_seed", + [](const std::string& device) { return GetRngSeed(device); }, + py::arg("device") = ""); + m.def( + "_xla_sync_multi", + [](const std::vector& tensors, + const std::vector& devices, bool wait, + bool sync_xla_data) { + NoGilSection nogil; + SyncTensors(tensors, devices, wait, sync_xla_data); + }, + py::arg("tensors"), py::arg("devices"), py::arg("wait") = true, + py::arg("sync_xla_data") = true); + m.def( + "_xla_warm_up_cache", + [](const std::vector& tensors, + const std::vector& devices) { + NoGilSection nogil; + SyncTensors(tensors, devices, /*wait=*/false, /*sync_xla_data=*/false, + /*warm_up_cache_only=*/true); + }, + py::arg("tensors"), py::arg("devices")); + m.def( + "_xla_sync_live_tensors", + [](const std::string& device, const std::vector& devices, + bool wait) { + NoGilSection nogil; + SyncLiveTensors(device, devices, wait); + }, + py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); + m.def( + "_xla_step_marker", + [](const std::string& device, const std::vector& devices, + bool wait) { + NoGilSection nogil; + StepMarker(device, devices, wait); + }, + py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); m.def("_get_stablehlo", [](const std::vector& tensors, const std::string& device, const std::vector& devices, @@ -1363,18 +1370,19 @@ void InitXlaModuleBindings(py::module m) { } return retlist; }); - m.def("_xla_wait_device_ops", - [](const std::vector& devices) { - NoGilSection nogil; - XLAGraphExecutor::Get()->WaitDeviceOps(devices); - if (UseVirtualDevice()) { - std::vector spmd_device = {"SPMD:0"}; - runtime::GetComputationClient()->WaitDeviceOps(spmd_device); - } else { - runtime::GetComputationClient()->WaitDeviceOps(devices); - } - }, - py::arg("devices")); + m.def( + "_xla_wait_device_ops", + [](const std::vector& devices) { + NoGilSection nogil; + XLAGraphExecutor::Get()->WaitDeviceOps(devices); + if (UseVirtualDevice()) { + std::vector spmd_device = {"SPMD:0"}; + runtime::GetComputationClient()->WaitDeviceOps(spmd_device); + } else { + runtime::GetComputationClient()->WaitDeviceOps(devices); + } + }, + py::arg("devices")); m.def("_xla_counter_names", []() { auto counter_names = torch::lazy::GetCounterNames(); auto xla_counter_names = runtime::metrics::GetCounterNames(); @@ -1439,21 +1447,23 @@ void InitXlaModuleBindings(py::module m) { torch::lazy::MetricsArena::Get()->ResetMetrics(); runtime::metrics::ClearMetrics(); }); - m.def("_xla_tensors_report", - [](size_t nodes_threshold, const std::string& device) { - return GetLiveTensorsReport(nodes_threshold, device); - }, - py::arg("nodes_threshold") = 100, py::arg("device") = ""); + m.def( + "_xla_tensors_report", + [](size_t nodes_threshold, const std::string& device) { + return GetLiveTensorsReport(nodes_threshold, device); + }, + py::arg("nodes_threshold") = 100, py::arg("device") = ""); m.def("_xla_memory_info", [](const std::string& device) -> py::object { return GetMemoryInfo(device); }); - m.def("_xla_set_use_full_mat_mul_precision", - [](bool use_full_mat_mul_precision) { - XlaHelpers::set_mat_mul_precision( - use_full_mat_mul_precision ? xla::PrecisionConfig::HIGHEST - : xla::PrecisionConfig::DEFAULT); - }, - py::arg("use_full_mat_mul_precision") = true); + m.def( + "_xla_set_use_full_mat_mul_precision", + [](bool use_full_mat_mul_precision) { + XlaHelpers::set_mat_mul_precision(use_full_mat_mul_precision + ? xla::PrecisionConfig::HIGHEST + : xla::PrecisionConfig::DEFAULT); + }, + py::arg("use_full_mat_mul_precision") = true); py::class_(m, "XlaBuilder"); py::class_(m, "XlaOp"); @@ -1631,7 +1641,7 @@ void InitXlaModuleBindings(py::module m) { XLA_CHECK(UseVirtualDevice()) << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; XLATensorPtr xtensor = bridge::GetXlaTensor(input); - tensor_methods::custom_mark_sharding(input, sharding); + tensor_methods::custom_mark_sharding(xtensor, sharding); } m.def("_xla_clear_sharding", [](const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); @@ -1643,33 +1653,32 @@ void InitXlaModuleBindings(py::module m) { }); m.def("_get_xla_sharding_specs", [](const std::vector& tensors) -> std::vector { - tsl::profiler::TraceMe activity("_get_xla_sharding_specs", - tsl::profiler::TraceMeLevel::kInfo); - TORCH_LAZY_TIMED("_get_xla_sharding_specs"); - std::vector sharding_specs; - sharding_specs.reserve(tensors.size()); - for (const at::Tensor& tensor : tensors) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); - XLATensor::ShardingSpecPtr sharding_spec = - xtensor ? xtensor->sharding_spec() : nullptr; - if (sharding_spec != nullptr) { - sharding_specs.push_back( - xla::HloSharding::FromProto(sharding_spec->sharding) - ->ToString()); - } else { - sharding_specs.push_back(""); - } - } - return sharding_specs; + tsl::profiler::TraceMe activity("_get_xla_sharding_specs", + tsl::profiler::TraceMeLevel::kInfo); + TORCH_LAZY_TIMED("_get_xla_sharding_specs"); + std::vector sharding_specs; + sharding_specs.reserve(tensors.size()); + for (const at::Tensor& tensor : tensors) { + XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensor::ShardingSpecPtr sharding_spec = + xtensor ? xtensor->sharding_spec() : nullptr; + if (sharding_spec != nullptr) { + sharding_specs.push_back( + xla::HloSharding::FromProto(sharding_spec->sharding)->ToString()); + } else { + sharding_specs.push_back(""); + } + } + return sharding_specs; }); m.def("_get_xla_sharding_type", [](const at::Tensor& input) -> std::optional { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - auto sharding_spec = xtensor->sharding_spec(); - if (sharding_spec != nullptr) { - return ShardingUtil::GetShardingType(sharding_spec->sharding); - } - return std::nullopt; + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + auto sharding_spec = xtensor->sharding_spec(); + if (sharding_spec != nullptr) { + return ShardingUtil::GetShardingType(sharding_spec->sharding); + } + return std::nullopt; }); // Reassemble the CPU shards into a global tensor. A new sharded tensor is // created from the local shards with the provided sharding annotation @@ -1745,33 +1754,31 @@ void InitXlaModuleBindings(py::module m) { m.def("_get_local_shards", [](const at::Tensor& input) -> std::tuple, std::vector> { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - XLA_CHECK(xtensor->GetXlaData() != nullptr) - << "Shard data is not available"; - XLA_CHECK(xtensor->sharding_spec() != nullptr) - << "Tensor is not sharded"; - XLA_CHECK(UseVirtualDevice()) - << "Virtual device must be enabled to use _get_local_shards"; - auto handle = - std::dynamic_pointer_cast( - xtensor->GetXlaData()); - std::vector shard_handles = - runtime::GetComputationClient()->GetDataShards(handle); - std::vector shards; - std::vector str_devices; - shards.reserve(shard_handles.size()); - str_devices.reserve(shard_handles.size()); - // Tansfer shards from the device and create cpu tensors. - for (const runtime::ComputationClient::DataPtr shard_handle : - shard_handles) { - shards.push_back( - XlaDataToTensors( - {shard_handle}, - TensorTypeFromXlaType(shard_handle->shape().element_type())) - .front()); - str_devices.push_back(shard_handle->device()); - } - return std::make_tuple(shards, str_devices); + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLA_CHECK(xtensor->GetXlaData() != nullptr) + << "Shard data is not available"; + XLA_CHECK(xtensor->sharding_spec() != nullptr) << "Tensor is not sharded"; + XLA_CHECK(UseVirtualDevice()) + << "Virtual device must be enabled to use _get_local_shards"; + auto handle = std::dynamic_pointer_cast( + xtensor->GetXlaData()); + std::vector shard_handles = + runtime::GetComputationClient()->GetDataShards(handle); + std::vector shards; + std::vector str_devices; + shards.reserve(shard_handles.size()); + str_devices.reserve(shard_handles.size()); + // Tansfer shards from the device and create cpu tensors. + for (const runtime::ComputationClient::DataPtr shard_handle : + shard_handles) { + shards.push_back( + XlaDataToTensors( + {shard_handle}, + TensorTypeFromXlaType(shard_handle->shape().element_type())) + .front()); + str_devices.push_back(shard_handle->device()); + } + return std::make_tuple(shards, str_devices); }); // For each local shard, returns the tuple: // (replica_id: int, indices: Union[List[Slice], Ellipsis]), @@ -1783,50 +1790,47 @@ void InitXlaModuleBindings(py::module m) { // of the shards returned from `_get_local_shards`. m.def("_get_local_shard_replica_and_indices", [](const at::Tensor& input) -> std::vector> { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - XLA_CHECK(xtensor->sharding_spec() != nullptr) - << "Tensor is not sharded"; - auto handle = - std::dynamic_pointer_cast( - xtensor->GetXlaData()); - auto shards = runtime::GetComputationClient()->GetDataShards(handle); - std::vector shard_devices; - for (auto& shard : shards) { - shard_devices.push_back(shard->device()); - } - auto sharding_spec = xtensor->sharding_spec(); - auto sharding = xtensor->sharding_spec()->sharding; - auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); - auto replica_and_indices = - ShardingUtil::GetShardReplicaAndIndicesForDevices( - shard_shape, input.sizes().vec(), sharding, shard_devices); - - // Convert each vector to List[py::slice] or py::ellipsis - std::vector> result; - result.reserve(shard_devices.size()); - for (auto& device_replica_and_indices : replica_and_indices) { - auto& replica_id = device_replica_and_indices.first; - auto& indices = device_replica_and_indices.second; - XLA_CHECK(indices.size() > 0) - << "Unexpected empty shard indices for tensor " << input; - if (indices[0].is_ellipsis()) { - result.push_back(std::make_pair(replica_id, py::ellipsis())); - } else { - std::vector index_slices; - for (auto& tensor_index : indices) { - XLA_CHECK(tensor_index.is_slice()) - << "Unexpected TensorIndex type: " << tensor_index; - auto slice = tensor_index.slice(); - ssize_t start = slice.start().expect_int(); - ssize_t stop = slice.stop().expect_int(); - ssize_t step = slice.step().expect_int(); - index_slices.push_back(py::slice(start, stop, step)); - } - result.push_back( - std::make_pair(replica_id, py::cast(index_slices))); - } - } - return result; + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLA_CHECK(xtensor->sharding_spec() != nullptr) << "Tensor is not sharded"; + auto handle = std::dynamic_pointer_cast( + xtensor->GetXlaData()); + auto shards = runtime::GetComputationClient()->GetDataShards(handle); + std::vector shard_devices; + for (auto& shard : shards) { + shard_devices.push_back(shard->device()); + } + auto sharding_spec = xtensor->sharding_spec(); + auto sharding = xtensor->sharding_spec()->sharding; + auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); + auto replica_and_indices = + ShardingUtil::GetShardReplicaAndIndicesForDevices( + shard_shape, input.sizes().vec(), sharding, shard_devices); + + // Convert each vector to List[py::slice] or py::ellipsis + std::vector> result; + result.reserve(shard_devices.size()); + for (auto& device_replica_and_indices : replica_and_indices) { + auto& replica_id = device_replica_and_indices.first; + auto& indices = device_replica_and_indices.second; + XLA_CHECK(indices.size() > 0) + << "Unexpected empty shard indices for tensor " << input; + if (indices[0].is_ellipsis()) { + result.push_back(std::make_pair(replica_id, py::ellipsis())); + } else { + std::vector index_slices; + for (auto& tensor_index : indices) { + XLA_CHECK(tensor_index.is_slice()) + << "Unexpected TensorIndex type: " << tensor_index; + auto slice = tensor_index.slice(); + ssize_t start = slice.start().expect_int(); + ssize_t stop = slice.stop().expect_int(); + ssize_t step = slice.step().expect_int(); + index_slices.push_back(py::slice(start, stop, step)); + } + result.push_back(std::make_pair(replica_id, py::cast(index_slices))); + } + } + return result; }); // Load a list of local shards into an explicitly-sharded tensor. A shard must // be provided for each device. @@ -1863,26 +1867,25 @@ void InitXlaModuleBindings(py::module m) { bool choose_faster_windowed_einsum = false, bool unroll_windowed_einsum = false, bool bidirectional_windowed_einsum = false) -> std::string { - xla::HloModuleConfig config; - config.set_use_spmd_partitioning(true); - config.set_replica_count(num_replicas); - config.set_num_partitions(num_devices); - - std::string hlo_text = - GetTensorsHloGraph(tensors, EmitMode::kHloReadable); - auto hlo_module_error = - xla::ParseAndReturnUnverifiedModule(hlo_text, config); - XLA_CHECK_OK(hlo_module_error.status()) - << "HLO Module loading failed: " << hlo_module_error.status(); - - auto module = std::move(hlo_module_error.value()); - xla::HloModuleProto module_proto = ShardingUtil::SpmdPartitioningPass( - module->ToProto(), num_replicas, num_devices, - conv_halo_exchange_always_on_lhs, choose_faster_windowed_einsum, - unroll_windowed_einsum, bidirectional_windowed_einsum); - module = std::move( - xla::HloModule::CreateFromProto(module_proto, config).value()); - return module->ToString(); + xla::HloModuleConfig config; + config.set_use_spmd_partitioning(true); + config.set_replica_count(num_replicas); + config.set_num_partitions(num_devices); + + std::string hlo_text = GetTensorsHloGraph(tensors, EmitMode::kHloReadable); + auto hlo_module_error = + xla::ParseAndReturnUnverifiedModule(hlo_text, config); + XLA_CHECK_OK(hlo_module_error.status()) + << "HLO Module loading failed: " << hlo_module_error.status(); + + auto module = std::move(hlo_module_error.value()); + xla::HloModuleProto module_proto = ShardingUtil::SpmdPartitioningPass( + module->ToProto(), num_replicas, num_devices, + conv_halo_exchange_always_on_lhs, choose_faster_windowed_einsum, + unroll_windowed_einsum, bidirectional_windowed_einsum); + module = std::move( + xla::HloModule::CreateFromProto(module_proto, config).value()); + return module->ToString(); }); m.def("_is_placecholder", [](at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); @@ -1894,32 +1897,36 @@ void InitXlaModuleBindings(py::module m) { InitXlaBackend(); }); m.def("_set_ir_debug", - [](bool ir_debug) { FLAGS_torch_lazy_ir_debug = ir_debug; }); - m.def("_get_ir_debug", []() { return FLAGS_torch_lazy_ir_debug; }); + [](bool ir_debug) { + FLAGS_torch_lazy_ir_debug = ir_debug; }); + m.def("_get_ir_debug", []() { + return FLAGS_torch_lazy_ir_debug; }); m.def("_set_xla_handle_special_scalars", [](bool handle_special_scalars) { FLAGS_torch_lazy_handle_special_scalars = handle_special_scalars; }); m.def("_get_xla_handle_special_scalars", - []() { return FLAGS_torch_lazy_handle_special_scalars; }); + []() { + return FLAGS_torch_lazy_handle_special_scalars; }); m.def("_set_xla_enable_device_data_cache", [](bool enable_device_data_cache) { FLAGS_torch_lazy_enable_device_data_cache = enable_device_data_cache; }); m.def("_get_xla_enable_device_data_cache", - []() { return FLAGS_torch_lazy_enable_device_data_cache; }); + []() { + return FLAGS_torch_lazy_enable_device_data_cache; }); m.def("_replace_xla_tensor", [](at::Tensor& self, const at::Tensor& source) -> at::Tensor& { - return XLANativeFunctions::set_(self, source); + return XLANativeFunctions::set_(self, source); }); m.def("_get_all_reduce_token", [](const std::string& device_str) -> const torch::lazy::Value& { - auto device = GetDeviceOrCurrent(device_str); - return GetAllReduceToken(device); + auto device = GetDeviceOrCurrent(device_str); + return GetAllReduceToken(device); }); m.def("_set_all_reduce_token", [](const std::string& device_str, const std::shared_ptr& token) { - auto device = GetDeviceOrCurrent(device_str); - SetAllReduceToken(device, token); + auto device = GetDeviceOrCurrent(device_str); + SetAllReduceToken(device, token); }); BuildProfilerSubmodule(&m); @@ -1927,12 +1934,12 @@ void InitXlaModuleBindings(py::module m) { m.def("_get_tensors_handle", [](const std::vector& tensors) -> std::vector { - std::vector handles; - handles.reserve(tensors.size()); - for (auto& tensor : tensors) { - handles.push_back(bridge::GetXlaTensor(tensor)->GetHandle()); - } - return handles; + std::vector handles; + handles.reserve(tensors.size()); + for (auto& tensor : tensors) { + handles.push_back(bridge::GetXlaTensor(tensor)->GetHandle()); + } + return handles; }); // -------------Dynamo Integration API Start------------------------- @@ -1943,76 +1950,77 @@ void InitXlaModuleBindings(py::module m) { m.def("_get_tensors_xla_device_data_node", [](const std::vector& tensors) -> std::pair, std::vector> { - std::vector tensor_ids; - std::vector ivalues; - std::vector roots; - for (const at::Tensor& tensor : tensors) { - auto xtensor = bridge::TryGetXlaTensor(tensor); - if (xtensor) { - roots.push_back(xtensor->GetIrValue().node.get()); - } - } - auto post_order = torch::lazy::Util::ComputePostOrder(roots); - std::unordered_set data_handles; - - for (const torch::lazy::Node* nodeptr : post_order) { - const auto backend_data = - torch::lazy::getBackend()->GetComputationDataFromNode(nodeptr); - if (!backend_data) { - continue; - } + std::vector tensor_ids; + std::vector ivalues; + std::vector roots; + for (const at::Tensor& tensor : tensors) { + auto xtensor = bridge::TryGetXlaTensor(tensor); + if (xtensor) { + roots.push_back(xtensor->GetIrValue().node.get()); + } + } + auto post_order = torch::lazy::Util::ComputePostOrder(roots); + std::unordered_set data_handles; - // Dedup by handle - torch::lazy::BackendData::Handle handle = backend_data->GetHandle(); - if (!data_handles.insert(handle).second) { - continue; - } - auto* infoptr = - static_cast( - backend_data->info()); - if (infoptr) { - tensor_ids.push_back(infoptr->tensor_id); - } else { - // TODO(JackCaoG): Make sure this device data is actually seed. - tensor_ids.push_back(seed_info_id); - } - at::Tensor tensor = bridge::AtenFromXlaTensor( - torch_xla::XLATensor::Create(backend_data)); - ivalues.emplace_back(tensor); - } - return std::make_pair(tensor_ids, ivalues); + for (const torch::lazy::Node* nodeptr : post_order) { + const auto backend_data = + torch::lazy::getBackend()->GetComputationDataFromNode(nodeptr); + if (!backend_data) { + continue; + } + + // Dedup by handle + torch::lazy::BackendData::Handle handle = backend_data->GetHandle(); + if (!data_handles.insert(handle).second) { + continue; + } + auto* infoptr = + static_cast( + backend_data->info()); + if (infoptr) { + tensor_ids.push_back(infoptr->tensor_id); + } else { + // TODO(JackCaoG): Make sure this device data is actually seed. + tensor_ids.push_back(seed_info_id); + } + at::Tensor tensor = + bridge::AtenFromXlaTensor(torch_xla::XLATensor::Create(backend_data)); + ivalues.emplace_back(tensor); + } + return std::make_pair(tensor_ids, ivalues); }); - m.def("_get_seed_info_id", []() -> int64_t { return seed_info_id; }); + m.def("_get_seed_info_id", []() -> int64_t { + return seed_info_id; }); m.def("_get_base_seed_as_tensor", [](const std::string& device_str) -> at::IValue { - torch::lazy::BackendDevice device = - bridge::AtenDeviceToXlaDevice(c10::Device(device_str)); - return bridge::AtenFromXlaTensor(torch_xla::XLATensor::Create( - XLAGraphExecutor::Get()->GetBaseSeedData(device))); + torch::lazy::BackendDevice device = + bridge::AtenDeviceToXlaDevice(c10::Device(device_str)); + return bridge::AtenFromXlaTensor(torch_xla::XLATensor::Create( + XLAGraphExecutor::Get()->GetBaseSeedData(device))); }); // Return true if value of the tensor requires a computation. m.def("_check_tensor_need_materialization", [](const std::vector& tensors) -> std::vector { - std::vector xtensors; - xtensors.reserve(tensors.size()); - for (const at::Tensor& tensor : tensors) { - xtensors.push_back(bridge::TryGetXlaTensor(tensor)); - } - return check_materialization_helper(xtensors); + std::vector xtensors; + xtensors.reserve(tensors.size()); + for (const at::Tensor& tensor : tensors) { + xtensors.push_back(bridge::TryGetXlaTensor(tensor)); + } + return check_materialization_helper(xtensors); }); // Return true if value of the any tensor in this devicerequires a // computation. m.def("_check_device_tensor_need_materialization", [](const std::string& device_str) -> std::vector { - auto opt_device = GetOptionalDevice(device_str); - std::vector xtensors = - XLAGraphExecutor::Get()->GetLiveTensors( - opt_device ? &opt_device.value() : nullptr); - return check_materialization_helper(xtensors); + auto opt_device = GetOptionalDevice(device_str); + std::vector xtensors = + XLAGraphExecutor::Get()->GetLiveTensors(opt_device ? &opt_device.value() + : nullptr); + return check_materialization_helper(xtensors); }); m.def("_get_graph_hash", [](const std::vector& tensors) { @@ -2038,24 +2046,23 @@ void InitXlaModuleBindings(py::module m) { [](const std::string& hash_str, const std::vector& graph_inputs) -> std::vector { - XLA_CHECK(hash_str.size() == sizeof(torch::lazy::hash_t)); - torch::lazy::hash_t hash = *(torch::lazy::hash_t*)(hash_str.c_str()); - // Device will be Virtual device if SPMD is enabled. - torch::lazy::BackendDevice device = - torch_xla::bridge::GetCurrentDevice(); - auto results = XLAGraphExecutor::Get()->ExecuteComputationWithBarrier( - hash, graph_inputs, device); - std::vector retlist; - { - TORCH_LAZY_TIMED("RunCachedGraphOutputData"); - // Convert result back to at::tensor - for (const auto& data : results) { - XLATensorPtr xla_tensor = torch_xla::XLATensor::Create(data); - retlist.push_back(bridge::AtenFromXlaTensor(xla_tensor)); - } - } + XLA_CHECK(hash_str.size() == sizeof(torch::lazy::hash_t)); + torch::lazy::hash_t hash = *(torch::lazy::hash_t*)(hash_str.c_str()); + // Device will be Virtual device if SPMD is enabled. + torch::lazy::BackendDevice device = torch_xla::bridge::GetCurrentDevice(); + auto results = XLAGraphExecutor::Get()->ExecuteComputationWithBarrier( + hash, graph_inputs, device); + std::vector retlist; + { + TORCH_LAZY_TIMED("RunCachedGraphOutputData"); + // Convert result back to at::tensor + for (const auto& data : results) { + XLATensorPtr xla_tensor = torch_xla::XLATensor::Create(data); + retlist.push_back(bridge::AtenFromXlaTensor(xla_tensor)); + } + } - return retlist; + return retlist; }); // -------------Dynamo Integration API End------------------------- } diff --git a/torch_xla/csrc/ops/custom_mark_sharding.cpp b/torch_xla/csrc/ops/custom_mark_sharding.cpp index 5e2e2e77a0c..bb3d371c0ae 100644 --- a/torch_xla/csrc/ops/custom_mark_sharding.cpp +++ b/torch_xla/csrc/ops/custom_mark_sharding.cpp @@ -6,19 +6,24 @@ namespace torch_xla { -CustomMarkSharding::CustomMarkSharding(const torch::lazy::Value& input, xla::OpSharding sharding) +CustomMarkSharding::CustomMarkSharding(const torch::lazy::Value& input, + xla::OpSharding sharding) : XlaNode(xla_custom_mark_sharding, {input}, GetXlaShape(input), - /*num_outputs=*/1, torch::lazy::MHash(std::string("MarkSharding"))), + /*num_outputs=*/1, + torch::lazy::MHash(std::string("MarkSharding"))), sharding_(sharding) {} -torch::lazy::NodePtr CustomMarkSharding::Clone(torch::lazy::OpList operands) const { - return torch::lazy::MakeNode(operands.at(0), operands.at(1)); +torch::lazy::NodePtr CustomMarkSharding::Clone( + torch::lazy::OpList operands) const { + return torch::lazy::MakeNode(operands.at(0), + operands.at(1)); } XlaOpVector CustomMarkSharding::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); xla::XlaOp sharding = loctx->GetOutputOp(operand(1)); - return ReturnOp(BuildCustomMarkSharding(loctx->device(), input, sharding), loctx); + return ReturnOp(BuildCustomMarkSharding(loctx->device(), input, sharding), + loctx); } std::string CustomMarkSharding::ToString() const { diff --git a/torch_xla/csrc/ops/custom_mark_sharding.h b/torch_xla/csrc/ops/custom_mark_sharding.h index 7da19bc15fe..0870a99112c 100644 --- a/torch_xla/csrc/ops/custom_mark_sharding.h +++ b/torch_xla/csrc/ops/custom_mark_sharding.h @@ -16,8 +16,8 @@ class CustomMarkSharding : public XlaNode { std::string ToString() const override; - private: - xla::OpSharding sharding_; + private: + xla::OpSharding sharding_; }; } // namespace torch_xla diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index bf9fc6a7a3f..59407e17044 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -442,8 +442,7 @@ void custom_sharding_( input->SetShardingSpec(*sharding_spec); } -void custom_mark_sharding( - const XLATensorPtr& input, xla::OpSharding sharding) { +void custom_mark_sharding(const XLATensorPtr& input, xla::OpSharding sharding) { torch::lazy::NodePtr node = torch::lazy::MakeNode( torch::lazy::MakeNode(input->GetIrValue(), sharding)); // TODO (@wonjoo) what do I return here? diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 0c28ee2b9f4..0f2bf21f0c4 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -57,8 +57,7 @@ std::pair collective_permute( void custom_sharding_(const XLATensorPtr& input, const std::shared_ptr& spec); -void custom_mark_sharding(const XLATensorPtr& input, - xla::OpSharding sharding); +void custom_mark_sharding(const XLATensorPtr& input, xla::OpSharding sharding); XLATensorPtr get_dimensions_size(const XLATensorPtr& input, std::vector dimensions); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 150bdf8be9a..eef41c83980 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1228,11 +1228,13 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) { {input}, ShapeHelper::ShapeOfXlaOp(input)); } -xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device, const xla::XlaOp& input, xla::OpSharding sharding) { +xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device, + const xla::XlaOp& input, + xla::OpSharding sharding) { auto new_sharding_spec = std::make_shared( - sharding, MakeShapeWithDeviceLayout( - ShapeHelper::ShapeOfXlaOp(input), - static_cast(device.type()))); + sharding, + MakeShapeWithDeviceLayout(ShapeHelper::ShapeOfXlaOp(input), + static_cast(device.type()))); // For Non DeviceData IR values, we directly attach the sharding spec // to the xtensor. @@ -1244,8 +1246,10 @@ xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device, con return; } - // TODO move rest of `xla/torch_xla/csrc/init_python_bindings.cpp::_xla_mark_sharding`. - // Note to self: `_xla_mark_sharding` works with XLATensorPtr directly, as opposed to XlaOp here. -} + // TODO move rest of + // `xla/torch_xla/csrc/init_python_bindings.cpp::_xla_mark_sharding`. Note + // to self: `_xla_mark_sharding` works with XLATensorPtr directly, as + // opposed to XlaOp here. + } } // namespace torch_xla diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 4ddaf74d1a3..ae7ed8ff06e 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -150,7 +150,9 @@ xla::XlaOp BuildCdistForward(xla::XlaOp x1, xla::XlaOp x2, xla::XlaOp p, xla::XlaOp BuildCustomSharding(const xla::XlaOp& input); -xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device, const xla::XlaOp& input, xla::OpSharding sharding); +xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device, + const xla::XlaOp& input, + xla::OpSharding sharding); } // namespace torch_xla From 6aeeecffa670186f8d1a367410fb21c84493c6f3 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Tue, 24 Oct 2023 20:48:10 +0000 Subject: [PATCH 04/14] Update unit tests --- test/spmd/test_dynamo_spmd.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index cc3edc4f79b..7dc839979ee 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -171,18 +171,23 @@ def test_dynamo_input_sharding_threashold(self): else: del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] - def test_mark_sharding_after_compile(self): + def test_mark_sharding_inside_compile(self): + + def fn_simple(x): + y = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], + dtype=torch.float, + device=xm.xla_device()) + ys = xs.mark_sharding(y, self._get_mesh((1, self.n_devices)), (0, 1)) + + return x + ys + device = xm.xla_device() - linear = SimpleLinear().to(device) - linear.eval() - xla_x = torch.randn(1, 128, device=device) - xs.mark_sharding(linear.fc2.weight, self._get_mesh((1, self.n_devices)), - (1, 0)) - xla_res = linear(xla_x) + x_xla = torch.zeros((1, 8)).to(device) + xla_res = fn_simple(x_xla) xm.mark_step() - dynamo_linear = torch.compile(linear, backend="openxla") - dynamo_res = dynamo_linear(xla_x) + dynamo_linear = torch.compile(fn_simple, backend="openxla") + dynamo_res = dynamo_linear(x_xla) torch.allclose(xla_res.cpu(), dynamo_res.cpu()) From dc19b9bcf8926e99731ef2b61e4c28cdaa227f89 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Tue, 24 Oct 2023 22:22:35 +0000 Subject: [PATCH 05/14] Refine custom marking sharding op --- torch_xla/csrc/aten_xla_type.cpp | 7 +- torch_xla/csrc/init_python_bindings.cpp | 410 ++++++++++---------- torch_xla/csrc/ops/custom_mark_sharding.cpp | 7 +- torch_xla/csrc/ops/custom_mark_sharding.h | 5 +- torch_xla/csrc/ops/xla_ops.cpp | 1 + torch_xla/csrc/ops/xla_ops.h | 1 + torch_xla/csrc/tensor_methods.cpp | 72 +++- torch_xla/csrc/xla_lower_util.cpp | 25 +- torch_xla/csrc/xla_lower_util.h | 2 +- 9 files changed, 288 insertions(+), 242 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 49780fc8811..dfceec0353a 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -114,10 +114,9 @@ void CheckSubOperandTypes(at::ScalarType type1, at::ScalarType type2) { c10::optional PromoteIntegralType( at::ScalarType src_dtype, const c10::optional& opt_dtype) { - return opt_dtype.has_value() - ? opt_dtype.value() - : at::isIntegralType(src_dtype, /*includeBool=*/true) ? at::kLong - : opt_dtype; + return opt_dtype.has_value() ? opt_dtype.value() + : at::isIntegralType(src_dtype, /*includeBool=*/true) ? at::kLong + : opt_dtype; } bool IsTypeWithLargerRangeThanLong(torch::ScalarType dtype) { diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 6861b3812aa..7d380f725b6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1635,14 +1635,11 @@ void InitXlaModuleBindings(py::module m) { // Register sharded tensor data. XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); }); - m.def("_xla_mark_sharding_custom_op", [](const at::Tensor& input, - xla::OpSharding sharding) { - TORCH_LAZY_COUNTER("XlaMarkSharding", 1); - XLA_CHECK(UseVirtualDevice()) - << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - tensor_methods::custom_mark_sharding(xtensor, sharding); - } + m.def("_xla_mark_sharding_custom_op", + [](const at::Tensor& input, xla::OpSharding sharding) { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + tensor_methods::custom_mark_sharding(xtensor, sharding); + }); m.def("_xla_clear_sharding", [](const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); xtensor->ClearShardingSpec(); @@ -1653,32 +1650,33 @@ void InitXlaModuleBindings(py::module m) { }); m.def("_get_xla_sharding_specs", [](const std::vector& tensors) -> std::vector { - tsl::profiler::TraceMe activity("_get_xla_sharding_specs", - tsl::profiler::TraceMeLevel::kInfo); - TORCH_LAZY_TIMED("_get_xla_sharding_specs"); - std::vector sharding_specs; - sharding_specs.reserve(tensors.size()); - for (const at::Tensor& tensor : tensors) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); - XLATensor::ShardingSpecPtr sharding_spec = - xtensor ? xtensor->sharding_spec() : nullptr; - if (sharding_spec != nullptr) { - sharding_specs.push_back( - xla::HloSharding::FromProto(sharding_spec->sharding)->ToString()); - } else { - sharding_specs.push_back(""); - } - } - return sharding_specs; + tsl::profiler::TraceMe activity("_get_xla_sharding_specs", + tsl::profiler::TraceMeLevel::kInfo); + TORCH_LAZY_TIMED("_get_xla_sharding_specs"); + std::vector sharding_specs; + sharding_specs.reserve(tensors.size()); + for (const at::Tensor& tensor : tensors) { + XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensor::ShardingSpecPtr sharding_spec = + xtensor ? xtensor->sharding_spec() : nullptr; + if (sharding_spec != nullptr) { + sharding_specs.push_back( + xla::HloSharding::FromProto(sharding_spec->sharding) + ->ToString()); + } else { + sharding_specs.push_back(""); + } + } + return sharding_specs; }); m.def("_get_xla_sharding_type", [](const at::Tensor& input) -> std::optional { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - auto sharding_spec = xtensor->sharding_spec(); - if (sharding_spec != nullptr) { - return ShardingUtil::GetShardingType(sharding_spec->sharding); - } - return std::nullopt; + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + auto sharding_spec = xtensor->sharding_spec(); + if (sharding_spec != nullptr) { + return ShardingUtil::GetShardingType(sharding_spec->sharding); + } + return std::nullopt; }); // Reassemble the CPU shards into a global tensor. A new sharded tensor is // created from the local shards with the provided sharding annotation @@ -1754,31 +1752,33 @@ void InitXlaModuleBindings(py::module m) { m.def("_get_local_shards", [](const at::Tensor& input) -> std::tuple, std::vector> { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - XLA_CHECK(xtensor->GetXlaData() != nullptr) - << "Shard data is not available"; - XLA_CHECK(xtensor->sharding_spec() != nullptr) << "Tensor is not sharded"; - XLA_CHECK(UseVirtualDevice()) - << "Virtual device must be enabled to use _get_local_shards"; - auto handle = std::dynamic_pointer_cast( - xtensor->GetXlaData()); - std::vector shard_handles = - runtime::GetComputationClient()->GetDataShards(handle); - std::vector shards; - std::vector str_devices; - shards.reserve(shard_handles.size()); - str_devices.reserve(shard_handles.size()); - // Tansfer shards from the device and create cpu tensors. - for (const runtime::ComputationClient::DataPtr shard_handle : - shard_handles) { - shards.push_back( - XlaDataToTensors( - {shard_handle}, - TensorTypeFromXlaType(shard_handle->shape().element_type())) - .front()); - str_devices.push_back(shard_handle->device()); - } - return std::make_tuple(shards, str_devices); + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLA_CHECK(xtensor->GetXlaData() != nullptr) + << "Shard data is not available"; + XLA_CHECK(xtensor->sharding_spec() != nullptr) + << "Tensor is not sharded"; + XLA_CHECK(UseVirtualDevice()) + << "Virtual device must be enabled to use _get_local_shards"; + auto handle = + std::dynamic_pointer_cast( + xtensor->GetXlaData()); + std::vector shard_handles = + runtime::GetComputationClient()->GetDataShards(handle); + std::vector shards; + std::vector str_devices; + shards.reserve(shard_handles.size()); + str_devices.reserve(shard_handles.size()); + // Tansfer shards from the device and create cpu tensors. + for (const runtime::ComputationClient::DataPtr shard_handle : + shard_handles) { + shards.push_back( + XlaDataToTensors( + {shard_handle}, + TensorTypeFromXlaType(shard_handle->shape().element_type())) + .front()); + str_devices.push_back(shard_handle->device()); + } + return std::make_tuple(shards, str_devices); }); // For each local shard, returns the tuple: // (replica_id: int, indices: Union[List[Slice], Ellipsis]), @@ -1790,47 +1790,50 @@ void InitXlaModuleBindings(py::module m) { // of the shards returned from `_get_local_shards`. m.def("_get_local_shard_replica_and_indices", [](const at::Tensor& input) -> std::vector> { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - XLA_CHECK(xtensor->sharding_spec() != nullptr) << "Tensor is not sharded"; - auto handle = std::dynamic_pointer_cast( - xtensor->GetXlaData()); - auto shards = runtime::GetComputationClient()->GetDataShards(handle); - std::vector shard_devices; - for (auto& shard : shards) { - shard_devices.push_back(shard->device()); - } - auto sharding_spec = xtensor->sharding_spec(); - auto sharding = xtensor->sharding_spec()->sharding; - auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); - auto replica_and_indices = - ShardingUtil::GetShardReplicaAndIndicesForDevices( - shard_shape, input.sizes().vec(), sharding, shard_devices); - - // Convert each vector to List[py::slice] or py::ellipsis - std::vector> result; - result.reserve(shard_devices.size()); - for (auto& device_replica_and_indices : replica_and_indices) { - auto& replica_id = device_replica_and_indices.first; - auto& indices = device_replica_and_indices.second; - XLA_CHECK(indices.size() > 0) - << "Unexpected empty shard indices for tensor " << input; - if (indices[0].is_ellipsis()) { - result.push_back(std::make_pair(replica_id, py::ellipsis())); - } else { - std::vector index_slices; - for (auto& tensor_index : indices) { - XLA_CHECK(tensor_index.is_slice()) - << "Unexpected TensorIndex type: " << tensor_index; - auto slice = tensor_index.slice(); - ssize_t start = slice.start().expect_int(); - ssize_t stop = slice.stop().expect_int(); - ssize_t step = slice.step().expect_int(); - index_slices.push_back(py::slice(start, stop, step)); - } - result.push_back(std::make_pair(replica_id, py::cast(index_slices))); - } - } - return result; + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLA_CHECK(xtensor->sharding_spec() != nullptr) + << "Tensor is not sharded"; + auto handle = + std::dynamic_pointer_cast( + xtensor->GetXlaData()); + auto shards = runtime::GetComputationClient()->GetDataShards(handle); + std::vector shard_devices; + for (auto& shard : shards) { + shard_devices.push_back(shard->device()); + } + auto sharding_spec = xtensor->sharding_spec(); + auto sharding = xtensor->sharding_spec()->sharding; + auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); + auto replica_and_indices = + ShardingUtil::GetShardReplicaAndIndicesForDevices( + shard_shape, input.sizes().vec(), sharding, shard_devices); + + // Convert each vector to List[py::slice] or py::ellipsis + std::vector> result; + result.reserve(shard_devices.size()); + for (auto& device_replica_and_indices : replica_and_indices) { + auto& replica_id = device_replica_and_indices.first; + auto& indices = device_replica_and_indices.second; + XLA_CHECK(indices.size() > 0) + << "Unexpected empty shard indices for tensor " << input; + if (indices[0].is_ellipsis()) { + result.push_back(std::make_pair(replica_id, py::ellipsis())); + } else { + std::vector index_slices; + for (auto& tensor_index : indices) { + XLA_CHECK(tensor_index.is_slice()) + << "Unexpected TensorIndex type: " << tensor_index; + auto slice = tensor_index.slice(); + ssize_t start = slice.start().expect_int(); + ssize_t stop = slice.stop().expect_int(); + ssize_t step = slice.step().expect_int(); + index_slices.push_back(py::slice(start, stop, step)); + } + result.push_back( + std::make_pair(replica_id, py::cast(index_slices))); + } + } + return result; }); // Load a list of local shards into an explicitly-sharded tensor. A shard must // be provided for each device. @@ -1867,25 +1870,26 @@ void InitXlaModuleBindings(py::module m) { bool choose_faster_windowed_einsum = false, bool unroll_windowed_einsum = false, bool bidirectional_windowed_einsum = false) -> std::string { - xla::HloModuleConfig config; - config.set_use_spmd_partitioning(true); - config.set_replica_count(num_replicas); - config.set_num_partitions(num_devices); - - std::string hlo_text = GetTensorsHloGraph(tensors, EmitMode::kHloReadable); - auto hlo_module_error = - xla::ParseAndReturnUnverifiedModule(hlo_text, config); - XLA_CHECK_OK(hlo_module_error.status()) - << "HLO Module loading failed: " << hlo_module_error.status(); - - auto module = std::move(hlo_module_error.value()); - xla::HloModuleProto module_proto = ShardingUtil::SpmdPartitioningPass( - module->ToProto(), num_replicas, num_devices, - conv_halo_exchange_always_on_lhs, choose_faster_windowed_einsum, - unroll_windowed_einsum, bidirectional_windowed_einsum); - module = std::move( - xla::HloModule::CreateFromProto(module_proto, config).value()); - return module->ToString(); + xla::HloModuleConfig config; + config.set_use_spmd_partitioning(true); + config.set_replica_count(num_replicas); + config.set_num_partitions(num_devices); + + std::string hlo_text = + GetTensorsHloGraph(tensors, EmitMode::kHloReadable); + auto hlo_module_error = + xla::ParseAndReturnUnverifiedModule(hlo_text, config); + XLA_CHECK_OK(hlo_module_error.status()) + << "HLO Module loading failed: " << hlo_module_error.status(); + + auto module = std::move(hlo_module_error.value()); + xla::HloModuleProto module_proto = ShardingUtil::SpmdPartitioningPass( + module->ToProto(), num_replicas, num_devices, + conv_halo_exchange_always_on_lhs, choose_faster_windowed_einsum, + unroll_windowed_einsum, bidirectional_windowed_einsum); + module = std::move( + xla::HloModule::CreateFromProto(module_proto, config).value()); + return module->ToString(); }); m.def("_is_placecholder", [](at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); @@ -1897,36 +1901,32 @@ void InitXlaModuleBindings(py::module m) { InitXlaBackend(); }); m.def("_set_ir_debug", - [](bool ir_debug) { - FLAGS_torch_lazy_ir_debug = ir_debug; }); - m.def("_get_ir_debug", []() { - return FLAGS_torch_lazy_ir_debug; }); + [](bool ir_debug) { FLAGS_torch_lazy_ir_debug = ir_debug; }); + m.def("_get_ir_debug", []() { return FLAGS_torch_lazy_ir_debug; }); m.def("_set_xla_handle_special_scalars", [](bool handle_special_scalars) { FLAGS_torch_lazy_handle_special_scalars = handle_special_scalars; }); m.def("_get_xla_handle_special_scalars", - []() { - return FLAGS_torch_lazy_handle_special_scalars; }); + []() { return FLAGS_torch_lazy_handle_special_scalars; }); m.def("_set_xla_enable_device_data_cache", [](bool enable_device_data_cache) { FLAGS_torch_lazy_enable_device_data_cache = enable_device_data_cache; }); m.def("_get_xla_enable_device_data_cache", - []() { - return FLAGS_torch_lazy_enable_device_data_cache; }); + []() { return FLAGS_torch_lazy_enable_device_data_cache; }); m.def("_replace_xla_tensor", [](at::Tensor& self, const at::Tensor& source) -> at::Tensor& { - return XLANativeFunctions::set_(self, source); + return XLANativeFunctions::set_(self, source); }); m.def("_get_all_reduce_token", [](const std::string& device_str) -> const torch::lazy::Value& { - auto device = GetDeviceOrCurrent(device_str); - return GetAllReduceToken(device); + auto device = GetDeviceOrCurrent(device_str); + return GetAllReduceToken(device); }); m.def("_set_all_reduce_token", [](const std::string& device_str, const std::shared_ptr& token) { - auto device = GetDeviceOrCurrent(device_str); - SetAllReduceToken(device, token); + auto device = GetDeviceOrCurrent(device_str); + SetAllReduceToken(device, token); }); BuildProfilerSubmodule(&m); @@ -1934,12 +1934,12 @@ void InitXlaModuleBindings(py::module m) { m.def("_get_tensors_handle", [](const std::vector& tensors) -> std::vector { - std::vector handles; - handles.reserve(tensors.size()); - for (auto& tensor : tensors) { - handles.push_back(bridge::GetXlaTensor(tensor)->GetHandle()); - } - return handles; + std::vector handles; + handles.reserve(tensors.size()); + for (auto& tensor : tensors) { + handles.push_back(bridge::GetXlaTensor(tensor)->GetHandle()); + } + return handles; }); // -------------Dynamo Integration API Start------------------------- @@ -1950,77 +1950,76 @@ void InitXlaModuleBindings(py::module m) { m.def("_get_tensors_xla_device_data_node", [](const std::vector& tensors) -> std::pair, std::vector> { - std::vector tensor_ids; - std::vector ivalues; - std::vector roots; - for (const at::Tensor& tensor : tensors) { - auto xtensor = bridge::TryGetXlaTensor(tensor); - if (xtensor) { - roots.push_back(xtensor->GetIrValue().node.get()); - } - } - auto post_order = torch::lazy::Util::ComputePostOrder(roots); - std::unordered_set data_handles; - - for (const torch::lazy::Node* nodeptr : post_order) { - const auto backend_data = - torch::lazy::getBackend()->GetComputationDataFromNode(nodeptr); - if (!backend_data) { - continue; - } + std::vector tensor_ids; + std::vector ivalues; + std::vector roots; + for (const at::Tensor& tensor : tensors) { + auto xtensor = bridge::TryGetXlaTensor(tensor); + if (xtensor) { + roots.push_back(xtensor->GetIrValue().node.get()); + } + } + auto post_order = torch::lazy::Util::ComputePostOrder(roots); + std::unordered_set data_handles; + + for (const torch::lazy::Node* nodeptr : post_order) { + const auto backend_data = + torch::lazy::getBackend()->GetComputationDataFromNode(nodeptr); + if (!backend_data) { + continue; + } - // Dedup by handle - torch::lazy::BackendData::Handle handle = backend_data->GetHandle(); - if (!data_handles.insert(handle).second) { - continue; - } - auto* infoptr = - static_cast( - backend_data->info()); - if (infoptr) { - tensor_ids.push_back(infoptr->tensor_id); - } else { - // TODO(JackCaoG): Make sure this device data is actually seed. - tensor_ids.push_back(seed_info_id); - } - at::Tensor tensor = - bridge::AtenFromXlaTensor(torch_xla::XLATensor::Create(backend_data)); - ivalues.emplace_back(tensor); - } - return std::make_pair(tensor_ids, ivalues); + // Dedup by handle + torch::lazy::BackendData::Handle handle = backend_data->GetHandle(); + if (!data_handles.insert(handle).second) { + continue; + } + auto* infoptr = + static_cast( + backend_data->info()); + if (infoptr) { + tensor_ids.push_back(infoptr->tensor_id); + } else { + // TODO(JackCaoG): Make sure this device data is actually seed. + tensor_ids.push_back(seed_info_id); + } + at::Tensor tensor = bridge::AtenFromXlaTensor( + torch_xla::XLATensor::Create(backend_data)); + ivalues.emplace_back(tensor); + } + return std::make_pair(tensor_ids, ivalues); }); - m.def("_get_seed_info_id", []() -> int64_t { - return seed_info_id; }); + m.def("_get_seed_info_id", []() -> int64_t { return seed_info_id; }); m.def("_get_base_seed_as_tensor", [](const std::string& device_str) -> at::IValue { - torch::lazy::BackendDevice device = - bridge::AtenDeviceToXlaDevice(c10::Device(device_str)); - return bridge::AtenFromXlaTensor(torch_xla::XLATensor::Create( - XLAGraphExecutor::Get()->GetBaseSeedData(device))); + torch::lazy::BackendDevice device = + bridge::AtenDeviceToXlaDevice(c10::Device(device_str)); + return bridge::AtenFromXlaTensor(torch_xla::XLATensor::Create( + XLAGraphExecutor::Get()->GetBaseSeedData(device))); }); // Return true if value of the tensor requires a computation. m.def("_check_tensor_need_materialization", [](const std::vector& tensors) -> std::vector { - std::vector xtensors; - xtensors.reserve(tensors.size()); - for (const at::Tensor& tensor : tensors) { - xtensors.push_back(bridge::TryGetXlaTensor(tensor)); - } - return check_materialization_helper(xtensors); + std::vector xtensors; + xtensors.reserve(tensors.size()); + for (const at::Tensor& tensor : tensors) { + xtensors.push_back(bridge::TryGetXlaTensor(tensor)); + } + return check_materialization_helper(xtensors); }); // Return true if value of the any tensor in this devicerequires a // computation. m.def("_check_device_tensor_need_materialization", [](const std::string& device_str) -> std::vector { - auto opt_device = GetOptionalDevice(device_str); - std::vector xtensors = - XLAGraphExecutor::Get()->GetLiveTensors(opt_device ? &opt_device.value() - : nullptr); - return check_materialization_helper(xtensors); + auto opt_device = GetOptionalDevice(device_str); + std::vector xtensors = + XLAGraphExecutor::Get()->GetLiveTensors( + opt_device ? &opt_device.value() : nullptr); + return check_materialization_helper(xtensors); }); m.def("_get_graph_hash", [](const std::vector& tensors) { @@ -2046,23 +2045,24 @@ void InitXlaModuleBindings(py::module m) { [](const std::string& hash_str, const std::vector& graph_inputs) -> std::vector { - XLA_CHECK(hash_str.size() == sizeof(torch::lazy::hash_t)); - torch::lazy::hash_t hash = *(torch::lazy::hash_t*)(hash_str.c_str()); - // Device will be Virtual device if SPMD is enabled. - torch::lazy::BackendDevice device = torch_xla::bridge::GetCurrentDevice(); - auto results = XLAGraphExecutor::Get()->ExecuteComputationWithBarrier( - hash, graph_inputs, device); - std::vector retlist; - { - TORCH_LAZY_TIMED("RunCachedGraphOutputData"); - // Convert result back to at::tensor - for (const auto& data : results) { - XLATensorPtr xla_tensor = torch_xla::XLATensor::Create(data); - retlist.push_back(bridge::AtenFromXlaTensor(xla_tensor)); - } - } + XLA_CHECK(hash_str.size() == sizeof(torch::lazy::hash_t)); + torch::lazy::hash_t hash = *(torch::lazy::hash_t*)(hash_str.c_str()); + // Device will be Virtual device if SPMD is enabled. + torch::lazy::BackendDevice device = + torch_xla::bridge::GetCurrentDevice(); + auto results = XLAGraphExecutor::Get()->ExecuteComputationWithBarrier( + hash, graph_inputs, device); + std::vector retlist; + { + TORCH_LAZY_TIMED("RunCachedGraphOutputData"); + // Convert result back to at::tensor + for (const auto& data : results) { + XLATensorPtr xla_tensor = torch_xla::XLATensor::Create(data); + retlist.push_back(bridge::AtenFromXlaTensor(xla_tensor)); + } + } - return retlist; + return retlist; }); // -------------Dynamo Integration API End------------------------- } diff --git a/torch_xla/csrc/ops/custom_mark_sharding.cpp b/torch_xla/csrc/ops/custom_mark_sharding.cpp index bb3d371c0ae..af9302590f4 100644 --- a/torch_xla/csrc/ops/custom_mark_sharding.cpp +++ b/torch_xla/csrc/ops/custom_mark_sharding.cpp @@ -7,11 +7,10 @@ namespace torch_xla { CustomMarkSharding::CustomMarkSharding(const torch::lazy::Value& input, - xla::OpSharding sharding) - : XlaNode(xla_custom_mark_sharding, {input}, GetXlaShape(input), + const torch::lazy::Value& sharding) + : XlaNode(xla_custom_mark_sharding, {input, sharding}, GetXlaShape(input), /*num_outputs=*/1, - torch::lazy::MHash(std::string("MarkSharding"))), - sharding_(sharding) {} + torch::lazy::MHash(std::string("MarkSharding"))) {} torch::lazy::NodePtr CustomMarkSharding::Clone( torch::lazy::OpList operands) const { diff --git a/torch_xla/csrc/ops/custom_mark_sharding.h b/torch_xla/csrc/ops/custom_mark_sharding.h index 0870a99112c..32c3c3c1512 100644 --- a/torch_xla/csrc/ops/custom_mark_sharding.h +++ b/torch_xla/csrc/ops/custom_mark_sharding.h @@ -8,16 +8,13 @@ namespace torch_xla { class CustomMarkSharding : public XlaNode { public: // Make a custom call to Sharding. - CustomMarkSharding(const torch::lazy::Value& input, xla::OpSharding sharding); + CustomMarkSharding(const torch::lazy::Value& input, const torch::lazy::Value& sharding); torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; XlaOpVector Lower(LoweringContext* loctx) const override; std::string ToString() const override; - - private: - xla::OpSharding sharding_; }; } // namespace torch_xla diff --git a/torch_xla/csrc/ops/xla_ops.cpp b/torch_xla/csrc/ops/xla_ops.cpp index fa106e5849c..7e058573bd0 100644 --- a/torch_xla/csrc/ops/xla_ops.cpp +++ b/torch_xla/csrc/ops/xla_ops.cpp @@ -30,5 +30,6 @@ const OpKindWrapper xla_tensor_data("xla::tensor_data"); const OpKindWrapper xla_unselect("xla::unselect"); const OpKindWrapper xla_update_slice("xla::update_slice"); const OpKindWrapper xla_custom_sharding("xla::custom_sharding"); +const OpKindWrapper xla_custom_mark_sharding("xla::custom_mark_sharding"); } // namespace torch_xla diff --git a/torch_xla/csrc/ops/xla_ops.h b/torch_xla/csrc/ops/xla_ops.h index fa8082b978a..b5dfa31ca4e 100644 --- a/torch_xla/csrc/ops/xla_ops.h +++ b/torch_xla/csrc/ops/xla_ops.h @@ -55,6 +55,7 @@ extern const OpKindWrapper xla_tensor_data; extern const OpKindWrapper xla_unselect; extern const OpKindWrapper xla_update_slice; extern const OpKindWrapper xla_custom_sharding; +extern const OpKindWrapper xla_custom_mark_sharding; } // namespace torch_xla diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 59407e17044..f2e80980c4f 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -39,6 +39,7 @@ #include "torch_xla/csrc/ops/count_nonzero.h" #include "torch_xla/csrc/ops/cumprod.h" #include "torch_xla/csrc/ops/cumsum.h" +#include "torch_xla/csrc/ops/custom_mark_sharding.h" #include "torch_xla/csrc/ops/custom_sharding.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/diagonal.h" @@ -140,6 +141,7 @@ #include "torch_xla/csrc/tensor_ops.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/xla_graph_executor.h" +#include "torch_xla/csrc/xla_sharding_util.h" #include "xla/literal_util.h" namespace torch_xla { @@ -443,9 +445,73 @@ void custom_sharding_( } void custom_mark_sharding(const XLATensorPtr& input, xla::OpSharding sharding) { - torch::lazy::NodePtr node = torch::lazy::MakeNode( - torch::lazy::MakeNode(input->GetIrValue(), sharding)); - // TODO (@wonjoo) what do I return here? + // TODO (@wonjoo) Do we need this `sharding` here? + input->SetInPlaceIrValue(torch::lazy::MakeNode( + input->GetIrValue(), input->GetIrValue())); + + TORCH_LAZY_COUNTER("XlaMarkSharding", 1); + XLA_CHECK(UseVirtualDevice()) + << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; + // XLATensorPtr xtensor = bridge::GetXlaTensor(input); + auto new_sharding_spec = std::make_shared( + sharding, MakeShapeWithDeviceLayout( + input->shape(), + static_cast(input->GetDevice().type()))); + + // For Non DeviceData IR values, we directly attach the sharding spec + // to the xtensor. + const DeviceData* device_data_node = nullptr; + if (input->CurrentIrValue()) { + device_data_node = DeviceData::Cast(input->CurrentIrValue().node.get()); + if (!device_data_node) { + tensor_methods::custom_sharding_(input, new_sharding_spec); + return; + } + } + + // For data, we need to deal with the data transfers between + // host and device. + at::Tensor cpu_tensor; + if (input->CurrentTensorData().has_value()) { + TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); + // When virtual device is enabled for SPMD, we defer the initial + // data transfer to the device and retain the original data on the + // host, until the sharded data transfer. + cpu_tensor = input->CurrentTensorData().value(); + } else { + // A new input tensor is not expected to be sharded. But sometimes, + // the same input is called for sharding annotation over multiple steps, + // in which case we can skip if it's the same sharding; however, if it's + // the same input with a different sharding then we block & ask the user + // to clear the existing sharding first. + auto current_sharding_spec = input->sharding_spec(); + if (current_sharding_spec && (current_sharding_spec->sharding.type() != + xla::OpSharding::REPLICATED)) { + XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, + *current_sharding_spec)) + << "Existing annotation must be cleared first."; + return; + } + + // If the at::Tensor data is not present, we need to re-download the + // tensor from the physical device to CPU. In that case, the value + // must be present on the backend device. + XLA_CHECK((input->CurrentDataHandle() && + input->CurrentDataHandle()->HasValue()) || + device_data_node != nullptr) + << "Cannot shard tensor. Data does not present on any device."; + std::vector xla_tensors{input}; + cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; + } + auto xla_data = CreateTensorsData( + std::vector{cpu_tensor}, + std::vector{new_sharding_spec}, + std::vector{GetVirtualDevice().toString()})[0]; + input->SetXlaData(xla_data); + input->SetShardingSpec(*new_sharding_spec); + + // Register sharded tensor data. + XLAGraphExecutor::Get()->RegisterTensor(input->data()); } XLATensorPtr get_dimensions_size(const XLATensorPtr& input, diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index eef41c83980..c573adeb31e 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1230,26 +1230,9 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) { xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device, const xla::XlaOp& input, - xla::OpSharding sharding) { - auto new_sharding_spec = std::make_shared( - sharding, - MakeShapeWithDeviceLayout(ShapeHelper::ShapeOfXlaOp(input), - static_cast(device.type()))); - - // For Non DeviceData IR values, we directly attach the sharding spec - // to the xtensor. - const DeviceData* device_data_node = nullptr; - if (xtensor->CurrentIrValue()) { - device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); - if (!device_data_node) { - tensor_methods::custom_sharding_(xtensor, new_sharding_spec); - return; - } - - // TODO move rest of - // `xla/torch_xla/csrc/init_python_bindings.cpp::_xla_mark_sharding`. Note - // to self: `_xla_mark_sharding` works with XLATensorPtr directly, as - // opposed to XlaOp here. - } + const xla::XlaOp sharding) { + return xla::CustomCall(input.builder(), /*call_target_name=*/"MarkSharding", + {input}, ShapeHelper::ShapeOfXlaOp(input)); +} } // namespace torch_xla diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index ae7ed8ff06e..21970d0af23 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -152,7 +152,7 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input); xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device, const xla::XlaOp& input, - xla::OpSharding sharding); + const xla::XlaOp& sharding); } // namespace torch_xla From a20d710f2e8bc0ac13e74a3e0d28a2effe19fb97 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Tue, 24 Oct 2023 22:55:13 +0000 Subject: [PATCH 06/14] Re-run linter due to wrong version --- torch_xla/csrc/aten_xla_type.cpp | 7 +- torch_xla/csrc/init_python_bindings.cpp | 176 ++++++++++------------ torch_xla/csrc/ops/custom_mark_sharding.h | 3 +- 3 files changed, 89 insertions(+), 97 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index dfceec0353a..49780fc8811 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -114,9 +114,10 @@ void CheckSubOperandTypes(at::ScalarType type1, at::ScalarType type2) { c10::optional PromoteIntegralType( at::ScalarType src_dtype, const c10::optional& opt_dtype) { - return opt_dtype.has_value() ? opt_dtype.value() - : at::isIntegralType(src_dtype, /*includeBool=*/true) ? at::kLong - : opt_dtype; + return opt_dtype.has_value() + ? opt_dtype.value() + : at::isIntegralType(src_dtype, /*includeBool=*/true) ? at::kLong + : opt_dtype; } bool IsTypeWithLargerRangeThanLong(torch::ScalarType dtype) { diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 7d380f725b6..bf5ff032766 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -968,28 +968,27 @@ void InitXlaModuleBindings(py::module m) { ShardingUtil::ShardingType(sharding_type)), global_shape, minibatch); })); - m.def( - "_xla_tensors_from_aten", - [](const std::vector& tensors, - const std::vector& devices, - const std::optional>& - shardings) { - std::vector result; - { - NoGilSection nogil; - std::vector xla_tensors = - GetXlaTensorsFromAten(tensors, devices, shardings); - result.reserve(xla_tensors.size()); - for (size_t i = 0; i < xla_tensors.size(); ++i) { - result.push_back(torch::autograd::make_variable( - xla_tensors[i], - /*requires_grad=*/tensors.at(i).requires_grad())); + m.def("_xla_tensors_from_aten", + [](const std::vector& tensors, + const std::vector& devices, + const std::optional>& + shardings) { + std::vector result; + { + NoGilSection nogil; + std::vector xla_tensors = + GetXlaTensorsFromAten(tensors, devices, shardings); + result.reserve(xla_tensors.size()); + for (size_t i = 0; i < xla_tensors.size(); ++i) { + result.push_back(torch::autograd::make_variable( + xla_tensors[i], + /*requires_grad=*/tensors.at(i).requires_grad())); + } } - } - return result; - }, - py::arg("tensors"), py::arg("devices"), - py::arg("shardings") = py::none()); + return result; + }, + py::arg("tensors"), py::arg("devices"), + py::arg("shardings") = py::none()); m.def("_xla_get_cpu_tensors", [](const std::vector& tensors) { std::vector result; { @@ -1289,51 +1288,45 @@ void InitXlaModuleBindings(py::module m) { } return list; }); - m.def( - "_xla_set_rng_seed", - [](uint64_t seed, const std::string& device) { - SetRngSeed(seed, device); - }, - py::arg("seed") = 101, py::arg("device") = ""); - m.def( - "_xla_get_rng_seed", - [](const std::string& device) { return GetRngSeed(device); }, - py::arg("device") = ""); - m.def( - "_xla_sync_multi", - [](const std::vector& tensors, - const std::vector& devices, bool wait, - bool sync_xla_data) { - NoGilSection nogil; - SyncTensors(tensors, devices, wait, sync_xla_data); - }, - py::arg("tensors"), py::arg("devices"), py::arg("wait") = true, - py::arg("sync_xla_data") = true); - m.def( - "_xla_warm_up_cache", - [](const std::vector& tensors, - const std::vector& devices) { - NoGilSection nogil; - SyncTensors(tensors, devices, /*wait=*/false, /*sync_xla_data=*/false, - /*warm_up_cache_only=*/true); - }, - py::arg("tensors"), py::arg("devices")); - m.def( - "_xla_sync_live_tensors", - [](const std::string& device, const std::vector& devices, - bool wait) { - NoGilSection nogil; - SyncLiveTensors(device, devices, wait); - }, - py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); - m.def( - "_xla_step_marker", - [](const std::string& device, const std::vector& devices, - bool wait) { - NoGilSection nogil; - StepMarker(device, devices, wait); - }, - py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); + m.def("_xla_set_rng_seed", + [](uint64_t seed, const std::string& device) { + SetRngSeed(seed, device); + }, + py::arg("seed") = 101, py::arg("device") = ""); + m.def("_xla_get_rng_seed", + [](const std::string& device) { return GetRngSeed(device); }, + py::arg("device") = ""); + m.def("_xla_sync_multi", + [](const std::vector& tensors, + const std::vector& devices, bool wait, + bool sync_xla_data) { + NoGilSection nogil; + SyncTensors(tensors, devices, wait, sync_xla_data); + }, + py::arg("tensors"), py::arg("devices"), py::arg("wait") = true, + py::arg("sync_xla_data") = true); + m.def("_xla_warm_up_cache", + [](const std::vector& tensors, + const std::vector& devices) { + NoGilSection nogil; + SyncTensors(tensors, devices, /*wait=*/false, /*sync_xla_data=*/false, + /*warm_up_cache_only=*/true); + }, + py::arg("tensors"), py::arg("devices")); + m.def("_xla_sync_live_tensors", + [](const std::string& device, const std::vector& devices, + bool wait) { + NoGilSection nogil; + SyncLiveTensors(device, devices, wait); + }, + py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); + m.def("_xla_step_marker", + [](const std::string& device, const std::vector& devices, + bool wait) { + NoGilSection nogil; + StepMarker(device, devices, wait); + }, + py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); m.def("_get_stablehlo", [](const std::vector& tensors, const std::string& device, const std::vector& devices, @@ -1370,19 +1363,18 @@ void InitXlaModuleBindings(py::module m) { } return retlist; }); - m.def( - "_xla_wait_device_ops", - [](const std::vector& devices) { - NoGilSection nogil; - XLAGraphExecutor::Get()->WaitDeviceOps(devices); - if (UseVirtualDevice()) { - std::vector spmd_device = {"SPMD:0"}; - runtime::GetComputationClient()->WaitDeviceOps(spmd_device); - } else { - runtime::GetComputationClient()->WaitDeviceOps(devices); - } - }, - py::arg("devices")); + m.def("_xla_wait_device_ops", + [](const std::vector& devices) { + NoGilSection nogil; + XLAGraphExecutor::Get()->WaitDeviceOps(devices); + if (UseVirtualDevice()) { + std::vector spmd_device = {"SPMD:0"}; + runtime::GetComputationClient()->WaitDeviceOps(spmd_device); + } else { + runtime::GetComputationClient()->WaitDeviceOps(devices); + } + }, + py::arg("devices")); m.def("_xla_counter_names", []() { auto counter_names = torch::lazy::GetCounterNames(); auto xla_counter_names = runtime::metrics::GetCounterNames(); @@ -1447,23 +1439,21 @@ void InitXlaModuleBindings(py::module m) { torch::lazy::MetricsArena::Get()->ResetMetrics(); runtime::metrics::ClearMetrics(); }); - m.def( - "_xla_tensors_report", - [](size_t nodes_threshold, const std::string& device) { - return GetLiveTensorsReport(nodes_threshold, device); - }, - py::arg("nodes_threshold") = 100, py::arg("device") = ""); + m.def("_xla_tensors_report", + [](size_t nodes_threshold, const std::string& device) { + return GetLiveTensorsReport(nodes_threshold, device); + }, + py::arg("nodes_threshold") = 100, py::arg("device") = ""); m.def("_xla_memory_info", [](const std::string& device) -> py::object { return GetMemoryInfo(device); }); - m.def( - "_xla_set_use_full_mat_mul_precision", - [](bool use_full_mat_mul_precision) { - XlaHelpers::set_mat_mul_precision(use_full_mat_mul_precision - ? xla::PrecisionConfig::HIGHEST - : xla::PrecisionConfig::DEFAULT); - }, - py::arg("use_full_mat_mul_precision") = true); + m.def("_xla_set_use_full_mat_mul_precision", + [](bool use_full_mat_mul_precision) { + XlaHelpers::set_mat_mul_precision( + use_full_mat_mul_precision ? xla::PrecisionConfig::HIGHEST + : xla::PrecisionConfig::DEFAULT); + }, + py::arg("use_full_mat_mul_precision") = true); py::class_(m, "XlaBuilder"); py::class_(m, "XlaOp"); diff --git a/torch_xla/csrc/ops/custom_mark_sharding.h b/torch_xla/csrc/ops/custom_mark_sharding.h index 32c3c3c1512..a23323c9bbf 100644 --- a/torch_xla/csrc/ops/custom_mark_sharding.h +++ b/torch_xla/csrc/ops/custom_mark_sharding.h @@ -8,7 +8,8 @@ namespace torch_xla { class CustomMarkSharding : public XlaNode { public: // Make a custom call to Sharding. - CustomMarkSharding(const torch::lazy::Value& input, const torch::lazy::Value& sharding); + CustomMarkSharding(const torch::lazy::Value& input, + const torch::lazy::Value& sharding); torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; From bd169c2e3861eb0975c9fed6f1dbd06df69a9876 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Thu, 26 Oct 2023 22:47:29 +0000 Subject: [PATCH 07/14] Add new API for custom mark sharding op and update tests --- test/spmd/test_dynamo_spmd.py | 18 ++++++++++++++++- torch_xla/csrc/init_python_bindings.cpp | 2 +- torch_xla/csrc/xla_lower_util.cpp | 2 +- torch_xla/experimental/xla_sharding.py | 27 +++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 3 deletions(-) diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 7dc839979ee..7b835a26754 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -53,6 +53,21 @@ def test_dynamo_spmd_basic(self): # TODO(JackCaoG): add counter checks after ExecuteReplicated also creates # a ExecuteMetric. + def test_dynamo_spmd_basic_with_custom_mark_sharding_op(self): + device = xm.xla_device() + linear = SimpleLinear().to(device) + linear.eval() + xla_x = torch.randn(1, 128, device=device) + xs.mark_sharding_dynamo_custom_op(linear.fc2.weight, + self._get_mesh((1, self.n_devices)), + (1, 0)) + xla_res = linear(xla_x) + xm.mark_step() + + dynamo_linear = torch.compile(linear, backend="openxla") + dynamo_res = dynamo_linear(xla_x) + torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + def test_dynamo_spmd_output_sharding_spec(self): device = xm.xla_device() linear = SimpleLinear().to(device) @@ -177,7 +192,8 @@ def fn_simple(x): y = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, device=xm.xla_device()) - ys = xs.mark_sharding(y, self._get_mesh((1, self.n_devices)), (0, 1)) + ys = xs.mark_sharding_dynamo_custom_op( + y, self._get_mesh((1, self.n_devices)), (0, 1)) return x + ys diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index bf5ff032766..a865b7ad7d0 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1625,7 +1625,7 @@ void InitXlaModuleBindings(py::module m) { // Register sharded tensor data. XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); }); - m.def("_xla_mark_sharding_custom_op", + m.def("_xla_mark_sharding_dynamo_custom_op", [](const at::Tensor& input, xla::OpSharding sharding) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); tensor_methods::custom_mark_sharding(xtensor, sharding); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index c573adeb31e..567bd12ad5d 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1230,7 +1230,7 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) { xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device, const xla::XlaOp& input, - const xla::XlaOp sharding) { + const xla::XlaOp& sharding) { return xla::CustomCall(input.builder(), /*call_target_name=*/"MarkSharding", {input}, ShapeHelper::ShapeOfXlaOp(input)); } diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 21d0e2e570a..31c44e3c9d3 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -506,6 +506,33 @@ def mark_sharding( return XLAShardedTensor(t) +@xr.requires_pjrt +def mark_sharding_dynamo_custom_op( + t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, + partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor: + """ + Same functionality as `mark_sharding` above, except this variant uses the custom mark_sharding op in torch_xla._XLAC to allow dynamo to recognize and trace it. + """ + num_devices = xr.global_runtime_device_count() + assert num_devices > 0, "This requires XLA supported device(s)." + assert mesh.size() == num_devices, \ + f"{mesh.mesh_shape} is not mappable over {num_devices} devices." + # We only allow fully specified `partition_spec` to be applicable, as opposed + # to filling in the unspecified replicated dims. Fully specified `partiion_spec` + # should be of the same rank as `t`. This is to support partial replication + # where the group assignment may vary with different input ranks. + assert len(t.shape) == len(partition_spec), \ + f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." + + op_sharding = mesh.get_op_sharding(partition_spec) + + if isinstance(t, XLAShardedTensor): + torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) + return t + torch_xla._XLAC._xla_mark_sharding(t, op_sharding) + return XLAShardedTensor(t) + + def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor: """Clear sharding annotation from the input tensor and return a `cpu` casted tensor.""" torch_xla._XLAC._xla_clear_sharding(t) From ae05c9a445258162082d57f2470d03fefb54441d Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Tue, 31 Oct 2023 06:54:28 +0000 Subject: [PATCH 08/14] Add torch pin --- .torch_pin | 1 + 1 file changed, 1 insertion(+) create mode 100644 .torch_pin diff --git a/.torch_pin b/.torch_pin new file mode 100644 index 00000000000..22aa007a16d --- /dev/null +++ b/.torch_pin @@ -0,0 +1 @@ +#112483 From 08d6296c91d4cd25fcd3ac8d4372e3f00693341e Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Tue, 31 Oct 2023 23:38:10 +0000 Subject: [PATCH 09/14] Clean up some code --- .torch_pin | 1 - test/spmd/test_dynamo_spmd.py | 32 +++- torch_xla/csrc/aten_autograd_ops.cpp | 22 +-- torch_xla/csrc/aten_autograd_ops.h | 11 ++ torch_xla/csrc/init_python_bindings.cpp | 191 +++++++++++++------- torch_xla/csrc/ops/custom_mark_sharding.cpp | 34 ---- torch_xla/csrc/ops/custom_mark_sharding.h | 23 --- torch_xla/csrc/tensor_methods.cpp | 71 -------- torch_xla/csrc/tensor_methods.h | 2 - torch_xla/experimental/xla_sharding.py | 19 +- 10 files changed, 188 insertions(+), 218 deletions(-) delete mode 100644 .torch_pin delete mode 100644 torch_xla/csrc/ops/custom_mark_sharding.cpp delete mode 100644 torch_xla/csrc/ops/custom_mark_sharding.h diff --git a/.torch_pin b/.torch_pin deleted file mode 100644 index 22aa007a16d..00000000000 --- a/.torch_pin +++ /dev/null @@ -1 +0,0 @@ -#112483 diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 7b835a26754..f3abaec8978 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -200,11 +200,35 @@ def fn_simple(x): device = xm.xla_device() x_xla = torch.zeros((1, 8)).to(device) xla_res = fn_simple(x_xla) - xm.mark_step() + print(xla_res) + # xm.mark_step() - dynamo_linear = torch.compile(fn_simple, backend="openxla") - dynamo_res = dynamo_linear(x_xla) - torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + # dynamo_linear = torch.compile(fn_simple, backend="openxla") + # dynamo_res = dynamo_linear(x_xla) + # torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + + # TODO (@wonjoo) Remove this test, this is just for debugging + def test_wonjoo(self): + + def fn_simple(x): + print(f'x inside fn_simple before: {x}') + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(x_xla, [0], [0], [0], 0) + print(f'x inside fn_simple after: {x}') + return x + + device = xm.xla_device() + + x_xla = torch.zeros((1, 8)).to(device) + + # print(torch.ops.xla.add) + print(torch.ops.xla.max_pool2d_forward) + print(torch.ops.xla.xla_mark_sharding_dynamo_custom_op) + print(dir(torch.ops.xla.xla_mark_sharding_dynamo_custom_op)) + # print(f'x_xla before: {x_xla}') + + # dynamo_fn = torch.compile(fn_simple, backend="openxla") + # dynamo_res = dynamo_fn(x_xla) + # print(f'dynamo_res: {dynamo_res}') if __name__ == '__main__': diff --git a/torch_xla/csrc/aten_autograd_ops.cpp b/torch_xla/csrc/aten_autograd_ops.cpp index 08c6b27f92c..9d8edbbb731 100644 --- a/torch_xla/csrc/aten_autograd_ops.cpp +++ b/torch_xla/csrc/aten_autograd_ops.cpp @@ -253,17 +253,17 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, return grad; } -TORCH_LIBRARY(xla, m) { - m.def( - "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " - "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", - torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward))); +// TORCH_LIBRARY(xla, m) { +// m.def( +// "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " +// "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", +// torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward))); - m.def( - "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " - "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " - "-> Tensor", - torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward))); -} +// m.def( +// "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " +// "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " +// "-> Tensor", +// torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward))); +// } } // namespace aten_autograd_ops } // namespace torch_xla diff --git a/torch_xla/csrc/aten_autograd_ops.h b/torch_xla/csrc/aten_autograd_ops.h index be063b76620..d1cc8a98048 100644 --- a/torch_xla/csrc/aten_autograd_ops.h +++ b/torch_xla/csrc/aten_autograd_ops.h @@ -46,6 +46,17 @@ struct MaxPool3dAutogradFunction torch::autograd::variable_list grad_output); }; +torch::Tensor max_pool2d_forward(torch::Tensor self, + torch::IntArrayRef kernel_size, + torch::IntArrayRef stride, + torch::IntArrayRef padding, + torch::IntArrayRef dilation, bool ceil_mode); + +torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, + torch::IntArrayRef kernel_size, + torch::IntArrayRef stride, + torch::IntArrayRef padding, bool ceil_mode); + } // namespace aten_autograd_ops } // namespace torch_xla diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a865b7ad7d0..b602a686e72 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -29,6 +29,7 @@ #include "pybind11/pytypes.h" #include "pybind11/stl_bind.h" #include "torch_xla/csrc/XLANativeFunctions.h" +#include "torch_xla/csrc/aten_autograd_ops.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/helpers.h" @@ -651,6 +652,121 @@ std::string GetPyTypeString(py::handle obj) { return type; } +void xla_mark_sharding(const at::Tensor& input, xla::OpSharding sharding) { + TORCH_LAZY_COUNTER("XlaMarkSharding", 1); + XLA_CHECK(UseVirtualDevice()) + << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + auto new_sharding_spec = std::make_shared( + sharding, MakeShapeWithDeviceLayout( + xtensor->shape(), + static_cast(xtensor->GetDevice().type()))); + + // For Non DeviceData IR values, we directly attach the sharding spec + // to the xtensor. + const DeviceData* device_data_node = nullptr; + if (xtensor->CurrentIrValue()) { + device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (!device_data_node) { + tensor_methods::custom_sharding_(xtensor, new_sharding_spec); + return; + } + } + + // For data, we need to deal with the data transfers between + // host and device. + at::Tensor cpu_tensor; + if (xtensor->CurrentTensorData().has_value()) { + TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); + // When virtual device is enabled for SPMD, we defer the initial + // data transfer to the device and retain the original data on the + // host, until the sharded data transfer. + cpu_tensor = xtensor->CurrentTensorData().value(); + } else { + // A new input tensor is not expected to be sharded. But sometimes, + // the same input is called for sharding annotation over multiple steps, + // in which case we can skip if it's the same sharding; however, if it's + // the same input with a different sharding then we block & ask the user + // to clear the existing sharding first. + auto current_sharding_spec = xtensor->sharding_spec(); + if (current_sharding_spec && (current_sharding_spec->sharding.type() != + xla::OpSharding::REPLICATED)) { + XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, + *current_sharding_spec)) + << "Existing annotation must be cleared first."; + return; + } + + // If the at::Tensor data is not present, we need to re-download the + // tensor from the physical device to CPU. In that case, the value + // must be present on the backend device. + XLA_CHECK((xtensor->CurrentDataHandle() && + xtensor->CurrentDataHandle()->HasValue()) || + device_data_node != nullptr) + << "Cannot shard tensor. Data does not present on any device."; + std::vector xla_tensors{xtensor}; + cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; + } + auto xla_data = CreateTensorsData( + std::vector{cpu_tensor}, + std::vector{new_sharding_spec}, + std::vector{GetVirtualDevice().toString()})[0]; + xtensor->SetXlaData(xla_data); + xtensor->SetShardingSpec(*new_sharding_spec); + + // Register sharded tensor data. + XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); +} + +void xla_mark_sharding_dynamo_custom_op(const at::Tensor& input, c10::List tile_assignment, c10::List group_assignment, c10::List replication_groups, int64_t sharding_type) { + std::cout << "at xla_mark_sharding_dynamo_custom_op0" << std::endl; + + std::cout << "input: " << input << std::endl; + // std::cout << "tile_assignment: " << tile_assignment << std::endl; + std::cout << "converting tile_assignment_py" << std::endl; + // const py::list& tile_assignment_py = py::cast(tile_assignment[0]); + const py::list& tile_assignment_py = py::cast(torch::lazy::ToVector(tile_assignment[0])); + + // std::cout << "group_assignment: " << group_assignment << std::endl; + std::cout << "converting group_assignment_py" << std::endl; + const py::list& group_assignment_py = py::cast(group_assignment); + + // std::cout << "replication_groups: " << replication_groups << std::endl; + std::cout << "converting replication_groups_py" << std::endl; + const py::list& replication_groups_py = py::cast(replication_groups); + + std::cout << "at xla_mark_sharding_dynamo_custom_op1" << std::endl; + + const xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding( + tile_assignment_py, group_assignment_py, replication_groups_py, + ShardingUtil::ShardingType(sharding_type)); + + + std::cout << "at xla_mark_sharding_dynamo_custom_op2" << std::endl; + + xla_mark_sharding(input, op_sharding); + + std::cout << "at xla_mark_sharding_dynamo_custom_op3" << std::endl; +} + +// Macro for defining a function that will be run at static initialization time to define a library of operators in the namespace. +// Used to define a new set of custom operators that do not already exist in PyTorch. +TORCH_LIBRARY(xla, m) { + m.def( + "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " + "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", + torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); + + m.def( + "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " + "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " + "-> Tensor", + torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); + m.def( + "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] tile_assignment, int[][] group_assignment, int[][] replication_groups, int sharding_type) -> ()", + torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(xla_mark_sharding_dynamo_custom_op))); +} + std::vector check_materialization_helper( const std::vector& xtensors) { std::vector need_materialization; @@ -1561,75 +1677,14 @@ void InitXlaModuleBindings(py::module m) { })); m.def("_xla_mark_sharding", [](const at::Tensor& input, xla::OpSharding sharding) { - TORCH_LAZY_COUNTER("XlaMarkSharding", 1); - XLA_CHECK(UseVirtualDevice()) - << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - auto new_sharding_spec = std::make_shared( - sharding, MakeShapeWithDeviceLayout( - xtensor->shape(), - static_cast(xtensor->GetDevice().type()))); - - // For Non DeviceData IR values, we directly attach the sharding spec - // to the xtensor. - const DeviceData* device_data_node = nullptr; - if (xtensor->CurrentIrValue()) { - device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); - if (!device_data_node) { - tensor_methods::custom_sharding_(xtensor, new_sharding_spec); - return; - } - } - - // For data, we need to deal with the data transfers between - // host and device. - at::Tensor cpu_tensor; - if (xtensor->CurrentTensorData().has_value()) { - TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); - // When virtual device is enabled for SPMD, we defer the initial - // data transfer to the device and retain the original data on the - // host, until the sharded data transfer. - cpu_tensor = xtensor->CurrentTensorData().value(); - } else { - // A new input tensor is not expected to be sharded. But sometimes, - // the same input is called for sharding annotation over multiple steps, - // in which case we can skip if it's the same sharding; however, if it's - // the same input with a different sharding then we block & ask the user - // to clear the existing sharding first. - auto current_sharding_spec = xtensor->sharding_spec(); - if (current_sharding_spec && (current_sharding_spec->sharding.type() != - xla::OpSharding::REPLICATED)) { - XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, - *current_sharding_spec)) - << "Existing annotation must be cleared first."; - return; - } - - // If the at::Tensor data is not present, we need to re-download the - // tensor from the physical device to CPU. In that case, the value - // must be present on the backend device. - XLA_CHECK((xtensor->CurrentDataHandle() && - xtensor->CurrentDataHandle()->HasValue()) || - device_data_node != nullptr) - << "Cannot shard tensor. Data does not present on any device."; - std::vector xla_tensors{xtensor}; - cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; - } - auto xla_data = CreateTensorsData( - std::vector{cpu_tensor}, - std::vector{new_sharding_spec}, - std::vector{GetVirtualDevice().toString()})[0]; - xtensor->SetXlaData(xla_data); - xtensor->SetShardingSpec(*new_sharding_spec); - - // Register sharded tensor data. - XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); + xla_mark_sharding(input, sharding); }); - m.def("_xla_mark_sharding_dynamo_custom_op", - [](const at::Tensor& input, xla::OpSharding sharding) { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - tensor_methods::custom_mark_sharding(xtensor, sharding); - }); + // m.def("_xla_mark_sharding_dynamo_custom_op", + // [](const at::Tensor& input, xla::OpSharding sharding) { + // // xla_mark_sharding_dynamo_custom_op(input, tile_assignment, group_assignment, replication_groups, sharding_type); + // // at::IntArrayRef tile_assignment, at::IntArrayRef group_assignment, c10::List replication_groups, int64_t sharding_type + // at::IntArrayRef tile_assignment = + // }); m.def("_xla_clear_sharding", [](const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); xtensor->ClearShardingSpec(); diff --git a/torch_xla/csrc/ops/custom_mark_sharding.cpp b/torch_xla/csrc/ops/custom_mark_sharding.cpp deleted file mode 100644 index af9302590f4..00000000000 --- a/torch_xla/csrc/ops/custom_mark_sharding.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "torch_xla/csrc/ops/custom_mark_sharding.h" - -#include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/ops/xla_ops.h" -#include "torch_xla/csrc/xla_lower_util.h" - -namespace torch_xla { - -CustomMarkSharding::CustomMarkSharding(const torch::lazy::Value& input, - const torch::lazy::Value& sharding) - : XlaNode(xla_custom_mark_sharding, {input, sharding}, GetXlaShape(input), - /*num_outputs=*/1, - torch::lazy::MHash(std::string("MarkSharding"))) {} - -torch::lazy::NodePtr CustomMarkSharding::Clone( - torch::lazy::OpList operands) const { - return torch::lazy::MakeNode(operands.at(0), - operands.at(1)); -} - -XlaOpVector CustomMarkSharding::Lower(LoweringContext* loctx) const { - xla::XlaOp input = loctx->GetOutputOp(operand(0)); - xla::XlaOp sharding = loctx->GetOutputOp(operand(1)); - return ReturnOp(BuildCustomMarkSharding(loctx->device(), input, sharding), - loctx); -} - -std::string CustomMarkSharding::ToString() const { - std::stringstream ss; - ss << XlaNode::ToString() << ", MarkSharding"; - return ss.str(); -} - -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/custom_mark_sharding.h b/torch_xla/csrc/ops/custom_mark_sharding.h deleted file mode 100644 index a23323c9bbf..00000000000 --- a/torch_xla/csrc/ops/custom_mark_sharding.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef XLA_TORCH_XLA_CSRC_OPS_CUSTOM_MARK_SHARDING_H_ -#define XLA_TORCH_XLA_CSRC_OPS_CUSTOM_MARK_SHARDING_H_ - -#include "torch_xla/csrc/ir.h" - -namespace torch_xla { - -class CustomMarkSharding : public XlaNode { - public: - // Make a custom call to Sharding. - CustomMarkSharding(const torch::lazy::Value& input, - const torch::lazy::Value& sharding); - - torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; - - XlaOpVector Lower(LoweringContext* loctx) const override; - - std::string ToString() const override; -}; - -} // namespace torch_xla - -#endif // XLA_TORCH_XLA_CSRC_OPS_CUSTOM_MARK_SHARDING_H_ diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index f2e80980c4f..5a0d8e50d98 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -39,7 +39,6 @@ #include "torch_xla/csrc/ops/count_nonzero.h" #include "torch_xla/csrc/ops/cumprod.h" #include "torch_xla/csrc/ops/cumsum.h" -#include "torch_xla/csrc/ops/custom_mark_sharding.h" #include "torch_xla/csrc/ops/custom_sharding.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/diagonal.h" @@ -444,76 +443,6 @@ void custom_sharding_( input->SetShardingSpec(*sharding_spec); } -void custom_mark_sharding(const XLATensorPtr& input, xla::OpSharding sharding) { - // TODO (@wonjoo) Do we need this `sharding` here? - input->SetInPlaceIrValue(torch::lazy::MakeNode( - input->GetIrValue(), input->GetIrValue())); - - TORCH_LAZY_COUNTER("XlaMarkSharding", 1); - XLA_CHECK(UseVirtualDevice()) - << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; - // XLATensorPtr xtensor = bridge::GetXlaTensor(input); - auto new_sharding_spec = std::make_shared( - sharding, MakeShapeWithDeviceLayout( - input->shape(), - static_cast(input->GetDevice().type()))); - - // For Non DeviceData IR values, we directly attach the sharding spec - // to the xtensor. - const DeviceData* device_data_node = nullptr; - if (input->CurrentIrValue()) { - device_data_node = DeviceData::Cast(input->CurrentIrValue().node.get()); - if (!device_data_node) { - tensor_methods::custom_sharding_(input, new_sharding_spec); - return; - } - } - - // For data, we need to deal with the data transfers between - // host and device. - at::Tensor cpu_tensor; - if (input->CurrentTensorData().has_value()) { - TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); - // When virtual device is enabled for SPMD, we defer the initial - // data transfer to the device and retain the original data on the - // host, until the sharded data transfer. - cpu_tensor = input->CurrentTensorData().value(); - } else { - // A new input tensor is not expected to be sharded. But sometimes, - // the same input is called for sharding annotation over multiple steps, - // in which case we can skip if it's the same sharding; however, if it's - // the same input with a different sharding then we block & ask the user - // to clear the existing sharding first. - auto current_sharding_spec = input->sharding_spec(); - if (current_sharding_spec && (current_sharding_spec->sharding.type() != - xla::OpSharding::REPLICATED)) { - XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, - *current_sharding_spec)) - << "Existing annotation must be cleared first."; - return; - } - - // If the at::Tensor data is not present, we need to re-download the - // tensor from the physical device to CPU. In that case, the value - // must be present on the backend device. - XLA_CHECK((input->CurrentDataHandle() && - input->CurrentDataHandle()->HasValue()) || - device_data_node != nullptr) - << "Cannot shard tensor. Data does not present on any device."; - std::vector xla_tensors{input}; - cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; - } - auto xla_data = CreateTensorsData( - std::vector{cpu_tensor}, - std::vector{new_sharding_spec}, - std::vector{GetVirtualDevice().toString()})[0]; - input->SetXlaData(xla_data); - input->SetShardingSpec(*new_sharding_spec); - - // Register sharded tensor data. - XLAGraphExecutor::Get()->RegisterTensor(input->data()); -} - XLATensorPtr get_dimensions_size(const XLATensorPtr& input, std::vector dimensions) { return input->CreateFrom(torch::lazy::MakeNode( diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 0f2bf21f0c4..5a714170300 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -57,8 +57,6 @@ std::pair collective_permute( void custom_sharding_(const XLATensorPtr& input, const std::shared_ptr& spec); -void custom_mark_sharding(const XLATensorPtr& input, xla::OpSharding sharding); - XLATensorPtr get_dimensions_size(const XLATensorPtr& input, std::vector dimensions); diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 31c44e3c9d3..f08b2370d7d 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -82,7 +82,7 @@ def get_axis_name_idx(self, name: str) -> int: @functools.lru_cache(maxsize=None) def get_op_sharding(self, - partition_spec: Tuple) -> torch_xla._XLAC.OpSharding: + partition_spec: Tuple, flatten = False) -> torch_xla._XLAC.OpSharding: """ Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. @@ -104,6 +104,15 @@ def get_op_sharding(self, replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} group_assignment, replication_groups = _get_group_assignment( sharding_type, tile_assignment, len(partition_spec), replicate_dims) + + # If flatten = True, return the flattened version of OpSharding + # print each return type to debug + print(tile_assignment.tolist()) + print(group_assignment) + print(replication_groups) + if flatten: + return (tile_assignment.tolist(), group_assignment, replication_groups, int(sharding_type)) + return torch_xla._XLAC.OpSharding(tile_assignment.tolist(), group_assignment, replication_groups, int(sharding_type)) @@ -524,12 +533,14 @@ def mark_sharding_dynamo_custom_op( assert len(t.shape) == len(partition_spec), \ f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." - op_sharding = mesh.get_op_sharding(partition_spec) + tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding(partition_spec, flatten = True) + print('about to call xla_mark_sharding_dynamo_custom_op') if isinstance(t, XLAShardedTensor): - torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) + torch.ops.xla.xla_mark_sharding_dynamo_custom_op(t.global_tensor, tile_assignment, group_assignment, replication_groups, sharding_type) return t - torch_xla._XLAC._xla_mark_sharding(t, op_sharding) + torch.ops.xla.xla_mark_sharding_dynamo_custom_op(t, tile_assignment, group_assignment, replication_groups, sharding_type) + print('xla_mark_sharding_dynamo_custom_op call finished') return XLAShardedTensor(t) From 9aaa533bc5574f5f00667523e02c5fa469cfddd3 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Mon, 6 Nov 2023 21:15:07 +0000 Subject: [PATCH 10/14] Update code to transfer pylist to xla::OpSharding --- test/spmd/test_dynamo_spmd.py | 54 ++++++----------- torch_xla/csrc/init_python_bindings.cpp | 77 +++++++++++++++---------- torch_xla/experimental/xla_sharding.py | 22 ++++--- 3 files changed, 75 insertions(+), 78 deletions(-) diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index f3abaec8978..87f9d20b62e 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -15,7 +15,7 @@ class SimpleLinear(nn.Module): - def __init__(self): + def __init__(self, mark_sharding_inside = False, op_sharding = None): super(SimpleLinear, self).__init__() self.fc1 = nn.Linear(128, 128) self.relu = nn.ReLU() @@ -25,6 +25,9 @@ def __init__(self): self.fc3 = nn.Linear(1, 1) def forward(self, x): + print(f'self.fc2.weight.device={self.fc2.weight.device}') + if self.mark_sharding_inside and self.op_sharding and 'xla' in self.fc2.weight.device: + xs.mark_sharding(self.fc2.weight, self.op_sharding) y = self.relu(self.fc1(x)) z = self.fc2(y) return self.fc3(z) @@ -187,48 +190,25 @@ def test_dynamo_input_sharding_threashold(self): del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] def test_mark_sharding_inside_compile(self): + device = xm.xla_device() - def fn_simple(x): - y = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], + def fn_simple(t): + xs.mark_sharding_dynamo_custom_op( + t, self._get_mesh((1, self.n_devices)), (0, 1)) + + x = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) - ys = xs.mark_sharding_dynamo_custom_op( - y, self._get_mesh((1, self.n_devices)), (0, 1)) + device=device) - return x + ys + return t + x - device = xm.xla_device() - x_xla = torch.zeros((1, 8)).to(device) + x_xla = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]]).to(device) xla_res = fn_simple(x_xla) - print(xla_res) - # xm.mark_step() - - # dynamo_linear = torch.compile(fn_simple, backend="openxla") - # dynamo_res = dynamo_linear(x_xla) - # torch.allclose(xla_res.cpu(), dynamo_res.cpu()) - - # TODO (@wonjoo) Remove this test, this is just for debugging - def test_wonjoo(self): - - def fn_simple(x): - print(f'x inside fn_simple before: {x}') - torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(x_xla, [0], [0], [0], 0) - print(f'x inside fn_simple after: {x}') - return x - - device = xm.xla_device() - - x_xla = torch.zeros((1, 8)).to(device) - - # print(torch.ops.xla.add) - print(torch.ops.xla.max_pool2d_forward) - print(torch.ops.xla.xla_mark_sharding_dynamo_custom_op) - print(dir(torch.ops.xla.xla_mark_sharding_dynamo_custom_op)) - # print(f'x_xla before: {x_xla}') + xm.mark_step() - # dynamo_fn = torch.compile(fn_simple, backend="openxla") - # dynamo_res = dynamo_fn(x_xla) - # print(f'dynamo_res: {dynamo_res}') + dynamo_fn_simple = torch.compile(fn_simple, backend="openxla") + dynamo_res = dynamo_fn_simple(x_xla) + torch.allclose(xla_res.cpu(), dynamo_res.cpu()) if __name__ == '__main__': diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index b602a686e72..8106550c7f3 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -719,34 +719,38 @@ void xla_mark_sharding(const at::Tensor& input, xla::OpSharding sharding) { } void xla_mark_sharding_dynamo_custom_op(const at::Tensor& input, c10::List tile_assignment, c10::List group_assignment, c10::List replication_groups, int64_t sharding_type) { - std::cout << "at xla_mark_sharding_dynamo_custom_op0" << std::endl; - - std::cout << "input: " << input << std::endl; - // std::cout << "tile_assignment: " << tile_assignment << std::endl; - std::cout << "converting tile_assignment_py" << std::endl; - // const py::list& tile_assignment_py = py::cast(tile_assignment[0]); - const py::list& tile_assignment_py = py::cast(torch::lazy::ToVector(tile_assignment[0])); - - // std::cout << "group_assignment: " << group_assignment << std::endl; - std::cout << "converting group_assignment_py" << std::endl; - const py::list& group_assignment_py = py::cast(group_assignment); - - // std::cout << "replication_groups: " << replication_groups << std::endl; - std::cout << "converting replication_groups_py" << std::endl; - const py::list& replication_groups_py = py::cast(replication_groups); - - std::cout << "at xla_mark_sharding_dynamo_custom_op1" << std::endl; + py::list tile_assignment_py = py::list(); + for (int i = 0; i < tile_assignment.size(); i++) { + py::list pylist = py::list(); + for (int64_t t : tile_assignment[i].get().toIntList()) { + pylist.append(t); + } + tile_assignment_py.append(pylist); + } - const xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding( - tile_assignment_py, group_assignment_py, replication_groups_py, - ShardingUtil::ShardingType(sharding_type)); + py::list group_assignment_py = py::list(); + for (int i = 0; i < group_assignment.size(); i++) { + py::list pylist = py::list(); + for (int64_t t : group_assignment[i].get().toIntList()) { + pylist.append(t); + } + group_assignment_py.append(pylist); + } + py::list replication_groups_py = py::list(); + for (int i = 0; i < replication_groups.size(); i++) { + py::list pylist = py::list(); + for (int64_t t : replication_groups[i].get().toIntList()) { + pylist.append(t); + } + replication_groups_py.append(pylist); + } - std::cout << "at xla_mark_sharding_dynamo_custom_op2" << std::endl; + xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding( + tile_assignment_py, group_assignment_py, replication_groups_py, + ShardingUtil::ShardingType(sharding_type)); xla_mark_sharding(input, op_sharding); - - std::cout << "at xla_mark_sharding_dynamo_custom_op3" << std::endl; } // Macro for defining a function that will be run at static initialization time to define a library of operators in the namespace. @@ -1679,12 +1683,27 @@ void InitXlaModuleBindings(py::module m) { xla::OpSharding sharding) { xla_mark_sharding(input, sharding); }); - // m.def("_xla_mark_sharding_dynamo_custom_op", - // [](const at::Tensor& input, xla::OpSharding sharding) { - // // xla_mark_sharding_dynamo_custom_op(input, tile_assignment, group_assignment, replication_groups, sharding_type); - // // at::IntArrayRef tile_assignment, at::IntArrayRef group_assignment, c10::List replication_groups, int64_t sharding_type - // at::IntArrayRef tile_assignment = - // }); + m.def("_xla_mark_sharding_dynamo_custom_op", + [](const at::Tensor& input, const py::list& tile_assignment, + const py::list& group_assignment, + const py::list& replication_groups, int sharding_type) { + c10::List time_assignment_list = c10::List(); + for (auto t : tile_assignment) { + time_assignment_list.push_back(at::IntArrayRef(t.cast>())); + } + + c10::List group_assignment_list = c10::List(); + for (auto t : group_assignment) { + group_assignment_list.push_back(at::IntArrayRef(t.cast>())); + } + + c10::List replication_groups_list = c10::List(); + for (auto t : replication_groups) { + replication_groups_list.push_back(at::IntArrayRef(t.cast>())); + } + + xla_mark_sharding_dynamo_custom_op(input, time_assignment_list, group_assignment_list, replication_groups_list, sharding_type); + }); m.def("_xla_clear_sharding", [](const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); xtensor->ClearShardingSpec(); diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index f08b2370d7d..811823c1415 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -106,14 +106,10 @@ def get_op_sharding(self, sharding_type, tile_assignment, len(partition_spec), replicate_dims) # If flatten = True, return the flattened version of OpSharding - # print each return type to debug - print(tile_assignment.tolist()) - print(group_assignment) - print(replication_groups) if flatten: return (tile_assignment.tolist(), group_assignment, replication_groups, int(sharding_type)) - - return torch_xla._XLAC.OpSharding(tile_assignment.tolist(), + else: + return torch_xla._XLAC.OpSharding(tile_assignment.tolist(), group_assignment, replication_groups, int(sharding_type)) @@ -460,7 +456,7 @@ def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple): @xr.requires_pjrt def mark_sharding( t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, - partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor: + partition_spec: Tuple[Union[Tuple, int, str, None]], dynamo_custom_op: bool = False) -> XLAShardedTensor: """ Annotates the tensor provided with XLA partition spec. Internally, it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass. @@ -480,6 +476,9 @@ def mark_sharding( >> mesh_shape = (4, 2) >> partition_spec = (0, None) + dynamo_custom_op (bool): if set to True, it calls the dynamo custom op variant of mark_sharding + to make itself recognizeable and traceable by dynamo. + Examples —------------------------------ mesh_shape = (4, 2) @@ -520,7 +519,8 @@ def mark_sharding_dynamo_custom_op( t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor: """ - Same functionality as `mark_sharding` above, except this variant uses the custom mark_sharding op in torch_xla._XLAC to allow dynamo to recognize and trace it. + Same functionality as `mark_sharding` above, except this variant uses the custom + mark_sharding op in torch_xla._XLAC to allow dynamo to recognize and trace it. """ num_devices = xr.global_runtime_device_count() assert num_devices > 0, "This requires XLA supported device(s)." @@ -535,12 +535,10 @@ def mark_sharding_dynamo_custom_op( tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding(partition_spec, flatten = True) - print('about to call xla_mark_sharding_dynamo_custom_op') if isinstance(t, XLAShardedTensor): - torch.ops.xla.xla_mark_sharding_dynamo_custom_op(t.global_tensor, tile_assignment, group_assignment, replication_groups, sharding_type) + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t.global_tensor, tile_assignment, group_assignment, replication_groups, sharding_type) return t - torch.ops.xla.xla_mark_sharding_dynamo_custom_op(t, tile_assignment, group_assignment, replication_groups, sharding_type) - print('xla_mark_sharding_dynamo_custom_op call finished') + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t, tile_assignment, group_assignment, replication_groups, sharding_type) return XLAShardedTensor(t) From a98bfb22e33939e3fb399fb018b84551c884f521 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Tue, 7 Nov 2023 05:35:25 +0000 Subject: [PATCH 11/14] Update unit test and run linter --- test/spmd/test_dynamo_spmd.py | 40 +++--- torch_xla/csrc/aten_autograd_ops.cpp | 12 -- torch_xla/csrc/init_python_bindings.cpp | 181 +++++++++++++----------- torch_xla/csrc/ops/xla_ops.cpp | 1 - torch_xla/csrc/ops/xla_ops.h | 1 - torch_xla/csrc/tensor_methods.cpp | 1 - torch_xla/csrc/xla_lower_util.cpp | 7 - torch_xla/csrc/xla_lower_util.h | 4 - torch_xla/experimental/xla_sharding.py | 76 +++++----- 9 files changed, 153 insertions(+), 170 deletions(-) diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 87f9d20b62e..a49d41dfa59 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -15,7 +15,7 @@ class SimpleLinear(nn.Module): - def __init__(self, mark_sharding_inside = False, op_sharding = None): + def __init__(self, mesh=None): super(SimpleLinear, self).__init__() self.fc1 = nn.Linear(128, 128) self.relu = nn.ReLU() @@ -23,11 +23,14 @@ def __init__(self, mark_sharding_inside = False, op_sharding = None): # Add an additional 1x1 layer at the end to ensure the final layer # is not sharded. self.fc3 = nn.Linear(1, 1) + # If mesh is not none, we'll do a mark sharding inside the forward function + # to ensure dynamo can recognize and trace it in a torch compile. + self.mesh = mesh def forward(self, x): - print(f'self.fc2.weight.device={self.fc2.weight.device}') - if self.mark_sharding_inside and self.op_sharding and 'xla' in self.fc2.weight.device: - xs.mark_sharding(self.fc2.weight, self.op_sharding) + if self.mesh and 'xla' in str(self.fc2.weight.device): + xs.mark_sharding( + self.fc2.weight, self.mesh, (1, 0), dynamo_custom_op=True) y = self.relu(self.fc1(x)) z = self.fc2(y) return self.fc3(z) @@ -61,9 +64,10 @@ def test_dynamo_spmd_basic_with_custom_mark_sharding_op(self): linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) - xs.mark_sharding_dynamo_custom_op(linear.fc2.weight, - self._get_mesh((1, self.n_devices)), - (1, 0)) + xs.mark_sharding( + linear.fc2.weight, + self._get_mesh((1, self.n_devices)), (1, 0), + dynamo_custom_op=True) xla_res = linear(xla_x) xm.mark_step() @@ -191,23 +195,19 @@ def test_dynamo_input_sharding_threashold(self): def test_mark_sharding_inside_compile(self): device = xm.xla_device() + mesh = self._get_mesh((1, self.n_devices)) - def fn_simple(t): - xs.mark_sharding_dynamo_custom_op( - t, self._get_mesh((1, self.n_devices)), (0, 1)) - - x = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], - dtype=torch.float, - device=device) - - return t + x + # Passing this `mesh` as a parameter to `SimpleLinear` will call the dynamo custom op + # variant of mark_sharding inside the forward function. + linear = SimpleLinear(mesh=mesh).to(device) + linear.eval() - x_xla = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]]).to(device) - xla_res = fn_simple(x_xla) + xla_x = torch.randn(1, 128, device=device) + xla_res = linear(xla_x) xm.mark_step() - dynamo_fn_simple = torch.compile(fn_simple, backend="openxla") - dynamo_res = dynamo_fn_simple(x_xla) + dynamo_linear = torch.compile(linear, backend="openxla") + dynamo_res = dynamo_linear(xla_x) torch.allclose(xla_res.cpu(), dynamo_res.cpu()) diff --git a/torch_xla/csrc/aten_autograd_ops.cpp b/torch_xla/csrc/aten_autograd_ops.cpp index 9d8edbbb731..81cfdfb4f42 100644 --- a/torch_xla/csrc/aten_autograd_ops.cpp +++ b/torch_xla/csrc/aten_autograd_ops.cpp @@ -253,17 +253,5 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, return grad; } -// TORCH_LIBRARY(xla, m) { -// m.def( -// "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " -// "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", -// torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward))); - -// m.def( -// "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " -// "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " -// "-> Tensor", -// torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward))); -// } } // namespace aten_autograd_ops } // namespace torch_xla diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 8106550c7f3..35385256b1b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -653,72 +653,75 @@ std::string GetPyTypeString(py::handle obj) { } void xla_mark_sharding(const at::Tensor& input, xla::OpSharding sharding) { - TORCH_LAZY_COUNTER("XlaMarkSharding", 1); - XLA_CHECK(UseVirtualDevice()) - << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - auto new_sharding_spec = std::make_shared( - sharding, MakeShapeWithDeviceLayout( - xtensor->shape(), - static_cast(xtensor->GetDevice().type()))); - - // For Non DeviceData IR values, we directly attach the sharding spec - // to the xtensor. - const DeviceData* device_data_node = nullptr; - if (xtensor->CurrentIrValue()) { - device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); - if (!device_data_node) { - tensor_methods::custom_sharding_(xtensor, new_sharding_spec); - return; - } + TORCH_LAZY_COUNTER("XlaMarkSharding", 1); + XLA_CHECK(UseVirtualDevice()) + << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + auto new_sharding_spec = std::make_shared( + sharding, MakeShapeWithDeviceLayout( + xtensor->shape(), + static_cast(xtensor->GetDevice().type()))); + + // For Non DeviceData IR values, we directly attach the sharding spec + // to the xtensor. + const DeviceData* device_data_node = nullptr; + if (xtensor->CurrentIrValue()) { + device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (!device_data_node) { + tensor_methods::custom_sharding_(xtensor, new_sharding_spec); + return; } + } - // For data, we need to deal with the data transfers between - // host and device. - at::Tensor cpu_tensor; - if (xtensor->CurrentTensorData().has_value()) { - TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); - // When virtual device is enabled for SPMD, we defer the initial - // data transfer to the device and retain the original data on the - // host, until the sharded data transfer. - cpu_tensor = xtensor->CurrentTensorData().value(); - } else { - // A new input tensor is not expected to be sharded. But sometimes, - // the same input is called for sharding annotation over multiple steps, - // in which case we can skip if it's the same sharding; however, if it's - // the same input with a different sharding then we block & ask the user - // to clear the existing sharding first. - auto current_sharding_spec = xtensor->sharding_spec(); - if (current_sharding_spec && (current_sharding_spec->sharding.type() != - xla::OpSharding::REPLICATED)) { - XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, - *current_sharding_spec)) - << "Existing annotation must be cleared first."; - return; - } - - // If the at::Tensor data is not present, we need to re-download the - // tensor from the physical device to CPU. In that case, the value - // must be present on the backend device. - XLA_CHECK((xtensor->CurrentDataHandle() && - xtensor->CurrentDataHandle()->HasValue()) || - device_data_node != nullptr) - << "Cannot shard tensor. Data does not present on any device."; - std::vector xla_tensors{xtensor}; - cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; + // For data, we need to deal with the data transfers between + // host and device. + at::Tensor cpu_tensor; + if (xtensor->CurrentTensorData().has_value()) { + TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); + // When virtual device is enabled for SPMD, we defer the initial + // data transfer to the device and retain the original data on the + // host, until the sharded data transfer. + cpu_tensor = xtensor->CurrentTensorData().value(); + } else { + // A new input tensor is not expected to be sharded. But sometimes, + // the same input is called for sharding annotation over multiple steps, + // in which case we can skip if it's the same sharding; however, if it's + // the same input with a different sharding then we block & ask the user + // to clear the existing sharding first. + auto current_sharding_spec = xtensor->sharding_spec(); + if (current_sharding_spec && (current_sharding_spec->sharding.type() != + xla::OpSharding::REPLICATED)) { + XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, + *current_sharding_spec)) + << "Existing annotation must be cleared first."; + return; } - auto xla_data = CreateTensorsData( - std::vector{cpu_tensor}, - std::vector{new_sharding_spec}, - std::vector{GetVirtualDevice().toString()})[0]; - xtensor->SetXlaData(xla_data); - xtensor->SetShardingSpec(*new_sharding_spec); - // Register sharded tensor data. - XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); + // If the at::Tensor data is not present, we need to re-download the + // tensor from the physical device to CPU. In that case, the value + // must be present on the backend device. + XLA_CHECK((xtensor->CurrentDataHandle() && + xtensor->CurrentDataHandle()->HasValue()) || + device_data_node != nullptr) + << "Cannot shard tensor. Data does not present on any device."; + std::vector xla_tensors{xtensor}; + cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; + } + auto xla_data = CreateTensorsData( + std::vector{cpu_tensor}, + std::vector{new_sharding_spec}, + std::vector{GetVirtualDevice().toString()})[0]; + xtensor->SetXlaData(xla_data); + xtensor->SetShardingSpec(*new_sharding_spec); + + // Register sharded tensor data. + XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); } -void xla_mark_sharding_dynamo_custom_op(const at::Tensor& input, c10::List tile_assignment, c10::List group_assignment, c10::List replication_groups, int64_t sharding_type) { +void xla_mark_sharding_dynamo_custom_op( + const at::Tensor& input, c10::List tile_assignment, + c10::List group_assignment, + c10::List replication_groups, int64_t sharding_type) { py::list tile_assignment_py = py::list(); for (int i = 0; i < tile_assignment.size(); i++) { py::list pylist = py::list(); @@ -747,28 +750,36 @@ void xla_mark_sharding_dynamo_custom_op(const at::Tensor& input, c10::List Tensor", - torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); + torch::dispatch( + c10::DispatchKey::XLA, + TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); m.def( "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " "-> Tensor", - torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); + torch::dispatch( + c10::DispatchKey::XLA, + TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); m.def( - "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] tile_assignment, int[][] group_assignment, int[][] replication_groups, int sharding_type) -> ()", - torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(xla_mark_sharding_dynamo_custom_op))); + "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] " + "tile_assignment, int[][] group_assignment, int[][] replication_groups, " + "int sharding_type) -> ()", + torch::dispatch(c10::DispatchKey::XLA, + TORCH_FN(xla_mark_sharding_dynamo_custom_op))); } std::vector check_materialization_helper( @@ -1679,30 +1690,38 @@ void InitXlaModuleBindings(py::module m) { tile_assignment, group_assignment, replication_groups, ShardingUtil::ShardingType(sharding_type)); })); - m.def("_xla_mark_sharding", [](const at::Tensor& input, - xla::OpSharding sharding) { - xla_mark_sharding(input, sharding); - }); + m.def("_xla_mark_sharding", + [](const at::Tensor& input, xla::OpSharding sharding) { + xla_mark_sharding(input, sharding); + }); m.def("_xla_mark_sharding_dynamo_custom_op", [](const at::Tensor& input, const py::list& tile_assignment, - const py::list& group_assignment, - const py::list& replication_groups, int sharding_type) { - c10::List time_assignment_list = c10::List(); + const py::list& group_assignment, const py::list& replication_groups, + int sharding_type) { + c10::List time_assignment_list = + c10::List(); for (auto t : tile_assignment) { - time_assignment_list.push_back(at::IntArrayRef(t.cast>())); + time_assignment_list.push_back( + at::IntArrayRef(t.cast>())); } - c10::List group_assignment_list = c10::List(); + c10::List group_assignment_list = + c10::List(); for (auto t : group_assignment) { - group_assignment_list.push_back(at::IntArrayRef(t.cast>())); + group_assignment_list.push_back( + at::IntArrayRef(t.cast>())); } - c10::List replication_groups_list = c10::List(); + c10::List replication_groups_list = + c10::List(); for (auto t : replication_groups) { - replication_groups_list.push_back(at::IntArrayRef(t.cast>())); + replication_groups_list.push_back( + at::IntArrayRef(t.cast>())); } - xla_mark_sharding_dynamo_custom_op(input, time_assignment_list, group_assignment_list, replication_groups_list, sharding_type); + xla_mark_sharding_dynamo_custom_op( + input, time_assignment_list, group_assignment_list, + replication_groups_list, sharding_type); }); m.def("_xla_clear_sharding", [](const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); diff --git a/torch_xla/csrc/ops/xla_ops.cpp b/torch_xla/csrc/ops/xla_ops.cpp index 7e058573bd0..fa106e5849c 100644 --- a/torch_xla/csrc/ops/xla_ops.cpp +++ b/torch_xla/csrc/ops/xla_ops.cpp @@ -30,6 +30,5 @@ const OpKindWrapper xla_tensor_data("xla::tensor_data"); const OpKindWrapper xla_unselect("xla::unselect"); const OpKindWrapper xla_update_slice("xla::update_slice"); const OpKindWrapper xla_custom_sharding("xla::custom_sharding"); -const OpKindWrapper xla_custom_mark_sharding("xla::custom_mark_sharding"); } // namespace torch_xla diff --git a/torch_xla/csrc/ops/xla_ops.h b/torch_xla/csrc/ops/xla_ops.h index b5dfa31ca4e..fa8082b978a 100644 --- a/torch_xla/csrc/ops/xla_ops.h +++ b/torch_xla/csrc/ops/xla_ops.h @@ -55,7 +55,6 @@ extern const OpKindWrapper xla_tensor_data; extern const OpKindWrapper xla_unselect; extern const OpKindWrapper xla_update_slice; extern const OpKindWrapper xla_custom_sharding; -extern const OpKindWrapper xla_custom_mark_sharding; } // namespace torch_xla diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 5a0d8e50d98..fa54741190d 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -140,7 +140,6 @@ #include "torch_xla/csrc/tensor_ops.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/xla_graph_executor.h" -#include "torch_xla/csrc/xla_sharding_util.h" #include "xla/literal_util.h" namespace torch_xla { diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 567bd12ad5d..374e7569ca0 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1228,11 +1228,4 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) { {input}, ShapeHelper::ShapeOfXlaOp(input)); } -xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device, - const xla::XlaOp& input, - const xla::XlaOp& sharding) { - return xla::CustomCall(input.builder(), /*call_target_name=*/"MarkSharding", - {input}, ShapeHelper::ShapeOfXlaOp(input)); -} - } // namespace torch_xla diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 21970d0af23..252bbe5e31c 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -150,10 +150,6 @@ xla::XlaOp BuildCdistForward(xla::XlaOp x1, xla::XlaOp x2, xla::XlaOp p, xla::XlaOp BuildCustomSharding(const xla::XlaOp& input); -xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device, - const xla::XlaOp& input, - const xla::XlaOp& sharding); - } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_XLA_LOWER_UTIL_H_ diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 811823c1415..84872082b19 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -82,7 +82,8 @@ def get_axis_name_idx(self, name: str) -> int: @functools.lru_cache(maxsize=None) def get_op_sharding(self, - partition_spec: Tuple, flatten = False) -> torch_xla._XLAC.OpSharding: + partition_spec: Tuple, + flatten=False) -> torch_xla._XLAC.OpSharding: """ Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. @@ -104,14 +105,15 @@ def get_op_sharding(self, replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} group_assignment, replication_groups = _get_group_assignment( sharding_type, tile_assignment, len(partition_spec), replicate_dims) - + # If flatten = True, return the flattened version of OpSharding if flatten: - return (tile_assignment.tolist(), group_assignment, replication_groups, int(sharding_type)) + return (tile_assignment.tolist(), group_assignment, replication_groups, + int(sharding_type)) else: return torch_xla._XLAC.OpSharding(tile_assignment.tolist(), - group_assignment, replication_groups, - int(sharding_type)) + group_assignment, replication_groups, + int(sharding_type)) # HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4 @@ -454,9 +456,10 @@ def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple): @xr.requires_pjrt -def mark_sharding( - t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, - partition_spec: Tuple[Union[Tuple, int, str, None]], dynamo_custom_op: bool = False) -> XLAShardedTensor: +def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], + mesh: Mesh, + partition_spec: Tuple[Union[Tuple, int, str, None]], + dynamo_custom_op: bool = False) -> XLAShardedTensor: """ Annotates the tensor provided with XLA partition spec. Internally, it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass. @@ -505,41 +508,28 @@ def mark_sharding( assert len(t.shape) == len(partition_spec), \ f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." - op_sharding = mesh.get_op_sharding(partition_spec) - - if isinstance(t, XLAShardedTensor): - torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) - return t - torch_xla._XLAC._xla_mark_sharding(t, op_sharding) - return XLAShardedTensor(t) - - -@xr.requires_pjrt -def mark_sharding_dynamo_custom_op( - t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, - partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor: - """ - Same functionality as `mark_sharding` above, except this variant uses the custom - mark_sharding op in torch_xla._XLAC to allow dynamo to recognize and trace it. - """ - num_devices = xr.global_runtime_device_count() - assert num_devices > 0, "This requires XLA supported device(s)." - assert mesh.size() == num_devices, \ - f"{mesh.mesh_shape} is not mappable over {num_devices} devices." - # We only allow fully specified `partition_spec` to be applicable, as opposed - # to filling in the unspecified replicated dims. Fully specified `partiion_spec` - # should be of the same rank as `t`. This is to support partial replication - # where the group assignment may vary with different input ranks. - assert len(t.shape) == len(partition_spec), \ - f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." - - tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding(partition_spec, flatten = True) - - if isinstance(t, XLAShardedTensor): - torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t.global_tensor, tile_assignment, group_assignment, replication_groups, sharding_type) - return t - torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t, tile_assignment, group_assignment, replication_groups, sharding_type) - return XLAShardedTensor(t) + if dynamo_custom_op: + tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding( + partition_spec, flatten=True) + + if isinstance(t, XLAShardedTensor): + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op( + t.global_tensor, tile_assignment, group_assignment, + replication_groups, sharding_type) + return t + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t, tile_assignment, + group_assignment, + replication_groups, + sharding_type) + return XLAShardedTensor(t) + else: + op_sharding = mesh.get_op_sharding(partition_spec) + + if isinstance(t, XLAShardedTensor): + torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) + return t + torch_xla._XLAC._xla_mark_sharding(t, op_sharding) + return XLAShardedTensor(t) def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor: From a4318a6c1c009ef108b1cf3ec1a6c587deb0795f Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Tue, 7 Nov 2023 22:55:16 +0000 Subject: [PATCH 12/14] Address comments -- fix typos and variable names --- test/spmd/test_dynamo_spmd.py | 4 +- torch_xla/csrc/init_python_bindings.cpp | 138 +----------------------- torch_xla/csrc/xla_sharding_util.cpp | 135 +++++++++++++++++++++++ torch_xla/csrc/xla_sharding_util.h | 8 ++ torch_xla/experimental/xla_sharding.py | 27 ++--- 5 files changed, 163 insertions(+), 149 deletions(-) diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index a49d41dfa59..41bf3a7d178 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -30,7 +30,7 @@ def __init__(self, mesh=None): def forward(self, x): if self.mesh and 'xla' in str(self.fc2.weight.device): xs.mark_sharding( - self.fc2.weight, self.mesh, (1, 0), dynamo_custom_op=True) + self.fc2.weight, self.mesh, (1, 0), use_dynamo_custom_op=True) y = self.relu(self.fc1(x)) z = self.fc2(y) return self.fc3(z) @@ -67,7 +67,7 @@ def test_dynamo_spmd_basic_with_custom_mark_sharding_op(self): xs.mark_sharding( linear.fc2.weight, self._get_mesh((1, self.n_devices)), (1, 0), - dynamo_custom_op=True) + use_dynamo_custom_op=True) xla_res = linear(xla_x) xm.mark_step() diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 35385256b1b..622c834aed5 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -652,136 +652,6 @@ std::string GetPyTypeString(py::handle obj) { return type; } -void xla_mark_sharding(const at::Tensor& input, xla::OpSharding sharding) { - TORCH_LAZY_COUNTER("XlaMarkSharding", 1); - XLA_CHECK(UseVirtualDevice()) - << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - auto new_sharding_spec = std::make_shared( - sharding, MakeShapeWithDeviceLayout( - xtensor->shape(), - static_cast(xtensor->GetDevice().type()))); - - // For Non DeviceData IR values, we directly attach the sharding spec - // to the xtensor. - const DeviceData* device_data_node = nullptr; - if (xtensor->CurrentIrValue()) { - device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); - if (!device_data_node) { - tensor_methods::custom_sharding_(xtensor, new_sharding_spec); - return; - } - } - - // For data, we need to deal with the data transfers between - // host and device. - at::Tensor cpu_tensor; - if (xtensor->CurrentTensorData().has_value()) { - TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); - // When virtual device is enabled for SPMD, we defer the initial - // data transfer to the device and retain the original data on the - // host, until the sharded data transfer. - cpu_tensor = xtensor->CurrentTensorData().value(); - } else { - // A new input tensor is not expected to be sharded. But sometimes, - // the same input is called for sharding annotation over multiple steps, - // in which case we can skip if it's the same sharding; however, if it's - // the same input with a different sharding then we block & ask the user - // to clear the existing sharding first. - auto current_sharding_spec = xtensor->sharding_spec(); - if (current_sharding_spec && (current_sharding_spec->sharding.type() != - xla::OpSharding::REPLICATED)) { - XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, - *current_sharding_spec)) - << "Existing annotation must be cleared first."; - return; - } - - // If the at::Tensor data is not present, we need to re-download the - // tensor from the physical device to CPU. In that case, the value - // must be present on the backend device. - XLA_CHECK((xtensor->CurrentDataHandle() && - xtensor->CurrentDataHandle()->HasValue()) || - device_data_node != nullptr) - << "Cannot shard tensor. Data does not present on any device."; - std::vector xla_tensors{xtensor}; - cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; - } - auto xla_data = CreateTensorsData( - std::vector{cpu_tensor}, - std::vector{new_sharding_spec}, - std::vector{GetVirtualDevice().toString()})[0]; - xtensor->SetXlaData(xla_data); - xtensor->SetShardingSpec(*new_sharding_spec); - - // Register sharded tensor data. - XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); -} - -void xla_mark_sharding_dynamo_custom_op( - const at::Tensor& input, c10::List tile_assignment, - c10::List group_assignment, - c10::List replication_groups, int64_t sharding_type) { - py::list tile_assignment_py = py::list(); - for (int i = 0; i < tile_assignment.size(); i++) { - py::list pylist = py::list(); - for (int64_t t : tile_assignment[i].get().toIntList()) { - pylist.append(t); - } - tile_assignment_py.append(pylist); - } - - py::list group_assignment_py = py::list(); - for (int i = 0; i < group_assignment.size(); i++) { - py::list pylist = py::list(); - for (int64_t t : group_assignment[i].get().toIntList()) { - pylist.append(t); - } - group_assignment_py.append(pylist); - } - - py::list replication_groups_py = py::list(); - for (int i = 0; i < replication_groups.size(); i++) { - py::list pylist = py::list(); - for (int64_t t : replication_groups[i].get().toIntList()) { - pylist.append(t); - } - replication_groups_py.append(pylist); - } - - xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding( - tile_assignment_py, group_assignment_py, replication_groups_py, - ShardingUtil::ShardingType(sharding_type)); - - xla_mark_sharding(input, op_sharding); -} - -// Macro for defining a function that will be run at static initialization time -// to define a library of operators in the namespace. Used to define a new set -// of custom operators that do not already exist in PyTorch. -TORCH_LIBRARY(xla, m) { - m.def( - "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " - "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", - torch::dispatch( - c10::DispatchKey::XLA, - TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); - - m.def( - "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " - "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " - "-> Tensor", - torch::dispatch( - c10::DispatchKey::XLA, - TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); - m.def( - "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] " - "tile_assignment, int[][] group_assignment, int[][] replication_groups, " - "int sharding_type) -> ()", - torch::dispatch(c10::DispatchKey::XLA, - TORCH_FN(xla_mark_sharding_dynamo_custom_op))); -} - std::vector check_materialization_helper( const std::vector& xtensors) { std::vector need_materialization; @@ -1692,16 +1562,16 @@ void InitXlaModuleBindings(py::module m) { })); m.def("_xla_mark_sharding", [](const at::Tensor& input, xla::OpSharding sharding) { - xla_mark_sharding(input, sharding); + ShardingUtil::xla_mark_sharding(input, sharding); }); m.def("_xla_mark_sharding_dynamo_custom_op", [](const at::Tensor& input, const py::list& tile_assignment, const py::list& group_assignment, const py::list& replication_groups, int sharding_type) { - c10::List time_assignment_list = + c10::List tile_assignment_list = c10::List(); for (auto t : tile_assignment) { - time_assignment_list.push_back( + tile_assignment_list.push_back( at::IntArrayRef(t.cast>())); } @@ -1720,7 +1590,7 @@ void InitXlaModuleBindings(py::module m) { } xla_mark_sharding_dynamo_custom_op( - input, time_assignment_list, group_assignment_list, + input, tile_assignment_list, group_assignment_list, replication_groups_list, sharding_type); }); m.def("_xla_clear_sharding", [](const at::Tensor& input) { diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index cde74256eee..10fe03fb174 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -6,6 +6,8 @@ #include #include "torch/csrc/lazy/core/ir_util.h" +#include "torch_xla/csrc/aten_autograd_ops.h" +#include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/device_data.h" @@ -14,7 +16,9 @@ #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/tensor.h" +#include "torch_xla/csrc/tensor_methods.h" #include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/xla_graph_executor.h" #include "tsl/profiler/lib/traceme.h" #include "xla/execution_options_util.h" #include "xla/hlo/ir/hlo_module.h" @@ -742,4 +746,135 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( source_tensors, GetVirtualDevice().toString(), global_shape, sharding); } +void ShardingUtil::xla_mark_sharding(const at::Tensor& input, + xla::OpSharding sharding) { + TORCH_LAZY_COUNTER("XlaMarkSharding", 1); + XLA_CHECK(UseVirtualDevice()) + << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + auto new_sharding_spec = std::make_shared( + sharding, MakeShapeWithDeviceLayout( + xtensor->shape(), + static_cast(xtensor->GetDevice().type()))); + + // For Non DeviceData IR values, we directly attach the sharding spec + // to the xtensor. + const DeviceData* device_data_node = nullptr; + if (xtensor->CurrentIrValue()) { + device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (!device_data_node) { + tensor_methods::custom_sharding_(xtensor, new_sharding_spec); + return; + } + } + + // For data, we need to deal with the data transfers between + // host and device. + at::Tensor cpu_tensor; + if (xtensor->CurrentTensorData().has_value()) { + TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); + // When virtual device is enabled for SPMD, we defer the initial + // data transfer to the device and retain the original data on the + // host, until the sharded data transfer. + cpu_tensor = xtensor->CurrentTensorData().value(); + } else { + // A new input tensor is not expected to be sharded. But sometimes, + // the same input is called for sharding annotation over multiple steps, + // in which case we can skip if it's the same sharding; however, if it's + // the same input with a different sharding then we block & ask the user + // to clear the existing sharding first. + auto current_sharding_spec = xtensor->sharding_spec(); + if (current_sharding_spec && (current_sharding_spec->sharding.type() != + xla::OpSharding::REPLICATED)) { + XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, + *current_sharding_spec)) + << "Existing annotation must be cleared first."; + return; + } + + // If the at::Tensor data is not present, we need to re-download the + // tensor from the physical device to CPU. In that case, the value + // must be present on the backend device. + XLA_CHECK((xtensor->CurrentDataHandle() && + xtensor->CurrentDataHandle()->HasValue()) || + device_data_node != nullptr) + << "Cannot shard tensor. Data does not present on any device."; + std::vector xla_tensors{xtensor}; + cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; + } + auto xla_data = CreateTensorsData( + std::vector{cpu_tensor}, + std::vector{new_sharding_spec}, + std::vector{GetVirtualDevice().toString()})[0]; + xtensor->SetXlaData(xla_data); + xtensor->SetShardingSpec(*new_sharding_spec); + + // Register sharded tensor data. + XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); +} + +void xla_mark_sharding_dynamo_custom_op( + const at::Tensor& input, c10::List tile_assignment, + c10::List group_assignment, + c10::List replication_groups, int64_t sharding_type) { + py::list tile_assignment_py = py::list(); + for (int i = 0; i < tile_assignment.size(); i++) { + py::list pylist = py::list(); + for (int64_t t : tile_assignment[i].get().toIntList()) { + pylist.append(t); + } + tile_assignment_py.append(pylist); + } + + py::list group_assignment_py = py::list(); + for (int i = 0; i < group_assignment.size(); i++) { + py::list pylist = py::list(); + for (int64_t t : group_assignment[i].get().toIntList()) { + pylist.append(t); + } + group_assignment_py.append(pylist); + } + + py::list replication_groups_py = py::list(); + for (int i = 0; i < replication_groups.size(); i++) { + py::list pylist = py::list(); + for (int64_t t : replication_groups[i].get().toIntList()) { + pylist.append(t); + } + replication_groups_py.append(pylist); + } + + xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding( + tile_assignment_py, group_assignment_py, replication_groups_py, + ShardingUtil::ShardingType(sharding_type)); + + ShardingUtil::xla_mark_sharding(input, op_sharding); +} + +// Macro for defining a function that will be run at static initialization time +// to define a library of operators in the namespace. Used to define a new set +// of custom operators that do not already exist in PyTorch. +TORCH_LIBRARY(xla, m) { + m.def( + "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " + "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", + torch::dispatch( + c10::DispatchKey::XLA, + TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); + + m.def( + "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " + "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " + "-> Tensor", + torch::dispatch( + c10::DispatchKey::XLA, + TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); + m.def( + "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] " + "tile_assignment, int[][] group_assignment, int[][] replication_groups, " + "int sharding_type) -> ()", + torch::dispatch(c10::DispatchKey::XLA, + TORCH_FN(torch_xla::xla_mark_sharding_dynamo_custom_op))); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 32060c7fc09..3e600be6871 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -150,8 +150,16 @@ class ShardingUtil { const std::vector& shards, const std::vector& devices, const XLATensor::ShardingSpecPtr& sharding_spec); + + static void xla_mark_sharding(const at::Tensor& input, + xla::OpSharding sharding); }; +void xla_mark_sharding_dynamo_custom_op( + const at::Tensor& input, c10::List tile_assignment, + c10::List group_assignment, + c10::List replication_groups, int64_t sharding_type); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_XLA_SHARDING_UTIL_H_ diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 84872082b19..f6bec1acdb7 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -83,7 +83,7 @@ def get_axis_name_idx(self, name: str) -> int: @functools.lru_cache(maxsize=None) def get_op_sharding(self, partition_spec: Tuple, - flatten=False) -> torch_xla._XLAC.OpSharding: + flatten_opsharding = False) -> torch_xla._XLAC.OpSharding: """ Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. @@ -106,8 +106,8 @@ def get_op_sharding(self, group_assignment, replication_groups = _get_group_assignment( sharding_type, tile_assignment, len(partition_spec), replicate_dims) - # If flatten = True, return the flattened version of OpSharding - if flatten: + # If flatten_opsharding = True, return the flattened version of OpSharding + if flatten_opsharding: return (tile_assignment.tolist(), group_assignment, replication_groups, int(sharding_type)) else: @@ -459,7 +459,7 @@ def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple): def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Union[Tuple, int, str, None]], - dynamo_custom_op: bool = False) -> XLAShardedTensor: + use_dynamo_custom_op: bool = False) -> XLAShardedTensor: """ Annotates the tensor provided with XLA partition spec. Internally, it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass. @@ -508,28 +508,29 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], assert len(t.shape) == len(partition_spec), \ f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." - if dynamo_custom_op: + if use_dynamo_custom_op: tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding( - partition_spec, flatten=True) + partition_spec, flatten_opsharding=True) if isinstance(t, XLAShardedTensor): torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op( t.global_tensor, tile_assignment, group_assignment, replication_groups, sharding_type) return t - torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t, tile_assignment, - group_assignment, - replication_groups, - sharding_type) - return XLAShardedTensor(t) + else: + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op( + t, tile_assignment, group_assignment, replication_groups, + sharding_type) + return XLAShardedTensor(t) else: op_sharding = mesh.get_op_sharding(partition_spec) if isinstance(t, XLAShardedTensor): torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) return t - torch_xla._XLAC._xla_mark_sharding(t, op_sharding) - return XLAShardedTensor(t) + else: + torch_xla._XLAC._xla_mark_sharding(t, op_sharding) + return XLAShardedTensor(t) def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor: From e35ca6413f71467f8f32d33c60163e55ec8ceffc Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Wed, 8 Nov 2023 20:03:09 +0000 Subject: [PATCH 13/14] Run linter --- torch_xla/experimental/xla_sharding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index f6bec1acdb7..1b12513fc2e 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -83,7 +83,7 @@ def get_axis_name_idx(self, name: str) -> int: @functools.lru_cache(maxsize=None) def get_op_sharding(self, partition_spec: Tuple, - flatten_opsharding = False) -> torch_xla._XLAC.OpSharding: + flatten_opsharding=False) -> torch_xla._XLAC.OpSharding: """ Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. From ae721ae2dbfb9f1c2db148b3251a81c42478fc58 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Thu, 9 Nov 2023 05:32:28 +0000 Subject: [PATCH 14/14] Add metric assertions to unit tests --- test/spmd/test_dynamo_spmd.py | 43 ++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 41bf3a7d178..2874d5783bc 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -59,22 +59,6 @@ def test_dynamo_spmd_basic(self): # TODO(JackCaoG): add counter checks after ExecuteReplicated also creates # a ExecuteMetric. - def test_dynamo_spmd_basic_with_custom_mark_sharding_op(self): - device = xm.xla_device() - linear = SimpleLinear().to(device) - linear.eval() - xla_x = torch.randn(1, 128, device=device) - xs.mark_sharding( - linear.fc2.weight, - self._get_mesh((1, self.n_devices)), (1, 0), - use_dynamo_custom_op=True) - xla_res = linear(xla_x) - xm.mark_step() - - dynamo_linear = torch.compile(linear, backend="openxla") - dynamo_res = dynamo_linear(xla_x) - torch.allclose(xla_res.cpu(), dynamo_res.cpu()) - def test_dynamo_spmd_output_sharding_spec(self): device = xm.xla_device() linear = SimpleLinear().to(device) @@ -193,7 +177,29 @@ def test_dynamo_input_sharding_threashold(self): else: del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] + def test_dynamo_spmd_mark_sharding_outside_of_compile(self): + device = xm.xla_device() + linear = SimpleLinear().to(device) + linear.eval() + xla_x = torch.randn(1, 128, device=device) + xs.mark_sharding( + linear.fc2.weight, + self._get_mesh((1, self.n_devices)), (1, 0), + use_dynamo_custom_op=True) + xla_res = linear(xla_x) + xm.mark_step() + + dynamo_linear = torch.compile(linear, backend="openxla") + dynamo_res = dynamo_linear(xla_x) + torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + + # Ensure that another run with same input does not trigger additional compilation + compile_count = met.metric_data('CompileTime')[0] + dynamo_res = dynamo_linear(xla_x) + self.assertEqual(met.metric_data('CompileTime')[0], compile_count) + def test_mark_sharding_inside_compile(self): + met.clear_counters() device = xm.xla_device() mesh = self._get_mesh((1, self.n_devices)) @@ -210,6 +216,11 @@ def test_mark_sharding_inside_compile(self): dynamo_res = dynamo_linear(xla_x) torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + # Ensure that another run with same input does not trigger additional compilation + compile_count = met.metric_data('CompileTime')[0] + dynamo_res = dynamo_linear(xla_x) + self.assertEqual(met.metric_data('CompileTime')[0], compile_count) + if __name__ == '__main__': test = unittest.main()