Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fori_loop|While_loop] Enable while_loop/fori_loop, Add linear/MNIST test case #6867

Closed
wants to merge 323 commits into from
Closed
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
323 commits
Select commit Hold shift + click to select a range
c4fe122
update
ManfeiBai Apr 12, 2024
80d2003
update
ManfeiBai Apr 12, 2024
71718e1
update
ManfeiBai Apr 12, 2024
e8e18f7
update
ManfeiBai Apr 12, 2024
9741e8d
update
ManfeiBai Apr 12, 2024
26e30a3
update
ManfeiBai Apr 12, 2024
c6f8259
update
ManfeiBai Apr 12, 2024
5b9378c
update
ManfeiBai Apr 12, 2024
6408879
update
ManfeiBai Apr 12, 2024
7c408a9
update
ManfeiBai Apr 12, 2024
9c68528
update
ManfeiBai Apr 12, 2024
141c704
update
ManfeiBai Apr 12, 2024
d5e5537
update
ManfeiBai Apr 12, 2024
cfbe475
update
ManfeiBai Apr 12, 2024
b11f015
update
ManfeiBai Apr 12, 2024
9cd8ca0
update
ManfeiBai Apr 12, 2024
89a2b34
update
ManfeiBai Apr 12, 2024
9e420ad
update
ManfeiBai Apr 12, 2024
a536a3e
update
ManfeiBai Apr 12, 2024
b2b246f
update
ManfeiBai Apr 12, 2024
2664e23
update
ManfeiBai Apr 12, 2024
4f28a54
update
ManfeiBai Apr 12, 2024
5c993af
update
ManfeiBai Apr 12, 2024
8298841
update
ManfeiBai Apr 12, 2024
32af47b
update
ManfeiBai Apr 12, 2024
448de12
update
ManfeiBai Apr 12, 2024
abe4c78
update
ManfeiBai Apr 12, 2024
1bf37d8
update
ManfeiBai Apr 12, 2024
2021537
update
ManfeiBai Apr 12, 2024
83c2f97
update
ManfeiBai Apr 12, 2024
1945952
update
ManfeiBai Apr 12, 2024
6ee5400
update
ManfeiBai Apr 12, 2024
3db127e
update
ManfeiBai Apr 12, 2024
dc1837d
update
ManfeiBai Apr 12, 2024
ea9bf56
update
ManfeiBai Apr 12, 2024
ee05a67
update
ManfeiBai Apr 12, 2024
67f2652
update
ManfeiBai Apr 12, 2024
3bba9fd
update
ManfeiBai Apr 12, 2024
2e349a6
update
ManfeiBai Apr 12, 2024
7b63eb8
update
ManfeiBai Apr 12, 2024
62855cc
update
ManfeiBai Apr 12, 2024
70a29c4
update
ManfeiBai Apr 12, 2024
863ba37
update
ManfeiBai Apr 12, 2024
4da7ede
update
ManfeiBai Apr 12, 2024
261ec24
update
ManfeiBai Apr 12, 2024
334036e
update
ManfeiBai Apr 12, 2024
acc2e20
update
ManfeiBai Apr 12, 2024
3e7172c
update
ManfeiBai Apr 12, 2024
ead5981
update
ManfeiBai Apr 12, 2024
4762c55
update
ManfeiBai Apr 12, 2024
e4ff32f
update
ManfeiBai Apr 12, 2024
7215162
update
ManfeiBai Apr 12, 2024
f275c0e
update
ManfeiBai Apr 12, 2024
2958d83
update
ManfeiBai Apr 12, 2024
9db53a3
update
ManfeiBai Apr 12, 2024
d35e48e
update
ManfeiBai Apr 12, 2024
25ac04e
update
ManfeiBai Apr 12, 2024
3f473ad
update
ManfeiBai Apr 12, 2024
b31e280
update
ManfeiBai Apr 12, 2024
845d903
update
ManfeiBai Apr 12, 2024
13157d9
update
ManfeiBai Apr 12, 2024
41d3ba6
update
ManfeiBai Apr 12, 2024
52cd5bc
update
ManfeiBai Apr 12, 2024
a3756e6
update
ManfeiBai Apr 12, 2024
21ba83d
update
ManfeiBai Apr 12, 2024
24f2fe3
update
ManfeiBai Apr 12, 2024
4687bbe
update
ManfeiBai Apr 12, 2024
8713b1c
update
ManfeiBai Apr 12, 2024
05f1c33
update
ManfeiBai Apr 12, 2024
d96d735
update
ManfeiBai Apr 13, 2024
e12bda6
update
ManfeiBai Apr 13, 2024
8ed8f12
update
ManfeiBai Apr 13, 2024
094a4ca
update
ManfeiBai Apr 13, 2024
529a8c8
update
ManfeiBai Apr 13, 2024
57083ae
update
ManfeiBai Apr 13, 2024
630caa9
update
ManfeiBai Apr 13, 2024
406635b
update
ManfeiBai Apr 13, 2024
9925465
update
ManfeiBai Apr 13, 2024
9b1b7a7
update
ManfeiBai Apr 15, 2024
3cf0200
update
ManfeiBai Apr 15, 2024
1b06687
update
ManfeiBai Apr 15, 2024
13b7767
update
ManfeiBai Apr 15, 2024
90f6df1
update
ManfeiBai Apr 15, 2024
732876b
update
ManfeiBai Apr 15, 2024
1ec25fe
update
ManfeiBai Apr 15, 2024
fdb7cab
update
ManfeiBai Apr 15, 2024
17d43ef
update
ManfeiBai Apr 15, 2024
0016420
update
ManfeiBai Apr 15, 2024
4274575
update
ManfeiBai Apr 15, 2024
4fd2e4a
update
ManfeiBai Apr 15, 2024
b062900
update
ManfeiBai Apr 15, 2024
c2f2973
update
ManfeiBai Apr 15, 2024
8dbde24
update
ManfeiBai Apr 15, 2024
066047c
update
ManfeiBai Apr 15, 2024
374d16d
update
ManfeiBai Apr 15, 2024
1f30e26
update
ManfeiBai Apr 15, 2024
fc85bc3
update
ManfeiBai Apr 15, 2024
6ebda73
update
ManfeiBai Apr 15, 2024
55c2ea1
update
ManfeiBai Apr 15, 2024
36d2565
update
ManfeiBai Apr 15, 2024
4a500c0
update
ManfeiBai Apr 15, 2024
3a2553e
update
ManfeiBai Apr 15, 2024
0998b37
update
ManfeiBai Apr 15, 2024
267e8b3
update
ManfeiBai Apr 15, 2024
a577859
update
ManfeiBai Apr 15, 2024
a3ee72a
update
ManfeiBai Apr 15, 2024
293b87a
update
ManfeiBai Apr 15, 2024
128a3dc
update
ManfeiBai Apr 15, 2024
f5298a5
update
ManfeiBai Apr 15, 2024
e4104a4
update
ManfeiBai Apr 15, 2024
2ef7d32
update
ManfeiBai Apr 15, 2024
4affdd7
update
ManfeiBai Apr 15, 2024
3254253
update
ManfeiBai Apr 15, 2024
92887e7
update
ManfeiBai Apr 15, 2024
98425a8
update
ManfeiBai Apr 15, 2024
862e3f1
update
ManfeiBai Apr 15, 2024
6b07e22
update
ManfeiBai Apr 15, 2024
1edc7bd
update
ManfeiBai Apr 15, 2024
4cfa522
update
ManfeiBai Apr 15, 2024
edb6fc7
update
ManfeiBai Apr 15, 2024
99a6589
update
ManfeiBai Apr 15, 2024
a34cd6f
update
ManfeiBai Apr 15, 2024
be562ad
update
ManfeiBai Apr 15, 2024
11c1b54
update
ManfeiBai Apr 15, 2024
2480b5b
update
ManfeiBai Apr 15, 2024
cce486d
update
ManfeiBai Apr 15, 2024
38d30b7
update
ManfeiBai Apr 15, 2024
a85a8e3
update
ManfeiBai Apr 15, 2024
8b73e43
update
ManfeiBai Apr 15, 2024
839f5d1
update
ManfeiBai Apr 15, 2024
4b96b9b
update
ManfeiBai Apr 15, 2024
d1e141a
update
ManfeiBai Apr 15, 2024
5f73913
update
ManfeiBai Apr 15, 2024
3fc2758
update
ManfeiBai Apr 15, 2024
191b666
update
ManfeiBai Apr 15, 2024
d693518
update
ManfeiBai Apr 15, 2024
f8cc89e
update
ManfeiBai Apr 15, 2024
089a27f
update
ManfeiBai Apr 15, 2024
3a1fcf1
update
ManfeiBai Apr 15, 2024
ea49168
update
ManfeiBai Apr 15, 2024
ab065d5
update
ManfeiBai Apr 15, 2024
060eaac
update
ManfeiBai Apr 15, 2024
b649e7e
update
ManfeiBai Apr 15, 2024
1d901a3
update
ManfeiBai Apr 15, 2024
b62ba46
update
ManfeiBai Apr 15, 2024
b0f18f6
update
ManfeiBai Apr 15, 2024
a37f72b
update
ManfeiBai Apr 15, 2024
36fa72e
update
ManfeiBai Apr 15, 2024
6312c49
update
ManfeiBai Apr 15, 2024
4be4c23
update
ManfeiBai Apr 15, 2024
e3a2fcd
update
ManfeiBai Apr 15, 2024
abb3045
update
ManfeiBai Apr 15, 2024
95608ff
update
ManfeiBai Apr 15, 2024
c8d722e
update
ManfeiBai Apr 15, 2024
d33f14c
update
ManfeiBai Apr 16, 2024
1f8002a
update
ManfeiBai Apr 16, 2024
5cffda2
update
ManfeiBai Apr 16, 2024
f9cd4dc
update
ManfeiBai Apr 16, 2024
1fc1414
update
ManfeiBai Apr 16, 2024
9048161
update
ManfeiBai Apr 16, 2024
614a57f
update
ManfeiBai Apr 16, 2024
eded2ac
update
ManfeiBai Apr 16, 2024
504ea2b
update
ManfeiBai Apr 16, 2024
b5889a8
update
ManfeiBai Apr 16, 2024
a572c1b
update
ManfeiBai Apr 16, 2024
aeda7f7
update
ManfeiBai Apr 16, 2024
c6344a9
update
ManfeiBai Apr 16, 2024
2e7f5fb
update
ManfeiBai Apr 16, 2024
abc2c9c
update
ManfeiBai Apr 16, 2024
c7d4654
update
ManfeiBai Apr 16, 2024
5dbbffb
update
ManfeiBai Apr 16, 2024
5fca07d
update
ManfeiBai Apr 16, 2024
35ed14d
update
ManfeiBai Apr 16, 2024
b826cf9
update
ManfeiBai Apr 16, 2024
df34a6b
update
ManfeiBai Apr 16, 2024
fde54ac
update
ManfeiBai Apr 16, 2024
1f2f50e
update
ManfeiBai Apr 16, 2024
27f20b7
update
ManfeiBai Apr 16, 2024
43af215
update
ManfeiBai Apr 16, 2024
79bd303
update
ManfeiBai Apr 17, 2024
08a155e
update
ManfeiBai Apr 17, 2024
2840aef
update
ManfeiBai Apr 17, 2024
f144d25
update
ManfeiBai Apr 17, 2024
03743ed
update
ManfeiBai Apr 17, 2024
118e640
update
ManfeiBai Apr 17, 2024
915ca5a
update
ManfeiBai Apr 17, 2024
0bca16d
update
ManfeiBai Apr 17, 2024
0b04e28
update
ManfeiBai Apr 17, 2024
68e9aba
update
ManfeiBai Apr 17, 2024
94d0045
update
ManfeiBai Apr 17, 2024
b9192be
update
ManfeiBai Apr 17, 2024
1b594e6
update
ManfeiBai Apr 17, 2024
867e488
update
ManfeiBai Apr 17, 2024
7a1bde9
update
ManfeiBai Apr 17, 2024
120c335
update
ManfeiBai Apr 17, 2024
bc5950a
update
ManfeiBai Apr 17, 2024
91a3aa8
update
ManfeiBai Apr 17, 2024
7041bc3
update
ManfeiBai Apr 17, 2024
3ec8120
update
ManfeiBai Apr 17, 2024
9beac79
update
ManfeiBai Apr 17, 2024
7089193
update
ManfeiBai Apr 17, 2024
9d9bc32
update
ManfeiBai Apr 17, 2024
0d218c8
update
ManfeiBai Apr 17, 2024
bae4952
update
ManfeiBai Apr 17, 2024
7bb4791
update
ManfeiBai Apr 17, 2024
1145315
update
ManfeiBai Apr 17, 2024
bad4c1f
update
ManfeiBai Apr 17, 2024
00cd07c
update
ManfeiBai Apr 17, 2024
629f42e
update
ManfeiBai Apr 17, 2024
d7bab38
update
ManfeiBai Apr 17, 2024
2e485a5
update
ManfeiBai Apr 17, 2024
5ec1184
update
ManfeiBai Apr 17, 2024
db3c8bf
update
ManfeiBai Apr 17, 2024
14f0569
update
ManfeiBai Apr 17, 2024
d07015a
update
ManfeiBai Apr 17, 2024
55061c3
update
ManfeiBai Apr 17, 2024
f1a596a
update
ManfeiBai Apr 17, 2024
3818380
update
ManfeiBai Apr 17, 2024
b70a34d
update
ManfeiBai Apr 17, 2024
26d2fdb
update
ManfeiBai Apr 17, 2024
c8e39fe
update
ManfeiBai Apr 17, 2024
1237176
update
ManfeiBai Apr 17, 2024
620d627
update
ManfeiBai Apr 17, 2024
cfa5f7e
update
ManfeiBai Apr 17, 2024
0cb1dce
update
ManfeiBai Apr 17, 2024
2dcfab0
update
ManfeiBai Apr 17, 2024
d0270b6
update
ManfeiBai Apr 17, 2024
2377b8b
update
ManfeiBai Apr 17, 2024
8947db7
update
ManfeiBai Apr 17, 2024
0fbb23d
update
ManfeiBai Apr 17, 2024
53f8185
update
ManfeiBai Apr 17, 2024
52435ec
update
ManfeiBai Apr 17, 2024
116a68b
update
ManfeiBai Apr 17, 2024
3cb631a
update
ManfeiBai Apr 17, 2024
fe3a530
update
ManfeiBai Apr 17, 2024
ac574f3
update
ManfeiBai Apr 17, 2024
ca3e757
update
ManfeiBai Apr 17, 2024
1819420
update
ManfeiBai Apr 17, 2024
1b91fc8
update
ManfeiBai Apr 17, 2024
673145d
update
ManfeiBai Apr 17, 2024
12f6f71
update
ManfeiBai Apr 17, 2024
89605cb
format
ManfeiBai Apr 17, 2024
2e9c979
format
ManfeiBai Apr 17, 2024
6db8139
format
ManfeiBai Apr 17, 2024
da35561
format
ManfeiBai Apr 17, 2024
431ab66
format
ManfeiBai Apr 17, 2024
6244d21
format
ManfeiBai Apr 17, 2024
04ca72d
format
ManfeiBai Apr 17, 2024
33fa1fb
format
ManfeiBai Apr 17, 2024
332bd40
[rebase] rebase fori_loop_simple_case_test (#7165)
ManfeiBai May 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 130 additions & 12 deletions test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ def _fake_while_loop(cond_fn, body_fn, operands):


def _fake_fori_loop(lower, upper, body_fun, *init_val):
(plus_value, init_val) = init_val
for i in range((upper - lower)[0]):
plus_value, init_val = body_fun(plus_value, init_val)
return init_val
if len(init_val) > 1:
(a, b) = init_val
for i in range((upper - lower)[0]):
a = body_fun(a, b)
else:
for i in range((upper - lower)[0]):
a = body_fun(*init_val)
return a


class WhileLoopTest(unittest.TestCase):
Expand Down Expand Up @@ -82,25 +86,139 @@ def body_fn(init, limit_value):
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)

