diff --git a/test/test_operations.py b/test/test_operations.py index 97a7a950972..16956dea8de 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2449,17 +2449,17 @@ class TestDLPack(parameterized.TestCase): def _test_dlpack_capsule_conversion_helper(self, xla_tensor): dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule - got = xdlpack.from_dlpack(dlpt) + xla_tensor2 = xdlpack.from_dlpack(dlpt) - self.assertEqual(xla_tensor.device, got.device) - self.assertTrue(torch.allclose(xla_tensor.cpu(), got.cpu())) + self.assertEqual(xla_tensor.device, xla_tensor2.device) + self.assertTrue(torch.allclose(xla_tensor.cpu(), xla_tensor2.cpu())) self.assertRaisesRegex(RuntimeError, "DLTensor capsule can be consumed only once", lambda: xdlpack.from_dlpack(dlpt)) self.assertEqual( torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor), - torch_xla._XLAC._unsafe_buffer_pointer(got)) + torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor2)) @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA @@ -2492,8 +2492,6 @@ def test_dlpack_roundtrip(self, dtype): xla_tensor_3 = torch.arange(5, dtype=dtype, device=xm.xla_device()) xm.mark_step() - # Without the `wait_device_ops()`, the pjrt buffer (pjrt_data->buffer) at https://github.com/pytorch/xla/blob/e3fc03314dab5f44e3ed9ccbba6c15fbca3285cd/torch_xla/csrc/runtime/pjrt_computation_client.cc#L467 will be nullptr. - xm.wait_device_ops() self._test_dlpack_capsule_conversion_helper(xla_tensor_3) @onlyIfTorchSupportsCUDA diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index 3c0371efcf6..5790b7c2112 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -19,47 +19,25 @@ namespace torch_xla { -std::shared_ptr get_data_handle( - const at::Tensor& input) { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - XLA_CHECK(xtensor) << "The input has to be an XLA tensor."; - if (xtensor->CurrentDataHandle() != nullptr) { - TF_VLOG(4) << "The xla tensor has a current data handle."; - return std::dynamic_pointer_cast( - xtensor->CurrentDataHandle()); - } else if (xtensor->CurrentIrValue().node != nullptr) { - DeviceData* device_data = - DeviceData::Cast(xtensor->CurrentIrValue().node.get()); - if (device_data != nullptr) { - return UnwrapXlaData(device_data->data()); - } - TF_VLOG(4) << "The xla tensor has IR value but does not have device data."; - } - TF_VLOG(4) - << "The xla tensor either has no current data handle or has no IR value."; - return nullptr; -} - -struct TorchXLADLMTensor { - ~TorchXLADLMTensor(); +struct DLPackTensor { + ~DLPackTensor(); std::unique_ptr external_reference; std::shared_ptr buffer_reference; - // at::Tensor source_tensor; std::vector shape; std::vector strides; DLManagedTensor tensor; }; -TorchXLADLMTensor::~TorchXLADLMTensor() { +DLPackTensor::~DLPackTensor() { if (external_reference) { external_reference.reset(nullptr); } } -void TorchXLADLMTensorDeleter(DLManagedTensor* t) { +void DLPackTensorDeleter(DLManagedTensor* t) { if (t) { - delete static_cast(t->manager_ctx); + delete static_cast(t->manager_ctx); } } @@ -151,7 +129,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { XLA_ERROR() << "Unimplemented. DynamicShape is not implemented in DLPack."; } - auto pack = std::make_unique(); + auto pack = std::make_unique(); DLTensor& dt = pack->tensor.dl_tensor; { // AcquireExternalReference may block @@ -163,11 +141,10 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { XLA_CHECK_OK(status); } pack->buffer_reference = pjrt_buffer; - // pack->source_tensor = input; dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer(); pack->tensor.manager_ctx = pack.get(); - pack->tensor.deleter = TorchXLADLMTensorDeleter; + pack->tensor.deleter = DLPackTensorDeleter; dt.device = DLDeviceForDevice(*pjrt_buffer->device()); dt.device.device_id = pjrt_buffer->device()->local_hardware_id(); dt.ndim = pjrt_buffer->dimensions().size(); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index ba3e9baf8c8..d46f2712166 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -482,8 +482,17 @@ std::shared_ptr PjRtComputationClient::GetPjRtBuffer( const DataPtr handle) { std::shared_ptr pjrt_data = std::dynamic_pointer_cast(handle); + XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString(); - return pjrt_data->buffer; + std::shared_ptr pjrt_buffer = pjrt_data->buffer; + if (pjrt_buffer != nullptr) { + return pjrt_buffer; + } else { + TF_VLOG(3) << "The pjrt buffer is null so we need to wait for device ops " + "to finish."; + WaitDeviceOps({}); + return std::dynamic_pointer_cast(handle)->buffer; + } } std::vector PjRtComputationClient::TransferFromDevice( diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index dd13bd63d1b..870f6945973 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -17,6 +17,7 @@ #include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/layout_manager.h" +#include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/runtime.h" @@ -931,4 +932,25 @@ xla::PrimitiveType GetShapeDimensionType( return xla::PrimitiveType::S32; } +std::shared_ptr get_data_handle( + const at::Tensor& input) { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLA_CHECK(xtensor) << "The input has to be an XLA tensor."; + if (xtensor->CurrentDataHandle() != nullptr) { + TF_VLOG(4) << "The xla tensor has a current data handle."; + return std::dynamic_pointer_cast( + xtensor->CurrentDataHandle()); + } else if (xtensor->CurrentIrValue().node != nullptr) { + DeviceData* device_data = + DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (device_data != nullptr) { + return UnwrapXlaData(device_data->data()); + } + TF_VLOG(4) << "The xla tensor has IR value but does not have device data."; + } + TF_VLOG(4) + << "The xla tensor either has no current data handle or has no IR value."; + return nullptr; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index 7d726c00b50..0804d3e9f78 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -212,6 +212,9 @@ inline std::vector xla_expand_outplace(at::TensorList to_expand) { } } +std::shared_ptr get_data_handle( + const at::Tensor& input); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_TENSOR_UTIL_H_