Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support NMS op lowering #3871

Merged
merged 12 commits into from
Nov 27, 2024
Merged

Support NMS op lowering #3871

merged 12 commits into from
Nov 27, 2024

Conversation

jinchen62
Copy link
Collaborator

@jinchen62 jinchen62 commented Nov 13, 2024

TODO: support multiple batches and classes

@zjgarvey
Copy link
Collaborator

Like we were talking about on a call, I definitely think we should tensorize some of this.

In the outer loop, we pick a single box to compare to the other boxes.
Instead of computing the IOU for each other box individually in the inner loop, lets do:

// outside all loops:
Value x1Slice /*=slice boxes to get x1 values for each box*/;
Value x2Slice ..;
Value y1Slice ..;
Value y2Slice ..;
Value xDistance = x2Slice - x1Slice;
Value yDistance = y2Slice - y1Slice;
Value boxAreas = xDistance * yDistance;
// inside loop over sorted boxes:
Value currBox /*=the box with the highest score among available boxes*/;
// The elementwise tensor arithmetic below allows broadcasting
Value innerX1 = max(x1Slice, currBox[0]);
Value innerX2 = min(x2Slice, currBox[2]);
Value innerY1 ..;
Value innerY2 ..;
Value intersectionDistanceX = innerX2 - innerX1;
Value intersectionDistanceY = innerY2 - innerY1;
Value intersectionArea = intersectionDistanceX * intersectionDistanceY;
Value currArea = boxAreas[currBoxIdx];
Value unionArea = boxAreas + currArea - intersectionArea;
Value IOU = intersectionArea / unionArea;

Actually, based on the ordering of [x1,y1,x2,y2] we could even do:

// outside all loops:
lowSlice = [x1,y1] slice;
highSlice = [x2,y2] slice;
distances = highSlice - lowSlice;
area = reduceProd(distances, the dim that has size two);
// inside loop
innerLow = max(lowSlice, currBox[x1,y1]);
innerHigh = min(highSlice, currBox[x2,y2]);
innerDistance = innerHigh - innerLow;
intersectionArea = reduceProd(innerDistance);
currArea = area[currBoxIdx];
unionArea = area + currArea - intersectionArea;
IOUs = intersectionArea / unionArea;

Although, I'm not sure how we can go about skipping the IOU calculations that are redundant. I'm actually curious if it is worth it to skip redundant computations if it requires us to extract individual elements and do the arithmetic one at a time. At the very least, computing each of the box areas outside the loops is going to be an improvement.

@jinchen62 jinchen62 marked this pull request as draft November 16, 2024 00:15
@jinchen62 jinchen62 force-pushed the nms_decomp branch 3 times, most recently from f2997d3 to e07ec39 Compare November 19, 2024 17:50
@jinchen62 jinchen62 changed the title Add torchvision.nms decomposition Support NMS op lowering Nov 19, 2024
@jinchen62 jinchen62 marked this pull request as ready for review November 20, 2024 00:57
@jinchen62 jinchen62 force-pushed the nms_decomp branch 5 times, most recently from bbbfe87 to e1c19ec Compare November 26, 2024 05:14
@jinchen62
Copy link
Collaborator Author

jinchen62 commented Nov 26, 2024

Three e2e tests pass with #3892 and without flag --iree-input-demote-f64-to-f32=false(iree-org/iree#19299).

  • test_nonmaxsuppression_identical_boxes
  • test_nonmaxsuppression_single_box
  • test_nonmaxsuppression_suppress_by_IOU

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation looks good to me, and if we can pass e2e tests, I think this is good.

My only lingering concern is related to the signature of the onnx op when one of the other optional args is missing. It might be good to add some test cases to check what happens there and make sure we won't mis-compile those examples.

Another point of concern is that this whole implementation could likely be done in the onnx-to-torch lowering and might allow support for the batched version of NMS that onnx uses (we have some unsqueezes in this conversion that cannot reasonably be assumed to succeed). I don't think we need to make that change now, but it's something to keep in mind if we end up seeing this op in models we care about with non-trivial batching.

@jinchen62 jinchen62 merged commit c9ed993 into llvm:main Nov 27, 2024
3 checks passed
@jinchen62 jinchen62 deleted the nms_decomp branch November 27, 2024 00:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants