Skip to content

Commit

Permalink
Clean up unused prints and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed May 10, 2024
1 parent d1b6dd2 commit 8e08c84
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 74 deletions.
9 changes: 0 additions & 9 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2449,27 +2449,18 @@ class TestDLPack(parameterized.TestCase):

def _test_dlpack_capsule_conversion_helper(self, xla_tensor):
dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule
print('xw32 finished the to_dlpack')
got = xdlpack.from_dlpack(dlpt)
print('xw32 finished the from_dlpack')

print('xla_tensor.device=', xla_tensor.device, ', got.device=', got.device)
self.assertEqual(xla_tensor.device, got.device)
print('xla_tensor.cpu()=', xla_tensor.cpu())
print('got.cpu()=', got.cpu())
self.assertTrue(torch.allclose(xla_tensor.cpu(), got.cpu()))
self.assertRaisesRegex(RuntimeError, "DLTensor capsule can be consumed only once", lambda: xdlpack.from_dlpack(dlpt))

print('xw32 torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor)=', torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor))
print('xw32 torch_xla._XLAC._unsafe_buffer_pointer(got)=', torch_xla._XLAC._unsafe_buffer_pointer(got))
self.assertEqual(torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor),torch_xla._XLAC._unsafe_buffer_pointer(got))

# TODO(xw32): need to test different data type such as pytorch/test/test_dlpack.py
@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
@parameterized.parameters(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
def test_dlpack_roundtrip(self, dtype):
print('xw32 dtype=', dtype)
# "arange_cpu" not implemented for complex64 and complex128.
# xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) failed with `RuntimeError: false INTERNAL ASSERT FAILED at "/ansible/pytorch/torch/csrc/lazy/core/hash.h":139, please report a bug to PyTorch. Unsupported scalar type:UInt64`, similar to other uint.
if dtype in { torch.complex128, torch.complex64, torch.uint64, torch.uint32, torch.uint16, torch.bool }:
Expand Down
66 changes: 15 additions & 51 deletions torch_xla/csrc/dl_convertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ TorchXLADLMTensor::~TorchXLADLMTensor() {
}

void TorchXLADLMTensorDeleter(DLManagedTensor* t) {
std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl;
if (t) {
std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl;
delete static_cast<TorchXLADLMTensor*>(t->manager_ctx);
}
}
Expand All @@ -67,8 +65,6 @@ DLDeviceType DLDeviceTypeForDevice(const xla::PjRtDevice& device) {
return DLDeviceType::kDLCPU;
} else if (device.client()->platform_id() == xla::CudaId()) {
return DLDeviceType::kDLCUDA;
} else if (device.client()->platform_id() == xla::RocmId()) {
return DLDeviceType::kDLROCM;
}
XLA_ERROR() << "Device " << device.DebugString() << " cannot be used as a DLPack device.";
}
Expand Down Expand Up @@ -131,7 +127,7 @@ std::vector<int64_t> StridesForShape(xla::PrimitiveType element_type,
return strides;
}

// Convert an XLA tensor to dlPack tensor.
// Convert an XLA tensor to a dlPack tensor.
DLManagedTensor* toDLPack(const at::Tensor& input) {
std::shared_ptr<runtime::ComputationClient::Data> handle = get_data_handle(input);
XLA_CHECK(handle != nullptr) << "Could not extract a valid data handle from the input tensor";
Expand All @@ -146,63 +142,46 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
XLA_ERROR() << "Unimplemented. DynamicShape is not implemented in DLPack.";
}

auto torchXlaDLMTensor = std::make_unique<TorchXLADLMTensor>();
DLTensor& dt = torchXlaDLMTensor->tensor.dl_tensor;
auto pack = std::make_unique<TorchXLADLMTensor>();
DLTensor& dt = pack->tensor.dl_tensor;
{
// AcquireExternalReference may block
auto external_ref = pjrt_buffer->AcquireExternalReference();
XLA_CHECK_OK(external_ref.status());
torchXlaDLMTensor->external_reference = std::move(external_ref.value());
pack->external_reference = std::move(external_ref.value());
xla::PjRtFuture<> future = pjrt_buffer->GetReadyFuture();
absl::Status status = future.Await();
XLA_CHECK_OK(status);
}
torchXlaDLMTensor->buffer_reference = pjrt_buffer;
// torchXlaDLMTensor->source_tensor = input;
// pack->buffer_reference = nb::borrow<nb::object>(py_buffer); // xw32: should we do it?
pack->buffer_reference = pjrt_buffer;
// pack->source_tensor = input;

