Skip to content

Commit

Permalink
test range
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Mar 21, 2024
1 parent 070728b commit bf91ab5
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,40 @@ def fori_loop(lower, upper, body_fun, *init_vals): # *init_val):
# iterator = lower

# one_value_original = torch.tensor([1], dtype=torch.int32, device=device)
(a, b) = init_vals
# (a, b) = init_vals

def cond_fn(upper, lower, *init_vals):
def cond_fn(upper, lowers): # lower, *init_vals):
# init_val_compy = init_val.clone()
# one_value1 = torch.tensor([0], dtype=torch.int32, device=device)
# one_value2 = torch.tensor([0], dtype=torch.int32, device=device)
# lower = torch.add(lower, one_value1[0])
# lower = torch.sub(lower, one_value2[0])
# assert isinstance(init_vals[0], torch.Tensor)
# assert isinstance(init_vals[1], torch.Tensor)
bool_value = isinstance(init_vals[0], torch.Tensor) and isinstance(init_vals[1], torch.Tensor)
# bool_value = isinstance(init_vals[0], torch.Tensor) and isinstance(init_vals[1], torch.Tensor)
# body_fun(*init_vals)
# result = True
# if (lower[0] <= upper[0]) and bool_value:
# return True
# return False
bool_result = ((lower[0] <= upper[0]) and bool_value)
bool_tensor = torch.tensor(bool_result, dtype=torch.bool)
return False # bool_tensor # (lower[0] <= upper[0]) and bool_tensor
# bool_result = ((lower[0] <= upper[0]) and bool_value)
# bool_tensor = torch.tensor(bool_result, dtype=torch.bool)
# return bool_tensor # (lower[0] <= upper[0]) and bool_tensor
# return lower[0] <= upper[0]
return lowers[0] <= upper[0]

def body_fn(upper, lower, *init_vals):
def body_fn(upper, lowers): # , *init_vals):
# one_value_original = torch.tensor(1, dtype=torch.int32, device=device)
(a, b) = init_vals
# (a, b) = init_vals
# return (upper, torch.add(lower, 1), body_fun(a, b), b.clone())
# return (upper.clone(), (torch.add(lower.clone(), init_vals[1].clone())).clone(), (body_fun(*init_vals)).clone(), init_vals[1].clone()) # init_vals[1:])
return (upper, torch.add(lower, b), body_fun(a, b), b) # init_vals[1:])
return (upper, (torch.add(lowers[0], lowers[2]), body_fun(lowers[1], lowers[2]), lowers[2])) # init_vals[1:])
# (body_fun(*init_vals)).clone(), init_vals[1].clone())
# body_fun(one_value_original, init_val)) # body_fun(lower, init_val))

# res = while_loop(cond_fn, body_fn, (upper, lower, *init_vals))
res = _xla_while_loop(cond_fn, body_fn, (upper, lower, *init_vals))
lowers = (lower, *init_vals)
res = _xla_while_loop(cond_fn, body_fn, (upper, lowers)) # , *init_vals))
# print("upper: ", upper)
# print("lower: ", lower)
return res
Expand Down

0 comments on commit bf91ab5

Please sign in to comment.