Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[test] Fori loop simple case test testnewone #7029

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
546 commits
Select commit Hold shift + click to select a range
cfa5f7e
update
ManfeiBai Apr 17, 2024
0cb1dce
update
ManfeiBai Apr 17, 2024
2dcfab0
update
ManfeiBai Apr 17, 2024
d0270b6
update
ManfeiBai Apr 17, 2024
2377b8b
update
ManfeiBai Apr 17, 2024
8947db7
update
ManfeiBai Apr 17, 2024
0fbb23d
update
ManfeiBai Apr 17, 2024
53f8185
update
ManfeiBai Apr 17, 2024
52435ec
update
ManfeiBai Apr 17, 2024
116a68b
update
ManfeiBai Apr 17, 2024
3cb631a
update
ManfeiBai Apr 17, 2024
fe3a530
update
ManfeiBai Apr 17, 2024
ac574f3
update
ManfeiBai Apr 17, 2024
ca3e757
update
ManfeiBai Apr 17, 2024
1819420
update
ManfeiBai Apr 17, 2024
1b91fc8
update
ManfeiBai Apr 17, 2024
673145d
update
ManfeiBai Apr 17, 2024
12f6f71
update
ManfeiBai Apr 17, 2024
89605cb
format
ManfeiBai Apr 17, 2024
2e9c979
format
ManfeiBai Apr 17, 2024
6db8139
format
ManfeiBai Apr 17, 2024
da35561
format
ManfeiBai Apr 17, 2024
431ab66
format
ManfeiBai Apr 17, 2024
6244d21
format
ManfeiBai Apr 17, 2024
04ca72d
format
ManfeiBai Apr 17, 2024
33fa1fb
format
ManfeiBai Apr 17, 2024
ea0e663
down into cpp
ManfeiBai Apr 23, 2024
cf9bb2f
down into cpp
ManfeiBai Apr 24, 2024
0e64276
down into cpp
ManfeiBai Apr 24, 2024
6fbc330
down into cpp
ManfeiBai Apr 24, 2024
3d2bd41
down into cpp
ManfeiBai Apr 24, 2024
d9d6358
down into cpp
ManfeiBai Apr 24, 2024
8ac3fc8
down into cpp
ManfeiBai Apr 24, 2024
387b9fd
down into cpp
ManfeiBai Apr 24, 2024
dc10d64
down into cpp
ManfeiBai Apr 24, 2024
a2c01a1
down into cpp
ManfeiBai Apr 24, 2024
6f7dcc5
down into cpp
ManfeiBai Apr 24, 2024
83aa4b0
down into cpp
ManfeiBai Apr 24, 2024
73ab5c1
down into cpp
ManfeiBai Apr 24, 2024
6ce33bd
down into cpp
ManfeiBai Apr 24, 2024
ea4cc8c
down into cpp
ManfeiBai Apr 24, 2024
362bdc3
down into cpp
ManfeiBai Apr 24, 2024
2f29c1b
down into cpp
ManfeiBai Apr 24, 2024
32dd2ea
down into cpp
ManfeiBai Apr 24, 2024
2bae9fe
down into cpp
ManfeiBai Apr 24, 2024
3a06a3e
down into cpp
ManfeiBai Apr 24, 2024
c7fdd16
down into cpp
ManfeiBai Apr 24, 2024
0f8d315
down into cpp
ManfeiBai Apr 24, 2024
9f6912c
down into cpp
ManfeiBai Apr 24, 2024
9f16c7d
down into cpp
ManfeiBai Apr 24, 2024
77656a1
down into cpp
ManfeiBai Apr 24, 2024
3cd1a6b
down into cpp
ManfeiBai Apr 24, 2024
1841bd2
down into cpp
ManfeiBai Apr 24, 2024
98a7e7c
down into cpp
ManfeiBai Apr 24, 2024
98d044b
down into cpp
ManfeiBai Apr 24, 2024
8bbd848
down into cpp
ManfeiBai Apr 24, 2024
b47a022
down into cpp
ManfeiBai Apr 24, 2024
d705daf
down into cpp
ManfeiBai Apr 24, 2024
3c17839
down into cpp
ManfeiBai Apr 24, 2024
b4f479b
down into cpp
ManfeiBai Apr 24, 2024
c53b9b9
down into cpp
ManfeiBai Apr 24, 2024
4ded9a3
down into cpp
ManfeiBai Apr 24, 2024
ffbe9a5
down into cpp
ManfeiBai Apr 24, 2024
2be4254
down into cpp
ManfeiBai Apr 24, 2024
7df2af4
down into cpp
ManfeiBai Apr 24, 2024
1be8e3a
down into cpp
ManfeiBai Apr 24, 2024
929ff17
down into cpp
ManfeiBai Apr 24, 2024
0d8a23c
down into cpp
ManfeiBai Apr 24, 2024
be617cc
down into cpp
ManfeiBai Apr 24, 2024
9925fd2
down into cpp
ManfeiBai Apr 24, 2024
bf038a5
down into cpp
ManfeiBai Apr 24, 2024
3727dbf
down into cpp
ManfeiBai Apr 24, 2024
e210bc7
down into cpp
ManfeiBai Apr 24, 2024
7a084c5
down into cpp
ManfeiBai Apr 24, 2024
b7d90a6
down into cpp
ManfeiBai Apr 24, 2024
dff1e9d
down into cpp
ManfeiBai Apr 24, 2024
251bd6f
down into cpp
ManfeiBai Apr 24, 2024
87f9eca
down into cpp
ManfeiBai Apr 24, 2024
d15a492
down into cpp
ManfeiBai Apr 24, 2024
15ec185
down into cpp
ManfeiBai Apr 24, 2024
31afff5
down into cpp
ManfeiBai Apr 24, 2024
6e3c699
down into cpp
ManfeiBai Apr 24, 2024
8405149
down into cpp
ManfeiBai Apr 24, 2024
25efff6
down into cpp
ManfeiBai Apr 24, 2024
8790af2
down into cpp
ManfeiBai Apr 24, 2024
be5710d
down into cpp
ManfeiBai Apr 24, 2024
9cbdb40
down into cpp
ManfeiBai Apr 24, 2024
5b8fcb7
down into cpp
ManfeiBai Apr 24, 2024
7d07234
down into cpp
ManfeiBai Apr 24, 2024
f11daa0
down into cpp
ManfeiBai Apr 24, 2024
3aa67e7
down into cpp
ManfeiBai Apr 24, 2024
760d49c
down into cpp
ManfeiBai Apr 24, 2024
9f7ddc2
down into cpp
ManfeiBai Apr 24, 2024
76b4b82
down into cpp
ManfeiBai Apr 24, 2024
042d37c
down into cpp
ManfeiBai Apr 24, 2024
088355b
down into cpp
ManfeiBai Apr 24, 2024
67b5840
down into cpp
ManfeiBai Apr 24, 2024
1b81829
down into cpp
ManfeiBai Apr 24, 2024
bee343f
down into cpp
ManfeiBai Apr 24, 2024
2de7464
down into cpp
ManfeiBai Apr 24, 2024
e25e9d2
down into cpp
ManfeiBai Apr 24, 2024
58359ca
down into cpp
ManfeiBai Apr 24, 2024
6ee65d1
down into cpp
ManfeiBai Apr 24, 2024
72a8d07
down into cpp
ManfeiBai Apr 24, 2024
5734811
down into cpp
ManfeiBai Apr 24, 2024
f244aab
down into cpp
ManfeiBai Apr 24, 2024
85a7975
down into cpp
ManfeiBai Apr 25, 2024
c7a2d9a
down into cpp
ManfeiBai Apr 25, 2024
6afbc66
down into cpp
ManfeiBai Apr 25, 2024
6c71661
down into cpp
ManfeiBai Apr 25, 2024
b492c01
down into cpp
ManfeiBai Apr 25, 2024
8dc49d5
down into cpp
ManfeiBai Apr 25, 2024
32c0e1a
down into cpp
ManfeiBai Apr 25, 2024
b2b14a3
format
ManfeiBai Apr 26, 2024
187330a
format
ManfeiBai Apr 26, 2024
53a1856
format
ManfeiBai Apr 26, 2024
fac3f1a
format
ManfeiBai Apr 26, 2024
82dedf7
format
ManfeiBai Apr 26, 2024
289a7b0
format
ManfeiBai Apr 26, 2024
d2fe907
format
ManfeiBai Apr 26, 2024
1bd65bf
test
ManfeiBai Apr 26, 2024
08f06f6
test
ManfeiBai Apr 26, 2024
e7fa8f8
test
ManfeiBai Apr 26, 2024
74bfd5f
test
ManfeiBai Apr 26, 2024
4fcbceb
test
ManfeiBai Apr 26, 2024
4525f2e
test
ManfeiBai Apr 26, 2024
0f7faed
test
ManfeiBai Apr 26, 2024
f9fdb2a
test
ManfeiBai Apr 26, 2024
443c28a
test
ManfeiBai Apr 26, 2024
5d8ea80
test
ManfeiBai Apr 26, 2024
9e6e12e
test
ManfeiBai Apr 26, 2024
41ddcfc
test
ManfeiBai Apr 26, 2024
c6acfe3
test
ManfeiBai Apr 26, 2024
4a0303f
test
ManfeiBai Apr 26, 2024
9b4b93a
test
ManfeiBai Apr 26, 2024
9eaf3fd
test
ManfeiBai Apr 26, 2024
7dca87a
test
ManfeiBai Apr 26, 2024
8c115b0
test
ManfeiBai Apr 26, 2024
76893a3
test
ManfeiBai Apr 26, 2024
d6c0025
test
ManfeiBai Apr 26, 2024
48a2b17
test
ManfeiBai Apr 26, 2024
7b65514
test
ManfeiBai Apr 26, 2024
af6d44b
test
ManfeiBai Apr 26, 2024
d3bea05
test
ManfeiBai Apr 26, 2024
afdac6c
test
ManfeiBai Apr 26, 2024
1f052a5
test
ManfeiBai Apr 26, 2024
2c8a445
test
ManfeiBai Apr 26, 2024
83d5e4f
test
ManfeiBai Apr 26, 2024
19eaf2c
test
ManfeiBai Apr 26, 2024
58d92f8
test
ManfeiBai Apr 29, 2024
1265107
test
ManfeiBai Apr 29, 2024
b15d721
test
ManfeiBai Apr 29, 2024
000d5ca
test
ManfeiBai Apr 29, 2024
f0678f7
test
ManfeiBai Apr 29, 2024
66b6027
test
ManfeiBai Apr 29, 2024
9953a7f
test
ManfeiBai Apr 29, 2024
a6e28e0
test
ManfeiBai Apr 29, 2024
a30db8e
test
ManfeiBai Apr 29, 2024
a54ba4c
test
ManfeiBai Apr 29, 2024
b86a03a
test
ManfeiBai Apr 29, 2024
407adf8
test
ManfeiBai Apr 29, 2024
532b9ce
test
ManfeiBai Apr 29, 2024
205f2e3
test
ManfeiBai Apr 29, 2024
c4e547d
test
ManfeiBai Apr 29, 2024
d4f86c7
test
ManfeiBai Apr 29, 2024
98a5b7b
test
ManfeiBai Apr 29, 2024
93c5db0
test
ManfeiBai Apr 29, 2024
9977250
test
ManfeiBai Apr 29, 2024
ad19fca
test
ManfeiBai Apr 29, 2024
47eb44f
test
ManfeiBai Apr 29, 2024
7d66521
test
ManfeiBai Apr 30, 2024
bcbed01
test
ManfeiBai Apr 30, 2024
a3744e4
test
ManfeiBai Apr 30, 2024
4783e72
test
ManfeiBai Apr 30, 2024
e4e0066
test
ManfeiBai Apr 30, 2024
906292a
test
ManfeiBai Apr 30, 2024
1b7d3af
test
ManfeiBai Apr 30, 2024
82ba323
test
ManfeiBai Apr 30, 2024
f2307aa
test
ManfeiBai Apr 30, 2024
74e69de
test
ManfeiBai Apr 30, 2024
d596c60
test
ManfeiBai Apr 30, 2024
cea7a98
test
ManfeiBai Apr 30, 2024
1da2f8e
test
ManfeiBai Apr 30, 2024
59152dd
test
ManfeiBai Apr 30, 2024
8d795b0
test
ManfeiBai Apr 30, 2024
fa0632d
test
ManfeiBai Apr 30, 2024
8f4b9a3
test
ManfeiBai Apr 30, 2024
ba14f8d
test
ManfeiBai Apr 30, 2024
4e00273
test
ManfeiBai Apr 30, 2024
499e304
test
ManfeiBai Apr 30, 2024
04009c4
test
ManfeiBai Apr 30, 2024
eb4aeec
test
ManfeiBai Apr 30, 2024
ee2a0d5
test
ManfeiBai Apr 30, 2024
5ef8dea
test
ManfeiBai Apr 30, 2024
7079301
test
ManfeiBai Apr 30, 2024
584d6cc
test
ManfeiBai Apr 30, 2024
0affbb4
test
ManfeiBai Apr 30, 2024
4fe7a62
test
ManfeiBai Apr 30, 2024
7f9dfa3
test
ManfeiBai Apr 30, 2024
d9bb401
test
ManfeiBai Apr 30, 2024
bebed0a
test
ManfeiBai Apr 30, 2024
881dbca
test
ManfeiBai Apr 30, 2024
bab82ee
test
ManfeiBai Apr 30, 2024
76b4ba4
test
ManfeiBai Apr 30, 2024
1e6cb5b
test
ManfeiBai Apr 30, 2024
a80e747
test
ManfeiBai Apr 30, 2024
bb7f682
test
ManfeiBai Apr 30, 2024
9e257f6
test
ManfeiBai Apr 30, 2024
42da876
test
ManfeiBai Apr 30, 2024
c31b3ec
test
ManfeiBai Apr 30, 2024
cf33820
test
ManfeiBai Apr 30, 2024
65f2edb
test
ManfeiBai Apr 30, 2024
2154f65
test
ManfeiBai Apr 30, 2024
214ac1f
test
ManfeiBai Apr 30, 2024
ddb69ee
test
ManfeiBai Apr 30, 2024
75dad47
test
ManfeiBai Apr 30, 2024
7225a3c
test
ManfeiBai Apr 30, 2024
77c0e42
test
ManfeiBai Apr 30, 2024
5454e20
test
ManfeiBai Apr 30, 2024
843758f
test
ManfeiBai Apr 30, 2024
558f80f
test
ManfeiBai Apr 30, 2024
af06fa6
test
ManfeiBai Apr 30, 2024
e1b6c8a
test
ManfeiBai Apr 30, 2024
8f5572f
test
ManfeiBai Apr 30, 2024
00f4990
test
ManfeiBai Apr 30, 2024
e109619
test
ManfeiBai Apr 30, 2024
ea1ce95
test
ManfeiBai Apr 30, 2024
abafd6a
test
ManfeiBai Apr 30, 2024
95227f0
test
ManfeiBai Apr 30, 2024
06cb773
test
ManfeiBai Apr 30, 2024
f55e4b5
test
ManfeiBai Apr 30, 2024
47d06eb
test
ManfeiBai Apr 30, 2024
68892b4
test
ManfeiBai Apr 30, 2024
fa12ead
test
ManfeiBai Apr 30, 2024
f4b79fc
test
ManfeiBai Apr 30, 2024
532e8c0
test
ManfeiBai Apr 30, 2024
1222705
test
ManfeiBai Apr 30, 2024
814f561
test
ManfeiBai Apr 30, 2024
84f2791
test
ManfeiBai Apr 30, 2024
113dacc
test
ManfeiBai Apr 30, 2024
94e7599
test
ManfeiBai Apr 30, 2024
d38b494
test
ManfeiBai Apr 30, 2024
180d7c5
test
ManfeiBai Apr 30, 2024
a904700
test
ManfeiBai Apr 30, 2024
1831d2c
test
ManfeiBai Apr 30, 2024
2fad582
test
ManfeiBai Apr 30, 2024
6e7ae09
test
ManfeiBai Apr 30, 2024
5004a74
test
ManfeiBai Apr 30, 2024
5fb1769
mnist
May 1, 2024
94025fa
addd
ManfeiBai May 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 249 additions & 12 deletions test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,22 @@ def _fake_while_loop(cond_fn, body_fn, operands):