def test_while_loop_tpu_simple_linear(self):

xm.mark_step()
device = xm.xla_device()
torch.set_grad_enabled(False)

linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device())

def cond_fn(upper, lower, one_value, x, input_value, output_value):
return lower[0] < upper[0]

def body_fn(upper, lower, one_value, x, input_value, output_value):
new_lower = torch.add(one_value, lower)
output_value = linear_0(input_value)
weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement
bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(
one_value, x), input_value.clone(), bias.clone(), weight.clone(
Comment on lines +105 to +106
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we returning a torch.add(one_value, x) here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was used here to confirm calculation run expected times as a timer

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think our test case is too complicated, we should aim to support what pytorch support, similar to https://github.com/pytorch/pytorch/blob/8573d9551a7694b9313310412867ac3b6b751f26/test/functorch/test_control_flow.py#L137-L150.

), output_value.clone()
Comment on lines +105 to +107
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remind me why do we need to return weight and bias in this function?

Copy link
Collaborator Author

@ManfeiBai ManfeiBai Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to make sure body_fn's xlacomputation's input and output are the same, because input would include weight automatically, so here we return weight and bias from python level to ensure weight and bias are included in ouput too. Add bias to avoid output_value is used as bias in calculation, because bias has the same shape and value as output_value

but we also has plan to lower add weight and bias in xlacomputation arguments to CPP level, let me test locally too, if pass, we could avoid return weight and bias from python level

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too confusing, we need to think of a better UX. body_fn also should take linear_0 as an input instead of calling it from parent scope.

I think instead of manually return the each parameter in the module, we should just return module's named_paramter . User also shouldn't need to manually order the return parameter(this will be super confusing to the user), we should do it in our layer.


upper = torch.tensor([1], dtype=torch.int32, device=device)
lower = torch.tensor([0], dtype=torch.int32, device=device)
one_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)
l_in_0 = torch.rand(10, device=xm.xla_device())
output_value = torch.zeros([20], dtype=torch.float32, device=device)

upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop(
cond_fn, body_fn,
(upper, lower, one_value, init_val, l_in_0, output_value))

expected = _fake_fori_loop(lower, upper, linear_0, l_in_0)

return self.assertTrue(torch.all(torch.eq(expected, output_value_real__)))

def test_while_loop_tpu_simple_linear_class(self):

xm.mark_step()
device = xm.xla_device()
torch.set_grad_enabled(False)

class SimpleWithLinear(torch.nn.Module):

def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 20).to(xm.xla_device())

def forward(self, upper, lower, one_value, x, input_value, output_value):

def cond_fn(upper, lower, one_value, x, input_value, output_value):
return lower[0] < upper[0]

def body_fn(upper, lower, one_value, x, input_value, output_value):
new_lower = torch.add(one_value, lower)
output_value_real = self.linear(input_value)
weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement
bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(
one_value, x), input_value.clone(
), output_value_real, weight.clone(), bias.clone()

