Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed May 13, 2024
1 parent 4eb243f commit 1f1eeeb
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 37 deletions.
10 changes: 4 additions & 6 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2449,17 +2449,17 @@ class TestDLPack(parameterized.TestCase):

def _test_dlpack_capsule_conversion_helper(self, xla_tensor):
dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule
got = xdlpack.from_dlpack(dlpt)
xla_tensor2 = xdlpack.from_dlpack(dlpt)

self.assertEqual(xla_tensor.device, got.device)
self.assertTrue(torch.allclose(xla_tensor.cpu(), got.cpu()))
self.assertEqual(xla_tensor.device, xla_tensor2.device)
self.assertTrue(torch.allclose(xla_tensor.cpu(), xla_tensor2.cpu()))
self.assertRaisesRegex(RuntimeError,
"DLTensor capsule can be consumed only once",
lambda: xdlpack.from_dlpack(dlpt))

self.assertEqual(
torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor),
torch_xla._XLAC._unsafe_buffer_pointer(got))
torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor2))

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
Expand Down Expand Up @@ -2492,8 +2492,6 @@ def test_dlpack_roundtrip(self, dtype):

xla_tensor_3 = torch.arange(5, dtype=dtype, device=xm.xla_device())
xm.mark_step()
# Without the `wait_device_ops()`, the pjrt buffer (pjrt_data->buffer) at https://github.com/pytorch/xla/blob/e3fc03314dab5f44e3ed9ccbba6c15fbca3285cd/torch_xla/csrc/runtime/pjrt_computation_client.cc#L467 will be nullptr.
xm.wait_device_ops()
self._test_dlpack_capsule_conversion_helper(xla_tensor_3)

@onlyIfTorchSupportsCUDA
Expand Down
37 changes: 7 additions & 30 deletions torch_xla/csrc/dl_convertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,47 +19,25 @@

namespace torch_xla {

std::shared_ptr<runtime::ComputationClient::Data> get_data_handle(
const at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
XLA_CHECK(xtensor) << "The input has to be an XLA tensor.";
if (xtensor->CurrentDataHandle() != nullptr) {
TF_VLOG(4) << "The xla tensor has a current data handle.";
return std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
xtensor->CurrentDataHandle());
} else if (xtensor->CurrentIrValue().node != nullptr) {
DeviceData* device_data =
DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (device_data != nullptr) {
return UnwrapXlaData(device_data->data());
}
TF_VLOG(4) << "The xla tensor has IR value but does not have device data.";
}
TF_VLOG(4)
<< "The xla tensor either has no current data handle or has no IR value.";
return nullptr;
}

struct TorchXLADLMTensor {
~TorchXLADLMTensor();
struct DLPackTensor {
~DLPackTensor();
std::unique_ptr<xla::PjRtBuffer::ExternalReference> external_reference;
std::shared_ptr<xla::PjRtBuffer> buffer_reference;
// at::Tensor source_tensor;

std::vector<int64_t> shape;
std::vector<int64_t> strides;
DLManagedTensor tensor;
};

TorchXLADLMTensor::~TorchXLADLMTensor() {
DLPackTensor::~DLPackTensor() {
if (external_reference) {
external_reference.reset(nullptr);
}
}

void TorchXLADLMTensorDeleter(DLManagedTensor* t) {
void DLPackTensorDeleter(DLManagedTensor* t) {
if (t) {
delete static_cast<TorchXLADLMTensor*>(t->manager_ctx);
delete static_cast<DLPackTensor*>(t->manager_ctx);
}
}

Expand Down Expand Up @@ -151,7 +129,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
XLA_ERROR() << "Unimplemented. DynamicShape is not implemented in DLPack.";
}

auto pack = std::make_unique<TorchXLADLMTensor>();
auto pack = std::make_unique<DLPackTensor>();
DLTensor& dt = pack->tensor.dl_tensor;
{
// AcquireExternalReference may block
Expand All @@ -163,11 +141,10 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
XLA_CHECK_OK(status);
}
pack->buffer_reference = pjrt_buffer;
// pack->source_tensor = input;

dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer();
pack->tensor.manager_ctx = pack.get();
pack->tensor.deleter = TorchXLADLMTensorDeleter;
pack->tensor.deleter = DLPackTensorDeleter;
dt.device = DLDeviceForDevice(*pjrt_buffer->device());
dt.device.device_id = pjrt_buffer->device()->local_hardware_id();
dt.ndim = pjrt_buffer->dimensions().size();
Expand Down
11 changes: 10 additions & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,17 @@ std::shared_ptr<xla::PjRtBuffer> PjRtComputationClient::GetPjRtBuffer(
const DataPtr handle) {
std::shared_ptr<PjRtData> pjrt_data =
std::dynamic_pointer_cast<PjRtData>(handle);

XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString();
return pjrt_data->buffer;
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer = pjrt_data->buffer;
if (pjrt_buffer != nullptr) {
return pjrt_buffer;
} else {
TF_VLOG(3) << "The pjrt buffer is null so we need to wait for device ops "
"to finish.";
WaitDeviceOps({});
return std::dynamic_pointer_cast<PjRtData>(handle)->buffer;
}
}

std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice(
Expand Down
22 changes: 22 additions & 0 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "torch_xla/csrc/dtype.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/runtime.h"
Expand Down Expand Up @@ -931,4 +932,25 @@ xla::PrimitiveType GetShapeDimensionType(
return xla::PrimitiveType::S32;
}

std::shared_ptr<runtime::ComputationClient::Data> get_data_handle(
const at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
XLA_CHECK(xtensor) << "The input has to be an XLA tensor.";
if (xtensor->CurrentDataHandle() != nullptr) {
TF_VLOG(4) << "The xla tensor has a current data handle.";
return std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
xtensor->CurrentDataHandle());
} else if (xtensor->CurrentIrValue().node != nullptr) {
DeviceData* device_data =
DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (device_data != nullptr) {
return UnwrapXlaData(device_data->data());
}
TF_VLOG(4) << "The xla tensor has IR value but does not have device data.";
}
TF_VLOG(4)
<< "The xla tensor either has no current data handle or has no IR value.";
return nullptr;
}

} // namespace torch_xla
3 changes: 3 additions & 0 deletions torch_xla/csrc/tensor_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ inline std::vector<at::Tensor> xla_expand_outplace(at::TensorList to_expand) {
}
}

std::shared_ptr<runtime::ComputationClient::Data> get_data_handle(
const at::Tensor& input);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_TENSOR_UTIL_H_

0 comments on commit 1f1eeeb

Please sign in to comment.