Skip to content

Commit

Permalink
Reapply lower randperm (#6482)
Browse files Browse the repository at this point in the history
Also updates test infrastructure to copy upstream pytorch test_input_copy and test_put to call XLA required `mark_step` during tests.
  • Loading branch information
changm authored Feb 8, 2024
1 parent b935268 commit 157e06e
Show file tree
Hide file tree
Showing 10 changed files with 356 additions and 8 deletions.
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ supported:
- random_
- random_.from
- random_.to
- randperm
- reflection_pad2d
- reflection_pad2d_backward
- remainder.Scalar
Expand Down
49 changes: 41 additions & 8 deletions test/cpp/test_aten_xla_tensor_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1484,14 +1484,47 @@ TEST_F(AtenXlaTensorTest, TestNativeDropoutZeroProbability) {

TEST_F(AtenXlaTensorTest, TestRandperm) {
int n = 5;
torch::Tensor shuffle = torch::randperm(
n, torch::TensorOptions(torch::kLong).device(torch::kXLA));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);
std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.size() == n && xla::IsPermutation(shuffle_data));
ExpectCounterNotChanged("aten::(?!randperm.generator_out).*",
cpp_test::GetIgnoredCounters());
ForEachDevice([&](const torch::Device& device) {
torch::Tensor shuffle =
torch::randperm(n, torch::TensorOptions(torch::kLong).device(device));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);

std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.size() == n && xla::IsPermutation(shuffle_data));
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::randperm", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestRandpermZeroDoesntCrash) {
int n = 0;
ForEachDevice([&](const torch::Device& device) {
torch::Tensor shuffle =
torch::randperm(n, torch::TensorOptions(torch::kLong).device(device));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);

std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.empty());
});
}

TEST_F(AtenXlaTensorTest, TestRandpermCPUFallback) {
int n = 5;
ForEachDevice([&](const torch::Device& device) {
torch::Tensor shuffle = torch::randperm(
n,
torch::TensorOptions(torch::kLong).device(device).pinned_memory(true));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);

std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.size() == n && xla::IsPermutation(shuffle_data));
});

ExpectCounterChanged("aten::.*", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestSlice) {
Expand Down
9 changes: 9 additions & 0 deletions test/pytorch_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
import re
import sys
import runpy
import torch
import unittest

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu

from functools import wraps
from torch.testing._internal.common_device_type import (DeviceTypeTestBase)

DEFAULT_FLOATING_PRECISION = 1e-3

TORCH_TEST_PRECIIONS = {
Expand All @@ -18,6 +23,8 @@
'test_var_neg_dim_xla_bfloat16': 0.01,
'test_sum_xla_bfloat16': 0.1,
'test_put_xla_bfloat16': 0.05,
# Note test_put_* is local to PyTorch/XLA repo and not upstream.
'test_put_cpu_bfloat16': 0.05,
'test_take_xla_bfloat16': 0.05,
}

Expand Down Expand Up @@ -119,6 +126,8 @@
'test_resize_as_all_dtypes_and_devices', # uses half
'test_resize_all_dtypes_and_devices', # uses half
'test_pinverse', # lowering
'test_put', # Due to randperm and LTC, not deterministic.
'test_index_copy', # Due to randperm and LTC, not deterministic
'test_norm',
'test_multinomial',
'test_multinomial_alias',
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/pjrt/test_internal_tpu.py"
run_test "$CDIR/pjrt/test_ddp.py"
run_test "$CDIR/pjrt/test_mesh_service.py"
run_test "$CDIR/test_python_ops.py"
run_test "$CDIR/test_ops.py"
run_test "$CDIR/test_metrics.py"
run_test "$CDIR/test_zero1.py"
Expand Down
15 changes: 15 additions & 0 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3050,6 +3050,21 @@ def test_aten_prod_dim_int_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.prod.dim_int, args, kwargs)

# Due to the way randperm isn't on device, we manually assert checks here instead of using
# the existing test harness.
def test_aten_randperm_0(self):
args = (20,)
kwargs = dict()
pytorch = torch.randperm(20)

xla = torch.randperm(20, device=xm.xla_device())
xla_detached = xla.detach().cpu()

# Check equal lengths and that the sorted sets are equal. Since these numbers are randomly
# generated there's no way to check that pytorch == pytorch/xla.
self.assertEqual(len(pytorch), len(xla))
self.assertEqual(sorted(set(pytorch)), sorted(set(xla_detached)))

