Skip to content

Commit

Permalink
Adding more CUDA instead of GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Oct 10, 2023
1 parent 6d3a97a commit a40767a
Show file tree
Hide file tree
Showing 32 changed files with 59 additions and 48 deletions.
2 changes: 1 addition & 1 deletion test/cpp/test_aten_xla_tensor_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1512,7 +1512,7 @@ TEST_F(AtenXlaTensorTest, TestGroupNormBackward) {
/*cudnn_enabled=*/false);
};
torch::Tensor undef;
ForEachDevice({XlaDeviceType::GPU, XlaDeviceType::TPU},
ForEachDevice({XlaDeviceType::CUDA, XlaDeviceType::TPU},
[&](const torch::Device& device) {
TestBackward({input, undef_weight ? undef : weight,
undef_weight ? undef : bias},
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/test_aten_xla_tensor_6.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ TEST_F(AtenXlaTensorTest, TestEmbeddingBackward) {
TEST_F(AtenXlaTensorTest, TestAmpUpdateScale) {
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(bridge::GetDefaultDevice()->type());
if (hw_type != XlaDeviceType::GPU && hw_type != XlaDeviceType::CPU) {
if (hw_type != XlaDeviceType::CUDA && hw_type != XlaDeviceType::CPU) {
return;
}
torch::Tensor growth_tracker =
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/test_replication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class ReplicationTest : public AtenXlaTensorTestBase {};

TEST_F(ReplicationTest, TestNSingleReplication) {
WithAllDevices(
{XlaDeviceType::TPU, XlaDeviceType::GPU},
{XlaDeviceType::TPU, XlaDeviceType::CUDA},
[&](const std::vector<torch::lazy::BackendDevice>& devices,
const std::vector<torch::lazy::BackendDevice>& all_devices) {
TestSingleReplication(devices, all_devices);
Expand Down
2 changes: 1 addition & 1 deletion test/pjrt/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _ddp_init(index: int = ...):
def test_ddp_init(self):
pjrt.run_multiprocess(self._ddp_init)

@absltest.skipIf(xr.device_type() == 'GPU',
@absltest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'),
"GPU device is not supported by pjrt.spawn_threads")
def test_ddp_init_threaded(self):
pjrt.spawn_threads(self._ddp_init)
Expand Down
6 changes: 3 additions & 3 deletions test/pjrt/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TestExperimentalPjrt(parameterized.TestCase):
def setUp(self):
xr.set_device_type('CPU')

@parameterized.parameters(('CPU', 'CPU'), ('GPU', 'GPU'), ('TPU', 'TPU'),
@parameterized.parameters(('CPU', 'CPU'), ('CUDA', 'CUDA'), ('TPU', 'TPU'),
('TPU_C_API', 'TPU'), ('TPU_LEGACY', 'TPU'))
def test_device_type(self, pjrt_device, expected):
with mock.patch.dict(os.environ, {'PJRT_DEVICE': pjrt_device}, clear=True):
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_xla_device_error(self):
}, True), ('gpu_num_devives', {
'GPU_NUM_DEVICES': '4'
}, True), ('pjrt_gpu', {
'PJRT_DEVICE': 'GPU',
'PJRT_DEVICE': 'CUDA',
'GPU_NUM_DEVICES': '4'
}, True))
def test_pjrt_default_device(self, env_vars, expect_using_pjrt):
Expand All @@ -77,7 +77,7 @@ def test_pjrt_default_device(self, env_vars, expect_using_pjrt):
xr.using_pjrt()

if expect_using_pjrt:
self.assertIn(xr.device_type(), ['CPU', 'GPU', 'TPU'])
self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU', 'ROCM', 'GPU'])
else:
self.assertIsNone(xr.device_type())

Expand Down
4 changes: 2 additions & 2 deletions test/pjrt/test_runtime_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from absl.testing import absltest, parameterized


@unittest.skipIf(xr.device_type() != 'GPU',
@unittest.skipIf(xr.device_type() not in ('GPU', 'CUDA', 'ROCM'),
f"GPU tests should only run on GPU devices.")
class TestExperimentalPjrtGpu(parameterized.TestCase):

def setUp(self):
xr.set_device_type('GPU')
xr.set_device_type('CUDA')

os.environ.update({
xenv.PJRT_GPU_ASYNC_CLIENT: 'true',
Expand Down
2 changes: 1 addition & 1 deletion test/pytorch_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def union_of_disabled_tests(sets):
DISABLED_TORCH_TESTS = {
'TPU': prepare_match_set(DISABLED_TORCH_TESTS_TPU),
'CPU': prepare_match_set(DISABLED_TORCH_TESTS_CPU),
'GPU': prepare_match_set(DISABLED_TORCH_TESTS_GPU),
'CUDA': prepare_match_set(DISABLED_TORCH_TESTS_GPU),
}


Expand Down
2 changes: 1 addition & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function run_coverage {
function run_test {
echo "Running in PjRt runtime: $@"
if [ -x "$(command -v nvidia-smi)" ] && [ "$XLA_CUDA" != "0" ]; then
PJRT_DEVICE=GPU run_coverage "$@"
PJRT_DEVICE=CUDA run_coverage "$@"
else
# TODO(darisoy): run these tests with multiple CPU devices, this fails due to TF issue.
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$@"
Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_xla_sharding_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@unittest.skipIf(not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) == "GPU",
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM'),
f"Requires PJRT_DEVICE set to `TPU` or `CPU`.")
class XlaShardingTest(unittest.TestCase):

Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_xla_spmd_python_api_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

@unittest.skipIf(xr.device_type() not in ['GPU', 'TPU'],
@unittest.skipIf(xr.device_type() not in ['GPU', 'TPU', 'CUDA', 'ROCM'],
f"TPU/GPU autocast test.")
def test_xla_autocast_api(self):
device = xm.xla_device()
Expand Down
3 changes: 2 additions & 1 deletion test/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ def compare(first, second):
self.assertFalse(self.is_autocast_enabled())


@unittest.skipIf(not xm.get_xla_supported_devices("GPU"), f"GPU autocast test.")
@unittest.skipIf(not xm.get_xla_supported_devices("CUDA"),
f"CUDA autocast test.")
class TestAutocastCuda(TestAutocastBase):

def setUp(self):
Expand Down
2 changes: 1 addition & 1 deletion test/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _ddp_correctness(rank, use_large_net: bool, debug: bool):
# We cannot run this guard before XMP,
# see API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing.
device = xm.xla_device()
if xm.xla_device_hw(device) not in ('GPU', 'TPU'):
if xm.xla_device_hw(device) not in ('GPU', 'TPU', 'CUDA', 'ROCM'):
print(
'Default device {} is not a TPU device'.format(device),
file=sys.stderr)
Expand Down
12 changes: 6 additions & 6 deletions test/test_fsdp_auto_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def forward(self, x):
hidden2 = self.fc2(x)
return hidden1, hidden2

@unittest.skipIf(
xr.device_type() == 'GPU',
"This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)"
)
@unittest.skipIf(xr.device_type() in (
'GPU', 'ROCM', 'CUDA'
), "This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)"
)
def test(self):
dev = xm.xla_device()
input = torch.zeros([16, 16], device=dev)
Expand All @@ -50,12 +50,12 @@ def test(self):

def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
test = unittest.main(exit=False)
sys.exit(0 if test.result.wasSuccessful() else 1)
else:
print(
'Default device {} is not a TPU or GPU device'.format(device),
'Default device {} is not a TPU or CUDA device'.format(device),
file=sys.stderr)


Expand Down
2 changes: 2 additions & 0 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def test_metrics_report(self):
self.assertIn("CachedCompile", report)

@unittest.skipIf(
xm.get_xla_supported_devices("CUDA") or
xm.get_xla_supported_devices("GPU") or
xm.get_xla_supported_devices("ROCM") or
xm.get_xla_supported_devices("TPU"), f"This test only works on CPU.")
def test_execute_time_metric(self):
# Initialize the client before starting the timer.
Expand Down
2 changes: 1 addition & 1 deletion test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def all_gather(tensor, dim):
def _mp_fn(index):
device = xm.xla_device()
world_size = xm.xrt_world_size()
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
# Testing with a single replica group
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
result = xm.all_gather(ordinal_tensor, dim=0)
Expand Down
2 changes: 1 addition & 1 deletion test/test_mp_distributed_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def _mp_fn(index):
device = xm.xla_device()

if xm.xla_device_hw(device) in ('TPU', 'GPU'):
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
world_size = xm.xrt_world_size()
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
Expand Down
3 changes: 2 additions & 1 deletion test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ def test_get_real_xla_devices(self):
devices = xm.get_xla_supported_devices()
xla_devices = torch_xla._XLAC._xla_real_devices(devices)
for device, xdevice in zip(devices, xla_devices):
self.assertTrue(re.match(r'(CPU|GPU|TPU):\d+$', xdevice) is not None)
self.assertTrue(
re.match(r'(CPU|GPU|TPU|CUDA|ROCM):\d+$', xdevice) is not None)

def test_negative_slice(self):
t = _gen_tensor(32, 24, 32)
Expand Down
2 changes: 1 addition & 1 deletion test/test_torch_distributed_all_gather_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

Expand Down
2 changes: 1 addition & 1 deletion test/test_torch_distributed_all_reduce_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

Expand Down
4 changes: 2 additions & 2 deletions test/test_torch_distributed_fsdp_frozen_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

def _mp_fn(index):
dev = xm.xla_device()
if xm.xla_device_hw(dev) not in ('TPU', 'GPU'):
if xm.xla_device_hw(dev) not in ('TPU', 'CUDA'):
print(
'Default device {} is not a TPU or GPU device'.format(dev),
'Default device {} is not a TPU or CUDA device'.format(dev),
file=sys.stderr)
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

Expand Down
2 changes: 1 addition & 1 deletion test/test_torch_distributed_reduce_scatter_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_imagenet_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def train_imagenet():
if FLAGS.amp:
if device_hw == 'TPU':
scaler = None
elif device_hw == 'GPU':
elif device_hw in ('GPU', 'CUDA', 'ROCM'):
scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad)

def train_loop_fn(loader, epoch):
Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_mnist_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def train_mnist(flags, **kwargs):

if device_hw == 'TPU':
scaler = None
elif device_hw == 'GPU':
elif device_hw == 'CUDA':
# GradScaler only used for GPU
scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad)
else:
Expand Down
2 changes: 1 addition & 1 deletion test/test_zero1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class XlaZeRO1Test(TestCase):

@unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU")
@unittest.skipIf(xr.device_type() == 'GPU',
@unittest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'),
"TODO(alanwaketan): Fix it for the token change.")
def test_zero1(self):
device = xm.xla_device()
Expand Down
8 changes: 4 additions & 4 deletions torch_xla/_internal/pjrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def run_multiprocess(fn: Callable[..., R],
"""
if runtime.device_type() == 'TPU':
num_processes = tpu.num_local_processes()
elif runtime.device_type() == 'GPU':
elif runtime.device_type() in ('GPU', 'ROCM', 'CUDA'):
num_processes = gpu.num_local_processes()
gpu.initialize_distributed_runtime(num_processes)
elif runtime.device_type() == 'NEURON':
Expand All @@ -160,7 +160,7 @@ def run_multiprocess(fn: Callable[..., R],
itertools.chain.from_iterable(
result.items() for result in process_results))

if runtime.device_type() == 'GPU':
if runtime.device_type() in ('GPU', 'ROCM', 'CUDA'):
gpu.shutdown_distributed_runtime()

return _merge_replica_results(replica_results)
Expand Down Expand Up @@ -210,8 +210,8 @@ def _initialize_single_process(local_rank: int, local_world_size: int):

def spawn_threads(fn: Callable, args: Tuple = ()) -> None:
"""Run function in one process with one thread per addressable device."""
assert runtime.device_type(
) != 'GPU', "spawn_threads does not support GPU device"
assert runtime.device_type() not in (
'GPU', 'ROCM', 'CUDA'), "spawn_threads does not support GPU device"
spawn_fn = _SpawnFn(fn, *args)
_run_thread_per_device(
local_rank=0,
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/amp/autocast_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self,

self._enabled = enabled
self._xla_device = xm.xla_device_hw(device)
if self._xla_device == 'GPU':
if self._xla_device in ('GPU', 'ROCM', 'CUDA'):
backend = 'cuda'
self._xla_bfloat16 = False # True if xla backend with bfloat16 dtype.
if dtype is None:
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(self,
def __enter__(self):
# This ensures that xla autocast is enabled even for XLA:GPU, which calls
# `torch.amp.autocast_mode.autocast` with `cuda` backend.
if self._xla_device == 'GPU':
if self._xla_device in ('GPU', 'ROCM', 'CUDA'):
self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined]
self.prev_dtype = torch.get_autocast_xla_dtype(
) # type: ignore[attr-defined]
Expand All @@ -86,7 +86,7 @@ def __enter__(self):

def __exit__(self, exc_type: Any, exc_val: Any,
exc_tb: Any): # type: ignore[override]
if self._xla_device == 'GPU':
if self._xla_device in ('GPU', 'ROCM', 'CUDA'):
if self._xla_bfloat16:
# autocast_xla flags will be set by `torch.autocast` and we need to
# set autocast flags as we call into `torch.autocast` apis.
Expand Down
12 changes: 7 additions & 5 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def is_xla_tensor(tensor):


def parse_xla_device(device):
m = re.match(r'(CPU|TPU|GPU|XPU|NEURON):(\d+)$', device)
m = re.match(r'(CPU|TPU|GPU|ROCM|CUDA|XPU|NEURON):(\d+)$', device)
if m:
return (m.group(1), int(m.group(2)))

Expand All @@ -89,7 +89,9 @@ def get_xla_supported_devices(devkind=None, max_devices=None):
The list of device strings.
"""
xla_devices = _DEVICES.value
devkind = [devkind] if devkind else ['TPU', 'GPU', 'XPU', 'NEURON', 'CPU']
devkind = [devkind] if devkind else [
'TPU', 'GPU', 'XPU', 'NEURON', 'CPU', 'CUDA', 'ROCM'
]
for kind in devkind:
kind_devices = []
for i, device in enumerate(xla_devices):
Expand Down Expand Up @@ -181,8 +183,8 @@ def xla_device(n=None, devkind=None):
n (int, optional): The specific instance (ordinal) to be returned. If
specified, the specific XLA device instance will be returned. Otherwise
the first device of `devkind` will be returned.
devkind (string..., optional): If specified, one of `TPU`, `GPU`, `XPU`
`NEURON` or `CPU`.
devkind (string..., optional): If specified, one of `TPU`, `CUDA`, `XPU`
`NEURON`, `ROCM` or `CPU`.
Returns:
A `torch.device` with the requested instance.
Expand Down Expand Up @@ -217,7 +219,7 @@ def xla_device_hw(device):
real device.
Returns:
A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `NEURON`, `GPU`)
A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `NEURON`, `GPU`, `CUDA`, `ROCM`)
of the given device.
"""
real_device = _xla_real_device(device)
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ std::string GetDefaultGitGeneratorName() {
static_cast<XlaDeviceType>(bridge::GetCurrentDevice().type());
switch (hw_type) {
case XlaDeviceType::GPU:
case XlaDeviceType::CUDA:
case XlaDeviceType::ROCM:
return "three_fry";
default:
return "default";
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ PjRtComputationClient::PjRtComputationClient() {
client_ = std::move(xla::GetCApiClient("TPU").value());
} else if (device_type == "TPU_LEGACY") {
XLA_ERROR() << "TPU_LEGACY client is no longer available.";
} else if (device_type == "GPU") {
} else if (device_type == "GPU" || device_type == "CUDA" ||
device_type == "ROCM") {
TF_VLOG(1) << "Initializing PjRt GPU client...";
bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true);
int local_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0);
Expand Down
Loading

0 comments on commit a40767a

Please sign in to comment.