diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index 38720143f00..1cfb794c4e1 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -280,6 +280,7 @@ supported: - random_ - random_.from - random_.to + - randperm - reflection_pad2d - reflection_pad2d_backward - remainder.Scalar diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index 92f3d79be99..2444839af4a 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -1484,14 +1484,47 @@ TEST_F(AtenXlaTensorTest, TestNativeDropoutZeroProbability) { TEST_F(AtenXlaTensorTest, TestRandperm) { int n = 5; - torch::Tensor shuffle = torch::randperm( - n, torch::TensorOptions(torch::kLong).device(torch::kXLA)); - torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU); - std::vector shuffle_data(shuffle_cpu.data_ptr(), - shuffle_cpu.data_ptr() + n); - EXPECT_TRUE(shuffle_data.size() == n && xla::IsPermutation(shuffle_data)); - ExpectCounterNotChanged("aten::(?!randperm.generator_out).*", - cpp_test::GetIgnoredCounters()); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor shuffle = + torch::randperm(n, torch::TensorOptions(torch::kLong).device(device)); + torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU); + + std::vector shuffle_data(shuffle_cpu.data_ptr(), + shuffle_cpu.data_ptr() + n); + EXPECT_TRUE(shuffle_data.size() == n && xla::IsPermutation(shuffle_data)); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::randperm", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestRandpermZeroDoesntCrash) { + int n = 0; + ForEachDevice([&](const torch::Device& device) { + torch::Tensor shuffle = + torch::randperm(n, torch::TensorOptions(torch::kLong).device(device)); + torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU); + + std::vector shuffle_data(shuffle_cpu.data_ptr(), + shuffle_cpu.data_ptr() + n); + EXPECT_TRUE(shuffle_data.empty()); + }); +} + +TEST_F(AtenXlaTensorTest, TestRandpermCPUFallback) { + int n = 5; + ForEachDevice([&](const torch::Device& device) { + torch::Tensor shuffle = torch::randperm( + n, + torch::TensorOptions(torch::kLong).device(device).pinned_memory(true)); + torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU); + + std::vector shuffle_data(shuffle_cpu.data_ptr(), + shuffle_cpu.data_ptr() + n); + EXPECT_TRUE(shuffle_data.size() == n && xla::IsPermutation(shuffle_data)); + }); + + ExpectCounterChanged("aten::.*", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestSlice) { diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index 64c8e41e4f4..7fae41f5ac3 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -4,11 +4,16 @@ import re import sys import runpy +import torch +import unittest import torch_xla import torch_xla.core.xla_model as xm import torch_xla.utils.utils as xu +from functools import wraps +from torch.testing._internal.common_device_type import (DeviceTypeTestBase) + DEFAULT_FLOATING_PRECISION = 1e-3 TORCH_TEST_PRECIIONS = { @@ -18,6 +23,8 @@ 'test_var_neg_dim_xla_bfloat16': 0.01, 'test_sum_xla_bfloat16': 0.1, 'test_put_xla_bfloat16': 0.05, + # Note test_put_* is local to PyTorch/XLA repo and not upstream. + 'test_put_cpu_bfloat16': 0.05, 'test_take_xla_bfloat16': 0.05, } @@ -119,6 +126,8 @@ 'test_resize_as_all_dtypes_and_devices', # uses half 'test_resize_all_dtypes_and_devices', # uses half 'test_pinverse', # lowering + 'test_put', # Due to randperm and LTC, not deterministic. + 'test_index_copy', # Due to randperm and LTC, not deterministic 'test_norm', 'test_multinomial', 'test_multinomial_alias', diff --git a/test/run_tests.sh b/test/run_tests.sh index 84debb84b25..129659532ad 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -172,6 +172,7 @@ function run_xla_op_tests1 { run_test "$CDIR/pjrt/test_internal_tpu.py" run_test "$CDIR/pjrt/test_ddp.py" run_test "$CDIR/pjrt/test_mesh_service.py" + run_test "$CDIR/test_python_ops.py" run_test "$CDIR/test_ops.py" run_test "$CDIR/test_metrics.py" run_test "$CDIR/test_zero1.py" diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index ac36f88ebe6..4f2fa3d20ee 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -3050,6 +3050,21 @@ def test_aten_prod_dim_int_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.prod.dim_int, args, kwargs) + # Due to the way randperm isn't on device, we manually assert checks here instead of using + # the existing test harness. + def test_aten_randperm_0(self): + args = (20,) + kwargs = dict() + pytorch = torch.randperm(20) + + xla = torch.randperm(20, device=xm.xla_device()) + xla_detached = xla.detach().cpu() + + # Check equal lengths and that the sorted sets are equal. Since these numbers are randomly + # generated there's no way to check that pytorch == pytorch/xla. + self.assertEqual(len(pytorch), len(xla)) + self.assertEqual(sorted(set(pytorch)), sorted(set(xla_detached))) + def test_aten_reciprocal_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() diff --git a/test/test_python_ops.py b/test/test_python_ops.py new file mode 100644 index 00000000000..d939d97b93f --- /dev/null +++ b/test/test_python_ops.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +import torch_xla +import torch_xla.core.xla_model as xm +import unittest +import test_utils +import pytorch_test_base + +from torch.testing import make_tensor +from itertools import product +from functools import partial +from torch.testing._internal.common_utils import TestCase, run_tests, IS_JETSON +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, dtypes) +from torch.testing._internal.common_dtype import (all_types_and_complex_and) + + +# These tests are a copy of upstream pytorch tests due to the way lazy tensors +# work. The randperm op generates a random tensor. Every iteration of the test +# recompiles the randperm op thus generating a different random tensor which +# makes the test non-deterministic. To force determinism, this test has to +# call PyTorch/XLA mark_step() to materialize the tensor rather than recompile. +class TestPythonOps(pytorch_test_base.XLATestBase): + + @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) + def test_put(self, dtype): + if dtype in self.unsupported_dtypes: + raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format( + str(dtype))) + + device = xm.xla_device() + real_device_type = xm.xla_device_hw(str(xm.xla_device())) + if real_device_type == "TPU": + raise unittest.SkipTest("TestPut is too slow on TPU. Skipped") + + src_size = (4,) + + make_arg = partial(make_tensor, device=device, dtype=dtype) + make_idx = partial(make_tensor, low=0, device=device, dtype=torch.int64) + + def ref_put(dst, idx, src, accumulate): + new_dst = dst.clone(memory_format=torch.contiguous_format).view(-1) + new_idx = idx.contiguous().view(-1) + new_src = src.contiguous().view(-1) + method = new_dst.index_add_ if accumulate else new_dst.index_copy_ + return method(0, new_idx, new_src).view_as(dst) + + for dst_contig, src_contig, idx_contig, idx_reshape, accumulate in product( + [True, False], repeat=5): + for dst_size in ((5,), (4, 5)): + dst = make_arg(dst_size, noncontiguous=not dst_contig) + src = make_arg(src_size, noncontiguous=not src_contig) + + # If accumulate=True, `put_` should be deterministic regardless of the inputs on CPU + # On CUDA it may not be, but the test has enough tolerance to account for this + if accumulate: + idx = make_idx(src_size, high=dst.numel()) + else: + idx = torch.randperm( + dst.numel(), dtype=torch.int64, device=device)[:src_size[0]] + if not idx_contig: + idx = torch.repeat_interleave(idx, 2, dim=-1)[..., ::2] + if idx_reshape: + idx = idx.reshape(2, 2) + out = torch.put(dst, idx, src, accumulate) + + xm.mark_step() + + # out-place + reference = ref_put(dst, idx, src, accumulate) + self.assertEqual(out, reference) + + # in-place + dst.put_(idx, src, accumulate) + self.assertEqual(dst, reference) + + # Create the 8 possible combinations of scalar sizes for target / index / source + scalars = ((make_arg(size_t), make_idx(size_i, high=1), make_arg(size_s)) + for size_t, size_i, size_s in product([(), (1,)], repeat=3)) + for (dest, idx, source), accumulate in product(scalars, [True, False]): + dest_init = dest.clone() + # out-place + out = torch.put(dest, idx, source, accumulate=accumulate) + # in-place + dest1 = dest.clone() + dest1.put_(idx, source, accumulate=accumulate) + for d in [out, dest1]: + if accumulate: + self.assertEqual(d.item(), (dest_init + source).item()) + else: + self.assertEqual(d.item(), source.item()) + + # Empty case + dest = make_arg((3, 2)) + reference = dest.clone() + idx = make_idx((0,), high=1) + source = make_arg((0,)) + for accumulate in [True, False]: + out = torch.put(dest, idx, source, accumulate=accumulate) + self.assertEqual(out, reference) + dest.put_(idx, source, accumulate=accumulate) + self.assertEqual(dest, reference) + + @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) + def test_index_copy(self, dtype): + if dtype in self.unsupported_dtypes: + raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format( + str(dtype))) + + device = xm.xla_device() + + # We just test for num_copy <= num_dest, as otherwise there are repeated indices + # and the behavior is undefined + num_copy, num_dest = 3, 5 + + def make_arg(batch_sizes, n, dim, contig): + size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:] + return make_tensor( + size_arg, + dtype=dtype, + device=device, + low=None, + high=None, + noncontiguous=not contig) + + def ref_index_copy(tgt, dim, idx, src): + for i in range(idx.size(0)): + idx_dest = dim * (slice(None),) + (idx[i],) + idx_src = dim * (slice(None),) + (i,) + tgt[idx_dest] = src[idx_src] + + # More thorough testing as in index_add + for dest_contig, src_contig, index_contig in product([True, False], + repeat=3): + for other_sizes in ((), (4, 5)): + for dim in range(len(other_sizes)): + dest = make_arg(other_sizes, num_dest, dim, dest_contig) + src = make_arg(other_sizes, num_copy, dim, src_contig) + idx = torch.randperm( + num_dest, dtype=torch.int64, device=device)[:num_copy] + if not index_contig: + idx = torch.repeat_interleave(idx, 2, dim=-1) + idx = idx[..., ::2] + + xm.mark_step() + + dest2 = dest.clone() + dest.index_copy_(dim, idx, src) + ref_index_copy(dest2, dim, idx, src) + self.assertEqual(dest, dest2) + + +instantiate_device_type_tests(TestPythonOps, globals()) + +if __name__ == '__main__': + run_tests() diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 1c61fb266e9..314a0365006 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2470,6 +2470,32 @@ at::Tensor& XLANativeFunctions::random_( return self; } +at::Tensor XLANativeFunctions::randperm(int64_t n, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + + // Only support the basic version of randperm(int64_t) to start. If there are + // any other parameters, fallback to CPU. + bool fallback_to_cpu = false; + fallback_to_cpu |= layout.has_value(); + fallback_to_cpu |= pin_memory.has_value() && pin_memory.value() == true; + fallback_to_cpu |= dtype.value() != at::ScalarType::Long; + fallback_to_cpu |= n == 0; + + if (fallback_to_cpu) { + return at::native::call_fallback_fn<&xla_cpu_fallback, + ATEN_OP(randperm)>::call(n, dtype, + layout, device, + pin_memory); + } + + return bridge::AtenFromXlaTensor(tensor_methods::randperm( + n, GetXlaDeviceOrCurrent(device), at::ScalarType::Long)); +} + at::Tensor XLANativeFunctions::reflection_pad2d(const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); diff --git a/torch_xla/csrc/ops/randperm.cpp b/torch_xla/csrc/ops/randperm.cpp new file mode 100644 index 00000000000..b8a793eebe8 --- /dev/null +++ b/torch_xla/csrc/ops/randperm.cpp @@ -0,0 +1,72 @@ +#include "torch_xla/csrc/ops/randperm.h" + +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "tsl/platform/stacktrace.h" +#include "xla/client/lib/loops.h" +#include "xla/shape_util.h" + +namespace torch_xla { +namespace { + +using namespace xla; + +xla::Shape NodeOutputShape(int64_t n) { + return xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {n}); +} + +XlaOp Swap(XlaOp input, XlaOp i, XlaOp j) { + XlaOp i_value = xla::DynamicSlice(input, {i}, /*slice_sizes=*/{1}); + XlaOp j_value = xla::DynamicSlice(input, {j}, /*slice_sizes=*/{1}); + + XlaOp write_i = xla::DynamicUpdateSlice(input, j_value, {i}); + XlaOp write_j = xla::DynamicUpdateSlice(write_i, i_value, {j}); + + return write_j; +} + +StatusOr> LoopBodyFn(XlaOp i, absl::Span values, + XlaBuilder* builder) { + XlaOp input_array = values[0]; + XlaOp upper_bound_exclusive = values[1]; + + XlaOp target_index = xla::RngUniform( + i, upper_bound_exclusive, + ShapeUtil::MakeShape(xla::PrimitiveType::S64, /*dimensions=*/{1})); + + XlaOp swapped_array = Swap(input_array, i, target_index); + return std::vector{swapped_array, upper_bound_exclusive}; +} + +} // namespace + +RandPerm::RandPerm(int64_t n, const at::ScalarType dtype, + const at::Layout layout, const at::Device device, + bool pin_memory) + : XlaNode(torch::lazy::OpKind(at::aten::randperm), /*operands=*/{}, + [&]() { return NodeOutputShape(n); }, /*num_outputs=*/1, + torch::lazy::MHash(n)), + n_(n) {} + +// Fischer Yates Shuffle. +XlaOpVector RandPerm::Lower(LoweringContext* lotcx) const { + xla::XlaBuilder* builder = lotcx->builder(); + auto init_tensor = xla::Iota(lotcx->builder(), xla::PrimitiveType::S64, n_); + + auto upper_bound_exclusive = xla::ConstantLiteral( + lotcx->builder(), xla::LiteralUtil::CreateR0(n_)); + auto fischer_yates_loop = xla::ForEachIndex( + /*num_iterations=*/n_ - 1, xla::PrimitiveType::S64, &LoopBodyFn, + {init_tensor, upper_bound_exclusive}, "Fischer-Yates-Shuffle", builder); + + return ReturnOp(fischer_yates_loop.value()[0], lotcx); +} + +std::string RandPerm::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", n=" << n_; + return ss.str(); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/randperm.h b/torch_xla/csrc/ops/randperm.h new file mode 100644 index 00000000000..d29777807da --- /dev/null +++ b/torch_xla/csrc/ops/randperm.h @@ -0,0 +1,24 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_RANDPERM_H_ +#define XLA_TORCH_XLA_CSRC_OPS_RANDPERM_H_ + +#include + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class RandPerm : public XlaNode { + public: + RandPerm(int64_t n, const at::ScalarType dtype, const at::Layout layout, + const at::Device device, bool pin_memory); + + XlaOpVector Lower(LoweringContext* loctx) const override; + std::string ToString() const override; + + private: + int64_t n_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_RANDPERM_H_ diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index eb5a361f1c5..453d3a1a895 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -94,6 +94,7 @@ #include "torch_xla/csrc/ops/put.h" #include "torch_xla/csrc/ops/qr.h" #include "torch_xla/csrc/ops/quant_tensor.h" +#include "torch_xla/csrc/ops/randperm.h" #include "torch_xla/csrc/ops/recv.h" #include "torch_xla/csrc/ops/reduce_scatter.h" #include "torch_xla/csrc/ops/reflection_pad2d.h" @@ -2258,6 +2259,16 @@ void random_(XLATensorPtr& input, int64_t from, int64_t to) { XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape)); } +XLATensorPtr randperm(int64_t n, const torch::lazy::BackendDevice& device, + at::ScalarType scalar_type) { + // These are all PyTorch defaults. PyTorch/XLA doesn't support non default + // params here yet. + torch::lazy::NodePtr node = torch::lazy::MakeNode( + n, at::ScalarType::Long, at::Layout::Strided, at::DeviceType::XLA, + /*pin_memory=*/false); + return XLATensor::Create(node, device, scalar_type); +} + XLATensorPtr reflection_pad2d(const XLATensorPtr& input, std::vector padding) { return input->CreateFrom(torch::lazy::MakeNode(