def _fake_fori_loop(lower, upper, body_fun, *init_val):
(plus_value, init_val) = init_val
for i in range((upper - lower)[0]):
plus_value, init_val = body_fun(plus_value, init_val)
return init_val
if len(init_val) > 1:
(a, b) = init_val
for i in range((upper - lower)[0]):
a = body_fun(a, b)
else:
for i in range((upper - lower)[0]):
a = body_fun(*init_val)
return a


class WhileLoopTest(unittest.TestCase):

# additional_inputs: ()
def test_while_loop_tpu_subtraction(self):

print("$$$ test_while_loop_tpu_subtraction !!!")
device = xm.xla_device()

def cond_fn(init, limit_value):
Expand All @@ -46,8 +52,10 @@ def body_fn(init, limit_value):
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)

# additional_inputs: ()
def test_while_loop_tpu_addition(self):

print("$$$ test_while_loop_tpu_addition !!!")
device = xm.xla_device()

def cond_fn(init, limit_value):
Expand All @@ -64,8 +72,10 @@ def body_fn(init, limit_value):
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)

# additional_inputs: ()
def test_while_loop_tpu_subtraction_nested(self):

print("$$$ test_while_loop_tpu_subtraction_nested !!!")
device = xm.xla_device()

def cond_fn(init, limit_value):
Expand All @@ -82,25 +92,252 @@ def body_fn(init, limit_value):
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)

