Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reapply lower randperm #6482

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ supported:
- random_
- random_.from
- random_.to
- randperm
- reflection_pad2d
- reflection_pad2d_backward
- remainder.Scalar
Expand Down
49 changes: 41 additions & 8 deletions test/cpp/test_aten_xla_tensor_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1484,14 +1484,47 @@ TEST_F(AtenXlaTensorTest, TestNativeDropoutZeroProbability) {

TEST_F(AtenXlaTensorTest, TestRandperm) {
int n = 5;
torch::Tensor shuffle = torch::randperm(
n, torch::TensorOptions(torch::kLong).device(torch::kXLA));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);
std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.size() == n && xla::IsPermutation(shuffle_data));
ExpectCounterNotChanged("aten::(?!randperm.generator_out).*",
cpp_test::GetIgnoredCounters());
ForEachDevice([&](const torch::Device& device) {
torch::Tensor shuffle =
torch::randperm(n, torch::TensorOptions(torch::kLong).device(device));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);

std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.size() == n && xla::IsPermutation(shuffle_data));
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::randperm", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestRandpermZeroDoesntCrash) {
int n = 0;
ForEachDevice([&](const torch::Device& device) {
torch::Tensor shuffle =
torch::randperm(n, torch::TensorOptions(torch::kLong).device(device));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);

std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.empty());
});
}

TEST_F(AtenXlaTensorTest, TestRandpermCPUFallback) {
int n = 5;
ForEachDevice([&](const torch::Device& device) {
torch::Tensor shuffle = torch::randperm(
n,
torch::TensorOptions(torch::kLong).device(device).pinned_memory(true));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);

std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.size() == n && xla::IsPermutation(shuffle_data));
});

ExpectCounterChanged("aten::.*", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestSlice) {
Expand Down
9 changes: 9 additions & 0 deletions test/pytorch_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
import re
import sys
import runpy
import torch
import unittest

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu

from functools import wraps
from torch.testing._internal.common_device_type import (DeviceTypeTestBase)

DEFAULT_FLOATING_PRECISION = 1e-3

TORCH_TEST_PRECIIONS = {
Expand All @@ -18,6 +23,8 @@
'test_var_neg_dim_xla_bfloat16': 0.01,
'test_sum_xla_bfloat16': 0.1,
'test_put_xla_bfloat16': 0.05,
# Note test_put_* is local to PyTorch/XLA repo and not upstream.
'test_put_cpu_bfloat16': 0.05,
'test_take_xla_bfloat16': 0.05,
}

Expand Down Expand Up @@ -119,6 +126,8 @@
'test_resize_as_all_dtypes_and_devices', # uses half
'test_resize_all_dtypes_and_devices', # uses half
'test_pinverse', # lowering
'test_put', # Due to randperm and LTC, not deterministic.
'test_index_copy', # Due to randperm and LTC, not deterministic
'test_norm',
'test_multinomial',
'test_multinomial_alias',
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/pjrt/test_internal_tpu.py"
run_test "$CDIR/pjrt/test_ddp.py"
run_test "$CDIR/pjrt/test_mesh_service.py"
run_test "$CDIR/test_python_ops.py"
run_test "$CDIR/test_ops.py"
run_test "$CDIR/test_metrics.py"
run_test "$CDIR/test_zero1.py"
Expand Down
13 changes: 11 additions & 2 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3348,11 +3348,20 @@ def test_aten_prod_dim_int_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.prod.dim_int, args, kwargs)

@unittest.skip
# Due to the way randperm isn't on device, we manually assert checks here instead of using
# the existing test harness.
def test_aten_randperm_0(self):
args = (20,)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.randperm, args, kwargs)
pytorch = torch.randperm(20)

xla = torch.randperm(20, device=xm.xla_device())
xla_detached = xla.detach().cpu()

# Check equal lengths and that the sorted sets are equal. Since these numbers are randomly
# generated there's no way to check that pytorch == pytorch/xla.
self.assertEqual(len(pytorch), len(xla))
self.assertEqual(sorted(set(pytorch)), sorted(set(xla_detached)))

def test_aten_reciprocal_0(self):
args = (torch.randn((10, 10)).to(torch.float32),)
Expand Down
156 changes: 156 additions & 0 deletions test/test_python_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import torch
import torch.nn as nn
import torch_xla
import torch_xla.core.xla_model as xm
import unittest
import test_utils
import pytorch_test_base

