From 08d6296c91d4cd25fcd3ac8d4372e3f00693341e Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Tue, 31 Oct 2023 23:38:10 +0000 Subject: [PATCH] Clean up some code --- .torch_pin | 1 - test/spmd/test_dynamo_spmd.py | 32 +++- torch_xla/csrc/aten_autograd_ops.cpp | 22 +-- torch_xla/csrc/aten_autograd_ops.h | 11 ++ torch_xla/csrc/init_python_bindings.cpp | 191 +++++++++++++------- torch_xla/csrc/ops/custom_mark_sharding.cpp | 34 ---- torch_xla/csrc/ops/custom_mark_sharding.h | 23 --- torch_xla/csrc/tensor_methods.cpp | 71 -------- torch_xla/csrc/tensor_methods.h | 2 - torch_xla/experimental/xla_sharding.py | 19 +- 10 files changed, 188 insertions(+), 218 deletions(-) delete mode 100644 .torch_pin delete mode 100644 torch_xla/csrc/ops/custom_mark_sharding.cpp delete mode 100644 torch_xla/csrc/ops/custom_mark_sharding.h diff --git a/.torch_pin b/.torch_pin deleted file mode 100644 index 22aa007a16d..00000000000 --- a/.torch_pin +++ /dev/null @@ -1 +0,0 @@ -#112483 diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 7b835a26754..f3abaec8978 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -200,11 +200,35 @@ def fn_simple(x): device = xm.xla_device() x_xla = torch.zeros((1, 8)).to(device) xla_res = fn_simple(x_xla) - xm.mark_step() + print(xla_res) + # xm.mark_step() - dynamo_linear = torch.compile(fn_simple, backend="openxla") - dynamo_res = dynamo_linear(x_xla) - torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + # dynamo_linear = torch.compile(fn_simple, backend="openxla") + # dynamo_res = dynamo_linear(x_xla) + # torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + + # TODO (@wonjoo) Remove this test, this is just for debugging + def test_wonjoo(self): + + def fn_simple(x): + print(f'x inside fn_simple before: {x}') + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(x_xla, [0], [0], [0], 0) + print(f'x inside fn_simple after: {x}') + return x + + device = xm.xla_device() + + x_xla = torch.zeros((1, 8)).to(device) + + # print(torch.ops.xla.add) + print(torch.ops.xla.max_pool2d_forward) + print(torch.ops.xla.xla_mark_sharding_dynamo_custom_op) + print(dir(torch.ops.xla.xla_mark_sharding_dynamo_custom_op)) + # print(f'x_xla before: {x_xla}') + + # dynamo_fn = torch.compile(fn_simple, backend="openxla") + # dynamo_res = dynamo_fn(x_xla) + # print(f'dynamo_res: {dynamo_res}') if __name__ == '__main__': diff --git a/torch_xla/csrc/aten_autograd_ops.cpp b/torch_xla/csrc/aten_autograd_ops.cpp index 08c6b27f92c..9d8edbbb731 100644 --- a/torch_xla/csrc/aten_autograd_ops.cpp +++ b/torch_xla/csrc/aten_autograd_ops.cpp @@ -253,17 +253,17 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, return grad; } -TORCH_LIBRARY(xla, m) { - m.def( - "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " - "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", - torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward))); +// TORCH_LIBRARY(xla, m) { +// m.def( +// "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " +// "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", +// torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward))); - m.def( - "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " - "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " - "-> Tensor", - torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward))); -} +// m.def( +// "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " +// "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " +// "-> Tensor", +// torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward))); +// } } // namespace aten_autograd_ops } // namespace torch_xla diff --git a/torch_xla/csrc/aten_autograd_ops.h b/torch_xla/csrc/aten_autograd_ops.h index be063b76620..d1cc8a98048 100644 --- a/torch_xla/csrc/aten_autograd_ops.h +++ b/torch_xla/csrc/aten_autograd_ops.h @@ -46,6 +46,17 @@ struct MaxPool3dAutogradFunction torch::autograd::variable_list grad_output); }; +torch::Tensor max_pool2d_forward(torch::Tensor self, + torch::IntArrayRef kernel_size, + torch::IntArrayRef stride, + torch::IntArrayRef padding, + torch::IntArrayRef dilation, bool ceil_mode); + +torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, + torch::IntArrayRef kernel_size, + torch::IntArrayRef stride, + torch::IntArrayRef padding, bool ceil_mode); + } // namespace aten_autograd_ops } // namespace torch_xla diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a865b7ad7d0..b602a686e72 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -29,6 +29,7 @@ #include "pybind11/pytypes.h" #include "pybind11/stl_bind.h" #include "torch_xla/csrc/XLANativeFunctions.h" +#include "torch_xla/csrc/aten_autograd_ops.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/helpers.h" @@ -651,6 +652,121 @@ std::string GetPyTypeString(py::handle obj) { return type; } +void xla_mark_sharding(const at::Tensor& input, xla::OpSharding sharding) { + TORCH_LAZY_COUNTER("XlaMarkSharding", 1); + XLA_CHECK(UseVirtualDevice()) + << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + auto new_sharding_spec = std::make_shared( + sharding, MakeShapeWithDeviceLayout( + xtensor->shape(), + static_cast(xtensor->GetDevice().type()))); + + // For Non DeviceData IR values, we directly attach the sharding spec + // to the xtensor. + const DeviceData* device_data_node = nullptr; + if (xtensor->CurrentIrValue()) { + device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (!device_data_node) { + tensor_methods::custom_sharding_(xtensor, new_sharding_spec); + return; + } + } + + // For data, we need to deal with the data transfers between + // host and device. + at::Tensor cpu_tensor; + if (xtensor->CurrentTensorData().has_value()) { + TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); + // When virtual device is enabled for SPMD, we defer the initial + // data transfer to the device and retain the original data on the + // host, until the sharded data transfer. + cpu_tensor = xtensor->CurrentTensorData().value(); + } else { + // A new input tensor is not expected to be sharded. But sometimes, + // the same input is called for sharding annotation over multiple steps, + // in which case we can skip if it's the same sharding; however, if it's + // the same input with a different sharding then we block & ask the user + // to clear the existing sharding first. + auto current_sharding_spec = xtensor->sharding_spec(); + if (current_sharding_spec && (current_sharding_spec->sharding.type() != + xla::OpSharding::REPLICATED)) { + XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, + *current_sharding_spec)) + << "Existing annotation must be cleared first."; + return; + } + + // If the at::Tensor data is not present, we need to re-download the + // tensor from the physical device to CPU. In that case, the value + // must be present on the backend device. + XLA_CHECK((xtensor->CurrentDataHandle() && + xtensor->CurrentDataHandle()->HasValue()) || + device_data_node != nullptr) + << "Cannot shard tensor. Data does not present on any device."; + std::vector xla_tensors{xtensor}; + cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; + } + auto xla_data = CreateTensorsData( + std::vector{cpu_tensor}, + std::vector{new_sharding_spec}, + std::vector{GetVirtualDevice().toString()})[0]; + xtensor->SetXlaData(xla_data); + xtensor->SetShardingSpec(*new_sharding_spec); + + // Register sharded tensor data. + XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); +} + +void xla_mark_sharding_dynamo_custom_op(const at::Tensor& input, c10::List tile_assignment, c10::List group_assignment, c10::List replication_groups, int64_t sharding_type) { + std::cout << "at xla_mark_sharding_dynamo_custom_op0" << std::endl; + + std::cout << "input: " << input << std::endl; + // std::cout << "tile_assignment: " << tile_assignment << std::endl; + std::cout << "converting tile_assignment_py" << std::endl; + // const py::list& tile_assignment_py = py::cast(tile_assignment[0]); + const py::list& tile_assignment_py = py::cast(torch::lazy::ToVector(tile_assignment[0])); + + // std::cout << "group_assignment: " << group_assignment << std::endl; + std::cout << "converting group_assignment_py" << std::endl; + const py::list& group_assignment_py = py::cast(group_assignment); + + // std::cout << "replication_groups: " << replication_groups << std::endl; + std::cout << "converting replication_groups_py" << std::endl; + const py::list& replication_groups_py = py::cast(replication_groups); + + std::cout << "at xla_mark_sharding_dynamo_custom_op1" << std::endl; + + const xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding( + tile_assignment_py, group_assignment_py, replication_groups_py, + ShardingUtil::ShardingType(sharding_type)); + + + std::cout << "at xla_mark_sharding_dynamo_custom_op2" << std::endl; + + xla_mark_sharding(input, op_sharding); + + std::cout << "at xla_mark_sharding_dynamo_custom_op3" << std::endl; +} + +// Macro for defining a function that will be run at static initialization time to define a library of operators in the namespace. +// Used to define a new set of custom operators that do not already exist in PyTorch. +TORCH_LIBRARY(xla, m) { + m.def( + "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " + "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", + torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); + + m.def( + "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " + "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " + "-> Tensor", + torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); + m.def( + "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] tile_assignment, int[][] group_assignment, int[][] replication_groups, int sharding_type) -> ()", + torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(xla_mark_sharding_dynamo_custom_op))); +} + std::vector check_materialization_helper( const std::vector& xtensors) { std::vector need_materialization; @@ -1561,75 +1677,14 @@ void InitXlaModuleBindings(py::module m) { })); m.def("_xla_mark_sharding", [](const at::Tensor& input, xla::OpSharding sharding) { - TORCH_LAZY_COUNTER("XlaMarkSharding", 1); - XLA_CHECK(UseVirtualDevice()) - << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - auto new_sharding_spec = std::make_shared( - sharding, MakeShapeWithDeviceLayout( - xtensor->shape(), - static_cast(xtensor->GetDevice().type()))); - - // For Non DeviceData IR values, we directly attach the sharding spec - // to the xtensor. - const DeviceData* device_data_node = nullptr; - if (xtensor->CurrentIrValue()) { - device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); - if (!device_data_node) { - tensor_methods::custom_sharding_(xtensor, new_sharding_spec); - return; - } - } - - // For data, we need to deal with the data transfers between - // host and device. - at::Tensor cpu_tensor; - if (xtensor->CurrentTensorData().has_value()) { - TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); - // When virtual device is enabled for SPMD, we defer the initial - // data transfer to the device and retain the original data on the - // host, until the sharded data transfer. - cpu_tensor = xtensor->CurrentTensorData().value(); - } else { - // A new input tensor is not expected to be sharded. But sometimes, - // the same input is called for sharding annotation over multiple steps, - // in which case we can skip if it's the same sharding; however, if it's - // the same input with a different sharding then we block & ask the user - // to clear the existing sharding first. - auto current_sharding_spec = xtensor->sharding_spec(); - if (current_sharding_spec && (current_sharding_spec->sharding.type() != - xla::OpSharding::REPLICATED)) { - XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, - *current_sharding_spec)) - << "Existing annotation must be cleared first."; - return; - } - - // If the at::Tensor data is not present, we need to re-download the - // tensor from the physical device to CPU. In that case, the value - // must be present on the backend device. - XLA_CHECK((xtensor->CurrentDataHandle() && - xtensor->CurrentDataHandle()->HasValue()) || - device_data_node != nullptr) - << "Cannot shard tensor. Data does not present on any device."; - std::vector xla_tensors{xtensor}; - cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; - } - auto xla_data = CreateTensorsData( - std::vector{cpu_tensor}, - std::vector{new_sharding_spec}, - std::vector{GetVirtualDevice().toString()})[0]; - xtensor->SetXlaData(xla_data); - xtensor->SetShardingSpec(*new_sharding_spec); - - // Register sharded tensor data. - XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); + xla_mark_sharding(input, sharding); }); - m.def("_xla_mark_sharding_dynamo_custom_op", - [](const at::Tensor& input, xla::OpSharding sharding) { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - tensor_methods::custom_mark_sharding(xtensor, sharding); - }); + // m.def("_xla_mark_sharding_dynamo_custom_op", + // [](const at::Tensor& input, xla::OpSharding sharding) { + // // xla_mark_sharding_dynamo_custom_op(input, tile_assignment, group_assignment, replication_groups, sharding_type); + // // at::IntArrayRef tile_assignment, at::IntArrayRef group_assignment, c10::List replication_groups, int64_t sharding_type + // at::IntArrayRef tile_assignment = + // }); m.def("_xla_clear_sharding", [](const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); xtensor->ClearShardingSpec(); diff --git a/torch_xla/csrc/ops/custom_mark_sharding.cpp b/torch_xla/csrc/ops/custom_mark_sharding.cpp deleted file mode 100644 index af9302590f4..00000000000 --- a/torch_xla/csrc/ops/custom_mark_sharding.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "torch_xla/csrc/ops/custom_mark_sharding.h" - -#include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/ops/xla_ops.h" -#include "torch_xla/csrc/xla_lower_util.h" - -namespace torch_xla { - -CustomMarkSharding::CustomMarkSharding(const torch::lazy::Value& input, - const torch::lazy::Value& sharding) - : XlaNode(xla_custom_mark_sharding, {input, sharding}, GetXlaShape(input), - /*num_outputs=*/1, - torch::lazy::MHash(std::string("MarkSharding"))) {} - -torch::lazy::NodePtr CustomMarkSharding::Clone( - torch::lazy::OpList operands) const { - return torch::lazy::MakeNode(operands.at(0), - operands.at(1)); -} - -XlaOpVector CustomMarkSharding::Lower(LoweringContext* loctx) const { - xla::XlaOp input = loctx->GetOutputOp(operand(0)); - xla::XlaOp sharding = loctx->GetOutputOp(operand(1)); - return ReturnOp(BuildCustomMarkSharding(loctx->device(), input, sharding), - loctx); -} - -std::string CustomMarkSharding::ToString() const { - std::stringstream ss; - ss << XlaNode::ToString() << ", MarkSharding"; - return ss.str(); -} - -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/custom_mark_sharding.h b/torch_xla/csrc/ops/custom_mark_sharding.h deleted file mode 100644 index a23323c9bbf..00000000000 --- a/torch_xla/csrc/ops/custom_mark_sharding.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef XLA_TORCH_XLA_CSRC_OPS_CUSTOM_MARK_SHARDING_H_ -#define XLA_TORCH_XLA_CSRC_OPS_CUSTOM_MARK_SHARDING_H_ - -#include "torch_xla/csrc/ir.h" - -namespace torch_xla { - -class CustomMarkSharding : public XlaNode { - public: - // Make a custom call to Sharding. - CustomMarkSharding(const torch::lazy::Value& input, - const torch::lazy::Value& sharding); - - torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; - - XlaOpVector Lower(LoweringContext* loctx) const override; - - std::string ToString() const override; -}; - -} // namespace torch_xla - -#endif // XLA_TORCH_XLA_CSRC_OPS_CUSTOM_MARK_SHARDING_H_ diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index f2e80980c4f..5a0d8e50d98 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -39,7 +39,6 @@ #include "torch_xla/csrc/ops/count_nonzero.h" #include "torch_xla/csrc/ops/cumprod.h" #include "torch_xla/csrc/ops/cumsum.h" -#include "torch_xla/csrc/ops/custom_mark_sharding.h" #include "torch_xla/csrc/ops/custom_sharding.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/diagonal.h" @@ -444,76 +443,6 @@ void custom_sharding_( input->SetShardingSpec(*sharding_spec); } -void custom_mark_sharding(const XLATensorPtr& input, xla::OpSharding sharding) { - // TODO (@wonjoo) Do we need this `sharding` here? - input->SetInPlaceIrValue(torch::lazy::MakeNode( - input->GetIrValue(), input->GetIrValue())); - - TORCH_LAZY_COUNTER("XlaMarkSharding", 1); - XLA_CHECK(UseVirtualDevice()) - << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; - // XLATensorPtr xtensor = bridge::GetXlaTensor(input); - auto new_sharding_spec = std::make_shared( - sharding, MakeShapeWithDeviceLayout( - input->shape(), - static_cast(input->GetDevice().type()))); - - // For Non DeviceData IR values, we directly attach the sharding spec - // to the xtensor. - const DeviceData* device_data_node = nullptr; - if (input->CurrentIrValue()) { - device_data_node = DeviceData::Cast(input->CurrentIrValue().node.get()); - if (!device_data_node) { - tensor_methods::custom_sharding_(input, new_sharding_spec); - return; - } - } - - // For data, we need to deal with the data transfers between - // host and device. - at::Tensor cpu_tensor; - if (input->CurrentTensorData().has_value()) { - TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); - // When virtual device is enabled for SPMD, we defer the initial - // data transfer to the device and retain the original data on the - // host, until the sharded data transfer. - cpu_tensor = input->CurrentTensorData().value(); - } else { - // A new input tensor is not expected to be sharded. But sometimes, - // the same input is called for sharding annotation over multiple steps, - // in which case we can skip if it's the same sharding; however, if it's - // the same input with a different sharding then we block & ask the user - // to clear the existing sharding first. - auto current_sharding_spec = input->sharding_spec(); - if (current_sharding_spec && (current_sharding_spec->sharding.type() != - xla::OpSharding::REPLICATED)) { - XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, - *current_sharding_spec)) - << "Existing annotation must be cleared first."; - return; - } - - // If the at::Tensor data is not present, we need to re-download the - // tensor from the physical device to CPU. In that case, the value - // must be present on the backend device. - XLA_CHECK((input->CurrentDataHandle() && - input->CurrentDataHandle()->HasValue()) || - device_data_node != nullptr) - << "Cannot shard tensor. Data does not present on any device."; - std::vector xla_tensors{input}; - cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; - } - auto xla_data = CreateTensorsData( - std::vector{cpu_tensor}, - std::vector{new_sharding_spec}, - std::vector{GetVirtualDevice().toString()})[0]; - input->SetXlaData(xla_data); - input->SetShardingSpec(*new_sharding_spec); - - // Register sharded tensor data. - XLAGraphExecutor::Get()->RegisterTensor(input->data()); -} - XLATensorPtr get_dimensions_size(const XLATensorPtr& input, std::vector dimensions) { return input->CreateFrom(torch::lazy::MakeNode( diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 0f2bf21f0c4..5a714170300 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -57,8 +57,6 @@ std::pair collective_permute( void custom_sharding_(const XLATensorPtr& input, const std::shared_ptr& spec); -void custom_mark_sharding(const XLATensorPtr& input, xla::OpSharding sharding); - XLATensorPtr get_dimensions_size(const XLATensorPtr& input, std::vector dimensions); diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 31c44e3c9d3..f08b2370d7d 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -82,7 +82,7 @@ def get_axis_name_idx(self, name: str) -> int: @functools.lru_cache(maxsize=None) def get_op_sharding(self, - partition_spec: Tuple) -> torch_xla._XLAC.OpSharding: + partition_spec: Tuple, flatten = False) -> torch_xla._XLAC.OpSharding: """ Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. @@ -104,6 +104,15 @@ def get_op_sharding(self, replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} group_assignment, replication_groups = _get_group_assignment( sharding_type, tile_assignment, len(partition_spec), replicate_dims) + + # If flatten = True, return the flattened version of OpSharding + # print each return type to debug + print(tile_assignment.tolist()) + print(group_assignment) + print(replication_groups) + if flatten: + return (tile_assignment.tolist(), group_assignment, replication_groups, int(sharding_type)) + return torch_xla._XLAC.OpSharding(tile_assignment.tolist(), group_assignment, replication_groups, int(sharding_type)) @@ -524,12 +533,14 @@ def mark_sharding_dynamo_custom_op( assert len(t.shape) == len(partition_spec), \ f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." - op_sharding = mesh.get_op_sharding(partition_spec) + tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding(partition_spec, flatten = True) + print('about to call xla_mark_sharding_dynamo_custom_op') if isinstance(t, XLAShardedTensor): - torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) + torch.ops.xla.xla_mark_sharding_dynamo_custom_op(t.global_tensor, tile_assignment, group_assignment, replication_groups, sharding_type) return t - torch_xla._XLAC._xla_mark_sharding(t, op_sharding) + torch.ops.xla.xla_mark_sharding_dynamo_custom_op(t, tile_assignment, group_assignment, replication_groups, sharding_type) + print('xla_mark_sharding_dynamo_custom_op call finished') return XLAShardedTensor(t)