diff --git a/test/test_fori_loop_simple_linear_model_test_code.py b/test/test_fori_loop_simple_linear_model_test_code.py index 3fa988ded846..f43a040fb2d6 100644 --- a/test/test_fori_loop_simple_linear_model_test_code.py +++ b/test/test_fori_loop_simple_linear_model_test_code.py @@ -81,7 +81,7 @@ def body_fun(l_in_i): weight_0 =linear_0.weight bias_0 = linear_0.bias # one_value_, lower_, upper_, add_res_x_, bias_, weight_, l_in_i_plus_1_, l_out_= fori_loop(one_value, lower, upper, body_fun, init_val, l_in_0, weight_0=weight_0, bias_0=bias_0) # , placeholder_func, placeholder_input) # , linear_0, l_in_0) -one_value_, lower_, upper_, add_res_x_, bias_, weight_, l_in_i_plus_1_, l_out_= fori_loop(one_value, lower, upper, linear_0, init_val, l_in_0, weight_0=weight_0, bias_0=bias_0) # , placeholder_func, placeholder_input) # , linear_0, l_in_0) +one_value_, lower_, upper_, add_res_x_, bias_, weight_, l_in_i_plus_1_, l_out_= fori_loop(one_value, lower, upper, linear_0, init_val, l_in_0) #, weight_0=weight_0, bias_0=bias_0) # , placeholder_func, placeholder_input) # , linear_0, l_in_0) #one_value, [upper],[lower],x, [bias],[new_weight], [l_in_i+1], l_out print("one_value_: ", one_value_) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index d98b5d54412c..f07f88825341 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -90,7 +90,8 @@ def body_fn(one_value, lower, upper, x, bias_0, weight_0, *input_value, output_v return_list.append(bias) # TODO(@manfei): should initialize weight with torch.nn.linear's real weight, currently we use placeholder for all weight are 1 # weight = torch.ones([20, 10], dtype=torch.float32, device=device) # f32[20,10] # ??? - weight = weight_0 + # weight = weight_0 + weight = body_fun.weight # return_list.append(weight) # return_list.insert(-1, weight) return_list.append(weight)