diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index bdfc1e5c4e4..38e58da8f68 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -34,7 +34,7 @@ def _xla_while_loop(cond_fn, body_fn, operands): cond_result = cond_fn(operands) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") - cond_ctx.build_for_while(list(cond_result)) + cond_ctx.build(list(cond_result)) cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) @@ -43,7 +43,7 @@ def _xla_while_loop(cond_fn, body_fn, operands): body_result = body_fn(operands) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") - body_ctx.build_for_while(list(body_result)) + body_ctx.build(list(body_result)) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo)