Skip to content

Commit

Permalink
clean version
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Mar 21, 2024
1 parent 5480942 commit 5f01f63
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ def test_fori_loop_tpu_addition(self):
upper = torch.tensor([30], dtype=torch.int32, device=device)
one_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val_list = (init_val, one_value)
# lowers = torch.tensor([[1], [1], [1]], dtype=torch.int32, device=device) # lower, init_val, one_value
# lowers = torch.tensor(([1], [1], [1]), dtype=torch.int32, device=device) # lower, init_val, one_value

def body_fun(a, b):
return torch.add(a, b) # [0])
# _, _, res, _ = fori_loop(lower, upper, body_fun, init_val, one_value) # init_val_list) # init_val)
# A, B, res, D = fori_loop(lower, upper, body_fun, init_val, one_value) # init_val_list) # init_val)
# A, B, res, D = fori_loop(upper, body_fun, lowers) # lower, upper, body_fun, init_val, one_value)
res, _ = fori_loop(lower, upper, body_fun, init_val, one_value)
res, _ = fori_loop(lower, upper, body_fun, (init_val, one_value))
print("result: ", res) # init_val_
# print("A: ", A) # lower_
# print("B: ", B) # upper_
Expand Down
87 changes: 52 additions & 35 deletions torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


# def fori_loop(upper, body_fun, lowers):# upper, body_fun, *init_vals): # *init_val):
def fori_loop(lower, upper, body_fun, init_val, one_value):
def fori_loop(lower, upper, body_fun, init_vals): # (init_val, one_value)):

device = xm.xla_device()
# limit_value = upper
Expand All @@ -23,8 +23,8 @@ def fori_loop(lower, upper, body_fun, init_val, one_value):
# one_value is actually not used here, but actually redefined in body_fn to avoid introduce new argument in body_xlacomputation
# lower == init_val
# assert(lower == init_val)
init = lower # = init_val
limit_value = upper
# init = lower # = init_val
# limit_value = upper

# one_value_original = torch.tensor([1], dtype=torch.int32, device=device)
# (a, b) = init_vals
Expand All @@ -36,44 +36,59 @@ def fori_loop(lower, upper, body_fun, init_val, one_value):
# one_value = torch.ones(1, dtype=torch.int32, device=device)
# return (torch.add(init, one_value), limit_value.clone())

# def cond_fn(upper, lowers): # lower, *init_vals):
def cond_fn(init, limit_value): # lower, *init_vals):
# init_val_compy = init_val.clone()
# one_value1 = torch.tensor([0], dtype=torch.int32, device=device)
# one_value2 = torch.tensor([0], dtype=torch.int32, device=device)
# lower = torch.add(lower, one_value1[0])
# lower = torch.sub(lower, one_value2[0])
# assert isinstance(init_vals[0], torch.Tensor)
# assert isinstance(init_vals[1], torch.Tensor)
# bool_value = isinstance(init_vals[0], torch.Tensor) and isinstance(init_vals[1], torch.Tensor)
# body_fun(*init_vals)
# result = True
# if (lower[0] <= upper[0]) and bool_value:
# return True
# return False
# bool_result = ((lower[0] <= upper[0]) and bool_value)
# bool_tensor = torch.tensor(bool_result, dtype=torch.bool)
# return bool_tensor # (lower[0] <= upper[0]) and bool_tensor
# return lower[0] <= upper[0]
# return lowers[0] <= upper[0]
return limit_value[0] >= init[0]
# cond_fn
def _fori_cond_fun(loop_carry):
i, upper, _ = loop_carry
return torch.lt(i, upper)

def _fori_body_fun(body_fun):
# body_fun = weakref.ref(body_fun)
def while_body_fun(loop_carry):
i, upper, x = loop_carry
one_value = torch.ones(1, dtype=torch.int32, device=device)
return torch.add(i, one_value), upper, body_fun()(i, x)
return while_body_fun

# # def cond_fn(upper, lowers): # lower, *init_vals):
# def cond_fn(init, limit_value): # lower, *init_vals):
# # init_val_compy = init_val.clone()
# # one_value1 = torch.tensor([0], dtype=torch.int32, device=device)
# # one_value2 = torch.tensor([0], dtype=torch.int32, device=device)
# # lower = torch.add(lower, one_value1[0])
# # lower = torch.sub(lower, one_value2[0])
# # assert isinstance(init_vals[0], torch.Tensor)
# # assert isinstance(init_vals[1], torch.Tensor)
# # bool_value = isinstance(init_vals[0], torch.Tensor) and isinstance(init_vals[1], torch.Tensor)
# # body_fun(*init_vals)
# # result = True
# # if (lower[0] <= upper[0]) and bool_value:
# # return True
# # return False
# # bool_result = ((lower[0] <= upper[0]) and bool_value)
# # bool_tensor = torch.tensor(bool_result, dtype=torch.bool)
# # return bool_tensor # (lower[0] <= upper[0]) and bool_tensor
# # return lower[0] <= upper[0]
# # return lowers[0] <= upper[0]
# return limit_value[0] >= init[0]

# def body_fn(upper, lowers): # , *init_vals):
def body_fn(init, limit_value):
# one_value_original = torch.tensor(1, dtype=torch.int32, device=device)
# (a, b) = init_vals
# return (upper, torch.add(lower, 1), body_fun(a, b), b.clone())
# return (upper.clone(), (torch.add(lower.clone(), init_vals[1].clone())).clone(), (body_fun(*init_vals)).clone(), init_vals[1].clone()) # init_vals[1:])
# return (upper, (torch.add(lowers[0], lowers[2]), body_fun(lowers[1], lowers[2]), lowers[2])) # init_vals[1:])
# (body_fun(*init_vals)).clone(), init_vals[1].clone())
# body_fun(one_value_original, init_val)) # body_fun(lower, init_val))
one_value = torch.ones(1, dtype=torch.int32, device=device)
return (body_fun(init, one_value), limit_value.clone())
# def body_fn(init, limit_value):
# # one_value_original = torch.tensor(1, dtype=torch.int32, device=device)
# # (a, b) = init_vals
# # return (upper, torch.add(lower, 1), body_fun(a, b), b.clone())
# # return (upper.clone(), (torch.add(lower.clone(), init_vals[1].clone())).clone(), (body_fun(*init_vals)).clone(), init_vals[1].clone()) # init_vals[1:])
# # return (upper, (torch.add(lowers[0], lowers[2]), body_fun(lowers[1], lowers[2]), lowers[2])) # init_vals[1:])
# # (body_fun(*init_vals)).clone(), init_vals[1].clone())
# # body_fun(one_value_original, init_val)) # body_fun(lower, init_val))
# one_value = torch.ones(1, dtype=torch.int32, device=device)
# return (body_fun(init, one_value), limit_value.clone())

# res = while_loop(cond_fn, body_fn, (upper, lower, *init_vals))
# lowers = (lower, *init_vals)
# res = _xla_while_loop(cond_fn, body_fn, (upper, lowers)) # , *init_vals))
res = _xla_while_loop(cond_fn, body_fn, (init, limit_value))
# res = _xla_while_loop(cond_fn, body_fn, (init, limit_value))
_, _, result = _xla_while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
(lower, upper, init_val))
# print("upper: ", upper)
# print("lower: ", lower)
return res
Expand All @@ -87,6 +102,8 @@ def while_loop(cond_fn, body_fn, operands):

def _xla_while_loop(cond_fn, body_fn, operands):

print("!!! arguments: cond_fn: ", cond_fn, ", body_fn: ", body_fn, ", operands: ", operands)

# create inputs placeholder
kwargs = {}
shapes = xb.tensor_shape(operands)
Expand Down

0 comments on commit 5f01f63

Please sign in to comment.