Skip to content

Commit

Permalink
mnist test run
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed May 15, 2024
1 parent f211b5f commit c1ec552
Show file tree
Hide file tree
Showing 3 changed files with 658 additions and 48 deletions.
81 changes: 56 additions & 25 deletions test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch_xla
# We need to import the underlying implementation function to register with the dispatcher
import torch_xla.experimental.fori_loop
from torch_xla.experimental.fori_loop import fori_loop, _xla_while_loop, _xla_while_loop_target, _xla_while_loop_target_first, insert_model_pars_into_additional_inputs
from torch_xla.experimental.fori_loop import fori_loop, _xla_while_loop, _xla_while_loop_target, _xla_while_loop_target_first, insert_model_pars_into_additional_inputs, _xla_while_loop_target_first_second
# from torch_xla.experimental.fori_loop import _post_order_get_xla_computation_target_first, _xla_while_loop_get_xla_computation
from torch._higher_order_ops.while_loop import while_loop
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -375,40 +375,67 @@ def forward(self, iter, x, y):
def cond_fn(iter, x, y):
return iter > 0

def body_fn(iter, x, y):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.conv1.named_parameters())
x = self.bn1(x)
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn1.named_parameters())
# insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn1.named_parameters())
x = F.relu(F.max_pool2d(self.conv2(x), 2))
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.conv2.named_parameters())
x = self.bn2(x)
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn2.named_parameters())
# insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn2.named_parameters())
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.fc1.named_parameters())
x = self.fc2(x)
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.fc2.named_parameters())
def body_fn(iter, x, y): # def body_fn(iter, original_x, y):
# x = original_x.clone()
# x = original_x.clone()

# x = F.relu(F.max_pool2d(self.conv1(x), 2))
# # insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.conv1.named_parameters())
# x = self.bn1(x)
# # insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn1.named_parameters())
# # insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn1.named_parameters())
# x = F.relu(F.max_pool2d(self.conv2(x), 2))
# # insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.conv2.named_parameters())
# x = self.bn2(x)
# # insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn2.named_parameters())
# # insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn2.named_parameters())
# x = torch.flatten(x, 1)
# x = F.relu(self.fc1(x))
# # insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.fc1.named_parameters())
# x = self.fc2(x)
# # insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.fc2.named_parameters())

y = F.relu(F.max_pool2d(self.conv1(x), 2))
y = self.bn1(y)
y = F.relu(F.max_pool2d(self.conv2(y), 2))
y = self.bn2(y)
y = torch.flatten(y, 1)
y = F.relu(self.fc1(y))
y = self.fc2(y)

insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.conv1.named_parameters())
insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn1.named_parameters())
insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn1.named_parameters())
# insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn1.named_parameters())
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn1.named_parameters())
insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.conv2.named_parameters())
insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn2.named_parameters())
insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn2.named_parameters())
# insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn2.named_parameters())
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn2.named_parameters())
insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.fc1.named_parameters())
insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.fc2.named_parameters())

# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn2.named_parameters())
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn1.named_parameters())
# # insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn2.named_parameters())

insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn2.named_parameters())
insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn1.named_parameters())

# keep this modification here due to the additional_inputs would be modified after body_xlacomputation triggered
# self.bn_weight_bias_lists.reverse()
# self.weight_bias_lists = self.weight_bias_lists + self.bn_weight_bias_lists
# self.weight_bias_lists = [self.weight_bias_lists, self.bn_weight_bias_lists]
# print("weight_bias_lists: ", weight_bias_lists)
return iter-1, x, F.log_softmax(x, dim=1)

self.bn_weight_bias_lists.reverse()
self.weight_bias_lists = self.weight_bias_lists + self.bn_weight_bias_lists
return _xla_while_loop_target_first(cond_fn, body_fn, (iter, x, y), self.weight_bias_lists)
# for i in range(len(self.weight_bias_lists)): print("self.weight_bias_lists ", i, " size: ", self.weight_bias_lists[i].size())
# return iter-1, x, F.log_softmax(x, dim=1)
# return iter-1, original_x, F.log_softmax(x, dim=1)
return iter-1, x, F.log_softmax(y, dim=1)

# self.bn_weight_bias_lists.reverse()
# self.weight_bias_lists = self.weight_bias_lists + self.bn_weight_bias_lists
# for i in range(len(self.weight_bias_lists)): print("now self.weight_bias_lists ", i, " size: ", self.weight_bias_lists[i].size())
# return _xla_while_loop_target_first(cond_fn, body_fn, (iter, x, y), self.weight_bias_lists, self.bn_weight_bias_lists)
return _xla_while_loop_target_first_second(cond_fn, body_fn, (iter, x, y), self.weight_bias_lists, self.bn_weight_bias_lists)
# return while_loop(cond_fn, body_fn, (iter, x, y))
# return _xla_while_loop(cond_fn, body_fn, (iter, x, y), self.weight_bias_lists)

Expand All @@ -421,8 +448,12 @@ def body_fn(iter, x, y):
l_out = torch.randn(bs, 10, dtype=torch.float32, device=device)
iter = torch.tensor(3, device=device)
res = mnist(iter, l_in_0, l_out)
print("res: ", res)
print("res: ", res[-1])
# print("act-res: ", res[-1])
# expected = _fake_fori_loop(0, 3, mnist, l_in_0)
# for i in range(3):
# expected = mnist(l_in_0)
# print("expected: ", expected)

@unittest.skip("skip _get_xlacomputation now")
def test_while_loop_tpu_MNIST_outside_loop(self):
Expand Down
8 changes: 6 additions & 2 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,13 +965,15 @@ class PyLoweringContext {
lowering_ctx.AddResult(root);
}

xla::XlaBuilder* local_builder = lowering_ctx.builder();
int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size();
// xla::XlaBuilder* local_builder = lowering_ctx.builder();
// int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size();
// XLA_ERROR() << "for fori_loop, we have args now: " << parameter_idx;

// hard-code modify cond xlacomputation input arguments with unusedarguments
// for xla::while requriement
if (GetNameString() == "condctx") {
xla::XlaBuilder* local_builder = lowering_ctx.builder();
int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size();
// xla::XlaBuilder* local_builder = lowering_ctx.builder();
// int64_t parameter_idx = 2; // parameter_idx start from 2 after used upper and lower // param_count
for (auto& additional_input_tensor : additional_inputs_list) {
Expand All @@ -987,6 +989,8 @@ class PyLoweringContext {
// for xla::while requriement
if (GetNameString() == "bodyctx") {
xla::XlaBuilder* local_builder = lowering_ctx.builder();
int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size();
// xla::XlaBuilder* local_builder = lowering_ctx.builder();
// TODO(@manfei): treat hard code parameter_idx value
// int64_t parameter_idx = 21;
for (auto& additional_input_tensor : additional_inputs_list) {
Expand Down
Loading

0 comments on commit c1ec552

Please sign in to comment.