dt.data = torchXlaDLMTensor->external_reference->OpaqueDeviceMemoryDataPointer();
torchXlaDLMTensor->tensor.manager_ctx = torchXlaDLMTensor.get();
torchXlaDLMTensor->tensor.deleter = TorchXLADLMTensorDeleter;
dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer();
pack->tensor.manager_ctx = pack.get();
pack->tensor.deleter = TorchXLADLMTensorDeleter;
dt.device = DLDeviceForDevice(*pjrt_buffer->device());
dt.device.device_id = pjrt_buffer->device()->local_hardware_id();
dt.ndim = pjrt_buffer->dimensions().size();
dt.dtype = PrimitiveTypeToDLDataType(pjrt_buffer->element_type());

torchXlaDLMTensor->shape = std::vector<int64_t>(pjrt_buffer->dimensions().begin(), pjrt_buffer->dimensions().end());
pack->shape = std::vector<int64_t>(pjrt_buffer->dimensions().begin(), pjrt_buffer->dimensions().end());
xla::Layout xla_layout = xla::GetXlaLayoutUnsafe(pjrt_buffer->layout());
torchXlaDLMTensor->strides = StridesForShape(pjrt_buffer->element_type(), pjrt_buffer->dimensions(), xla_layout);
dt.shape = reinterpret_cast<std::int64_t*>(torchXlaDLMTensor->shape.data());
dt.strides = reinterpret_cast<std::int64_t*>(torchXlaDLMTensor->strides.data());
pack->strides = StridesForShape(pjrt_buffer->element_type(), pjrt_buffer->dimensions(), xla_layout);
dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
dt.byte_offset = 0;

return &(torchXlaDLMTensor.release()->tensor);
return &(pack.release()->tensor);
}

absl::StatusOr<xla::PjRtDevice*> DeviceForDLDevice(const DLDevice& context) {
switch (context.device_type) {
case DLDeviceType::kDLCPU:
// if (cpu_client == nullptr) {
// return InvalidArgument(
// "DLPack tensor is on CPU, but no CPU backend was provided.");
// }
XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), xla::CpuId());
return runtime::GetComputationClient()->LookupAddressableDevice(context.device_id);
case DLDeviceType::kDLCUDA:
// if (gpu_client == nullptr) { // xw32 TODO: check if client_ is GPU client
// return InvalidArgument(
// "DLPack tensor is on GPU, but no GPU backend was provided.");
// }
XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), xla::CudaId());
return runtime::GetComputationClient()->LookupAddressableDevice(context.device_id);
// case DLDeviceType::kDLROCM:
// // if (gpu_client == nullptr) {
// // return InvalidArgument(
// // "DLPack tensor is on GPU, but no GPU backend was provided.");
// // }
// XLA_CHECK_EQ(pjrt_client->platform_id(), xla::RocmId());
// xla::PjRtDevice* device = pjrt_client->addressable_devices()[context.device_id];
// return device;
default:
return tsl::errors::InvalidArgument("Unknown/unsupported DLPack device type %d",
context.device_type);
Expand Down Expand Up @@ -325,7 +304,7 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) {
if (dlmt->dl_tensor.ndim < 0) {
XLA_ERROR() << "Number of dimensions in DLManagedTensor must be nonnegative, got " << dlmt->dl_tensor.ndim;
}
xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value(); // client_ is a xla::PjRtClient. So this fromDLPack should be inside pjrt_computation_client class.
xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value();
absl::Span<int64_t const> dimensions(
const_cast<int64_t*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
xla::PrimitiveType element_type = DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype).value();
Expand All @@ -344,19 +323,6 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) {
xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions,
minor_to_major);

