diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 340b8f3cee4..46b748e2285 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -154,7 +154,7 @@ def test_while_loop_get_xlacomputation(self): # print("print computation from _get_stablehlo: !!!!!!!!!") # print(hlo_print) - def test_while_loop_get_xlacomputation(self): + def test_while_loop_get_xlacomputation_directly(self): xm.mark_step() device = xm.xla_device() @@ -173,7 +173,137 @@ def test_while_loop_get_xlacomputation(self): else: print("print computation from _get_xla_computation: null !!!!!!!!!!!!!") + def test_while_loop_get_xlacomputation_tpu_simple_linear_without_while_loop(self): + + xm.mark_step() + device = xm.xla_device() + #device = '' + torch.set_grad_enabled(False) + + class SimpleWithLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + self.register_buffer("dec", torch.tensor(1)) + + def forward(self, x): + x = self.linear(x) + return x + # def cond_fn(it, x): + # return it - self.dec > 0 + + # def body_fn(it, x): + # return it - 1, self.linear(x) + + # return while_loop(cond_fn, body_fn, (iter, x)) + + simple_with_linear = SimpleWithLinear() + simple_with_linear.to(device) + #breakpoint() + input = torch.randn(2, 2).to(device) + # iter = torch.tensor(3, device=device) + # res = simple_with_linear(iter, input) + t3 = simple_with_linear(input) + # t1 = torch.randn(20, 5).to(device) + # t2 = torch.randn(20, 5).to(device) + # t3 = torch.add(t1, t2) + + ### implement one new function for xlacomputation generation with post-order + print("before run _get_xla_computation: !!!!!!!!!") + res_xla_computation = torch_xla._XLAC._get_xla_computation([t3], [], True) + print("after run _get_xla_computation: !!!!!!!!!") + if res_xla_computation: + hlo_print = xb.get_computation_hlo(res_xla_computation) + print("print computation from _get_xla_computation: !!!!!!!!!") + print(hlo_print) + else: + print("print computation from _get_xla_computation: null !!!!!!!!!!!!!") + + def test_while_loop_get_xlacomputation_tpu_simple_linear_while_loop(self): + + xm.mark_step() + device = xm.xla_device() + #device = '' + torch.set_grad_enabled(False) + + class SimpleWithLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + self.register_buffer("dec", torch.tensor(1)) + + def forward(self, x): + x = self.linear(x) + return x + # def cond_fn(it, x): + # return it - self.dec > 0 + + # def body_fn(it, x): + # return it - 1, self.linear(x) + + # return while_loop(cond_fn, body_fn, (iter, x)) + + simple_with_linear = SimpleWithLinear() + simple_with_linear.to(device) + #breakpoint() + input = torch.randn(2, 2).to(device) + # iter = torch.tensor(3, device=device) + # res = simple_with_linear(iter, input) + t3 = simple_with_linear(input) + # t1 = torch.randn(20, 5).to(device) + # t2 = torch.randn(20, 5).to(device) + # t3 = torch.add(t1, t2) + + def cond_fn(upper, lower, one_value, x, input_value, output_value, *args): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value, *args): + new_lower = torch.add(one_value, lower) + output_value = simple_with_linear(input_value) + res = [upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone()] + return tuple(res) + # bn_list = [] + # for name, param in simple_with_linear.named_parameters(): + # if name[:2]=='bn': + # bn_list.append(param) + + # res.insert(-1, param) + + # # add still exist bn_list if the last additional_inputs is bn- pre + # # add at the tile + # if len(bn_list) !=0: + # output_value = res[-1] + # bn_list.reverse() + # res = res[:-1] + bn_list + # res.append(output_value) + # bn_list = [] + # return tuple(res) + + ### implement one new function for xlacomputation generation with post-order + print("before run _get_xla_computation: !!!!!!!!!") + res_xla_computation = torch_xla._XLAC._get_xla_computation([t3], [], True) + print("after run _get_xla_computation: !!!!!!!!!") + if res_xla_computation: + hlo_print = xb.get_computation_hlo(res_xla_computation) + print("print computation from _get_xla_computation: !!!!!!!!!") + print(hlo_print) + else: + print("print computation from _get_xla_computation: null !!!!!!!!!!!!!") + + ### get xlacomputation via PyLoweringContext + # body_result = body_fn(*fake_carried_inputs) + body_ctx = torch_xla._XLAC.lowering.LoweringContext() + body_ctx.set_name_string("bodyctx") + # body_ctx.buildforiloop(list(body_result), []) + body_ctx.buildforiloop(list(t3), []) + body_hlo = body_ctx.hlo() + body_computation = xb.computation_from_module_proto("bodycomputation", + body_hlo) + body_hlo_print = xb.get_computation_hlo(body_computation) + print("print computation from PyLoweringContext: !!!!!!!!!") + print(body_hlo_print) + if __name__ == '__main__': test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) \ No newline at end of file + sys.exit(0 if test.result.wasSuccessful() else 1)