diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 35385256b1ba..622c834aed54 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -652,136 +652,6 @@ std::string GetPyTypeString(py::handle obj) { return type; } -void xla_mark_sharding(const at::Tensor& input, xla::OpSharding sharding) { - TORCH_LAZY_COUNTER("XlaMarkSharding", 1); - XLA_CHECK(UseVirtualDevice()) - << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - auto new_sharding_spec = std::make_shared( - sharding, MakeShapeWithDeviceLayout( - xtensor->shape(), - static_cast(xtensor->GetDevice().type()))); - - // For Non DeviceData IR values, we directly attach the sharding spec - // to the xtensor. - const DeviceData* device_data_node = nullptr; - if (xtensor->CurrentIrValue()) { - device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); - if (!device_data_node) { - tensor_methods::custom_sharding_(xtensor, new_sharding_spec); - return; - } - } - - // For data, we need to deal with the data transfers between - // host and device. - at::Tensor cpu_tensor; - if (xtensor->CurrentTensorData().has_value()) { - TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); - // When virtual device is enabled for SPMD, we defer the initial - // data transfer to the device and retain the original data on the - // host, until the sharded data transfer. - cpu_tensor = xtensor->CurrentTensorData().value(); - } else { - // A new input tensor is not expected to be sharded. But sometimes, - // the same input is called for sharding annotation over multiple steps, - // in which case we can skip if it's the same sharding; however, if it's - // the same input with a different sharding then we block & ask the user - // to clear the existing sharding first. - auto current_sharding_spec = xtensor->sharding_spec(); - if (current_sharding_spec && (current_sharding_spec->sharding.type() != - xla::OpSharding::REPLICATED)) { - XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, - *current_sharding_spec)) - << "Existing annotation must be cleared first."; - return; - } - - // If the at::Tensor data is not present, we need to re-download the - // tensor from the physical device to CPU. In that case, the value - // must be present on the backend device. - XLA_CHECK((xtensor->CurrentDataHandle() && - xtensor->CurrentDataHandle()->HasValue()) || - device_data_node != nullptr) - << "Cannot shard tensor. Data does not present on any device."; - std::vector xla_tensors{xtensor}; - cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; - } - auto xla_data = CreateTensorsData( - std::vector{cpu_tensor}, - std::vector{new_sharding_spec}, - std::vector{GetVirtualDevice().toString()})[0]; - xtensor->SetXlaData(xla_data); - xtensor->SetShardingSpec(*new_sharding_spec); - - // Register sharded tensor data. - XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); -} - -void xla_mark_sharding_dynamo_custom_op( - const at::Tensor& input, c10::List tile_assignment, - c10::List group_assignment, - c10::List replication_groups, int64_t sharding_type) { - py::list tile_assignment_py = py::list(); - for (int i = 0; i < tile_assignment.size(); i++) { - py::list pylist = py::list(); - for (int64_t t : tile_assignment[i].get().toIntList()) { - pylist.append(t); - } - tile_assignment_py.append(pylist); - } - - py::list group_assignment_py = py::list(); - for (int i = 0; i < group_assignment.size(); i++) { - py::list pylist = py::list(); - for (int64_t t : group_assignment[i].get().toIntList()) { - pylist.append(t); - } - group_assignment_py.append(pylist); - } - - py::list replication_groups_py = py::list(); - for (int i = 0; i < replication_groups.size(); i++) { - py::list pylist = py::list(); - for (int64_t t : replication_groups[i].get().toIntList()) { - pylist.append(t); - } - replication_groups_py.append(pylist); - } - - xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding( - tile_assignment_py, group_assignment_py, replication_groups_py, - ShardingUtil::ShardingType(sharding_type)); - - xla_mark_sharding(input, op_sharding); -} - -// Macro for defining a function that will be run at static initialization time -// to define a library of operators in the namespace. Used to define a new set -// of custom operators that do not already exist in PyTorch. -TORCH_LIBRARY(xla, m) { - m.def( - "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " - "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", - torch::dispatch( - c10::DispatchKey::XLA, - TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); - - m.def( - "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " - "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " - "-> Tensor", - torch::dispatch( - c10::DispatchKey::XLA, - TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); - m.def( - "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] " - "tile_assignment, int[][] group_assignment, int[][] replication_groups, " - "int sharding_type) -> ()", - torch::dispatch(c10::DispatchKey::XLA, - TORCH_FN(xla_mark_sharding_dynamo_custom_op))); -} - std::vector check_materialization_helper( const std::vector& xtensors) { std::vector need_materialization; @@ -1692,16 +1562,16 @@ void InitXlaModuleBindings(py::module m) { })); m.def("_xla_mark_sharding", [](const at::Tensor& input, xla::OpSharding sharding) { - xla_mark_sharding(input, sharding); + ShardingUtil::xla_mark_sharding(input, sharding); }); m.def("_xla_mark_sharding_dynamo_custom_op", [](const at::Tensor& input, const py::list& tile_assignment, const py::list& group_assignment, const py::list& replication_groups, int sharding_type) { - c10::List time_assignment_list = + c10::List tile_assignment_list = c10::List(); for (auto t : tile_assignment) { - time_assignment_list.push_back( + tile_assignment_list.push_back( at::IntArrayRef(t.cast>())); } @@ -1720,7 +1590,7 @@ void InitXlaModuleBindings(py::module m) { } xla_mark_sharding_dynamo_custom_op( - input, time_assignment_list, group_assignment_list, + input, tile_assignment_list, group_assignment_list, replication_groups_list, sharding_type); }); m.def("_xla_clear_sharding", [](const at::Tensor& input) { diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index cde74256eeee..10fe03fb1749 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -6,6 +6,8 @@ #include #include "torch/csrc/lazy/core/ir_util.h" +#include "torch_xla/csrc/aten_autograd_ops.h" +#include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/device_data.h" @@ -14,7 +16,9 @@ #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/tensor.h" +#include "torch_xla/csrc/tensor_methods.h" #include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/xla_graph_executor.h" #include "tsl/profiler/lib/traceme.h" #include "xla/execution_options_util.h" #include "xla/hlo/ir/hlo_module.h" @@ -742,4 +746,135 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( source_tensors, GetVirtualDevice().toString(), global_shape, sharding); } +void ShardingUtil::xla_mark_sharding(const at::Tensor& input, + xla::OpSharding sharding) { + TORCH_LAZY_COUNTER("XlaMarkSharding", 1); + XLA_CHECK(UseVirtualDevice()) + << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + auto new_sharding_spec = std::make_shared( + sharding, MakeShapeWithDeviceLayout( + xtensor->shape(), + static_cast(xtensor->GetDevice().type()))); + + // For Non DeviceData IR values, we directly attach the sharding spec + // to the xtensor. + const DeviceData* device_data_node = nullptr; + if (xtensor->CurrentIrValue()) { + device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (!device_data_node) { + tensor_methods::custom_sharding_(xtensor, new_sharding_spec); + return; + } + } + + // For data, we need to deal with the data transfers between + // host and device. + at::Tensor cpu_tensor; + if (xtensor->CurrentTensorData().has_value()) { + TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); + // When virtual device is enabled for SPMD, we defer the initial + // data transfer to the device and retain the original data on the + // host, until the sharded data transfer. + cpu_tensor = xtensor->CurrentTensorData().value(); + } else { + // A new input tensor is not expected to be sharded. But sometimes, + // the same input is called for sharding annotation over multiple steps, + // in which case we can skip if it's the same sharding; however, if it's + // the same input with a different sharding then we block & ask the user + // to clear the existing sharding first. + auto current_sharding_spec = xtensor->sharding_spec(); + if (current_sharding_spec && (current_sharding_spec->sharding.type() != + xla::OpSharding::REPLICATED)) { + XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, + *current_sharding_spec)) + << "Existing annotation must be cleared first."; + return; + } + + // If the at::Tensor data is not present, we need to re-download the + // tensor from the physical device to CPU. In that case, the value + // must be present on the backend device. + XLA_CHECK((xtensor->CurrentDataHandle() && + xtensor->CurrentDataHandle()->HasValue()) || + device_data_node != nullptr) + << "Cannot shard tensor. Data does not present on any device."; + std::vector xla_tensors{xtensor}; + cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; + } + auto xla_data = CreateTensorsData( + std::vector{cpu_tensor}, + std::vector{new_sharding_spec}, + std::vector{GetVirtualDevice().toString()})[0]; + xtensor->SetXlaData(xla_data); + xtensor->SetShardingSpec(*new_sharding_spec); + + // Register sharded tensor data. + XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); +} + +void xla_mark_sharding_dynamo_custom_op( + const at::Tensor& input, c10::List tile_assignment, + c10::List group_assignment, + c10::List replication_groups, int64_t sharding_type) { + py::list tile_assignment_py = py::list(); + for (int i = 0; i < tile_assignment.size(); i++) { + py::list pylist = py::list(); + for (int64_t t : tile_assignment[i].get().toIntList()) { + pylist.append(t); + } + tile_assignment_py.append(pylist); + } + + py::list group_assignment_py = py::list(); + for (int i = 0; i < group_assignment.size(); i++) { + py::list pylist = py::list(); + for (int64_t t : group_assignment[i].get().toIntList()) { + pylist.append(t); + } + group_assignment_py.append(pylist); + } + + py::list replication_groups_py = py::list(); + for (int i = 0; i < replication_groups.size(); i++) { + py::list pylist = py::list(); + for (int64_t t : replication_groups[i].get().toIntList()) { + pylist.append(t); + } + replication_groups_py.append(pylist); + } + + xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding( + tile_assignment_py, group_assignment_py, replication_groups_py, + ShardingUtil::ShardingType(sharding_type)); + + ShardingUtil::xla_mark_sharding(input, op_sharding); +} + +// Macro for defining a function that will be run at static initialization time +// to define a library of operators in the namespace. Used to define a new set +// of custom operators that do not already exist in PyTorch. +TORCH_LIBRARY(xla, m) { + m.def( + "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " + "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", + torch::dispatch( + c10::DispatchKey::XLA, + TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); + + m.def( + "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " + "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " + "-> Tensor", + torch::dispatch( + c10::DispatchKey::XLA, + TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); + m.def( + "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] " + "tile_assignment, int[][] group_assignment, int[][] replication_groups, " + "int sharding_type) -> ()", + torch::dispatch(c10::DispatchKey::XLA, + TORCH_FN(torch_xla::xla_mark_sharding_dynamo_custom_op))); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 32060c7fc098..3e600be68715 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -150,8 +150,16 @@ class ShardingUtil { const std::vector& shards, const std::vector& devices, const XLATensor::ShardingSpecPtr& sharding_spec); + + static void xla_mark_sharding(const at::Tensor& input, + xla::OpSharding sharding); }; +void xla_mark_sharding_dynamo_custom_op( + const at::Tensor& input, c10::List tile_assignment, + c10::List group_assignment, + c10::List replication_groups, int64_t sharding_type); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_XLA_SHARDING_UTIL_H_ diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 84872082b198..fd050dcd73b2 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -83,7 +83,7 @@ def get_axis_name_idx(self, name: str) -> int: @functools.lru_cache(maxsize=None) def get_op_sharding(self, partition_spec: Tuple, - flatten=False) -> torch_xla._XLAC.OpSharding: + flatten_opsharding = False) -> torch_xla._XLAC.OpSharding: """ Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. @@ -107,7 +107,7 @@ def get_op_sharding(self, sharding_type, tile_assignment, len(partition_spec), replicate_dims) # If flatten = True, return the flattened version of OpSharding - if flatten: + if flatten_opsharding: return (tile_assignment.tolist(), group_assignment, replication_groups, int(sharding_type)) else: @@ -459,7 +459,7 @@ def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple): def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Union[Tuple, int, str, None]], - dynamo_custom_op: bool = False) -> XLAShardedTensor: + use_dynamo_custom_op: bool = False) -> XLAShardedTensor: """ Annotates the tensor provided with XLA partition spec. Internally, it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass. @@ -508,7 +508,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], assert len(t.shape) == len(partition_spec), \ f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." - if dynamo_custom_op: + if use_dynamo_custom_op: tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding( partition_spec, flatten=True) @@ -517,19 +517,21 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], t.global_tensor, tile_assignment, group_assignment, replication_groups, sharding_type) return t - torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t, tile_assignment, - group_assignment, - replication_groups, - sharding_type) - return XLAShardedTensor(t) + else: + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t, tile_assignment, + group_assignment, + replication_groups, + sharding_type) + return XLAShardedTensor(t) else: op_sharding = mesh.get_op_sharding(partition_spec) if isinstance(t, XLAShardedTensor): torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) return t - torch_xla._XLAC._xla_mark_sharding(t, op_sharding) - return XLAShardedTensor(t) + else: + torch_xla._XLAC._xla_mark_sharding(t, op_sharding) + return XLAShardedTensor(t) def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor: