Skip to content

Commit

Permalink
clean version
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Mar 21, 2024
1 parent 5480942 commit 74eb89b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,9 @@ def _fake_while_loop(cond_fn, body_fn, operands):
return operands

def _fake_fori_loop(lower, upper, body_fun, *init_val):
# operands need to be more than one here
# print("upper - lower: ", upper - lower)
# print("init_val: ", init_val)
# print("type init_val: ", type(init_val))
(a, b) = init_val
# print("a: ", a)
# print("b: ", b)
for i in range((upper - lower)[0]):
a = body_fun(a, b)
# print("a: ", a)
# print("i: ", i)
return a

class WhileLoopTest(unittest.TestCase):
Expand Down Expand Up @@ -81,22 +73,13 @@ def test_fori_loop_tpu_addition(self):
upper = torch.tensor([30], dtype=torch.int32, device=device)
one_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val_list = (init_val, one_value)
# lowers = torch.tensor([[1], [1], [1]], dtype=torch.int32, device=device) # lower, init_val, one_value

def body_fun(a, b):
return torch.add(a, b) # [0])
# _, _, res, _ = fori_loop(lower, upper, body_fun, init_val, one_value) # init_val_list) # init_val)
# A, B, res, D = fori_loop(lower, upper, body_fun, init_val, one_value) # init_val_list) # init_val)
# A, B, res, D = fori_loop(upper, body_fun, lowers) # lower, upper, body_fun, init_val, one_value)
return torch.add(a, b)

res, _ = fori_loop(lower, upper, body_fun, init_val, one_value)
print("result: ", res) # init_val_
# print("A: ", A) # lower_
# print("B: ", B) # upper_
# print("D: ", D) # one_value_
# print("lower[0] <= upper[0]: ", lower[0] <= upper[0])
# print("lower: ", lower)
# print("upper: ", upper)
# fori_loop(cond_fn, body_fn, (init, limit_value))
print("result: ", res)

expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value)
self.assertEqual(expected, res)

Expand Down
59 changes: 2 additions & 57 deletions torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,70 +12,20 @@
from torch._higher_order_ops.while_loop import while_loop_op


# def fori_loop(upper, body_fun, lowers):# upper, body_fun, *init_vals): # *init_val):
def fori_loop(lower, upper, body_fun, init_val, one_value):

device = xm.xla_device()
# limit_value = upper
# init = lower
# iterator = lower

# one_value is actually not used here, but actually redefined in body_fn to avoid introduce new argument in body_xlacomputation
# lower == init_val
# assert(lower == init_val)
init = lower # = init_val
init = lower
limit_value = upper

# one_value_original = torch.tensor([1], dtype=torch.int32, device=device)
# (a, b) = init_vals

# def cond_fn(init, limit_value):
# return limit_value[0] >= init[0]

# def body_fn(init, limit_value):
# one_value = torch.ones(1, dtype=torch.int32, device=device)
# return (torch.add(init, one_value), limit_value.clone())

# def cond_fn(upper, lowers): # lower, *init_vals):
def cond_fn(init, limit_value): # lower, *init_vals):
# init_val_compy = init_val.clone()
# one_value1 = torch.tensor([0], dtype=torch.int32, device=device)
# one_value2 = torch.tensor([0], dtype=torch.int32, device=device)
# lower = torch.add(lower, one_value1[0])
# lower = torch.sub(lower, one_value2[0])
# assert isinstance(init_vals[0], torch.Tensor)
# assert isinstance(init_vals[1], torch.Tensor)
# bool_value = isinstance(init_vals[0], torch.Tensor) and isinstance(init_vals[1], torch.Tensor)
# body_fun(*init_vals)
# result = True
# if (lower[0] <= upper[0]) and bool_value:
# return True
# return False
# bool_result = ((lower[0] <= upper[0]) and bool_value)
# bool_tensor = torch.tensor(bool_result, dtype=torch.bool)
# return bool_tensor # (lower[0] <= upper[0]) and bool_tensor
# return lower[0] <= upper[0]
# return lowers[0] <= upper[0]
def cond_fn(init, limit_value):
return limit_value[0] >= init[0]

# def body_fn(upper, lowers): # , *init_vals):
def body_fn(init, limit_value):
# one_value_original = torch.tensor(1, dtype=torch.int32, device=device)
# (a, b) = init_vals
# return (upper, torch.add(lower, 1), body_fun(a, b), b.clone())
# return (upper.clone(), (torch.add(lower.clone(), init_vals[1].clone())).clone(), (body_fun(*init_vals)).clone(), init_vals[1].clone()) # init_vals[1:])
# return (upper, (torch.add(lowers[0], lowers[2]), body_fun(lowers[1], lowers[2]), lowers[2])) # init_vals[1:])
# (body_fun(*init_vals)).clone(), init_vals[1].clone())
# body_fun(one_value_original, init_val)) # body_fun(lower, init_val))
one_value = torch.ones(1, dtype=torch.int32, device=device)
return (body_fun(init, one_value), limit_value.clone())

# res = while_loop(cond_fn, body_fn, (upper, lower, *init_vals))
# lowers = (lower, *init_vals)
# res = _xla_while_loop(cond_fn, body_fn, (upper, lowers)) # , *init_vals))
res = _xla_while_loop(cond_fn, body_fn, (init, limit_value))
# print("upper: ", upper)
# print("lower: ", lower)
return res

@while_loop_op.py_impl(DispatchKey.XLA)
Expand Down Expand Up @@ -133,9 +83,4 @@ def _xla_while_loop(cond_fn, body_fn, operands):
result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while',
tuple(operands), computation)

# print("operands: ", operands)
# print("upper: ", operands[0])
# print("lower: ", operands[1])
# print("init: ", operands[2])

return result

0 comments on commit 74eb89b

Please sign in to comment.