Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Apr 10, 2024
1 parent e8d227e commit 81d3c88
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions test/test_fori_loop_simple_linear_model_test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
# We need to import the underlying implementation function to register with the dispatcher
import torch_xla.experimental.fori_loop
from torch_xla.experimental.fori_loop import fori_loop
# from torch._higher_order_ops.while_loop import while_loop
import torch_xla.core.xla_model as xm
# import torch_xla.core.xla_builder as xb
import torch_xla.utils.utils as xu

torch.set_grad_enabled(False)
Expand All @@ -21,10 +19,8 @@
print("linear one: ", l_out)

# --- while test case ---

upper = torch.tensor([52], dtype=torch.int32, device=device)
lower = torch.tensor([2], dtype=torch.int32, device=device)
# one_value = torch.tensor([1], dtype=torch.int32, device=device)
upper = torch.tensor([5], dtype=torch.int32, device=device)
lower = torch.tensor([0], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)

linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device())
Expand All @@ -44,6 +40,8 @@ def body_fun(l_in_i):
print("l_in_i_plus_1_: ", l_in_i_plus_1_)
print("l_out_: ", l_out_)



# --- linear two ---
l_in_2 = torch.randn(10, device=xm.xla_device())
linear_2 = torch.nn.Linear(10, 20).to(xm.xla_device())
Expand Down

0 comments on commit 81d3c88

Please sign in to comment.