From 5a0ded86ecef329c729c30a5910ac7269bf59111 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 10 Apr 2024 22:24:13 +0000 Subject: [PATCH] test --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 544105a6de4..ba2ddaa7294 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -191,7 +191,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): # a, # generate cond_fn xlacomputation # TODO(@manfei): specify which element is for which argument like a,b,c # print("cond fake_carried_inputs[0]: ", fake_carried_inputs[0]) - cond_result = cond_fn(*fake_carried_inputs[:-1], output_value=fake_carried_inputs[-1]) + cond_result = cond_fn(*fake_carried_inputs) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") additional_inputs_list = list(fake_carried_inputs[2:]) @@ -207,7 +207,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): # a, # generate body_fn xlacomputation # body_result = body_fn(*fake_carried_inputs) # , a=additional_inputs[0], b=additional_inputs[1], c=additional_inputs[2]) - body_result = body_fn(*fake_carried_inputs[:-1], output_value=fake_carried_inputs[-1]) + body_result = body_fn(*fake_carried_inputs) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") body_ctx.buildforiloop(list(body_result), [])