Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically move CUDA non XLA Tensors to XLA Device and back to CUDA device #6644

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_simple_model(self):
res_xla_dynamo = fn_simple_dynamo(xla_x, xla_y)
self.assertIn('xla::add', met.counter_names())
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))
# verifiy that tracing is skipped in following runs
# verify that tracing is skipped in following runs
met.clear_counters()
res_xla_dynamo_2 = fn_simple_dynamo(xla_x, xla_y)
self.assertNotIn('xla::add', met.counter_names())
Expand All @@ -103,6 +103,35 @@ def test_simple_model(self):
self.assertNotIn('XLAData: None',
torch_xla._XLAC._get_xla_tensor_debug_info(xla_y3))

# Tests that the dynamo bridge automatically moves tensors to XLA device,
# then back to the original device.
@unittest.skipIf(xr.device_type() != "CUDA",
changm marked this conversation as resolved.
Show resolved Hide resolved
f"GPU tests should only run on GPU devices.")
def test_simple_model_automoves_tensors(self):
x = torch.tensor(100.0).to(device="cuda")
y = torch.tensor(200.0).to(device="cuda")
original_device = x.device
eager_result = self.fn_simple(x, y)

fn_simple_dynamo = torch.compile(self.fn_simple, backend="openxla")
res_xla_dynamo = fn_simple_dynamo(x, y)
self.assertIn('xla::add', met.counter_names())
changm marked this conversation as resolved.
Show resolved Hide resolved
self.assertTrue(res_xla_dynamo.device == original_device)
self.assertTrue(torch.allclose(eager_result, res_xla_dynamo))

# verify that tracing is skipped in following runs
met.clear_counters()
res_xla_dynamo_reused = fn_simple_dynamo(x, y)
self.assertNotIn('xla::add', met.counter_names())
self.assertTrue(res_xla_dynamo_reused.device == original_device)
self.assertTrue(torch.allclose(eager_result, res_xla_dynamo_reused))

# verify that dynamo can handle different inputs
res_xla_dynamo_different = fn_simple_dynamo(x + y, y * 3)
res_cpu_3 = self.fn_simple(x + y, y * 3)
self.assertTrue(res_xla_dynamo_different.device == original_device)
self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_different))

def test_fn_without_input(self):

def fn_without_input(device):
Expand Down Expand Up @@ -526,7 +555,7 @@ def test_resnet18(self):
met.metric_data('RunCachedGraphOutputData')[0], sample_count * 3)


class DynamErrorMessageTest(unittest.TestCase):
class DynamoErrorMessageTest(unittest.TestCase):

def test_mixed_cpu_tensor(self):
device = xm.xla_device()
Expand Down
73 changes: 72 additions & 1 deletion torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,65 @@ def get_fallback_ops():
return fallback_ops


# Checks that all input args that are tensors are on the same device.
def _get_input_arg_device(input_args: tuple) -> torch.device:
device = None
for arg in input_args:
if not isinstance(arg, torch.Tensor):
continue

if device is None:
device = arg.device
else:
assert arg.device == device, "Not all args are on the same device."
changm marked this conversation as resolved.
Show resolved Hide resolved

return device


# Returns True if all the input args are on a CUDA device.
def _args_on_cuda(input_args: tuple) -> bool:
input_device: torch.device = _get_input_arg_device(input_args)
if input_device is None:
return False

return input_device.type == "cuda"


# Given an input list, moves the tensors to the given target_device.
# The output order will be the same as the input. Non tensors will also still
# be in the list.
def _maybe_move_tensors_to_device(tensors: tuple,
target_device: torch.device) -> tuple:
if not torch.cuda.is_available():
return tensors
changm marked this conversation as resolved.
Show resolved Hide resolved

moved_tensors = []
cpu_device: torch.device = torch.device("cpu")

for tensor in tensors:
if not isinstance(tensor, torch.Tensor):
moved_tensors.append(tensor)
continue

if tensor.device == target_device:
moved_tensors.append(tensor)
continue

if dynamo_debug:
print("Moving Tensor {} to device {}".format(tensor, target_device))

# Have to move to CPU before moving it to target device.
moved_tensor = tensor.to(cpu_device)
moved_tensor = moved_tensor.to(target_device)
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved

# Explicitly have to copy requires_grad attribute because it's dropped
# with torch.to(..)
moved_tensor.requires_grad = tensor.requires_grad
moved_tensors.append(moved_tensor)

return tuple(moved_tensors)


class Deduper:

def __init__(self):
Expand Down Expand Up @@ -375,7 +434,7 @@ def extract_internal(xla_model: torch.fx.GraphModule):
skip_checking_input_sharding_threashold = xu.getenv_as(
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)

def optimized_mod(*args):
def optimized_mod(*args: tuple):
nonlocal xla_model
nonlocal xla_args_sharding_spec
nonlocal args_and_out
Expand All @@ -387,6 +446,12 @@ def optimized_mod(*args):
nonlocal xla_args_need_update
nonlocal skip_checking_input_sharding_threashold

original_device: torch.device = _get_input_arg_device(args)
is_cuda_args: bool = _args_on_cuda(args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_args_on_cuda will call _get_input_arg_device which is redundant.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think it's a little cleaner to do the redundant call, but removed the call here.


if is_cuda_args:
args = _maybe_move_tensors_to_device(args, xm.xla_device())

# mark_step needs to be blocking since we want to access args's XLADatas
# and they can't be placeholder.
input_tensors_to_sync = [
Expand Down Expand Up @@ -437,6 +502,9 @@ def optimized_mod(*args):
result = res[len(xla_args_need_update):]

none_remover.add_nones(result)
if is_cuda_args:
result = _maybe_move_tensors_to_device(tuple(result), original_device)

if len(result) == 1:
return result[0]
else:
Expand Down Expand Up @@ -532,6 +600,9 @@ def allow_cpu_device(self, node: torch.fx.Node):


def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
if _args_on_cuda(xla_args):
xla_args = tuple(_maybe_move_tensors_to_device(xla_args, xm.xla_device()))

# Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its
# value reference before actually computing it.
for a in xla_args:
Expand Down
Loading