Skip to content

Commit

Permalink
Reapply "Lower RandPerm" (#6394)
Browse files Browse the repository at this point in the history
This reverts commit 2f4275f.

Also fixes test infrastructure to call xla.mark_step() as required.
  • Loading branch information
changm committed Jan 31, 2024
1 parent 492fe27 commit 69288f8
Show file tree
Hide file tree
Showing 11 changed files with 383 additions and 12 deletions.
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ supported:
- random_
- random_.from
- random_.to
- randperm
- reflection_pad2d
- reflection_pad2d_backward
- remainder.Scalar
Expand Down
49 changes: 41 additions & 8 deletions test/cpp/test_aten_xla_tensor_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1484,14 +1484,47 @@ TEST_F(AtenXlaTensorTest, TestNativeDropoutZeroProbability) {

TEST_F(AtenXlaTensorTest, TestRandperm) {
int n = 5;
torch::Tensor shuffle = torch::randperm(
n, torch::TensorOptions(torch::kLong).device(torch::kXLA));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);
std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.size() == n && xla::IsPermutation(shuffle_data));
ExpectCounterNotChanged("aten::(?!randperm.generator_out).*",
cpp_test::GetIgnoredCounters());
ForEachDevice([&](const torch::Device& device) {
torch::Tensor shuffle =
torch::randperm(n, torch::TensorOptions(torch::kLong).device(device));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);

std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.size() == n && xla::IsPermutation(shuffle_data));
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::randperm", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestRandpermZeroDoesntCrash) {
int n = 0;
ForEachDevice([&](const torch::Device& device) {
torch::Tensor shuffle =
torch::randperm(n, torch::TensorOptions(torch::kLong).device(device));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);

std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.empty());
});
}

TEST_F(AtenXlaTensorTest, TestRandpermCPUFallback) {
int n = 5;
ForEachDevice([&](const torch::Device& device) {
torch::Tensor shuffle = torch::randperm(
n,
torch::TensorOptions(torch::kLong).device(device).pinned_memory(true));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);

std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.size() == n && xla::IsPermutation(shuffle_data));
});

ExpectCounterChanged("aten::.*", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestSlice) {
Expand Down
6 changes: 4 additions & 2 deletions test/pytorch_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@
'test_resize_as_all_dtypes_and_devices', # uses half
'test_resize_all_dtypes_and_devices', # uses half
'test_pinverse', # lowering
'test_put', # Due to randperm and LTC, not deterministic.
'test_index_copy', # Due to randperm and LTC, not deterministic
'test_norm',
'test_multinomial',
'test_multinomial_alias',
Expand Down Expand Up @@ -607,8 +609,8 @@ def skipped_test(self, *args, reason=reason, **kwargs):
setattr(cls, dtype_test_name, disallowed_test)
if not skipped:
xla_dtypes.append(
dtype_combination
if len(dtype_combination) > 1 else dtype_combination[0])
dtype_combination if len(dtype_combination) >
1 else dtype_combination[0])
if len(xla_dtypes) != 0:
test.dtypes[cls.device_type] = xla_dtypes
super().instantiate_test(name, test, generic_cls=generic_cls)
Expand Down
2 changes: 2 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ function run_xla_op_tests1 {
run_test "$CDIR/pjrt/test_internal_tpu.py"
run_test "$CDIR/pjrt/test_ddp.py"
run_test "$CDIR/pjrt/test_mesh_service.py"
run_test "$CDIR/test_index_copy.py"
run_test "$CDIR/test_put.py"
run_test "$CDIR/test_ops.py"
run_test "$CDIR/test_metrics.py"
run_test "$CDIR/test_zero1.py"
Expand Down
13 changes: 11 additions & 2 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3395,11 +3395,20 @@ def test_aten_prod_dim_int_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.prod.dim_int, args, kwargs)

@unittest.skip
# Due to the way randperm isn't on device, we manually assert checks here instead of using
# the existing test harness.
def test_aten_randperm_0(self):
args = (20,)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.randperm, args, kwargs)
pytorch = torch.randperm(20)

