diff --git a/examples/train_decoder_only_base.py b/examples/train_decoder_only_base.py index 5210031a300..a55d8e39989 100644 --- a/examples/train_decoder_only_base.py +++ b/examples/train_decoder_only_base.py @@ -18,7 +18,10 @@ class TrainDecoderOnlyBase(): def __init__(self): self.config = DecoderOnlyConfig() - self.batch_size = 16 + if xr.device_type() == 'NEURON': + self.batch_size = 4 + else: + self.batch_size = 16 self.seq_len = 512 self.num_steps = 200 self.num_epochs = 1 diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index f601ff536af..8200b70d32c 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -30,7 +30,12 @@ def _is_on_tpu(): return xr.device_type() == 'TPU' +def _is_on_neuron(): + return xr.device_type() == 'NEURON' + + skipOnTpu = unittest.skipIf(_is_on_tpu(), 'Not supported on TPU') +skipOnNeuron = unittest.skipIf(_is_on_neuron(), 'Not supported on NEURON') class DynamoInPlaceTest(unittest.TestCase): @@ -152,6 +157,7 @@ def _choose_proper_device(self, initialize_on_cuda): }) return "cuda:0" + @skipOnNeuron def test_simple_model(self): device = xm.xla_device() x = torch.tensor(100.0) @@ -361,6 +367,7 @@ def get_loader(self, device, sample_count, batch_size=4): return loader @skipOnTpu + @skipOnNeuron @parameterized.parameters( True, False, @@ -393,6 +400,7 @@ def test_resnet18(self, initialize_on_cuda): self.assertEqual( met.metric_data('RunCachedGraphOutputData')[0], sample_count) + @skipOnNeuron def test_resnet18_lazy_vs_dynamo(self): sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) device = torch_xla.device() @@ -555,6 +563,7 @@ def test_simple_model(self): input.grad, xla_input.grad.cpu(), rtol=1e-05, atol=1e-04)) @skipOnTpu + @skipOnNeuron def test_resnet18(self): torch._dynamo.reset() met.clear_counters() diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index 9c988f56b60..8f7293e3f84 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -1,4 +1,318 @@ #!/bin/bash -set -xue +set -exo pipefail +CDIR="$(cd "$(dirname "$0")"/../ ; pwd -P)" +LOGFILE=/tmp/pytorch_py_test.log +MAX_GRAPH_SIZE=500 +GRAPH_CHECK_FREQUENCY=100 +VERBOSITY=2 -python3 test/neuron/test_neuron_utils.py +# Note [Keep Going] +# +# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CI tests continue on error. +# This will allow you to see all the failures on your PR, not stopping with the first +# test failure like the default behavior. +CONTINUE_ON_ERROR="${CONTINUE_ON_ERROR:-0}" +if [[ "$CONTINUE_ON_ERROR" == "1" ]]; then + set +e +fi + +while getopts 'LM:C:V:' OPTION +do + case $OPTION in + L) + LOGFILE= + ;; + M) + MAX_GRAPH_SIZE=$OPTARG + ;; + C) + GRAPH_CHECK_FREQUENCY=$OPTARG + ;; + V) + VERBOSITY=$OPTARG + ;; + esac +done +shift $(($OPTIND - 1)) + +export TRIM_GRAPH_SIZE=$MAX_GRAPH_SIZE +export TRIM_GRAPH_CHECK_FREQUENCY=$GRAPH_CHECK_FREQUENCY +export TORCH_TEST_DEVICES="$CDIR/pytorch_test_base.py" +export PYTORCH_TEST_WITH_SLOW=1 +export XLA_DUMP_FATAL_STACK=1 +export CPU_NUM_DEVICES=4 + +TORCH_XLA_DIR=$(cd ~; dirname "$(python -c 'import torch_xla; print(torch_xla.__file__)')") +COVERAGE_FILE="$CDIR/../.coverage" + +function run_coverage { + if [ "${USE_COVERAGE:-0}" != "0" ]; then + coverage run --source="$TORCH_XLA_DIR" -p "$@" + else + python3 "$@" + fi +} + +function run_test { + echo "Running in PjRt runtime: $@" + PJRT_DEVICE=NEURON NEURON_NUM_DEVICES=1 run_coverage "$@" +} + +function run_test_without_functionalization { + echo "Running with XLA_DISABLE_FUNCTIONALIZATION: $@" + XLA_DISABLE_FUNCTIONALIZATION=1 run_test "$@" +} + +function run_xla_ir_debug { + echo "Running with XLA_IR_DEBUG: $@" + XLA_IR_DEBUG=1 run_test "$@" +} + +function run_use_bf16 { + echo "Running with XLA_USE_BF16: $@" + XLA_USE_BF16=1 run_test "$@" +} + +function run_downcast_bf16 { + echo "Running with XLA_DOWNCAST_BF16: $@" + XLA_DOWNCAST_BF16=1 run_test "$@" +} + +function run_xla_hlo_debug { + echo "Running with XLA_IR_DEBUG: $@" + XLA_HLO_DEBUG=1 run_test "$@" +} + +function run_dynamic { + echo "Running in DynamicShape mode: $@" + XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter:nms" run_test "$@" +} + +function run_eager_debug { + echo "Running in Eager Debug mode: $@" + XLA_USE_EAGER_DEBUG_MODE=1 run_test "$@" +} + +function run_save_tensor_ir { + echo "Running in save tensor file mode: $@" + XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="text" run_test "$@" +} + +function run_save_tensor_hlo { + echo "Running in save tensor file mode: $@" + XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" run_test "$@" +} + +function run_pt_xla_debug { + echo "Running in save tensor file mode: $@" + PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@" +} + +function run_pt_xla_debug_level1 { + echo "Running in save tensor file mode: $@" + PT_XLA_DEBUG_LEVEL=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@" +} + +function run_torchrun { + PJRT_DEVICE=NEURON torchrun --nnodes 1 --nproc-per-node 2 $@ +} + +function run_torch_op_tests { + run_dynamic "$CDIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA + run_test_without_functionalization "$CDIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA + run_test "$CDIR/../../test/test_torch.py" "$@" -v TestTorchDeviceTypeXLA + run_dynamic "$CDIR/../../test/test_torch.py" "$@" -v TestDevicePrecisionXLA + run_test "$CDIR/../../test/test_torch.py" "$@" -v TestTensorDeviceOpsXLA + run_test "$CDIR/../../test/test_indexing.py" "$@" -v TestIndexingXLA + run_test "$CDIR/../../test/test_indexing.py" "$@" -v NumpyTestsXLA + # run_dynamic "$CDIR/../../test/test_nn.py" "$@" -v TestNNDeviceTypeXLA + run_dynamic "$CDIR/../../test/nn/test_dropout.py" "$@" -v TestDropoutNNDeviceTypeXLA + run_dynamic "$CDIR/../../test/nn/test_pooling.py" "$@" -v TestPoolingNNDeviceTypeXLA + run_dynamic "$CDIR/../../test/nn/test_embedding.py" "$@" -v TestEmbeddingNNDeviceTypeXLA + run_dynamic "$CDIR/../../test/nn/test_convolution.py" "$@" -v TestConvolutionNNDeviceTypeXLA + run_dynamic "$CDIR/../../test/nn/test_multihead_attention.py" "$@" -v TestMultiheadAttentionNNDeviceTypeXLA + run_dynamic "$CDIR/../../test/test_type_promotion.py" "$@" -v TestTypePromotionXLA +} + +####################################################################################### +################################# XLA OP TESTS SHARDS ################################# +####################################################################################### + +# DO NOT MODIFY +function run_xla_op_tests1 { + #run_dynamic "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY + #run_dynamic "$CDIR/ds/test_dynamic_shapes.py" + #run_dynamic "$CDIR/ds/test_dynamic_shape_models.py" "$@" --verbosity=$VERBOSITY + #run_eager_debug "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY + #run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY + #run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY + run_pt_xla_debug "$CDIR/debug_tool/test_pt_xla_debug.py" + run_pt_xla_debug_level1 "$CDIR/debug_tool/test_pt_xla_debug.py" + run_test "$CDIR/test_async_closures.py" + run_test "$CDIR/test_hlo_metadata.py" + #run_test "$CDIR/test_profiler.py" + run_test "$CDIR/pjrt/test_runtime.py" + #NEURONCORE_NUM_DEVICES=2 python "$CDIR/pjrt/test_ddp.py" + run_test "$CDIR/pjrt/test_mesh_service.py" + #run_test "$CDIR/test_python_ops.py" + #run_test "$CDIR/test_ops.py" + run_test "$CDIR/test_metrics.py" + run_test "$CDIR/test_deprecation.py" + run_test "$CDIR/dynamo/test_dynamo_integrations_util.py" + #run_test "$CDIR/dynamo/test_dynamo_aliasing.py" + run_test "$CDIR/dynamo/test_dynamo.py" + run_test "$CDIR/dynamo/test_dynamo_dynamic_shape.py" + run_test "$CDIR/dynamo/test_bridge.py" + run_test "$CDIR/dynamo/test_num_output.py" + run_test "$CDIR/dynamo/test_graph_input_matcher.py" + run_test "$CDIR/dynamo/test_dynamo_config.py" + run_save_tensor_ir "$CDIR/dynamo/test_dynamo_graph_dump.py" + #run_test "$CDIR/test_data_type.py" + run_use_bf16 "$CDIR/test_data_type.py" + run_downcast_bf16 "$CDIR/test_data_type.py" + #run_test "$CDIR/test_fp8.py" + run_xla_ir_debug "$CDIR/test_env_var_mapper.py" + run_xla_hlo_debug "$CDIR/test_env_var_mapper.py" + run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_save_load.py" + run_save_tensor_ir "$CDIR/spmd/test_spmd_graph_dump.py" + run_save_tensor_hlo "$CDIR/spmd/test_spmd_graph_dump.py" +} + +function run_xla_op_tests2 { + run_test "$CDIR/pjrt/test_dtypes.py" + #run_test "$CDIR/test_while_loop.py" + run_test "$CDIR/test_scan.py" + run_test "$CDIR/test_autocast.py" + run_test "$CDIR/test_grad_checkpoint.py" + #run_test "$CDIR/eager/test_eager.py" + run_test "$CDIR/eager/test_eager_with_xla_compile.py" + run_test "$CDIR/eager/test_eager_with_torch_compile.py" + #run_test "$CDIR/eager/test_eager_all_reduce_in_place.py" + run_test "$CDIR/eager/test_eager_spmd.py" + run_test "$CDIR/test_callback.py" + XLA_USE_SPMD=1 run_test "$CDIR/test_callback.py" +} + +# All the new xla op tests should go to run_xla_op_tests3 +function run_xla_op_tests3 { + # TODO(qihqi): this test require tensorflow to run. need to setup separate + # CI with tf. + run_test "$CDIR/stablehlo/test_exports.py" + run_test "$CDIR/stablehlo/test_export_fx_passes.py" + run_test "$CDIR/stablehlo/test_implicit_broadcasting.py" + run_test "$CDIR/stablehlo/test_composite.py" + run_test "$CDIR/stablehlo/test_pt2e_qdq.py" + run_test "$CDIR/stablehlo/test_stablehlo_custom_call.py" + #run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_inference.py" + #=run_test "$CDIR/stablehlo/test_stablehlo_compile.py" + run_test "$CDIR/stablehlo/test_unbounded_dynamism.py" + #run_test "$CDIR/quantized_ops/test_quantized_matmul.py" + #run_test "$CDIR/quantized_ops/test_dot_general.py" + #run_test "$CDIR/spmd/test_xla_sharding.py" + run_test "$CDIR/spmd/test_xla_sharding_hlo.py" + #run_test "$CDIR/spmd/test_xla_virtual_device.py" + #run_test "$CDIR/spmd/test_dynamo_spmd.py" + run_test "$CDIR/spmd/test_spmd_debugging.py" + #=run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py" + run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py" + #run_test "$CDIR/spmd/test_dtensor_integration.py" + #run_test "$CDIR/spmd/test_dtensor_integration2.py" + run_test "$CDIR/spmd/test_xla_auto_sharding.py" + #run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py" + run_test "$CDIR/spmd/test_train_spmd_linear_model.py" + run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py" + run_test "$CDIR/spmd/test_xla_auto_sharding.py" + run_test "$CDIR/spmd/test_fsdp_v2.py" + run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY + run_test "$CDIR/test_input_output_aliases.py" + run_test "$CDIR/test_torch_distributed_xla_backend.py" + run_torchrun "$CDIR/pjrt/test_torchrun.py" + run_test "$CDIR/test_persistent_cache.py" + run_test "$CDIR/test_devices.py" + + run_test "$CDIR/neuron/test_neuron_utils.py" + + #python3 examples/data_parallel/train_resnet_xla_ddp.py # compiler error + #python3 examples/fsdp/train_resnet_fsdp_auto_wrap.py + #python3 examples/eager/train_decoder_only_eager.py # OOM + #python3 examples/eager/train_decoder_only_eager_spmd_data_parallel.py # compiler err due to f64 + PJRT_DEVICE=NEURON NEURONCORE_NUM_DEVICES=1 python3 examples/eager/train_decoder_only_eager_with_compile.py # nan loss expected? + PJRT_DEVICE=NEURON NEURONCORE_NUM_DEVICES=1 python3 examples/eager/train_decoder_only_eager_multi_process.py +} + +####################################################################################### + +function run_op_tests { + run_torch_op_tests + run_xla_op_tests1 + run_xla_op_tests2 + run_xla_op_tests3 +} + +function run_mp_op_tests { + run_test "$CDIR/test_mp_replication.py" + #run_test "$CDIR/test_mp_all_to_all.py" + run_test "$CDIR/test_mp_collective_permute.py" + #run_test "$CDIR/test_mp_all_gather.py" # "wrong reductions"? + run_test "$CDIR/test_mp_reduce_scatter.py" + run_test "$CDIR/test_zero1.py" + run_test "$CDIR/test_mp_distributed_mm.py" + run_test "$CDIR/test_mp_save.py" + run_test "$CDIR/test_mp_mesh_reduce.py" + run_test "$CDIR/test_mp_sync_batch_norm.py" + # TODO(JackCaoG): enable this + run_test "$CDIR/dynamo/test_traceable_collectives.py" + run_test "$CDIR/test_fsdp_auto_wrap.py" + # run_torchrun "$CDIR/test_mp_early_exit.py" + run_pt_xla_debug "$CDIR/debug_tool/test_mp_pt_xla_debug.py" + run_test "$CDIR/torch_distributed/test_torch_distributed_all_gather_xla_backend.py" + run_test "$CDIR/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py" + #run_test "$CDIR/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py" # crash without NEURONCORE_NUM_DEVICES=2 + run_test "$CDIR/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py" + run_test "$CDIR/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py" + run_test "$CDIR/torch_distributed/test_ddp.py" + #run_test "$CDIR/torch_distributed/test_torch_distributed_fsdp_meta.py" # crash without NEURONCORE_NUM_DEVICES=2 + PJRT_DEVICE=NEURON NEURONCORE_NUM_DEVICES=2 python3 $CDIR/torch_distributed/test_torch_distributed_all_gather_xla_backend.py + PJRT_DEVICE=NEURON NEURONCORE_NUM_DEVICES=2 python3 $CDIR/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py + PJRT_DEVICE=NEURON NEURONCORE_NUM_DEVICES=2 python3 $CDIR/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py + PJRT_DEVICE=NEURON NEURONCORE_NUM_DEVICES=2 python3 $CDIR/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py + PJRT_DEVICE=NEURON NEURONCORE_NUM_DEVICES=2 python3 $CDIR/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py + PJRT_DEVICE=NEURON NEURONCORE_NUM_DEVICES=2 python3 $CDIR/torch_distributed/test_torch_distributed_fsdp_meta.py +} + +function run_tests { + # RUN_ flags filter an explicit test type to run, XLA_SKIP_ flags exclude one. + if [[ "$RUN_XLA_OP_TESTS1" == "xla_op1" ]]; then + echo "Running xla op tests..." + run_xla_op_tests1 + elif [[ "$RUN_XLA_OP_TESTS2" == "xla_op2" ]]; then + echo "Running xla op tests..." + run_xla_op_tests2 + elif [[ "$RUN_XLA_OP_TESTS3" == "xla_op3" ]]; then + echo "Running xla op tests..." + run_xla_op_tests3 + elif [[ "$RUN_TORCH_MP_OP_TESTS" == "torch_mp_op" ]]; then + echo "Running torch op tests..." + #run_torch_op_tests + run_mp_op_tests + else + # Run full tests without sharding, respects XLA_SKIP_* + if [[ "$XLA_SKIP_XLA_OP_TESTS" != "1" ]]; then + run_xla_op_tests1 + run_xla_op_tests2 + run_xla_op_tests3 + fi + #if [[ "$XLA_SKIP_TORCH_OP_TESTS" != "1" ]]; then + # run_torch_op_tests + #fi + if [[ "$XLA_SKIP_MP_OP_TESTS" != "1" ]]; then + run_mp_op_tests + fi + fi +} + +if [ "$LOGFILE" != "" ]; then + run_tests 2>&1 | tee $LOGFILE +else + run_tests +fi diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index acf96cf8efb..c9cfd66836e 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -34,7 +34,9 @@ def _mp_fn(index): cpu_result = result.cpu() expected = torch.arange(0, world_size, dtype=torch.float) if not cpu_result.allclose(expected): - print('xm.all_gather() produced wrong reductions', file=sys.stderr) + print( + 'xm.all_gather() produced wrong reductions (torch.compile)', + file=sys.stderr) print(f'[{index}] {cpu_result}', file=sys.stderr) sys.exit(1) diff --git a/test/test_mp_all_to_all.py b/test/test_mp_all_to_all.py index 6b2811a58bb..f7e4a2f0c08 100644 --- a/test/test_mp_all_to_all.py +++ b/test/test_mp_all_to_all.py @@ -7,7 +7,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) == 'TPU': + if xm.xla_device_hw(device) in ('TPU', 'NEURON'): slots_per_device = 4 size = slots_per_device * xr.world_size() ordinal = xr.global_ordinal() diff --git a/test/test_mp_collective_permute.py b/test/test_mp_collective_permute.py index 3ad0568c583..07f99712f27 100644 --- a/test/test_mp_collective_permute.py +++ b/test/test_mp_collective_permute.py @@ -7,7 +7,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) == 'TPU': + if xm.xla_device_hw(device) in ['TPU', 'NEURON']: world_size = xr.world_size() ordinal = xr.global_ordinal() value = torch.tensor([ordinal] * 100, dtype=torch.int32, device=device) diff --git a/test/test_persistent_cache.py b/test/test_persistent_cache.py index 1683d15ddbb..a0619b45f14 100644 --- a/test/test_persistent_cache.py +++ b/test/test_persistent_cache.py @@ -93,7 +93,7 @@ def _spmd_sharded_test(tmpdir, metrics): _assert_correctness_and_metrics(t, xt, metrics) -@absltest.skipUnless(xr.device_type() in {'TPU', 'CUDA'}, +@absltest.skipUnless(xr.device_type() in {'TPU', 'CUDA', 'NEURON'}, 'Device type does not support persistent caching') class PersistentCacheTest(parameterized.TestCase): """ diff --git a/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py b/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py index a3fc373bc49..80dc7f9eabf 100644 --- a/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'CUDA'): + if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py index 0171a7f17a3..38ba9a559bf 100644 --- a/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'CUDA'): + if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() dist.init_process_group('xla', init_method='xla://') # note that we can't use torch.tensor(torch.distributed.get_rank()) directly diff --git a/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py index f29a48479c3..3d5736b0ec4 100644 --- a/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'CUDA'): + if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/torch_distributed/test_torch_distributed_fsdp_meta.py b/test/torch_distributed/test_torch_distributed_fsdp_meta.py index 9b3e214d302..6969c4e9c77 100644 --- a/test/torch_distributed/test_torch_distributed_fsdp_meta.py +++ b/test/torch_distributed/test_torch_distributed_fsdp_meta.py @@ -143,7 +143,7 @@ def meta_module_fn(): def _mp_fn(index): device = xm.xla_device() # This test fails on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840) - if xm.xla_device_hw(device) in ('TPU'): + if xm.xla_device_hw(device) in ('TPU', 'NEURON'): dist.init_process_group('xla', init_method='xla://') test = TestFSDPWithMetaDevice() test.test_simple_model_with_meta_device_reset_params() diff --git a/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py index 1a136b29391..c10be21ef42 100644 --- a/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'CUDA'): + if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/torch_xla/amp/autocast_mode.py b/torch_xla/amp/autocast_mode.py index 8d3be719590..867dddd07bb 100644 --- a/torch_xla/amp/autocast_mode.py +++ b/torch_xla/amp/autocast_mode.py @@ -47,7 +47,7 @@ def __init__(self, enabled=enabled, dtype=self._dtype, cache_enabled=cache_enabled) - elif self._xla_device == 'TPU': + elif self._xla_device == 'TPU' or self._xla_device == 'NEURON': if dtype is None: dtype = torch.bfloat16 if dtype != torch.bfloat16: