diff --git a/test/test_fori_loop_simple_linear_model_test_code.py b/test/test_fori_loop_simple_linear_model_test_code.py index f43a040fb2d..e57c1e3f013 100644 --- a/test/test_fori_loop_simple_linear_model_test_code.py +++ b/test/test_fori_loop_simple_linear_model_test_code.py @@ -78,8 +78,8 @@ def body_fun(l_in_i): # lower_, upper_, one_value_, add_res_, l_out_res_, weight_, final_one_= fori_loop(upper, lower, body_fun, one_value, init_val, l_in_0) # , placeholder_func, placeholder_input) # , linear_0, l_in_0) # lower_, upper_, one_value_, add_res_, l_out_res_, weight_, final_one_= fori_loop(lower, upper, body_fun, one_value, init_val, l_in_0) # , placeholder_func, placeholder_input) # , linear_0, l_in_0) # one_value_, upper_, lower_, add_res_x_, bias_, weight_, l_in_i_plus_1_, l_out_= fori_loop(upper, lower, body_fun, one_value, init_val, l_in_0) # , placeholder_func, placeholder_input) # , linear_0, l_in_0) -weight_0 =linear_0.weight -bias_0 = linear_0.bias +# weight_0 =linear_0.weight +# bias_0 = linear_0.bias # one_value_, lower_, upper_, add_res_x_, bias_, weight_, l_in_i_plus_1_, l_out_= fori_loop(one_value, lower, upper, body_fun, init_val, l_in_0, weight_0=weight_0, bias_0=bias_0) # , placeholder_func, placeholder_input) # , linear_0, l_in_0) one_value_, lower_, upper_, add_res_x_, bias_, weight_, l_in_i_plus_1_, l_out_= fori_loop(one_value, lower, upper, linear_0, init_val, l_in_0) #, weight_0=weight_0, bias_0=bias_0) # , placeholder_func, placeholder_input) # , linear_0, l_in_0) #one_value, [upper],[lower],x, [bias],[new_weight], [l_in_i+1], l_out diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index f07f8882534..6baccd7253f 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -13,9 +13,11 @@ # TODO(@manfei): delete one_value? -def fori_loop(one_value, lower, upper, body_fun, init_val, *input_value, weight_0, bias_0): +def fori_loop(one_value, lower, upper, body_fun, init_val, *input_value): #, weight_0, bias_0): device = xm.xla_device() + weight_0 =linear_0.weight + bias_0 = linear_0.bias # a = torch.tensor(1, dtype=torch.int32, device=device) # b = torch.tensor(1, dtype=torch.int32, device=device)