diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 46c9b29e1a0..599dcb80c7f 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -39,7 +39,8 @@ def _xla_while_loop(cond_fn, body_fn, operands): cond_ctx.setnamestring("condctx") cond_ctx.build(list(cond_result)) cond_hlo = cond_ctx.hlo() - cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) + cond_computation = xb.computation_from_module_proto("condcomputation", + cond_hlo) # generate body_fn xlacomputation xm.mark_step() @@ -48,18 +49,22 @@ def _xla_while_loop(cond_fn, body_fn, operands): body_ctx.setnamestring("bodyctx") body_ctx.build(list(body_result)) body_hlo = body_ctx.hlo() - body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) + body_computation = xb.computation_from_module_proto("bodycomputation", + body_hlo) # create xla:While op with cond_computation and body_computation input_tuple = xb.Op.tuple(params) aaa_tuple = xb.Op.get_tuple_element(input_tuple, 0) - w = xb.mkop('While', [aaa_tuple.op], condition_computation=cond_computation, body_computation=body_computation) + w = xb.mkop( + 'While', [aaa_tuple.op], + condition_computation=cond_computation, + body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) # operands would be changed from torch.tensor([1]) to torch.tensor(1) after torch.compile when call torch._higher_order_ops.while_loop, so create a new input tesor here localoperands = torch.tensor([1], dtype=torch.int32, device=xm.xla_device()) - result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', (localoperands,), - computation) + result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', + (localoperands,), computation) return result