-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
103 additions
and
96 deletions.
There are no files selected for viewing
199 changes: 103 additions & 96 deletions
199
test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,99 +1,106 @@ | ||
import numpy as np | ||
import os | ||
import unittest | ||
from typing import Callable, Dict, List | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.core.xla_builder as xb | ||
# We need to import the underlying implementation function to register with the dispatcher | ||
import torch_xla.experimental.fori_loop | ||
from torch_xla.experimental.fori_loop import fori_loop | ||
from torch._higher_order_ops.while_loop import while_loop | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.utils.utils as xu | ||
import torch_xla.core.xla_op_registry as xor | ||
|
||
from torch._C import DispatchKey | ||
from torch._ops import HigherOrderOperator | ||
import torch._higher_order_ops.while_loop | ||
from torch._higher_order_ops.while_loop import while_loop_op | ||
|
||
|
||
def fori_loop(lower, upper, user_body_func, *init_val): | ||
|
||
device = xm.xla_device() | ||
|
||
def cond_fn(upper, lower, *init_val): | ||
return lower[0] < upper[0] | ||
|
||
def body_fn(upper, lower, *init_val): | ||
one_value_i = torch.ones(1, dtype=torch.int32, device=device) | ||
res_list = list(user_body_func(*init_val)) | ||
res_list.insert(0, lower) | ||
res_list.insert(0, torch.sub(upper, one_value_i)) | ||
return res_list | ||
|
||
res = while_loop(cond_fn, body_fn, (lower, upper, *init_val)) | ||
return res | ||
|
||
|
||
@while_loop_op.py_impl(DispatchKey.XLA) | ||
def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): | ||
# TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') | ||
# cond_fn&body_fn: callable | ||
# carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) | ||
if additional_inputs is None: | ||
additional_inputs = tuple() | ||
return _xla_while_loop( | ||
cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) | ||
|
||
|
||
def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): | ||
# untuple carried_inputs from while_loop | ||
carried_inputs = carried_inputs[0] | ||
# fake carried_inputs to split formal code | ||
fake_carried_inputs = [] | ||
for carried_input in carried_inputs: | ||
device = carried_input.device | ||
fake_carried_inputs.append( | ||
torch.randint(10, carried_input.size(), | ||
dtype=carried_input.dtype).to(device)) | ||
fake_carried_inputs = tuple(fake_carried_inputs) | ||
|
||
# trans fake_carried_inputs from list(tensor) to list(xla::op) | ||
kwargs = {} | ||
if type(fake_carried_inputs) is tuple: | ||
shapes = xb.tensor_shape(fake_carried_inputs) | ||
else: | ||
shapes = xb.tensor_shape((fake_carried_inputs)) | ||
builder = xb.create_builder('test_while') | ||
params = [] | ||
for shape in shapes: | ||
p = xb.mkparam(builder, len(params), shape) | ||
params.append(p) | ||
|
||
# generate cond_fn xlacomputation | ||
cond_result = cond_fn(*fake_carried_inputs) | ||
cond_ctx = torch_xla._XLAC.lowering.LoweringContext() | ||
cond_ctx.set_name_string("condctx") | ||
cond_ctx.buildforiloop([cond_result], list(fake_carried_inputs[2:])) | ||
cond_hlo = cond_ctx.hlo() | ||
cond_computation = xb.computation_from_module_proto("condcomputation", | ||
cond_hlo) | ||
|
||
# generate body_fn xlacomputation | ||
body_result = body_fn(*fake_carried_inputs) | ||
body_ctx = torch_xla._XLAC.lowering.LoweringContext() | ||
body_ctx.set_name_string("bodyctx") | ||
body_ctx.buildforiloop(list(body_result), []) | ||
body_hlo = body_ctx.hlo() | ||
body_computation = xb.computation_from_module_proto("bodycomputation", | ||
body_hlo) | ||
|
||
# generate while xlacomputation | ||
input_tuple = xb.Op.tuple(tuple(params)) | ||
w = xb.mkop( | ||
'While', (input_tuple.op,), | ||
condition_computation=cond_computation, | ||
body_computation=body_computation) | ||
name = 'fori_loop_ed_torch_func' | ||
computation = w.build(name) | ||
|
||
# gain final result with generated while xlacomputation | ||
result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', | ||
(carried_inputs), computation) | ||
|
||
return result | ||
import torch_xla.core.xla_builder as xb | ||
|
||
|
||
def _fake_while_loop(cond_fn, body_fn, operands): | ||
# operands need to be more than one here | ||
while cond_fn(*operands): | ||
operands = body_fn(*operands) | ||
return operands | ||
|
||
|
||
def _fake_fori_loop(lower, upper, body_fun, *init_val): | ||
(plus_value, init_val) = init_val | ||
for i in range((upper - lower)[0]): | ||
plus_value, init_val = body_fun(plus_value, init_val) | ||
return init_val | ||
|
||
|
||
class WhileLoopTest(unittest.TestCase): | ||
|
||
def test_while_loop_tpu_subtraction(self): | ||
|
||
device = xm.xla_device() | ||
|
||
def cond_fn(init, limit_value): | ||
return limit_value[0] <= init[0] | ||
|
||
def body_fn(init, limit_value): | ||
one_value = torch.ones(1, dtype=torch.int32, device=device) | ||
two_value = limit_value.clone() | ||
return (torch.sub(init, one_value), two_value) | ||
|
||
init = torch.tensor([10], dtype=torch.int32, device=device) | ||
limit_value = torch.tensor([0], dtype=torch.int32, device=device) | ||
res = while_loop(cond_fn, body_fn, (init, limit_value)) | ||
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) | ||
self.assertEqual(expected, res) | ||
|
||
def test_while_loop_tpu_addition(self): | ||
|
||
device = xm.xla_device() | ||
|
||
def cond_fn(init, limit_value): | ||
return limit_value[0] >= init[0] | ||
|
||
def body_fn(init, limit_value): | ||
one_value = torch.ones(1, dtype=torch.int32, device=device) | ||
return (torch.add(init, one_value), limit_value.clone()) | ||
|
||
# TODO(@manfei): init and limit_value has to be torch.tensor. | ||
init = torch.tensor([0], dtype=torch.int32, device=device) | ||
limit_value = torch.tensor([10], dtype=torch.int32, device=device) | ||
res = while_loop(cond_fn, body_fn, (init, limit_value)) | ||
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) | ||
self.assertEqual(expected, res) | ||
|
||
def test_while_loop_tpu_subtraction_nested(self): | ||
|
||
device = xm.xla_device() | ||
|
||
def cond_fn(init, limit_value): | ||
return limit_value[0] <= init[0] | ||
|
||
def body_fn(init, limit_value): | ||
one_value = torch.ones(1, dtype=torch.int32, device=device) | ||
two_value = limit_value.clone() | ||
return (torch.sub(torch.sub(init, one_value), one_value), two_value) | ||
|
||
init = torch.tensor([10], dtype=torch.int32, device=device) | ||
limit_value = torch.tensor([0], dtype=torch.int32, device=device) | ||
res = while_loop(cond_fn, body_fn, (init, limit_value)) | ||
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) | ||
self.assertEqual(expected, res) | ||
|
||
def test_fori_loop_tpu_addition(self): | ||
|
||
xm.mark_step() | ||
device = xm.xla_device() | ||
|
||
lower = torch.tensor([2], dtype=torch.int32, device=device) | ||
upper = torch.tensor([52], dtype=torch.int32, device=device) | ||
plus_value = torch.tensor([1], dtype=torch.int32, device=device) | ||
init_val = torch.tensor([1], dtype=torch.int32, device=device) | ||
|
||
def body_fun(*argus): | ||
plus_value, init_val = argus | ||
return plus_value, torch.add(plus_value, init_val) | ||
|
||
_, _, _, actual = fori_loop(upper, lower, body_fun, plus_value, init_val) | ||
expected = _fake_fori_loop(lower, upper, body_fun, plus_value, init_val) | ||
self.assertEqual(expected, actual) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |