Skip to content

Commit

Permalink
Refactor nms into TorchVision variant. (#6814)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored Apr 1, 2024
1 parent b0ba29f commit 6cf9b91
Show file tree
Hide file tree
Showing 16 changed files with 398 additions and 458 deletions.
1 change: 0 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ xla_model
.. automodule:: torch_xla.core.functions
.. autofunction:: all_reduce
.. autofunction:: all_gather
.. autofunction:: nms

distributed
----------------------------------
Expand Down
109 changes: 83 additions & 26 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def skipIfFunctionalizationDisabled(reason):
return _skipIfFunctionalization(value=True, reason=reason)


def onlyOnCUDA(fn):
accelerator = os.environ.get("PJRT_DEVICE").lower()
return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn)


def _gen_tensor(*args, **kwargs):
return torch.randn(*args, **kwargs)

Expand Down Expand Up @@ -2280,32 +2285,6 @@ def test_send_to_device_single(self):
self.assertEqual(dt[0].device, xla_device)
self.assertTrue(torch.all(torch.eq(dt[0].cpu(), t)))

def test_nms(self):
BOXES = (
(0, 0, 3, 2),
(3, 3, 11, 7),
(2, 2, 5, 7),
(7, 4, 15, 12),
)
SCORES = (0.9, 0.5, 0.95, 0.4)
SCORE_THRESHOLD = 0.1
IOU_THRESHOLD = 0.08

xla_device = xm.xla_device()
boxes = torch.tensor(BOXES, dtype=torch.float).to(xla_device)
scores = torch.tensor(SCORES, dtype=torch.float).to(xla_device)
score_threshold = torch.tensor(
SCORE_THRESHOLD, dtype=torch.float).to(xla_device)
iou_threshold = torch.tensor(
IOU_THRESHOLD, dtype=torch.float).to(xla_device)

selected_indices, num_valid = xf.nms(boxes, scores, score_threshold,
iou_threshold, len(BOXES))

self.assertEqual(selected_indices,
torch.tensor([2, 0, 3, 1], dtype=torch.int32))
self.assertEqual(num_valid.item(), 3)

def test_util_foreach_api(self):

class ForTest(object):
Expand Down Expand Up @@ -2473,6 +2452,84 @@ def test_dropout(self):
f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}")


# These tests were extracted and adapted from torchvision.
# Source: vision/test/test_ops.py
class TestNMS(test_utils.XlaTestCase):

def _reference_nms(self, boxes, scores, iou_threshold):
import torchvision
return torchvision.ops.nms(boxes.cpu(), scores.cpu(), iou_threshold)

def _nms(self, boxes, scores, iou_threshold):
import torchvision
device = xm.xla_device()
return torchvision.ops.nms(
boxes.to(device), scores.to(device), iou_threshold).cpu()

def _create_tensors_with_iou(self, N, iou_thresh):
# force last box to have a pre-defined iou with the first box
# let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
# then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
# we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
# Adjust the threshold upward a bit with the intent of creating
# at least one box that exceeds (barely) the threshold and so
# should be suppressed.
boxes = torch.rand(N, 4) * 100
boxes[:, 2:] += boxes[:, :2]
boxes[-1, :] = boxes[0, :]
x0, y0, x1, y1 = boxes[-1].tolist()
iou_thresh += 1e-5
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
scores = torch.rand(N)
return boxes, scores

@skipOnEagerDebug
def test_nms_ref(self):

def _test(iou, seed):
torch.random.manual_seed(seed)
err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
boxes, scores = self._create_tensors_with_iou(1000, iou)
keep_ref = self._reference_nms(boxes, scores, iou)
keep = self._nms(boxes, scores, iou)
self.assertEqual(keep, keep_ref, message=err_msg.format(iou))

for iou in (0.2, 0.5, 0.8):
for seed in range(10):
with self.subTest(iou=iou, seed=seed):
_test(iou, seed)

def test_nms_input_errors(self):
with self.assertRaisesRegex(RuntimeError, "boxes should be a 2D tensor."):
self._nms(torch.rand(4), torch.rand(3), 0.5)
with self.assertRaisesRegex(
RuntimeError, "boxes should be a 2D tensor of shape \[N, 4\]."):
self._nms(torch.rand(3, 5), torch.rand(3), 0.5)
with self.assertRaisesRegex(RuntimeError, "scores should be a 1D tensor."):
self._nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
with self.assertRaisesRegex(
RuntimeError,
"boxes and scores should have the same size for dimension 0."):
self._nms(torch.rand(3, 4), torch.rand(4), 0.5)

def test_legacy(self):
BOXES = (
(0, 0, 3, 2),
(3, 3, 11, 7),
(2, 2, 5, 7),
(7, 4, 15, 12),
)
SCORES = (0.9, 0.5, 0.95, 0.4)
IOU_THRESHOLD = 0.08

def fn(boxes, scores):
return self._reference_nms(boxes, scores, IOU_THRESHOLD)

boxes = torch.tensor(BOXES, dtype=torch.float)
scores = torch.tensor(SCORES, dtype=torch.float)
self.runAtenTest((boxes, scores), fn)


