From 92a6e00cdda0996ceb328afe6f78ac472043c6cb Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Tue, 6 Feb 2024 21:42:33 +0000 Subject: [PATCH] add missing tests to ci scripts --- test/run_tests.sh | 12 +++++------- test/stablehlo/test_pt2e_qdq.py | 1 + test/stablehlo/test_stablehlo_compile.py | 11 +++++++---- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/test/run_tests.sh b/test/run_tests.sh index 84debb84b25..be4f948d27b 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -113,11 +113,6 @@ function run_pt_xla_debug { PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@" } -function run_stablehlo_compile { - echo "Running in StableHlo Compile mode: $@" - XLA_STABLEHLO_COMPILE=1 run_test "$@" -} - function run_xla_backend_mp { echo "Running XLA backend multiprocessing test: $@" MASTER_ADDR=localhost MASTER_PORT=6000 run_test "$@" @@ -200,8 +195,11 @@ function run_xla_op_tests3 { # TODO(qihqi): this test require tensorflow to run. need to setup separate # CI with tf. run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_inference.py" - run_stablehlo_compile "$CDIR/stablehlo/test_stablehlo_compile.py" - run_stablehlo_compile "$CDIR/stablehlo/test_implicit_broadcasting.py" + run_test "$CDIR/stablehlo/test_stablehlo_compile.py" + run_test "$CDIR/stablehlo/test_implicit_broadcasting.py" + run_test "$CDIR/stablehlo/test_unbounded_dynamism.py" + run_test "$CDIR/stablehlo/test_mark_pattern.py" + run_test "$CDIR/stablehlo/test_pt2e_qdq.py" run_test "$CDIR/spmd/test_xla_sharding.py" run_test "$CDIR/spmd/test_xla_sharding_hlo.py" run_test "$CDIR/spmd/test_xla_virtual_device.py" diff --git a/test/stablehlo/test_pt2e_qdq.py b/test/stablehlo/test_pt2e_qdq.py index c901f3684b1..a7d2cfbbe53 100644 --- a/test/stablehlo/test_pt2e_qdq.py +++ b/test/stablehlo/test_pt2e_qdq.py @@ -82,6 +82,7 @@ def test_per_channel_qdq(self): self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1) self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1) + @unittest.skip("Failed because PT2E BC break change on constant folding.") def test_resnet18(self): # Step 1: export resnet18 args = (torch.randn(1, 3, 224, 224),) diff --git a/test/stablehlo/test_stablehlo_compile.py b/test/stablehlo/test_stablehlo_compile.py index d406f733c39..9d38fb9efbb 100644 --- a/test/stablehlo/test_stablehlo_compile.py +++ b/test/stablehlo/test_stablehlo_compile.py @@ -1,12 +1,15 @@ +import os +import unittest + +import numpy as np +import torch import torch_xla import torch_xla.core.xla_model as xm -import torch -import torchvision -import unittest import torch_xla.debug.metrics as met import torch_xla.debug.metrics_compare_utils as mcu -import numpy as np +import torchvision +os.environ['XLA_STABLEHLO_COMPILE'] = '1' class StableHloCompileTest(unittest.TestCase):