diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index a3b69782343b..0a2394e839c2 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -354,7 +354,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): ### !!! add duplicated bn argus as the tile of the list if len(bn_list) !=0: # additional_inputs = additional_inputs + bn_list - bn_list.reverse() + bn_list.reverse() ### !!! reverse list for bn duplicate lists additional_inputs = additional_inputs + bn_list # print("added bn_list: ", bn_list) bn_list = [] @@ -370,7 +370,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): #### conv1+bn1 # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, output_value_real__, = _xla_while_loop( ##### conv1 + bn1 + conv2 - upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, output_value_real__, = _xla_while_loop( + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, output_value_real__, = _xla_while_loop( + ##### conv1 + bn1 + conv2 + bn2 + upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, output_value_real__, = _xla_while_loop( + cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) # (upper, lower, one_value, init_val, l_in_0, output_value), ())