Skip to content

Commit

Permalink
Fixup test dynamo failures
Browse files Browse the repository at this point in the history
  • Loading branch information
changm committed Feb 29, 2024
1 parent ef25d39 commit 476edaf
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
36 changes: 18 additions & 18 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,23 @@ def test_simple_model_automoves_tensors(self):
self.assertTrue(res_cpu_3.device == res_xla_dynamo_different.device)
self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_different))

def test_resnet_all_cpu_tensor_moved_to_xla(self):
met.clear_all()
input = torch.randn(4, 3, 224, 224)
resnet18 = torchvision.models.resnet18()
resnet18.eval()
dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla')

# input and model weight on cpu
with warnings.catch_warnings(record=True) as w:
res = dynamo_resnet18_cpu(input)
# there should be 18 paramters + 1 input all moved to XLA Device.
self.assertTrue(len(w) == 0)

# Ops should work automatically and XLA should "just work"
self.assertTrue(len(met.counter_names()) > 1)
self.assertIn('MarkStep', met.counter_names())

def test_fn_without_input(self):

def fn_without_input(device):
Expand Down Expand Up @@ -542,7 +559,7 @@ def test_resnet18(self):
met.metric_data('RunCachedGraphOutputData')[0], sample_count * 3)


class DynamErrorMessageTest(unittest.TestCase):
class DynamoErrorMessageTest(unittest.TestCase):

def test_mixed_cpu_tensor(self):
device = xm.xla_device()
Expand All @@ -566,23 +583,6 @@ def test_mixed_cpu_tensor(self):
self.assertTrue(
'found two different devices' in context.exception.__str__())

def test_all_cpu_tensor(self):
met.clear_all()
input = torch.randn(4, 3, 224, 224)
resnet18 = torchvision.models.resnet18()
resnet18.eval()
dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla')
# input and model weight on cpu
with warnings.catch_warnings(record=True) as w:
res = dynamo_resnet18_cpu(input)
# there should be 18 paramters + 1 input
self.assertGreater(len(w), 15)
self.assertIn('Found tensor with shape torch.Size', str(w[0].message))
# no XLA operation should happens except a empty mark_step. Partitioner should offload all CPU
# ops to CPU.
self.assertEqual(len(met.counter_names()), 1)
self.assertIn('MarkStep', met.counter_names())


if __name__ == '__main__':
test = unittest.main()
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, input_args):

for xla_arg in xla_args:
assert xla_arg.device.type == 'xla', "Found tensor with shape " + str(
xla_arg.size()) + " on " + str(xla_arg.device)
xla_arg.size()) + " on non-XLA device: " + str(xla_arg.device)

cloned_args = [
torch.clone(xla_arg) if isinstance(xla_arg, torch.Tensor) else xla_arg
Expand Down

0 comments on commit 476edaf

Please sign in to comment.