return while_loop(
cond_fn, body_fn,
(upper, lower, one_value, x, input_value, output_value))

simple_with_linear = SimpleWithLinear()
upper = torch.tensor([52], dtype=torch.int32, device=device)
lower = torch.tensor([0], dtype=torch.int32, device=device)
one_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)
l_in_0 = torch.rand(10, device=xm.xla_device())
output_value = torch.zeros([20], dtype=torch.float32, device=device)

weight_0 = simple_with_linear.linear.weight
bias_0 = simple_with_linear.linear.bias

aaa = {
"simple_with_linear":
(simple_with_linear, (upper, lower, one_value, init_val, l_in_0,
output_value))
}

upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear(
upper, lower, one_value, init_val, l_in_0, output_value)

# create same weight/bias liear model for compare
linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device())
linear_0.weight.data = weight__
linear_0.bias.data = bias__
expected = _fake_fori_loop(lower, upper, linear_0, l_in_0)

self.assertTrue(torch.all(torch.eq(expected, output_value_real__)))
return aaa

def test_fori_loop_tpu_addition(self):

xm.mark_step()
device = xm.xla_device()

lower = torch.tensor([2], dtype=torch.int32, device=device)
upper = torch.tensor([52], dtype=torch.int32, device=device)
plus_value = torch.tensor([1], dtype=torch.int32, device=device)
one_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)

