From 029090f651a4440249fea7237aa6a2489b3f9277 Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Wed, 13 Mar 2024 09:41:52 -0700 Subject: [PATCH] Automatically move CUDA non XLA Tensors to XLA Device and back to CUDA device (#6644) --- test/dynamo/test_dynamo.py | 36 ++++++++++++++- torch_xla/core/dynamo_bridge.py | 77 ++++++++++++++++++++++++++++++++- 2 files changed, 110 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index fb13ba94323..9f93f9b803b 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -83,7 +83,7 @@ def test_simple_model(self): res_xla_dynamo = fn_simple_dynamo(xla_x, xla_y) self.assertIn('xla::add', met.counter_names()) self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu())) - # verifiy that tracing is skipped in following runs + # verify that tracing is skipped in following runs met.clear_counters() res_xla_dynamo_2 = fn_simple_dynamo(xla_x, xla_y) self.assertNotIn('xla::add', met.counter_names()) @@ -103,6 +103,38 @@ def test_simple_model(self): self.assertNotIn('XLAData: None', torch_xla._XLAC._get_xla_tensor_debug_info(xla_y3)) + # Tests that the dynamo bridge automatically moves tensors to XLA device, + # then back to the original device. + @unittest.skipIf(xr.device_type() != "CUDA", + f"GPU tests should only run on GPU devices.") + def test_simple_model_automoves_tensors(self): + x = torch.tensor(100.0).to(device="cuda") + y = torch.tensor(200.0).to(device="cuda") + original_device = x.device + eager_result = self.fn_simple(x, y) + + # Since all tests run in the same process, have to reset the metrics report. + met.clear_all() + + fn_simple_dynamo = torch.compile(self.fn_simple, backend="openxla") + res_xla_dynamo = fn_simple_dynamo(x, y) + self.assertIn('xla::add', met.counter_names()) + self.assertTrue(res_xla_dynamo.device == original_device) + self.assertTrue(torch.allclose(eager_result, res_xla_dynamo)) + + # verify that tracing is skipped in following runs + met.clear_counters() + res_xla_dynamo_reused = fn_simple_dynamo(x, y) + self.assertNotIn('xla::add', met.counter_names()) + self.assertTrue(res_xla_dynamo_reused.device == original_device) + self.assertTrue(torch.allclose(eager_result, res_xla_dynamo_reused)) + + # verify that dynamo can handle different inputs + res_xla_dynamo_different = fn_simple_dynamo(x + y, y * 3) + res_cpu_3 = self.fn_simple(x + y, y * 3) + self.assertTrue(res_xla_dynamo_different.device == original_device) + self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_different)) + def test_fn_without_input(self): def fn_without_input(device): @@ -526,7 +558,7 @@ def test_resnet18(self): met.metric_data('RunCachedGraphOutputData')[0], sample_count * 3) -class DynamErrorMessageTest(unittest.TestCase): +class DynamoErrorMessageTest(unittest.TestCase): def test_mixed_cpu_tensor(self): device = xm.xla_device() diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index c87ede2ff5b..5a20aa2389d 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -111,6 +111,67 @@ def get_fallback_ops(): return fallback_ops +# Checks that all input args that are tensors are on the same device. +def _get_input_arg_device(input_args: tuple) -> torch.device: + device = None + for arg in input_args: + if not isinstance(arg, torch.Tensor): + continue + + if device is None: + device = arg.device + else: + assert arg.device == device, "Not all args are on the same device." + + return device + + +# Returns True if all the input args are on a CUDA device. +def _args_on_cuda(input_args: tuple) -> bool: + input_device: torch.device = _get_input_arg_device(input_args) + if input_device is None: + return False + + return input_device.type == "cuda" + + +# Given an input list, moves the tensors to the given target_device. +# The output order will be the same as the input. Non tensors will also still +# be in the list. +def _maybe_move_tensors_to_device(tensors: tuple, + target_device: torch.device) -> tuple: + if not torch.cuda.is_available(): + return tensors + + moved_tensors = [] + cpu_device: torch.device = torch.device("cpu") + + for tensor in tensors: + if not isinstance(tensor, torch.Tensor): + moved_tensors.append(tensor) + continue + + if tensor.device == target_device: + moved_tensors.append(tensor) + continue + + assert target_device is not None, "Moving tensors to None device not supported" + + if dynamo_debug: + print("Moving Tensor {} to device {}".format(tensor, target_device)) + + # Have to move to CPU before moving it to target device. + moved_tensor = tensor.to(cpu_device) + moved_tensor = moved_tensor.to(target_device) + + # Explicitly have to copy requires_grad attribute because it's dropped + # with torch.to(..) + moved_tensor.requires_grad = tensor.requires_grad + moved_tensors.append(moved_tensor) + + return tuple(moved_tensors) + + class Deduper: def __init__(self): @@ -375,7 +436,7 @@ def extract_internal(xla_model: torch.fx.GraphModule): skip_checking_input_sharding_threashold = xu.getenv_as( 'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5) - def optimized_mod(*args): + def optimized_mod(*args: tuple): nonlocal xla_model nonlocal xla_args_sharding_spec nonlocal args_and_out @@ -387,6 +448,14 @@ def optimized_mod(*args): nonlocal xla_args_need_update nonlocal skip_checking_input_sharding_threashold + original_device: torch.device = _get_input_arg_device(args) + is_cuda_args: bool = False + if original_device: + is_cuda_args = original_device.type == "cuda" + + if is_cuda_args: + args = _maybe_move_tensors_to_device(args, xm.xla_device()) + # mark_step needs to be blocking since we want to access args's XLADatas # and they can't be placeholder. input_tensors_to_sync = [ @@ -437,6 +506,9 @@ def optimized_mod(*args): result = res[len(xla_args_need_update):] none_remover.add_nones(result) + if is_cuda_args: + result = _maybe_move_tensors_to_device(tuple(result), original_device) + if len(result) == 1: return result[0] else: @@ -532,6 +604,9 @@ def allow_cpu_device(self, node: torch.fx.Node): def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): + if _args_on_cuda(xla_args): + xla_args = tuple(_maybe_move_tensors_to_device(xla_args, xm.xla_device())) + # Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its # value reference before actually computing it. for a in xla_args: