-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fori_loop|While_loop] Enable while_loop/fori_loop, Add linear/MNIST test case #6867
Changes from 250 commits
c4fe122
80d2003
71718e1
e8e18f7
9741e8d
26e30a3
c6f8259
5b9378c
6408879
7c408a9
9c68528
141c704
d5e5537
cfbe475
b11f015
9cd8ca0
89a2b34
9e420ad
a536a3e
b2b246f
2664e23
4f28a54
5c993af
8298841
32af47b
448de12
abe4c78
1bf37d8
2021537
83c2f97
1945952
6ee5400
3db127e
dc1837d
ea9bf56
ee05a67
67f2652
3bba9fd
2e349a6
7b63eb8
62855cc
70a29c4
863ba37
4da7ede
261ec24
334036e
acc2e20
3e7172c
ead5981
4762c55
e4ff32f
7215162
f275c0e
2958d83
9db53a3
d35e48e
25ac04e
3f473ad
b31e280
845d903
13157d9
41d3ba6
52cd5bc
a3756e6
21ba83d
24f2fe3
4687bbe
8713b1c
05f1c33
d96d735
e12bda6
8ed8f12
094a4ca
529a8c8
57083ae
630caa9
406635b
9925465
9b1b7a7
3cf0200
1b06687
13b7767
90f6df1
732876b
1ec25fe
fdb7cab
17d43ef
0016420
4274575
4fd2e4a
b062900
c2f2973
8dbde24
066047c
374d16d
1f30e26
fc85bc3
6ebda73
55c2ea1
36d2565
4a500c0
3a2553e
0998b37
267e8b3
a577859
a3ee72a
293b87a
128a3dc
f5298a5
e4104a4
2ef7d32
4affdd7
3254253
92887e7
98425a8
862e3f1
6b07e22
1edc7bd
4cfa522
edb6fc7
99a6589
a34cd6f
be562ad
11c1b54
2480b5b
cce486d
38d30b7
a85a8e3
8b73e43
839f5d1
4b96b9b
d1e141a
5f73913
3fc2758
191b666
d693518
f8cc89e
089a27f
3a1fcf1
ea49168
ab065d5
060eaac
b649e7e
1d901a3
b62ba46
b0f18f6
a37f72b
36fa72e
6312c49
4be4c23
e3a2fcd
abb3045
95608ff
c8d722e
d33f14c
1f8002a
5cffda2
f9cd4dc
1fc1414
9048161
614a57f
eded2ac
504ea2b
b5889a8
a572c1b
aeda7f7
c6344a9
2e7f5fb
abc2c9c
c7d4654
5dbbffb
5fca07d
35ed14d
b826cf9
df34a6b
fde54ac
1f2f50e
27f20b7
43af215
79bd303
08a155e
2840aef
f144d25
03743ed
118e640
915ca5a
0bca16d
0b04e28
68e9aba
94d0045
b9192be
1b594e6
867e488
7a1bde9
120c335
bc5950a
91a3aa8
7041bc3
3ec8120
9beac79
7089193
9d9bc32
0d218c8
bae4952
7bb4791
1145315
bad4c1f
00cd07c
629f42e
d7bab38
2e485a5
5ec1184
db3c8bf
14f0569
d07015a
55061c3
f1a596a
3818380
b70a34d
26d2fdb
c8e39fe
1237176
620d627
cfa5f7e
0cb1dce
2dcfab0
d0270b6
2377b8b
8947db7
0fbb23d
53f8185
52435ec
116a68b
3cb631a
fe3a530
ac574f3
ca3e757
1819420
1b91fc8
673145d
12f6f71
89605cb
2e9c979
6db8139
da35561
431ab66
6244d21
04ca72d
33fa1fb
332bd40
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,10 +20,14 @@ def _fake_while_loop(cond_fn, body_fn, operands): | |
|
||
|
||
def _fake_fori_loop(lower, upper, body_fun, *init_val): | ||
(plus_value, init_val) = init_val | ||
for i in range((upper - lower)[0]): | ||
plus_value, init_val = body_fun(plus_value, init_val) | ||
return init_val | ||
if len(init_val) > 1: | ||
(a, b) = init_val | ||
for i in range((upper - lower)[0]): | ||
a = body_fun(a, b) | ||
else: | ||
for i in range((upper - lower)[0]): | ||
a = body_fun(*init_val) | ||
return a | ||
|
||
|
||
class WhileLoopTest(unittest.TestCase): | ||
|
@@ -82,25 +86,139 @@ def body_fn(init, limit_value): | |
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) | ||
self.assertEqual(expected, res) | ||
|
||
def test_while_loop_tpu_simple_linear(self): | ||
|
||
xm.mark_step() | ||
device = xm.xla_device() | ||
torch.set_grad_enabled(False) | ||
|
||
linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) | ||
|
||
def cond_fn(upper, lower, one_value, x, input_value, output_value): | ||
return lower[0] < upper[0] | ||
|
||
def body_fn(upper, lower, one_value, x, input_value, output_value): | ||
new_lower = torch.add(one_value, lower) | ||
output_value = linear_0(input_value) | ||
weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement | ||
bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement | ||
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( | ||
one_value, x), input_value.clone(), bias.clone(), weight.clone( | ||
), output_value.clone() | ||
Comment on lines
+105
to
+107
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remind me why do we need to return weight and bias in this function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to make sure body_fn's xlacomputation's input and output are the same, because input would include weight automatically, so here we return weight and bias from python level to ensure weight and bias are included in ouput too. Add bias to avoid output_value is used as bias in calculation, because bias has the same shape and value as output_value but we also has plan to lower add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is too confusing, we need to think of a better UX. body_fn also should take I think instead of manually return the each parameter in the module, we should just return module's |
||
|
||
upper = torch.tensor([1], dtype=torch.int32, device=device) | ||
lower = torch.tensor([0], dtype=torch.int32, device=device) | ||
one_value = torch.tensor([1], dtype=torch.int32, device=device) | ||
init_val = torch.tensor([1], dtype=torch.int32, device=device) | ||
l_in_0 = torch.rand(10, device=xm.xla_device()) | ||
output_value = torch.zeros([20], dtype=torch.float32, device=device) | ||
|
||
upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop( | ||
cond_fn, body_fn, | ||
(upper, lower, one_value, init_val, l_in_0, output_value)) | ||
|
||
expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) | ||
|
||
return self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) | ||
|
||
def test_while_loop_tpu_simple_linear_class(self): | ||
|
||
xm.mark_step() | ||
device = xm.xla_device() | ||
torch.set_grad_enabled(False) | ||
|
||
class SimpleWithLinear(torch.nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) | ||
|
||
def forward(self, upper, lower, one_value, x, input_value, output_value): | ||
|
||
def cond_fn(upper, lower, one_value, x, input_value, output_value): | ||
return lower[0] < upper[0] | ||
|
||
def body_fn(upper, lower, one_value, x, input_value, output_value): | ||
new_lower = torch.add(one_value, lower) | ||
output_value_real = self.linear(input_value) | ||
weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement | ||
bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement | ||
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( | ||
one_value, x), input_value.clone( | ||
), output_value_real, weight.clone(), bias.clone() | ||
|
||
return while_loop( | ||
cond_fn, body_fn, | ||
(upper, lower, one_value, x, input_value, output_value)) | ||
|
||
simple_with_linear = SimpleWithLinear() | ||
upper = torch.tensor([52], dtype=torch.int32, device=device) | ||
lower = torch.tensor([0], dtype=torch.int32, device=device) | ||
one_value = torch.tensor([1], dtype=torch.int32, device=device) | ||
init_val = torch.tensor([1], dtype=torch.int32, device=device) | ||
l_in_0 = torch.rand(10, device=xm.xla_device()) | ||
output_value = torch.zeros([20], dtype=torch.float32, device=device) | ||
|
||
weight_0 = simple_with_linear.linear.weight | ||
bias_0 = simple_with_linear.linear.bias | ||
|
||
aaa = { | ||
"simple_with_linear": | ||
(simple_with_linear, (upper, lower, one_value, init_val, l_in_0, | ||
output_value)) | ||
} | ||
|
||
upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear( | ||
upper, lower, one_value, init_val, l_in_0, output_value) | ||
|
||
# create same weight/bias liear model for compare | ||
linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) | ||
linear_0.weight.data = weight__ | ||
linear_0.bias.data = bias__ | ||
expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) | ||
|
||
self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) | ||
return aaa | ||
|
||
def test_fori_loop_tpu_addition(self): | ||
|
||
xm.mark_step() | ||
device = xm.xla_device() | ||
|
||
lower = torch.tensor([2], dtype=torch.int32, device=device) | ||
upper = torch.tensor([52], dtype=torch.int32, device=device) | ||
plus_value = torch.tensor([1], dtype=torch.int32, device=device) | ||
one_value = torch.tensor([1], dtype=torch.int32, device=device) | ||
init_val = torch.tensor([1], dtype=torch.int32, device=device) | ||
|
||
def body_fun(a, b): | ||
return torch.add(a, b) | ||
|
||
upper_, new_lower_, one_value_, add_res_x_, res_ = fori_loop( | ||
upper, lower, body_fun, one_value, init_val) | ||
expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value) | ||
self.assertEqual(expected, res_) | ||
|
||
def test_fori_loop_tpu_simple_linear(self): | ||
|
||
xm.mark_step() | ||
device = xm.xla_device() | ||
torch.set_grad_enabled(False) | ||
|
||
upper = torch.tensor([52], dtype=torch.int32, device=device) | ||
lower = torch.tensor([0], dtype=torch.int32, device=device) | ||
init_val = torch.tensor([1], dtype=torch.int32, device=device) | ||
l_in_0 = torch.randn(10, device=xm.xla_device()) | ||
|
||
linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) | ||
|
||
upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop( | ||
upper, lower, linear_0, init_val, l_in_0) | ||
|
||
def body_fun(*argus): | ||
plus_value, init_val = argus | ||
return plus_value, torch.add(plus_value, init_val) | ||
expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) | ||
|
||
_, _, _, actual = fori_loop(upper, lower, body_fun, plus_value, init_val) | ||
expected = _fake_fori_loop(lower, upper, body_fun, plus_value, init_val) | ||
self.assertEqual(expected, actual) | ||
self.assertTrue(torch.all(torch.eq(expected, l_out_))) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -913,14 +913,31 @@ class PyLoweringContext { | |
// Builds a HLO graph given a set of output tensors, and add unused parameters | ||
// needed in xlacomputation. | ||
void BuildForiLoop(std::vector<at::Tensor> tensors, | ||
std::vector<at::Tensor> input_arguments = {}) { | ||
std::vector<at::Tensor> additional_inputs_list = {}) { | ||
// hard-code modify cond xlacomputation input arguments with unusedarguments | ||
// for xla::while requriement | ||
if (GetNameString() == "condctx") { | ||
xla::XlaBuilder* local_builder = lowering_ctx.builder(); | ||
// hard-code parameter_idx to 2 to skip existing upper/lower arguments | ||
int64_t parameter_idx = 2; | ||
for (at::Tensor input_argument : input_arguments) { | ||
xla::Shape shape = | ||
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); | ||
int64_t parameter_idx = | ||
2; // parameter_idx start from 2 after used upper and lower | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you move the comment above the code? Thanks! |
||
for (auto& additional_input_tensor : additional_inputs_list) { | ||
XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); | ||
xla::Shape shape = xtensor->shape().get(); | ||
xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, | ||
"UnusedArgumentsPlaceholder"); | ||
Comment on lines
+929
to
+930
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't quite understand this part, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, x was added here to add expected shape/type arguments in body/cond xlacomputation's arguments due to unused input arguments missed with built via LTC, to meet |
||
parameter_idx += 1; | ||
} | ||
} | ||
|
||
// hard-code modify body xlacomputation input arguments with unusedarguments | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit, space between |
||
// for xla::while requriement | ||
if (GetNameString() == "bodyctx") { | ||
xla::XlaBuilder* local_builder = lowering_ctx.builder(); | ||
// TODO(@manfei): treat hard code parameter_idx value | ||
int64_t parameter_idx = 7; | ||
for (auto& additional_input_tensor : additional_inputs_list) { | ||
XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); | ||
xla::Shape shape = xtensor->shape().get(); | ||
xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, | ||
"UnusedArgumentsPlaceholder"); | ||
parameter_idx += 1; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,66 +10,95 @@ | |
from torch._ops import HigherOrderOperator | ||
import torch._higher_order_ops.while_loop | ||
from torch._higher_order_ops.while_loop import while_loop_op | ||
from torch._higher_order_ops.while_loop import while_loop as torch_while_loop | ||
|
||
|
||
def fori_loop(lower, upper, user_body_func, *init_val): | ||
# TODO(@manfei): treat *input_value | ||
def fori_loop(upper, lower, body_fun, init_val, input_value): | ||
|
||
device = xm.xla_device() | ||
|
||
def cond_fn(upper, lower, *init_val): | ||
return lower[0] < upper[0] | ||
one_value = torch.tensor([1], dtype=torch.int32, device=device) | ||
|
||
def body_fn(upper, lower, *init_val): | ||
one_value_i = torch.ones(1, dtype=torch.int32, device=device) | ||
res_list = list(user_body_func(*init_val)) | ||
res_list.insert(0, lower) | ||
res_list.insert(0, torch.sub(upper, one_value_i)) | ||
return res_list | ||
if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the plan to remove this condition? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually why do we special casing the weight and bias? Is it only for the linear layer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, its only for the linear layer, special casing the weight and bias due to different body_fn return: weight/bias was not mentioned in inputs, but need to be returned or added in xlacomputation return arguments plans to remove this condition is:
|
||
output_value = torch.zeros([20], dtype=torch.float32, device=device) | ||
|
||
def cond_fn(upper, lower, one_value, x, input_value, output_value): | ||
return lower[0] < upper[0] | ||
|
||
def body_fn(upper, lower, one_value, x, input_value, output_value): | ||
new_lower = torch.add(one_value, lower) | ||
output_value = body_fun(input_value) | ||
weight = body_fun.weight # not be used actually, initialized as placeholder xlacomputation requirement | ||
bias = body_fun.bias # not be used actually, initialized as placeholder xlacomputation requirement | ||
Comment on lines
+32
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit, can we move the comment above the code? |
||
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( | ||
one_value, x), input_value.clone(), bias.clone(), weight.clone( | ||
), output_value.clone() | ||
Comment on lines
+34
to
+36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need to clone? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. miss |
||
|
||
res = torch_while_loop( | ||
cond_fn, body_fn, | ||
(upper, lower, one_value, init_val, input_value, output_value)) | ||
else: | ||
output_value = torch.tensor([1], dtype=torch.int32, device=device) | ||
|
||
def cond_fn(upper, lower, one_value, x, input_value): | ||
return lower[0] < upper[0] | ||
|
||
def body_fn(upper, lower, one_value, x, input_value): | ||
new_lower = torch.add(one_value, lower) | ||
output_val = body_fun(one_value, input_value) | ||
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( | ||
one_value, x), output_val.clone() | ||
|
||
res = torch_while_loop(cond_fn, body_fn, | ||
(upper, lower, one_value, init_val, input_value)) | ||
|
||
res = while_loop(cond_fn, body_fn, (lower, upper, *init_val)) | ||
return res | ||
|
||
|
||
@while_loop_op.py_impl(DispatchKey.XLA) | ||
def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): | ||
def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): | ||
# TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') | ||
# cond_fn&body_fn: callable | ||
# carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) | ||
if additional_inputs is None: | ||
additional_inputs = tuple() | ||
return _xla_while_loop( | ||
cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) | ||
return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) | ||
|
||
|
||
def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): | ||
# untuple carried_inputs from while_loop | ||
carried_inputs = carried_inputs[0] | ||
def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): | ||
# fake carried_inputs to split formal code | ||
fake_carried_inputs = [] | ||
for carried_input in carried_inputs: | ||
device = carried_input.device | ||
fake_carried_inputs.append( | ||
torch.randint(10, carried_input.size(), | ||
dtype=carried_input.dtype).to(device)) | ||
fake_carried_inputs = tuple(fake_carried_inputs) | ||
|
||
# trans fake_carried_inputs from list(tensor) to list(xla::op) | ||
kwargs = {} | ||
if type(fake_carried_inputs) is tuple: | ||
shapes = xb.tensor_shape(fake_carried_inputs) | ||
else: | ||
shapes = xb.tensor_shape((fake_carried_inputs)) | ||
builder = xb.create_builder('test_while') | ||
params = [] | ||
for shape in shapes: | ||
p = xb.mkparam(builder, len(params), shape) | ||
params.append(p) | ||
for additional_input in additional_inputs: | ||
device = additional_input.device | ||
fake_carried_inputs.append( | ||
torch.randint( | ||
10, additional_input.size(), | ||
dtype=additional_input.dtype).to(device)) | ||
|
||
# generate cond_fn xlacomputation | ||
# TODO(@manfei): specify which element is for which argument like a,b,c | ||
cond_result = cond_fn(*fake_carried_inputs) | ||
cond_ctx = torch_xla._XLAC.lowering.LoweringContext() | ||
cond_ctx.set_name_string("condctx") | ||
cond_ctx.buildforiloop([cond_result], list(fake_carried_inputs[2:])) | ||
|
||
# TODO(@manfei): treat hard-code cond xlacomputation change: currently switch output_value and weight position if additional_inputs(weight/bias) exists | ||
additional_inputs_list_cond = list( | ||
fake_carried_inputs[2:] | ||
) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor | ||
if additional_inputs: | ||
tmp_bias = additional_inputs_list_cond[ | ||
-3] # not used, change order doesn't affect logic | ||
del additional_inputs_list_cond[ | ||
-3] # not used, change order doesn't affect logic | ||
additional_inputs_list_cond.append( | ||
tmp_bias) # not used, change order doesn't affect logic | ||
|
||
cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) | ||
cond_hlo = cond_ctx.hlo() | ||
cond_computation = xb.computation_from_module_proto("condcomputation", | ||
cond_hlo) | ||
|
@@ -78,11 +107,38 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): | |
body_result = body_fn(*fake_carried_inputs) | ||
body_ctx = torch_xla._XLAC.lowering.LoweringContext() | ||
body_ctx.set_name_string("bodyctx") | ||
body_ctx.buildforiloop(list(body_result), []) | ||
|
||
# TODO(@manfei): treat hard-code body xlacomputation change: currently add non-changed output_value argument if additional_inputs(weight/bias) exists | ||
if additional_inputs: | ||
additional_inputs_list_body = [fake_carried_inputs[-3]] | ||
else: | ||
additional_inputs_list_body = [] | ||
|
||
# TODO(@manfei): treat hard-code parameters: additional_inputs_list_body | ||
body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) | ||
body_hlo = body_ctx.hlo() | ||
body_computation = xb.computation_from_module_proto("bodycomputation", | ||
body_hlo) | ||
|
||
# trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while | ||
total_inputs = carried_inputs + additional_inputs | ||
kwargs = {} | ||
if type(total_inputs) is tuple: | ||
shapes = xb.tensor_shape(total_inputs) | ||
else: | ||
shapes = xb.tensor_shape((total_inputs)) | ||
builder = xb.create_builder('test_while') | ||
params = [] | ||
for shape in shapes: | ||
p = xb.mkparam(builder, len(params), shape) | ||
params.append(p) | ||
|
||
# TODO(@manfei): treat hard-code input arguments, currently switch bias and output_value if additional_inputs(weight/bias) exists | ||
if additional_inputs: | ||
tmp_bias = params[-3] | ||
del params[-3] | ||
params.append(tmp_bias) | ||
|
||
# generate while xlacomputation | ||
input_tuple = xb.Op.tuple(tuple(params)) | ||
w = xb.mkop( | ||
|
@@ -94,6 +150,6 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): | |
|
||
# gain final result with generated while xlacomputation | ||
result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', | ||
(carried_inputs), computation) | ||
(total_inputs), computation) | ||
|
||
return result | ||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we returning a
torch.add(one_value, x)
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was used here to confirm calculation run expected times as a timer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think our test case is too complicated, we should aim to support what pytorch support, similar to https://github.com/pytorch/pytorch/blob/8573d9551a7694b9313310412867ac3b6b751f26/test/functorch/test_control_flow.py#L137-L150.