Skip to content

Commit

Permalink
test range
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Mar 21, 2024
1 parent 80d28c4 commit a9060e2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ def test_fori_loop_tpu_addition(self):
upper = torch.tensor([30], dtype=torch.int32, device=device)
one_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val_list = (init_val, one_value)
lowers = torch.tensor([[1], [1], [1]], dtype=torch.int32, device=device) # lower, init_val, one_value
# lowers = torch.tensor([[1], [1], [1]], dtype=torch.int32, device=device) # lower, init_val, one_value

def body_fun(a, b):
return torch.add(a, b) # [0])
# _, _, res, _ = fori_loop(lower, upper, body_fun, init_val, one_value) # init_val_list) # init_val)
# A, B, res, D = fori_loop(lower, upper, body_fun, init_val, one_value) # init_val_list) # init_val)
A, B, res, D = fori_loop(upper, body_fun, lowers) # lower, upper, body_fun, init_val, one_value)
A, B, res, D = fori_loop(lower, upper, body_fun, init_val, one_value) # init_val_list) # init_val)
# A, B, res, D = fori_loop(upper, body_fun, lowers) # lower, upper, body_fun, init_val, one_value)
print("result: ", res) # init_val_
print("A: ", A) # lower_
print("B: ", B) # upper_
Expand Down
29 changes: 24 additions & 5 deletions torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,32 @@
from torch._higher_order_ops.while_loop import while_loop_op


def fori_loop(upper, body_fun, lowers):# upper, body_fun, *init_vals): # *init_val):
# def fori_loop(upper, body_fun, lowers):# upper, body_fun, *init_vals): # *init_val):
def fori_loop(lower, upper, body_fun, init_val, one_value):

device = xm.xla_device()
# limit_value = upper
# init = lower
# iterator = lower

# one_value is actually not used here, but actually redefined in body_fn to avoid introduce new argument in body_xlacomputation
# lower == init_val
assert(lower == init_val)
init = lower # = init_val
limit_value = upper

# one_value_original = torch.tensor([1], dtype=torch.int32, device=device)
# (a, b) = init_vals

def cond_fn(upper, lowers): # lower, *init_vals):
# def cond_fn(init, limit_value):
# return limit_value[0] >= init[0]

# def body_fn(init, limit_value):
# one_value = torch.ones(1, dtype=torch.int32, device=device)
# return (torch.add(init, one_value), limit_value.clone())

# def cond_fn(upper, lowers): # lower, *init_vals):
def cond_fn(init, limit_value): # lower, *init_vals):
# init_val_compy = init_val.clone()
# one_value1 = torch.tensor([0], dtype=torch.int32, device=device)
# one_value2 = torch.tensor([0], dtype=torch.int32, device=device)
Expand All @@ -40,16 +55,20 @@ def cond_fn(upper, lowers): # lower, *init_vals):
# bool_tensor = torch.tensor(bool_result, dtype=torch.bool)
# return bool_tensor # (lower[0] <= upper[0]) and bool_tensor
# return lower[0] <= upper[0]
return lowers[0] <= upper[0]
# return lowers[0] <= upper[0]
return limit_value[0] >= init[0]

def body_fn(upper, lowers): # , *init_vals):
# def body_fn(upper, lowers): # , *init_vals):
def body_fn(init, limit_value):
# one_value_original = torch.tensor(1, dtype=torch.int32, device=device)
# (a, b) = init_vals
# return (upper, torch.add(lower, 1), body_fun(a, b), b.clone())
# return (upper.clone(), (torch.add(lower.clone(), init_vals[1].clone())).clone(), (body_fun(*init_vals)).clone(), init_vals[1].clone()) # init_vals[1:])
return (upper, (torch.add(lowers[0], lowers[2]), body_fun(lowers[1], lowers[2]), lowers[2])) # init_vals[1:])
# return (upper, (torch.add(lowers[0], lowers[2]), body_fun(lowers[1], lowers[2]), lowers[2])) # init_vals[1:])
# (body_fun(*init_vals)).clone(), init_vals[1].clone())
# body_fun(one_value_original, init_val)) # body_fun(lower, init_val))
one_value = torch.ones(1, dtype=torch.int32, device=device)
return (body_fun(init, one_value), limit_value.clone())

# res = while_loop(cond_fn, body_fn, (upper, lower, *init_vals))
# lowers = (lower, *init_vals)
Expand Down

0 comments on commit a9060e2

Please sign in to comment.