Skip to content

Commit

Permalink
Refactor nms into TorchVision variant. (#6814)
Browse files Browse the repository at this point in the history
ysiraichi authored Apr 1, 2024
1 parent b0ba29f commit 6cf9b91
Showing 16 changed files with 398 additions and 458 deletions.
1 change: 0 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -33,7 +33,6 @@ xla_model
.. automodule:: torch_xla.core.functions
.. autofunction:: all_reduce
.. autofunction:: all_gather
.. autofunction:: nms

distributed
----------------------------------
109 changes: 83 additions & 26 deletions test/test_operations.py
Original file line number Diff line number Diff line change
@@ -83,6 +83,11 @@ def skipIfFunctionalizationDisabled(reason):
return _skipIfFunctionalization(value=True, reason=reason)


def onlyOnCUDA(fn):
accelerator = os.environ.get("PJRT_DEVICE").lower()
return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn)


def _gen_tensor(*args, **kwargs):
return torch.randn(*args, **kwargs)

@@ -2280,32 +2285,6 @@ def test_send_to_device_single(self):
self.assertEqual(dt[0].device, xla_device)
self.assertTrue(torch.all(torch.eq(dt[0].cpu(), t)))

def test_nms(self):
BOXES = (
(0, 0, 3, 2),
(3, 3, 11, 7),
(2, 2, 5, 7),
(7, 4, 15, 12),
)
SCORES = (0.9, 0.5, 0.95, 0.4)
SCORE_THRESHOLD = 0.1
IOU_THRESHOLD = 0.08

xla_device = xm.xla_device()
boxes = torch.tensor(BOXES, dtype=torch.float).to(xla_device)
scores = torch.tensor(SCORES, dtype=torch.float).to(xla_device)
score_threshold = torch.tensor(
SCORE_THRESHOLD, dtype=torch.float).to(xla_device)
iou_threshold = torch.tensor(
IOU_THRESHOLD, dtype=torch.float).to(xla_device)

selected_indices, num_valid = xf.nms(boxes, scores, score_threshold,
iou_threshold, len(BOXES))

self.assertEqual(selected_indices,
torch.tensor([2, 0, 3, 1], dtype=torch.int32))
self.assertEqual(num_valid.item(), 3)

def test_util_foreach_api(self):

class ForTest(object):
@@ -2473,6 +2452,84 @@ def test_dropout(self):
f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}")


# These tests were extracted and adapted from torchvision.
# Source: vision/test/test_ops.py
class TestNMS(test_utils.XlaTestCase):

def _reference_nms(self, boxes, scores, iou_threshold):
import torchvision
return torchvision.ops.nms(boxes.cpu(), scores.cpu(), iou_threshold)

def _nms(self, boxes, scores, iou_threshold):
import torchvision
device = xm.xla_device()
return torchvision.ops.nms(
boxes.to(device), scores.to(device), iou_threshold).cpu()

def _create_tensors_with_iou(self, N, iou_thresh):
# force last box to have a pre-defined iou with the first box
# let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
# then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
# we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
# Adjust the threshold upward a bit with the intent of creating
# at least one box that exceeds (barely) the threshold and so
# should be suppressed.
boxes = torch.rand(N, 4) * 100
boxes[:, 2:] += boxes[:, :2]
boxes[-1, :] = boxes[0, :]
x0, y0, x1, y1 = boxes[-1].tolist()
iou_thresh += 1e-5
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
scores = torch.rand(N)
return boxes, scores

@skipOnEagerDebug
def test_nms_ref(self):

def _test(iou, seed):
torch.random.manual_seed(seed)
err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
boxes, scores = self._create_tensors_with_iou(1000, iou)
keep_ref = self._reference_nms(boxes, scores, iou)
keep = self._nms(boxes, scores, iou)
self.assertEqual(keep, keep_ref, message=err_msg.format(iou))

for iou in (0.2, 0.5, 0.8):
for seed in range(10):
with self.subTest(iou=iou, seed=seed):
_test(iou, seed)

def test_nms_input_errors(self):
with self.assertRaisesRegex(RuntimeError, "boxes should be a 2D tensor."):
self._nms(torch.rand(4), torch.rand(3), 0.5)
with self.assertRaisesRegex(
RuntimeError, "boxes should be a 2D tensor of shape \[N, 4\]."):
self._nms(torch.rand(3, 5), torch.rand(3), 0.5)
with self.assertRaisesRegex(RuntimeError, "scores should be a 1D tensor."):
self._nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
with self.assertRaisesRegex(
RuntimeError,
"boxes and scores should have the same size for dimension 0."):
self._nms(torch.rand(3, 4), torch.rand(4), 0.5)

def test_legacy(self):
BOXES = (
(0, 0, 3, 2),
(3, 3, 11, 7),
(2, 2, 5, 7),
(7, 4, 15, 12),
)
SCORES = (0.9, 0.5, 0.95, 0.4)
IOU_THRESHOLD = 0.08

def fn(boxes, scores):
return self._reference_nms(boxes, scores, IOU_THRESHOLD)

boxes = torch.tensor(BOXES, dtype=torch.float)
scores = torch.tensor(SCORES, dtype=torch.float)
self.runAtenTest((boxes, scores), fn)


if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
23 changes: 0 additions & 23 deletions torch_xla/core/functions.py
Original file line number Diff line number Diff line change
@@ -84,29 +84,6 @@ def all_gather(value, dim=0):
return AllGather.apply(value, dim)


def nms(boxes, scores, score_threshold, iou_threshold, output_size):
"""Performs a Non Maximal Suppression operation.
Args:
boxes (torch.Tensor): A `torch.Tensor` of shape `[N, 4]` listing the boxes
coordinates in `(y0, x0, y1, x1)` form.
scores (torch.Tensor): A `torch.Tensor` of shape `[N]` listing the scores
of each box.
score_threshold (torch.Tensor): The minimum score for a box to qualify as
valid.
iou_threshold (torch.Tensor): The minimum IOU (Intersection Over Union)
score to trigger overlap logic.
output_size (int): The maximum number of returned indices (must be lower or
equal to N).
Returns:
A tuple of `torch.Tensor` with the first element being the selected box
indices, and the second element being the number of valid boxes.
"""
return torch_xla._XLAC._xla_nms(boxes, scores, score_threshold, iou_threshold,
output_size)