def test_aten_reciprocal_0(self):
args = (torch.randn((10, 10)).to(torch.float32),)
kwargs = dict()
Expand Down
156 changes: 156 additions & 0 deletions test/test_python_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import torch
import torch.nn as nn
import torch_xla
import torch_xla.core.xla_model as xm
import unittest
import test_utils
import pytorch_test_base

from torch.testing import make_tensor
from itertools import product
from functools import partial
from torch.testing._internal.common_utils import TestCase, run_tests, IS_JETSON
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, dtypes)
from torch.testing._internal.common_dtype import (all_types_and_complex_and)


# These tests are a copy of upstream pytorch tests due to the way lazy tensors
# work. The randperm op generates a random tensor. Every iteration of the test
# recompiles the randperm op thus generating a different random tensor which
# makes the test non-deterministic. To force determinism, this test has to
# call PyTorch/XLA mark_step() to materialize the tensor rather than recompile.
class TestPythonOps(pytorch_test_base.XLATestBase):

@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_put(self, dtype):
if dtype in self.unsupported_dtypes:
raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format(
str(dtype)))

device = xm.xla_device()
real_device_type = xm.xla_device_hw(str(xm.xla_device()))
if real_device_type == "TPU":
raise unittest.SkipTest("TestPut is too slow on TPU. Skipped")

src_size = (4,)

make_arg = partial(make_tensor, device=device, dtype=dtype)
make_idx = partial(make_tensor, low=0, device=device, dtype=torch.int64)

def ref_put(dst, idx, src, accumulate):
new_dst = dst.clone(memory_format=torch.contiguous_format).view(-1)
new_idx = idx.contiguous().view(-1)
new_src = src.contiguous().view(-1)
method = new_dst.index_add_ if accumulate else new_dst.index_copy_
return method(0, new_idx, new_src).view_as(dst)

for dst_contig, src_contig, idx_contig, idx_reshape, accumulate in product(
[True, False], repeat=5):
for dst_size in ((5,), (4, 5)):
dst = make_arg(dst_size, noncontiguous=not dst_contig)
src = make_arg(src_size, noncontiguous=not src_contig)

# If accumulate=True, `put_` should be deterministic regardless of the inputs on CPU
# On CUDA it may not be, but the test has enough tolerance to account for this
if accumulate:
idx = make_idx(src_size, high=dst.numel())
else:
idx = torch.randperm(
dst.numel(), dtype=torch.int64, device=device)[:src_size[0]]
if not idx_contig:
idx = torch.repeat_interleave(idx, 2, dim=-1)[..., ::2]
if idx_reshape:
idx = idx.reshape(2, 2)
out = torch.put(dst, idx, src, accumulate)

xm.mark_step()

# out-place
reference = ref_put(dst, idx, src, accumulate)
self.assertEqual(out, reference)

# in-place
dst.put_(idx, src, accumulate)
self.assertEqual(dst, reference)

# Create the 8 possible combinations of scalar sizes for target / index / source
scalars = ((make_arg(size_t), make_idx(size_i, high=1), make_arg(size_s))
for size_t, size_i, size_s in product([(), (1,)], repeat=3))
for (dest, idx, source), accumulate in product(scalars, [True, False]):
dest_init = dest.clone()
# out-place
out = torch.put(dest, idx, source, accumulate=accumulate)
# in-place
dest1 = dest.clone()
dest1.put_(idx, source, accumulate=accumulate)
for d in [out, dest1]:
if accumulate:
self.assertEqual(d.item(), (dest_init + source).item())
else:
self.assertEqual(d.item(), source.item())

# Empty case
dest = make_arg((3, 2))
reference = dest.clone()
idx = make_idx((0,), high=1)
source = make_arg((0,))
for accumulate in [True, False]:
out = torch.put(dest, idx, source, accumulate=accumulate)
self.assertEqual(out, reference)
dest.put_(idx, source, accumulate=accumulate)
self.assertEqual(dest, reference)

@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_index_copy(self, dtype):
if dtype in self.unsupported_dtypes:
raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format(
str(dtype)))

device = xm.xla_device()

# We just test for num_copy <= num_dest, as otherwise there are repeated indices
# and the behavior is undefined
num_copy, num_dest = 3, 5

def make_arg(batch_sizes, n, dim, contig):
size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:]
return make_tensor(
size_arg,
dtype=dtype,
device=device,
low=None,
high=None,
noncontiguous=not contig)

def ref_index_copy(tgt, dim, idx, src):
for i in range(idx.size(0)):
idx_dest = dim * (slice(None),) + (idx[i],)
idx_src = dim * (slice(None),) + (i,)
tgt[idx_dest] = src[idx_src]