def body_fun(a, b):
return torch.add(a, b)

upper_, new_lower_, one_value_, add_res_x_, res_ = fori_loop(
upper, lower, body_fun, one_value, init_val)
expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value)
self.assertEqual(expected, res_)

def test_fori_loop_tpu_simple_linear(self):

xm.mark_step()
device = xm.xla_device()
torch.set_grad_enabled(False)

upper = torch.tensor([52], dtype=torch.int32, device=device)
lower = torch.tensor([0], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)
l_in_0 = torch.randn(10, device=xm.xla_device())

linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device())

upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop(
upper, lower, linear_0, init_val, l_in_0)

def body_fun(*argus):
plus_value, init_val = argus
return plus_value, torch.add(plus_value, init_val)
expected = _fake_fori_loop(lower, upper, linear_0, l_in_0)

_, _, _, actual = fori_loop(upper, lower, body_fun, plus_value, init_val)
expected = _fake_fori_loop(lower, upper, body_fun, plus_value, init_val)
self.assertEqual(expected, actual)
self.assertTrue(torch.all(torch.eq(expected, l_out_)))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
sys.exit(0 if test.result.wasSuccessful() else 1)
29 changes: 23 additions & 6 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -913,14 +913,31 @@ class PyLoweringContext {
// Builds a HLO graph given a set of output tensors, and add unused parameters
// needed in xlacomputation.
void BuildForiLoop(std::vector<at::Tensor> tensors,
std::vector<at::Tensor> input_arguments = {}) {
std::vector<at::Tensor> additional_inputs_list = {}) {
// hard-code modify cond xlacomputation input arguments with unusedarguments
// for xla::while requriement
if (GetNameString() == "condctx") {
xla::XlaBuilder* local_builder = lowering_ctx.builder();
// hard-code parameter_idx to 2 to skip existing upper/lower arguments
int64_t parameter_idx = 2;
for (at::Tensor input_argument : input_arguments) {
xla::Shape shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1});
int64_t parameter_idx =
2; // parameter_idx start from 2 after used upper and lower
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move the comment above the code? Thanks!

for (auto& additional_input_tensor : additional_inputs_list) {
XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor);
xla::Shape shape = xtensor->shape().get();
xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape,
"UnusedArgumentsPlaceholder");
Comment on lines +929 to +930
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand this part, x is a local variable that will be released after the for loop, is the intention of calling xla::Parameter to init a parameter at position parameter_idx with given shape? If so can you add a comment to make it more clear?

Copy link
Collaborator Author

@ManfeiBai ManfeiBai Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, x was added here to add expected shape/type arguments in body/cond xlacomputation's arguments due to unused input arguments missed with built via LTC,

to meet XLA::While requirement: parameter of condition and body, the result of the body, and init must all have the same shape

parameter_idx += 1;
}
}

