diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 1c151ee9ec1..f1cad90ff27 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -84,6 +84,9 @@ def _xla_while_loop(cond_fn, body_fn, operands): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) + cond_hlo_print = xb.get_computation_hlo(cond_computation) + print("cond computation: !!!!!!!!!") + print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*operands)