-
Notifications
You must be signed in to change notification settings - Fork 491
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Fori_loop|While_loop] Placeholder lower torch.while_loop with python…
… dispatch for simple addition test case (#6532)
- Loading branch information
Showing
5 changed files
with
84 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
39 changes: 39 additions & 0 deletions
39
test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import os | ||
import unittest | ||
from typing import Callable, Dict, List | ||
|
||
import torch | ||
import torch_xla | ||
# We need to import the underlying implementation function to register with the dispatcher | ||
import torch_xla.experimental.fori_loop | ||
from torch._higher_order_ops.while_loop import while_loop | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.core.xla_builder as xb | ||
|
||
|
||
def _fake_while_loop(cond_fn, body_fn, operands): | ||
while cond_fn(*operands): | ||
operands = body_fn(*operands) | ||
return operands | ||
|
||
|
||
class WhileLoopTest(unittest.TestCase): | ||
|
||
def test_while_loop_tpu(self): | ||
|
||
def cond_fn(x): | ||
return x.sum() <= 10 | ||
|
||
def body_fn(x): | ||
return (x + 1,) | ||
|
||
device = xm.xla_device() | ||
x = torch.ones(1, dtype=torch.int, device=device) | ||
res = while_loop(cond_fn, body_fn, (x,)) | ||
expected = _fake_while_loop(cond_fn, body_fn, x) | ||
self.assertEqual(expected, res) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import numpy as np | ||
import torch | ||
import torch_xla | ||
import torch_xla.core.xla_builder as xb | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.utils.utils as xu | ||
import torch_xla.core.xla_op_registry as xor | ||
|
||
from torch._C import DispatchKey | ||
from torch._ops import HigherOrderOperator | ||
import torch._higher_order_ops.while_loop | ||
from torch._higher_order_ops.while_loop import while_loop_op | ||
|
||
|
||
@while_loop_op.py_impl(DispatchKey.XLA) | ||
def while_loop(cond_fn, body_fn, operands): | ||
# cond_fn&body_fn: callable | ||
# operands: (Tuple of possibly nested dict/list/tuple of tensors) | ||
return _xla_while_loop(cond_fn, body_fn, operands) | ||
|
||
|
||
def _xla_while_loop(cond_fn, body_fn, operands): | ||
|
||
def op_fn(internal_x): | ||
# TODO(manfei) replace cond_fn_placeholder and body_fn_placeholder once xla::while lowering in the backend is available | ||
def cond_fn_placeholder(counter, internal_x): | ||
return counter < xb.Op.scalar(internal_x.builder(), 10, dtype=xb.Type.S32) | ||
|
||
def body_fn_placeholder(counter, internal_x): | ||
next_counter = counter + xb.Op.scalar( | ||
counter.builder(), 1, dtype=xb.Type.S32) | ||
internal_x = internal_x + xb.Op.scalar( | ||
internal_x.builder(), 1, dtype=xb.Type.S32) | ||
return xb.Op.tuple((next_counter, internal_x)) | ||
|
||
zero = xb.Op.scalar(internal_x.builder(), 0, dtype=xb.Type.S32) | ||
w = xb.Op.mkwhile((zero, internal_x), cond_fn_placeholder, | ||
body_fn_placeholder) | ||
return w.get_tuple_element(1) | ||
|
||
op = xor.register('test_while', op_fn) | ||
return xu.as_list(op(operands[0])) |