diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 42d0de34469..cd6fdf1130e 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -19,17 +19,9 @@ def _fake_while_loop(cond_fn, body_fn, operands): return operands def _fake_fori_loop(lower, upper, body_fun, *init_val): - # operands need to be more than one here - # print("upper - lower: ", upper - lower) - # print("init_val: ", init_val) - # print("type init_val: ", type(init_val)) (a, b) = init_val - # print("a: ", a) - # print("b: ", b) for i in range((upper - lower)[0]): a = body_fun(a, b) - # print("a: ", a) - # print("i: ", i) return a class WhileLoopTest(unittest.TestCase): @@ -81,22 +73,13 @@ def test_fori_loop_tpu_addition(self): upper = torch.tensor([30], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val_list = (init_val, one_value) - # lowers = torch.tensor([[1], [1], [1]], dtype=torch.int32, device=device) # lower, init_val, one_value def body_fun(a, b): - return torch.add(a, b) # [0]) - # _, _, res, _ = fori_loop(lower, upper, body_fun, init_val, one_value) # init_val_list) # init_val) - # A, B, res, D = fori_loop(lower, upper, body_fun, init_val, one_value) # init_val_list) # init_val) - # A, B, res, D = fori_loop(upper, body_fun, lowers) # lower, upper, body_fun, init_val, one_value) + return torch.add(a, b) + res, _ = fori_loop(lower, upper, body_fun, init_val, one_value) - print("result: ", res) # init_val_ - # print("A: ", A) # lower_ - # print("B: ", B) # upper_ - # print("D: ", D) # one_value_ - # print("lower[0] <= upper[0]: ", lower[0] <= upper[0]) - # print("lower: ", lower) - # print("upper: ", upper) - # fori_loop(cond_fn, body_fn, (init, limit_value)) + print("result: ", res) + expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value) self.assertEqual(expected, res) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 75856fac111..445157c05c7 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -12,70 +12,20 @@ from torch._higher_order_ops.while_loop import while_loop_op -# def fori_loop(upper, body_fun, lowers):# upper, body_fun, *init_vals): # *init_val): def fori_loop(lower, upper, body_fun, init_val, one_value): device = xm.xla_device() - # limit_value = upper - # init = lower - # iterator = lower - - # one_value is actually not used here, but actually redefined in body_fn to avoid introduce new argument in body_xlacomputation - # lower == init_val - # assert(lower == init_val) - init = lower # = init_val + init = lower limit_value = upper - # one_value_original = torch.tensor([1], dtype=torch.int32, device=device) - # (a, b) = init_vals - - # def cond_fn(init, limit_value): - # return limit_value[0] >= init[0] - - # def body_fn(init, limit_value): - # one_value = torch.ones(1, dtype=torch.int32, device=device) - # return (torch.add(init, one_value), limit_value.clone()) - - # def cond_fn(upper, lowers): # lower, *init_vals): - def cond_fn(init, limit_value): # lower, *init_vals): - # init_val_compy = init_val.clone() - # one_value1 = torch.tensor([0], dtype=torch.int32, device=device) - # one_value2 = torch.tensor([0], dtype=torch.int32, device=device) - # lower = torch.add(lower, one_value1[0]) - # lower = torch.sub(lower, one_value2[0]) - # assert isinstance(init_vals[0], torch.Tensor) - # assert isinstance(init_vals[1], torch.Tensor) - # bool_value = isinstance(init_vals[0], torch.Tensor) and isinstance(init_vals[1], torch.Tensor) - # body_fun(*init_vals) - # result = True - # if (lower[0] <= upper[0]) and bool_value: - # return True - # return False - # bool_result = ((lower[0] <= upper[0]) and bool_value) - # bool_tensor = torch.tensor(bool_result, dtype=torch.bool) - # return bool_tensor # (lower[0] <= upper[0]) and bool_tensor - # return lower[0] <= upper[0] - # return lowers[0] <= upper[0] + def cond_fn(init, limit_value): return limit_value[0] >= init[0] - # def body_fn(upper, lowers): # , *init_vals): def body_fn(init, limit_value): - # one_value_original = torch.tensor(1, dtype=torch.int32, device=device) - # (a, b) = init_vals - # return (upper, torch.add(lower, 1), body_fun(a, b), b.clone()) - # return (upper.clone(), (torch.add(lower.clone(), init_vals[1].clone())).clone(), (body_fun(*init_vals)).clone(), init_vals[1].clone()) # init_vals[1:]) - # return (upper, (torch.add(lowers[0], lowers[2]), body_fun(lowers[1], lowers[2]), lowers[2])) # init_vals[1:]) - # (body_fun(*init_vals)).clone(), init_vals[1].clone()) - # body_fun(one_value_original, init_val)) # body_fun(lower, init_val)) one_value = torch.ones(1, dtype=torch.int32, device=device) return (body_fun(init, one_value), limit_value.clone()) - # res = while_loop(cond_fn, body_fn, (upper, lower, *init_vals)) - # lowers = (lower, *init_vals) - # res = _xla_while_loop(cond_fn, body_fn, (upper, lowers)) # , *init_vals)) res = _xla_while_loop(cond_fn, body_fn, (init, limit_value)) - # print("upper: ", upper) - # print("lower: ", lower) return res @while_loop_op.py_impl(DispatchKey.XLA) @@ -133,9 +83,4 @@ def _xla_while_loop(cond_fn, body_fn, operands): result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', tuple(operands), computation) - # print("operands: ", operands) - # print("upper: ", operands[0]) - # print("lower: ", operands[1]) - # print("init: ", operands[2]) - return result