### return weight/bias
# additional_inputs: (tensor([1*20], device='xla:0'), tensor([10*20], device='xla:0'))
def test_while_loop_tpu_simple_linear(self):

print("$$$ test_while_loop_tpu_simple_linear !!!")
xm.mark_step()
device = xm.xla_device()
torch.set_grad_enabled(False)

linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device())

def cond_fn(upper, lower, one_value, x, input_value, output_value):
return lower[0] < upper[0]

def body_fn(upper, lower, one_value, x, input_value, output_value):
new_lower = torch.add(one_value, lower)
output_value = linear_0(input_value)
weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement
bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(
one_value, x), input_value.clone(), bias.clone(), weight.clone(
), output_value.clone()

upper = torch.tensor([1], dtype=torch.int32, device=device)
lower = torch.tensor([0], dtype=torch.int32, device=device)
one_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)
l_in_0 = torch.rand(10, device=xm.xla_device())
output_value = torch.zeros([20], dtype=torch.float32, device=device)

upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop(
cond_fn, body_fn,
(upper, lower, one_value, init_val, l_in_0, output_value))

expected = _fake_fori_loop(lower, upper, linear_0, l_in_0)

return self.assertTrue(torch.all(torch.eq(expected, output_value_real__)))

###
#
def test_while_loop_tpu_simple_linear_wrapper(self):

