From bb598d979f75cc129dbeac61adf87ba034af171e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 10 Apr 2024 23:45:21 +0000 Subject: [PATCH] test --- test/test_fori_loop_simple_linear_model_test_code.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/test/test_fori_loop_simple_linear_model_test_code.py b/test/test_fori_loop_simple_linear_model_test_code.py index 4b98af619c4..0a6d63b94fb 100644 --- a/test/test_fori_loop_simple_linear_model_test_code.py +++ b/test/test_fori_loop_simple_linear_model_test_code.py @@ -1,6 +1,4 @@ import os -# import unittest -# from typing import Callable, Dict, List import torch import torch_xla @@ -17,15 +15,13 @@ device = xm.xla_device() # --- linear one --- -# l_in = torch.randn(10, device=xm.xla_device()) -# linear = torch.nn.Linear(10, 20).to(xm.xla_device()) -# l_out = linear(l_in) -# print("linear one: ", l_out) +l_in = torch.randn(10, device=xm.xla_device()) +linear = torch.nn.Linear(10, 20).to(xm.xla_device()) +l_out = linear(l_in) +print("linear one: ", l_out) # --- while test case --- -# lower = torch.tensor([2], dtype=torch.int32, device=device) -# upper = torch.tensor([52], dtype=torch.int32, device=device) upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([2], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device)