From 302cd4c82ba4c2839bd8d5652186b9f92b374b08 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 23 Mar 2024 20:21:14 -0300 Subject: [PATCH] Refactor and implement new `nms`. --- torch_xla/csrc/BUILD | 1 + torch_xla/csrc/nms_op.cpp | 299 --------------------- torch_xla/csrc/nms_op.h | 19 -- torch_xla/csrc/ops/nms.cpp | 45 +--- torch_xla/csrc/ops/nms.h | 10 +- torch_xla/csrc/tensor_methods.cpp | 19 +- torch_xla/csrc/tensor_methods.h | 7 +- torch_xla/csrc/xla_lower_util.cpp | 243 +++++++++++++++++ torch_xla/csrc/xla_lower_util.h | 3 + torch_xla/csrc/xla_manual_registration.cpp | 38 +++ 10 files changed, 308 insertions(+), 376 deletions(-) delete mode 100644 torch_xla/csrc/nms_op.cpp delete mode 100644 torch_xla/csrc/nms_op.h create mode 100644 torch_xla/csrc/xla_manual_registration.cpp diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index ce718a1ebfe3..2faf483f067a 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -66,6 +66,7 @@ ptxla_cc_library( "xla_lower_util.cpp", "xla_op_builder.cpp", "xla_sharding_util.cpp", + "xla_manual_registration.cpp", ":RegisterAutogradXLA.cpp", ":RegisterXLA.cpp", ":XLANativeFunctions.cpp", diff --git a/torch_xla/csrc/nms_op.cpp b/torch_xla/csrc/nms_op.cpp deleted file mode 100644 index cf21710bef3c..000000000000 --- a/torch_xla/csrc/nms_op.cpp +++ /dev/null @@ -1,299 +0,0 @@ -#include "torch_xla/csrc/nms_op.h" - -#include - -#include - -#include "torch_xla/csrc/helpers.h" -#include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/util.h" -#include "torch_xla/csrc/shape_helper.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/loops.h" -#include "xla/client/lib/sorting.h" -#include "xla/util.h" - -// Code extracted from: -// https://github.com/tensorflow/tensorflow/blob/dc4c6d305ba3d2de4a795ec77b483b0fa695b9ee/tensorflow/compiler/tf2xla/kernels/image_ops.cc#L399 - -namespace torch_xla { -namespace { - -struct WhileCondFn { - WhileCondFn(int64_t num_boxes, int64_t output_size) - : num_boxes(num_boxes), output_size(output_size) {} - - xla::StatusOr operator()(absl::Span values, - xla::XlaBuilder* builder) const { - xla::XlaOp row_idx = values[0]; - xla::XlaOp row_in_bounds = - xla::Lt(row_idx, xla::ConstantR0(builder, num_boxes)); - xla::XlaOp num_outputs_so_far = values[1]; - xla::XlaOp results_not_full = xla::Lt( - num_outputs_so_far, xla::ConstantR0(builder, output_size)); - return xla::And(row_in_bounds, results_not_full); - } - - int64_t num_boxes; - int64_t output_size; -}; - -// Process the boxes one-by-one using the iou matrix mask. -// This implementation uses a correct, but greedy, sequential algorithm -// to ensure that suppressed boxes cannot themselves suppress other -// boxes. -struct SuppressBodyFn { - explicit SuppressBodyFn(int64_t num_boxes) : num_boxes(num_boxes) {} - - xla::StatusOr> operator()( - absl::Span values, xla::XlaBuilder* builder) const { - xla::XlaOp row_idx = values[0]; - xla::XlaOp num_outputs_so_far = values[1]; - xla::XlaOp iou_mask = values[2]; - xla::XlaOp included_iou = values[3]; - xla::XlaOp zero = xla::Zero(builder, xla::PrimitiveType::S32); - xla::XlaOp one = xla::One(builder, xla::PrimitiveType::S32); - // Determine if current elem is active using a slice. - // The only reason we need an explicit vector is because some old GCCs can't - // deduce the right type for MakeConstSpan, and providing a single-value - // initializer list directly uses the wrong overload. Delete this once the - // deprecated overload is gone. - std::vector row_idx_vector = {row_idx}; - xla::XlaOp active_elem = - xla::DynamicSlice(included_iou, row_idx_vector, {1}); - active_elem = xla::Reshape(active_elem, {}); - // Increment output count iff current elem is not suppressed. - num_outputs_so_far = - xla::Select(active_elem, num_outputs_so_far + one, num_outputs_so_far); - // Slice out the row_idx. - xla::XlaOp row_iou = - xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes}); - // Remove the diagonal from consideration. An elem cannot suppress - // itself. - row_iou = xla::DynamicUpdateSlice( - row_iou, xla::ConstantR2FromArray2D(builder, {{false}}), - {zero, row_idx}); - // Create a suppression by inverting polarity. - row_iou = xla::Reshape(row_iou, {num_boxes}); - xla::XlaOp supp_mask = xla::Not(row_iou); - // Update mask iff current elem is not suppressed. - included_iou = xla::Select(xla::Broadcast(active_elem, {num_boxes}), - xla::And(included_iou, supp_mask), included_iou); - return std::vector{row_idx + one, num_outputs_so_far, iou_mask, - included_iou}; - } - - int64_t num_boxes; -}; - -xla::XlaOp NmsGather(xla::XlaOp input, absl::Span input_sizes, - xla::XlaOp indices, - absl::Span indices_sizes, int64_t axis) { - const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); - int64_t num_indices = runtime::util::Multiply(indices_sizes); - if (num_indices == 0) { - std::vector output_sizes = - torch::lazy::ToVector(input_sizes); - output_sizes.erase(std::next(output_sizes.begin(), axis)); - return xla::Broadcast( - xla::Zero(input.builder(), input_shape.element_type()), output_sizes); - } - - // Example of a 1-D gather with axis=1, pulling two [3,1] tensors out of a - // tensor of shape [3,3]. - // - // operand = s32[3,3] parameter(0) - // indices = s32[2] parameter(1) - // gather = s32[3,2] gather(operand, indices), - // offset_dims={0}, - // collapsed_slice_dims={1}, - // start_index_map={1}, - // index_vector_dim=1, - // slice_sizes={3, 1} - // - // - // Example of an N-D gather pulling out slices of shape [1,1,2] out of a - // tensor of shape [3,3,2]. - // - // operand = s32[3,3,2] parameter(0) - // indices = s32[2,2] parameter(1) - // gather = s32[2,2] gather(operand, indices), - // offset_dims={1}, - // collapsed_slice_dims={0,1}, - // start_index_map={0,1}, - // index_vector_dim=0, - // slice_sizes={1,1,2} - xla::GatherDimensionNumbers dim_numbers; - std::vector slice_sizes; - for (int64_t i = 0; i < input_sizes.size(); ++i) { - int64_t window_bound; - if (i == axis) { - dim_numbers.add_collapsed_slice_dims(i); - window_bound = 1; - } else { - window_bound = input_sizes[i]; - } - slice_sizes.push_back(window_bound); - if (i < axis) { - dim_numbers.add_offset_dims(i); - } else if (i > axis) { - dim_numbers.add_offset_dims(i + indices_sizes.size() - 1); - } - } - dim_numbers.set_index_vector_dim(indices_sizes.size()); - dim_numbers.add_start_index_map(axis); - return xla::Gather(input, indices, dim_numbers, slice_sizes); -} - -} // namespace - -NmsResult BuildNms(xla::XlaOp boxes, xla::XlaOp scores, - xla::XlaOp score_threshold, xla::XlaOp iou_threshold, - int64_t output_size) { - const xla::Shape& boxes_shape = ShapeHelper::ShapeOfXlaOp(boxes); - int64_t num_boxes = boxes_shape.dimensions(0); - const xla::Shape& scores_shape = ShapeHelper::ShapeOfXlaOp(scores); - XLA_CHECK_EQ(boxes_shape.rank(), 2); - XLA_CHECK_EQ(boxes_shape.dimensions(1), 4); - XLA_CHECK_EQ(scores_shape.rank(), 1); - XLA_CHECK_EQ(scores_shape.dimensions(0), num_boxes); - XLA_CHECK_LT(num_boxes, std::numeric_limits::max()); - XLA_CHECK_GE(output_size, 0); - XLA_CHECK_LT(output_size, std::numeric_limits::max()); - - xla::XlaBuilder* builder = boxes.builder(); - // Choose a more convenient layout. - xla::XlaOp boxes_transposed = xla::Transpose(boxes, {1, 0}); - xla::XlaOp boxes_sorted = xla::GetTupleElement( - xla::Sort({xla::Broadcast(scores, {4}), boxes_transposed}, - xla::CreateScalarGtComputation( - {scores_shape.element_type(), boxes_shape.element_type()}, - builder), - /*dimension=*/1), - 1); - // Track the mapping of indices into sorted domain. - xla::XlaOp iota_indices = - xla::Iota(builder, xla::PrimitiveType::S32, num_boxes); - xla::XlaOp indices_sort = xla::Sort( - {scores, iota_indices}, - xla::CreateScalarGtComputation( - {scores_shape.element_type(), xla::PrimitiveType::S32}, builder)); - xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1); - xla::XlaOp scores_sorted = xla::GetTupleElement(indices_sort, 0); - - // Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0. - xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted, - /*start_index=*/0, - /*limit_index=*/1, - /*stride=*/1, - /*dimno=*/0), - {num_boxes}); - xla::XlaOp c_x0 = xla::Reshape(xla::SliceInDim(boxes_sorted, - /*start_index=*/1, - /*limit_index=*/2, - /*stride=*/1, - /*dimno=*/0), - {num_boxes}); - xla::XlaOp c_y1 = xla::Reshape(xla::SliceInDim(boxes_sorted, - /*start_index=*/2, - /*limit_index=*/3, - /*stride=*/1, - /*dimno=*/0), - {num_boxes}); - xla::XlaOp c_x1 = xla::Reshape(xla::SliceInDim(boxes_sorted, - /*start_index=*/3, - /*limit_index=*/4, - /*stride=*/1, - /*dimno=*/0), - {num_boxes}); - - xla::XlaOp y1 = xla::Select(xla::Le(c_y0, c_y1), c_y0, c_y1); - xla::XlaOp y2 = xla::Select(xla::Le(c_y0, c_y1), c_y1, c_y0); - xla::XlaOp x1 = xla::Select(xla::Le(c_x0, c_x1), c_x0, c_x1); - xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0); - xla::XlaOp area = (y2 - y1) * (x2 - x1); - - // Shapes are henceforth [1, num_boxes]. - y1 = xla::Broadcast(y1, {1}); - y2 = xla::Broadcast(y2, {1}); - x1 = xla::Broadcast(x1, {1}); - x2 = xla::Broadcast(x2, {1}); - area = xla::Broadcast(area, {1}); - - // Shapes are henceforth [num_boxes, num_boxes]. - xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0})); - xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0})); - xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0})); - xla::XlaOp i_ymax = xla::Min(y2, xla::Transpose(y2, {1, 0})); - auto square_zero = xla::ZerosLike(i_xmin); - - xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) * - xla::Max(i_ymax - i_ymin, square_zero); - xla::XlaOp u_area = area + xla::Transpose(area, {1, 0}) - i_area; - xla::XlaOp iou = i_area / u_area; - - xla::XlaOp iou_threshold_mask = xla::Gt(iou, iou_threshold + square_zero); - xla::XlaOp included_iou = - xla::Broadcast(xla::ConstantR0(builder, true), {num_boxes}); - if (boxes_shape.is_dynamic_dimension(0)) { - // Update included_iou's size to match boxes actual size. - included_iou = xla::SetDimensionSize( - included_iou, XlaHelpers::GetDimensionsSize({boxes}, {0}).size, 0); - } - - xla::XlaOp zero_s32 = xla::Zero(builder, xla::PrimitiveType::S32); - xla::XlaOp one_s32 = xla::One(builder, xla::PrimitiveType::S32); - std::vector init_values; - init_values.reserve(4); - init_values.push_back(zero_s32); // col_idx - init_values.push_back(zero_s32); // num_outputs - init_values.push_back(iou_threshold_mask); - init_values.push_back(included_iou); - - auto suppress_loop_result = ConsumeValue(xla::WhileLoopHelper( - WhileCondFn(num_boxes, output_size), SuppressBodyFn(num_boxes), - init_values, "BoxSuppressLoop", builder)); - - xla::XlaOp included_score = - xla::Gt(scores_sorted, xla::Broadcast(score_threshold, {num_boxes})); - xla::XlaOp included = xla::And(included_score, suppress_loop_result[3]); - - // Only consider boxes over which we have iterated. This allows for accurate - // counting. DynamicSlice would require knowledge of the size of the output. - xla::XlaOp valid_elem = xla::Lt( - iota_indices, xla::Broadcast(suppress_loop_result[0], {num_boxes})); - included = xla::And(included, valid_elem); - - xla::XlaOp neg_inf = xla::Broadcast( - xla::MinValue(builder, scores_shape.element_type()), {num_boxes}); - xla::XlaOp scores_included = xla::Select(included, scores_sorted, neg_inf); - xla::XlaOp output_tuple = xla::TopK(scores_included, output_size); - xla::XlaOp selected_indices_sorted = xla::GetTupleElement(output_tuple, 1); - // Calculate num_valid. - // Note: num_valid cannot be taken from the loop outputs, because outputs - // can be suppressed by score threshold. - xla::XlaOp ones_included = - xla::Select(included, xla::Broadcast(one_s32, {num_boxes}), - xla::Broadcast(zero_s32, {num_boxes})); - // num_valid is scalar. torch::lazy::Value should be bound by output_size. - xla::XlaOp num_valid_total = xla::Reduce( - ones_included, - /*init_value=*/zero_s32, - /*computation=*/ - xla::CreateScalarAddComputation(xla::PrimitiveType::S32, builder), - /*dimensions_to_reduce=*/{0}); - xla::XlaOp num_valid = - xla::Min(num_valid_total, xla::ConstantR0(builder, output_size)); - - // Re-index into the original scores input tensor, using a Gather. - // Boxes were suppressed in the sorted domain. - xla::XlaOp selected_indices = - NmsGather(indices_sorted, scores_shape.dimensions(), - selected_indices_sorted, {output_size}, - /*axis=*/0); - return {selected_indices, num_valid}; -} - -} // namespace torch_xla diff --git a/torch_xla/csrc/nms_op.h b/torch_xla/csrc/nms_op.h deleted file mode 100644 index 42b2bf0016ee..000000000000 --- a/torch_xla/csrc/nms_op.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef XLA_TORCH_XLA_CSRC_NMS_OP_H_ -#define XLA_TORCH_XLA_CSRC_NMS_OP_H_ - -#include "xla/client/xla_builder.h" - -namespace torch_xla { - -struct NmsResult { - xla::XlaOp selected_indices; - xla::XlaOp num_valid; -}; - -NmsResult BuildNms(xla::XlaOp boxes, xla::XlaOp scores, - xla::XlaOp score_threshold, xla::XlaOp iou_threshold, - int64_t output_size); - -} // namespace torch_xla - -#endif // XLA_TORCH_XLA_CSRC_NMS_OP_H_ \ No newline at end of file diff --git a/torch_xla/csrc/ops/nms.cpp b/torch_xla/csrc/ops/nms.cpp index 64df29ce4605..70a80a9173bf 100644 --- a/torch_xla/csrc/ops/nms.cpp +++ b/torch_xla/csrc/ops/nms.cpp @@ -1,65 +1,46 @@ #include "torch_xla/csrc/ops/nms.h" #include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/nms_op.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/xla_lower_util.h" namespace torch_xla { namespace { xla::Shape NodeOutputShape(const torch::lazy::Value& boxes, const torch::lazy::Value& scores, - const torch::lazy::Value& score_threshold, - const torch::lazy::Value& iou_threshold, - int64_t output_size) { + const torch::lazy::Value& iou_threshold) { auto shape_fn = [&](absl::Span operands) -> xla::XlaOp { - NmsResult result = BuildNms(operands[0], operands[1], operands[2], - operands[3], output_size); - return xla::Tuple(result.selected_indices.builder(), - {result.selected_indices, result.num_valid}); + return BuildNms(operands[0], operands[1], operands[2]); }; + return InferOutputShape( - {GetXlaShape(boxes), GetXlaShape(scores), GetXlaShape(score_threshold), - GetXlaShape(iou_threshold)}, + {GetXlaShape(boxes), GetXlaShape(scores), GetXlaShape(iou_threshold)}, shape_fn); } } // namespace Nms::Nms(const torch::lazy::Value& boxes, const torch::lazy::Value& scores, - const torch::lazy::Value& score_threshold, - const torch::lazy::Value& iou_threshold, int64_t output_size) + const torch::lazy::Value& iou_threshold) : XlaNode( - xla_nms, {boxes, scores, score_threshold, iou_threshold}, - [&]() { - return NodeOutputShape(boxes, scores, score_threshold, - iou_threshold, output_size); - }, - /*num_outputs=*/2, torch::lazy::MHash(output_size)), - output_size_(output_size) {} + /*op=*/xla_nms, + /*operands=*/{boxes, scores, iou_threshold}, + /*xla_shape_fn=*/ + [&]() { return NodeOutputShape(boxes, scores, iou_threshold); }) {} torch::lazy::NodePtr Nms::Clone(torch::lazy::OpList operands) const { return torch::lazy::MakeNode(operands.at(0), operands.at(1), - operands.at(2), operands.at(3), - output_size_); + operands.at(2)); } XlaOpVector Nms::Lower(LoweringContext* loctx) const { xla::XlaOp boxes = loctx->GetOutputOp(operand(0)); xla::XlaOp scores = loctx->GetOutputOp(operand(1)); - xla::XlaOp score_threshold = loctx->GetOutputOp(operand(2)); - xla::XlaOp iou_threshold = loctx->GetOutputOp(operand(3)); - NmsResult result = - BuildNms(boxes, scores, score_threshold, iou_threshold, output_size_); - return ReturnOps({result.selected_indices, result.num_valid}, loctx); -} - -std::string Nms::ToString() const { - std::stringstream ss; - ss << XlaNode::ToString() << ", output_size=" << output_size_; - return ss.str(); + xla::XlaOp iou_threshold = loctx->GetOutputOp(operand(2)); + return ReturnOp(BuildNms(boxes, scores, iou_threshold), loctx); } } // namespace torch_xla diff --git a/torch_xla/csrc/ops/nms.h b/torch_xla/csrc/ops/nms.h index b825b27f1877..c45dd84cdf4d 100644 --- a/torch_xla/csrc/ops/nms.h +++ b/torch_xla/csrc/ops/nms.h @@ -8,19 +8,11 @@ namespace torch_xla { class Nms : public XlaNode { public: Nms(const torch::lazy::Value& boxes, const torch::lazy::Value& scores, - const torch::lazy::Value& score_threshold, - const torch::lazy::Value& iou_threshold, int64_t output_size); + const torch::lazy::Value& iou_threshold); torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; XlaOpVector Lower(LoweringContext* loctx) const override; - - std::string ToString() const override; - - int64_t output_size() const { return output_size_; } - - private: - int64_t output_size_; }; } // namespace torch_xla diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index bb9ea602541c..aefe04458b7b 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2079,19 +2079,14 @@ XLATensorPtr nll_loss_backward(const XLATensorPtr& grad_output, GetXlaReductionMode(reduction), ignore_index)); } -std::pair nms(const XLATensorPtr& boxes, - const XLATensorPtr& scores, - const XLATensorPtr& score_threshold, - const XLATensorPtr& iou_threshold, - int64_t output_size) { +XLATensorPtr nms(const XLATensorPtr& boxes, const XLATensorPtr& scores, + double iou_threshold) { + const torch::lazy::BackendDevice& device = boxes->GetDevice(); + torch::lazy::NodePtr xla_iou_threshold = + ScalarOp(iou_threshold, MakeXlaPrimitiveType(at::kDouble, &device)); torch::lazy::NodePtr node = torch::lazy::MakeNode( - boxes->GetIrValue(), scores->GetIrValue(), score_threshold->GetIrValue(), - iou_threshold->GetIrValue(), output_size); - return std::pair( - XLATensor::Create(torch::lazy::Value(node, 0), boxes->GetDevice(), - at::ScalarType::Int), - XLATensor::Create(torch::lazy::Value(node, 1), boxes->GetDevice(), - at::ScalarType::Int)); + boxes->GetIrValue(), scores->GetIrValue(), xla_iou_threshold); + return XLATensor::Create(node, device, at::ScalarType::Long); } XLATensorPtr nonzero(const XLATensorPtr& input) { diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 65c4324f0605..70b42b5c2ac2 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -678,11 +678,8 @@ XLATensorPtr nll_loss_backward(const XLATensorPtr& grad_output, int ignore_index, const XLATensorPtr& total_weight); -std::pair nms(const XLATensorPtr& boxes, - const XLATensorPtr& scores, - const XLATensorPtr& score_threshold, - const XLATensorPtr& iou_threshold, - int64_t output_size); +XLATensorPtr nms(const XLATensorPtr& boxes, const XLATensorPtr& scores, + double iou_threshold); XLATensorPtr nonzero(const XLATensorPtr& input); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index e5358614bcf5..49cfa4ef5d79 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1264,4 +1264,247 @@ xla::XlaOp BuildTpuCustomCall(const std::vector& inputs, inputs, output_shape, input_shapes, payload); } +std::vector BuildBoxSelectionLoop(int64_t num_boxes, + xla::XlaOp iou_threshold_mask) { + using IndexType = int32_t; + const xla::PrimitiveType XLAIndexType = xla::PrimitiveType::S32; + + xla::XlaBuilder* builder = iou_threshold_mask.builder(); + + const xla::XlaOp ZERO = xla::Zero(builder, XLAIndexType); + const xla::XlaOp TRUE = xla::ConstantR0(builder, true); + + // Initial values to the while loop. + std::vector init_values(3); + // 1. Loop counter: represents the actual box being processed. + init_values[0] = ZERO; + // 2. State of each box (i.e. whether it was included or not). + init_values[1] = xla::Broadcast(TRUE, {num_boxes}); + // 3. The actual IoU threshold matrix. + init_values[2] = iou_threshold_mask; + + return ConsumeValue(xla::WhileLoopHelper( + [=](absl::Span values, xla::XlaBuilder* builder) { + xla::XlaOp box_index = values[0]; + // Check: current loop counter is within bounds, i.e. has a + // corresponding box. + return xla::Lt(box_index, + xla::ConstantR0(builder, num_boxes)); + }, + [=](absl::Span values, xla::XlaBuilder* builder) { + const xla::XlaOp ONE = xla::One(builder, XLAIndexType); + const xla::XlaOp ZERO = xla::Zero(builder, XLAIndexType); + + xla::XlaOp box_index = values[0]; + xla::XlaOp state = values[1]; + xla::XlaOp iou_threshold_mask = values[2]; + + // Retrieve the IoU mask row corresponding to this box. + xla::XlaOp box_iou_threshold_mask = xla::DynamicSlice( + iou_threshold_mask, {box_index, ZERO}, {1, num_boxes}); + + // Update the current state with the IoU mask. + // Basically, sets to false every box X whose IoU with the current box + // is less-than or equal than the given threshold. + xla::XlaOp updated_state = xla::And( + state, + // Update the mask so that if we select this box + // (i.e. state[box] == true), we don't de-select it. + xla::DynamicUpdateSlice( + // Before that, we need to pre-process the mask. + // 1. Negate the mask: if this box is selected, we only want + // those that have a low intersection ratio. + // 2. Reshape it to: [num_boxes]. + xla::Reshape(xla::Not(box_iou_threshold_mask), {num_boxes}), + xla::ConstantR1(builder, {true}), {box_index})); + + // Flag: should this box (loop counter) be included in the output? + xla::XlaOp should_include = xla::DynamicSlice(state, {box_index}, {1}); + // Pick the new values of state, depending on whether we should include + // this box or not. + xla::XlaOp new_state = + xla::Select(xla::BroadcastInDim(should_include, {num_boxes}, {0}), + updated_state, state); + + xla::XlaOp next_box_index = box_index + ONE; + return std::vector{next_box_index, new_state, + iou_threshold_mask}; + }, + init_values, "BoxSelectionLoop", builder)); +} + +xla::XlaOp BuildNms(xla::XlaOp boxes, xla::XlaOp scores, + xla::XlaOp iou_threshold) { + using IndexType = int32_t; + const xla::PrimitiveType XLAIndexType = xla::PrimitiveType::S32; + + xla::XlaBuilder* builder = boxes.builder(); + + const int64_t COORDINATES = 4; + const xla::XlaOp ONE = xla::One(builder, XLAIndexType); + const xla::XlaOp ZERO = xla::Zero(builder, XLAIndexType); + + const xla::Shape& boxes_shape = ShapeHelper::ShapeOfXlaOp(boxes); + XLA_CHECK_EQ(boxes_shape.rank(), 2); + XLA_CHECK_EQ(boxes_shape.dimensions(1), COORDINATES); + int64_t num_boxes = boxes_shape.dimensions(0); + + const xla::Shape& scores_shape = ShapeHelper::ShapeOfXlaOp(scores); + XLA_CHECK_EQ(scores_shape.rank(), 1); + XLA_CHECK_EQ(scores_shape.dimensions(0), num_boxes); + + // 1. Order the boxes according to their scores. + // Also remember the order of the boxes original indices by having an + // extra Iota operand. + xla::XlaOp sorted = xla::Sort( + { + // Here, we need to broadcast both the scores and Iota operands, so + // as to have the same dimensions as boxes: {COORDINATES, num_boxes}. + xla::Broadcast(scores, {COORDINATES}), + xla::Broadcast(xla::Iota(builder, XLAIndexType, num_boxes), + {COORDINATES}), + // Transpose boxes, so as to manipulate its values in an easier way. + xla::Transpose(boxes, {1, 0}), + }, + xla::CreateScalarGtComputation( + { + scores_shape.element_type(), + XLAIndexType, + boxes_shape.element_type(), + }, + builder), + /*dimension=*/1); + + // 1.1. De-construct the returned tuple. + // Specifically, we only need one of the rows of the sorted index tensor + // and of the sorted scores, since the others were only broadcasted. + // + // Shape: [1, num_boxes] + xla::XlaOp sorted_scores = + xla::SliceInDim(xla::GetTupleElement(sorted, 0), /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, /*dimno=*/0); + // Shape: [1, num_boxes] + xla::XlaOp sorted_indices = + xla::SliceInDim(xla::GetTupleElement(sorted, 1), /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, /*dimno=*/0); + // Shape: [COORDINATES, num_boxes] + xla::XlaOp sorted_boxes = xla::GetTupleElement(sorted, 2); + + // 1.2. Retrieve each coordinate, in their own tensor. + // Since we transposed boxes tensor, each row corresponds to a + // coordinate. + // + // Shape: [1, num_boxes] + xla::XlaOp y0 = xla::SliceInDim(sorted_boxes, /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, /*dimno=*/0); + xla::XlaOp x0 = xla::SliceInDim(sorted_boxes, /*start_index=*/1, + /*limit_index=*/2, /*stride=*/1, /*dimno=*/0); + xla::XlaOp y1 = xla::SliceInDim(sorted_boxes, /*start_index=*/2, + /*limit_index=*/3, /*stride=*/1, /*dimno=*/0); + xla::XlaOp x1 = xla::SliceInDim(sorted_boxes, /*start_index=*/3, + /*limit_index=*/4, /*stride=*/1, /*dimno=*/0); + + // 2. Create the IoU (Intersection over Union) ratio mask + // 2.1. First, compute the area of each box. + // + // Shape: [1, num_boxes] + xla::XlaOp area = (y1 - y0) * (x1 - x0); + + // 2.2. Get the corners of the intersection box created by every pair of + // boxes. + // Basically, given 2 boxes, what the corner of their + // intersection box would be? + // + // Shape: [num_boxes, num_boxes] + xla::XlaOp left = xla::Max(x0, xla::Transpose(x0, {1, 0})); + xla::XlaOp bottom = xla::Max(y0, xla::Transpose(y0, {1, 0})); + xla::XlaOp right = xla::Min(x1, xla::Transpose(x1, {1, 0})); + xla::XlaOp top = xla::Min(y1, xla::Transpose(y1, {1, 0})); + + // 2.3. Compute the intersection area. + // Whenever 2 boxes don't intersect, either their width or height will be + // negative. + // + // Shape: [num_boxes, num_boxes] + xla::XlaOp zeros = xla::ZerosLike(left); + xla::XlaOp intersection_area = + xla::Max(right - left, zeros) * xla::Max(top - bottom, zeros); + + // 2.4. Compute the union area. + // Sum of the areas of every pair of boxes, minus their intersection + // area. + // + // Shape: [num_boxes, num_boxes] + xla::XlaOp union_area = + area + xla::Transpose(area, {1, 0}) - intersection_area; + + // 2.5. Compute the IoU ratio. + // + // Shape: [num_boxes, num_boxes] + xla::XlaOp iou = intersection_area / union_area; + + // 2.6. Create the mask by comparing it with the given threshold. + // + // Shape: [num_boxes, num_boxes] + xla::XlaOp casted_threshold = + xla::ConvertElementType(iou_threshold, XlaHelpers::TypeOfXlaOp(iou)); + xla::XlaOp iou_threshold_mask = xla::Gt(iou, casted_threshold); + + // 3. Iteratively select the highest scoring box, and eliminate those whose + // IoU is greater than the threshold. + // + // state: a [num_boxes] tensor, where, at the end of the loop, for + // each box i, state[i] represents whether box i should be + // included in the output or not + // + // Loop Invariant: for every i in [0..current iteration], state[i] + // represents whether box i is included or not in the + // output. + // + // Rough idea: at every iteration i, we: + // - Check if state[i] == true (i.e. box i should be included) + // + // - If so, retrieve and negate the i-th row from the IoU mask + // (i.e. what are the boxes that have an IoU ratio lower-than or + // equal the given threshold?). + // + // - Update state[i+1..] by computing a logical and operation with + // the retrieved negated IoU mask. Note that this won't modify + // state[0..i]. The next box j > i, where state[j] == true is + // the next box that will be included. + std::vector loop_result = + BuildBoxSelectionLoop(num_boxes, iou_threshold_mask); + + xla::XlaOp loop_counter = loop_result[0]; + xla::XlaOp included_mask = loop_result[1]; + + // 4. Retrieve the included box indices. + // 4.1. Compute the number of included boxes. + // 4.1.1. Transform that mask into a 0-1 tensor. + xla::XlaOp one_if_included = + xla::Select(included_mask, xla::Broadcast(ONE, {num_boxes}), + xla::Broadcast(ZERO, {num_boxes})); + // 4.1.2. Sum it up. + xla::XlaOp included_boxes = + xla::Reduce(one_if_included, ZERO, + xla::CreateScalarAddComputation(XLAIndexType, builder), {0}); + + // 4.2. Move the indices of the included boxes to the beginning. + // Doing so alongside the previously sorted indices gives us an index + // tensor with the original indices of the selected boxes at the + // beginning. + xla::XlaOp included_indices_first = xla::GetTupleElement( + xla::Sort( + { + one_if_included, + xla::Reshape(sorted_indices, {num_boxes}), + }, + xla::CreateScalarGtComputation({XLAIndexType, XLAIndexType}, builder), + /*dimension=*/0, /*is_stable=*/true), + 1); + + // 4.3. Get only the first included_boxes indices. + return xla::SetDimensionSize(included_indices_first, included_boxes, 0); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index f0c74dff9918..ae26200e0104 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -156,6 +156,9 @@ xla::XlaOp BuildTpuCustomCall(const std::vector& inputs, const xla::Shape& output_shape, const std::string& payload); +xla::XlaOp BuildNms(xla::XlaOp boxes, xla::XlaOp scores, + xla::XlaOp iou_threshold); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_XLA_LOWER_UTIL_H_ diff --git a/torch_xla/csrc/xla_manual_registration.cpp b/torch_xla/csrc/xla_manual_registration.cpp new file mode 100644 index 000000000000..881e51ca1fee --- /dev/null +++ b/torch_xla/csrc/xla_manual_registration.cpp @@ -0,0 +1,38 @@ +#include +#include + +#include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/ops/nms.h" +#include "torch_xla/csrc/ops/ops.h" +#include "torch_xla/csrc/tensor_methods.h" +#include "torch_xla/csrc/tensor_util.h" + +namespace torch_xla { +namespace manual { +namespace { + +at::Tensor nms_kernel(const at::Tensor& boxes, const at::Tensor& scores, + double iou_threshold) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + + XLA_CHECK_EQ(boxes.dim(), 2) << "nms(): boxes should be a 2D tensor."; + XLA_CHECK_EQ(boxes.size(1), 4) + << "nms(): boxes should be a 2D tensor of shape [N, 4]."; + XLA_CHECK_EQ(scores.dim(), 1) << "nms(): scores should be a 1D tensor."; + XLA_CHECK_EQ(boxes.size(0), scores.size(0)) + << "nms(): boxes and scores should have the same size for dimension 0."; + + XLATensorPtr xla_boxes = bridge::GetXlaTensor(boxes); + XLATensorPtr xla_scores = bridge::GetXlaTensor(scores); + return bridge::AtenFromXlaTensor( + tensor_methods::nms(xla_boxes, xla_scores, iou_threshold)); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, XLA, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); +} + +} // namespace manual +} // namespace torch_xla