diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 92288bd335c..82d2a3ba62e 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -118,6 +118,8 @@ def new_body_fn(*carried_inputs): # return list(body_fn(*carried_inputs)).extend(additional_inputs) res = list(body_fn(*carried_inputs)) print("res: ", res) + trynewres = res[:-1] + res[-1] + print("trynewres: ", trynewres) newres = res[:-1] + list(additional_inputs) + res[-1] print("newres: ", newres) res.insert(-2, *additional_inputs)