From 6cf9b915b1d94a690cbebded69fc756ffeaf3248 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 1 Apr 2024 11:03:36 -0300 Subject: [PATCH] Refactor `nms` into TorchVision variant. (#6814) --- docs/source/index.rst | 1 - test/test_operations.py | 109 ++++++-- torch_xla/core/functions.py | 23 -- torch_xla/csrc/BUILD | 3 +- torch_xla/csrc/aten_xla_bridge.cpp | 6 +- torch_xla/csrc/aten_xla_bridge.h | 3 +- torch_xla/csrc/init_python_bindings.cpp | 27 -- torch_xla/csrc/nms_op.cpp | 299 --------------------- torch_xla/csrc/nms_op.h | 19 -- torch_xla/csrc/ops/nms.cpp | 45 +--- torch_xla/csrc/ops/nms.h | 10 +- torch_xla/csrc/tensor_methods.cpp | 19 +- torch_xla/csrc/tensor_methods.h | 7 +- torch_xla/csrc/xla_lower_util.cpp | 243 +++++++++++++++++ torch_xla/csrc/xla_lower_util.h | 3 + torch_xla/csrc/xla_manual_registration.cpp | 39 +++ 16 files changed, 398 insertions(+), 458 deletions(-) delete mode 100644 torch_xla/csrc/nms_op.cpp delete mode 100644 torch_xla/csrc/nms_op.h create mode 100644 torch_xla/csrc/xla_manual_registration.cpp diff --git a/docs/source/index.rst b/docs/source/index.rst index e29053f8538..4246fbbe820 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -33,7 +33,6 @@ xla_model .. automodule:: torch_xla.core.functions .. autofunction:: all_reduce .. autofunction:: all_gather -.. autofunction:: nms distributed ---------------------------------- diff --git a/test/test_operations.py b/test/test_operations.py index b77ca763058..7fb9f5bc3e3 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -83,6 +83,11 @@ def skipIfFunctionalizationDisabled(reason): return _skipIfFunctionalization(value=True, reason=reason) +def onlyOnCUDA(fn): + accelerator = os.environ.get("PJRT_DEVICE").lower() + return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn) + + def _gen_tensor(*args, **kwargs): return torch.randn(*args, **kwargs) @@ -2280,32 +2285,6 @@ def test_send_to_device_single(self): self.assertEqual(dt[0].device, xla_device) self.assertTrue(torch.all(torch.eq(dt[0].cpu(), t))) - def test_nms(self): - BOXES = ( - (0, 0, 3, 2), - (3, 3, 11, 7), - (2, 2, 5, 7), - (7, 4, 15, 12), - ) - SCORES = (0.9, 0.5, 0.95, 0.4) - SCORE_THRESHOLD = 0.1 - IOU_THRESHOLD = 0.08 - - xla_device = xm.xla_device() - boxes = torch.tensor(BOXES, dtype=torch.float).to(xla_device) - scores = torch.tensor(SCORES, dtype=torch.float).to(xla_device) - score_threshold = torch.tensor( - SCORE_THRESHOLD, dtype=torch.float).to(xla_device) - iou_threshold = torch.tensor( - IOU_THRESHOLD, dtype=torch.float).to(xla_device) - - selected_indices, num_valid = xf.nms(boxes, scores, score_threshold, - iou_threshold, len(BOXES)) - - self.assertEqual(selected_indices, - torch.tensor([2, 0, 3, 1], dtype=torch.int32)) - self.assertEqual(num_valid.item(), 3) - def test_util_foreach_api(self): class ForTest(object): @@ -2473,6 +2452,84 @@ def test_dropout(self): f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}") +# These tests were extracted and adapted from torchvision. +# Source: vision/test/test_ops.py +class TestNMS(test_utils.XlaTestCase): + + def _reference_nms(self, boxes, scores, iou_threshold): + import torchvision + return torchvision.ops.nms(boxes.cpu(), scores.cpu(), iou_threshold) + + def _nms(self, boxes, scores, iou_threshold): + import torchvision + device = xm.xla_device() + return torchvision.ops.nms( + boxes.to(device), scores.to(device), iou_threshold).cpu() + + def _create_tensors_with_iou(self, N, iou_thresh): + # force last box to have a pre-defined iou with the first box + # let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1], + # then, in order to satisfy ops.iou(b0, b1) == iou_thresh, + # we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh + # Adjust the threshold upward a bit with the intent of creating + # at least one box that exceeds (barely) the threshold and so + # should be suppressed. + boxes = torch.rand(N, 4) * 100 + boxes[:, 2:] += boxes[:, :2] + boxes[-1, :] = boxes[0, :] + x0, y0, x1, y1 = boxes[-1].tolist() + iou_thresh += 1e-5 + boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh + scores = torch.rand(N) + return boxes, scores + + @skipOnEagerDebug + def test_nms_ref(self): + + def _test(iou, seed): + torch.random.manual_seed(seed) + err_msg = "NMS incompatible between CPU and reference implementation for IoU={}" + boxes, scores = self._create_tensors_with_iou(1000, iou) + keep_ref = self._reference_nms(boxes, scores, iou) + keep = self._nms(boxes, scores, iou) + self.assertEqual(keep, keep_ref, message=err_msg.format(iou)) + + for iou in (0.2, 0.5, 0.8): + for seed in range(10): + with self.subTest(iou=iou, seed=seed): + _test(iou, seed) + + def test_nms_input_errors(self): + with self.assertRaisesRegex(RuntimeError, "boxes should be a 2D tensor."): + self._nms(torch.rand(4), torch.rand(3), 0.5) + with self.assertRaisesRegex( + RuntimeError, "boxes should be a 2D tensor of shape \[N, 4\]."): + self._nms(torch.rand(3, 5), torch.rand(3), 0.5) + with self.assertRaisesRegex(RuntimeError, "scores should be a 1D tensor."): + self._nms(torch.rand(3, 4), torch.rand(3, 2), 0.5) + with self.assertRaisesRegex( + RuntimeError, + "boxes and scores should have the same size for dimension 0."): + self._nms(torch.rand(3, 4), torch.rand(4), 0.5) + + def test_legacy(self): + BOXES = ( + (0, 0, 3, 2), + (3, 3, 11, 7), + (2, 2, 5, 7), + (7, 4, 15, 12), + ) + SCORES = (0.9, 0.5, 0.95, 0.4) + IOU_THRESHOLD = 0.08 + + def fn(boxes, scores): + return self._reference_nms(boxes, scores, IOU_THRESHOLD) + + boxes = torch.tensor(BOXES, dtype=torch.float) + scores = torch.tensor(SCORES, dtype=torch.float) + self.runAtenTest((boxes, scores), fn) + + if __name__ == '__main__': torch.set_default_dtype(torch.float32) torch.manual_seed(42) diff --git a/torch_xla/core/functions.py b/torch_xla/core/functions.py index 59f82c06665..e868ddd3695 100644 --- a/torch_xla/core/functions.py +++ b/torch_xla/core/functions.py @@ -84,29 +84,6 @@ def all_gather(value, dim=0): return AllGather.apply(value, dim) -def nms(boxes, scores, score_threshold, iou_threshold, output_size): - """Performs a Non Maximal Suppression operation. - - Args: - boxes (torch.Tensor): A `torch.Tensor` of shape `[N, 4]` listing the boxes - coordinates in `(y0, x0, y1, x1)` form. - scores (torch.Tensor): A `torch.Tensor` of shape `[N]` listing the scores - of each box. - score_threshold (torch.Tensor): The minimum score for a box to qualify as - valid. - iou_threshold (torch.Tensor): The minimum IOU (Intersection Over Union) - score to trigger overlap logic. - output_size (int): The maximum number of returned indices (must be lower or - equal to N). - - Returns: - A tuple of `torch.Tensor` with the first element being the selected box - indices, and the second element being the number of valid boxes. - """ - return torch_xla._XLAC._xla_nms(boxes, scores, score_threshold, iou_threshold, - output_size) - - def distributed_mm(w, x, split=1): """Performs a matrix multiplication with sharded weight. diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 7f78c534af4..2faf483f067 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -47,7 +47,6 @@ ptxla_cc_library( "ir_dump_util.cpp", "matrix.cpp", "nll_loss.cpp", - "nms_op.cpp", "pooling.cpp", "quant_util.cpp", "random.cpp", @@ -67,6 +66,7 @@ ptxla_cc_library( "xla_lower_util.cpp", "xla_op_builder.cpp", "xla_sharding_util.cpp", + "xla_manual_registration.cpp", ":RegisterAutogradXLA.cpp", ":RegisterXLA.cpp", ":XLANativeFunctions.cpp", @@ -87,7 +87,6 @@ ptxla_cc_library( "ir_dump_util.h", "matrix.h", "nll_loss.h", - "nms_op.h", "pooling.h", "quant_util.h", "random.h", diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index bd9b46b0a04..d091e616c40 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -400,12 +400,14 @@ at::Tensor XlaToAtenTensor(XLATensorPtr xla_tensor, return tensor.to(tensor_options, /*non_blocking=*/false, /*copy=*/true); } -at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor) { +at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor, + bool skip_functionalization) { if (xla_tensor) { auto out = at::Tensor(c10::make_intrusive(std::move(xla_tensor))); // See Note [Lazy Tensor Functionalization] - if (c10::impl::tls_local_dispatch_key_set().excluded_.has( + if (skip_functionalization || + c10::impl::tls_local_dispatch_key_set().excluded_.has( c10::DispatchKey::Functionalize)) { // Invariant: if the functionalization key is in the exclude set, then // we're expected to return an ordinary tensor, which will be "lifted" diff --git a/torch_xla/csrc/aten_xla_bridge.h b/torch_xla/csrc/aten_xla_bridge.h index 7d6188809c0..19dd2b81412 100644 --- a/torch_xla/csrc/aten_xla_bridge.h +++ b/torch_xla/csrc/aten_xla_bridge.h @@ -112,7 +112,8 @@ at::Tensor XlaToAtenTensor(XLATensorPtr xla_tensor, const at::TensorOptions& tensor_options); // Creates an ATen tensor with XLA type id from an XLATensorPtr. -at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor); +at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor, + bool skip_functionalization = false); std::vector AtenFromXlaTensors( absl::Span xla_tensors); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a11fb837479..228b9e96246 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -666,28 +666,6 @@ py::object GetRevisions() { return py_dict; } -py::object XlaNms(const at::Tensor& boxes, const at::Tensor& scores, - const at::Tensor& score_threshold, - const at::Tensor& iou_threshold, int64_t output_size) { - at::Tensor selected_indices; - at::Tensor num_valid; - { - NoGilSection nogil; - auto nms_result = tensor_methods::nms( - bridge::GetXlaTensor(boxes), bridge::GetXlaTensor(scores), - bridge::GetXlaTensor(score_threshold), - bridge::GetXlaTensor(iou_threshold), output_size); - selected_indices = bridge::AtenFromXlaTensor(std::move(nms_result.first)); - num_valid = bridge::AtenFromXlaTensor(std::move(nms_result.second)); - } - auto result_tuple = py::tuple(2); - result_tuple[0] = - torch::autograd::make_variable(selected_indices, /*requires_grad=*/false); - result_tuple[1] = - torch::autograd::make_variable(num_valid, /*requires_grad=*/false); - return result_tuple; -} - std::vector XlaUserComputation( const std::string& opname, const std::vector& inputs, runtime::ComputationClient::ComputationPtr computation) { @@ -1086,11 +1064,6 @@ void InitXlaModuleBindings(py::module m) { [](const at::Tensor& tensor, int dim) { return GetXlaTensorDimensionSize(tensor, dim); }); - m.def("_xla_nms", [](const at::Tensor& boxes, const at::Tensor& scores, - const at::Tensor& score_threshold, - const at::Tensor& iou_threshold, int64_t output_size) { - return XlaNms(boxes, scores, score_threshold, iou_threshold, output_size); - }); m.def("_xla_user_computation", [](const std::string& opname, const std::vector& inputs, const runtime::ComputationClient::ComputationPtr& computation) { diff --git a/torch_xla/csrc/nms_op.cpp b/torch_xla/csrc/nms_op.cpp deleted file mode 100644 index cf21710bef3..00000000000 --- a/torch_xla/csrc/nms_op.cpp +++ /dev/null @@ -1,299 +0,0 @@ -#include "torch_xla/csrc/nms_op.h" - -#include - -#include - -#include "torch_xla/csrc/helpers.h" -#include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/util.h" -#include "torch_xla/csrc/shape_helper.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/loops.h" -#include "xla/client/lib/sorting.h" -#include "xla/util.h" - -// Code extracted from: -// https://github.com/tensorflow/tensorflow/blob/dc4c6d305ba3d2de4a795ec77b483b0fa695b9ee/tensorflow/compiler/tf2xla/kernels/image_ops.cc#L399 - -namespace torch_xla { -namespace { - -struct WhileCondFn { - WhileCondFn(int64_t num_boxes, int64_t output_size) - : num_boxes(num_boxes), output_size(output_size) {} - - xla::StatusOr operator()(absl::Span values, - xla::XlaBuilder* builder) const { - xla::XlaOp row_idx = values[0]; - xla::XlaOp row_in_bounds = - xla::Lt(row_idx, xla::ConstantR0(builder, num_boxes)); - xla::XlaOp num_outputs_so_far = values[1]; - xla::XlaOp results_not_full = xla::Lt( - num_outputs_so_far, xla::ConstantR0(builder, output_size)); - return xla::And(row_in_bounds, results_not_full); - } - - int64_t num_boxes; - int64_t output_size; -}; - -// Process the boxes one-by-one using the iou matrix mask. -// This implementation uses a correct, but greedy, sequential algorithm -// to ensure that suppressed boxes cannot themselves suppress other -// boxes. -struct SuppressBodyFn { - explicit SuppressBodyFn(int64_t num_boxes) : num_boxes(num_boxes) {} - - xla::StatusOr> operator()( - absl::Span values, xla::XlaBuilder* builder) const { - xla::XlaOp row_idx = values[0]; - xla::XlaOp num_outputs_so_far = values[1]; - xla::XlaOp iou_mask = values[2]; - xla::XlaOp included_iou = values[3]; - xla::XlaOp zero = xla::Zero(builder, xla::PrimitiveType::S32); - xla::XlaOp one = xla::One(builder, xla::PrimitiveType::S32); - // Determine if current elem is active using a slice. - // The only reason we need an explicit vector is because some old GCCs can't - // deduce the right type for MakeConstSpan, and providing a single-value - // initializer list directly uses the wrong overload. Delete this once the - // deprecated overload is gone. - std::vector row_idx_vector = {row_idx}; - xla::XlaOp active_elem = - xla::DynamicSlice(included_iou, row_idx_vector, {1}); - active_elem = xla::Reshape(active_elem, {}); - // Increment output count iff current elem is not suppressed. - num_outputs_so_far = - xla::Select(active_elem, num_outputs_so_far + one, num_outputs_so_far); - // Slice out the row_idx. - xla::XlaOp row_iou = - xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes}); - // Remove the diagonal from consideration. An elem cannot suppress - // itself. - row_iou = xla::DynamicUpdateSlice( - row_iou, xla::ConstantR2FromArray2D(builder, {{false}}), - {zero, row_idx}); - // Create a suppression by inverting polarity. - row_iou = xla::Reshape(row_iou, {num_boxes}); - xla::XlaOp supp_mask = xla::Not(row_iou); - // Update mask iff current elem is not suppressed. - included_iou = xla::Select(xla::Broadcast(active_elem, {num_boxes}), - xla::And(included_iou, supp_mask), included_iou); - return std::vector{row_idx + one, num_outputs_so_far, iou_mask, - included_iou}; - } - - int64_t num_boxes; -}; - -xla::XlaOp NmsGather(xla::XlaOp input, absl::Span input_sizes, - xla::XlaOp indices, - absl::Span indices_sizes, int64_t axis) { - const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); - int64_t num_indices = runtime::util::Multiply(indices_sizes); - if (num_indices == 0) { - std::vector output_sizes = - torch::lazy::ToVector(input_sizes); - output_sizes.erase(std::next(output_sizes.begin(), axis)); - return xla::Broadcast( - xla::Zero(input.builder(), input_shape.element_type()), output_sizes); - } - - // Example of a 1-D gather with axis=1, pulling two [3,1] tensors out of a - // tensor of shape [3,3]. - // - // operand = s32[3,3] parameter(0) - // indices = s32[2] parameter(1) - // gather = s32[3,2] gather(operand, indices), - // offset_dims={0}, - // collapsed_slice_dims={1}, - // start_index_map={1}, - // index_vector_dim=1, - // slice_sizes={3, 1} - // - // - // Example of an N-D gather pulling out slices of shape [1,1,2] out of a - // tensor of shape [3,3,2]. - // - // operand = s32[3,3,2] parameter(0) - // indices = s32[2,2] parameter(1) - // gather = s32[2,2] gather(operand, indices), - // offset_dims={1}, - // collapsed_slice_dims={0,1}, - // start_index_map={0,1}, - // index_vector_dim=0, - // slice_sizes={1,1,2} - xla::GatherDimensionNumbers dim_numbers; - std::vector slice_sizes; - for (int64_t i = 0; i < input_sizes.size(); ++i) { - int64_t window_bound; - if (i == axis) { - dim_numbers.add_collapsed_slice_dims(i); - window_bound = 1; - } else { - window_bound = input_sizes[i]; - } - slice_sizes.push_back(window_bound); - if (i < axis) { - dim_numbers.add_offset_dims(i); - } else if (i > axis) { - dim_numbers.add_offset_dims(i + indices_sizes.size() - 1); - } - } - dim_numbers.set_index_vector_dim(indices_sizes.size()); - dim_numbers.add_start_index_map(axis); - return xla::Gather(input, indices, dim_numbers, slice_sizes); -} - -} // namespace - -NmsResult BuildNms(xla::XlaOp boxes, xla::XlaOp scores, - xla::XlaOp score_threshold, xla::XlaOp iou_threshold, - int64_t output_size) { - const xla::Shape& boxes_shape = ShapeHelper::ShapeOfXlaOp(boxes); - int64_t num_boxes = boxes_shape.dimensions(0); - const xla::Shape& scores_shape = ShapeHelper::ShapeOfXlaOp(scores); - XLA_CHECK_EQ(boxes_shape.rank(), 2); - XLA_CHECK_EQ(boxes_shape.dimensions(1), 4); - XLA_CHECK_EQ(scores_shape.rank(), 1); - XLA_CHECK_EQ(scores_shape.dimensions(0), num_boxes); - XLA_CHECK_LT(num_boxes, std::numeric_limits::max()); - XLA_CHECK_GE(output_size, 0); - XLA_CHECK_LT(output_size, std::numeric_limits::max()); - - xla::XlaBuilder* builder = boxes.builder(); - // Choose a more convenient layout. - xla::XlaOp boxes_transposed = xla::Transpose(boxes, {1, 0}); - xla::XlaOp boxes_sorted = xla::GetTupleElement( - xla::Sort({xla::Broadcast(scores, {4}), boxes_transposed}, - xla::CreateScalarGtComputation( - {scores_shape.element_type(), boxes_shape.element_type()}, - builder), - /*dimension=*/1), - 1); - // Track the mapping of indices into sorted domain. - xla::XlaOp iota_indices = - xla::Iota(builder, xla::PrimitiveType::S32, num_boxes); - xla::XlaOp indices_sort = xla::Sort( - {scores, iota_indices}, - xla::CreateScalarGtComputation( - {scores_shape.element_type(), xla::PrimitiveType::S32}, builder)); - xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1); - xla::XlaOp scores_sorted = xla::GetTupleElement(indices_sort, 0); - - // Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0. - xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted, - /*start_index=*/0, - /*limit_index=*/1, - /*stride=*/1, - /*dimno=*/0), - {num_boxes}); - xla::XlaOp c_x0 = xla::Reshape(xla::SliceInDim(boxes_sorted, - /*start_index=*/1, - /*limit_index=*/2, - /*stride=*/1, - /*dimno=*/0), - {num_boxes}); - xla::XlaOp c_y1 = xla::Reshape(xla::SliceInDim(boxes_sorted, - /*start_index=*/2, - /*limit_index=*/3, - /*stride=*/1, - /*dimno=*/0), - {num_boxes}); - xla::XlaOp c_x1 = xla::Reshape(xla::SliceInDim(boxes_sorted, - /*start_index=*/3, - /*limit_index=*/4, - /*stride=*/1, - /*dimno=*/0), - {num_boxes}); - - xla::XlaOp y1 = xla::Select(xla::Le(c_y0, c_y1), c_y0, c_y1); - xla::XlaOp y2 = xla::Select(xla::Le(c_y0, c_y1), c_y1, c_y0); - xla::XlaOp x1 = xla::Select(xla::Le(c_x0, c_x1), c_x0, c_x1); - xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0); - xla::XlaOp area = (y2 - y1) * (x2 - x1); - - // Shapes are henceforth [1, num_boxes]. - y1 = xla::Broadcast(y1, {1}); - y2 = xla::Broadcast(y2, {1}); - x1 = xla::Broadcast(x1, {1}); - x2 = xla::Broadcast(x2, {1}); - area = xla::Broadcast(area, {1}); - - // Shapes are henceforth [num_boxes, num_boxes]. - xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0})); - xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0})); - xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0})); - xla::XlaOp i_ymax = xla::Min(y2, xla::Transpose(y2, {1, 0})); - auto square_zero = xla::ZerosLike(i_xmin); - - xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) * - xla::Max(i_ymax - i_ymin, square_zero); - xla::XlaOp u_area = area + xla::Transpose(area, {1, 0}) - i_area; - xla::XlaOp iou = i_area / u_area; - - xla::XlaOp iou_threshold_mask = xla::Gt(iou, iou_threshold + square_zero); - xla::XlaOp included_iou = - xla::Broadcast(xla::ConstantR0(builder, true), {num_boxes}); - if (boxes_shape.is_dynamic_dimension(0)) { - // Update included_iou's size to match boxes actual size. - included_iou = xla::SetDimensionSize( - included_iou, XlaHelpers::GetDimensionsSize({boxes}, {0}).size, 0); - } - - xla::XlaOp zero_s32 = xla::Zero(builder, xla::PrimitiveType::S32); - xla::XlaOp one_s32 = xla::One(builder, xla::PrimitiveType::S32); - std::vector init_values; - init_values.reserve(4); - init_values.push_back(zero_s32); // col_idx - init_values.push_back(zero_s32); // num_outputs - init_values.push_back(iou_threshold_mask); - init_values.push_back(included_iou); - - auto suppress_loop_result = ConsumeValue(xla::WhileLoopHelper( - WhileCondFn(num_boxes, output_size), SuppressBodyFn(num_boxes), - init_values, "BoxSuppressLoop", builder)); - - xla::XlaOp included_score = - xla::Gt(scores_sorted, xla::Broadcast(score_threshold, {num_boxes})); - xla::XlaOp included = xla::And(included_score, suppress_loop_result[3]); - - // Only consider boxes over which we have iterated. This allows for accurate - // counting. DynamicSlice would require knowledge of the size of the output. - xla::XlaOp valid_elem = xla::Lt( - iota_indices, xla::Broadcast(suppress_loop_result[0], {num_boxes})); - included = xla::And(included, valid_elem); - - xla::XlaOp neg_inf = xla::Broadcast( - xla::MinValue(builder, scores_shape.element_type()), {num_boxes}); - xla::XlaOp scores_included = xla::Select(included, scores_sorted, neg_inf); - xla::XlaOp output_tuple = xla::TopK(scores_included, output_size); - xla::XlaOp selected_indices_sorted = xla::GetTupleElement(output_tuple, 1); - // Calculate num_valid. - // Note: num_valid cannot be taken from the loop outputs, because outputs - // can be suppressed by score threshold. - xla::XlaOp ones_included = - xla::Select(included, xla::Broadcast(one_s32, {num_boxes}), - xla::Broadcast(zero_s32, {num_boxes})); - // num_valid is scalar. torch::lazy::Value should be bound by output_size. - xla::XlaOp num_valid_total = xla::Reduce( - ones_included, - /*init_value=*/zero_s32, - /*computation=*/ - xla::CreateScalarAddComputation(xla::PrimitiveType::S32, builder), - /*dimensions_to_reduce=*/{0}); - xla::XlaOp num_valid = - xla::Min(num_valid_total, xla::ConstantR0(builder, output_size)); - - // Re-index into the original scores input tensor, using a Gather. - // Boxes were suppressed in the sorted domain. - xla::XlaOp selected_indices = - NmsGather(indices_sorted, scores_shape.dimensions(), - selected_indices_sorted, {output_size}, - /*axis=*/0); - return {selected_indices, num_valid}; -} - -} // namespace torch_xla diff --git a/torch_xla/csrc/nms_op.h b/torch_xla/csrc/nms_op.h deleted file mode 100644 index 42b2bf0016e..00000000000 --- a/torch_xla/csrc/nms_op.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef XLA_TORCH_XLA_CSRC_NMS_OP_H_ -#define XLA_TORCH_XLA_CSRC_NMS_OP_H_ - -#include "xla/client/xla_builder.h" - -namespace torch_xla { - -struct NmsResult { - xla::XlaOp selected_indices; - xla::XlaOp num_valid; -}; - -NmsResult BuildNms(xla::XlaOp boxes, xla::XlaOp scores, - xla::XlaOp score_threshold, xla::XlaOp iou_threshold, - int64_t output_size); - -} // namespace torch_xla - -#endif // XLA_TORCH_XLA_CSRC_NMS_OP_H_ \ No newline at end of file diff --git a/torch_xla/csrc/ops/nms.cpp b/torch_xla/csrc/ops/nms.cpp index 64df29ce460..70a80a9173b 100644 --- a/torch_xla/csrc/ops/nms.cpp +++ b/torch_xla/csrc/ops/nms.cpp @@ -1,65 +1,46 @@ #include "torch_xla/csrc/ops/nms.h" #include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/nms_op.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/xla_lower_util.h" namespace torch_xla { namespace { xla::Shape NodeOutputShape(const torch::lazy::Value& boxes, const torch::lazy::Value& scores, - const torch::lazy::Value& score_threshold, - const torch::lazy::Value& iou_threshold, - int64_t output_size) { + const torch::lazy::Value& iou_threshold) { auto shape_fn = [&](absl::Span operands) -> xla::XlaOp { - NmsResult result = BuildNms(operands[0], operands[1], operands[2], - operands[3], output_size); - return xla::Tuple(result.selected_indices.builder(), - {result.selected_indices, result.num_valid}); + return BuildNms(operands[0], operands[1], operands[2]); }; + return InferOutputShape( - {GetXlaShape(boxes), GetXlaShape(scores), GetXlaShape(score_threshold), - GetXlaShape(iou_threshold)}, + {GetXlaShape(boxes), GetXlaShape(scores), GetXlaShape(iou_threshold)}, shape_fn); } } // namespace Nms::Nms(const torch::lazy::Value& boxes, const torch::lazy::Value& scores, - const torch::lazy::Value& score_threshold, - const torch::lazy::Value& iou_threshold, int64_t output_size) + const torch::lazy::Value& iou_threshold) : XlaNode( - xla_nms, {boxes, scores, score_threshold, iou_threshold}, - [&]() { - return NodeOutputShape(boxes, scores, score_threshold, - iou_threshold, output_size); - }, - /*num_outputs=*/2, torch::lazy::MHash(output_size)), - output_size_(output_size) {} + /*op=*/xla_nms, + /*operands=*/{boxes, scores, iou_threshold}, + /*xla_shape_fn=*/ + [&]() { return NodeOutputShape(boxes, scores, iou_threshold); }) {} torch::lazy::NodePtr Nms::Clone(torch::lazy::OpList operands) const { return torch::lazy::MakeNode(operands.at(0), operands.at(1), - operands.at(2), operands.at(3), - output_size_); + operands.at(2)); } XlaOpVector Nms::Lower(LoweringContext* loctx) const { xla::XlaOp boxes = loctx->GetOutputOp(operand(0)); xla::XlaOp scores = loctx->GetOutputOp(operand(1)); - xla::XlaOp score_threshold = loctx->GetOutputOp(operand(2)); - xla::XlaOp iou_threshold = loctx->GetOutputOp(operand(3)); - NmsResult result = - BuildNms(boxes, scores, score_threshold, iou_threshold, output_size_); - return ReturnOps({result.selected_indices, result.num_valid}, loctx); -} - -std::string Nms::ToString() const { - std::stringstream ss; - ss << XlaNode::ToString() << ", output_size=" << output_size_; - return ss.str(); + xla::XlaOp iou_threshold = loctx->GetOutputOp(operand(2)); + return ReturnOp(BuildNms(boxes, scores, iou_threshold), loctx); } } // namespace torch_xla diff --git a/torch_xla/csrc/ops/nms.h b/torch_xla/csrc/ops/nms.h index b825b27f187..c45dd84cdf4 100644 --- a/torch_xla/csrc/ops/nms.h +++ b/torch_xla/csrc/ops/nms.h @@ -8,19 +8,11 @@ namespace torch_xla { class Nms : public XlaNode { public: Nms(const torch::lazy::Value& boxes, const torch::lazy::Value& scores, - const torch::lazy::Value& score_threshold, - const torch::lazy::Value& iou_threshold, int64_t output_size); + const torch::lazy::Value& iou_threshold); torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; XlaOpVector Lower(LoweringContext* loctx) const override; - - std::string ToString() const override; - - int64_t output_size() const { return output_size_; } - - private: - int64_t output_size_; }; } // namespace torch_xla diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index d89a4afbf28..9ef81db5b29 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2089,19 +2089,14 @@ XLATensorPtr nll_loss_backward(const XLATensorPtr& grad_output, GetXlaReductionMode(reduction), ignore_index)); } -std::pair nms(const XLATensorPtr& boxes, - const XLATensorPtr& scores, - const XLATensorPtr& score_threshold, - const XLATensorPtr& iou_threshold, - int64_t output_size) { +XLATensorPtr nms(const XLATensorPtr& boxes, const XLATensorPtr& scores, + double iou_threshold) { + const torch::lazy::BackendDevice& device = boxes->GetDevice(); + torch::lazy::NodePtr xla_iou_threshold = + ScalarOp(iou_threshold, MakeXlaPrimitiveType(at::kDouble, &device)); torch::lazy::NodePtr node = torch::lazy::MakeNode( - boxes->GetIrValue(), scores->GetIrValue(), score_threshold->GetIrValue(), - iou_threshold->GetIrValue(), output_size); - return std::pair( - XLATensor::Create(torch::lazy::Value(node, 0), boxes->GetDevice(), - at::ScalarType::Int), - XLATensor::Create(torch::lazy::Value(node, 1), boxes->GetDevice(), - at::ScalarType::Int)); + boxes->GetIrValue(), scores->GetIrValue(), xla_iou_threshold); + return XLATensor::Create(node, device, at::ScalarType::Long); } XLATensorPtr nonzero(const XLATensorPtr& input) { diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 96a43cb675c..0a704dea636 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -678,11 +678,8 @@ XLATensorPtr nll_loss_backward(const XLATensorPtr& grad_output, int ignore_index, const XLATensorPtr& total_weight); -std::pair nms(const XLATensorPtr& boxes, - const XLATensorPtr& scores, - const XLATensorPtr& score_threshold, - const XLATensorPtr& iou_threshold, - int64_t output_size); +XLATensorPtr nms(const XLATensorPtr& boxes, const XLATensorPtr& scores, + double iou_threshold); XLATensorPtr nonzero(const XLATensorPtr& input); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 34bfa1e306e..47b438a6e35 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1295,4 +1295,247 @@ std::vector BuildTpuCustomCall( return result; } +std::vector BuildBoxSelectionLoop(int64_t num_boxes, + xla::XlaOp iou_threshold_mask) { + using IndexType = int32_t; + const xla::PrimitiveType XLAIndexType = xla::PrimitiveType::S32; + + xla::XlaBuilder* builder = iou_threshold_mask.builder(); + + const xla::XlaOp ZERO = xla::Zero(builder, XLAIndexType); + const xla::XlaOp TRUE = xla::ConstantR0(builder, true); + + // Initial values to the while loop. + std::vector init_values(3); + // 1. Loop counter: represents the actual box being processed. + init_values[0] = ZERO; + // 2. State of each box (i.e. whether it was included or not). + init_values[1] = xla::Broadcast(TRUE, {num_boxes}); + // 3. The actual IoU threshold matrix. + init_values[2] = iou_threshold_mask; + + return ConsumeValue(xla::WhileLoopHelper( + [=](absl::Span values, xla::XlaBuilder* builder) { + xla::XlaOp box_index = values[0]; + // Check: current loop counter is within bounds, i.e. has a + // corresponding box. + return xla::Lt(box_index, + xla::ConstantR0(builder, num_boxes)); + }, + [=](absl::Span values, xla::XlaBuilder* builder) { + const xla::XlaOp ONE = xla::One(builder, XLAIndexType); + const xla::XlaOp ZERO = xla::Zero(builder, XLAIndexType); + + xla::XlaOp box_index = values[0]; + xla::XlaOp state = values[1]; + xla::XlaOp iou_threshold_mask = values[2]; + + // Retrieve the IoU mask row corresponding to this box. + xla::XlaOp box_iou_threshold_mask = xla::DynamicSlice( + iou_threshold_mask, {box_index, ZERO}, {1, num_boxes}); + + // Update the current state with the IoU mask. + // Basically, sets to false every box X whose IoU with the current box + // is less-than or equal than the given threshold. + xla::XlaOp updated_state = xla::And( + state, + // Update the mask so that if we select this box + // (i.e. state[box] == true), we don't de-select it. + xla::DynamicUpdateSlice( + // Before that, we need to pre-process the mask. + // 1. Negate the mask: if this box is selected, we only want + // those that have a low intersection ratio. + // 2. Reshape it to: [num_boxes]. + xla::Reshape(xla::Not(box_iou_threshold_mask), {num_boxes}), + xla::ConstantR1(builder, {true}), {box_index})); + + // Flag: should this box (loop counter) be included in the output? + xla::XlaOp should_include = xla::DynamicSlice(state, {box_index}, {1}); + // Pick the new values of state, depending on whether we should include + // this box or not. + xla::XlaOp new_state = + xla::Select(xla::BroadcastInDim(should_include, {num_boxes}, {0}), + updated_state, state); + + xla::XlaOp next_box_index = box_index + ONE; + return std::vector{next_box_index, new_state, + iou_threshold_mask}; + }, + init_values, "BoxSelectionLoop", builder)); +} + +xla::XlaOp BuildNms(xla::XlaOp boxes, xla::XlaOp scores, + xla::XlaOp iou_threshold) { + using IndexType = int32_t; + const xla::PrimitiveType XLAIndexType = xla::PrimitiveType::S32; + + xla::XlaBuilder* builder = boxes.builder(); + + const int64_t COORDINATES = 4; + const xla::XlaOp ONE = xla::One(builder, XLAIndexType); + const xla::XlaOp ZERO = xla::Zero(builder, XLAIndexType); + + const xla::Shape& boxes_shape = ShapeHelper::ShapeOfXlaOp(boxes); + XLA_CHECK_EQ(boxes_shape.rank(), 2); + XLA_CHECK_EQ(boxes_shape.dimensions(1), COORDINATES); + int64_t num_boxes = boxes_shape.dimensions(0); + + const xla::Shape& scores_shape = ShapeHelper::ShapeOfXlaOp(scores); + XLA_CHECK_EQ(scores_shape.rank(), 1); + XLA_CHECK_EQ(scores_shape.dimensions(0), num_boxes); + + // 1. Order the boxes according to their scores. + // Also remember the order of the boxes original indices by having an + // extra Iota operand. + xla::XlaOp sorted = xla::Sort( + { + // Here, we need to broadcast both the scores and Iota operands, so + // as to have the same dimensions as boxes: {COORDINATES, num_boxes}. + xla::Broadcast(scores, {COORDINATES}), + xla::Broadcast(xla::Iota(builder, XLAIndexType, num_boxes), + {COORDINATES}), + // Transpose boxes, so as to manipulate its values in an easier way. + xla::Transpose(boxes, {1, 0}), + }, + xla::CreateScalarGtComputation( + { + scores_shape.element_type(), + XLAIndexType, + boxes_shape.element_type(), + }, + builder), + /*dimension=*/1); + + // 1.1. De-construct the returned tuple. + // Specifically, we only need one of the rows of the sorted index tensor + // and of the sorted scores, since the others were only broadcasted. + // + // Shape: [1, num_boxes] + xla::XlaOp sorted_scores = + xla::SliceInDim(xla::GetTupleElement(sorted, 0), /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, /*dimno=*/0); + // Shape: [1, num_boxes] + xla::XlaOp sorted_indices = + xla::SliceInDim(xla::GetTupleElement(sorted, 1), /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, /*dimno=*/0); + // Shape: [COORDINATES, num_boxes] + xla::XlaOp sorted_boxes = xla::GetTupleElement(sorted, 2); + + // 1.2. Retrieve each coordinate, in their own tensor. + // Since we transposed boxes tensor, each row corresponds to a + // coordinate. + // + // Shape: [1, num_boxes] + xla::XlaOp y0 = xla::SliceInDim(sorted_boxes, /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, /*dimno=*/0); + xla::XlaOp x0 = xla::SliceInDim(sorted_boxes, /*start_index=*/1, + /*limit_index=*/2, /*stride=*/1, /*dimno=*/0); + xla::XlaOp y1 = xla::SliceInDim(sorted_boxes, /*start_index=*/2, + /*limit_index=*/3, /*stride=*/1, /*dimno=*/0); + xla::XlaOp x1 = xla::SliceInDim(sorted_boxes, /*start_index=*/3, + /*limit_index=*/4, /*stride=*/1, /*dimno=*/0); + + // 2. Create the IoU (Intersection over Union) ratio mask + // 2.1. First, compute the area of each box. + // + // Shape: [1, num_boxes] + xla::XlaOp area = (y1 - y0) * (x1 - x0); + + // 2.2. Get the corners of the intersection box created by every pair of + // boxes. + // Basically, given 2 boxes, what the corner of their + // intersection box would be? + // + // Shape: [num_boxes, num_boxes] + xla::XlaOp left = xla::Max(x0, xla::Transpose(x0, {1, 0})); + xla::XlaOp bottom = xla::Max(y0, xla::Transpose(y0, {1, 0})); + xla::XlaOp right = xla::Min(x1, xla::Transpose(x1, {1, 0})); + xla::XlaOp top = xla::Min(y1, xla::Transpose(y1, {1, 0})); + + // 2.3. Compute the intersection area. + // Whenever 2 boxes don't intersect, either their width or height will be + // negative. + // + // Shape: [num_boxes, num_boxes] + xla::XlaOp zeros = xla::ZerosLike(left); + xla::XlaOp intersection_area = + xla::Max(right - left, zeros) * xla::Max(top - bottom, zeros); + + // 2.4. Compute the union area. + // Sum of the areas of every pair of boxes, minus their intersection + // area. + // + // Shape: [num_boxes, num_boxes] + xla::XlaOp union_area = + area + xla::Transpose(area, {1, 0}) - intersection_area; + + // 2.5. Compute the IoU ratio. + // + // Shape: [num_boxes, num_boxes] + xla::XlaOp iou = intersection_area / union_area; + + // 2.6. Create the mask by comparing it with the given threshold. + // + // Shape: [num_boxes, num_boxes] + xla::XlaOp casted_threshold = + xla::ConvertElementType(iou_threshold, XlaHelpers::TypeOfXlaOp(iou)); + xla::XlaOp iou_threshold_mask = xla::Gt(iou, casted_threshold); + + // 3. Iteratively select the highest scoring box, and eliminate those whose + // IoU is greater than the threshold. + // + // state: a [num_boxes] tensor, where, at the end of the loop, for + // each box i, state[i] represents whether box i should be + // included in the output or not + // + // Loop Invariant: for every i in [0..current iteration], state[i] + // represents whether box i is included or not in the + // output. + // + // Rough idea: at every iteration i, we: + // - Check if state[i] == true (i.e. box i should be included) + // + // - If so, retrieve and negate the i-th row from the IoU mask + // (i.e. what are the boxes that have an IoU ratio lower-than or + // equal the given threshold?). + // + // - Update state[i+1..] by computing a logical and operation with + // the retrieved negated IoU mask. Note that this won't modify + // state[0..i]. The next box j > i, where state[j] == true is + // the next box that will be included. + std::vector loop_result = + BuildBoxSelectionLoop(num_boxes, iou_threshold_mask); + + xla::XlaOp loop_counter = loop_result[0]; + xla::XlaOp included_mask = loop_result[1]; + + // 4. Retrieve the included box indices. + // 4.1. Compute the number of included boxes. + // 4.1.1. Transform that mask into a 0-1 tensor. + xla::XlaOp one_if_included = + xla::Select(included_mask, xla::Broadcast(ONE, {num_boxes}), + xla::Broadcast(ZERO, {num_boxes})); + // 4.1.2. Sum it up. + xla::XlaOp included_boxes = + xla::Reduce(one_if_included, ZERO, + xla::CreateScalarAddComputation(XLAIndexType, builder), {0}); + + // 4.2. Move the indices of the included boxes to the beginning. + // Doing so alongside the previously sorted indices gives us an index + // tensor with the original indices of the selected boxes at the + // beginning. + xla::XlaOp included_indices_first = xla::GetTupleElement( + xla::Sort( + { + one_if_included, + xla::Reshape(sorted_indices, {num_boxes}), + }, + xla::CreateScalarGtComputation({XLAIndexType, XLAIndexType}, builder), + /*dimension=*/0, /*is_stable=*/true), + 1); + + // 4.3. Get only the first included_boxes indices. + return xla::SetDimensionSize(included_indices_first, included_boxes, 0); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 45014c5f4fb..8e632796c23 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -156,6 +156,9 @@ std::vector BuildTpuCustomCall( const std::vector& inputs, const xla::Shape& output_shape, const std::string& payload); +xla::XlaOp BuildNms(xla::XlaOp boxes, xla::XlaOp scores, + xla::XlaOp iou_threshold); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_XLA_LOWER_UTIL_H_ diff --git a/torch_xla/csrc/xla_manual_registration.cpp b/torch_xla/csrc/xla_manual_registration.cpp new file mode 100644 index 00000000000..dc7df436ec7 --- /dev/null +++ b/torch_xla/csrc/xla_manual_registration.cpp @@ -0,0 +1,39 @@ +#include +#include + +#include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/ops/nms.h" +#include "torch_xla/csrc/ops/ops.h" +#include "torch_xla/csrc/tensor_methods.h" +#include "torch_xla/csrc/tensor_util.h" + +namespace torch_xla { +namespace manual { +namespace { + +at::Tensor nms_kernel(const at::Tensor& boxes, const at::Tensor& scores, + double iou_threshold) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + + XLA_CHECK_EQ(boxes.dim(), 2) << "nms(): boxes should be a 2D tensor."; + XLA_CHECK_EQ(boxes.size(1), 4) + << "nms(): boxes should be a 2D tensor of shape [N, 4]."; + XLA_CHECK_EQ(scores.dim(), 1) << "nms(): scores should be a 1D tensor."; + XLA_CHECK_EQ(boxes.size(0), scores.size(0)) + << "nms(): boxes and scores should have the same size for dimension 0."; + + XLATensorPtr xla_boxes = bridge::GetXlaTensor(boxes); + XLATensorPtr xla_scores = bridge::GetXlaTensor(scores); + return bridge::AtenFromXlaTensor( + tensor_methods::nms(xla_boxes, xla_scores, iou_threshold), + /*skip_functionalization=*/true); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, XLA, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); +} + +} // namespace manual +} // namespace torch_xla