Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Apr 10, 2024
1 parent bc2e035 commit 5a0ded8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): # a,
# generate cond_fn xlacomputation
# TODO(@manfei): specify which element is for which argument like a,b,c
# print("cond fake_carried_inputs[0]: ", fake_carried_inputs[0])
cond_result = cond_fn(*fake_carried_inputs[:-1], output_value=fake_carried_inputs[-1])
cond_result = cond_fn(*fake_carried_inputs)
cond_ctx = torch_xla._XLAC.lowering.LoweringContext()
cond_ctx.set_name_string("condctx")
additional_inputs_list = list(fake_carried_inputs[2:])
Expand All @@ -207,7 +207,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): # a,

# generate body_fn xlacomputation
# body_result = body_fn(*fake_carried_inputs) # , a=additional_inputs[0], b=additional_inputs[1], c=additional_inputs[2])
body_result = body_fn(*fake_carried_inputs[:-1], output_value=fake_carried_inputs[-1])
body_result = body_fn(*fake_carried_inputs)
body_ctx = torch_xla._XLAC.lowering.LoweringContext()
body_ctx.set_name_string("bodyctx")
body_ctx.buildforiloop(list(body_result), [])
Expand Down

0 comments on commit 5a0ded8

Please sign in to comment.