Skip to content

Commit

Permalink
Add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed Mar 23, 2024
1 parent 302cd4c commit 6e94d7b
Showing 1 changed file with 97 additions and 26 deletions.
123 changes: 97 additions & 26 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def skipIfFunctionalizationDisabled(reason):
return _skipIfFunctionalization(value=True, reason=reason)


def onlyOnCUDA(fn):
accelerator = os.environ.get("PJRT_DEVICE").lower()
return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn)


def _gen_tensor(*args, **kwargs):
return torch.randn(*args, **kwargs)

Expand Down Expand Up @@ -2280,32 +2285,6 @@ def test_send_to_device_single(self):
self.assertEqual(dt[0].device, xla_device)
self.assertTrue(torch.all(torch.eq(dt[0].cpu(), t)))

def test_nms(self):
BOXES = (
(0, 0, 3, 2),
(3, 3, 11, 7),
(2, 2, 5, 7),
(7, 4, 15, 12),
)
SCORES = (0.9, 0.5, 0.95, 0.4)
SCORE_THRESHOLD = 0.1
IOU_THRESHOLD = 0.08

xla_device = xm.xla_device()
boxes = torch.tensor(BOXES, dtype=torch.float).to(xla_device)
scores = torch.tensor(SCORES, dtype=torch.float).to(xla_device)
score_threshold = torch.tensor(
SCORE_THRESHOLD, dtype=torch.float).to(xla_device)
iou_threshold = torch.tensor(
IOU_THRESHOLD, dtype=torch.float).to(xla_device)

selected_indices, num_valid = xf.nms(boxes, scores, score_threshold,
iou_threshold, len(BOXES))

self.assertEqual(selected_indices,
torch.tensor([2, 0, 3, 1], dtype=torch.int32))
self.assertEqual(num_valid.item(), 3)

def test_util_foreach_api(self):

class ForTest(object):
Expand Down Expand Up @@ -2473,6 +2452,98 @@ def test_dropout(self):
f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}")


# These tests were extracted and adapted from torchvision.
# Source: vision/test/test_ops.py
class TestNMS(test_utils.XlaTestCase):

def _reference_nms(self, boxes, scores, iou_threshold):
import torchvision
return torchvision.ops.nms(boxes.cpu(), scores.cpu(), iou_threshold)

def _nms(self, boxes, scores, iou_threshold):
import torchvision
device = xm.xla_device()
return torchvision.ops.nms(
boxes.to(device), scores.to(device), iou_threshold).cpu()

def _create_tensors_with_iou(self, N, iou_thresh):
# force last box to have a pre-defined iou with the first box
# let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
# then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
# we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
# Adjust the threshold upward a bit with the intent of creating
# at least one box that exceeds (barely) the threshold and so
# should be suppressed.
boxes = torch.rand(N, 4) * 100
boxes[:, 2:] += boxes[:, :2]
boxes[-1, :] = boxes[0, :]
x0, y0, x1, y1 = boxes[-1].tolist()
iou_thresh += 1e-5
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
scores = torch.rand(N)
return boxes, scores

def test_nms_ref(self):

def _test(iou, seed):
torch.random.manual_seed(seed)
err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
boxes, scores = self._create_tensors_with_iou(1000, iou)
keep_ref = self._reference_nms(boxes, scores, iou)
keep = self._nms(boxes, scores, iou)
self.assertEqual(keep, keep_ref, message=err_msg.format(iou))

for iou in (0.2, 0.5, 0.8):
for seed in range(10):
with self.subTest(iou=iou, seed=seed):
_test(iou, seed)

def test_nms_input_errors(self):
with self.assertRaisesRegex(RuntimeError, "boxes should be a 2D tensor."):
self._nms(torch.rand(4), torch.rand(3), 0.5)
with self.assertRaisesRegex(RuntimeError,
"boxes should be a 2D tensor of shape \[N, 4\]."):
self._nms(torch.rand(3, 5), torch.rand(3), 0.5)
with self.assertRaisesRegex(RuntimeError, "scores should be a 1D tensor."):
self._nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
with self.assertRaisesRegex(
RuntimeError,
"boxes and scores should have the same size for dimension 0."):
self._nms(torch.rand(3, 4), torch.rand(4), 0.5)

@onlyOnCUDA
def test_nms_float16(self):
boxes = torch.tensor([
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
])
scores = torch.tensor([0.6370, 0.7569, 0.3966])

iou_thres = 0.2
keep32 = self._nms(boxes, scores, iou_thres)
keep16 = self._nms(
boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
assert keep32.eq(keep16).all().item(), f"{keep32} != {keep16}"

def test_legacy(self):
BOXES = (
(0, 0, 3, 2),
(3, 3, 11, 7),
(2, 2, 5, 7),
(7, 4, 15, 12),
)
SCORES = (0.9, 0.5, 0.95, 0.4)
IOU_THRESHOLD = 0.08

def fn(boxes, scores):
return self._reference_nms(boxes, scores, IOU_THRESHOLD)

boxes = torch.tensor(BOXES, dtype=torch.float)
scores = torch.tensor(SCORES, dtype=torch.float)
self.runAtenTest((boxes, scores), fn)


if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
Expand Down

0 comments on commit 6e94d7b

Please sign in to comment.