// hard-code modify body xlacomputation input arguments with unusedarguments
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, space between unused and arguments

// for xla::while requriement
if (GetNameString() == "bodyctx") {
xla::XlaBuilder* local_builder = lowering_ctx.builder();
// TODO(@manfei): treat hard code parameter_idx value
int64_t parameter_idx = 7;
for (auto& additional_input_tensor : additional_inputs_list) {
XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor);
xla::Shape shape = xtensor->shape().get();
xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape,
"UnusedArgumentsPlaceholder");
parameter_idx += 1;
Expand Down
124 changes: 90 additions & 34 deletions torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,66 +10,95 @@
from torch._ops import HigherOrderOperator
import torch._higher_order_ops.while_loop
from torch._higher_order_ops.while_loop import while_loop_op
from torch._higher_order_ops.while_loop import while_loop as torch_while_loop


def fori_loop(lower, upper, user_body_func, *init_val):
# TODO(@manfei): treat *input_value
def fori_loop(upper, lower, body_fun, init_val, input_value):

device = xm.xla_device()

def cond_fn(upper, lower, *init_val):
return lower[0] < upper[0]
one_value = torch.tensor([1], dtype=torch.int32, device=device)

def body_fn(upper, lower, *init_val):
one_value_i = torch.ones(1, dtype=torch.int32, device=device)
res_list = list(user_body_func(*init_val))
res_list.insert(0, lower)
res_list.insert(0, torch.sub(upper, one_value_i))
return res_list
if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the plan to remove this condition?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually why do we special casing the weight and bias? Is it only for the linear layer?

