From dca9877c78e955ee3f6b4c00a824bea3a5343f04 Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Tue, 27 Feb 2024 18:27:12 +0000 Subject: [PATCH 1/3] Automove CUDA tensors to xla device if they aren't already on the XLA device --- test/dynamo/test_dynamo.py | 33 ++++++++++++++- torch_xla/core/dynamo_bridge.py | 73 ++++++++++++++++++++++++++++++++- 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index fb13ba94323..d46f44d6386 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -83,7 +83,7 @@ def test_simple_model(self): res_xla_dynamo = fn_simple_dynamo(xla_x, xla_y) self.assertIn('xla::add', met.counter_names()) self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu())) - # verifiy that tracing is skipped in following runs + # verify that tracing is skipped in following runs met.clear_counters() res_xla_dynamo_2 = fn_simple_dynamo(xla_x, xla_y) self.assertNotIn('xla::add', met.counter_names()) @@ -103,6 +103,35 @@ def test_simple_model(self): self.assertNotIn('XLAData: None', torch_xla._XLAC._get_xla_tensor_debug_info(xla_y3)) + # Tests that the dynamo bridge automatically moves tensors to XLA device, + # then back to the original device. + @unittest.skipIf(xr.device_type() != "CUDA", + f"GPU tests should only run on GPU devices.") + def test_simple_model_automoves_tensors(self): + x = torch.tensor(100.0).to(device="cuda") + y = torch.tensor(200.0).to(device="cuda") + original_device = x.device + eager_result = self.fn_simple(x, y) + + fn_simple_dynamo = torch.compile(self.fn_simple, backend="openxla") + res_xla_dynamo = fn_simple_dynamo(x, y) + self.assertIn('xla::add', met.counter_names()) + self.assertTrue(res_xla_dynamo.device == original_device) + self.assertTrue(torch.allclose(eager_result, res_xla_dynamo)) + + # verify that tracing is skipped in following runs + met.clear_counters() + res_xla_dynamo_reused = fn_simple_dynamo(x, y) + self.assertNotIn('xla::add', met.counter_names()) + self.assertTrue(res_xla_dynamo_reused.device == original_device) + self.assertTrue(torch.allclose(eager_result, res_xla_dynamo_reused)) + + # verify that dynamo can handle different inputs + res_xla_dynamo_different = fn_simple_dynamo(x + y, y * 3) + res_cpu_3 = self.fn_simple(x + y, y * 3) + self.assertTrue(res_xla_dynamo_different.device == original_device) + self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_different)) + def test_fn_without_input(self): def fn_without_input(device): @@ -526,7 +555,7 @@ def test_resnet18(self): met.metric_data('RunCachedGraphOutputData')[0], sample_count * 3) -class DynamErrorMessageTest(unittest.TestCase): +class DynamoErrorMessageTest(unittest.TestCase): def test_mixed_cpu_tensor(self): device = xm.xla_device() diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index c87ede2ff5b..7b31ebbbe1b 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -111,6 +111,65 @@ def get_fallback_ops(): return fallback_ops +# Checks that all input args that are tensors are on the same device. +def _get_input_arg_device(input_args: tuple) -> torch.device: + device = None + for arg in input_args: + if not isinstance(arg, torch.Tensor): + continue + + if device is None: + device = arg.device + else: + assert arg.device == device, "Not all args are on the same device." + + return device + + +# Returns True if all the input args are on a CUDA device. +def _args_on_cuda(input_args: tuple) -> bool: + input_device: torch.device = _get_input_arg_device(input_args) + if input_device is None: + return False + + return input_device.type == "cuda" + + +# Given an input list, moves the tensors to the given target_device. +# The output order will be the same as the input. Non tensors will also still +# be in the list. +def _maybe_move_tensors_to_device(tensors: tuple, + target_device: torch.device) -> tuple: + if not torch.cuda.is_available(): + return tensors + + moved_tensors = [] + cpu_device: torch.device = torch.device("cpu") + + for tensor in tensors: + if not isinstance(tensor, torch.Tensor): + moved_tensors.append(tensor) + continue + + if tensor.device == target_device: + moved_tensors.append(tensor) + continue + + if dynamo_debug: + print("Moving Tensor {} to device {}".format(tensor, target_device)) + + # Have to move to CPU before moving it to target device. + moved_tensor = tensor.to(cpu_device) + moved_tensor = moved_tensor.to(target_device) + + # Explicitly have to copy requires_grad attribute because it's dropped + # with torch.to(..) + moved_tensor.requires_grad = tensor.requires_grad + moved_tensors.append(moved_tensor) + + return tuple(moved_tensors) + + class Deduper: def __init__(self): @@ -375,7 +434,7 @@ def extract_internal(xla_model: torch.fx.GraphModule): skip_checking_input_sharding_threashold = xu.getenv_as( 'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5) - def optimized_mod(*args): + def optimized_mod(*args: tuple): nonlocal xla_model nonlocal xla_args_sharding_spec nonlocal args_and_out @@ -387,6 +446,12 @@ def optimized_mod(*args): nonlocal xla_args_need_update nonlocal skip_checking_input_sharding_threashold + original_device: torch.device = _get_input_arg_device(args) + is_cuda_args: bool = _args_on_cuda(args) + + if is_cuda_args: + args = _maybe_move_tensors_to_device(args, xm.xla_device()) + # mark_step needs to be blocking since we want to access args's XLADatas # and they can't be placeholder. input_tensors_to_sync = [ @@ -437,6 +502,9 @@ def optimized_mod(*args): result = res[len(xla_args_need_update):] none_remover.add_nones(result) + if is_cuda_args: + result = _maybe_move_tensors_to_device(tuple(result), original_device) + if len(result) == 1: return result[0] else: @@ -532,6 +600,9 @@ def allow_cpu_device(self, node: torch.fx.Node): def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): + if _args_on_cuda(xla_args): + xla_args = tuple(_maybe_move_tensors_to_device(xla_args, xm.xla_device())) + # Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its # value reference before actually computing it. for a in xla_args: From e03223c6d5fb31325d96a07d2adcdc6cdf9473f3 Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Mon, 11 Mar 2024 22:31:54 +0000 Subject: [PATCH 2/3] Clear metrics for test and remove redunant call --- test/dynamo/test_dynamo.py | 3 +++ torch_xla/core/dynamo_bridge.py | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index d46f44d6386..9f93f9b803b 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -113,6 +113,9 @@ def test_simple_model_automoves_tensors(self): original_device = x.device eager_result = self.fn_simple(x, y) + # Since all tests run in the same process, have to reset the metrics report. + met.clear_all() + fn_simple_dynamo = torch.compile(self.fn_simple, backend="openxla") res_xla_dynamo = fn_simple_dynamo(x, y) self.assertIn('xla::add', met.counter_names()) diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index 7b31ebbbe1b..f720bb5e338 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -447,7 +447,9 @@ def optimized_mod(*args: tuple): nonlocal skip_checking_input_sharding_threashold original_device: torch.device = _get_input_arg_device(args) - is_cuda_args: bool = _args_on_cuda(args) + is_cuda_args: bool = False + if original_device: + is_cuda_args = original_device.type == "cuda" if is_cuda_args: args = _maybe_move_tensors_to_device(args, xm.xla_device()) From db1adae5ddc8f705ff3fe8b0e5cb662111c2d365 Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Mon, 11 Mar 2024 22:39:29 +0000 Subject: [PATCH 3/3] Assert device exists when moving to a target device --- torch_xla/core/dynamo_bridge.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index f720bb5e338..5a20aa2389d 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -155,6 +155,8 @@ def _maybe_move_tensors_to_device(tensors: tuple, moved_tensors.append(tensor) continue + assert target_device is not None, "Moving tensors to None device not supported" + if dynamo_debug: print("Moving Tensor {} to device {}".format(tensor, target_device))