diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 104242cb384..7a4f48b7287 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -369,11 +369,13 @@ def __init__(self): self.fc2 = torch.nn.Linear(50, 10) # .to(xm.xla_device()) self.weight_bias_lists = [] self.bn_weight_bias_lists = [] + self.register_buffer("dec", torch.tensor(1)) def forward(self, iter, x, y): def cond_fn(iter, x, y): - return iter > 0 + # return iter > 0 + return iter > self.dec def body_fn(iter, x, y): # def body_fn(iter, original_x, y): # x = original_x.clone() @@ -418,6 +420,7 @@ def body_fn(iter, x, y): # def body_fn(iter, original_x, y): # insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn1.named_parameters()) # # insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn2.named_parameters()) + # self.bn_weight_bias_lists.append(self.dec) insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn2.named_parameters()) insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn1.named_parameters()) @@ -429,7 +432,7 @@ def body_fn(iter, x, y): # def body_fn(iter, original_x, y): # for i in range(len(self.weight_bias_lists)): print("self.weight_bias_lists ", i, " size: ", self.weight_bias_lists[i].size()) # return iter-1, x, F.log_softmax(x, dim=1) # return iter-1, original_x, F.log_softmax(x, dim=1) - return iter-1, x, F.log_softmax(y, dim=1) + return iter - self.dec, x, F.log_softmax(y, dim=1) # self.bn_weight_bias_lists.reverse() # self.weight_bias_lists = self.weight_bias_lists + self.bn_weight_bias_lists diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 5bf628b80c2..2812f5beebf 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -582,6 +582,23 @@ def ninth_try_new_body_fn(*carried_inputs): res = [iter, ] + additional_inputs + [one, ] + bn_additional_inputs + inputs_items + [outputs_items, ] return res + def tenth_try_new_body_fn(*carried_inputs): + # add s64[] for 1 + res = list(body_fn(*carried_inputs)) + iter = res[0] + inputs_and_outputs = res[1:] + inputs_items = inputs_and_outputs[:-1] + outputs_items = inputs_and_outputs[-1] + # if len(inputs_and_outputs)==1: + # inputs_and_outputs = [inputs_and_outputs,] + # res = res + list(additional_inputs) + # xla_device = carried_inputs[0].device + # one = torch.tensor(1, dtype=torch.int64, device=xla_device) + # res = [iter, ] + additional_inputs + [one, ] + bn_additional_inputs + inputs_items + [outputs_items, ] + # res = [iter, ] + additional_inputs + bn_additional_inputs[1:] + inputs_items + [outputs_items, ] + res = [iter, ] + additional_inputs + bn_additional_inputs + inputs_items + [outputs_items, ] + return res + # new_additional_inputs = additional_inputs[0] + additional_inputs[1] # return _xla_while_loop_target(cond_fn, new_body_fn, carried_inputs, additional_inputs) @@ -605,6 +622,8 @@ def _xla_while_loop_target_second(cond_fn, body_fn, carried_inputs, additional_i # print("additional_inputs: ", additional_inputs) ### use output as input now case, so we could get output in the return value from original inpjut position + # for i in range(len(bn_additional_inputs)): print("bn_additional_inputs ", i, " size: ", bn_additional_inputs[i].size()) + # fake carried_inputs to split formal code fake_carried_inputs = [] for carried_input in carried_inputs: @@ -690,6 +709,10 @@ def _xla_while_loop_target_second(cond_fn, body_fn, carried_inputs, additional_i # === add bn_additional_inputs === fake_additiona_args += bn_additional_inputs + # for i in range(len(bn_additional_inputs)): print("bn_additional_inputs ", i, " size: ", bn_additional_inputs[i].size()) + + # for i in range(len(additional_inputs)): print("additional_inputs ", i, " size: ", additional_inputs[i].size()) + # TODO(@manfei): specify which element is for which argument like a,b,c # cond_result = cond_fn(*fake_carried_inputs) # cond_result = cond_fn(*fake_carried_inputs_all_args) @@ -756,7 +779,14 @@ def _xla_while_loop_target_second(cond_fn, body_fn, carried_inputs, additional_i # total_inputs = carried_inputs + tuple(additional_inputs) iter_value = carried_inputs[0] input_and_outputs_value = carried_inputs[1:] - total_inputs = tuple([iter_value,]) + tuple(additional_inputs) + tuple([one, ]) + tuple(bn_additional_inputs) + tuple(carried_inputs[1:]) + + # # === add one === + # # xla_device = carried_inputs[0].device + # one = torch.tensor(1, dtype=torch.int64, device=device) # xla_device) + # fake_additiona_args.append(one) + + # total_inputs = tuple([iter_value,]) + tuple(additional_inputs) + tuple([one, ]) + tuple(bn_additional_inputs) + tuple(carried_inputs[1:]) + total_inputs = tuple([iter_value,]) + tuple(additional_inputs) + tuple([one,]) + tuple(bn_additional_inputs) + tuple(carried_inputs[1:]) # for i in range(len(carried_inputs)): print("2 carried_inputs ", i, " size: ", carried_inputs[i].size()) # for i in range(len(additional_inputs)): print("2 additional_inputs ", i, " size: ", additional_inputs[i].size())