From 5b8fcb7fc93702705d8b6104e002242cca86aee9 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 22:03:56 +0000 Subject: [PATCH] down into cpp --- torch_xla/experimental/fori_loop.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 265f42bf7ad5..6d406fdecaf7 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -171,9 +171,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - cond_hlo_print = xb.get_computation_hlo(cond_computation) - print("cond computation: !!!!!!!!!") - print(cond_hlo_print) + # cond_hlo_print = xb.get_computation_hlo(cond_computation) + # print("cond computation: !!!!!!!!!") + # print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) @@ -192,9 +192,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - body_hlo_print = xb.get_computation_hlo(body_computation) - print("body computation: !!!!!!!!!") - print(body_hlo_print) + # body_hlo_print = xb.get_computation_hlo(body_computation) + # print("body computation: !!!!!!!!!") + # print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs @@ -223,9 +223,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) - hlo_print = xb.get_computation_hlo(computation) - print("while computation: !!!!!!!!!") - print(hlo_print) + # hlo_print = xb.get_computation_hlo(computation) + # print("while computation: !!!!!!!!!") + # print(hlo_print) # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while',