diff --git a/test/cpp/test_aten_xla_tensor_2.cpp b/test/cpp/test_aten_xla_tensor_2.cpp index 9c4135f64c5..b8599c0e7d6 100644 --- a/test/cpp/test_aten_xla_tensor_2.cpp +++ b/test/cpp/test_aten_xla_tensor_2.cpp @@ -1512,7 +1512,7 @@ TEST_F(AtenXlaTensorTest, TestGroupNormBackward) { /*cudnn_enabled=*/false); }; torch::Tensor undef; - ForEachDevice({XlaDeviceType::GPU, XlaDeviceType::TPU}, + ForEachDevice({XlaDeviceType::CUDA, XlaDeviceType::TPU}, [&](const torch::Device& device) { TestBackward({input, undef_weight ? undef : weight, undef_weight ? undef : bias}, diff --git a/test/cpp/test_aten_xla_tensor_6.cpp b/test/cpp/test_aten_xla_tensor_6.cpp index d2e3e284f5b..d7eb32619c4 100644 --- a/test/cpp/test_aten_xla_tensor_6.cpp +++ b/test/cpp/test_aten_xla_tensor_6.cpp @@ -873,7 +873,7 @@ TEST_F(AtenXlaTensorTest, TestEmbeddingBackward) { TEST_F(AtenXlaTensorTest, TestAmpUpdateScale) { XlaDeviceType hw_type = static_cast(bridge::GetDefaultDevice()->type()); - if (hw_type != XlaDeviceType::GPU && hw_type != XlaDeviceType::CPU) { + if (hw_type != XlaDeviceType::CUDA && hw_type != XlaDeviceType::CPU) { return; } torch::Tensor growth_tracker = diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index 08b039b9e5f..6d7a54add0c 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -94,7 +94,7 @@ class ReplicationTest : public AtenXlaTensorTestBase {}; TEST_F(ReplicationTest, TestNSingleReplication) { WithAllDevices( - {XlaDeviceType::TPU, XlaDeviceType::GPU}, + {XlaDeviceType::TPU, XlaDeviceType::CUDA}, [&](const std::vector& devices, const std::vector& all_devices) { TestSingleReplication(devices, all_devices); diff --git a/test/pjrt/test_ddp.py b/test/pjrt/test_ddp.py index f84cc30ec9e..7b359311c8f 100644 --- a/test/pjrt/test_ddp.py +++ b/test/pjrt/test_ddp.py @@ -32,7 +32,7 @@ def _ddp_init(index: int = ...): def test_ddp_init(self): pjrt.run_multiprocess(self._ddp_init) - @absltest.skipIf(xr.device_type() == 'GPU', + @absltest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'), "GPU device is not supported by pjrt.spawn_threads") def test_ddp_init_threaded(self): pjrt.spawn_threads(self._ddp_init) diff --git a/test/pjrt/test_runtime.py b/test/pjrt/test_runtime.py index 8e500ea4ef0..8cb930714e0 100644 --- a/test/pjrt/test_runtime.py +++ b/test/pjrt/test_runtime.py @@ -16,7 +16,7 @@ class TestExperimentalPjrt(parameterized.TestCase): def setUp(self): xr.set_device_type('CPU') - @parameterized.parameters(('CPU', 'CPU'), ('GPU', 'GPU'), ('TPU', 'TPU'), + @parameterized.parameters(('CPU', 'CPU'), ('CUDA', 'CUDA'), ('TPU', 'TPU'), ('TPU_C_API', 'TPU'), ('TPU_LEGACY', 'TPU')) def test_device_type(self, pjrt_device, expected): with mock.patch.dict(os.environ, {'PJRT_DEVICE': pjrt_device}, clear=True): @@ -61,7 +61,7 @@ def test_xla_device_error(self): }, True), ('gpu_num_devives', { 'GPU_NUM_DEVICES': '4' }, True), ('pjrt_gpu', { - 'PJRT_DEVICE': 'GPU', + 'PJRT_DEVICE': 'CUDA', 'GPU_NUM_DEVICES': '4' }, True)) def test_pjrt_default_device(self, env_vars, expect_using_pjrt): @@ -77,7 +77,7 @@ def test_pjrt_default_device(self, env_vars, expect_using_pjrt): xr.using_pjrt() if expect_using_pjrt: - self.assertIn(xr.device_type(), ['CPU', 'GPU', 'TPU']) + self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU', 'ROCM', 'GPU']) else: self.assertIsNone(xr.device_type()) diff --git a/test/pjrt/test_runtime_gpu.py b/test/pjrt/test_runtime_gpu.py index d82144b2c1a..77fd4d94fb7 100644 --- a/test/pjrt/test_runtime_gpu.py +++ b/test/pjrt/test_runtime_gpu.py @@ -17,12 +17,12 @@ from absl.testing import absltest, parameterized -@unittest.skipIf(xr.device_type() != 'GPU', +@unittest.skipIf(xr.device_type() not in ('GPU', 'CUDA', 'ROCM'), f"GPU tests should only run on GPU devices.") class TestExperimentalPjrtGpu(parameterized.TestCase): def setUp(self): - xr.set_device_type('GPU') + xr.set_device_type('CUDA') os.environ.update({ xenv.PJRT_GPU_ASYNC_CLIENT: 'true', diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index 1c77e85b9f6..6835ebe7993 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -519,7 +519,7 @@ def union_of_disabled_tests(sets): DISABLED_TORCH_TESTS = { 'TPU': prepare_match_set(DISABLED_TORCH_TESTS_TPU), 'CPU': prepare_match_set(DISABLED_TORCH_TESTS_CPU), - 'GPU': prepare_match_set(DISABLED_TORCH_TESTS_GPU), + 'CUDA': prepare_match_set(DISABLED_TORCH_TESTS_GPU), } diff --git a/test/run_tests.sh b/test/run_tests.sh index 9b13aa4494f..0bb9b965b83 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -56,7 +56,7 @@ function run_coverage { function run_test { echo "Running in PjRt runtime: $@" if [ -x "$(command -v nvidia-smi)" ] && [ "$XLA_CUDA" != "0" ]; then - PJRT_DEVICE=GPU run_coverage "$@" + PJRT_DEVICE=CUDA run_coverage "$@" else # TODO(darisoy): run these tests with multiple CPU devices, this fails due to TF issue. PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$@" diff --git a/test/spmd/test_xla_sharding_base.py b/test/spmd/test_xla_sharding_base.py index 4b83368d380..54067512ce2 100644 --- a/test/spmd/test_xla_sharding_base.py +++ b/test/spmd/test_xla_sharding_base.py @@ -10,7 +10,7 @@ @unittest.skipIf(not xr.using_pjrt() or - xu.getenv_as(xenv.PJRT_DEVICE, str) == "GPU", + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM'), f"Requires PJRT_DEVICE set to `TPU` or `CPU`.") class XlaShardingTest(unittest.TestCase): diff --git a/test/spmd/test_xla_spmd_python_api_interaction.py b/test/spmd/test_xla_spmd_python_api_interaction.py index 8ea4db3e051..9228558f5e2 100644 --- a/test/spmd/test_xla_spmd_python_api_interaction.py +++ b/test/spmd/test_xla_spmd_python_api_interaction.py @@ -120,7 +120,7 @@ def setUpClass(cls): xr.use_spmd() super().setUpClass() - @unittest.skipIf(xr.device_type() not in ['GPU', 'TPU'], + @unittest.skipIf(xr.device_type() not in ['GPU', 'TPU', 'CUDA', 'ROCM'], f"TPU/GPU autocast test.") def test_xla_autocast_api(self): device = xm.xla_device() diff --git a/test/test_autocast.py b/test/test_autocast.py index 9caa3017ea8..edbd834b61b 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -341,7 +341,8 @@ def compare(first, second): self.assertFalse(self.is_autocast_enabled()) -@unittest.skipIf(not xm.get_xla_supported_devices("GPU"), f"GPU autocast test.") +@unittest.skipIf(not xm.get_xla_supported_devices("CUDA"), + f"CUDA autocast test.") class TestAutocastCuda(TestAutocastBase): def setUp(self): diff --git a/test/test_ddp.py b/test/test_ddp.py index 2389cc51f0d..25e53790cc5 100644 --- a/test/test_ddp.py +++ b/test/test_ddp.py @@ -16,7 +16,7 @@ def _ddp_correctness(rank, use_large_net: bool, debug: bool): # We cannot run this guard before XMP, # see API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing. device = xm.xla_device() - if xm.xla_device_hw(device) not in ('GPU', 'TPU'): + if xm.xla_device_hw(device) not in ('GPU', 'TPU', 'CUDA', 'ROCM'): print( 'Default device {} is not a TPU device'.format(device), file=sys.stderr) diff --git a/test/test_fsdp_auto_wrap.py b/test/test_fsdp_auto_wrap.py index 5bd85bb6b94..b14fb769bc0 100644 --- a/test/test_fsdp_auto_wrap.py +++ b/test/test_fsdp_auto_wrap.py @@ -31,10 +31,10 @@ def forward(self, x): hidden2 = self.fc2(x) return hidden1, hidden2 - @unittest.skipIf( - xr.device_type() == 'GPU', - "This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)" - ) + @unittest.skipIf(xr.device_type() in ( + 'GPU', 'ROCM', 'CUDA' + ), "This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)" + ) def test(self): dev = xm.xla_device() input = torch.zeros([16, 16], device=dev) @@ -50,12 +50,12 @@ def test(self): def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): test = unittest.main(exit=False) sys.exit(0 if test.result.wasSuccessful() else 1) else: print( - 'Default device {} is not a TPU or GPU device'.format(device), + 'Default device {} is not a TPU or CUDA device'.format(device), file=sys.stderr) diff --git a/test/test_metrics.py b/test/test_metrics.py index 8f5b0cfd850..3037692ea65 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -164,7 +164,9 @@ def test_metrics_report(self): self.assertIn("CachedCompile", report) @unittest.skipIf( + xm.get_xla_supported_devices("CUDA") or xm.get_xla_supported_devices("GPU") or + xm.get_xla_supported_devices("ROCM") or xm.get_xla_supported_devices("TPU"), f"This test only works on CPU.") def test_execute_time_metric(self): # Initialize the client before starting the timer. diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index b8fee7e29ba..3ffeebc963d 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -13,7 +13,7 @@ def all_gather(tensor, dim): def _mp_fn(index): device = xm.xla_device() world_size = xm.xrt_world_size() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): # Testing with a single replica group ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device) result = xm.all_gather(ordinal_tensor, dim=0) diff --git a/test/test_mp_distributed_mm.py b/test/test_mp_distributed_mm.py index 3dc45732ac3..eb292f7a53d 100644 --- a/test/test_mp_distributed_mm.py +++ b/test/test_mp_distributed_mm.py @@ -9,7 +9,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): world_size = xm.xrt_world_size() torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) diff --git a/test/test_operations.py b/test/test_operations.py index 6a1be16ede3..db46574bd8c 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -434,7 +434,8 @@ def test_get_real_xla_devices(self): devices = xm.get_xla_supported_devices() xla_devices = torch_xla._XLAC._xla_real_devices(devices) for device, xdevice in zip(devices, xla_devices): - self.assertTrue(re.match(r'(CPU|GPU|TPU):\d+$', xdevice) is not None) + self.assertTrue( + re.match(r'(CPU|GPU|TPU|CUDA|ROCM):\d+$', xdevice) is not None) def test_negative_slice(self): t = _gen_tensor(32, 24, 32) diff --git a/test/test_torch_distributed_all_gather_xla_backend.py b/test/test_torch_distributed_all_gather_xla_backend.py index 763c15d6f5b..f75a019db86 100644 --- a/test/test_torch_distributed_all_gather_xla_backend.py +++ b/test/test_torch_distributed_all_gather_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() diff --git a/test/test_torch_distributed_all_reduce_xla_backend.py b/test/test_torch_distributed_all_reduce_xla_backend.py index 3f0bca31b8f..9962c824b7d 100644 --- a/test/test_torch_distributed_all_reduce_xla_backend.py +++ b/test/test_torch_distributed_all_reduce_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() diff --git a/test/test_torch_distributed_fsdp_frozen_weight.py b/test/test_torch_distributed_fsdp_frozen_weight.py index 79b65a46999..c626faf7447 100644 --- a/test/test_torch_distributed_fsdp_frozen_weight.py +++ b/test/test_torch_distributed_fsdp_frozen_weight.py @@ -8,9 +8,9 @@ def _mp_fn(index): dev = xm.xla_device() - if xm.xla_device_hw(dev) not in ('TPU', 'GPU'): + if xm.xla_device_hw(dev) not in ('TPU', 'CUDA'): print( - 'Default device {} is not a TPU or GPU device'.format(dev), + 'Default device {} is not a TPU or CUDA device'.format(dev), file=sys.stderr) return diff --git a/test/test_torch_distributed_multi_all_reduce_xla_backend.py b/test/test_torch_distributed_multi_all_reduce_xla_backend.py index cf16311ca98..e576c3ffb0f 100644 --- a/test/test_torch_distributed_multi_all_reduce_xla_backend.py +++ b/test/test_torch_distributed_multi_all_reduce_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() diff --git a/test/test_torch_distributed_reduce_scatter_xla_backend.py b/test/test_torch_distributed_reduce_scatter_xla_backend.py index f278567379e..fd146d98af7 100644 --- a/test/test_torch_distributed_reduce_scatter_xla_backend.py +++ b/test/test_torch_distributed_reduce_scatter_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() diff --git a/test/test_train_mp_imagenet_amp.py b/test/test_train_mp_imagenet_amp.py index bec112a9378..3ed92389715 100644 --- a/test/test_train_mp_imagenet_amp.py +++ b/test/test_train_mp_imagenet_amp.py @@ -221,7 +221,7 @@ def train_imagenet(): if FLAGS.amp: if device_hw == 'TPU': scaler = None - elif device_hw == 'GPU': + elif device_hw in ('GPU', 'CUDA', 'ROCM'): scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) def train_loop_fn(loader, epoch): diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index ae4db118300..3c9363f8d09 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -142,7 +142,7 @@ def train_mnist(flags, **kwargs): if device_hw == 'TPU': scaler = None - elif device_hw == 'GPU': + elif device_hw == 'CUDA': # GradScaler only used for GPU scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) else: diff --git a/test/test_zero1.py b/test/test_zero1.py index cb751726577..e9c3a3eeee6 100644 --- a/test/test_zero1.py +++ b/test/test_zero1.py @@ -13,7 +13,7 @@ class XlaZeRO1Test(TestCase): @unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU") - @unittest.skipIf(xr.device_type() == 'GPU', + @unittest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'), "TODO(alanwaketan): Fix it for the token change.") def test_zero1(self): device = xm.xla_device() diff --git a/torch_xla/_internal/pjrt.py b/torch_xla/_internal/pjrt.py index 9e7533955e4..67844b3ab39 100644 --- a/torch_xla/_internal/pjrt.py +++ b/torch_xla/_internal/pjrt.py @@ -138,7 +138,7 @@ def run_multiprocess(fn: Callable[..., R], """ if runtime.device_type() == 'TPU': num_processes = tpu.num_local_processes() - elif runtime.device_type() == 'GPU': + elif runtime.device_type() in ('GPU', 'ROCM', 'CUDA'): num_processes = gpu.num_local_processes() gpu.initialize_distributed_runtime(num_processes) elif runtime.device_type() == 'NEURON': @@ -160,7 +160,7 @@ def run_multiprocess(fn: Callable[..., R], itertools.chain.from_iterable( result.items() for result in process_results)) - if runtime.device_type() == 'GPU': + if runtime.device_type() in ('GPU', 'ROCM', 'CUDA'): gpu.shutdown_distributed_runtime() return _merge_replica_results(replica_results) @@ -210,8 +210,8 @@ def _initialize_single_process(local_rank: int, local_world_size: int): def spawn_threads(fn: Callable, args: Tuple = ()) -> None: """Run function in one process with one thread per addressable device.""" - assert runtime.device_type( - ) != 'GPU', "spawn_threads does not support GPU device" + assert runtime.device_type() not in ( + 'GPU', 'ROCM', 'CUDA'), "spawn_threads does not support GPU device" spawn_fn = _SpawnFn(fn, *args) _run_thread_per_device( local_rank=0, diff --git a/torch_xla/amp/autocast_mode.py b/torch_xla/amp/autocast_mode.py index fcdd4a40840..db7dc3d5d9b 100644 --- a/torch_xla/amp/autocast_mode.py +++ b/torch_xla/amp/autocast_mode.py @@ -25,7 +25,7 @@ def __init__(self, self._enabled = enabled self._xla_device = xm.xla_device_hw(device) - if self._xla_device == 'GPU': + if self._xla_device in ('GPU', 'ROCM', 'CUDA'): backend = 'cuda' self._xla_bfloat16 = False # True if xla backend with bfloat16 dtype. if dtype is None: @@ -70,7 +70,7 @@ def __init__(self, def __enter__(self): # This ensures that xla autocast is enabled even for XLA:GPU, which calls # `torch.amp.autocast_mode.autocast` with `cuda` backend. - if self._xla_device == 'GPU': + if self._xla_device in ('GPU', 'ROCM', 'CUDA'): self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined] self.prev_dtype = torch.get_autocast_xla_dtype( ) # type: ignore[attr-defined] @@ -86,7 +86,7 @@ def __enter__(self): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] - if self._xla_device == 'GPU': + if self._xla_device in ('GPU', 'ROCM', 'CUDA'): if self._xla_bfloat16: # autocast_xla flags will be set by `torch.autocast` and we need to # set autocast flags as we call into `torch.autocast` apis. diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 7f9682c1000..ff6d015b2b2 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -71,7 +71,7 @@ def is_xla_tensor(tensor): def parse_xla_device(device): - m = re.match(r'(CPU|TPU|GPU|XPU|NEURON):(\d+)$', device) + m = re.match(r'(CPU|TPU|GPU|ROCM|CUDA|XPU|NEURON):(\d+)$', device) if m: return (m.group(1), int(m.group(2))) @@ -89,7 +89,9 @@ def get_xla_supported_devices(devkind=None, max_devices=None): The list of device strings. """ xla_devices = _DEVICES.value - devkind = [devkind] if devkind else ['TPU', 'GPU', 'XPU', 'NEURON', 'CPU'] + devkind = [devkind] if devkind else [ + 'TPU', 'GPU', 'XPU', 'NEURON', 'CPU', 'CUDA', 'ROCM' + ] for kind in devkind: kind_devices = [] for i, device in enumerate(xla_devices): @@ -181,8 +183,8 @@ def xla_device(n=None, devkind=None): n (int, optional): The specific instance (ordinal) to be returned. If specified, the specific XLA device instance will be returned. Otherwise the first device of `devkind` will be returned. - devkind (string..., optional): If specified, one of `TPU`, `GPU`, `XPU` - `NEURON` or `CPU`. + devkind (string..., optional): If specified, one of `TPU`, `CUDA`, `XPU` + `NEURON`, `ROCM` or `CPU`. Returns: A `torch.device` with the requested instance. @@ -217,7 +219,7 @@ def xla_device_hw(device): real device. Returns: - A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `NEURON`, `GPU`) + A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `NEURON`, `GPU`, `CUDA`, `ROCM`) of the given device. """ real_device = _xla_real_device(device) diff --git a/torch_xla/csrc/random.cpp b/torch_xla/csrc/random.cpp index 44564c1d24f..78376797bbe 100644 --- a/torch_xla/csrc/random.cpp +++ b/torch_xla/csrc/random.cpp @@ -21,6 +21,8 @@ std::string GetDefaultGitGeneratorName() { static_cast(bridge::GetCurrentDevice().type()); switch (hw_type) { case XlaDeviceType::GPU: + case XlaDeviceType::CUDA: + case XlaDeviceType::ROCM: return "three_fry"; default: return "default"; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 30ebd247e93..237531d43cf 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -118,7 +118,8 @@ PjRtComputationClient::PjRtComputationClient() { client_ = std::move(xla::GetCApiClient("TPU").value()); } else if (device_type == "TPU_LEGACY") { XLA_ERROR() << "TPU_LEGACY client is no longer available."; - } else if (device_type == "GPU") { + } else if (device_type == "GPU" || device_type == "CUDA" || + device_type == "ROCM") { TF_VLOG(1) << "Initializing PjRt GPU client..."; bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true); int local_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0); diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 04bd60ce9a0..6322f052265 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -75,7 +75,9 @@ XLATensorImpl::XLATensorImpl(XLATensor&& tensor) // Upstream TensorImpl cannot differentiate between XLA:TPU and XLA:GPU // so we must manually update Autocast to AutocastCUDA on XLA:GPU. torch::lazy::BackendDevice current_device = bridge::GetCurrentDevice(); - if (static_cast(current_device.type()) == XlaDeviceType::GPU) { + auto dev_type = static_cast(current_device.type()); + if (dev_type == XlaDeviceType::GPU || dev_type == XlaDeviceType::CUDA || + dev_type == XlaDeviceType::ROCM) { auto autocast_cuda_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastCUDA); auto autocast_xla_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastXLA); key_set_ = (key_set_ - autocast_xla_ks) | autocast_cuda_ks; diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 3087f3c80f6..649c538bd33 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -41,7 +41,7 @@ def _maybe_select_default_device(): # TODO(wcromar): Detect GPU device elif xu.getenv_as(xenv.GPU_NUM_DEVICES, int, 0) > 0: logging.warning('GPU_NUM_DEVICES is set. Setting PJRT_DEVICE=GPU') - os.environ[xenv.PJRT_DEVICE] = 'GPU' + os.environ[xenv.PJRT_DEVICE] = 'CUDA' else: logging.warning('Defaulting to PJRT_DEVICE=CPU') os.environ[xenv.PJRT_DEVICE] = 'CPU'