def distributed_mm(w, x, split=1):
"""Performs a matrix multiplication with sharded weight.
3 changes: 1 addition & 2 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
@@ -47,7 +47,6 @@ ptxla_cc_library(
"ir_dump_util.cpp",
"matrix.cpp",
"nll_loss.cpp",
"nms_op.cpp",
"pooling.cpp",
"quant_util.cpp",
"random.cpp",
@@ -67,6 +66,7 @@ ptxla_cc_library(
"xla_lower_util.cpp",
"xla_op_builder.cpp",
"xla_sharding_util.cpp",
"xla_manual_registration.cpp",
":RegisterAutogradXLA.cpp",
":RegisterXLA.cpp",
":XLANativeFunctions.cpp",
@@ -87,7 +87,6 @@ ptxla_cc_library(
"ir_dump_util.h",
"matrix.h",
"nll_loss.h",
"nms_op.h",
"pooling.h",
"quant_util.h",
"random.h",
6 changes: 4 additions & 2 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
@@ -400,12 +400,14 @@ at::Tensor XlaToAtenTensor(XLATensorPtr xla_tensor,
return tensor.to(tensor_options, /*non_blocking=*/false, /*copy=*/true);
}

at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor) {
at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor,
bool skip_functionalization) {
if (xla_tensor) {
auto out =
at::Tensor(c10::make_intrusive<XLATensorImpl>(std::move(xla_tensor)));
// See Note [Lazy Tensor Functionalization]
if (c10::impl::tls_local_dispatch_key_set().excluded_.has(
if (skip_functionalization ||
c10::impl::tls_local_dispatch_key_set().excluded_.has(
c10::DispatchKey::Functionalize)) {
// Invariant: if the functionalization key is in the exclude set, then
// we're expected to return an ordinary tensor, which will be "lifted"
3 changes: 2 additions & 1 deletion torch_xla/csrc/aten_xla_bridge.h
Original file line number Diff line number Diff line change
@@ -112,7 +112,8 @@ at::Tensor XlaToAtenTensor(XLATensorPtr xla_tensor,
const at::TensorOptions& tensor_options);

// Creates an ATen tensor with XLA type id from an XLATensorPtr.
at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor);
at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor,
bool skip_functionalization = false);

std::vector<at::Tensor> AtenFromXlaTensors(
absl::Span<const XLATensorPtr> xla_tensors);
27 changes: 0 additions & 27 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
@@ -666,28 +666,6 @@ py::object GetRevisions() {
return py_dict;
}

py::object XlaNms(const at::Tensor& boxes, const at::Tensor& scores,
const at::Tensor& score_threshold,
const at::Tensor& iou_threshold, int64_t output_size) {
at::Tensor selected_indices;
at::Tensor num_valid;
{
NoGilSection nogil;
auto nms_result = tensor_methods::nms(
bridge::GetXlaTensor(boxes), bridge::GetXlaTensor(scores),
bridge::GetXlaTensor(score_threshold),
bridge::GetXlaTensor(iou_threshold), output_size);
selected_indices = bridge::AtenFromXlaTensor(std::move(nms_result.first));
num_valid = bridge::AtenFromXlaTensor(std::move(nms_result.second));
}
auto result_tuple = py::tuple(2);
result_tuple[0] =
torch::autograd::make_variable(selected_indices, /*requires_grad=*/false);
result_tuple[1] =
torch::autograd::make_variable(num_valid, /*requires_grad=*/false);
return result_tuple;
}

std::vector<at::Tensor> XlaUserComputation(
const std::string& opname, const std::vector<at::Tensor>& inputs,
runtime::ComputationClient::ComputationPtr computation) {
@@ -1086,11 +1064,6 @@ void InitXlaModuleBindings(py::module m) {
[](const at::Tensor& tensor, int dim) {
return GetXlaTensorDimensionSize(tensor, dim);
});
m.def("_xla_nms", [](const at::Tensor& boxes, const at::Tensor& scores,
const at::Tensor& score_threshold,
const at::Tensor& iou_threshold, int64_t output_size) {
return XlaNms(boxes, scores, score_threshold, iou_threshold, output_size);
});
m.def("_xla_user_computation",
[](const std::string& opname, const std::vector<at::Tensor>& inputs,
const runtime::ComputationClient::ComputationPtr& computation) {
299 changes: 0 additions & 299 deletions torch_xla/csrc/nms_op.cpp

This file was deleted.

19 changes: 0 additions & 19 deletions torch_xla/csrc/nms_op.h

This file was deleted.

45 changes: 13 additions & 32 deletions torch_xla/csrc/ops/nms.cpp
Original file line number Diff line number Diff line change
@@ -1,65 +1,46 @@
#include "torch_xla/csrc/ops/nms.h"

#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/nms_op.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/xla_lower_util.h"

namespace torch_xla {
namespace {

xla::Shape NodeOutputShape(const torch::lazy::Value& boxes,
const torch::lazy::Value& scores,
const torch::lazy::Value& score_threshold,
const torch::lazy::Value& iou_threshold,
int64_t output_size) {
const torch::lazy::Value& iou_threshold) {
auto shape_fn = [&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
NmsResult result = BuildNms(operands[0], operands[1], operands[2],
operands[3], output_size);
return xla::Tuple(result.selected_indices.builder(),
{result.selected_indices, result.num_valid});
return BuildNms(operands[0], operands[1], operands[2]);
};

return InferOutputShape(
{GetXlaShape(boxes), GetXlaShape(scores), GetXlaShape(score_threshold),
GetXlaShape(iou_threshold)},
{GetXlaShape(boxes), GetXlaShape(scores), GetXlaShape(iou_threshold)},
shape_fn);
}

} // namespace

Nms::Nms(const torch::lazy::Value& boxes, const torch::lazy::Value& scores,
const torch::lazy::Value& score_threshold,
const torch::lazy::Value& iou_threshold, int64_t output_size)
const torch::lazy::Value& iou_threshold)
: XlaNode(
xla_nms, {boxes, scores, score_threshold, iou_threshold},
[&]() {
return NodeOutputShape(boxes, scores, score_threshold,
iou_threshold, output_size);
},
/*num_outputs=*/2, torch::lazy::MHash(output_size)),
output_size_(output_size) {}
/*op=*/xla_nms,
/*operands=*/{boxes, scores, iou_threshold},
/*xla_shape_fn=*/
[&]() { return NodeOutputShape(boxes, scores, iou_threshold); }) {}

torch::lazy::NodePtr Nms::Clone(torch::lazy::OpList operands) const {
return torch::lazy::MakeNode<Nms>(operands.at(0), operands.at(1),
operands.at(2), operands.at(3),
output_size_);
operands.at(2));
}

XlaOpVector Nms::Lower(LoweringContext* loctx) const {
xla::XlaOp boxes = loctx->GetOutputOp(operand(0));
xla::XlaOp scores = loctx->GetOutputOp(operand(1));
xla::XlaOp score_threshold = loctx->GetOutputOp(operand(2));
xla::XlaOp iou_threshold = loctx->GetOutputOp(operand(3));
NmsResult result =
BuildNms(boxes, scores, score_threshold, iou_threshold, output_size_);
return ReturnOps({result.selected_indices, result.num_valid}, loctx);
}

std::string Nms::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", output_size=" << output_size_;
return ss.str();
xla::XlaOp iou_threshold = loctx->GetOutputOp(operand(2));
return ReturnOp(BuildNms(boxes, scores, iou_threshold), loctx);
}

} // namespace torch_xla
10 changes: 1 addition & 9 deletions torch_xla/csrc/ops/nms.h
Original file line number Diff line number Diff line change
@@ -8,19 +8,11 @@ namespace torch_xla {
class Nms : public XlaNode {
public:
Nms(const torch::lazy::Value& boxes, const torch::lazy::Value& scores,
const torch::lazy::Value& score_threshold,
const torch::lazy::Value& iou_threshold, int64_t output_size);
const torch::lazy::Value& iou_threshold);

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

int64_t output_size() const { return output_size_; }

private:
int64_t output_size_;
};

} // namespace torch_xla
19 changes: 7 additions & 12 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
@@ -2089,19 +2089,14 @@ XLATensorPtr nll_loss_backward(const XLATensorPtr& grad_output,
GetXlaReductionMode(reduction), ignore_index));
}

std::pair<XLATensorPtr, XLATensorPtr> nms(const XLATensorPtr& boxes,
const XLATensorPtr& scores,
const XLATensorPtr& score_threshold,
const XLATensorPtr& iou_threshold,
int64_t output_size) {
XLATensorPtr nms(const XLATensorPtr& boxes, const XLATensorPtr& scores,
double iou_threshold) {
const torch::lazy::BackendDevice& device = boxes->GetDevice();
torch::lazy::NodePtr xla_iou_threshold =
ScalarOp(iou_threshold, MakeXlaPrimitiveType(at::kDouble, &device));
torch::lazy::NodePtr node = torch::lazy::MakeNode<Nms>(
boxes->GetIrValue(), scores->GetIrValue(), score_threshold->GetIrValue(),
iou_threshold->GetIrValue(), output_size);
return std::pair<XLATensorPtr, XLATensorPtr>(
XLATensor::Create(torch::lazy::Value(node, 0), boxes->GetDevice(),
at::ScalarType::Int),
XLATensor::Create(torch::lazy::Value(node, 1), boxes->GetDevice(),
at::ScalarType::Int));
boxes->GetIrValue(), scores->GetIrValue(), xla_iou_threshold);
return XLATensor::Create(node, device, at::ScalarType::Long);
}

XLATensorPtr nonzero(const XLATensorPtr& input) {
7 changes: 2 additions & 5 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
@@ -678,11 +678,8 @@ XLATensorPtr nll_loss_backward(const XLATensorPtr& grad_output,
int ignore_index,
const XLATensorPtr& total_weight);

std::pair<XLATensorPtr, XLATensorPtr> nms(const XLATensorPtr& boxes,
const XLATensorPtr& scores,
const XLATensorPtr& score_threshold,
const XLATensorPtr& iou_threshold,
int64_t output_size);
XLATensorPtr nms(const XLATensorPtr& boxes, const XLATensorPtr& scores,
double iou_threshold);

XLATensorPtr nonzero(const XLATensorPtr& input);

243 changes: 243 additions & 0 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
@@ -1295,4 +1295,247 @@ std::vector<xla::XlaOp> BuildTpuCustomCall(
return result;
}

std::vector<xla::XlaOp> BuildBoxSelectionLoop(int64_t num_boxes,
xla::XlaOp iou_threshold_mask) {
using IndexType = int32_t;
const xla::PrimitiveType XLAIndexType = xla::PrimitiveType::S32;

xla::XlaBuilder* builder = iou_threshold_mask.builder();

const xla::XlaOp ZERO = xla::Zero(builder, XLAIndexType);
const xla::XlaOp TRUE = xla::ConstantR0<bool>(builder, true);

// Initial values to the while loop.
std::vector<xla::XlaOp> init_values(3);
// 1. Loop counter: represents the actual box being processed.
init_values[0] = ZERO;
// 2. State of each box (i.e. whether it was included or not).
init_values[1] = xla::Broadcast(TRUE, {num_boxes});
// 3. The actual IoU threshold matrix.
init_values[2] = iou_threshold_mask;

return ConsumeValue(xla::WhileLoopHelper(
[=](absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) {
xla::XlaOp box_index = values[0];
// Check: current loop counter is within bounds, i.e. has a
// corresponding box.
return xla::Lt(box_index,
xla::ConstantR0<IndexType>(builder, num_boxes));
},
[=](absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) {
const xla::XlaOp ONE = xla::One(builder, XLAIndexType);
const xla::XlaOp ZERO = xla::Zero(builder, XLAIndexType);

xla::XlaOp box_index = values[0];
xla::XlaOp state = values[1];
xla::XlaOp iou_threshold_mask = values[2];

// Retrieve the IoU mask row corresponding to this box.
xla::XlaOp box_iou_threshold_mask = xla::DynamicSlice(
iou_threshold_mask, {box_index, ZERO}, {1, num_boxes});

// Update the current state with the IoU mask.
// Basically, sets to false every box X whose IoU with the current box
// is less-than or equal than the given threshold.
xla::XlaOp updated_state = xla::And(
state,
// Update the mask so that if we select this box
// (i.e. state[box] == true), we don't de-select it.
xla::DynamicUpdateSlice(
// Before that, we need to pre-process the mask.
// 1. Negate the mask: if this box is selected, we only want
// those that have a low intersection ratio.
// 2. Reshape it to: [num_boxes].
xla::Reshape(xla::Not(box_iou_threshold_mask), {num_boxes}),
xla::ConstantR1<bool>(builder, {true}), {box_index}));

// Flag: should this box (loop counter) be included in the output?
xla::XlaOp should_include = xla::DynamicSlice(state, {box_index}, {1});
// Pick the new values of state, depending on whether we should include
// this box or not.
xla::XlaOp new_state =
xla::Select(xla::BroadcastInDim(should_include, {num_boxes}, {0}),
updated_state, state);

xla::XlaOp next_box_index = box_index + ONE;
return std::vector<xla::XlaOp>{next_box_index, new_state,
iou_threshold_mask};
},
init_values, "BoxSelectionLoop", builder));
}

xla::XlaOp BuildNms(xla::XlaOp boxes, xla::XlaOp scores,
xla::XlaOp iou_threshold) {
using IndexType = int32_t;
const xla::PrimitiveType XLAIndexType = xla::PrimitiveType::S32;

xla::XlaBuilder* builder = boxes.builder();

const int64_t COORDINATES = 4;
const xla::XlaOp ONE = xla::One(builder, XLAIndexType);
const xla::XlaOp ZERO = xla::Zero(builder, XLAIndexType);

const xla::Shape& boxes_shape = ShapeHelper::ShapeOfXlaOp(boxes);
XLA_CHECK_EQ(boxes_shape.rank(), 2);
XLA_CHECK_EQ(boxes_shape.dimensions(1), COORDINATES);
int64_t num_boxes = boxes_shape.dimensions(0);

const xla::Shape& scores_shape = ShapeHelper::ShapeOfXlaOp(scores);
XLA_CHECK_EQ(scores_shape.rank(), 1);
XLA_CHECK_EQ(scores_shape.dimensions(0), num_boxes);

// 1. Order the boxes according to their scores.
// Also remember the order of the boxes original indices by having an
// extra Iota operand.
xla::XlaOp sorted = xla::Sort(
{
// Here, we need to broadcast both the scores and Iota operands, so
// as to have the same dimensions as boxes: {COORDINATES, num_boxes}.
xla::Broadcast(scores, {COORDINATES}),
xla::Broadcast(xla::Iota(builder, XLAIndexType, num_boxes),
{COORDINATES}),
// Transpose boxes, so as to manipulate its values in an easier way.
xla::Transpose(boxes, {1, 0}),
},
xla::CreateScalarGtComputation(
{
scores_shape.element_type(),
XLAIndexType,
boxes_shape.element_type(),
},
builder),
/*dimension=*/1);

// 1.1. De-construct the returned tuple.
// Specifically, we only need one of the rows of the sorted index tensor
// and of the sorted scores, since the others were only broadcasted.
//
// Shape: [1, num_boxes]
xla::XlaOp sorted_scores =
xla::SliceInDim(xla::GetTupleElement(sorted, 0), /*start_index=*/0,
/*limit_index=*/1, /*stride=*/1, /*dimno=*/0);
// Shape: [1, num_boxes]
xla::XlaOp sorted_indices =
xla::SliceInDim(xla::GetTupleElement(sorted, 1), /*start_index=*/0,
/*limit_index=*/1, /*stride=*/1, /*dimno=*/0);
// Shape: [COORDINATES, num_boxes]
xla::XlaOp sorted_boxes = xla::GetTupleElement(sorted, 2);

// 1.2. Retrieve each coordinate, in their own tensor.
// Since we transposed boxes tensor, each row corresponds to a
// coordinate.
//
// Shape: [1, num_boxes]
xla::XlaOp y0 = xla::SliceInDim(sorted_boxes, /*start_index=*/0,
/*limit_index=*/1, /*stride=*/1, /*dimno=*/0);
xla::XlaOp x0 = xla::SliceInDim(sorted_boxes, /*start_index=*/1,
/*limit_index=*/2, /*stride=*/1, /*dimno=*/0);
xla::XlaOp y1 = xla::SliceInDim(sorted_boxes, /*start_index=*/2,
/*limit_index=*/3, /*stride=*/1, /*dimno=*/0);
xla::XlaOp x1 = xla::SliceInDim(sorted_boxes, /*start_index=*/3,
/*limit_index=*/4, /*stride=*/1, /*dimno=*/0);

// 2. Create the IoU (Intersection over Union) ratio mask
// 2.1. First, compute the area of each box.
//
// Shape: [1, num_boxes]
xla::XlaOp area = (y1 - y0) * (x1 - x0);

// 2.2. Get the corners of the intersection box created by every pair of
// boxes.
// Basically, given 2 boxes, what the <direction> corner of their
// intersection box would be?
//
// Shape: [num_boxes, num_boxes]
xla::XlaOp left = xla::Max(x0, xla::Transpose(x0, {1, 0}));
xla::XlaOp bottom = xla::Max(y0, xla::Transpose(y0, {1, 0}));
xla::XlaOp right = xla::Min(x1, xla::Transpose(x1, {1, 0}));
xla::XlaOp top = xla::Min(y1, xla::Transpose(y1, {1, 0}));

// 2.3. Compute the intersection area.
// Whenever 2 boxes don't intersect, either their width or height will be
// negative.
//
// Shape: [num_boxes, num_boxes]
xla::XlaOp zeros = xla::ZerosLike(left);
xla::XlaOp intersection_area =
xla::Max(right - left, zeros) * xla::Max(top - bottom, zeros);

// 2.4. Compute the union area.
// Sum of the areas of every pair of boxes, minus their intersection
// area.
//
// Shape: [num_boxes, num_boxes]
xla::XlaOp union_area =
area + xla::Transpose(area, {1, 0}) - intersection_area;

// 2.5. Compute the IoU ratio.
//
// Shape: [num_boxes, num_boxes]
xla::XlaOp iou = intersection_area / union_area;

// 2.6. Create the mask by comparing it with the given threshold.
//
// Shape: [num_boxes, num_boxes]
xla::XlaOp casted_threshold =
xla::ConvertElementType(iou_threshold, XlaHelpers::TypeOfXlaOp(iou));
xla::XlaOp iou_threshold_mask = xla::Gt(iou, casted_threshold);

// 3. Iteratively select the highest scoring box, and eliminate those whose
// IoU is greater than the threshold.
//
// state: a [num_boxes] tensor, where, at the end of the loop, for
// each box i, state[i] represents whether box i should be
// included in the output or not
//
// Loop Invariant: for every i in [0..current iteration], state[i]
// represents whether box i is included or not in the
// output.
//
// Rough idea: at every iteration i, we:
// - Check if state[i] == true (i.e. box i should be included)
//
// - If so, retrieve and negate the i-th row from the IoU mask
// (i.e. what are the boxes that have an IoU ratio lower-than or
// equal the given threshold?).
//
// - Update state[i+1..] by computing a logical and operation with
// the retrieved negated IoU mask. Note that this won't modify
// state[0..i]. The next box j > i, where state[j] == true is
// the next box that will be included.
std::vector<xla::XlaOp> loop_result =
BuildBoxSelectionLoop(num_boxes, iou_threshold_mask);

xla::XlaOp loop_counter = loop_result[0];
xla::XlaOp included_mask = loop_result[1];

// 4. Retrieve the included box indices.
// 4.1. Compute the number of included boxes.
// 4.1.1. Transform that mask into a 0-1 tensor.
xla::XlaOp one_if_included =
xla::Select(included_mask, xla::Broadcast(ONE, {num_boxes}),
xla::Broadcast(ZERO, {num_boxes}));
// 4.1.2. Sum it up.
xla::XlaOp included_boxes =
xla::Reduce(one_if_included, ZERO,
xla::CreateScalarAddComputation(XLAIndexType, builder), {0});

// 4.2. Move the indices of the included boxes to the beginning.
// Doing so alongside the previously sorted indices gives us an index
// tensor with the original indices of the selected boxes at the
// beginning.
xla::XlaOp included_indices_first = xla::GetTupleElement(
xla::Sort(
{
one_if_included,
xla::Reshape(sorted_indices, {num_boxes}),
},
xla::CreateScalarGtComputation({XLAIndexType, XLAIndexType}, builder),
/*dimension=*/0, /*is_stable=*/true),
1);

// 4.3. Get only the first included_boxes indices.
return xla::SetDimensionSize(included_indices_first, included_boxes, 0);
}

} // namespace torch_xla
3 changes: 3 additions & 0 deletions torch_xla/csrc/xla_lower_util.h
Original file line number Diff line number Diff line change
@@ -156,6 +156,9 @@ std::vector<xla::XlaOp> BuildTpuCustomCall(
const std::vector<xla::XlaOp>& inputs, const xla::Shape& output_shape,
const std::string& payload);

xla::XlaOp BuildNms(xla::XlaOp boxes, xla::XlaOp scores,
xla::XlaOp iou_threshold);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_XLA_LOWER_UTIL_H_
39 changes: 39 additions & 0 deletions torch_xla/csrc/xla_manual_registration.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include <ATen/ATen.h>
#include <torch/library.h>

#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/ops/nms.h"
#include "torch_xla/csrc/ops/ops.h"
#include "torch_xla/csrc/tensor_methods.h"
#include "torch_xla/csrc/tensor_util.h"

namespace torch_xla {
namespace manual {
namespace {

at::Tensor nms_kernel(const at::Tensor& boxes, const at::Tensor& scores,
double iou_threshold) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");

XLA_CHECK_EQ(boxes.dim(), 2) << "nms(): boxes should be a 2D tensor.";
XLA_CHECK_EQ(boxes.size(1), 4)
<< "nms(): boxes should be a 2D tensor of shape [N, 4].";
XLA_CHECK_EQ(scores.dim(), 1) << "nms(): scores should be a 1D tensor.";
XLA_CHECK_EQ(boxes.size(0), scores.size(0))
<< "nms(): boxes and scores should have the same size for dimension 0.";

XLATensorPtr xla_boxes = bridge::GetXlaTensor(boxes);
XLATensorPtr xla_scores = bridge::GetXlaTensor(scores);
return bridge::AtenFromXlaTensor(
tensor_methods::nms(xla_boxes, xla_scores, iou_threshold),
/*skip_functionalization=*/true);
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, XLA, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
}

} // namespace manual
} // namespace torch_xla

0 comments on commit 6cf9b91

Please sign in to comment.