Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed May 10, 2024
1 parent 8e08c84 commit 4eb243f
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 84 deletions.
31 changes: 21 additions & 10 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2448,22 +2448,33 @@ def test_unsafe_buffer_pointer(self):
class TestDLPack(parameterized.TestCase):

def _test_dlpack_capsule_conversion_helper(self, xla_tensor):
dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule
dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule
got = xdlpack.from_dlpack(dlpt)

self.assertEqual(xla_tensor.device, got.device)
self.assertTrue(torch.allclose(xla_tensor.cpu(), got.cpu()))
self.assertRaisesRegex(RuntimeError, "DLTensor capsule can be consumed only once", lambda: xdlpack.from_dlpack(dlpt))
self.assertRaisesRegex(RuntimeError,
"DLTensor capsule can be consumed only once",
lambda: xdlpack.from_dlpack(dlpt))

self.assertEqual(torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor),torch_xla._XLAC._unsafe_buffer_pointer(got))
self.assertEqual(
torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor),
torch_xla._XLAC._unsafe_buffer_pointer(got))

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
@parameterized.parameters(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
@parameterized.parameters(*all_types_and_complex_and(torch.half,
torch.bfloat16,
torch.bool, torch.uint16,
torch.uint32,
torch.uint64))
def test_dlpack_roundtrip(self, dtype):
# "arange_cpu" not implemented for complex64 and complex128.
# xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) failed with `RuntimeError: false INTERNAL ASSERT FAILED at "/ansible/pytorch/torch/csrc/lazy/core/hash.h":139, please report a bug to PyTorch. Unsupported scalar type:UInt64`, similar to other uint.
if dtype in { torch.complex128, torch.complex64, torch.uint64, torch.uint32, torch.uint16, torch.bool }:
if dtype in {
torch.complex128, torch.complex64, torch.uint64, torch.uint32,
torch.uint16, torch.bool
}:
return
xla_device = xm.xla_device()
xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device)
Expand All @@ -2489,7 +2500,7 @@ def test_dlpack_roundtrip(self, dtype):
@onlyIfPJRTDeviceIsCUDA
def test_dlpack_roundtrip_bool(self):
xla_tensor = torch.ones(1, dtype=torch.bool).to(xm.xla_device())
self._test_dlpack_capsule_conversion_helper(xla_tensor)
self._test_dlpack_capsule_conversion_helper(xla_tensor)

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
Expand Down Expand Up @@ -2529,7 +2540,10 @@ def test_dlpack_non_default_layout(self):
self.assertTrue(torch.allclose(t2.cpu(), xla_t2.cpu()))

t3 = cuda_t[:, 0]
self.assertRaisesRegex(RuntimeError, r"Only DLPack tensors with trivial \(compact\) striding are supported", lambda: xdlpack.from_dlpack(t3.__dlpack__()))
self.assertRaisesRegex(
RuntimeError,
r"Only DLPack tensors with trivial \(compact\) striding are supported",
lambda: xdlpack.from_dlpack(t3.__dlpack__()))

t4 = cuda_t[1, :]
xla_t4 = xdlpack.from_dlpack(t4.__dlpack__())
Expand All @@ -2540,9 +2554,6 @@ def test_dlpack_non_default_layout(self):
self.assertTrue(torch.allclose(t5.cpu(), xla_t5.cpu()))





class SimpleModelWithDropout(torch.nn.Module):

def __init__(self):
Expand Down
104 changes: 64 additions & 40 deletions torch_xla/csrc/dl_convertor.cpp
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
#include "torch_xla/csrc/dl_convertor.h"

#include "absl/types/span.h"
#include <ATen/DLConvertor.h>

#include "torch_xla/csrc/tensor.h"
#include "absl/types/span.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/unwrap_data.h"
#include "torch_xla/csrc/tensor.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/unwrap_data.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/status.h"

