diff --git a/test/test_operations.py b/test/test_operations.py index b77ca7630587..8b59a20c2f26 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -83,6 +83,11 @@ def skipIfFunctionalizationDisabled(reason): return _skipIfFunctionalization(value=True, reason=reason) +def onlyOnCUDA(fn): + accelerator = os.environ.get("PJRT_DEVICE").lower() + return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn) + + def _gen_tensor(*args, **kwargs): return torch.randn(*args, **kwargs) @@ -2280,32 +2285,6 @@ def test_send_to_device_single(self): self.assertEqual(dt[0].device, xla_device) self.assertTrue(torch.all(torch.eq(dt[0].cpu(), t))) - def test_nms(self): - BOXES = ( - (0, 0, 3, 2), - (3, 3, 11, 7), - (2, 2, 5, 7), - (7, 4, 15, 12), - ) - SCORES = (0.9, 0.5, 0.95, 0.4) - SCORE_THRESHOLD = 0.1 - IOU_THRESHOLD = 0.08 - - xla_device = xm.xla_device() - boxes = torch.tensor(BOXES, dtype=torch.float).to(xla_device) - scores = torch.tensor(SCORES, dtype=torch.float).to(xla_device) - score_threshold = torch.tensor( - SCORE_THRESHOLD, dtype=torch.float).to(xla_device) - iou_threshold = torch.tensor( - IOU_THRESHOLD, dtype=torch.float).to(xla_device) - - selected_indices, num_valid = xf.nms(boxes, scores, score_threshold, - iou_threshold, len(BOXES)) - - self.assertEqual(selected_indices, - torch.tensor([2, 0, 3, 1], dtype=torch.int32)) - self.assertEqual(num_valid.item(), 3) - def test_util_foreach_api(self): class ForTest(object): @@ -2473,6 +2452,98 @@ def test_dropout(self): f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}") +# These tests were extracted and adapted from torchvision. +# Source: vision/test/test_ops.py +class TestNMS(test_utils.XlaTestCase): + + def _reference_nms(self, boxes, scores, iou_threshold): + import torchvision + return torchvision.ops.nms(boxes.cpu(), scores.cpu(), iou_threshold) + + def _nms(self, boxes, scores, iou_threshold): + import torchvision + device = xm.xla_device() + return torchvision.ops.nms( + boxes.to(device), scores.to(device), iou_threshold).cpu() + + def _create_tensors_with_iou(self, N, iou_thresh): + # force last box to have a pre-defined iou with the first box + # let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1], + # then, in order to satisfy ops.iou(b0, b1) == iou_thresh, + # we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh + # Adjust the threshold upward a bit with the intent of creating + # at least one box that exceeds (barely) the threshold and so + # should be suppressed. + boxes = torch.rand(N, 4) * 100 + boxes[:, 2:] += boxes[:, :2] + boxes[-1, :] = boxes[0, :] + x0, y0, x1, y1 = boxes[-1].tolist() + iou_thresh += 1e-5 + boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh + scores = torch.rand(N) + return boxes, scores + + def test_nms_ref(self): + + def _test(iou, seed): + torch.random.manual_seed(seed) + err_msg = "NMS incompatible between CPU and reference implementation for IoU={}" + boxes, scores = self._create_tensors_with_iou(1000, iou) + keep_ref = self._reference_nms(boxes, scores, iou) + keep = self._nms(boxes, scores, iou) + self.assertEqual(keep, keep_ref, message=err_msg.format(iou)) + + for iou in (0.2, 0.5, 0.8): + for seed in range(10): + with self.subTest(iou=iou, seed=seed): + _test(iou, seed) + + def test_nms_input_errors(self): + with self.assertRaisesRegex(RuntimeError, "boxes should be a 2D tensor."): + self._nms(torch.rand(4), torch.rand(3), 0.5) + with self.assertRaisesRegex(RuntimeError, + "boxes should be a 2D tensor of shape \[N, 4\]."): + self._nms(torch.rand(3, 5), torch.rand(3), 0.5) + with self.assertRaisesRegex(RuntimeError, "scores should be a 1D tensor."): + self._nms(torch.rand(3, 4), torch.rand(3, 2), 0.5) + with self.assertRaisesRegex( + RuntimeError, + "boxes and scores should have the same size for dimension 0."): + self._nms(torch.rand(3, 4), torch.rand(4), 0.5) + + @onlyOnCUDA + def test_nms_float16(self): + boxes = torch.tensor([ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ]) + scores = torch.tensor([0.6370, 0.7569, 0.3966]) + + iou_thres = 0.2 + keep32 = self._nms(boxes, scores, iou_thres) + keep16 = self._nms( + boxes.to(torch.float16), scores.to(torch.float16), iou_thres) + assert keep32.eq(keep16).all().item(), f"{keep32} != {keep16}" + + def test_legacy(self): + BOXES = ( + (0, 0, 3, 2), + (3, 3, 11, 7), + (2, 2, 5, 7), + (7, 4, 15, 12), + ) + SCORES = (0.9, 0.5, 0.95, 0.4) + IOU_THRESHOLD = 0.08 + + def fn(boxes, scores): + return self._reference_nms(boxes, scores, IOU_THRESHOLD) + + boxes = torch.tensor(BOXES, dtype=torch.float) + scores = torch.tensor(SCORES, dtype=torch.float) + self.runAtenTest((boxes, scores), fn) + + if __name__ == '__main__': torch.set_default_dtype(torch.float32) torch.manual_seed(42)