print("$$$ test_while_loop_tpu_simple_linear_wrapper !!!")
xm.mark_step()
device = xm.xla_device()
torch.set_grad_enabled(False)

linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device())

def cond_fn(upper, lower, one_value, x, input_value, output_value):
return lower[0] < upper[0]

def body_fn(upper, lower, one_value, x, input_value, output_value):
new_lower = torch.add(one_value, lower)
output_value = linear_0(input_value)
# weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement
# bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(
one_value, x), input_value.clone(), output_value.clone()

upper = torch.tensor([1], dtype=torch.int32, device=device)
lower = torch.tensor([0], dtype=torch.int32, device=device)
one_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)
l_in_0 = torch.rand(10, device=xm.xla_device())
output_value = torch.zeros([20], dtype=torch.float32, device=device)

# upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop(
# cond_fn, body_fn,
# (upper, lower, one_value, init_val, l_in_0, output_value))
upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, = while_loop(
cond_fn, body_fn,
(upper, lower, one_value, init_val, l_in_0, output_value))

expected = _fake_fori_loop(lower, upper, linear_0, l_in_0)

return self.assertTrue(torch.all(torch.eq(expected, output_value_real__)))


### return weight/bias
# additional_inputs: (tensor([ 1*20], device='xla:0'), tensor([10*20], device='xla:0'))
def test_while_loop_tpu_simple_linear_class(self):

