Skip to content

Commit

Permalink
Add dlpack support (pytorch#7025)
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 authored May 22, 2024
1 parent 5e1d454 commit 6023855
Show file tree
Hide file tree
Showing 13 changed files with 673 additions and 6 deletions.
140 changes: 140 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
parser.add_argument('--verbosity', type=int, default=0)
FLAGS, leftovers = parser.parse_known_args()
sys.argv = [sys.argv[0]] + leftovers
from absl.testing import absltest, parameterized

# Normal imports section starts here.
import collections
Expand All @@ -28,6 +29,11 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.testing._internal.common_device_type import dtypes
from torch.testing._internal.common_dtype import (
all_types_and_complex_and,
all_types_and,
)
import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_op_registry as xor
Expand All @@ -40,6 +46,7 @@
import torch_xla.distributed.spmd as xs
from torch_xla import runtime as xr
import torch_xla.test.test_utils as xtu
import torch_xla.utils.dlpack as xdlpack
import torch_xla.utils.utils as xu
import torch_xla.utils.serialization as xser
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -2464,6 +2471,139 @@ def test_unsafe_buffer_pointer(self):
self.assertGreaterEqual(buf_ptr_3, 0)


class TestDLPack(parameterized.TestCase):

def _test_dlpack_capsule_conversion_helper(self, xla_tensor):
dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule
xla_tensor2 = xdlpack.from_dlpack(dlpt)

self.assertEqual(xla_tensor.device, xla_tensor2.device)
self.assertTrue(torch.allclose(xla_tensor.cpu(), xla_tensor2.cpu()))
self.assertRaisesRegex(RuntimeError,
"DLTensor capsule can be consumed only once",
lambda: xdlpack.from_dlpack(dlpt))

self.assertEqual(
torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor),
torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor2))

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
@parameterized.parameters(*all_types_and(torch.half, torch.bfloat16))
def test_dlpack_roundtrip_tensor(self, dtype):
xla_device = xm.xla_device()
# xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr
# xla_tensor_2 uses XLANativeFunctions::_to_copy
xla_tensor_2 = torch.arange(5, dtype=dtype).to(xla_device)
self._test_dlpack_capsule_conversion_helper(xla_tensor_2)

# xla_tensor_3 uses arange_out IR node.
xla_tensor_3 = torch.arange(5, dtype=dtype, device=xm.xla_device())
xm.mark_step()
self._test_dlpack_capsule_conversion_helper(xla_tensor_3)

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
@parameterized.parameters(*all_types_and_complex_and(torch.half,
torch.bfloat16,
torch.bool, torch.uint16,
torch.uint32,
torch.uint64))
def test_dlpack_roundtrip_scalar(self, dtype):
xla_device = xm.xla_device()
xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device)
# `mark_step` ensures xtensor->CurrentDataHandle() != nullptr
xm.mark_step()
self._test_dlpack_capsule_conversion_helper(xla_tensor_0)

xla_tensor_1 = torch.tensor(42, dtype=dtype).to(xla_device)
# xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr
self._test_dlpack_capsule_conversion_helper(xla_tensor_1)

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_dlpack_roundtrip_bool(self):
xla_tensor = torch.ones(1, dtype=torch.bool).to(xm.xla_device())
self._test_dlpack_capsule_conversion_helper(xla_tensor)

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_dlpack_pytorch_cuda_to_xla(self):
t1_cuda = torch.arange(5).cuda()
dlt1 = torch.utils.dlpack.to_dlpack(t1_cuda)
xla_t1 = xdlpack.from_dlpack(dlt1)
self.assertEqual(xla_t1.device.type, 'xla')
self.assertEqual(xla_t1.device.index, t1_cuda.device.index)
t1_cuda[0] = t1_cuda[0] + 20
self.assertTrue(torch.allclose(xla_t1.cpu(), t1_cuda.cpu()))

t2_cuda = torch.tensor(5).cuda()
dlt2 = torch.utils.dlpack.to_dlpack(t2_cuda)
xla_t2 = xdlpack.from_dlpack(dlt2)
self.assertEqual(xla_t2.device.type, 'xla')
self.assertEqual(xla_t2.device.index, t2_cuda.device.index)
t2_cuda.fill_(6)
self.assertTrue(torch.allclose(xla_t2.cpu(), t2_cuda.cpu()))

cuda1 = torch.device('cuda:1')
t3_cuda = torch.tensor(5, device=cuda1)
dlt3 = torch.utils.dlpack.to_dlpack(t3_cuda)
xla_t3 = xdlpack.from_dlpack(dlt3)
self.assertEqual(xla_t3.device.type, 'xla')
self.assertEqual(
xla_t3.device.index,
t3_cuda.device.index,
msg='both value should 1. xla_t3.device should be xla:1.')
t3_cuda.fill_(6)
self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu()))

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_dlpack_xla_to_pytorch_cuda(self):
xla_t1 = torch.arange(5).to(xm.xla_device())
dlt1 = xdlpack.to_dlpack(xla_t1)
cuda_t1 = torch.utils.dlpack.from_dlpack(dlt1)
self.assertEqual(cuda_t1.device.type, 'cuda')
self.assertEqual(cuda_t1.device.index, xla_t1.device.index)
cuda_t1[0] = cuda_t1[0] + 20
self.assertTrue(torch.allclose(xla_t1.cpu(), cuda_t1.cpu()))

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_dlpack_non_default_layout(self):
cuda_t = torch.arange(25, device=torch.device('cuda')).reshape(5, 5)

t1 = cuda_t.t()
xla_t1 = xdlpack.from_dlpack(t1.__dlpack__())
self.assertEqual(xla_t1.device.type, 'xla')
self.assertEqual(xla_t1.device.index, 0)
self.assertTrue(torch.allclose(t1.cpu(), xla_t1.cpu()))

t2 = cuda_t[0]
xla_t2 = xdlpack.from_dlpack(t2.__dlpack__())
self.assertEqual(xla_t2.device.type, 'xla')
self.assertEqual(xla_t2.device.index, 0)
self.assertTrue(torch.allclose(t2.cpu(), xla_t2.cpu()))

t3 = cuda_t[:, 0]
self.assertRaisesRegex(
RuntimeError,
r"Only DLPack tensors with trivial \(compact\) striding are supported",
lambda: xdlpack.from_dlpack(t3.__dlpack__()))

t4 = cuda_t[1, :]
xla_t4 = xdlpack.from_dlpack(t4.__dlpack__())
self.assertEqual(xla_t4.device.type, 'xla')
self.assertEqual(xla_t4.device.index, 0)
self.assertTrue(torch.allclose(t4.cpu(), xla_t4.cpu()))

t5 = cuda_t[1]
xla_t5 = xdlpack.from_dlpack(t5.__dlpack__())
self.assertEqual(xla_t5.device.type, 'xla')
self.assertEqual(xla_t5.device.index, 0)
self.assertTrue(torch.allclose(t5.cpu(), xla_t5.cpu()))


class SimpleModelWithDropout(torch.nn.Module):

def __init__(self):
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ ptxla_cc_library(
"cross_replica_reduces.cpp",
"data_ops.cpp",
"debug_util.cpp",
"dl_convertor.cpp",
"elementwise.cpp",
"helpers.cpp",
"ir_dump_util.cpp",
Expand Down Expand Up @@ -81,6 +82,7 @@ ptxla_cc_library(
"cross_replica_reduces.h",
"data_ops.h",
"debug_util.h",
"dl_convertor.h",
"elementwise.h",
"generated_file_include.h",
"helpers.h",
Expand Down
Loading

0 comments on commit 6023855

Please sign in to comment.