from torch.testing import make_tensor
from itertools import product
from functools import partial
from torch.testing._internal.common_utils import TestCase, run_tests, IS_JETSON
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, dtypes)
from torch.testing._internal.common_dtype import (all_types_and_complex_and)


# These tests are a copy of upstream pytorch tests due to the way lazy tensors
# work. The randperm op generates a random tensor. Every iteration of the test
# recompiles the randperm op thus generating a different random tensor which
# makes the test non-deterministic. To force determinism, this test has to
# call PyTorch/XLA mark_step() to materialize the tensor rather than recompile.
class TestPythonOps(pytorch_test_base.XLATestBase):

@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_put(self, dtype):
if dtype in self.unsupported_dtypes:
raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format(
str(dtype)))

device = xm.xla_device()
real_device_type = xm.xla_device_hw(str(xm.xla_device()))
if real_device_type == "TPU":
raise unittest.SkipTest("TestPut is too slow on TPU. Skipped")

src_size = (4,)

make_arg = partial(make_tensor, device=device, dtype=dtype)
make_idx = partial(make_tensor, low=0, device=device, dtype=torch.int64)

def ref_put(dst, idx, src, accumulate):
new_dst = dst.clone(memory_format=torch.contiguous_format).view(-1)
new_idx = idx.contiguous().view(-1)
new_src = src.contiguous().view(-1)
method = new_dst.index_add_ if accumulate else new_dst.index_copy_
return method(0, new_idx, new_src).view_as(dst)

for dst_contig, src_contig, idx_contig, idx_reshape, accumulate in product(
[True, False], repeat=5):
for dst_size in ((5,), (4, 5)):
dst = make_arg(dst_size, noncontiguous=not dst_contig)
src = make_arg(src_size, noncontiguous=not src_contig)

# If accumulate=True, `put_` should be deterministic regardless of the inputs on CPU
# On CUDA it may not be, but the test has enough tolerance to account for this
if accumulate:
idx = make_idx(src_size, high=dst.numel())
else:
idx = torch.randperm(
dst.numel(), dtype=torch.int64, device=device)[:src_size[0]]
if not idx_contig:
idx = torch.repeat_interleave(idx, 2, dim=-1)[..., ::2]
if idx_reshape:
idx = idx.reshape(2, 2)
out = torch.put(dst, idx, src, accumulate)

xm.mark_step()

# out-place
reference = ref_put(dst, idx, src, accumulate)
self.assertEqual(out, reference)

# in-place
dst.put_(idx, src, accumulate)
self.assertEqual(dst, reference)

# Create the 8 possible combinations of scalar sizes for target / index / source
scalars = ((make_arg(size_t), make_idx(size_i, high=1), make_arg(size_s))
for size_t, size_i, size_s in product([(), (1,)], repeat=3))
for (dest, idx, source), accumulate in product(scalars, [True, False]):
dest_init = dest.clone()
# out-place
out = torch.put(dest, idx, source, accumulate=accumulate)
# in-place
dest1 = dest.clone()
dest1.put_(idx, source, accumulate=accumulate)
for d in [out, dest1]:
if accumulate:
self.assertEqual(d.item(), (dest_init + source).item())
else:
self.assertEqual(d.item(), source.item())

# Empty case
dest = make_arg((3, 2))
reference = dest.clone()
idx = make_idx((0,), high=1)
source = make_arg((0,))
for accumulate in [True, False]:
out = torch.put(dest, idx, source, accumulate=accumulate)
self.assertEqual(out, reference)
dest.put_(idx, source, accumulate=accumulate)
self.assertEqual(dest, reference)

@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_index_copy(self, dtype):
if dtype in self.unsupported_dtypes:
raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format(
str(dtype)))

device = xm.xla_device()

# We just test for num_copy <= num_dest, as otherwise there are repeated indices
# and the behavior is undefined
num_copy, num_dest = 3, 5

def make_arg(batch_sizes, n, dim, contig):
size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:]
return make_tensor(
size_arg,
dtype=dtype,
device=device,
low=None,
high=None,
noncontiguous=not contig)

def ref_index_copy(tgt, dim, idx, src):
for i in range(idx.size(0)):
idx_dest = dim * (slice(None),) + (idx[i],)
idx_src = dim * (slice(None),) + (i,)
tgt[idx_dest] = src[idx_src]

