Skip to content

Commit

Permalink
add test code for xlacomputation
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed May 8, 2024
1 parent 531d5c1 commit 43527e6
Showing 1 changed file with 132 additions and 2 deletions.
134 changes: 132 additions & 2 deletions test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_while_loop_get_xlacomputation(self):
# print("print computation from _get_stablehlo: !!!!!!!!!")
# print(hlo_print)

def test_while_loop_get_xlacomputation(self):
def test_while_loop_get_xlacomputation_directly(self):

xm.mark_step()
device = xm.xla_device()
Expand All @@ -173,7 +173,137 @@ def test_while_loop_get_xlacomputation(self):
else:
print("print computation from _get_xla_computation: null !!!!!!!!!!!!!")

def test_while_loop_get_xlacomputation_tpu_simple_linear_without_while_loop(self):

xm.mark_step()
device = xm.xla_device()
#device = ''
torch.set_grad_enabled(False)

class SimpleWithLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
self.register_buffer("dec", torch.tensor(1))

def forward(self, x):
x = self.linear(x)
return x
# def cond_fn(it, x):
# return it - self.dec > 0

# def body_fn(it, x):
# return it - 1, self.linear(x)

# return while_loop(cond_fn, body_fn, (iter, x))

simple_with_linear = SimpleWithLinear()
simple_with_linear.to(device)
#breakpoint()
input = torch.randn(2, 2).to(device)
# iter = torch.tensor(3, device=device)
# res = simple_with_linear(iter, input)
t3 = simple_with_linear(input)
# t1 = torch.randn(20, 5).to(device)
# t2 = torch.randn(20, 5).to(device)
# t3 = torch.add(t1, t2)

### implement one new function for xlacomputation generation with post-order
print("before run _get_xla_computation: !!!!!!!!!")
res_xla_computation = torch_xla._XLAC._get_xla_computation([t3], [], True)
print("after run _get_xla_computation: !!!!!!!!!")
if res_xla_computation:
hlo_print = xb.get_computation_hlo(res_xla_computation)
print("print computation from _get_xla_computation: !!!!!!!!!")
print(hlo_print)
else:
print("print computation from _get_xla_computation: null !!!!!!!!!!!!!")

def test_while_loop_get_xlacomputation_tpu_simple_linear_while_loop(self):

xm.mark_step()
device = xm.xla_device()
#device = ''
torch.set_grad_enabled(False)

class SimpleWithLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
self.register_buffer("dec", torch.tensor(1))

def forward(self, x):
x = self.linear(x)
return x
# def cond_fn(it, x):
# return it - self.dec > 0

# def body_fn(it, x):
# return it - 1, self.linear(x)

# return while_loop(cond_fn, body_fn, (iter, x))

simple_with_linear = SimpleWithLinear()
simple_with_linear.to(device)
#breakpoint()
input = torch.randn(2, 2).to(device)
# iter = torch.tensor(3, device=device)
# res = simple_with_linear(iter, input)
t3 = simple_with_linear(input)
# t1 = torch.randn(20, 5).to(device)
# t2 = torch.randn(20, 5).to(device)
# t3 = torch.add(t1, t2)

def cond_fn(upper, lower, one_value, x, input_value, output_value, *args):
return lower[0] < upper[0]

def body_fn(upper, lower, one_value, x, input_value, output_value, *args):
new_lower = torch.add(one_value, lower)
output_value = simple_with_linear(input_value)
res = [upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone()]
return tuple(res)
# bn_list = []
# for name, param in simple_with_linear.named_parameters():
# if name[:2]=='bn':
# bn_list.append(param)

# res.insert(-1, param)

# # add still exist bn_list if the last additional_inputs is bn- pre
# # add at the tile
# if len(bn_list) !=0:
# output_value = res[-1]
# bn_list.reverse()
# res = res[:-1] + bn_list
# res.append(output_value)
# bn_list = []
# return tuple(res)

### implement one new function for xlacomputation generation with post-order
print("before run _get_xla_computation: !!!!!!!!!")
res_xla_computation = torch_xla._XLAC._get_xla_computation([t3], [], True)
print("after run _get_xla_computation: !!!!!!!!!")
if res_xla_computation:
hlo_print = xb.get_computation_hlo(res_xla_computation)
print("print computation from _get_xla_computation: !!!!!!!!!")
print(hlo_print)
else:
print("print computation from _get_xla_computation: null !!!!!!!!!!!!!")

### get xlacomputation via PyLoweringContext
# body_result = body_fn(*fake_carried_inputs)
body_ctx = torch_xla._XLAC.lowering.LoweringContext()
body_ctx.set_name_string("bodyctx")
# body_ctx.buildforiloop(list(body_result), [])
body_ctx.buildforiloop(list(t3), [])
body_hlo = body_ctx.hlo()
body_computation = xb.computation_from_module_proto("bodycomputation",
body_hlo)
body_hlo_print = xb.get_computation_hlo(body_computation)
print("print computation from PyLoweringContext: !!!!!!!!!")
print(body_hlo_print)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
sys.exit(0 if test.result.wasSuccessful() else 1)

0 comments on commit 43527e6

Please sign in to comment.