diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index 62141e531d7..57da0ff799a 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -230,9 +230,6 @@ function run_xla_op_tests3 { run_test "$CDIR/test_persistent_cache.py" run_test "$CDIR/test_devices.py" - run_test "$CDIR/neuron/test_neuron_utils.py" - run_test "$CDIR/neuron/test_neuron_data_types.py" - #python3 examples/data_parallel/train_resnet_xla_ddp.py # compiler error #python3 examples/fsdp/train_resnet_fsdp_auto_wrap.py #python3 examples/eager/train_decoder_only_eager.py # OOM @@ -241,6 +238,12 @@ function run_xla_op_tests3 { PJRT_DEVICE=NEURON NEURONCORE_NUM_DEVICES=1 python3 examples/eager/train_decoder_only_eager_multi_process.py } +# Neuron specific tests +function run_xla_neuron_tests { + run_test "$CDIR/neuron/test_neuron_utils.py" + run_test "$CDIR/neuron/test_neuron_data_types.py" +} + ####################################################################################### function run_op_tests { @@ -248,6 +251,7 @@ function run_op_tests { run_xla_op_tests1 run_xla_op_tests2 run_xla_op_tests3 + run_xla_neuron_tests } function run_mp_op_tests { @@ -292,6 +296,9 @@ function run_tests { elif [[ "$RUN_XLA_OP_TESTS3" == "xla_op3" ]]; then echo "Running xla op tests..." run_xla_op_tests3 + elif [[ "$RUN_XLA_NEURON_TESTS" == "xla_neuron" ]]; then + echo "Running xla neuron tests..." + run_xla_neuron_tests elif [[ "$RUN_TORCH_MP_OP_TESTS" == "torch_mp_op" ]]; then echo "Running torch op tests..." #run_torch_op_tests @@ -309,6 +316,9 @@ function run_tests { if [[ "$XLA_SKIP_MP_OP_TESTS" != "1" ]]; then run_mp_op_tests fi + if [[ "$XLA_SKIP_NEURON_TESTS" != "1" ]]; then + run_xla_neuron_tests + fi fi }