if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
Expand Down
23 changes: 0 additions & 23 deletions torch_xla/core/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,29 +84,6 @@ def all_gather(value, dim=0):
return AllGather.apply(value, dim)


def nms(boxes, scores, score_threshold, iou_threshold, output_size):
"""Performs a Non Maximal Suppression operation.
Args:
boxes (torch.Tensor): A `torch.Tensor` of shape `[N, 4]` listing the boxes
coordinates in `(y0, x0, y1, x1)` form.
scores (torch.Tensor): A `torch.Tensor` of shape `[N]` listing the scores
of each box.
score_threshold (torch.Tensor): The minimum score for a box to qualify as
valid.
iou_threshold (torch.Tensor): The minimum IOU (Intersection Over Union)
score to trigger overlap logic.
output_size (int): The maximum number of returned indices (must be lower or
equal to N).
Returns:
A tuple of `torch.Tensor` with the first element being the selected box
indices, and the second element being the number of valid boxes.
"""
return torch_xla._XLAC._xla_nms(boxes, scores, score_threshold, iou_threshold,
output_size)


def distributed_mm(w, x, split=1):
"""Performs a matrix multiplication with sharded weight.
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ ptxla_cc_library(
"ir_dump_util.cpp",
"matrix.cpp",
"nll_loss.cpp",
"nms_op.cpp",
"pooling.cpp",
"quant_util.cpp",
"random.cpp",
Expand All @@ -67,6 +66,7 @@ ptxla_cc_library(
"xla_lower_util.cpp",
"xla_op_builder.cpp",
"xla_sharding_util.cpp",
"xla_manual_registration.cpp",
":RegisterAutogradXLA.cpp",
":RegisterXLA.cpp",
":XLANativeFunctions.cpp",
Expand All @@ -87,7 +87,6 @@ ptxla_cc_library(
"ir_dump_util.h",
"matrix.h",
"nll_loss.h",
"nms_op.h",
"pooling.h",
"quant_util.h",
"random.h",
Expand Down
6 changes: 4 additions & 2 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,14 @@ at::Tensor XlaToAtenTensor(XLATensorPtr xla_tensor,
return tensor.to(tensor_options, /*non_blocking=*/false, /*copy=*/true);
}

at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor) {
at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor,
bool skip_functionalization) {
if (xla_tensor) {
auto out =
at::Tensor(c10::make_intrusive<XLATensorImpl>(std::move(xla_tensor)));
// See Note [Lazy Tensor Functionalization]
if (c10::impl::tls_local_dispatch_key_set().excluded_.has(
if (skip_functionalization ||
c10::impl::tls_local_dispatch_key_set().excluded_.has(
c10::DispatchKey::Functionalize)) {
// Invariant: if the functionalization key is in the exclude set, then
// we're expected to return an ordinary tensor, which will be "lifted"
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/aten_xla_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ at::Tensor XlaToAtenTensor(XLATensorPtr xla_tensor,
const at::TensorOptions& tensor_options);

// Creates an ATen tensor with XLA type id from an XLATensorPtr.
at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor);
at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor,
bool skip_functionalization = false);

std::vector<at::Tensor> AtenFromXlaTensors(
absl::Span<const XLATensorPtr> xla_tensors);
Expand Down
27 changes: 0 additions & 27 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,28 +666,6 @@ py::object GetRevisions() {
return py_dict;
}

py::object XlaNms(const at::Tensor& boxes, const at::Tensor& scores,
const at::Tensor& score_threshold,
const at::Tensor& iou_threshold, int64_t output_size) {
at::Tensor selected_indices;
at::Tensor num_valid;
{
NoGilSection nogil;
auto nms_result = tensor_methods::nms(
bridge::GetXlaTensor(boxes), bridge::GetXlaTensor(scores),
bridge::GetXlaTensor(score_threshold),
bridge::GetXlaTensor(iou_threshold), output_size);
selected_indices = bridge::AtenFromXlaTensor(std::move(nms_result.first));
num_valid = bridge::AtenFromXlaTensor(std::move(nms_result.second));
}
auto result_tuple = py::tuple(2);
result_tuple[0] =
torch::autograd::make_variable(selected_indices, /*requires_grad=*/false);
result_tuple[1] =
torch::autograd::make_variable(num_valid, /*requires_grad=*/false);
return result_tuple;
}

std::vector<at::Tensor> XlaUserComputation(
const std::string& opname, const std::vector<at::Tensor>& inputs,
runtime::ComputationClient::ComputationPtr computation) {
Expand Down Expand Up @@ -1086,11 +1064,6 @@ void InitXlaModuleBindings(py::module m) {
[](const at::Tensor& tensor, int dim) {
return GetXlaTensorDimensionSize(tensor, dim);
});
m.def("_xla_nms", [](const at::Tensor& boxes, const at::Tensor& scores,
const at::Tensor& score_threshold,
const at::Tensor& iou_threshold, int64_t output_size) {
return XlaNms(boxes, scores, score_threshold, iou_threshold, output_size);
});
m.def("_xla_user_computation",
[](const std::string& opname, const std::vector<at::Tensor>& inputs,
const runtime::ComputationClient::ComputationPtr& computation) {
Expand Down
Loading

0 comments on commit 6cf9b91

Please sign in to comment.