From 5fb1769182183dc20edb85025a461e0b131caa7d Mon Sep 17 00:00:00 2001 From: root Date: Wed, 1 May 2024 18:41:45 +0000 Subject: [PATCH] mnist --- test/test_test_mnist.py | 16 ++++++++++++---- torch_xla/csrc/init_python_bindings.cpp | 3 ++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index a171942cf3d..c3da2020635 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -57,6 +57,7 @@ def __init__(self): self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5).to(xm.xla_device()) self.bn2 = torch.nn.BatchNorm2d(20).to(xm.xla_device()) self.fc1 = torch.nn.Linear(500, 50).to(xm.xla_device()) + self.fc2 = torch.nn.Linear(50, 10).to(xm.xla_device()) # self.fc1 = torch.nn.Linear(320, 50).to(xm.xla_device()) # self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) @@ -74,7 +75,9 @@ def forward(self, x): x = self.bn2(x) x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) - return x + x = self.fc2(x) + return F.log_softmax(x, dim=1) + # return x class SimpleWithLinear(torch.nn.Module): def __init__(self): @@ -328,7 +331,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # output_value = torch.zeros([16,10,28,28], dtype=torch.float32, device=device) # conv2d # output_value = torch.zeros([16,10,14,14], dtype=torch.float32, device=device) # conv2d+mnist-treat # conv1 + bn1 # output_value = torch.zeros([16,20,5,5], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 - output_value = torch.zeros([16,500], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1 + # output_value = torch.zeros([16,500], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1 + # output_value = torch.zeros([16,50], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1 + fc1 + output_value = torch.zeros([16,10], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1 + fc1 additional_inputs = [] @@ -377,8 +382,11 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): ##### conv1 + bn1 + conv2 # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, output_value_real__, = _xla_while_loop( ##### conv1 + bn1 + conv2 + bn2 - upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, output_value_real__, = _xla_while_loop( - + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, output_value_real__, = _xla_while_loop( + ##### conv1 + bn1 + conv2 + bn2 + fc1 + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, p7, p8, output_value_real__, = _xla_while_loop( + ##### conv1 + bn1 + conv2 + bn2 + fc1 + fc2 + softmax + upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, output_value_real__, = _xla_while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) # (upper, lower, one_value, init_val, l_in_0, output_value), ()) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ddd7045669b..a1adab814fb 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -941,7 +941,8 @@ class PyLoweringContext { // int64_t parameter_idx = 7; // conv2d // int64_t parameter_idx = 11; // conv2d+mnist-treat // conv1 + bn1 // int64_t parameter_idx = 13; // conv1 + bn1 + conv2 - int64_t parameter_idx = 19; // conv1 + bn1 + conv2 + bn2 + // int64_t parameter_idx = 19; // conv1 + bn1 + conv2 + bn2 + int64_t parameter_idx = 21; // conv1 + bn1 + conv2 + bn2 // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) {