diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 544105a6de4..ba2ddaa7294 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -191,7 +191,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): # a, # generate cond_fn xlacomputation # TODO(@manfei): specify which element is for which argument like a,b,c # print("cond fake_carried_inputs[0]: ", fake_carried_inputs[0]) - cond_result = cond_fn(*fake_carried_inputs[:-1], output_value=fake_carried_inputs[-1]) + cond_result = cond_fn(*fake_carried_inputs) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") additional_inputs_list = list(fake_carried_inputs[2:]) @@ -207,7 +207,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): # a, # generate body_fn xlacomputation # body_result = body_fn(*fake_carried_inputs) # , a=additional_inputs[0], b=additional_inputs[1], c=additional_inputs[2]) - body_result = body_fn(*fake_carried_inputs[:-1], output_value=fake_carried_inputs[-1]) + body_result = body_fn(*fake_carried_inputs) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") body_ctx.buildforiloop(list(body_result), [])