diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index bf32a712f3e..a76197cc736 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -1,99 +1,106 @@ -import numpy as np +import os +import unittest +from typing import Callable, Dict, List + import torch import torch_xla -import torch_xla.core.xla_builder as xb +# We need to import the underlying implementation function to register with the dispatcher +import torch_xla.experimental.fori_loop +from torch_xla.experimental.fori_loop import fori_loop +from torch._higher_order_ops.while_loop import while_loop import torch_xla.core.xla_model as xm -import torch_xla.utils.utils as xu -import torch_xla.core.xla_op_registry as xor - -from torch._C import DispatchKey -from torch._ops import HigherOrderOperator -import torch._higher_order_ops.while_loop -from torch._higher_order_ops.while_loop import while_loop_op - - -def fori_loop(lower, upper, user_body_func, *init_val): - - device = xm.xla_device() - - def cond_fn(upper, lower, *init_val): - return lower[0] < upper[0] - - def body_fn(upper, lower, *init_val): - one_value_i = torch.ones(1, dtype=torch.int32, device=device) - res_list = list(user_body_func(*init_val)) - res_list.insert(0, lower) - res_list.insert(0, torch.sub(upper, one_value_i)) - return res_list - - res = while_loop(cond_fn, body_fn, (lower, upper, *init_val)) - return res - - -@while_loop_op.py_impl(DispatchKey.XLA) -def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): - # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') - # cond_fn&body_fn: callable - # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) - if additional_inputs is None: - additional_inputs = tuple() - return _xla_while_loop( - cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) - - -def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): - # untuple carried_inputs from while_loop - carried_inputs = carried_inputs[0] - # fake carried_inputs to split formal code - fake_carried_inputs = [] - for carried_input in carried_inputs: - device = carried_input.device - fake_carried_inputs.append( - torch.randint(10, carried_input.size(), - dtype=carried_input.dtype).to(device)) - fake_carried_inputs = tuple(fake_carried_inputs) - - # trans fake_carried_inputs from list(tensor) to list(xla::op) - kwargs = {} - if type(fake_carried_inputs) is tuple: - shapes = xb.tensor_shape(fake_carried_inputs) - else: - shapes = xb.tensor_shape((fake_carried_inputs)) - builder = xb.create_builder('test_while') - params = [] - for shape in shapes: - p = xb.mkparam(builder, len(params), shape) - params.append(p) - - # generate cond_fn xlacomputation - cond_result = cond_fn(*fake_carried_inputs) - cond_ctx = torch_xla._XLAC.lowering.LoweringContext() - cond_ctx.set_name_string("condctx") - cond_ctx.buildforiloop([cond_result], list(fake_carried_inputs[2:])) - cond_hlo = cond_ctx.hlo() - cond_computation = xb.computation_from_module_proto("condcomputation", - cond_hlo) - - # generate body_fn xlacomputation - body_result = body_fn(*fake_carried_inputs) - body_ctx = torch_xla._XLAC.lowering.LoweringContext() - body_ctx.set_name_string("bodyctx") - body_ctx.buildforiloop(list(body_result), []) - body_hlo = body_ctx.hlo() - body_computation = xb.computation_from_module_proto("bodycomputation", - body_hlo) - - # generate while xlacomputation - input_tuple = xb.Op.tuple(tuple(params)) - w = xb.mkop( - 'While', (input_tuple.op,), - condition_computation=cond_computation, - body_computation=body_computation) - name = 'fori_loop_ed_torch_func' - computation = w.build(name) - - # gain final result with generated while xlacomputation - result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', - (carried_inputs), computation) - - return result \ No newline at end of file +import torch_xla.core.xla_builder as xb + + +def _fake_while_loop(cond_fn, body_fn, operands): + # operands need to be more than one here + while cond_fn(*operands): + operands = body_fn(*operands) + return operands + + +def _fake_fori_loop(lower, upper, body_fun, *init_val): + (plus_value, init_val) = init_val + for i in range((upper - lower)[0]): + plus_value, init_val = body_fun(plus_value, init_val) + return init_val + + +class WhileLoopTest(unittest.TestCase): + + def test_while_loop_tpu_subtraction(self): + + device = xm.xla_device() + + def cond_fn(init, limit_value): + return limit_value[0] <= init[0] + + def body_fn(init, limit_value): + one_value = torch.ones(1, dtype=torch.int32, device=device) + two_value = limit_value.clone() + return (torch.sub(init, one_value), two_value) + + init = torch.tensor([10], dtype=torch.int32, device=device) + limit_value = torch.tensor([0], dtype=torch.int32, device=device) + res = while_loop(cond_fn, body_fn, (init, limit_value)) + expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) + self.assertEqual(expected, res) + + def test_while_loop_tpu_addition(self): + + device = xm.xla_device() + + def cond_fn(init, limit_value): + return limit_value[0] >= init[0] + + def body_fn(init, limit_value): + one_value = torch.ones(1, dtype=torch.int32, device=device) + return (torch.add(init, one_value), limit_value.clone()) + + # TODO(@manfei): init and limit_value has to be torch.tensor. + init = torch.tensor([0], dtype=torch.int32, device=device) + limit_value = torch.tensor([10], dtype=torch.int32, device=device) + res = while_loop(cond_fn, body_fn, (init, limit_value)) + expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) + self.assertEqual(expected, res) + + def test_while_loop_tpu_subtraction_nested(self): + + device = xm.xla_device() + + def cond_fn(init, limit_value): + return limit_value[0] <= init[0] + + def body_fn(init, limit_value): + one_value = torch.ones(1, dtype=torch.int32, device=device) + two_value = limit_value.clone() + return (torch.sub(torch.sub(init, one_value), one_value), two_value) + + init = torch.tensor([10], dtype=torch.int32, device=device) + limit_value = torch.tensor([0], dtype=torch.int32, device=device) + res = while_loop(cond_fn, body_fn, (init, limit_value)) + expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) + self.assertEqual(expected, res) + + def test_fori_loop_tpu_addition(self): + + xm.mark_step() + device = xm.xla_device() + + lower = torch.tensor([2], dtype=torch.int32, device=device) + upper = torch.tensor([52], dtype=torch.int32, device=device) + plus_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + + def body_fun(*argus): + plus_value, init_val = argus + return plus_value, torch.add(plus_value, init_val) + + _, _, _, actual = fori_loop(upper, lower, body_fun, plus_value, init_val) + expected = _fake_fori_loop(lower, upper, body_fun, plus_value, init_val) + self.assertEqual(expected, actual) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) \ No newline at end of file