# More thorough testing as in index_add
for dest_contig, src_contig, index_contig in product([True, False],
repeat=3):
for other_sizes in ((), (4, 5)):
for dim in range(len(other_sizes)):
dest = make_arg(other_sizes, num_dest, dim, dest_contig)
src = make_arg(other_sizes, num_copy, dim, src_contig)
idx = torch.randperm(
num_dest, dtype=torch.int64, device=device)[:num_copy]
if not index_contig:
idx = torch.repeat_interleave(idx, 2, dim=-1)
idx = idx[..., ::2]

xm.mark_step()

dest2 = dest.clone()
dest.index_copy_(dim, idx, src)
ref_index_copy(dest2, dim, idx, src)
self.assertEqual(dest, dest2)


instantiate_device_type_tests(TestPythonOps, globals())

if __name__ == '__main__':
run_tests()
26 changes: 26 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2470,6 +2470,32 @@ at::Tensor& XLANativeFunctions::random_(
return self;
}

at::Tensor XLANativeFunctions::randperm(int64_t n,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");

// Only support the basic version of randperm(int64_t) to start. If there are
// any other parameters, fallback to CPU.
bool fallback_to_cpu = false;
fallback_to_cpu |= layout.has_value();
fallback_to_cpu |= pin_memory.has_value() && pin_memory.value() == true;
fallback_to_cpu |= dtype.value() != at::ScalarType::Long;
fallback_to_cpu |= n == 0;

if (fallback_to_cpu) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
ATEN_OP(randperm)>::call(n, dtype,
layout, device,
pin_memory);
}

return bridge::AtenFromXlaTensor(tensor_methods::randperm(
n, GetXlaDeviceOrCurrent(device), at::ScalarType::Long));
}

at::Tensor XLANativeFunctions::reflection_pad2d(const at::Tensor& self,
at::IntArrayRef padding) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
Expand Down
72 changes: 72 additions & 0 deletions torch_xla/csrc/ops/randperm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include "torch_xla/csrc/ops/randperm.h"

#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "tsl/platform/stacktrace.h"
#include "xla/client/lib/loops.h"
#include "xla/shape_util.h"

namespace torch_xla {
namespace {

using namespace xla;

xla::Shape NodeOutputShape(int64_t n) {
return xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {n});
}

XlaOp Swap(XlaOp input, XlaOp i, XlaOp j) {
XlaOp i_value = xla::DynamicSlice(input, {i}, /*slice_sizes=*/{1});
XlaOp j_value = xla::DynamicSlice(input, {j}, /*slice_sizes=*/{1});

XlaOp write_i = xla::DynamicUpdateSlice(input, j_value, {i});
XlaOp write_j = xla::DynamicUpdateSlice(write_i, i_value, {j});

return write_j;
}

StatusOr<std::vector<XlaOp>> LoopBodyFn(XlaOp i, absl::Span<const XlaOp> values,
XlaBuilder* builder) {
XlaOp input_array = values[0];
XlaOp upper_bound_exclusive = values[1];

XlaOp target_index = xla::RngUniform(
i, upper_bound_exclusive,
ShapeUtil::MakeShape(xla::PrimitiveType::S64, /*dimensions=*/{1}));

XlaOp swapped_array = Swap(input_array, i, target_index);
return std::vector<XlaOp>{swapped_array, upper_bound_exclusive};
}

} // namespace

RandPerm::RandPerm(int64_t n, const at::ScalarType dtype,
const at::Layout layout, const at::Device device,
bool pin_memory)
: XlaNode(torch::lazy::OpKind(at::aten::randperm), /*operands=*/{},
[&]() { return NodeOutputShape(n); }, /*num_outputs=*/1,
torch::lazy::MHash(n)),
n_(n) {}

// Fischer Yates Shuffle.
XlaOpVector RandPerm::Lower(LoweringContext* lotcx) const {
xla::XlaBuilder* builder = lotcx->builder();
auto init_tensor = xla::Iota(lotcx->builder(), xla::PrimitiveType::S64, n_);

auto upper_bound_exclusive = xla::ConstantLiteral(
lotcx->builder(), xla::LiteralUtil::CreateR0<int64_t>(n_));
auto fischer_yates_loop = xla::ForEachIndex(
/*num_iterations=*/n_ - 1, xla::PrimitiveType::S64, &LoopBodyFn,
{init_tensor, upper_bound_exclusive}, "Fischer-Yates-Shuffle", builder);

return ReturnOp(fischer_yates_loop.value()[0], lotcx);
}

std::string RandPerm::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", n=" << n_;
return ss.str();
}

} // namespace torch_xla
Loading

0 comments on commit 157e06e

Please sign in to comment.