xla = torch.randperm(20, device=xm.xla_device())
xla_detached = xla.detach().cpu()

# Check equal lengths and that the sorted sets are equal. Since these numbers are randomly
# generated there's no way to check that pytorch == pytorch/xla.
self.assertEqual(len(pytorch), len(xla))
self.assertEqual(sorted(set(pytorch)), sorted(set(xla_detached)))

def test_aten_reciprocal_0(self):
args = (torch.randn((10, 10)).to(torch.float32),)
Expand Down
82 changes: 82 additions & 0 deletions test/test_index_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import torch.nn as nn
import torch_xla
import torch_xla.core.xla_model as xm
from torch.testing import make_tensor
from itertools import product

from functools import partial

from torch.testing._internal.common_utils import TestCase, run_tests, IS_JETSON

from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, dtypes)

from torch.testing._internal.common_dtype import (all_types_and_complex_and)

import unittest


# This test is a copy of upstream pytorch test_put due to the way lazy tensors
# work. The randperm op generates a random tensor. Every iteration of ref_put
# recompiles the randperm op thus generating a different random tensor which
# makes the test non-deterministic. To force determinism, this test has to
# call PyTorch/XLA mark_step() to materialize the tensor rather than recompile.
class TestIndexCopy(TestCase):
unsupported_dtypes = {
torch.half, torch.complex32, torch.complex64, torch.complex128
}

@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_index_copy(self, dtype):
if dtype in self.unsupported_dtypes:
raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format(
str(dtype)))

device = xm.xla_device()

# We just test for num_copy <= num_dest, as otherwise there are repeated indices
# and the behavior is undefined
num_copy, num_dest = 3, 5

def make_arg(batch_sizes, n, dim, contig):
size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:]
return make_tensor(
size_arg,
dtype=dtype,
device=device,
low=None,
high=None,
noncontiguous=not contig)

def ref_index_copy(tgt, dim, idx, src):
for i in range(idx.size(0)):
idx_dest = dim * (slice(None),) + (idx[i],)
idx_src = dim * (slice(None),) + (i,)
tgt[idx_dest] = src[idx_src]

# More thorough testing as in index_add
for dest_contig, src_contig, index_contig in product([True, False],
repeat=3):
for other_sizes in ((), (4, 5)):
for dim in range(len(other_sizes)):
dest = make_arg(other_sizes, num_dest, dim, dest_contig)
src = make_arg(other_sizes, num_copy, dim, src_contig)
idx = torch.randperm(
num_dest, dtype=torch.int64, device=device)[:num_copy]
if not index_contig:
idx = torch.repeat_interleave(idx, 2, dim=-1)
idx = idx[..., ::2]

xm.mark_step()

dest2 = dest.clone()
dest.index_copy_(dim, idx, src)
ref_index_copy(dest2, dim, idx, src)
self.assertEqual(dest, dest2)


instantiate_device_type_tests(TestIndexCopy, globals())

if __name__ == '__main__':
run_tests()
109 changes: 109 additions & 0 deletions test/test_put.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
import torch.nn as nn
import torch_xla
import torch_xla.core.xla_model as xm
from torch.testing import make_tensor
from itertools import product

from functools import partial

from torch.testing._internal.common_utils import TestCase, run_tests, IS_JETSON

from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, dtypes)

from torch.testing._internal.common_dtype import (all_types_and_complex_and)

import unittest


# This test is a copy of upstream pytorch test_put due to the way lazy tensors
# work. The randperm op generates a random tensor. Every iteration of ref_put
# recompiles the randperm op thus generating a different random tensor which
# makes the test non-deterministic. To force determinism, this test has to
# call PyTorch/XLA mark_step() to materialize the tensor rather than recompile.
class TestPut(TestCase):
unsupported_dtypes = {
torch.half, torch.complex32, torch.complex64, torch.complex128
}

