Skip to content

Commit

Permalink
mnist
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed May 1, 2024
1 parent 5004a74 commit 5fb1769
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
16 changes: 12 additions & 4 deletions test/test_test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self):
self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5).to(xm.xla_device())
self.bn2 = torch.nn.BatchNorm2d(20).to(xm.xla_device())
self.fc1 = torch.nn.Linear(500, 50).to(xm.xla_device())
self.fc2 = torch.nn.Linear(50, 10).to(xm.xla_device())
# self.fc1 = torch.nn.Linear(320, 50).to(xm.xla_device())
# self.linear = torch.nn.Linear(10, 20).to(xm.xla_device())
# self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device())
Expand All @@ -74,7 +75,9 @@ def forward(self, x):
x = self.bn2(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
return x
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# return x

class SimpleWithLinear(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -328,7 +331,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args):
# output_value = torch.zeros([16,10,28,28], dtype=torch.float32, device=device) # conv2d
# output_value = torch.zeros([16,10,14,14], dtype=torch.float32, device=device) # conv2d+mnist-treat # conv1 + bn1
# output_value = torch.zeros([16,20,5,5], dtype=torch.float32, device=device) # conv1 + bn1 + conv2
output_value = torch.zeros([16,500], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1
# output_value = torch.zeros([16,500], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1
# output_value = torch.zeros([16,50], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1 + fc1
output_value = torch.zeros([16,10], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1 + fc1


additional_inputs = []
Expand Down Expand Up @@ -377,8 +382,11 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args):
##### conv1 + bn1 + conv2
# upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, output_value_real__, = _xla_while_loop(
##### conv1 + bn1 + conv2 + bn2
upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, output_value_real__, = _xla_while_loop(

# upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, output_value_real__, = _xla_while_loop(
##### conv1 + bn1 + conv2 + bn2 + fc1
# upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, p7, p8, output_value_real__, = _xla_while_loop(
##### conv1 + bn1 + conv2 + bn2 + fc1 + fc2 + softmax
upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, output_value_real__, = _xla_while_loop(
cond_fn, body_fn,
(upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs))
# (upper, lower, one_value, init_val, l_in_0, output_value), ())
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,8 @@ class PyLoweringContext {
// int64_t parameter_idx = 7; // conv2d
// int64_t parameter_idx = 11; // conv2d+mnist-treat // conv1 + bn1
// int64_t parameter_idx = 13; // conv1 + bn1 + conv2
int64_t parameter_idx = 19; // conv1 + bn1 + conv2 + bn2
// int64_t parameter_idx = 19; // conv1 + bn1 + conv2 + bn2
int64_t parameter_idx = 21; // conv1 + bn1 + conv2 + bn2
// int64_t parameter_idx = 9; // linear
// int64_t parameter_idx = tensors.size();
for (auto& additional_input_tensor : additional_inputs_list) {
Expand Down

0 comments on commit 5fb1769

Please sign in to comment.