namespace torch_xla {

std::shared_ptr<runtime::ComputationClient::Data> get_data_handle(const at::Tensor& input) {
std::shared_ptr<runtime::ComputationClient::Data> get_data_handle(
const at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
XLA_CHECK(xtensor) << "The input has to be an XLA tensor.";
if (xtensor->CurrentDataHandle() != nullptr) {
TF_VLOG(4) << "The xla tensor has a current data handle.";
return std::dynamic_pointer_cast<runtime::ComputationClient::Data>(xtensor->CurrentDataHandle());
return std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
xtensor->CurrentDataHandle());
} else if (xtensor->CurrentIrValue().node != nullptr) {
DeviceData* device_data =
DeviceData::Cast(xtensor->CurrentIrValue().node.get());
Expand All @@ -33,7 +35,8 @@ std::shared_ptr<runtime::ComputationClient::Data> get_data_handle(const at::Tens
}
TF_VLOG(4) << "The xla tensor has IR value but does not have device data.";
}
TF_VLOG(4) << "The xla tensor either has no current data handle or has no IR value.";
TF_VLOG(4)
<< "The xla tensor either has no current data handle or has no IR value.";
return nullptr;
}

Expand Down Expand Up @@ -66,7 +69,8 @@ DLDeviceType DLDeviceTypeForDevice(const xla::PjRtDevice& device) {
} else if (device.client()->platform_id() == xla::CudaId()) {
return DLDeviceType::kDLCUDA;
}
XLA_ERROR() << "Device " << device.DebugString() << " cannot be used as a DLPack device.";
XLA_ERROR() << "Device " << device.DebugString()
<< " cannot be used as a DLPack device.";
}

DLDevice DLDeviceForDevice(const xla::PjRtDevice& device) {
Expand Down Expand Up @@ -109,7 +113,8 @@ DLDataType PrimitiveTypeToDLDataType(xla::PrimitiveType type) {
case xla::PrimitiveType::C128:
return DLDataType{kDLComplex, 128, 1};
default:
XLA_ERROR() << "XLA type " << xla::PrimitiveType_Name(type) << " has no DLPack equivalent";
XLA_ERROR() << "XLA type " << xla::PrimitiveType_Name(type)
<< " has no DLPack equivalent";
}
}

Expand All @@ -129,14 +134,18 @@ std::vector<int64_t> StridesForShape(xla::PrimitiveType element_type,

// Convert an XLA tensor to a dlPack tensor.
DLManagedTensor* toDLPack(const at::Tensor& input) {
std::shared_ptr<runtime::ComputationClient::Data> handle = get_data_handle(input);
XLA_CHECK(handle != nullptr) << "Could not extract a valid data handle from the input tensor";
std::shared_ptr<runtime::ComputationClient::Data> handle =
get_data_handle(input);
XLA_CHECK(handle != nullptr)
<< "Could not extract a valid data handle from the input tensor";

std::shared_ptr<xla::PjRtBuffer> pjrt_buffer = runtime::GetComputationClient()->GetPjRtBuffer(handle);
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer =
runtime::GetComputationClient()->GetPjRtBuffer(handle);
XLA_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer";

if (pjrt_buffer->IsTuple()) {
XLA_ERROR() << "Unimplemented. BufferToDLPackManagedTensor is not implemented for tuple buffers.";
XLA_ERROR() << "Unimplemented. BufferToDLPackManagedTensor is not "
"implemented for tuple buffers.";
}
if (pjrt_buffer->has_dynamic_dimensions()) {
XLA_ERROR() << "Unimplemented. DynamicShape is not implemented in DLPack.";
Expand Down Expand Up @@ -164,9 +173,11 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
dt.ndim = pjrt_buffer->dimensions().size();
dt.dtype = PrimitiveTypeToDLDataType(pjrt_buffer->element_type());

pack->shape = std::vector<int64_t>(pjrt_buffer->dimensions().begin(), pjrt_buffer->dimensions().end());
pack->shape = std::vector<int64_t>(pjrt_buffer->dimensions().begin(),
pjrt_buffer->dimensions().end());
xla::Layout xla_layout = xla::GetXlaLayoutUnsafe(pjrt_buffer->layout());
pack->strides = StridesForShape(pjrt_buffer->element_type(), pjrt_buffer->dimensions(), xla_layout);
pack->strides = StridesForShape(pjrt_buffer->element_type(),
pjrt_buffer->dimensions(), xla_layout);
dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
dt.byte_offset = 0;
Expand All @@ -177,21 +188,25 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
absl::StatusOr<xla::PjRtDevice*> DeviceForDLDevice(const DLDevice& context) {
switch (context.device_type) {
case DLDeviceType::kDLCPU:
XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), xla::CpuId());
return runtime::GetComputationClient()->LookupAddressableDevice(context.device_id);
XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(),
xla::CpuId());
return runtime::GetComputationClient()->LookupAddressableDevice(
context.device_id);
case DLDeviceType::kDLCUDA:
XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), xla::CudaId());
return runtime::GetComputationClient()->LookupAddressableDevice(context.device_id);
XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(),
xla::CudaId());
return runtime::GetComputationClient()->LookupAddressableDevice(
context.device_id);
default:
return tsl::errors::InvalidArgument("Unknown/unsupported DLPack device type %d",
context.device_type);
return tsl::errors::InvalidArgument(
"Unknown/unsupported DLPack device type %d", context.device_type);
}
}

