diff --git a/test/test_operations.py b/test/test_operations.py index 2839db8c12f..97a7a950972 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2448,22 +2448,33 @@ def test_unsafe_buffer_pointer(self): class TestDLPack(parameterized.TestCase): def _test_dlpack_capsule_conversion_helper(self, xla_tensor): - dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule + dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule got = xdlpack.from_dlpack(dlpt) self.assertEqual(xla_tensor.device, got.device) self.assertTrue(torch.allclose(xla_tensor.cpu(), got.cpu())) - self.assertRaisesRegex(RuntimeError, "DLTensor capsule can be consumed only once", lambda: xdlpack.from_dlpack(dlpt)) + self.assertRaisesRegex(RuntimeError, + "DLTensor capsule can be consumed only once", + lambda: xdlpack.from_dlpack(dlpt)) - self.assertEqual(torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor),torch_xla._XLAC._unsafe_buffer_pointer(got)) + self.assertEqual( + torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor), + torch_xla._XLAC._unsafe_buffer_pointer(got)) @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA - @parameterized.parameters(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64)) + @parameterized.parameters(*all_types_and_complex_and(torch.half, + torch.bfloat16, + torch.bool, torch.uint16, + torch.uint32, + torch.uint64)) def test_dlpack_roundtrip(self, dtype): # "arange_cpu" not implemented for complex64 and complex128. # xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) failed with `RuntimeError: false INTERNAL ASSERT FAILED at "/ansible/pytorch/torch/csrc/lazy/core/hash.h":139, please report a bug to PyTorch. Unsupported scalar type:UInt64`, similar to other uint. - if dtype in { torch.complex128, torch.complex64, torch.uint64, torch.uint32, torch.uint16, torch.bool }: + if dtype in { + torch.complex128, torch.complex64, torch.uint64, torch.uint32, + torch.uint16, torch.bool + }: return xla_device = xm.xla_device() xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) @@ -2489,7 +2500,7 @@ def test_dlpack_roundtrip(self, dtype): @onlyIfPJRTDeviceIsCUDA def test_dlpack_roundtrip_bool(self): xla_tensor = torch.ones(1, dtype=torch.bool).to(xm.xla_device()) - self._test_dlpack_capsule_conversion_helper(xla_tensor) + self._test_dlpack_capsule_conversion_helper(xla_tensor) @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA @@ -2529,7 +2540,10 @@ def test_dlpack_non_default_layout(self): self.assertTrue(torch.allclose(t2.cpu(), xla_t2.cpu())) t3 = cuda_t[:, 0] - self.assertRaisesRegex(RuntimeError, r"Only DLPack tensors with trivial \(compact\) striding are supported", lambda: xdlpack.from_dlpack(t3.__dlpack__())) + self.assertRaisesRegex( + RuntimeError, + r"Only DLPack tensors with trivial \(compact\) striding are supported", + lambda: xdlpack.from_dlpack(t3.__dlpack__())) t4 = cuda_t[1, :] xla_t4 = xdlpack.from_dlpack(t4.__dlpack__()) @@ -2540,9 +2554,6 @@ def test_dlpack_non_default_layout(self): self.assertTrue(torch.allclose(t5.cpu(), xla_t5.cpu())) - - - class SimpleModelWithDropout(torch.nn.Module): def __init__(self): diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index e1a21aed46e..3c0371efcf6 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -1,17 +1,17 @@ #include "torch_xla/csrc/dl_convertor.h" -#include "absl/types/span.h" #include -#include "torch_xla/csrc/tensor.h" +#include "absl/types/span.h" #include "torch_xla/csrc/aten_xla_bridge.h" -#include "torch_xla/csrc/runtime/runtime.h" +#include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/ops/device_data.h" +#include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/tf_logging.h" -#include "torch_xla/csrc/unwrap_data.h" +#include "torch_xla/csrc/tensor.h" #include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/unwrap_data.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_layout.h" @@ -19,12 +19,14 @@ namespace torch_xla { -std::shared_ptr get_data_handle(const at::Tensor& input) { +std::shared_ptr get_data_handle( + const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); XLA_CHECK(xtensor) << "The input has to be an XLA tensor."; if (xtensor->CurrentDataHandle() != nullptr) { TF_VLOG(4) << "The xla tensor has a current data handle."; - return std::dynamic_pointer_cast(xtensor->CurrentDataHandle()); + return std::dynamic_pointer_cast( + xtensor->CurrentDataHandle()); } else if (xtensor->CurrentIrValue().node != nullptr) { DeviceData* device_data = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); @@ -33,7 +35,8 @@ std::shared_ptr get_data_handle(const at::Tens } TF_VLOG(4) << "The xla tensor has IR value but does not have device data."; } - TF_VLOG(4) << "The xla tensor either has no current data handle or has no IR value."; + TF_VLOG(4) + << "The xla tensor either has no current data handle or has no IR value."; return nullptr; } @@ -66,7 +69,8 @@ DLDeviceType DLDeviceTypeForDevice(const xla::PjRtDevice& device) { } else if (device.client()->platform_id() == xla::CudaId()) { return DLDeviceType::kDLCUDA; } - XLA_ERROR() << "Device " << device.DebugString() << " cannot be used as a DLPack device."; + XLA_ERROR() << "Device " << device.DebugString() + << " cannot be used as a DLPack device."; } DLDevice DLDeviceForDevice(const xla::PjRtDevice& device) { @@ -109,7 +113,8 @@ DLDataType PrimitiveTypeToDLDataType(xla::PrimitiveType type) { case xla::PrimitiveType::C128: return DLDataType{kDLComplex, 128, 1}; default: - XLA_ERROR() << "XLA type " << xla::PrimitiveType_Name(type) << " has no DLPack equivalent"; + XLA_ERROR() << "XLA type " << xla::PrimitiveType_Name(type) + << " has no DLPack equivalent"; } } @@ -129,14 +134,18 @@ std::vector StridesForShape(xla::PrimitiveType element_type, // Convert an XLA tensor to a dlPack tensor. DLManagedTensor* toDLPack(const at::Tensor& input) { - std::shared_ptr handle = get_data_handle(input); - XLA_CHECK(handle != nullptr) << "Could not extract a valid data handle from the input tensor"; + std::shared_ptr handle = + get_data_handle(input); + XLA_CHECK(handle != nullptr) + << "Could not extract a valid data handle from the input tensor"; - std::shared_ptr pjrt_buffer = runtime::GetComputationClient()->GetPjRtBuffer(handle); + std::shared_ptr pjrt_buffer = + runtime::GetComputationClient()->GetPjRtBuffer(handle); XLA_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer"; if (pjrt_buffer->IsTuple()) { - XLA_ERROR() << "Unimplemented. BufferToDLPackManagedTensor is not implemented for tuple buffers."; + XLA_ERROR() << "Unimplemented. BufferToDLPackManagedTensor is not " + "implemented for tuple buffers."; } if (pjrt_buffer->has_dynamic_dimensions()) { XLA_ERROR() << "Unimplemented. DynamicShape is not implemented in DLPack."; @@ -164,9 +173,11 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { dt.ndim = pjrt_buffer->dimensions().size(); dt.dtype = PrimitiveTypeToDLDataType(pjrt_buffer->element_type()); - pack->shape = std::vector(pjrt_buffer->dimensions().begin(), pjrt_buffer->dimensions().end()); + pack->shape = std::vector(pjrt_buffer->dimensions().begin(), + pjrt_buffer->dimensions().end()); xla::Layout xla_layout = xla::GetXlaLayoutUnsafe(pjrt_buffer->layout()); - pack->strides = StridesForShape(pjrt_buffer->element_type(), pjrt_buffer->dimensions(), xla_layout); + pack->strides = StridesForShape(pjrt_buffer->element_type(), + pjrt_buffer->dimensions(), xla_layout); dt.shape = reinterpret_cast(pack->shape.data()); dt.strides = reinterpret_cast(pack->strides.data()); dt.byte_offset = 0; @@ -177,21 +188,25 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { absl::StatusOr DeviceForDLDevice(const DLDevice& context) { switch (context.device_type) { case DLDeviceType::kDLCPU: - XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), xla::CpuId()); - return runtime::GetComputationClient()->LookupAddressableDevice(context.device_id); + XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), + xla::CpuId()); + return runtime::GetComputationClient()->LookupAddressableDevice( + context.device_id); case DLDeviceType::kDLCUDA: - XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), xla::CudaId()); - return runtime::GetComputationClient()->LookupAddressableDevice(context.device_id); + XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), + xla::CudaId()); + return runtime::GetComputationClient()->LookupAddressableDevice( + context.device_id); default: - return tsl::errors::InvalidArgument("Unknown/unsupported DLPack device type %d", - context.device_type); + return tsl::errors::InvalidArgument( + "Unknown/unsupported DLPack device type %d", context.device_type); } } absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { if (type.lanes != 1) { - return tsl::errors::Unimplemented("DLPack types with lanes != 1 not implemented, got %d", - type.lanes); + return tsl::errors::Unimplemented( + "DLPack types with lanes != 1 not implemented, got %d", type.lanes); } switch (type.code) { case kDLBool: @@ -265,7 +280,8 @@ absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { type.bits); } default: - return tsl::errors::Unimplemented("Unknown or invalid DLPack type code %d", type.code); + return tsl::errors::Unimplemented( + "Unknown or invalid DLPack type code %d", type.code); } } @@ -302,43 +318,51 @@ absl::StatusOr> StridesToLayout( at::Tensor fromDLPack(DLManagedTensor* dlmt) { if (dlmt->dl_tensor.ndim < 0) { - XLA_ERROR() << "Number of dimensions in DLManagedTensor must be nonnegative, got " << dlmt->dl_tensor.ndim; + XLA_ERROR() + << "Number of dimensions in DLManagedTensor must be nonnegative, got " + << dlmt->dl_tensor.ndim; } xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value(); absl::Span dimensions( const_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); - xla::PrimitiveType element_type = DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype).value(); + xla::PrimitiveType element_type = + DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype).value(); std::vector minor_to_major; if (dlmt->dl_tensor.strides && absl::c_find(dimensions, 0) == dimensions.end()) { absl::Span strides( - const_cast(dlmt->dl_tensor.strides), - dlmt->dl_tensor.ndim); + const_cast(dlmt->dl_tensor.strides), dlmt->dl_tensor.ndim); minor_to_major = StridesToLayout(dimensions, strides).value(); } else { minor_to_major.resize(dlmt->dl_tensor.ndim); std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); } - xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, - minor_to_major); + xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout( + element_type, dimensions, minor_to_major); std::function on_delete_callback; if (dlmt->deleter) { on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; } - xla::StatusOr> pjrt_buffer = device->client()->CreateViewOfDeviceBuffer( - static_cast(dlmt->dl_tensor.data) + - dlmt->dl_tensor.byte_offset, - shape, device, on_delete_callback); - XLA_CHECK_OK(pjrt_buffer.status()) << "Failed to create a pjrt buffer in " << __FUNCTION__; - XLA_CHECK(pjrt_buffer.value() != nullptr) << "pjrt buffer is null in " << __FUNCTION__; + xla::StatusOr> pjrt_buffer = + device->client()->CreateViewOfDeviceBuffer( + static_cast(dlmt->dl_tensor.data) + + dlmt->dl_tensor.byte_offset, + shape, device, on_delete_callback); + XLA_CHECK_OK(pjrt_buffer.status()) + << "Failed to create a pjrt buffer in " << __FUNCTION__; + XLA_CHECK(pjrt_buffer.value() != nullptr) + << "pjrt buffer is null in " << __FUNCTION__; + + runtime::ComputationClient::DataPtr data = + runtime::GetComputationClient()->CreateData( + runtime::GetComputationClient()->PjRtDeviceToString(device), shape, + std::move(pjrt_buffer.value())); - runtime::ComputationClient::DataPtr data = runtime::GetComputationClient()->CreateData(runtime::GetComputationClient()->PjRtDeviceToString(device), shape, std::move(pjrt_buffer.value())); - at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype); XLATensorPtr xla_tensor = XLATensor::Create(data, tensor_type); return bridge::AtenFromXlaTensor(xla_tensor); } -} +} // namespace torch_xla diff --git a/torch_xla/csrc/dl_convertor.h b/torch_xla/csrc/dl_convertor.h index 07d4587146a..f5a54823e2e 100644 --- a/torch_xla/csrc/dl_convertor.h +++ b/torch_xla/csrc/dl_convertor.h @@ -1,8 +1,8 @@ #ifndef XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ #define XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ -#include #include +#include namespace torch_xla { diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 2f4b90e42cf..e3fc165eb71 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1,7 +1,7 @@ +#include #include #include #include -#include #include #include #include @@ -35,8 +35,8 @@ #include "torch_xla/csrc/aten_autograd_ops.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" -#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/dl_convertor.h" +#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" @@ -1117,8 +1117,12 @@ void dlPack_Capsule_Destructor(PyObject* data) { } at::Tensor tensor_fromDLPack(PyObject* data) { - DLManagedTensor* dlMTensor = (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); - XLA_CHECK(dlMTensor != nullptr) << "from_dlpack received an invalid capsule. Note that a DLTensor capsule can be consumed only once. You may have already constructed a tensor from it once."; + DLManagedTensor* dlMTensor = + (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + XLA_CHECK(dlMTensor != nullptr) + << "from_dlpack received an invalid capsule. Note that a DLTensor " + "capsule can be consumed only once. You may have already constructed " + "a tensor from it once."; at::Tensor tensor = torch_xla::fromDLPack(dlMTensor); PyCapsule_SetName(data, "used_dltensor"); @@ -2543,17 +2547,17 @@ void InitXlaModuleBindings(py::module m) { NoGilSection nogil; dlMTensor = torch_xla::toDLPack(input); } - // return py::reinterpret_steal(PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor)); + // return py::reinterpret_steal(PyCapsule_New(dlMTensor, + // "dltensor", dlPack_Capsule_Destructor)); return PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor); }); - // m.def("_to_dlpack", &tensor_toDLPack, ""); // + // m.def("_to_dlpack", &tensor_toDLPack, ""); // // from a dlpack tensor to an XLA tensor m.def("_from_dlpack", [](py::handle ext_data) -> at::Tensor { return tensor_fromDLPack(ext_data.ptr()); }); - // -------------Dynamo Integration API Start------------------------- /* * Return tensor ids and at::tensors for all DeviceData nodes that is needed diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 6eca719c896..b275ef562ee 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -25,11 +25,11 @@ #include "torch_xla/csrc/runtime/types.h" #include "torch_xla/csrc/runtime/util.h" #include "xla/client/xla_computation.h" -#include "xla/pjrt/pjrt_client.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal_util.h" -#include "xla/types.h" +#include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" +#include "xla/types.h" namespace torch_xla { namespace runtime { @@ -260,9 +260,8 @@ class ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) = 0; - virtual DataPtr CreateData( - std::string device, xla::Shape shape, - std::shared_ptr pjrt_buffer) = 0; + virtual DataPtr CreateData(std::string device, xla::Shape shape, + std::shared_ptr pjrt_buffer) = 0; // Returns data shards. We expect this to be called on PjRtShardedData to // retrieve the shards. If other data type is passed, it returns the input @@ -281,7 +280,8 @@ class ComputationClient { // structure will be empty if there is no sharding, like with PjRtData. virtual std::optional GetDataSharding(DataPtr handle) = 0; - virtual std::string PjRtDeviceToString(xla::PjRtDevice* const device) const = 0; + virtual std::string PjRtDeviceToString( + xla::PjRtDevice* const device) const = 0; // Transfers local tensor values to the TPU devices and fetches the handles. virtual std::vector TransferToDevice( @@ -312,7 +312,8 @@ class ComputationClient { virtual std::uintptr_t UnsafeBufferPointer(const DataPtr handle) = 0; - virtual std::shared_ptr GetPjRtBuffer(const DataPtr handle) = 0; + virtual std::shared_ptr GetPjRtBuffer( + const DataPtr handle) = 0; // Compiles a set of computations. virtual std::vector Compile( @@ -356,7 +357,8 @@ class ComputationClient { virtual xla::PjRtPlatformId GetPlatformID() const = 0; - virtual absl::StatusOr LookupAddressableDevice(int local_device_id) const = 0; + virtual absl::StatusOr LookupAddressableDevice( + int local_device_id) const = 0; virtual size_t GetNumDevices() const = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index e059be41a08..e2a72992d6f 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -402,7 +402,8 @@ std::uintptr_t IfrtComputationClient::UnsafeBufferPointer( XLA_ERROR() << __FUNCTION__ << " not implemented"; } -std::shared_ptr IfrtComputationClient::GetPjRtBuffer(const DataPtr handle) { +std::shared_ptr IfrtComputationClient::GetPjRtBuffer( + const DataPtr handle) { XLA_ERROR() << __FUNCTION__ << " not implemented"; } diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index b2185842289..f843a2e53e4 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -33,11 +33,10 @@ class IfrtComputationClient : public ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) override; - DataPtr CreateData( - std::string device, xla::Shape shape, - std::shared_ptr pjrt_buffer) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - }; + DataPtr CreateData(std::string device, xla::Shape shape, + std::shared_ptr pjrt_buffer) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + }; std::vector GetDataShards(DataPtr data) override; @@ -96,7 +95,8 @@ class IfrtComputationClient : public ComputationClient { return client_->platform_id(); } - absl::StatusOr LookupAddressableDevice(int local_device_id) const override { + absl::StatusOr LookupAddressableDevice( + int local_device_id) const override { XLA_ERROR() << __FUNCTION__ << " not implemented"; } diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index c7896c09c20..ba3e9baf8c8 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -188,7 +188,8 @@ ComputationClient::DataPtr PjRtComputationClient::CreateDataPlaceholder( ComputationClient::DataPtr PjRtComputationClient::CreateData( std::string device, xla::Shape shape, std::shared_ptr pjrt_buffer) { - return std::make_shared(std::move(device), std::move(shape), pjrt_buffer); + return std::make_shared(std::move(device), std::move(shape), + pjrt_buffer); } std::vector PjRtComputationClient::GetDataShards( @@ -469,17 +470,19 @@ std::uintptr_t PjRtComputationClient::UnsafeBufferPointer( std::shared_ptr pjrt_data = std::dynamic_pointer_cast(handle); XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString(); - XLA_CHECK(pjrt_data->buffer != nullptr) << "PjRt buffer is null in " << __FUNCTION__; + XLA_CHECK(pjrt_data->buffer != nullptr) + << "PjRt buffer is null in " << __FUNCTION__; xla::StatusOr ptr = client_->UnsafeBufferPointer(pjrt_data->buffer.get()); XLA_CHECK(ptr.ok()); return ptr.value(); } -std::shared_ptr PjRtComputationClient::GetPjRtBuffer(const DataPtr handle) { +std::shared_ptr PjRtComputationClient::GetPjRtBuffer( + const DataPtr handle) { std::shared_ptr pjrt_data = std::dynamic_pointer_cast(handle); - XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString(); + XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString(); return pjrt_data->buffer; } @@ -498,7 +501,8 @@ std::vector PjRtComputationClient::TransferFromDevice( // is not sharded, then it is a no-op. std::shared_ptr pjrt_data = ReplicateShardedData(handle); XLA_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__; - XLA_CHECK(pjrt_data->buffer != nullptr) << "PjRt buffer is null in " << __FUNCTION__; + XLA_CHECK(pjrt_data->buffer != nullptr) + << "PjRt buffer is null in " << __FUNCTION__; xla::Literal& literal = literals.emplace_back(host_output_shape(pjrt_data->buffer.get())); @@ -508,7 +512,8 @@ std::vector PjRtComputationClient::TransferFromDevice( } for (auto& future : futures) { absl::Status status = future.Await(); - XLA_CHECK_OK(status) << "Failed to await future from buffer to literal in" << __FUNCTION__; + XLA_CHECK_OK(status) << "Failed to await future from buffer to literal in" + << __FUNCTION__; } InboundDataMetric()->AddSample(total_size); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 1b1e0aa5c47..aff4b781f99 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -32,9 +32,8 @@ class PjRtComputationClient : public ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) override; - DataPtr CreateData( - std::string device, xla::Shape shape, - std::shared_ptr pjrt_buffer) override; + DataPtr CreateData(std::string device, xla::Shape shape, + std::shared_ptr pjrt_buffer) override; std::vector GetDataShards(DataPtr data) override; @@ -99,8 +98,10 @@ class PjRtComputationClient : public ComputationClient { return client_->platform_id(); } - absl::StatusOr LookupAddressableDevice(int local_device_id) const override { - return client_->LookupAddressableDevice(xla::PjRtLocalDeviceId(local_device_id)); + absl::StatusOr LookupAddressableDevice( + int local_device_id) const override { + return client_->LookupAddressableDevice( + xla::PjRtLocalDeviceId(local_device_id)); } std::vector GetLocalDevices() const override; diff --git a/torch_xla/utils/dlpack.py b/torch_xla/utils/dlpack.py index 9ae99b8f802..9f93d532b27 100644 --- a/torch_xla/utils/dlpack.py +++ b/torch_xla/utils/dlpack.py @@ -1,8 +1,10 @@ from typing import Any import torch_xla + def to_dlpack(xla_tensor: Any): return torch_xla._XLAC._to_dlpack(xla_tensor) + def from_dlpack(ext_tensor: Any): return torch_xla._XLAC._from_dlpack(ext_tensor)