Skip to content

Commit

Permalink
add missing tests to ci scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
Siyuan Liu committed Feb 6, 2024
1 parent ff09ccb commit 92a6e00
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
12 changes: 5 additions & 7 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,6 @@ function run_pt_xla_debug {
PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@"
}

function run_stablehlo_compile {
echo "Running in StableHlo Compile mode: $@"
XLA_STABLEHLO_COMPILE=1 run_test "$@"
}

function run_xla_backend_mp {
echo "Running XLA backend multiprocessing test: $@"
MASTER_ADDR=localhost MASTER_PORT=6000 run_test "$@"
Expand Down Expand Up @@ -200,8 +195,11 @@ function run_xla_op_tests3 {
# TODO(qihqi): this test require tensorflow to run. need to setup separate
# CI with tf.
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_inference.py"
run_stablehlo_compile "$CDIR/stablehlo/test_stablehlo_compile.py"
run_stablehlo_compile "$CDIR/stablehlo/test_implicit_broadcasting.py"
run_test "$CDIR/stablehlo/test_stablehlo_compile.py"
run_test "$CDIR/stablehlo/test_implicit_broadcasting.py"
run_test "$CDIR/stablehlo/test_unbounded_dynamism.py"
run_test "$CDIR/stablehlo/test_mark_pattern.py"
run_test "$CDIR/stablehlo/test_pt2e_qdq.py"
run_test "$CDIR/spmd/test_xla_sharding.py"
run_test "$CDIR/spmd/test_xla_sharding_hlo.py"
run_test "$CDIR/spmd/test_xla_virtual_device.py"
Expand Down
1 change: 1 addition & 0 deletions test/stablehlo/test_pt2e_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_per_channel_qdq(self):
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1)
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1)

@unittest.skip("Failed because PT2E BC break change on constant folding.")
def test_resnet18(self):
# Step 1: export resnet18
args = (torch.randn(1, 3, 224, 224),)
Expand Down
11 changes: 7 additions & 4 deletions test/stablehlo/test_stablehlo_compile.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import unittest

import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch
import torchvision
import unittest
import torch_xla.debug.metrics as met
import torch_xla.debug.metrics_compare_utils as mcu
import numpy as np
import torchvision

os.environ['XLA_STABLEHLO_COMPILE'] = '1'

class StableHloCompileTest(unittest.TestCase):

Expand Down

0 comments on commit 92a6e00

Please sign in to comment.