Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor nms into TorchVision variant. #6814

Merged
merged 7 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ xla_model
.. automodule:: torch_xla.core.functions
.. autofunction:: all_reduce
.. autofunction:: all_gather
.. autofunction:: nms

distributed
----------------------------------
Expand Down
109 changes: 83 additions & 26 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def skipIfFunctionalizationDisabled(reason):
return _skipIfFunctionalization(value=True, reason=reason)


def onlyOnCUDA(fn):
accelerator = os.environ.get("PJRT_DEVICE").lower()
return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn)


def _gen_tensor(*args, **kwargs):
return torch.randn(*args, **kwargs)

Expand Down Expand Up @@ -2280,32 +2285,6 @@ def test_send_to_device_single(self):
self.assertEqual(dt[0].device, xla_device)
self.assertTrue(torch.all(torch.eq(dt[0].cpu(), t)))

def test_nms(self):
BOXES = (
(0, 0, 3, 2),
(3, 3, 11, 7),
(2, 2, 5, 7),
(7, 4, 15, 12),
)
SCORES = (0.9, 0.5, 0.95, 0.4)
SCORE_THRESHOLD = 0.1
IOU_THRESHOLD = 0.08

xla_device = xm.xla_device()
boxes = torch.tensor(BOXES, dtype=torch.float).to(xla_device)
scores = torch.tensor(SCORES, dtype=torch.float).to(xla_device)
score_threshold = torch.tensor(
SCORE_THRESHOLD, dtype=torch.float).to(xla_device)
iou_threshold = torch.tensor(
IOU_THRESHOLD, dtype=torch.float).to(xla_device)

selected_indices, num_valid = xf.nms(boxes, scores, score_threshold,
iou_threshold, len(BOXES))

self.assertEqual(selected_indices,
torch.tensor([2, 0, 3, 1], dtype=torch.int32))
self.assertEqual(num_valid.item(), 3)

def test_util_foreach_api(self):

class ForTest(object):
Expand Down Expand Up @@ -2473,6 +2452,84 @@ def test_dropout(self):
f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}")


# These tests were extracted and adapted from torchvision.
# Source: vision/test/test_ops.py
class TestNMS(test_utils.XlaTestCase):

def _reference_nms(self, boxes, scores, iou_threshold):
import torchvision
return torchvision.ops.nms(boxes.cpu(), scores.cpu(), iou_threshold)

def _nms(self, boxes, scores, iou_threshold):
import torchvision
device = xm.xla_device()
return torchvision.ops.nms(
boxes.to(device), scores.to(device), iou_threshold).cpu()

def _create_tensors_with_iou(self, N, iou_thresh):
# force last box to have a pre-defined iou with the first box
# let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
# then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
# we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
# Adjust the threshold upward a bit with the intent of creating
# at least one box that exceeds (barely) the threshold and so
# should be suppressed.
boxes = torch.rand(N, 4) * 100
boxes[:, 2:] += boxes[:, :2]
boxes[-1, :] = boxes[0, :]
x0, y0, x1, y1 = boxes[-1].tolist()
iou_thresh += 1e-5
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
scores = torch.rand(N)
return boxes, scores