Copy link
Collaborator Author

@ManfeiBai ManfeiBai Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, its only for the linear layer,

special casing the weight and bias due to different body_fn return: weight/bias was not mentioned in inputs, but need to be returned or added in xlacomputation return arguments

plans to remove this condition is:

  • A: check additional_inputs(weight/bias) like PyTorch and add them into xlacomputation arguments in CPP level
  • B: only check whether additional_inputs(weight/bias) exist for code of weight/bias in the next step

output_value = torch.zeros([20], dtype=torch.float32, device=device)

def cond_fn(upper, lower, one_value, x, input_value, output_value):
return lower[0] < upper[0]

def body_fn(upper, lower, one_value, x, input_value, output_value):
new_lower = torch.add(one_value, lower)
output_value = body_fun(input_value)
weight = body_fun.weight # not be used actually, initialized as placeholder xlacomputation requirement
bias = body_fun.bias # not be used actually, initialized as placeholder xlacomputation requirement
Comment on lines +32 to +33
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, can we move the comment above the code?

return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(
one_value, x), input_value.clone(), bias.clone(), weight.clone(
), output_value.clone()
Comment on lines +34 to +36
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to clone?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

miss .clone() here would cause PyTorch ERROR: torch.while_loop's body_fn might be aliasing the input!, so add .clone() to avoid return input arguments directly


res = torch_while_loop(
cond_fn, body_fn,
(upper, lower, one_value, init_val, input_value, output_value))
else:
output_value = torch.tensor([1], dtype=torch.int32, device=device)

def cond_fn(upper, lower, one_value, x, input_value):
return lower[0] < upper[0]

def body_fn(upper, lower, one_value, x, input_value):
new_lower = torch.add(one_value, lower)
output_val = body_fun(one_value, input_value)
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(
one_value, x), output_val.clone()

res = torch_while_loop(cond_fn, body_fn,
(upper, lower, one_value, init_val, input_value))

res = while_loop(cond_fn, body_fn, (lower, upper, *init_val))
return res


@while_loop_op.py_impl(DispatchKey.XLA)
def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None):
def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None):
# TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '')
# cond_fn&body_fn: callable
# carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors)
if additional_inputs is None:
additional_inputs = tuple()
return _xla_while_loop(
cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs)
return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs)


