diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 0f01d170e87..32a093e55a4 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -81,13 +81,13 @@ def test_fori_loop_tpu_addition(self): upper = torch.tensor([30], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val_list = (init_val, one_value) - lowers = torch.tensor([[1], [1], [1]], dtype=torch.int32, device=device) # lower, init_val, one_value + # lowers = torch.tensor([[1], [1], [1]], dtype=torch.int32, device=device) # lower, init_val, one_value def body_fun(a, b): return torch.add(a, b) # [0]) # _, _, res, _ = fori_loop(lower, upper, body_fun, init_val, one_value) # init_val_list) # init_val) - # A, B, res, D = fori_loop(lower, upper, body_fun, init_val, one_value) # init_val_list) # init_val) - A, B, res, D = fori_loop(upper, body_fun, lowers) # lower, upper, body_fun, init_val, one_value) + A, B, res, D = fori_loop(lower, upper, body_fun, init_val, one_value) # init_val_list) # init_val) + # A, B, res, D = fori_loop(upper, body_fun, lowers) # lower, upper, body_fun, init_val, one_value) print("result: ", res) # init_val_ print("A: ", A) # lower_ print("B: ", B) # upper_ diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index f1cad90ff27..594e125577d 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -12,17 +12,32 @@ from torch._higher_order_ops.while_loop import while_loop_op -def fori_loop(upper, body_fun, lowers):# upper, body_fun, *init_vals): # *init_val): +# def fori_loop(upper, body_fun, lowers):# upper, body_fun, *init_vals): # *init_val): +def fori_loop(lower, upper, body_fun, init_val, one_value): device = xm.xla_device() # limit_value = upper # init = lower # iterator = lower + # one_value is actually not used here, but actually redefined in body_fn to avoid introduce new argument in body_xlacomputation + # lower == init_val + assert(lower == init_val) + init = lower # = init_val + limit_value = upper + # one_value_original = torch.tensor([1], dtype=torch.int32, device=device) # (a, b) = init_vals - def cond_fn(upper, lowers): # lower, *init_vals): + # def cond_fn(init, limit_value): + # return limit_value[0] >= init[0] + + # def body_fn(init, limit_value): + # one_value = torch.ones(1, dtype=torch.int32, device=device) + # return (torch.add(init, one_value), limit_value.clone()) + + # def cond_fn(upper, lowers): # lower, *init_vals): + def cond_fn(init, limit_value): # lower, *init_vals): # init_val_compy = init_val.clone() # one_value1 = torch.tensor([0], dtype=torch.int32, device=device) # one_value2 = torch.tensor([0], dtype=torch.int32, device=device) @@ -40,16 +55,20 @@ def cond_fn(upper, lowers): # lower, *init_vals): # bool_tensor = torch.tensor(bool_result, dtype=torch.bool) # return bool_tensor # (lower[0] <= upper[0]) and bool_tensor # return lower[0] <= upper[0] - return lowers[0] <= upper[0] + # return lowers[0] <= upper[0] + return limit_value[0] >= init[0] - def body_fn(upper, lowers): # , *init_vals): + # def body_fn(upper, lowers): # , *init_vals): + def body_fn(init, limit_value): # one_value_original = torch.tensor(1, dtype=torch.int32, device=device) # (a, b) = init_vals # return (upper, torch.add(lower, 1), body_fun(a, b), b.clone()) # return (upper.clone(), (torch.add(lower.clone(), init_vals[1].clone())).clone(), (body_fun(*init_vals)).clone(), init_vals[1].clone()) # init_vals[1:]) - return (upper, (torch.add(lowers[0], lowers[2]), body_fun(lowers[1], lowers[2]), lowers[2])) # init_vals[1:]) + # return (upper, (torch.add(lowers[0], lowers[2]), body_fun(lowers[1], lowers[2]), lowers[2])) # init_vals[1:]) # (body_fun(*init_vals)).clone(), init_vals[1].clone()) # body_fun(one_value_original, init_val)) # body_fun(lower, init_val)) + one_value = torch.ones(1, dtype=torch.int32, device=device) + return (body_fun(init, one_value), limit_value.clone()) # res = while_loop(cond_fn, body_fn, (upper, lower, *init_vals)) # lowers = (lower, *init_vals)