diff --git a/test/test_fori_loop_simple_linear_model_test_code.py b/test/test_fori_loop_simple_linear_model_test_code.py index 4b98af619c4..0a6d63b94fb 100644 --- a/test/test_fori_loop_simple_linear_model_test_code.py +++ b/test/test_fori_loop_simple_linear_model_test_code.py @@ -1,6 +1,4 @@ import os -# import unittest -# from typing import Callable, Dict, List import torch import torch_xla @@ -17,15 +15,13 @@ device = xm.xla_device() # --- linear one --- -# l_in = torch.randn(10, device=xm.xla_device()) -# linear = torch.nn.Linear(10, 20).to(xm.xla_device()) -# l_out = linear(l_in) -# print("linear one: ", l_out) +l_in = torch.randn(10, device=xm.xla_device()) +linear = torch.nn.Linear(10, 20).to(xm.xla_device()) +l_out = linear(l_in) +print("linear one: ", l_out) # --- while test case --- -# lower = torch.tensor([2], dtype=torch.int32, device=device) -# upper = torch.tensor([52], dtype=torch.int32, device=device) upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([2], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device)