Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Apr 10, 2024
1 parent 0c792c3 commit 2b67140
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 2 additions & 2 deletions test/test_fori_loop_simple_linear_model_test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def body_fun(l_in_i):
# lower_, upper_, one_value_, add_res_, l_out_res_, weight_, final_one_= fori_loop(upper, lower, body_fun, one_value, init_val, l_in_0) # , placeholder_func, placeholder_input) # , linear_0, l_in_0)
# lower_, upper_, one_value_, add_res_, l_out_res_, weight_, final_one_= fori_loop(lower, upper, body_fun, one_value, init_val, l_in_0) # , placeholder_func, placeholder_input) # , linear_0, l_in_0)
# one_value_, upper_, lower_, add_res_x_, bias_, weight_, l_in_i_plus_1_, l_out_= fori_loop(upper, lower, body_fun, one_value, init_val, l_in_0) # , placeholder_func, placeholder_input) # , linear_0, l_in_0)
weight_0 =linear_0.weight
bias_0 = linear_0.bias
# weight_0 =linear_0.weight
# bias_0 = linear_0.bias
# one_value_, lower_, upper_, add_res_x_, bias_, weight_, l_in_i_plus_1_, l_out_= fori_loop(one_value, lower, upper, body_fun, init_val, l_in_0, weight_0=weight_0, bias_0=bias_0) # , placeholder_func, placeholder_input) # , linear_0, l_in_0)
one_value_, lower_, upper_, add_res_x_, bias_, weight_, l_in_i_plus_1_, l_out_= fori_loop(one_value, lower, upper, linear_0, init_val, l_in_0) #, weight_0=weight_0, bias_0=bias_0) # , placeholder_func, placeholder_input) # , linear_0, l_in_0)
#one_value, [upper],[lower],x, [bias],[new_weight], [l_in_i+1], l_out
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@


# TODO(@manfei): delete one_value?
def fori_loop(one_value, lower, upper, body_fun, init_val, *input_value, weight_0, bias_0):
def fori_loop(one_value, lower, upper, body_fun, init_val, *input_value): #, weight_0, bias_0):

device = xm.xla_device()
weight_0 =linear_0.weight
bias_0 = linear_0.bias

# a = torch.tensor(1, dtype=torch.int32, device=device)
# b = torch.tensor(1, dtype=torch.int32, device=device)
Expand Down

0 comments on commit 2b67140

Please sign in to comment.