absl::StatusOr<xla::PrimitiveType> DLDataTypeToPrimitiveType(DLDataType type) {
if (type.lanes != 1) {
return tsl::errors::Unimplemented("DLPack types with lanes != 1 not implemented, got %d",
type.lanes);
return tsl::errors::Unimplemented(
"DLPack types with lanes != 1 not implemented, got %d", type.lanes);
}
switch (type.code) {
case kDLBool:
Expand Down Expand Up @@ -265,7 +280,8 @@ absl::StatusOr<xla::PrimitiveType> DLDataTypeToPrimitiveType(DLDataType type) {
type.bits);
}
default:
return tsl::errors::Unimplemented("Unknown or invalid DLPack type code %d", type.code);
return tsl::errors::Unimplemented(
"Unknown or invalid DLPack type code %d", type.code);
}
}

Expand Down Expand Up @@ -302,43 +318,51 @@ absl::StatusOr<std::vector<int64_t>> StridesToLayout(

at::Tensor fromDLPack(DLManagedTensor* dlmt) {
if (dlmt->dl_tensor.ndim < 0) {
XLA_ERROR() << "Number of dimensions in DLManagedTensor must be nonnegative, got " << dlmt->dl_tensor.ndim;
XLA_ERROR()
<< "Number of dimensions in DLManagedTensor must be nonnegative, got "
<< dlmt->dl_tensor.ndim;
}
xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value();
absl::Span<int64_t const> dimensions(
const_cast<int64_t*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
xla::PrimitiveType element_type = DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype).value();
xla::PrimitiveType element_type =
DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype).value();

std::vector<int64_t> minor_to_major;
if (dlmt->dl_tensor.strides &&
absl::c_find(dimensions, 0) == dimensions.end()) {
absl::Span<int64_t const> strides(
const_cast<int64_t*>(dlmt->dl_tensor.strides),
dlmt->dl_tensor.ndim);
const_cast<int64_t*>(dlmt->dl_tensor.strides), dlmt->dl_tensor.ndim);
minor_to_major = StridesToLayout(dimensions, strides).value();
} else {
minor_to_major.resize(dlmt->dl_tensor.ndim);
std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0);
}
xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions,
minor_to_major);
xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout(
element_type, dimensions, minor_to_major);