def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs):
# untuple carried_inputs from while_loop
carried_inputs = carried_inputs[0]
def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None):
# fake carried_inputs to split formal code
fake_carried_inputs = []
for carried_input in carried_inputs:
device = carried_input.device
fake_carried_inputs.append(
torch.randint(10, carried_input.size(),
dtype=carried_input.dtype).to(device))
fake_carried_inputs = tuple(fake_carried_inputs)

# trans fake_carried_inputs from list(tensor) to list(xla::op)
kwargs = {}
if type(fake_carried_inputs) is tuple:
shapes = xb.tensor_shape(fake_carried_inputs)
else:
shapes = xb.tensor_shape((fake_carried_inputs))
builder = xb.create_builder('test_while')
params = []
for shape in shapes:
p = xb.mkparam(builder, len(params), shape)
params.append(p)
for additional_input in additional_inputs:
device = additional_input.device
fake_carried_inputs.append(
torch.randint(
10, additional_input.size(),
dtype=additional_input.dtype).to(device))

# generate cond_fn xlacomputation
# TODO(@manfei): specify which element is for which argument like a,b,c
cond_result = cond_fn(*fake_carried_inputs)
cond_ctx = torch_xla._XLAC.lowering.LoweringContext()
cond_ctx.set_name_string("condctx")
cond_ctx.buildforiloop([cond_result], list(fake_carried_inputs[2:]))

# TODO(@manfei): treat hard-code cond xlacomputation change: currently switch output_value and weight position if additional_inputs(weight/bias) exists
additional_inputs_list_cond = list(
fake_carried_inputs[2:]
) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor
if additional_inputs:
tmp_bias = additional_inputs_list_cond[
-3] # not used, change order doesn't affect logic
del additional_inputs_list_cond[
-3] # not used, change order doesn't affect logic
additional_inputs_list_cond.append(
tmp_bias) # not used, change order doesn't affect logic

cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond)
cond_hlo = cond_ctx.hlo()
cond_computation = xb.computation_from_module_proto("condcomputation",
cond_hlo)
Expand All @@ -78,11 +107,38 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs):
body_result = body_fn(*fake_carried_inputs)
body_ctx = torch_xla._XLAC.lowering.LoweringContext()
body_ctx.set_name_string("bodyctx")
body_ctx.buildforiloop(list(body_result), [])

# TODO(@manfei): treat hard-code body xlacomputation change: currently add non-changed output_value argument if additional_inputs(weight/bias) exists
if additional_inputs:
additional_inputs_list_body = [fake_carried_inputs[-3]]
else:
additional_inputs_list_body = []

# TODO(@manfei): treat hard-code parameters: additional_inputs_list_body
body_ctx.buildforiloop(list(body_result), additional_inputs_list_body)
body_hlo = body_ctx.hlo()
body_computation = xb.computation_from_module_proto("bodycomputation",
body_hlo)

# trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while
total_inputs = carried_inputs + additional_inputs
kwargs = {}
if type(total_inputs) is tuple:
shapes = xb.tensor_shape(total_inputs)
else:
shapes = xb.tensor_shape((total_inputs))
builder = xb.create_builder('test_while')
params = []
for shape in shapes:
p = xb.mkparam(builder, len(params), shape)
params.append(p)

# TODO(@manfei): treat hard-code input arguments, currently switch bias and output_value if additional_inputs(weight/bias) exists
if additional_inputs:
tmp_bias = params[-3]
del params[-3]
params.append(tmp_bias)

# generate while xlacomputation
input_tuple = xb.Op.tuple(tuple(params))
w = xb.mkop(
Expand All @@ -94,6 +150,6 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs):

# gain final result with generated while xlacomputation
result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while',
(carried_inputs), computation)
(total_inputs), computation)

return result
return result
Loading