Skip to content

Commit

Permalink
Automatically move CUDA non XLA Tensors to XLA Device and back to CUD…
Browse files Browse the repository at this point in the history
…A device (pytorch#6644)
  • Loading branch information
changm authored and yitongh committed Dec 11, 2024
1 parent 5140bcd commit d45968c
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 3 deletions.
36 changes: 34 additions & 2 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_simple_model(self):
res_xla_dynamo = fn_simple_dynamo(xla_x, xla_y)
self.assertIn('xla::add', met.counter_names())
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))
# verifiy that tracing is skipped in following runs
# verify that tracing is skipped in following runs
met.clear_counters()
res_xla_dynamo_2 = fn_simple_dynamo(xla_x, xla_y)
self.assertNotIn('xla::add', met.counter_names())
Expand All @@ -103,6 +103,38 @@ def test_simple_model(self):
self.assertNotIn('XLAData: None',
torch_xla._XLAC._get_xla_tensor_debug_info(xla_y3))

# Tests that the dynamo bridge automatically moves tensors to XLA device,
# then back to the original device.
@unittest.skipIf(xr.device_type() != "CUDA",
f"GPU tests should only run on GPU devices.")
def test_simple_model_automoves_tensors(self):
x = torch.tensor(100.0).to(device="cuda")
y = torch.tensor(200.0).to(device="cuda")
original_device = x.device
eager_result = self.fn_simple(x, y)

# Since all tests run in the same process, have to reset the metrics report.
met.clear_all()

fn_simple_dynamo = torch.compile(self.fn_simple, backend="openxla")
res_xla_dynamo = fn_simple_dynamo(x, y)
self.assertIn('xla::add', met.counter_names())
self.assertTrue(res_xla_dynamo.device == original_device)
self.assertTrue(torch.allclose(eager_result, res_xla_dynamo))

# verify that tracing is skipped in following runs
met.clear_counters()
res_xla_dynamo_reused = fn_simple_dynamo(x, y)
self.assertNotIn('xla::add', met.counter_names())
self.assertTrue(res_xla_dynamo_reused.device == original_device)
self.assertTrue(torch.allclose(eager_result, res_xla_dynamo_reused))

# verify that dynamo can handle different inputs
res_xla_dynamo_different = fn_simple_dynamo(x + y, y * 3)
res_cpu_3 = self.fn_simple(x + y, y * 3)
self.assertTrue(res_xla_dynamo_different.device == original_device)
self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_different))

def test_fn_without_input(self):

def fn_without_input(device):
Expand Down Expand Up @@ -526,7 +558,7 @@ def test_resnet18(self):
met.metric_data('RunCachedGraphOutputData')[0], sample_count * 3)


class DynamErrorMessageTest(unittest.TestCase):
class DynamoErrorMessageTest(unittest.TestCase):

def test_mixed_cpu_tensor(self):
device = xm.xla_device()
Expand Down
77 changes: 76 additions & 1 deletion torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,67 @@ def get_fallback_ops():
return fallback_ops


# Checks that all input args that are tensors are on the same device.
def _get_input_arg_device(input_args: tuple) -> torch.device:
device = None
for arg in input_args:
if not isinstance(arg, torch.Tensor):
continue

if device is None:
device = arg.device
else:
assert arg.device == device, "Not all args are on the same device."

return device


# Returns True if all the input args are on a CUDA device.
def _args_on_cuda(input_args: tuple) -> bool:
input_device: torch.device = _get_input_arg_device(input_args)
if input_device is None:
return False

return input_device.type == "cuda"


# Given an input list, moves the tensors to the given target_device.
# The output order will be the same as the input. Non tensors will also still
# be in the list.
def _maybe_move_tensors_to_device(tensors: tuple,
target_device: torch.device) -> tuple:
if not torch.cuda.is_available():
return tensors

moved_tensors = []
cpu_device: torch.device = torch.device("cpu")

for tensor in tensors:
if not isinstance(tensor, torch.Tensor):
moved_tensors.append(tensor)
continue

if tensor.device == target_device:
moved_tensors.append(tensor)
continue

assert target_device is not None, "Moving tensors to None device not supported"

if dynamo_debug:
print("Moving Tensor {} to device {}".format(tensor, target_device))

# Have to move to CPU before moving it to target device.
moved_tensor = tensor.to(cpu_device)
moved_tensor = moved_tensor.to(target_device)

# Explicitly have to copy requires_grad attribute because it's dropped
# with torch.to(..)
moved_tensor.requires_grad = tensor.requires_grad
moved_tensors.append(moved_tensor)

return tuple(moved_tensors)


class Deduper:

def __init__(self):
Expand Down Expand Up @@ -375,7 +436,7 @@ def extract_internal(xla_model: torch.fx.GraphModule):
skip_checking_input_sharding_threashold = xu.getenv_as(
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)

def optimized_mod(*args):
def optimized_mod(*args: tuple):
nonlocal xla_model
nonlocal xla_args_sharding_spec
nonlocal args_and_out
Expand All @@ -387,6 +448,14 @@ def optimized_mod(*args):
nonlocal xla_args_need_update
nonlocal skip_checking_input_sharding_threashold

original_device: torch.device = _get_input_arg_device(args)
is_cuda_args: bool = False
if original_device:
is_cuda_args = original_device.type == "cuda"

if is_cuda_args:
args = _maybe_move_tensors_to_device(args, xm.xla_device())

# mark_step needs to be blocking since we want to access args's XLADatas
# and they can't be placeholder.
input_tensors_to_sync = [
Expand Down Expand Up @@ -437,6 +506,9 @@ def optimized_mod(*args):
result = res[len(xla_args_need_update):]

none_remover.add_nones(result)
if is_cuda_args:
result = _maybe_move_tensors_to_device(tuple(result), original_device)

if len(result) == 1:
return result[0]
else:
Expand Down Expand Up @@ -532,6 +604,9 @@ def allow_cpu_device(self, node: torch.fx.Node):


def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
if _args_on_cuda(xla_args):
xla_args = tuple(_maybe_move_tensors_to_device(xla_args, xm.xla_device()))

# Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its
# value reference before actually computing it.
for a in xla_args:
Expand Down

0 comments on commit d45968c

Please sign in to comment.