-
Notifications
You must be signed in to change notification settings - Fork 487
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor
nms
into TorchVision variant. (#6814)
Showing
16 changed files
with
398 additions
and
458 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,65 +1,46 @@ | ||
#include "torch_xla/csrc/ops/nms.h" | ||
|
||
#include "torch_xla/csrc/lowering_context.h" | ||
#include "torch_xla/csrc/nms_op.h" | ||
#include "torch_xla/csrc/ops/infer_output_shape.h" | ||
#include "torch_xla/csrc/ops/xla_ops.h" | ||
#include "torch_xla/csrc/runtime/debug_macros.h" | ||
#include "torch_xla/csrc/xla_lower_util.h" | ||
|
||
namespace torch_xla { | ||
namespace { | ||
|
||
xla::Shape NodeOutputShape(const torch::lazy::Value& boxes, | ||
const torch::lazy::Value& scores, | ||
const torch::lazy::Value& score_threshold, | ||
const torch::lazy::Value& iou_threshold, | ||
int64_t output_size) { | ||
const torch::lazy::Value& iou_threshold) { | ||
auto shape_fn = [&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp { | ||
NmsResult result = BuildNms(operands[0], operands[1], operands[2], | ||
operands[3], output_size); | ||
return xla::Tuple(result.selected_indices.builder(), | ||
{result.selected_indices, result.num_valid}); | ||
return BuildNms(operands[0], operands[1], operands[2]); | ||
}; | ||
|
||
return InferOutputShape( | ||
{GetXlaShape(boxes), GetXlaShape(scores), GetXlaShape(score_threshold), | ||
GetXlaShape(iou_threshold)}, | ||
{GetXlaShape(boxes), GetXlaShape(scores), GetXlaShape(iou_threshold)}, | ||
shape_fn); | ||
} | ||
|
||
} // namespace | ||
|
||
Nms::Nms(const torch::lazy::Value& boxes, const torch::lazy::Value& scores, | ||
const torch::lazy::Value& score_threshold, | ||
const torch::lazy::Value& iou_threshold, int64_t output_size) | ||
const torch::lazy::Value& iou_threshold) | ||
: XlaNode( | ||
xla_nms, {boxes, scores, score_threshold, iou_threshold}, | ||
[&]() { | ||
return NodeOutputShape(boxes, scores, score_threshold, | ||
iou_threshold, output_size); | ||
}, | ||
/*num_outputs=*/2, torch::lazy::MHash(output_size)), | ||
output_size_(output_size) {} | ||
/*op=*/xla_nms, | ||
/*operands=*/{boxes, scores, iou_threshold}, | ||
/*xla_shape_fn=*/ | ||
[&]() { return NodeOutputShape(boxes, scores, iou_threshold); }) {} | ||
|
||
torch::lazy::NodePtr Nms::Clone(torch::lazy::OpList operands) const { | ||
return torch::lazy::MakeNode<Nms>(operands.at(0), operands.at(1), | ||
operands.at(2), operands.at(3), | ||
output_size_); | ||
operands.at(2)); | ||
} | ||
|
||
XlaOpVector Nms::Lower(LoweringContext* loctx) const { | ||
xla::XlaOp boxes = loctx->GetOutputOp(operand(0)); | ||
xla::XlaOp scores = loctx->GetOutputOp(operand(1)); | ||
xla::XlaOp score_threshold = loctx->GetOutputOp(operand(2)); | ||
xla::XlaOp iou_threshold = loctx->GetOutputOp(operand(3)); | ||
NmsResult result = | ||
BuildNms(boxes, scores, score_threshold, iou_threshold, output_size_); | ||
return ReturnOps({result.selected_indices, result.num_valid}, loctx); | ||
} | ||
|
||
std::string Nms::ToString() const { | ||
std::stringstream ss; | ||
ss << XlaNode::ToString() << ", output_size=" << output_size_; | ||
return ss.str(); | ||
xla::XlaOp iou_threshold = loctx->GetOutputOp(operand(2)); | ||
return ReturnOp(BuildNms(boxes, scores, iou_threshold), loctx); | ||
} | ||
|
||
} // namespace torch_xla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#include <ATen/ATen.h> | ||
#include <torch/library.h> | ||
|
||
#include "torch_xla/csrc/aten_xla_bridge.h" | ||
#include "torch_xla/csrc/ops/nms.h" | ||
#include "torch_xla/csrc/ops/ops.h" | ||
#include "torch_xla/csrc/tensor_methods.h" | ||
#include "torch_xla/csrc/tensor_util.h" | ||
|
||
namespace torch_xla { | ||
namespace manual { | ||
namespace { | ||
|
||
at::Tensor nms_kernel(const at::Tensor& boxes, const at::Tensor& scores, | ||
double iou_threshold) { | ||
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); | ||
|
||
XLA_CHECK_EQ(boxes.dim(), 2) << "nms(): boxes should be a 2D tensor."; | ||
XLA_CHECK_EQ(boxes.size(1), 4) | ||
<< "nms(): boxes should be a 2D tensor of shape [N, 4]."; | ||
XLA_CHECK_EQ(scores.dim(), 1) << "nms(): scores should be a 1D tensor."; | ||
XLA_CHECK_EQ(boxes.size(0), scores.size(0)) | ||
<< "nms(): boxes and scores should have the same size for dimension 0."; | ||
|
||
XLATensorPtr xla_boxes = bridge::GetXlaTensor(boxes); | ||
XLATensorPtr xla_scores = bridge::GetXlaTensor(scores); | ||
return bridge::AtenFromXlaTensor( | ||
tensor_methods::nms(xla_boxes, xla_scores, iou_threshold), | ||
/*skip_functionalization=*/true); | ||
} | ||
|
||
} // namespace | ||
|
||
TORCH_LIBRARY_IMPL(torchvision, XLA, m) { | ||
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); | ||
} | ||
|
||
} // namespace manual | ||
} // namespace torch_xla |