Skip to content

Commit

Permalink
pass again
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed May 15, 2024
1 parent c1ec552 commit 21c006e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,13 @@ def __init__(self):
self.fc2 = torch.nn.Linear(50, 10) # .to(xm.xla_device())
self.weight_bias_lists = []
self.bn_weight_bias_lists = []
self.register_buffer("dec", torch.tensor(1))

def forward(self, iter, x, y):

def cond_fn(iter, x, y):
return iter > 0
# return iter > 0
return iter > self.dec

def body_fn(iter, x, y): # def body_fn(iter, original_x, y):
# x = original_x.clone()
Expand Down Expand Up @@ -418,6 +420,7 @@ def body_fn(iter, x, y): # def body_fn(iter, original_x, y):
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn1.named_parameters())
# # insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn2.named_parameters())

# self.bn_weight_bias_lists.append(self.dec)
insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn2.named_parameters())
insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn1.named_parameters())

Expand All @@ -429,7 +432,7 @@ def body_fn(iter, x, y): # def body_fn(iter, original_x, y):
# for i in range(len(self.weight_bias_lists)): print("self.weight_bias_lists ", i, " size: ", self.weight_bias_lists[i].size())
# return iter-1, x, F.log_softmax(x, dim=1)
# return iter-1, original_x, F.log_softmax(x, dim=1)
return iter-1, x, F.log_softmax(y, dim=1)
return iter - self.dec, x, F.log_softmax(y, dim=1)

# self.bn_weight_bias_lists.reverse()
# self.weight_bias_lists = self.weight_bias_lists + self.bn_weight_bias_lists
Expand Down
32 changes: 31 additions & 1 deletion torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,23 @@ def ninth_try_new_body_fn(*carried_inputs):
res = [iter, ] + additional_inputs + [one, ] + bn_additional_inputs + inputs_items + [outputs_items, ]
return res

def tenth_try_new_body_fn(*carried_inputs):
# add s64[] for 1
res = list(body_fn(*carried_inputs))
iter = res[0]
inputs_and_outputs = res[1:]
inputs_items = inputs_and_outputs[:-1]
outputs_items = inputs_and_outputs[-1]
# if len(inputs_and_outputs)==1:
# inputs_and_outputs = [inputs_and_outputs,]
# res = res + list(additional_inputs)
# xla_device = carried_inputs[0].device
# one = torch.tensor(1, dtype=torch.int64, device=xla_device)
# res = [iter, ] + additional_inputs + [one, ] + bn_additional_inputs + inputs_items + [outputs_items, ]
# res = [iter, ] + additional_inputs + bn_additional_inputs[1:] + inputs_items + [outputs_items, ]
res = [iter, ] + additional_inputs + bn_additional_inputs + inputs_items + [outputs_items, ]
return res

# new_additional_inputs = additional_inputs[0] + additional_inputs[1]

# return _xla_while_loop_target(cond_fn, new_body_fn, carried_inputs, additional_inputs)
Expand All @@ -605,6 +622,8 @@ def _xla_while_loop_target_second(cond_fn, body_fn, carried_inputs, additional_i
# print("additional_inputs: ", additional_inputs)
### use output as input now case, so we could get output in the return value from original inpjut position

# for i in range(len(bn_additional_inputs)): print("bn_additional_inputs ", i, " size: ", bn_additional_inputs[i].size())

# fake carried_inputs to split formal code
fake_carried_inputs = []
for carried_input in carried_inputs:
Expand Down Expand Up @@ -690,6 +709,10 @@ def _xla_while_loop_target_second(cond_fn, body_fn, carried_inputs, additional_i
# === add bn_additional_inputs ===
fake_additiona_args += bn_additional_inputs

# for i in range(len(bn_additional_inputs)): print("bn_additional_inputs ", i, " size: ", bn_additional_inputs[i].size())

# for i in range(len(additional_inputs)): print("additional_inputs ", i, " size: ", additional_inputs[i].size())

# TODO(@manfei): specify which element is for which argument like a,b,c
# cond_result = cond_fn(*fake_carried_inputs)
# cond_result = cond_fn(*fake_carried_inputs_all_args)
Expand Down Expand Up @@ -756,7 +779,14 @@ def _xla_while_loop_target_second(cond_fn, body_fn, carried_inputs, additional_i
# total_inputs = carried_inputs + tuple(additional_inputs)
iter_value = carried_inputs[0]
input_and_outputs_value = carried_inputs[1:]
total_inputs = tuple([iter_value,]) + tuple(additional_inputs) + tuple([one, ]) + tuple(bn_additional_inputs) + tuple(carried_inputs[1:])

# # === add one ===
# # xla_device = carried_inputs[0].device
# one = torch.tensor(1, dtype=torch.int64, device=device) # xla_device)
# fake_additiona_args.append(one)

# total_inputs = tuple([iter_value,]) + tuple(additional_inputs) + tuple([one, ]) + tuple(bn_additional_inputs) + tuple(carried_inputs[1:])
total_inputs = tuple([iter_value,]) + tuple(additional_inputs) + tuple([one,]) + tuple(bn_additional_inputs) + tuple(carried_inputs[1:])

# for i in range(len(carried_inputs)): print("2 carried_inputs ", i, " size: ", carried_inputs[i].size())
# for i in range(len(additional_inputs)): print("2 additional_inputs ", i, " size: ", additional_inputs[i].size())
Expand Down

0 comments on commit 21c006e

Please sign in to comment.