diff --git a/test/test_operations.py b/test/test_operations.py index 8981c98066c..2839db8c12f 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2449,27 +2449,18 @@ class TestDLPack(parameterized.TestCase): def _test_dlpack_capsule_conversion_helper(self, xla_tensor): dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule - print('xw32 finished the to_dlpack') got = xdlpack.from_dlpack(dlpt) - print('xw32 finished the from_dlpack') - print('xla_tensor.device=', xla_tensor.device, ', got.device=', got.device) self.assertEqual(xla_tensor.device, got.device) - print('xla_tensor.cpu()=', xla_tensor.cpu()) - print('got.cpu()=', got.cpu()) self.assertTrue(torch.allclose(xla_tensor.cpu(), got.cpu())) self.assertRaisesRegex(RuntimeError, "DLTensor capsule can be consumed only once", lambda: xdlpack.from_dlpack(dlpt)) - print('xw32 torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor)=', torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor)) - print('xw32 torch_xla._XLAC._unsafe_buffer_pointer(got)=', torch_xla._XLAC._unsafe_buffer_pointer(got)) self.assertEqual(torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor),torch_xla._XLAC._unsafe_buffer_pointer(got)) - # TODO(xw32): need to test different data type such as pytorch/test/test_dlpack.py @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA @parameterized.parameters(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64)) def test_dlpack_roundtrip(self, dtype): - print('xw32 dtype=', dtype) # "arange_cpu" not implemented for complex64 and complex128. # xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) failed with `RuntimeError: false INTERNAL ASSERT FAILED at "/ansible/pytorch/torch/csrc/lazy/core/hash.h":139, please report a bug to PyTorch. Unsupported scalar type:UInt64`, similar to other uint. if dtype in { torch.complex128, torch.complex64, torch.uint64, torch.uint32, torch.uint16, torch.bool }: diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index d2cb0110bd2..e1a21aed46e 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -55,9 +55,7 @@ TorchXLADLMTensor::~TorchXLADLMTensor() { } void TorchXLADLMTensorDeleter(DLManagedTensor* t) { - std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl; if (t) { - std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl; delete static_cast(t->manager_ctx); } } @@ -67,8 +65,6 @@ DLDeviceType DLDeviceTypeForDevice(const xla::PjRtDevice& device) { return DLDeviceType::kDLCPU; } else if (device.client()->platform_id() == xla::CudaId()) { return DLDeviceType::kDLCUDA; - } else if (device.client()->platform_id() == xla::RocmId()) { - return DLDeviceType::kDLROCM; } XLA_ERROR() << "Device " << device.DebugString() << " cannot be used as a DLPack device."; } @@ -131,7 +127,7 @@ std::vector StridesForShape(xla::PrimitiveType element_type, return strides; } -// Convert an XLA tensor to dlPack tensor. +// Convert an XLA tensor to a dlPack tensor. DLManagedTensor* toDLPack(const at::Tensor& input) { std::shared_ptr handle = get_data_handle(input); XLA_CHECK(handle != nullptr) << "Could not extract a valid data handle from the input tensor"; @@ -146,63 +142,46 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { XLA_ERROR() << "Unimplemented. DynamicShape is not implemented in DLPack."; } - auto torchXlaDLMTensor = std::make_unique(); - DLTensor& dt = torchXlaDLMTensor->tensor.dl_tensor; + auto pack = std::make_unique(); + DLTensor& dt = pack->tensor.dl_tensor; { // AcquireExternalReference may block auto external_ref = pjrt_buffer->AcquireExternalReference(); XLA_CHECK_OK(external_ref.status()); - torchXlaDLMTensor->external_reference = std::move(external_ref.value()); + pack->external_reference = std::move(external_ref.value()); xla::PjRtFuture<> future = pjrt_buffer->GetReadyFuture(); absl::Status status = future.Await(); XLA_CHECK_OK(status); } - torchXlaDLMTensor->buffer_reference = pjrt_buffer; - // torchXlaDLMTensor->source_tensor = input; - // pack->buffer_reference = nb::borrow(py_buffer); // xw32: should we do it? + pack->buffer_reference = pjrt_buffer; + // pack->source_tensor = input; - dt.data = torchXlaDLMTensor->external_reference->OpaqueDeviceMemoryDataPointer(); - torchXlaDLMTensor->tensor.manager_ctx = torchXlaDLMTensor.get(); - torchXlaDLMTensor->tensor.deleter = TorchXLADLMTensorDeleter; + dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer(); + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = TorchXLADLMTensorDeleter; dt.device = DLDeviceForDevice(*pjrt_buffer->device()); dt.device.device_id = pjrt_buffer->device()->local_hardware_id(); dt.ndim = pjrt_buffer->dimensions().size(); dt.dtype = PrimitiveTypeToDLDataType(pjrt_buffer->element_type()); - torchXlaDLMTensor->shape = std::vector(pjrt_buffer->dimensions().begin(), pjrt_buffer->dimensions().end()); + pack->shape = std::vector(pjrt_buffer->dimensions().begin(), pjrt_buffer->dimensions().end()); xla::Layout xla_layout = xla::GetXlaLayoutUnsafe(pjrt_buffer->layout()); - torchXlaDLMTensor->strides = StridesForShape(pjrt_buffer->element_type(), pjrt_buffer->dimensions(), xla_layout); - dt.shape = reinterpret_cast(torchXlaDLMTensor->shape.data()); - dt.strides = reinterpret_cast(torchXlaDLMTensor->strides.data()); + pack->strides = StridesForShape(pjrt_buffer->element_type(), pjrt_buffer->dimensions(), xla_layout); + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = reinterpret_cast(pack->strides.data()); dt.byte_offset = 0; - return &(torchXlaDLMTensor.release()->tensor); + return &(pack.release()->tensor); } absl::StatusOr DeviceForDLDevice(const DLDevice& context) { switch (context.device_type) { case DLDeviceType::kDLCPU: - // if (cpu_client == nullptr) { - // return InvalidArgument( - // "DLPack tensor is on CPU, but no CPU backend was provided."); - // } XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), xla::CpuId()); return runtime::GetComputationClient()->LookupAddressableDevice(context.device_id); case DLDeviceType::kDLCUDA: - // if (gpu_client == nullptr) { // xw32 TODO: check if client_ is GPU client - // return InvalidArgument( - // "DLPack tensor is on GPU, but no GPU backend was provided."); - // } XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), xla::CudaId()); return runtime::GetComputationClient()->LookupAddressableDevice(context.device_id); - // case DLDeviceType::kDLROCM: - // // if (gpu_client == nullptr) { - // // return InvalidArgument( - // // "DLPack tensor is on GPU, but no GPU backend was provided."); - // // } - // XLA_CHECK_EQ(pjrt_client->platform_id(), xla::RocmId()); - // xla::PjRtDevice* device = pjrt_client->addressable_devices()[context.device_id]; - // return device; default: return tsl::errors::InvalidArgument("Unknown/unsupported DLPack device type %d", context.device_type); @@ -325,7 +304,7 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { if (dlmt->dl_tensor.ndim < 0) { XLA_ERROR() << "Number of dimensions in DLManagedTensor must be nonnegative, got " << dlmt->dl_tensor.ndim; } - xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value(); // client_ is a xla::PjRtClient. So this fromDLPack should be inside pjrt_computation_client class. + xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value(); absl::Span dimensions( const_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); xla::PrimitiveType element_type = DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype).value(); @@ -344,19 +323,6 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, minor_to_major); - // Raise an error if the resulting PjRtBuffer would have a non-default layout. - // TODO(skyewm): we do this because JAX doesn't currently have good support - // for non-default layouts, and will return wrong results if a non-default - // layout is passed to a computation expecting default layouts. Remove this - // special case when non-default layouts are better supported by JAX. - absl::StatusOr default_layout_from_client = - device->client()->GetDefaultLayout(element_type, dimensions); - XLA_CHECK_OK(default_layout_from_client.status()) << "Failed to get a default layout in " << __FUNCTION__; - xla::Layout default_layout = default_layout_from_client.value(); // TODO(xw32): the check below is needed due to an limitation in ifrt. Since torch_xla uses pjrt, we may not need the check below and the var default_layout. - // if (shape.layout() != default_layout) { - // XLA_ERROR() << "from_dlpack got array with non-default layout with minor-to-major dimensions (" << absl::StrJoin(shape.layout().minor_to_major(), ",") << "), expected (" << absl::StrJoin(default_layout.minor_to_major(), ",") << ")"; - // } - std::function on_delete_callback; if (dlmt->deleter) { on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; @@ -370,9 +336,7 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { runtime::ComputationClient::DataPtr data = runtime::GetComputationClient()->CreateData(runtime::GetComputationClient()->PjRtDeviceToString(device), shape, std::move(pjrt_buffer.value())); - // xw32 note: XlaDataToTensors does a fromDeviceToHost transfer.XlaDataToTensors at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype); - // return XlaDataToTensors({data}, {tensor_type})[0]; XLATensorPtr xla_tensor = XLATensor::Create(data, tensor_type); return bridge::AtenFromXlaTensor(xla_tensor); } diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index e9e4a9ef0fe..2f4b90e42cf 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1116,12 +1116,6 @@ void dlPack_Capsule_Destructor(PyObject* data) { } } -// PyObject* tensor_toDLPack(const at::Tensor& input) { -// DLManagedTensor* dlMTensor = torch_xla::toDLPack(input); -// std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl; -// return PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor); -// } - at::Tensor tensor_fromDLPack(PyObject* data) { DLManagedTensor* dlMTensor = (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); XLA_CHECK(dlMTensor != nullptr) << "from_dlpack received an invalid capsule. Note that a DLTensor capsule can be consumed only once. You may have already constructed a tensor from it once."; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index b17b8c451fd..1b1e0aa5c47 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -181,10 +181,7 @@ class PjRtComputationClient : public ComputationClient { }; void Assign(const torch::lazy::BackendData& data) override; bool HasValue() const override { - // bool has_value = buffer != nullptr && !buffer->IsDeleted(); - // std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": buffer != nullptr=" << (buffer != nullptr) << ", buffer->IsDeleted()=" << buffer->IsDeleted() << std::endl; - // return has_value; - return buffer != nullptr && !buffer->IsDeleted(); // TODO(xw32): uncomment this line and remove all above lines in the method. + return buffer != nullptr && !buffer->IsDeleted(); }; bool HasSharding() const override { return false; } @@ -240,7 +237,6 @@ class PjRtComputationClient : public ComputationClient { } bool HasValue() const override { - std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": PjRtShardedData::HasValue is called." << std::endl; if (shards.empty()) { return false; } diff --git a/torch_xla/utils/dlpack.py b/torch_xla/utils/dlpack.py index 67a37d5dac1..9ae99b8f802 100644 --- a/torch_xla/utils/dlpack.py +++ b/torch_xla/utils/dlpack.py @@ -3,9 +3,6 @@ def to_dlpack(xla_tensor: Any): return torch_xla._XLAC._to_dlpack(xla_tensor) - # dlt = torch_xla._XLAC._to_dlpack(xla_tensor) - # print('xw32 torch_xla._XLAC._to_dlpack has returned. dlt has __dlpack_+=', hasattr(dlt, "__dlpack__"), ', dlt has __dlpack_device__=', hasattr(dlt, "__dlpack_device__")) - # return dlt def from_dlpack(ext_tensor: Any): return torch_xla._XLAC._from_dlpack(ext_tensor)