// Raise an error if the resulting PjRtBuffer would have a non-default layout.
// TODO(skyewm): we do this because JAX doesn't currently have good support
// for non-default layouts, and will return wrong results if a non-default
// layout is passed to a computation expecting default layouts. Remove this
// special case when non-default layouts are better supported by JAX.
absl::StatusOr<xla::Layout> default_layout_from_client =
device->client()->GetDefaultLayout(element_type, dimensions);
XLA_CHECK_OK(default_layout_from_client.status()) << "Failed to get a default layout in " << __FUNCTION__;
xla::Layout default_layout = default_layout_from_client.value(); // TODO(xw32): the check below is needed due to an limitation in ifrt. Since torch_xla uses pjrt, we may not need the check below and the var default_layout.
// if (shape.layout() != default_layout) {
// XLA_ERROR() << "from_dlpack got array with non-default layout with minor-to-major dimensions (" << absl::StrJoin(shape.layout().minor_to_major(), ",") << "), expected (" << absl::StrJoin(default_layout.minor_to_major(), ",") << ")";
// }

std::function<void()> on_delete_callback;
if (dlmt->deleter) {
on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
Expand All @@ -370,9 +336,7 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) {

runtime::ComputationClient::DataPtr data = runtime::GetComputationClient()->CreateData(runtime::GetComputationClient()->PjRtDeviceToString(device), shape, std::move(pjrt_buffer.value()));

// xw32 note: XlaDataToTensors does a fromDeviceToHost transfer.XlaDataToTensors
at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype);
// return XlaDataToTensors({data}, {tensor_type})[0];
XLATensorPtr xla_tensor = XLATensor::Create(data, tensor_type);
return bridge::AtenFromXlaTensor(xla_tensor);
}
Expand Down
6 changes: 0 additions & 6 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1116,12 +1116,6 @@ void dlPack_Capsule_Destructor(PyObject* data) {
}
}

// PyObject* tensor_toDLPack(const at::Tensor& input) {
// DLManagedTensor* dlMTensor = torch_xla::toDLPack(input);
// std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl;
// return PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor);
// }

at::Tensor tensor_fromDLPack(PyObject* data) {
DLManagedTensor* dlMTensor = (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor");
XLA_CHECK(dlMTensor != nullptr) << "from_dlpack received an invalid capsule. Note that a DLTensor capsule can be consumed only once. You may have already constructed a tensor from it once.";
Expand Down
6 changes: 1 addition & 5 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,7 @@ class PjRtComputationClient : public ComputationClient {
};
void Assign(const torch::lazy::BackendData& data) override;
bool HasValue() const override {
// bool has_value = buffer != nullptr && !buffer->IsDeleted();
// std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": buffer != nullptr=" << (buffer != nullptr) << ", buffer->IsDeleted()=" << buffer->IsDeleted() << std::endl;
// return has_value;
return buffer != nullptr && !buffer->IsDeleted(); // TODO(xw32): uncomment this line and remove all above lines in the method.
return buffer != nullptr && !buffer->IsDeleted();
};

bool HasSharding() const override { return false; }
Expand Down Expand Up @@ -240,7 +237,6 @@ class PjRtComputationClient : public ComputationClient {
}

bool HasValue() const override {
std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": PjRtShardedData::HasValue is called." << std::endl;
if (shards.empty()) {
return false;
}
Expand Down
3 changes: 0 additions & 3 deletions torch_xla/utils/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

def to_dlpack(xla_tensor: Any):
return torch_xla._XLAC._to_dlpack(xla_tensor)
# dlt = torch_xla._XLAC._to_dlpack(xla_tensor)
# print('xw32 torch_xla._XLAC._to_dlpack has returned. dlt has __dlpack_+=', hasattr(dlt, "__dlpack__"), ', dlt has __dlpack_device__=', hasattr(dlt, "__dlpack_device__"))
# return dlt

def from_dlpack(ext_tensor: Any):
return torch_xla._XLAC._from_dlpack(ext_tensor)

0 comments on commit 8e08c84

Please sign in to comment.