diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 8ed3a7832009..7c033ad93b98 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -119,6 +119,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) + body_hlo_print = xb.get_computation_hlo(body_computation) + print("body computation: !!!!!!!!!") + print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs