diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 3db18f4bc0f..f80da7156bd 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -20,9 +20,9 @@ def fori_loop(lower, upper, body_fun, *init_vals): # *init_val): # iterator = lower # one_value_original = torch.tensor([1], dtype=torch.int32, device=device) - (a, b) = init_vals + # (a, b) = init_vals - def cond_fn(upper, lower, *init_vals): + def cond_fn(upper, lowers): # lower, *init_vals): # init_val_compy = init_val.clone() # one_value1 = torch.tensor([0], dtype=torch.int32, device=device) # one_value2 = torch.tensor([0], dtype=torch.int32, device=device) @@ -30,27 +30,30 @@ def cond_fn(upper, lower, *init_vals): # lower = torch.sub(lower, one_value2[0]) # assert isinstance(init_vals[0], torch.Tensor) # assert isinstance(init_vals[1], torch.Tensor) - bool_value = isinstance(init_vals[0], torch.Tensor) and isinstance(init_vals[1], torch.Tensor) + # bool_value = isinstance(init_vals[0], torch.Tensor) and isinstance(init_vals[1], torch.Tensor) # body_fun(*init_vals) # result = True # if (lower[0] <= upper[0]) and bool_value: # return True # return False - bool_result = ((lower[0] <= upper[0]) and bool_value) - bool_tensor = torch.tensor(bool_result, dtype=torch.bool) - return False # bool_tensor # (lower[0] <= upper[0]) and bool_tensor + # bool_result = ((lower[0] <= upper[0]) and bool_value) + # bool_tensor = torch.tensor(bool_result, dtype=torch.bool) + # return bool_tensor # (lower[0] <= upper[0]) and bool_tensor + # return lower[0] <= upper[0] + return lowers[0] <= upper[0] - def body_fn(upper, lower, *init_vals): + def body_fn(upper, lowers): # , *init_vals): # one_value_original = torch.tensor(1, dtype=torch.int32, device=device) - (a, b) = init_vals + # (a, b) = init_vals # return (upper, torch.add(lower, 1), body_fun(a, b), b.clone()) # return (upper.clone(), (torch.add(lower.clone(), init_vals[1].clone())).clone(), (body_fun(*init_vals)).clone(), init_vals[1].clone()) # init_vals[1:]) - return (upper, torch.add(lower, b), body_fun(a, b), b) # init_vals[1:]) + return (upper, (torch.add(lowers[0], lowers[2]), body_fun(lowers[1], lowers[2]), lowers[2])) # init_vals[1:]) # (body_fun(*init_vals)).clone(), init_vals[1].clone()) # body_fun(one_value_original, init_val)) # body_fun(lower, init_val)) # res = while_loop(cond_fn, body_fn, (upper, lower, *init_vals)) - res = _xla_while_loop(cond_fn, body_fn, (upper, lower, *init_vals)) + lowers = (lower, *init_vals) + res = _xla_while_loop(cond_fn, body_fn, (upper, lowers)) # , *init_vals)) # print("upper: ", upper) # print("lower: ", lower) return res