@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_put(self, dtype):
if dtype in self.unsupported_dtypes:
raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format(
str(dtype)))

device = xm.xla_device()
src_size = (4,)

make_arg = partial(make_tensor, device=device, dtype=dtype)
make_idx = partial(make_tensor, low=0, device=device, dtype=torch.int64)

def ref_put(dst, idx, src, accumulate):
new_dst = dst.clone(memory_format=torch.contiguous_format).view(-1)
new_idx = idx.contiguous().view(-1)
new_src = src.contiguous().view(-1)
method = new_dst.index_add_ if accumulate else new_dst.index_copy_
return method(0, new_idx, new_src).view_as(dst)

for dst_contig, src_contig, idx_contig, idx_reshape, accumulate in product(
[True, False], repeat=5):
for dst_size in ((5,), (4, 5)):
dst = make_arg(dst_size, noncontiguous=not dst_contig)
src = make_arg(src_size, noncontiguous=not src_contig)

# If accumulate=True, `put_` should be deterministic regardless of the inputs on CPU
# On CUDA it may not be, but the test has enough tolerance to account for this
if accumulate:
idx = make_idx(src_size, high=dst.numel())
else:
idx = torch.randperm(
dst.numel(), dtype=torch.int64, device=device)[:src_size[0]]
if not idx_contig:
idx = torch.repeat_interleave(idx, 2, dim=-1)[..., ::2]
if idx_reshape:
idx = idx.reshape(2, 2)
out = torch.put(dst, idx, src, accumulate)

xm.mark_step()

# out-place
reference = ref_put(dst, idx, src, accumulate)
self.assertEqual(out, reference)

# in-place
dst.put_(idx, src, accumulate)
self.assertEqual(dst, reference)

# Create the 8 possible combinations of scalar sizes for target / index / source
scalars = ((make_arg(size_t), make_idx(size_i, high=1), make_arg(size_s))
for size_t, size_i, size_s in product([(), (1,)], repeat=3))
for (dest, idx, source), accumulate in product(scalars, [True, False]):
dest_init = dest.clone()
# out-place
out = torch.put(dest, idx, source, accumulate=accumulate)
# in-place
dest1 = dest.clone()
dest1.put_(idx, source, accumulate=accumulate)
for d in [out, dest1]:
if accumulate:
self.assertEqual(d.item(), (dest_init + source).item())
else:
self.assertEqual(d.item(), source.item())

# Empty case
dest = make_arg((3, 2))
reference = dest.clone()
idx = make_idx((0,), high=1)
source = make_arg((0,))
for accumulate in [True, False]:
out = torch.put(dest, idx, source, accumulate=accumulate)
self.assertEqual(out, reference)
dest.put_(idx, source, accumulate=accumulate)
self.assertEqual(dest, reference)


instantiate_device_type_tests(TestPut, globals())

if __name__ == '__main__':
run_tests()
26 changes: 26 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2482,6 +2482,32 @@ at::Tensor& XLANativeFunctions::random_(
return self;
}

at::Tensor XLANativeFunctions::randperm(int64_t n,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");

// Only support the basic version of randperm(int64_t) to start. If there are
// any other parameters, fallback to CPU.
bool fallback_to_cpu = false;
fallback_to_cpu |= layout.has_value();
fallback_to_cpu |= pin_memory.has_value() && pin_memory.value() == true;
fallback_to_cpu |= dtype.value() != at::ScalarType::Long;
fallback_to_cpu |= n == 0;

if (fallback_to_cpu) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
ATEN_OP(randperm)>::call(n, dtype,
layout, device,
pin_memory);
}

return bridge::AtenFromXlaTensor(tensor_methods::randperm(
n, GetXlaDeviceOrCurrent(device), at::ScalarType::Long));
}

at::Tensor XLANativeFunctions::reflection_pad2d(const at::Tensor& self,
at::IntArrayRef padding) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
Expand Down
Loading

0 comments on commit 69288f8

Please sign in to comment.