Skip to content

Commit

Permalink
Update torch.compile usage int est (#5584)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Sep 18, 2023
1 parent 786722c commit 83692be
Showing 1 changed file with 14 additions and 22 deletions.
36 changes: 14 additions & 22 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,24 @@ def fn_simple(self, x, y):
b = torch.sin(y)
return a + b

@torch.compile(backend='openxla')
def fn_simple_dynamo(self, x, y):
return self.fn_simple(x, y)

def test_simple_model(self):
device = xm.xla_device()
x = torch.tensor(100.0)
y = torch.tensor(200.0)
xla_x = x.to(device)
xla_y = y.to(device)
res_cpu = self.fn_simple(x, y)
res_xla_dynamo = self.fn_simple_dynamo(xla_x, xla_y)
fn_simple_dynamo = torch.compile(self.fn_simple, backend="openxla")
res_xla_dynamo = fn_simple_dynamo(xla_x, xla_y)
self.assertIn('xla::add', met.counter_names())
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))
# verifiy that tracing is skipped in following runs
met.clear_counters()
res_xla_dynamo_2 = self.fn_simple_dynamo(xla_x, xla_y)
res_xla_dynamo_2 = fn_simple_dynamo(xla_x, xla_y)
self.assertNotIn('xla::add', met.counter_names())
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo_2.cpu()))
# verify that dynamo can handle different inputs
res_xla_dynamo_3 = self.fn_simple_dynamo(xla_x + xla_y, xla_y * 3)
res_xla_dynamo_3 = fn_simple_dynamo(xla_x + xla_y, xla_y * 3)
res_cpu_3 = self.fn_simple(x + y, y * 3)
self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_3.cpu()))

Expand Down Expand Up @@ -178,14 +175,15 @@ def test_simple_model_with_different_input_shape(self):
xla_x = torch.randn(5, 5).to(device)
xla_y = torch.randn(5, 5).to(device)
xla_z = torch.randn(10, 10).to(device)
self.fn_simple_dynamo(xla_x, xla_x)
fn_simple_dynamo = torch.compile(self.fn_simple, backend="openxla")
fn_simple_dynamo(xla_x, xla_x)
compile_count = met.metric_data('CompileTime')[0]
# Execute with input with same shape should not trigger additional compilation
self.fn_simple_dynamo(xla_y, xla_y)
fn_simple_dynamo(xla_y, xla_y)
self.assertEqual(met.metric_data('CompileTime')[0], compile_count)
# Give `fn_simple_dynamo` an input with different shappe, we expect
# dynamo to recognize this is a different graph and let XLA to retrace/recompile
res_xla_dynamo_3 = self.fn_simple_dynamo(xla_z, xla_z)
res_xla_dynamo_3 = fn_simple_dynamo(xla_z, xla_z)
self.assertEqual(met.metric_data('CompileTime')[0], compile_count + 1)
self.assertTrue(
torch.allclose(
Expand Down Expand Up @@ -319,10 +317,6 @@ def fn_simple(self, input):
loss.backward()
return loss

@torch.compile(backend='openxla')
def fn_simple_dynamo(self, input):
return self.fn_simple(input)

def train_model(self, model, data, target):
loss_fn = torch.nn.CrossEntropyLoss()
pred = model(data)
Expand All @@ -337,7 +331,8 @@ def test_simple_model(self):
xla_input = input.detach().to(device)
xla_input.requires_grad = True
res_cpu = self.fn_simple(input)
res_xla_dynamo = self.fn_simple_dynamo(xla_input)
fn_simple_dynamo = torch.compile(self.fn_simple, backend="openxla")
res_xla_dynamo = fn_simple_dynamo(xla_input)
self.assertIn('xla::nll_loss_backward', met.counter_names())
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))
self.assertTrue(
Expand All @@ -346,7 +341,7 @@ def test_simple_model(self):
# verifiy that tracing is skipped in following runs
xla_input.grad = None
met.clear_counters()
res_xla_dynamo_2 = self.fn_simple_dynamo(xla_input)
res_xla_dynamo_2 = fn_simple_dynamo(xla_input)
self.assertNotIn('xla::nll_loss_backward', met.counter_names())
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo_2.cpu()))
self.assertTrue(
Expand All @@ -355,7 +350,7 @@ def test_simple_model(self):
# verify that dynamo can handle different inputs
input.grad = None
xla_input.grad = None
res_xla_dynamo_3 = self.fn_simple_dynamo(xla_input * 2)
res_xla_dynamo_3 = fn_simple_dynamo(xla_input * 2)
res_cpu_3 = self.fn_simple(input * 2)
self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_3.cpu()))
self.assertTrue(
Expand Down Expand Up @@ -429,10 +424,6 @@ def fn_simple(self, input, optimizer):
optimizer.step()
return loss

@torch.compile(backend='openxla')
def fn_simple_dynamo(self, input, optimizer):
return self.fn_simple(input, optimizer)

def train_model(self, model, data, target, optimizer):
loss_fn = torch.nn.CrossEntropyLoss()
optimizer.zero_grad(True)
Expand All @@ -457,7 +448,8 @@ def test_simple_model(self):
# fwd + bwd is not being captured, hence we will get one lazy graph
# + one dynamo optimizer graph
res_cpu = self.fn_simple(input, optimizer)
res_xla_dynamo = self.fn_simple_dynamo(xla_input, xla_optimizer)
fn_simple_dynamo = torch.compile(self.fn_simple, backend="openxla")
res_xla_dynamo = fn_simple_dynamo(xla_input, xla_optimizer)
assert torch.allclose(res_cpu, res_xla_dynamo.cpu())
assert torch.allclose(
input.grad, xla_input.grad.cpu(), rtol=1e-04, atol=1e-04)
Expand Down

0 comments on commit 83692be

Please sign in to comment.