Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed May 7, 2024
1 parent 4752586 commit 759653d
Showing 1 changed file with 103 additions and 96 deletions.
199 changes: 103 additions & 96 deletions test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Original file line number Diff line number Diff line change
@@ -1,99 +1,106 @@
import numpy as np
import os
import unittest
from typing import Callable, Dict, List

import torch
import torch_xla
import torch_xla.core.xla_builder as xb
# We need to import the underlying implementation function to register with the dispatcher
import torch_xla.experimental.fori_loop
from torch_xla.experimental.fori_loop import fori_loop
from torch._higher_order_ops.while_loop import while_loop
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu
import torch_xla.core.xla_op_registry as xor

from torch._C import DispatchKey
from torch._ops import HigherOrderOperator
import torch._higher_order_ops.while_loop
from torch._higher_order_ops.while_loop import while_loop_op


def fori_loop(lower, upper, user_body_func, *init_val):

device = xm.xla_device()

def cond_fn(upper, lower, *init_val):
return lower[0] < upper[0]

def body_fn(upper, lower, *init_val):
one_value_i = torch.ones(1, dtype=torch.int32, device=device)
res_list = list(user_body_func(*init_val))
res_list.insert(0, lower)
res_list.insert(0, torch.sub(upper, one_value_i))
return res_list

res = while_loop(cond_fn, body_fn, (lower, upper, *init_val))
return res


@while_loop_op.py_impl(DispatchKey.XLA)
def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None):
# TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '')
# cond_fn&body_fn: callable
# carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors)
if additional_inputs is None:
additional_inputs = tuple()
return _xla_while_loop(
cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs)


def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs):
# untuple carried_inputs from while_loop
carried_inputs = carried_inputs[0]
# fake carried_inputs to split formal code
fake_carried_inputs = []
for carried_input in carried_inputs:
device = carried_input.device
fake_carried_inputs.append(
torch.randint(10, carried_input.size(),
dtype=carried_input.dtype).to(device))
fake_carried_inputs = tuple(fake_carried_inputs)

# trans fake_carried_inputs from list(tensor) to list(xla::op)
kwargs = {}
if type(fake_carried_inputs) is tuple:
shapes = xb.tensor_shape(fake_carried_inputs)
else:
shapes = xb.tensor_shape((fake_carried_inputs))
builder = xb.create_builder('test_while')
params = []
for shape in shapes:
p = xb.mkparam(builder, len(params), shape)
params.append(p)

# generate cond_fn xlacomputation
cond_result = cond_fn(*fake_carried_inputs)
cond_ctx = torch_xla._XLAC.lowering.LoweringContext()
cond_ctx.set_name_string("condctx")
cond_ctx.buildforiloop([cond_result], list(fake_carried_inputs[2:]))
cond_hlo = cond_ctx.hlo()
cond_computation = xb.computation_from_module_proto("condcomputation",
cond_hlo)

# generate body_fn xlacomputation
body_result = body_fn(*fake_carried_inputs)
body_ctx = torch_xla._XLAC.lowering.LoweringContext()
body_ctx.set_name_string("bodyctx")
body_ctx.buildforiloop(list(body_result), [])
body_hlo = body_ctx.hlo()
body_computation = xb.computation_from_module_proto("bodycomputation",
body_hlo)

# generate while xlacomputation
input_tuple = xb.Op.tuple(tuple(params))
w = xb.mkop(
'While', (input_tuple.op,),
condition_computation=cond_computation,
body_computation=body_computation)
name = 'fori_loop_ed_torch_func'
computation = w.build(name)

# gain final result with generated while xlacomputation
result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while',
(carried_inputs), computation)

return result
import torch_xla.core.xla_builder as xb


def _fake_while_loop(cond_fn, body_fn, operands):
# operands need to be more than one here
while cond_fn(*operands):
operands = body_fn(*operands)
return operands


def _fake_fori_loop(lower, upper, body_fun, *init_val):
(plus_value, init_val) = init_val
for i in range((upper - lower)[0]):
plus_value, init_val = body_fun(plus_value, init_val)
return init_val


class WhileLoopTest(unittest.TestCase):

def test_while_loop_tpu_subtraction(self):

device = xm.xla_device()

def cond_fn(init, limit_value):
return limit_value[0] <= init[0]

def body_fn(init, limit_value):
one_value = torch.ones(1, dtype=torch.int32, device=device)
two_value = limit_value.clone()
return (torch.sub(init, one_value), two_value)

init = torch.tensor([10], dtype=torch.int32, device=device)
limit_value = torch.tensor([0], dtype=torch.int32, device=device)
res = while_loop(cond_fn, body_fn, (init, limit_value))
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)

def test_while_loop_tpu_addition(self):

device = xm.xla_device()

def cond_fn(init, limit_value):
return limit_value[0] >= init[0]

def body_fn(init, limit_value):
one_value = torch.ones(1, dtype=torch.int32, device=device)
return (torch.add(init, one_value), limit_value.clone())

# TODO(@manfei): init and limit_value has to be torch.tensor.
init = torch.tensor([0], dtype=torch.int32, device=device)
limit_value = torch.tensor([10], dtype=torch.int32, device=device)
res = while_loop(cond_fn, body_fn, (init, limit_value))
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)

def test_while_loop_tpu_subtraction_nested(self):

device = xm.xla_device()

def cond_fn(init, limit_value):
return limit_value[0] <= init[0]

def body_fn(init, limit_value):
one_value = torch.ones(1, dtype=torch.int32, device=device)
two_value = limit_value.clone()
return (torch.sub(torch.sub(init, one_value), one_value), two_value)

init = torch.tensor([10], dtype=torch.int32, device=device)
limit_value = torch.tensor([0], dtype=torch.int32, device=device)
res = while_loop(cond_fn, body_fn, (init, limit_value))
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)

def test_fori_loop_tpu_addition(self):

xm.mark_step()
device = xm.xla_device()

lower = torch.tensor([2], dtype=torch.int32, device=device)
upper = torch.tensor([52], dtype=torch.int32, device=device)
plus_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)

def body_fun(*argus):
plus_value, init_val = argus
return plus_value, torch.add(plus_value, init_val)

_, _, _, actual = fori_loop(upper, lower, body_fun, plus_value, init_val)
expected = _fake_fori_loop(lower, upper, body_fun, plus_value, init_val)
self.assertEqual(expected, actual)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)

0 comments on commit 759653d

Please sign in to comment.