From 53c72f74de4a1a2cda79f1c010f17731ec6b5ec8 Mon Sep 17 00:00:00 2001 From: manfei Date: Mon, 20 May 2024 19:23:41 +0000 Subject: [PATCH] save it before change --- ...while_loop_simple_add_dispatch_in_torch.py | 16 +++++++++----- torch_xla/experimental/fori_loop.py | 22 ++++++++++++++----- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 0f15a352bcce..d42c1711209c 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -381,10 +381,12 @@ def forward(self, iteri, x, y): def cond_fn(iteri, x, y): return iteri > 0 - def body_fn(iteri, x): + def body_fn(iteri, x, y): - y = F.relu(F.max_pool2d(self.conv1(x), 2)) - # y = self.bn1(y) + y = self.conv1(x) + + # y = F.relu(F.max_pool2d(self.conv1(x), 2)) + # y = self.bn1(y.clone()) # y = F.relu(F.max_pool2d(self.conv2(y), 2)) # y = self.bn2(y) # y = torch.flatten(y, 1) @@ -392,17 +394,19 @@ def body_fn(iteri, x): # y = self.fc2(y) # return iteri - 1, F.log_softmax(y, dim=1) - return iteri - 1, y + return iteri - 1, x.clone(), y # torch.while_loop's body_fn might be aliasing the input! + # return iteri - 1, x.clone(), F.log_softmax(y, dim=1) # torch.while_loop's body_fn might be aliasing the input! - return while_loop(cond_fn, body_fn, (iteri, x)) + return while_loop(cond_fn, body_fn, (iteri, x, y)) mnist = MNIST() mnist.to(device) bs=16 # l_in_0 = torch.randn(bs, 1, 28, 28, dtype=torch.float32, device=device) l_in_0 = torch.randn(16, 1, 28, 28, dtype=torch.float32, device=device) + l_out = torch.randn(16, 10, 14, 14, dtype=torch.float32, device=device) iteri = torch.tensor(3, dtype=torch.int64, device=device) - res = mnist(iteri, l_in_0) + res = mnist(iteri, l_in_0, l_out) print("res: ", res[-1]) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 6d0eda445c86..dc095c84260f 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -483,7 +483,8 @@ def new_body_fn_may16_1426pm(*carried_inputs): res = res # list(res_iter_inputs) + [res_outputs, ] return res - print("wrapper carried_inputs: ", carried_inputs) + # print("wrapper carried_inputs: ", carried_inputs) + # for i in range(len(carried_inputs)): print("wrapper carried_inputs: ", i, " size: ", carried_inputs[i].size()) def new_body_fn_may19_2208pm(*carried_inputs): res = list(body_fn(*carried_inputs)) @@ -1305,11 +1306,19 @@ def _xla_while_loop_target_second_clean_version_s32_may19_2206pm(cond_fn, body_f # ============================= body_fn ========================================== # body_result = body_fn(*carried_inputs) # fake would miss iter - body_result = body_fn(*carried_inputs, *additional_inputs) # fake would miss iter # right inputs + body_result = body_fn(*carried_inputs, *additional_inputs) # fake would miss iter # right inputs # right one + # kkk = (carried_inputs[0], *additional_inputs, *carried_inputs[1:]) + # for i in range(len(kkk)): print("kkk: ", i, " size: ", kkk[i].size()) + # body_result = body_fn(carried_inputs[0], *additional_inputs, *carried_inputs[1:]) # fake would miss iter + # for i in range(len(body_result)): print("body_result: ", i, " size: ", body_result[i].size()) # body_result = body_fn(carried_inputs[0], *additional_inputs, *carried_inputs[1:]) # fake would miss iter - print("body carried_inputs: ", carried_inputs) - print("body additional_inputs: ", additional_inputs) - print("body inputs: ", (*carried_inputs, *additional_inputs)) + # print("body carried_inputs: ", carried_inputs) + # for i in range(len(carried_inputs)): print("body carried_inputs: ", i, " size: ", carried_inputs[i].size()) + # print("body additional_inputs: ", additional_inputs) + # for i in range(len(additional_inputs)): print("body additional_inputs: ", i, " size: ", additional_inputs[i].size()) + # print("body inputs: ", (*carried_inputs, *additional_inputs)) + # lll = (*carried_inputs, *additional_inputs) + # for i in range(len(lll)): print("body inputs: ", i, " size: ", lll[i].size()) # body_result = body_fn(*newest_fake_inputs) # fake would miss iter # body_result = body_fn(*modified_carried_inputs) # fake would miss iter body_ctx = torch_xla._XLAC.lowering.LoweringContext() @@ -1365,7 +1374,8 @@ def _xla_while_loop_target_second_clean_version_s32_may19_2206pm(cond_fn, body_f iter_value = carried_inputs[0] input_and_outputs_value = carried_inputs[1:] total_inputs = tuple([iter_value,]) + tuple(additional_inputs) + tuple(bn_additional_inputs) + tuple(input_and_outputs_value) - print("total_inputs: ", total_inputs) + # print("total_inputs: ", total_inputs) + # for i in range(len(total_inputs)): print("total_inputs: ", i, " size: ", total_inputs[i].size()) print("get total_inputs !!!")