Skip to content

Commit

Permalink
save it before change
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed May 20, 2024
1 parent 76c2e7a commit 53c72f7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -381,28 +381,32 @@ def forward(self, iteri, x, y):
def cond_fn(iteri, x, y):
return iteri > 0

def body_fn(iteri, x):
def body_fn(iteri, x, y):

y = F.relu(F.max_pool2d(self.conv1(x), 2))
# y = self.bn1(y)
y = self.conv1(x)

# y = F.relu(F.max_pool2d(self.conv1(x), 2))
# y = self.bn1(y.clone())
# y = F.relu(F.max_pool2d(self.conv2(y), 2))
# y = self.bn2(y)
# y = torch.flatten(y, 1)
# y = F.relu(self.fc1(y))
# y = self.fc2(y)

# return iteri - 1, F.log_softmax(y, dim=1)
return iteri - 1, y
return iteri - 1, x.clone(), y # torch.while_loop's body_fn might be aliasing the input!
# return iteri - 1, x.clone(), F.log_softmax(y, dim=1) # torch.while_loop's body_fn might be aliasing the input!

return while_loop(cond_fn, body_fn, (iteri, x))
return while_loop(cond_fn, body_fn, (iteri, x, y))

mnist = MNIST()
mnist.to(device)
bs=16
# l_in_0 = torch.randn(bs, 1, 28, 28, dtype=torch.float32, device=device)
l_in_0 = torch.randn(16, 1, 28, 28, dtype=torch.float32, device=device)
l_out = torch.randn(16, 10, 14, 14, dtype=torch.float32, device=device)
iteri = torch.tensor(3, dtype=torch.int64, device=device)
res = mnist(iteri, l_in_0)
res = mnist(iteri, l_in_0, l_out)
print("res: ", res[-1])


Expand Down
22 changes: 16 additions & 6 deletions torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,8 @@ def new_body_fn_may16_1426pm(*carried_inputs):
res = res # list(res_iter_inputs) + [res_outputs, ]
return res

print("wrapper carried_inputs: ", carried_inputs)
# print("wrapper carried_inputs: ", carried_inputs)
# for i in range(len(carried_inputs)): print("wrapper carried_inputs: ", i, " size: ", carried_inputs[i].size())
def new_body_fn_may19_2208pm(*carried_inputs):
res = list(body_fn(*carried_inputs))

Expand Down Expand Up @@ -1305,11 +1306,19 @@ def _xla_while_loop_target_second_clean_version_s32_may19_2206pm(cond_fn, body_f

# ============================= body_fn ==========================================
# body_result = body_fn(*carried_inputs) # fake would miss iter
body_result = body_fn(*carried_inputs, *additional_inputs) # fake would miss iter # right inputs
body_result = body_fn(*carried_inputs, *additional_inputs) # fake would miss iter # right inputs # right one
# kkk = (carried_inputs[0], *additional_inputs, *carried_inputs[1:])
# for i in range(len(kkk)): print("kkk: ", i, " size: ", kkk[i].size())
# body_result = body_fn(carried_inputs[0], *additional_inputs, *carried_inputs[1:]) # fake would miss iter
# for i in range(len(body_result)): print("body_result: ", i, " size: ", body_result[i].size())
# body_result = body_fn(carried_inputs[0], *additional_inputs, *carried_inputs[1:]) # fake would miss iter
print("body carried_inputs: ", carried_inputs)
print("body additional_inputs: ", additional_inputs)
print("body inputs: ", (*carried_inputs, *additional_inputs))
# print("body carried_inputs: ", carried_inputs)
# for i in range(len(carried_inputs)): print("body carried_inputs: ", i, " size: ", carried_inputs[i].size())
# print("body additional_inputs: ", additional_inputs)
# for i in range(len(additional_inputs)): print("body additional_inputs: ", i, " size: ", additional_inputs[i].size())
# print("body inputs: ", (*carried_inputs, *additional_inputs))
# lll = (*carried_inputs, *additional_inputs)
# for i in range(len(lll)): print("body inputs: ", i, " size: ", lll[i].size())
# body_result = body_fn(*newest_fake_inputs) # fake would miss iter
# body_result = body_fn(*modified_carried_inputs) # fake would miss iter
body_ctx = torch_xla._XLAC.lowering.LoweringContext()
Expand Down Expand Up @@ -1365,7 +1374,8 @@ def _xla_while_loop_target_second_clean_version_s32_may19_2206pm(cond_fn, body_f
iter_value = carried_inputs[0]
input_and_outputs_value = carried_inputs[1:]
total_inputs = tuple([iter_value,]) + tuple(additional_inputs) + tuple(bn_additional_inputs) + tuple(input_and_outputs_value)
print("total_inputs: ", total_inputs)
# print("total_inputs: ", total_inputs)
# for i in range(len(total_inputs)): print("total_inputs: ", i, " size: ", total_inputs[i].size())

print("get total_inputs !!!")

Expand Down

0 comments on commit 53c72f7

Please sign in to comment.