print("$$$ test_while_loop_tpu_simple_linear_class !!!")
xm.mark_step()
device = xm.xla_device()
torch.set_grad_enabled(False)

class SimpleWithLinear(torch.nn.Module):

def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 20).to(xm.xla_device())

def forward(self, upper, lower, one_value, x, input_value, output_value):

def cond_fn(upper, lower, one_value, x, input_value, output_value):
return lower[0] < upper[0]

def body_fn(upper, lower, one_value, x, input_value, output_value):
new_lower = torch.add(one_value, lower)
output_value_real = self.linear(input_value)
weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement
bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(
one_value, x), input_value.clone(
), output_value_real, weight.clone(), bias.clone()

return while_loop(
cond_fn, body_fn,
(upper, lower, one_value, x, input_value, output_value))

simple_with_linear = SimpleWithLinear()
upper = torch.tensor([52], dtype=torch.int32, device=device)
lower = torch.tensor([0], dtype=torch.int32, device=device)
one_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)
l_in_0 = torch.rand(10, device=xm.xla_device())
output_value = torch.zeros([20], dtype=torch.float32, device=device)

weight_0 = simple_with_linear.linear.weight
bias_0 = simple_with_linear.linear.bias

aaa = {
"simple_with_linear":
(simple_with_linear, (upper, lower, one_value, init_val, l_in_0,
output_value))
}

upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear(
upper, lower, one_value, init_val, l_in_0, output_value)

# create same weight/bias liear model for compare
linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device())
linear_0.weight.data = weight__
linear_0.bias.data = bias__
expected = _fake_fori_loop(lower, upper, linear_0, l_in_0)

self.assertTrue(torch.all(torch.eq(expected, output_value_real__)))
return aaa

###
#
def test_while_loop_tpu_simple_linear_class_wrapper(self):

print("$$$ test_while_loop_tpu_simple_linear_class_wrapper !!!")
xm.mark_step()
device = xm.xla_device()
torch.set_grad_enabled(False)

class SimpleWithLinear(torch.nn.Module):

def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 20).to(xm.xla_device())

def forward(self, upper, lower, one_value, x, input_value, output_value):

def cond_fn(upper, lower, one_value, x, input_value, output_value):
return lower[0] < upper[0]

def body_fn(upper, lower, one_value, x, input_value, output_value):
new_lower = torch.add(one_value, lower)
output_value_real = self.linear(input_value)
weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement
bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(
one_value, x), input_value.clone(
), output_value_real

return while_loop(
cond_fn, body_fn,
(upper, lower, one_value, x, input_value, output_value))

simple_with_linear = SimpleWithLinear()
upper = torch.tensor([52], dtype=torch.int32, device=device)
lower = torch.tensor([0], dtype=torch.int32, device=device)
one_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)
l_in_0 = torch.rand(10, device=xm.xla_device())
output_value = torch.zeros([20], dtype=torch.float32, device=device)

weight_0 = simple_with_linear.linear.weight
bias_0 = simple_with_linear.linear.bias

aaa = {
"simple_with_linear":
(simple_with_linear, (upper, lower, one_value, init_val, l_in_0,
output_value))
}

upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear(
upper, lower, one_value, init_val, l_in_0, output_value)

# create same weight/bias liear model for compare
linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device())
linear_0.weight.data = weight__
linear_0.bias.data = bias__
expected = _fake_fori_loop(lower, upper, linear_0, l_in_0)

self.assertTrue(torch.all(torch.eq(expected, output_value_real__)))
return aaa

# additional_inputs: ()
def test_fori_loop_tpu_addition(self):

print("$$$ test_fori_loop_tpu_addition !!!")
xm.mark_step()
device = xm.xla_device()

lower = torch.tensor([2], dtype=torch.int32, device=device)
upper = torch.tensor([52], dtype=torch.int32, device=device)
plus_value = torch.tensor([1], dtype=torch.int32, device=device)
one_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)

def body_fun(a, b):
return torch.add(a, b)

upper_, new_lower_, one_value_, add_res_x_, res_ = fori_loop(
upper, lower, body_fun, one_value, init_val)
expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value)
self.assertEqual(expected, res_)

# additional_inputs: (tensor([1*20], device='xla:0'), tensor([[10*20], device='xla:0'))
def test_fori_loop_tpu_simple_linear(self):

print("$$$ test_fori_loop_tpu_simple_linear !!!")
xm.mark_step()
device = xm.xla_device()
torch.set_grad_enabled(False)

upper = torch.tensor([52], dtype=torch.int32, device=device)
lower = torch.tensor([0], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)
l_in_0 = torch.randn(10, device=xm.xla_device())

linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device())

upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop(
upper, lower, linear_0, init_val, l_in_0)

def body_fun(*argus):
plus_value, init_val = argus
return plus_value, torch.add(plus_value, init_val)
expected = _fake_fori_loop(lower, upper, linear_0, l_in_0)

_, _, _, actual = fori_loop(upper, lower, body_fun, plus_value, init_val)
expected = _fake_fori_loop(lower, upper, body_fun, plus_value, init_val)
self.assertEqual(expected, actual)
self.assertTrue(torch.all(torch.eq(expected, l_out_)))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
sys.exit(0 if test.result.wasSuccessful() else 1)
Loading
Loading