# More thorough testing as in index_add
for dest_contig, src_contig, index_contig in product([True, False],
repeat=3):
for other_sizes in ((), (4, 5)):
for dim in range(len(other_sizes)):
dest = make_arg(other_sizes, num_dest, dim, dest_contig)
src = make_arg(other_sizes, num_copy, dim, src_contig)
idx = torch.randperm(
num_dest, dtype=torch.int64, device=device)[:num_copy]
if not index_contig:
idx = torch.repeat_interleave(idx, 2, dim=-1)
idx = idx[..., ::2]

xm.mark_step()

dest2 = dest.clone()
dest.index_copy_(dim, idx, src)
ref_index_copy(dest2, dim, idx, src)
self.assertEqual(dest, dest2)


instantiate_device_type_tests(TestPythonOps, globals())

if __name__ == '__main__':
run_tests()
26 changes: 26 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2470,6 +2470,32 @@ at::Tensor& XLANativeFunctions::random_(
return self;
}

at::Tensor XLANativeFunctions::randperm(int64_t n,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");

// Only support the basic version of randperm(int64_t) to start. If there are
// any other parameters, fallback to CPU.
bool fallback_to_cpu = false;
fallback_to_cpu |= layout.has_value();
fallback_to_cpu |= pin_memory.has_value() && pin_memory.value() == true;
fallback_to_cpu |= dtype.value() != at::ScalarType::Long;
fallback_to_cpu |= n == 0;

if (fallback_to_cpu) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
ATEN_OP(randperm)>::call(n, dtype,
layout, device,
pin_memory);
}

return bridge::AtenFromXlaTensor(tensor_methods::randperm(
n, GetXlaDeviceOrCurrent(device), at::ScalarType::Long));
}

at::Tensor XLANativeFunctions::reflection_pad2d(const at::Tensor& self,
at::IntArrayRef padding) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
Expand Down
72 changes: 72 additions & 0 deletions torch_xla/csrc/ops/randperm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include "torch_xla/csrc/ops/randperm.h"

#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "tsl/platform/stacktrace.h"
#include "xla/client/lib/loops.h"
#include "xla/shape_util.h"

namespace torch_xla {
namespace {

using namespace xla;

xla::Shape NodeOutputShape(int64_t n) {
return xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {n});
}

XlaOp Swap(XlaOp input, XlaOp i, XlaOp j) {
XlaOp i_value = xla::DynamicSlice(input, {i}, /*slice_sizes=*/{1});
XlaOp j_value = xla::DynamicSlice(input, {j}, /*slice_sizes=*/{1});

XlaOp write_i = xla::DynamicUpdateSlice(input, j_value, {i});
XlaOp write_j = xla::DynamicUpdateSlice(write_i, i_value, {j});

return write_j;
}

StatusOr<std::vector<XlaOp>> LoopBodyFn(XlaOp i, absl::Span<const XlaOp> values,
XlaBuilder* builder) {
XlaOp input_array = values[0];
XlaOp upper_bound_exclusive = values[1];

XlaOp target_index = xla::RngUniform(
i, upper_bound_exclusive,
ShapeUtil::MakeShape(xla::PrimitiveType::S64, /*dimensions=*/{1}));

XlaOp swapped_array = Swap(input_array, i, target_index);
return std::vector<XlaOp>{swapped_array, upper_bound_exclusive};
}

} // namespace

RandPerm::RandPerm(int64_t n, const at::ScalarType dtype,
const at::Layout layout, const at::Device device,
bool pin_memory)
: XlaNode(torch::lazy::OpKind(at::aten::randperm), /*operands=*/{},
[&]() { return NodeOutputShape(n); }, /*num_outputs=*/1,
torch::lazy::MHash(n)),
n_(n) {}

// Fischer Yates Shuffle.
XlaOpVector RandPerm::Lower(LoweringContext* lotcx) const {
xla::XlaBuilder* builder = lotcx->builder();
auto init_tensor = xla::Iota(lotcx->builder(), xla::PrimitiveType::S64, n_);

auto upper_bound_exclusive = xla::ConstantLiteral(
lotcx->builder(), xla::LiteralUtil::CreateR0<int64_t>(n_));
auto fischer_yates_loop = xla::ForEachIndex(
/*num_iterations=*/n_ - 1, xla::PrimitiveType::S64, &LoopBodyFn,
{init_tensor, upper_bound_exclusive}, "Fischer-Yates-Shuffle", builder);

return ReturnOp(fischer_yates_loop.value()[0], lotcx);
}

std::string RandPerm::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", n=" << n_;
return ss.str();
}

} // namespace torch_xla
Loading
Loading