@skipOnEagerDebug
def test_nms_ref(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ysiraichi can you please add a dynamic shape input test scenario for nms? Because this op is dynamic, it's been falling back to CPU. Of course, there is a lot of interest to bring the op to the xla device; though, this requires correct functionality for dynamism.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more clarity: number of boxes is set to 1000 at the moment. we want that number to be dynamic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be clear the output of this test appears to be dynamic, though the input number of boxes can also be dynamic for nms. This test, currently, covers one dynamism scenario (i.e. the output dynamism).


def _test(iou, seed):
torch.random.manual_seed(seed)
err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
boxes, scores = self._create_tensors_with_iou(1000, iou)
keep_ref = self._reference_nms(boxes, scores, iou)
keep = self._nms(boxes, scores, iou)
self.assertEqual(keep, keep_ref, message=err_msg.format(iou))

for iou in (0.2, 0.5, 0.8):
for seed in range(10):
with self.subTest(iou=iou, seed=seed):
_test(iou, seed)

def test_nms_input_errors(self):
with self.assertRaisesRegex(RuntimeError, "boxes should be a 2D tensor."):
self._nms(torch.rand(4), torch.rand(3), 0.5)
with self.assertRaisesRegex(
RuntimeError, "boxes should be a 2D tensor of shape \[N, 4\]."):
self._nms(torch.rand(3, 5), torch.rand(3), 0.5)
with self.assertRaisesRegex(RuntimeError, "scores should be a 1D tensor."):
self._nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
with self.assertRaisesRegex(
RuntimeError,
"boxes and scores should have the same size for dimension 0."):
self._nms(torch.rand(3, 4), torch.rand(4), 0.5)

def test_legacy(self):
BOXES = (
(0, 0, 3, 2),
(3, 3, 11, 7),
(2, 2, 5, 7),
(7, 4, 15, 12),
)
SCORES = (0.9, 0.5, 0.95, 0.4)
IOU_THRESHOLD = 0.08

def fn(boxes, scores):
return self._reference_nms(boxes, scores, IOU_THRESHOLD)

boxes = torch.tensor(BOXES, dtype=torch.float)
scores = torch.tensor(SCORES, dtype=torch.float)
self.runAtenTest((boxes, scores), fn)


if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
Expand Down
23 changes: 0 additions & 23 deletions torch_xla/core/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,29 +84,6 @@ def all_gather(value, dim=0):
return AllGather.apply(value, dim)


def nms(boxes, scores, score_threshold, iou_threshold, output_size):
"""Performs a Non Maximal Suppression operation.

Args:
boxes (torch.Tensor): A `torch.Tensor` of shape `[N, 4]` listing the boxes
coordinates in `(y0, x0, y1, x1)` form.
scores (torch.Tensor): A `torch.Tensor` of shape `[N]` listing the scores
of each box.
score_threshold (torch.Tensor): The minimum score for a box to qualify as
valid.
iou_threshold (torch.Tensor): The minimum IOU (Intersection Over Union)
score to trigger overlap logic.
output_size (int): The maximum number of returned indices (must be lower or
equal to N).

Returns:
A tuple of `torch.Tensor` with the first element being the selected box
indices, and the second element being the number of valid boxes.
"""
return torch_xla._XLAC._xla_nms(boxes, scores, score_threshold, iou_threshold,
output_size)


def distributed_mm(w, x, split=1):
"""Performs a matrix multiplication with sharded weight.

Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ ptxla_cc_library(
"ir_dump_util.cpp",
"matrix.cpp",
"nll_loss.cpp",
"nms_op.cpp",
"pooling.cpp",
"quant_util.cpp",
"random.cpp",
Expand All @@ -67,6 +66,7 @@ ptxla_cc_library(
"xla_lower_util.cpp",
"xla_op_builder.cpp",
"xla_sharding_util.cpp",
"xla_manual_registration.cpp",
":RegisterAutogradXLA.cpp",
":RegisterXLA.cpp",
":XLANativeFunctions.cpp",
Expand All @@ -87,7 +87,6 @@ ptxla_cc_library(
"ir_dump_util.h",
"matrix.h",
"nll_loss.h",
"nms_op.h",
"pooling.h",
"quant_util.h",
"random.h",
Expand Down
6 changes: 4 additions & 2 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,14 @@ at::Tensor XlaToAtenTensor(XLATensorPtr xla_tensor,
return tensor.to(tensor_options, /*non_blocking=*/false, /*copy=*/true);
}

at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor) {
at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor,
bool skip_functionalization) {
if (xla_tensor) {
auto out =
at::Tensor(c10::make_intrusive<XLATensorImpl>(std::move(xla_tensor)));
// See Note [Lazy Tensor Functionalization]
if (c10::impl::tls_local_dispatch_key_set().excluded_.has(
if (skip_functionalization ||
c10::impl::tls_local_dispatch_key_set().excluded_.has(
c10::DispatchKey::Functionalize)) {
// Invariant: if the functionalization key is in the exclude set, then
// we're expected to return an ordinary tensor, which will be "lifted"
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/aten_xla_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ at::Tensor XlaToAtenTensor(XLATensorPtr xla_tensor,
const at::TensorOptions& tensor_options);

// Creates an ATen tensor with XLA type id from an XLATensorPtr.
at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor);
at::Tensor AtenFromXlaTensor(XLATensorPtr xla_tensor,
bool skip_functionalization = false);

std::vector<at::Tensor> AtenFromXlaTensors(
absl::Span<const XLATensorPtr> xla_tensors);
Expand Down
27 changes: 0 additions & 27 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,28 +666,6 @@ py::object GetRevisions() {
return py_dict;
}

py::object XlaNms(const at::Tensor& boxes, const at::Tensor& scores,
const at::Tensor& score_threshold,
const at::Tensor& iou_threshold, int64_t output_size) {
at::Tensor selected_indices;
at::Tensor num_valid;
{
NoGilSection nogil;
auto nms_result = tensor_methods::nms(
bridge::GetXlaTensor(boxes), bridge::GetXlaTensor(scores),
bridge::GetXlaTensor(score_threshold),
bridge::GetXlaTensor(iou_threshold), output_size);
selected_indices = bridge::AtenFromXlaTensor(std::move(nms_result.first));
num_valid = bridge::AtenFromXlaTensor(std::move(nms_result.second));
}
auto result_tuple = py::tuple(2);
result_tuple[0] =
torch::autograd::make_variable(selected_indices, /*requires_grad=*/false);
result_tuple[1] =
torch::autograd::make_variable(num_valid, /*requires_grad=*/false);
return result_tuple;
}

std::vector<at::Tensor> XlaUserComputation(
const std::string& opname, const std::vector<at::Tensor>& inputs,
runtime::ComputationClient::ComputationPtr computation) {
Expand Down Expand Up @@ -1086,11 +1064,6 @@ void InitXlaModuleBindings(py::module m) {
[](const at::Tensor& tensor, int dim) {
return GetXlaTensorDimensionSize(tensor, dim);
});
m.def("_xla_nms", [](const at::Tensor& boxes, const at::Tensor& scores,
const at::Tensor& score_threshold,
const at::Tensor& iou_threshold, int64_t output_size) {
return XlaNms(boxes, scores, score_threshold, iou_threshold, output_size);
});
m.def("_xla_user_computation",
[](const std::string& opname, const std::vector<at::Tensor>& inputs,
const runtime::ComputationClient::ComputationPtr& computation) {
Expand Down
Loading
Loading