Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make nms fallback by default. #6933

Merged
merged 4 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/cpp/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ BUILDTYPE="opt"
VERB=
FILTER=
LOGFILE=/tmp/pytorch_cpp_test.log
XLA_EXPERIMENTAL="nonzero:masked_select"
XLA_EXPERIMENTAL="nonzero:masked_select:nms"
BAZEL_REMOTE_CACHE="0"
BAZEL_VERB="test"

Expand Down
2 changes: 1 addition & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function run_xla_hlo_debug {

function run_dynamic {
echo "Running in DynamicShape mode: $@"
XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" run_test "$@"
XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter:nms" run_test "$@"
}

function run_eager_debug {
Expand Down
7 changes: 7 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ def onlyOnCUDA(fn):
return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn)


def onlyIfXLAExperimentalContains(feat):
experimental = os.environ.get("XLA_EXPERIMENTAL", "").split(":")
return unittest.skipIf(feat not in experimental,
f"XLA_EXPERIMENTAL={feat} required")


def _gen_tensor(*args, **kwargs):
return torch.randn(*args, **kwargs)

Expand Down Expand Up @@ -2454,6 +2460,7 @@ def test_dropout(self):

# These tests were extracted and adapted from torchvision.
# Source: vision/test/test_ops.py
@onlyIfXLAExperimentalContains("nms")
class TestNMS(test_utils.XlaTestCase):

def _reference_nms(self, boxes, scores, iou_threshold):
Expand Down
4 changes: 2 additions & 2 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ python3 test/spmd/test_xla_distributed_checkpoint.py
python3 test/spmd/test_train_spmd_linear_model.py
python3 test/spmd/test_xla_spmd_python_api_interaction.py
python3 test/spmd/test_xla_auto_sharding.py
XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shape_models.py -v
XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shapes.py -v
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v
python3 test/test_autocast.py
python3 test/dynamo/test_dynamo.py
python3 test/spmd/test_spmd_debugging.py
Expand Down
14 changes: 14 additions & 0 deletions torch_xla/csrc/xla_manual_registration.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include <ATen/ATen.h>
#include <torch/library.h>

#include "torch_xla/csrc/aten_cpu_fallback.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/debug_util.h"
#include "torch_xla/csrc/ops/nms.h"
#include "torch_xla/csrc/ops/ops.h"
#include "torch_xla/csrc/tensor_methods.h"
Expand All @@ -11,10 +13,22 @@ namespace torch_xla {
namespace manual {
namespace {

struct NmsOp {
using schema = at::Tensor(const at::Tensor&, const at::Tensor&, double);
using ptr_schema = schema*;
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "torchvision::nms")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
};

at::Tensor nms_kernel(const at::Tensor& boxes, const at::Tensor& scores,
double iou_threshold) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");

if (!DebugUtil::ExperimentEnabled("nms")) {
return at::native::call_fallback_fn<&xla_cpu_fallback, NmsOp>::call(
boxes, scores, iou_threshold);
}

XLA_CHECK_EQ(boxes.dim(), 2) << "nms(): boxes should be a 2D tensor.";
XLA_CHECK_EQ(boxes.size(1), 4)
<< "nms(): boxes should be a 2D tensor of shape [N, 4].";
Expand Down
Loading