std::function<void()> on_delete_callback;
if (dlmt->deleter) {
on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
}
xla::StatusOr<std::unique_ptr<xla::PjRtBuffer>> pjrt_buffer = device->client()->CreateViewOfDeviceBuffer(
static_cast<char*>(dlmt->dl_tensor.data) +
dlmt->dl_tensor.byte_offset,
shape, device, on_delete_callback);
XLA_CHECK_OK(pjrt_buffer.status()) << "Failed to create a pjrt buffer in " << __FUNCTION__;
XLA_CHECK(pjrt_buffer.value() != nullptr) << "pjrt buffer is null in " << __FUNCTION__;
xla::StatusOr<std::unique_ptr<xla::PjRtBuffer>> pjrt_buffer =
device->client()->CreateViewOfDeviceBuffer(
static_cast<char*>(dlmt->dl_tensor.data) +
dlmt->dl_tensor.byte_offset,
shape, device, on_delete_callback);
XLA_CHECK_OK(pjrt_buffer.status())
<< "Failed to create a pjrt buffer in " << __FUNCTION__;
XLA_CHECK(pjrt_buffer.value() != nullptr)
<< "pjrt buffer is null in " << __FUNCTION__;

runtime::ComputationClient::DataPtr data =
runtime::GetComputationClient()->CreateData(
runtime::GetComputationClient()->PjRtDeviceToString(device), shape,
std::move(pjrt_buffer.value()));

runtime::ComputationClient::DataPtr data = runtime::GetComputationClient()->CreateData(runtime::GetComputationClient()->PjRtDeviceToString(device), shape, std::move(pjrt_buffer.value()));

at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype);
XLATensorPtr xla_tensor = XLATensor::Create(data, tensor_type);
return bridge::AtenFromXlaTensor(xla_tensor);
}

}
} // namespace torch_xla
2 changes: 1 addition & 1 deletion torch_xla/csrc/dl_convertor.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#ifndef XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_
#define XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_

#include <ATen/dlpack.h>
#include <ATen/Tensor.h>
#include <ATen/dlpack.h>

namespace torch_xla {

Expand Down
18 changes: 11 additions & 7 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <ATen/dlpack.h>
#include <Python.h>
#include <c10/core/Device.h>
#include <c10/util/Optional.h>
#include <ATen/dlpack.h>
#include <google/protobuf/text_format.h>
#include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/autograd/variable.h>
Expand Down Expand Up @@ -35,8 +35,8 @@
#include "torch_xla/csrc/aten_autograd_ops.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/dtype.h"
#include "torch_xla/csrc/dl_convertor.h"
#include "torch_xla/csrc/dtype.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/ir_dump_util.h"
Expand Down Expand Up @@ -1117,8 +1117,12 @@ void dlPack_Capsule_Destructor(PyObject* data) {
}

at::Tensor tensor_fromDLPack(PyObject* data) {
DLManagedTensor* dlMTensor = (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor");
XLA_CHECK(dlMTensor != nullptr) << "from_dlpack received an invalid capsule. Note that a DLTensor capsule can be consumed only once. You may have already constructed a tensor from it once.";
DLManagedTensor* dlMTensor =
(DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor");
XLA_CHECK(dlMTensor != nullptr)
<< "from_dlpack received an invalid capsule. Note that a DLTensor "
"capsule can be consumed only once. You may have already constructed "
"a tensor from it once.";

at::Tensor tensor = torch_xla::fromDLPack(dlMTensor);
PyCapsule_SetName(data, "used_dltensor");
Expand Down Expand Up @@ -2543,17 +2547,17 @@ void InitXlaModuleBindings(py::module m) {
NoGilSection nogil;
dlMTensor = torch_xla::toDLPack(input);
}
// return py::reinterpret_steal<py::object>(PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor));
// return py::reinterpret_steal<py::object>(PyCapsule_New(dlMTensor,
// "dltensor", dlPack_Capsule_Destructor));
return PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor);
});
// m.def("_to_dlpack", &tensor_toDLPack, ""); //
// m.def("_to_dlpack", &tensor_toDLPack, ""); //

// from a dlpack tensor to an XLA tensor
m.def("_from_dlpack", [](py::handle ext_data) -> at::Tensor {
return tensor_fromDLPack(ext_data.ptr());
});


// -------------Dynamo Integration API Start-------------------------
/*
* Return tensor ids and at::tensors for all DeviceData nodes that is needed
Expand Down
Loading

0 comments on commit 4eb243f

Please sign in to comment.