diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index a76197cc736..e20f1e5a5bf 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -20,16 +20,22 @@ def _fake_while_loop(cond_fn, body_fn, operands): def _fake_fori_loop(lower, upper, body_fun, *init_val): - (plus_value, init_val) = init_val - for i in range((upper - lower)[0]): - plus_value, init_val = body_fun(plus_value, init_val) - return init_val + if len(init_val) > 1: + (a, b) = init_val + for i in range((upper - lower)[0]): + a = body_fun(a, b) + else: + for i in range((upper - lower)[0]): + a = body_fun(*init_val) + return a class WhileLoopTest(unittest.TestCase): +# additional_inputs: () def test_while_loop_tpu_subtraction(self): + print("$$$ test_while_loop_tpu_subtraction !!!") device = xm.xla_device() def cond_fn(init, limit_value): @@ -46,8 +52,10 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) +# additional_inputs: () def test_while_loop_tpu_addition(self): + print("$$$ test_while_loop_tpu_addition !!!") device = xm.xla_device() def cond_fn(init, limit_value): @@ -64,8 +72,10 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) +# additional_inputs: () def test_while_loop_tpu_subtraction_nested(self): + print("$$$ test_while_loop_tpu_subtraction_nested !!!") device = xm.xla_device() def cond_fn(init, limit_value): @@ -82,25 +92,252 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) +### return weight/bias +# additional_inputs: (tensor([1*20], device='xla:0'), tensor([10*20], device='xla:0')) + def test_while_loop_tpu_simple_linear(self): + + print("$$$ test_while_loop_tpu_simple_linear !!!") + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value = linear_0(input_value) + weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), bias.clone(), weight.clone( + ), output_value.clone() + + upper = torch.tensor([1], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value)) + + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + return self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + +### +# + def test_while_loop_tpu_simple_linear_wrapper(self): + + print("$$$ test_while_loop_tpu_simple_linear_wrapper !!!") + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value = linear_0(input_value) + # weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + # bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), output_value.clone() + + upper = torch.tensor([1], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + # upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop( + # cond_fn, body_fn, + # (upper, lower, one_value, init_val, l_in_0, output_value)) + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, = while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value)) + + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + return self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + + +### return weight/bias +# additional_inputs: (tensor([ 1*20], device='xla:0'), tensor([10*20], device='xla:0')) + def test_while_loop_tpu_simple_linear_class(self): + + print("$$$ test_while_loop_tpu_simple_linear_class !!!") + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + class SimpleWithLinear(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def forward(self, upper, lower, one_value, x, input_value, output_value): + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value_real = self.linear(input_value) + weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone( + ), output_value_real, weight.clone(), bias.clone() + + return while_loop( + cond_fn, body_fn, + (upper, lower, one_value, x, input_value, output_value)) + + simple_with_linear = SimpleWithLinear() + upper = torch.tensor([52], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + weight_0 = simple_with_linear.linear.weight + bias_0 = simple_with_linear.linear.bias + + aaa = { + "simple_with_linear": + (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, + output_value)) + } + + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear( + upper, lower, one_value, init_val, l_in_0, output_value) + + # create same weight/bias liear model for compare + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + linear_0.weight.data = weight__ + linear_0.bias.data = bias__ + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + return aaa + +### +# + def test_while_loop_tpu_simple_linear_class_wrapper(self): + + print("$$$ test_while_loop_tpu_simple_linear_class_wrapper !!!") + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + class SimpleWithLinear(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def forward(self, upper, lower, one_value, x, input_value, output_value): + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value_real = self.linear(input_value) + weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone( + ), output_value_real + + return while_loop( + cond_fn, body_fn, + (upper, lower, one_value, x, input_value, output_value)) + + simple_with_linear = SimpleWithLinear() + upper = torch.tensor([52], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + weight_0 = simple_with_linear.linear.weight + bias_0 = simple_with_linear.linear.bias + + aaa = { + "simple_with_linear": + (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, + output_value)) + } + + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear( + upper, lower, one_value, init_val, l_in_0, output_value) + + # create same weight/bias liear model for compare + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + linear_0.weight.data = weight__ + linear_0.bias.data = bias__ + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + return aaa + +# additional_inputs: () def test_fori_loop_tpu_addition(self): + print("$$$ test_fori_loop_tpu_addition !!!") xm.mark_step() device = xm.xla_device() lower = torch.tensor([2], dtype=torch.int32, device=device) upper = torch.tensor([52], dtype=torch.int32, device=device) - plus_value = torch.tensor([1], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + + def body_fun(a, b): + return torch.add(a, b) + + upper_, new_lower_, one_value_, add_res_x_, res_ = fori_loop( + upper, lower, body_fun, one_value, init_val) + expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value) + self.assertEqual(expected, res_) + +# additional_inputs: (tensor([1*20], device='xla:0'), tensor([[10*20], device='xla:0')) + def test_fori_loop_tpu_simple_linear(self): + + print("$$$ test_fori_loop_tpu_simple_linear !!!") + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + upper = torch.tensor([52], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.randn(10, device=xm.xla_device()) + + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + + upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop( + upper, lower, linear_0, init_val, l_in_0) - def body_fun(*argus): - plus_value, init_val = argus - return plus_value, torch.add(plus_value, init_val) + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) - _, _, _, actual = fori_loop(upper, lower, body_fun, plus_value, init_val) - expected = _fake_fori_loop(lower, upper, body_fun, plus_value, init_val) - self.assertEqual(expected, actual) + self.assertTrue(torch.all(torch.eq(expected, l_out_))) if __name__ == '__main__': test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) \ No newline at end of file + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py new file mode 100644 index 00000000000..c3da2020635 --- /dev/null +++ b/test/test_test_mnist.py @@ -0,0 +1,583 @@ +import torch +import torchvision +import os +import shutil +import sys +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +import torch_xla +import torch_xla.debug.metrics as met +import torch_xla.distributed.parallel_loader as pl +import torch_xla.utils.utils as xu +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.test.test_utils as test_utils +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import torch_xla.distributed.xla_backend +import torch_xla.experimental.fori_loop +from torch_xla.experimental.fori_loop import _xla_while_loop +from torch._higher_order_ops.while_loop import while_loop + +n_epochs = 3 +batch_size_train = 8 # 64 +batch_size_test = 10 # 1000 +learning_rate = 0.01 +momentum = 0.5 +log_interval = 10 +random_seed = 1 +torch.backends.cudnn.enabled = False +torch.manual_seed(random_seed) + +### load data +test_loader = xu.SampleGenerator( + data=(torch.zeros(8, 1, 28,28), torch.zeros(8, dtype=torch.int64)), + sample_count=1000 // 8 // xm.xrt_world_size()) + +# examples = enumerate(test_loader) +# batch_idx, (example_data, example_targets) = next(examples) +# print("shape: ", example_data.shape) + +### build model +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +# model.parameters() + +class SimpleWithLinearPure(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5, stride=1, padding=2).to(xm.xla_device()) + self.bn1 = torch.nn.BatchNorm2d(10).to(xm.xla_device()) + self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5).to(xm.xla_device()) + self.bn2 = torch.nn.BatchNorm2d(20).to(xm.xla_device()) + self.fc1 = torch.nn.Linear(500, 50).to(xm.xla_device()) + self.fc2 = torch.nn.Linear(50, 10).to(xm.xla_device()) + # self.fc1 = torch.nn.Linear(320, 50).to(xm.xla_device()) + # self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + # self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) + # self.fc1 = nn.Linear(320, 50).to(xm.xla_device()) + # self.fc2 = nn.Linear(50, 10).to(xm.xla_device()) + + # def forward(self, upper, lower, one_value, x, input_value, output_value): + def forward(self, x): + # output_value_real = self.linear(input_value) + # output_value_real_final = self.linear2(output_value_real) + # output_value_real_final = self.conv1(input_value) # conv2d + x = F.relu(F.max_pool2d(self.conv1(x), 2)) # conv2d+mnist-treat + x = self.bn1(x) + x = F.relu(F.max_pool2d(self.conv2(x), 2)) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + # return x + +class SimpleWithLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) + + def forward(self, upper, lower, one_value, x, input_value, output_value): + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value_real = self.linear(input_value) + output_value_real_final = self.linear2(output_value_real) + # weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + # bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone( + ), output_value_real_final # , weight.clone(), bias.clone() + + return while_loop( + cond_fn, body_fn, + (upper, lower, one_value, x, input_value, output_value)) + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x) + +class MNIST(nn.Module): + + def __init__(self): + super(MNIST, self).__init__() + self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5) + self.bn1 = torch.nn.BatchNorm2d(10) + self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5) + self.bn2 = torch.nn.BatchNorm2d(20) + self.fc1 = torch.nn.Linear(320, 50) + self.fc2 = torch.nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = self.bn1(x) + x = F.relu(F.max_pool2d(self.conv2(x), 2)) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + +device = xm.xla_device() +network = MNIST().to(device) +# network = Net().to(device) +optimizer = optim.SGD(network.parameters(), lr=learning_rate, + momentum=momentum) +# loss_fn = nn.NLLLoss() + +train_losses = [] +train_counter = [] +test_losses = [] +test_counter = [i*20 for i in range(n_epochs + 1)] + +def test(): + network.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + output = network(data) + test_loss += F.nll_loss(output, target, size_average=False).item() + pred = output.data.max(1, keepdim=True)[1] + correct += pred.eq(target.data.view_as(pred)).sum() + test_loss /= 20 + test_losses.append(test_loss) + print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, 20, + 100. * correct / 20)) + return test_loss + +def new_test(): + simple_with_linear = SimpleWithLinear() + upper = torch.tensor([52], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + weight_0 = simple_with_linear.linear.weight + bias_0 = simple_with_linear.linear.bias + + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear( + upper, lower, one_value, init_val, l_in_0, output_value) + print("finish new_test") + +def newnew_test(): + device = xm.xla_device() + torch.set_grad_enabled(False) + + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + # simple_with_linear = SimpleWithLinear() + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value = linear_0(input_value) + weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), bias.clone(), weight.clone( + ), output_value.clone() + + upper = torch.tensor([1], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value)) + print("finish newnew_test") + + # # simple_with_linear = SimpleWithLinear() + # simple_with_linear = SimpleWithLinear() + # upper = torch.tensor([52], dtype=torch.int32, device=device) + # lower = torch.tensor([0], dtype=torch.int32, device=device) + # one_value = torch.tensor([1], dtype=torch.int32, device=device) + # init_val = torch.tensor([1], dtype=torch.int32, device=device) + # l_in_0 = torch.rand(10, device=xm.xla_device()) + # output_value = torch.zeros([20], dtype=torch.float32, device=device) + # weight_0 = simple_with_linear.linear.weight + # bias_0 = simple_with_linear.linear.bias + + # upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear( + # upper, lower, one_value, init_val, l_in_0, output_value) + # print("finish new_test") + + # def cond_fn(upper, lower, one_value, x, input_value, output_value): + # return lower[0] < upper[0] + + # def body_fn(upper, lower, one_value, x, input_value, output_value): + # new_lower = torch.add(one_value, lower) + # output_value_real = self.linear(input_value) + # output_value_real_final = self.linear2(output_value_real) + # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + # one_value, x), input_value.clone( + # ), output_value_real_final # , weight.clone(), bias.clone() + + # return while_loop( + # cond_fn, body_fn, + # (upper, lower, one_value, x, input_value, output_value)) + +def newnewnew_test(): + device = xm.xla_device() + torch.set_grad_enabled(False) + + # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + # simple_with_linear = SimpleWithLinear() + simple_with_linear = SimpleWithLinearPure() + + def cond_fn(upper, lower, one_value, x, input_value, output_value, *args): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value, *args): + new_lower = torch.add(one_value, lower) + output_value = simple_with_linear(input_value) + # weight = simple_with_linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + # bias = simple_with_linear.bias # not be used actually, initialized as placeholder xlacomputation requirement + res = [upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone()] + bn_list = [] + # bn_flag = False + for name, param in simple_with_linear.named_parameters(): + if name[:2]=='bn': + # bn_flag = True + # bn_list.insert(-1, param) # dumpicate # continue # skip bn + bn_list.append(param) + # else: + # bn_flag = False + + res.insert(-1, param) + + # if (not bn_flag) and (len(bn_list) !=0): # False + # output_value = res[-1] + # res = res[:-1] + bn_list # + res[-1] + # res.append(output_value) + # bn_list = [] + # # torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) + + ### !!! add still exist bn_list if the last additional_inputs is bn- pre + # if bn_flag and (len(bn_list) !=0): + ### !!! add at the tile + if len(bn_list) !=0: + output_value = res[-1] + # res = res[:-1] + bn_list # + res[-1] + bn_list.reverse() + res = res[:-1] + bn_list + res.append(output_value) + bn_list = [] + # bn_flag = False + + return tuple(res) + # return (upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + # one_value, x), input_value.clone(), output_value.clone(), simple_with_linear.linear.weight) # bias.clone(), weight.clone(), output_value.clone() + + # print("simple_with_linear weight: ", simple_with_linear.weight) + # print("simple_with_linear bias: ", simple_with_linear.bias) + # print("print all things!!!") + # print(type(simple_with_linear.parameters())) + # print("simple_with_linear.named_parameters(): ", simple_with_linear.named_parameters()) + # import pdb; pdb.set_trace() + + # for name, param in simple_with_linear.named_parameters(): + # # print("arrive the loop") + # print("name: ", name) + # print("param: ", param) + + # if name in ['bias']: + # print(param.size()) + + upper = torch.tensor([50], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + ### linear 10*20 + 20*30 input&output + # l_in_0 = torch.rand(10, device=xm.xla_device()) + # output_value = torch.zeros([30], dtype=torch.float32, device=device) + ### conv2d input&output + bs=16 + l_in_0 = torch.randn(bs, 1, 28, 28, dtype=torch.float32, device=device) + # c = nn.Conv2d(3,10,kernel_size=5,stride=1,padding=2) + # out = c(x) + # print(out.nelement()) + # output_value = torch.zeros([16,10,28,28], dtype=torch.float32, device=device) # conv2d + # output_value = torch.zeros([16,10,14,14], dtype=torch.float32, device=device) # conv2d+mnist-treat # conv1 + bn1 + # output_value = torch.zeros([16,20,5,5], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + # output_value = torch.zeros([16,500], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1 + # output_value = torch.zeros([16,50], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1 + fc1 + output_value = torch.zeros([16,10], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1 + fc1 + + + additional_inputs = [] + bn_list = [] + # bn_flag = False + for name, param in simple_with_linear.named_parameters(): + # if name[:2]=='bn': + # additional_inputs.append(param) # dumplicate + if name[:2]=='bn': + # print("catch: ", name) + # bn_flag = True + # bn_list.insert(-1, param) # dumpicate # continue # skip bn + bn_list.append(param) + # print("newest bn_list: ", bn_list) + # else: + # bn_flag = False + + # additional_inputs.insert(-1, param) + additional_inputs.append(param) + + # if (not bn_flag) and (len(bn_list) !=0): # False + # additional_inputs =additional_inputs + bn_list + # # print("added bn_list: ", bn_list) + # bn_list = [] + + ### !!! add still exist bn_list if the last additional_inputs is bn- pre + # if bn_flag and (len(bn_list) !=0): + ### !!! add duplicated bn argus as the tile of the list + if len(bn_list) !=0: + # additional_inputs = additional_inputs + bn_list + bn_list.reverse() ### !!! reverse list for bn duplicate lists + additional_inputs = additional_inputs + bn_list + # print("added bn_list: ", bn_list) + bn_list = [] + # bn_flag = False + + # print("final additional_inputs: ", additional_inputs) + + # print("in mnist additional_inputs: ", additional_inputs) + ### linear 10*20 + 20*30 + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, w2_, b2_, output_value_real__, = _xla_while_loop( + ### conv2d + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, output_value_real__, = _xla_while_loop( + #### conv1+bn1 + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, output_value_real__, = _xla_while_loop( + ##### conv1 + bn1 + conv2 + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, output_value_real__, = _xla_while_loop( + ##### conv1 + bn1 + conv2 + bn2 + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, output_value_real__, = _xla_while_loop( + ##### conv1 + bn1 + conv2 + bn2 + fc1 + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, p7, p8, output_value_real__, = _xla_while_loop( + ##### conv1 + bn1 + conv2 + bn2 + fc1 + fc2 + softmax + upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, output_value_real__, = _xla_while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) + # (upper, lower, one_value, init_val, l_in_0, output_value), ()) + print("finish newnewnew_test") + print("torch_add_res__: run times: ", torch_add_res__) + print("actual res: ", output_value_real__[0][0]) + expected_ = simple_with_linear(l_in_0) + print("expected res: ", expected_[0][0]) + +# run test model +def test_mnist(): + torch.manual_seed(1) + + print("before test_mnist") + newnewnew_test() # newnew_test() # new_test() # test() + # # target fori_loop + # for epoch in range(1, n_epochs + 1): + # newnewnew_test() # newnew_test() # new_test() # test() + + print("after test_mnist") + +if __name__ == '__main__': + test_mnist() + +# torch.set_default_dtype(torch.float32) +# accuracy = test_mnist() + +# ///////////////////////////////////////////////////////////////////////////////////////////////////////// + +# import args_parse +# from torch_xla import runtime as xr + +# # MODEL_OPTS = { +# # '--ddp': { +# # 'action': 'store_true', +# # }, +# # '--pjrt_distributed': { +# # 'action': 'store_true', +# # }, +# # } + +# FLAGS = args_parse.parse_common_options( +# datadir='/tmp/mnist-data', +# batch_size=128, +# momentum=0.5, +# lr=0.01, +# target_accuracy=98.0, +# num_epochs=18, +# # opts=MODEL_OPTS.items(), +# ) + +# import os +# import shutil +# import sys +# import numpy as np +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# import torch.optim as optim +# from torchvision import datasets, transforms +# import torch_xla +# import torch_xla.debug.metrics as met +# import torch_xla.distributed.parallel_loader as pl +# import torch_xla.utils.utils as xu +# import torch_xla.core.xla_model as xm +# import torch_xla.distributed.xla_multiprocessing as xmp +# import torch_xla.test.test_utils as test_utils + +# import torch.distributed as dist +# from torch.nn.parallel import DistributedDataParallel as DDP +# import torch_xla.distributed.xla_backend + +# import torch_xla.experimental.fori_loop +# from torch_xla.experimental.fori_loop import fori_loop + + +# class MNIST(nn.Module): + +# def __init__(self): +# super(MNIST, self).__init__() +# self.conv1 = nn.Conv2d(1, 10, kernel_size=5) +# self.bn1 = nn.BatchNorm2d(10) +# self.conv2 = nn.Conv2d(10, 20, kernel_size=5) +# self.bn2 = nn.BatchNorm2d(20) +# self.fc1 = nn.Linear(320, 50) +# self.fc2 = nn.Linear(50, 10) + +# def forward(self, x): +# x = F.relu(F.max_pool2d(self.conv1(x), 2)) +# x = self.bn1(x) +# x = F.relu(F.max_pool2d(self.conv2(x), 2)) +# x = self.bn2(x) +# x = torch.flatten(x, 1) +# x = F.relu(self.fc1(x)) +# x = self.fc2(x) +# return F.log_softmax(x, dim=1) + + +# def train_mnist(flags, **kwargs): +# torch.manual_seed(1) + +# test_loader = xu.SampleGenerator( +# data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), +# sample_count=10000 // flags.batch_size // xm.xrt_world_size()) + +# # Scale learning rate to num cores +# lr = flags.lr * xm.xrt_world_size() +# device = xm.xla_device() +# model = MNIST().to(device) + +# # Initialization is nondeterministic with multiple threads in PjRt. +# # Synchronize model parameters across replicas manually. +# if xr.using_pjrt(): +# xm.broadcast_master_param(model) + +# writer = None +# if xm.is_master_ordinal(): +# writer = test_utils.get_summary_writer(flags.logdir) +# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) +# loss_fn = nn.NLLLoss() + +# def test_loop_fn(): # loader): +# total_samples = 0 +# correct = 0 +# model.eval() +# # print("loader: ", loader) +# # print("type loader: ", type(loader)) +# # for data, target in loader: +# for data, target in test_loader: +# output = model(data) +# pred = output.max(1, keepdim=True)[1] +# correct += pred.eq(target.view_as(pred)).sum() +# total_samples += data.size()[0] + +# accuracy = 100.0 * correct.item() / total_samples +# accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) +# return accuracy + +# # train_device_loader = pl.MpDeviceLoader(train_loader, device) +# test_device_loader = pl.MpDeviceLoader(test_loader, device) +# accuracy, max_accuracy = 0.0, 0.0 + +# for epoch in range(1, flags.num_epochs + 1): +# # xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) +# # train_loop_fn(train_device_loader, epoch) +# # xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) +# accuracy = test_loop_fn() # test_device_loader) +# xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(epoch, test_utils.now(), accuracy)) +# max_accuracy = max(accuracy, max_accuracy) +# # test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) +# # if flags.metrics_debug: xm.master_print(met.metrics_report()) + +# ### fori_loop +# # torch.set_grad_enabled(False) +# # new_test_device_loader = pl.MpDeviceLoader(test_loader, device) +# upper = torch.tensor([flags.num_epochs + 1], dtype=torch.int32, device=device) # flags.num_epochs + 1 +# lower = torch.tensor([1], dtype=torch.int32, device=device) # 1 +# init_val = torch.tensor([1], dtype=torch.int32, device=device) +# # l_in_0 = torch.randn(10, device=xm.xla_device()) # test_device_loader +# # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) +# def body_fun(): +# res1 = torch.tensor([2], dtype=torch.int32, device=device) +# res2 = torch.tensor([2], dtype=torch.int32, device=device) +# res3 = res1 + res2 +# return res3 +# # def body_fun(test_device_loader): +# # accuracy = test_loop_fn(test_device_loader) +# # max_accuracy = max(accuracy, max_accuracy) +# # return max_accuracy + +# upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop( +# upper, lower, body_fun, ()) + +# test_utils.close_summary_writer(writer) +# xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) +# return max_accuracy + + +# # def _mp_fn(index, flags): +# def main_fun(flags): +# torch.set_default_dtype(torch.float32) +# accuracy = train_mnist(flags) +# if flags.tidy and os.path.isdir(flags.datadir): +# shutil.rmtree(flags.datadir) +# if accuracy < flags.target_accuracy: +# print('Accuracy {} is below target {}'.format(accuracy, +# flags.target_accuracy)) +# sys.exit(21) + + +# if __name__ == '__main__': +# # xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores) +# # _mp_fn() +# main_fun(FLAGS) diff --git a/test/test_torch_xla_while_loop_test.py b/test/test_torch_xla_while_loop_test.py new file mode 100644 index 00000000000..6573bcb3ede --- /dev/null +++ b/test/test_torch_xla_while_loop_test.py @@ -0,0 +1,33 @@ +import time +start_time = time.time() + +import torch +import torch_xla +import torch_xla.experimental.fori_loop +from torch._higher_order_ops.while_loop import while_loop +import torch_xla.core.xla_model as xm +import torch_xla.core.xla_builder as xb +import torch_xla.debug.profiler as xp + +# server = xp.start_server(9012) + +# xp.trace_detached( +# f'localhost:9012', +# '/root/profiles/', +# duration_ms=2000) + +device = xm.xla_device() + +def cond_fn(init, limit_value): + return limit_value[0] >= init[0] + +def body_fn(init, limit_value): + one_value = torch.ones(1, dtype=torch.int32, device=device) + return (torch.add(init, one_value), limit_value.clone()) + +init = torch.tensor([0], dtype=torch.int32, device=device) +limit_value = torch.tensor([1000], dtype=torch.int32, device=device) +res = while_loop(cond_fn, body_fn, (init, limit_value)) +print("res: ", res) + +print("--- %s seconds ---" % (time.time() - start_time)) diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 3b078d22fab..f27e1f58ed6 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -170,6 +170,8 @@ def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() + # print("loader: ", loader) + # print("type loader: ", type(loader)) for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] @@ -180,25 +182,35 @@ def test_loop_fn(loader): accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy - train_device_loader = pl.MpDeviceLoader(train_loader, device) + # train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): - xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) - train_loop_fn(train_device_loader, epoch) - xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) - + # xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) + # train_loop_fn(train_device_loader, epoch) + # xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) - xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( - epoch, test_utils.now(), accuracy)) + # xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) - test_utils.write_to_summary( - writer, - epoch, - dict_to_write={'Accuracy/test': accuracy}, - write_xla_metrics=True) - if flags.metrics_debug: - xm.master_print(met.metrics_report()) + # test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) + # if flags.metrics_debug: xm.master_print(met.metrics_report()) + + # ### fori_loop + # # torch.set_grad_enabled(False) + # new_test_device_loader = pl.MpDeviceLoader(test_loader, device) + # upper = torch.tensor([flags.num_epochs + 1], dtype=torch.int32, device=device) # flags.num_epochs + 1 + # lower = torch.tensor([1], dtype=torch.int32, device=device) # 1 + # init_val = torch.tensor([1], dtype=torch.int32, device=device) + # # l_in_0 = torch.randn(10, device=xm.xla_device()) # test_device_loader + # # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + # def body_fun(test_device_loader): + # accuracy = test_loop_fn(test_device_loader) + # max_accuracy = max(accuracy, max_accuracy) + # return max_accuracy + + # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop( + # upper, lower, body_fun, init_val, new_test_device_loader) + test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c603e5d27a5..a1adab814fb 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -913,14 +913,41 @@ class PyLoweringContext { // Builds a HLO graph given a set of output tensors, and add unused parameters // needed in xlacomputation. void BuildForiLoop(std::vector tensors, - std::vector input_arguments = {}) { + std::vector additional_inputs_list = {}) { + // hard-code modify cond xlacomputation input arguments with unusedarguments + // for xla::while requriement if (GetNameString() == "condctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); - // hard-code parameter_idx to 2 to skip existing upper/lower arguments - int64_t parameter_idx = 2; - for (at::Tensor input_argument : input_arguments) { - xla::Shape shape = - xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); + int64_t parameter_idx = + 2; // parameter_idx start from 2 after used upper and lower + for (auto& additional_input_tensor : additional_inputs_list) { + XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + xla::Shape shape = xtensor->shape().get(); + xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + "UnusedArgumentsPlaceholder"); + parameter_idx += 1; + } + } + + // hard-code modify body xlacomputation input arguments with unusedarguments + // for xla::while requriement + // !!! actually weight/bias don't need to be added here as dummy arguments by additional_inputs_list, + // !!! they will be added automatically added here, we need to add dummy argument for output/return_value + // !!! + if (GetNameString() == "bodyctx") { + xla::XlaBuilder* local_builder = lowering_ctx.builder(); + // TODO(@manfei): treat hard code parameter_idx value + // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); + // int64_t parameter_idx = 7; // conv2d + // int64_t parameter_idx = 11; // conv2d+mnist-treat // conv1 + bn1 + // int64_t parameter_idx = 13; // conv1 + bn1 + conv2 + // int64_t parameter_idx = 19; // conv1 + bn1 + conv2 + bn2 + int64_t parameter_idx = 21; // conv1 + bn1 + conv2 + bn2 + // int64_t parameter_idx = 9; // linear + // int64_t parameter_idx = tensors.size(); + for (auto& additional_input_tensor : additional_inputs_list) { + XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + xla::Shape shape = xtensor->shape().get(); xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, "UnusedArgumentsPlaceholder"); parameter_idx += 1; diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index fe12e392ea4..8bb054f0fe1 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1284,6 +1284,9 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( {{"graph_hash", torch::lazy::HashToString(coll.hash)}}); }, tsl::profiler::TraceMeLevel::kInfo); + + TF_VLOG(3) << "We are running XLAGraphExecutor::Compile now"; + static const bool enable_aliasing = runtime::sys_util::GetEnvBool("XLA_ENABLE_PARAM_ALIASING", true); static const size_t parameter_wrapping_threadshold = diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index bf32a712f3e..1424ce62853 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -10,40 +10,134 @@ from torch._ops import HigherOrderOperator import torch._higher_order_ops.while_loop from torch._higher_order_ops.while_loop import while_loop_op +from torch._higher_order_ops.while_loop import while_loop as torch_while_loop -def fori_loop(lower, upper, user_body_func, *init_val): +# TODO(@manfei): treat *input_value +def fori_loop(upper, lower, body_fun, init_val, input_value): device = xm.xla_device() - def cond_fn(upper, lower, *init_val): - return lower[0] < upper[0] + one_value = torch.tensor([1], dtype=torch.int32, device=device) - def body_fn(upper, lower, *init_val): - one_value_i = torch.ones(1, dtype=torch.int32, device=device) - res_list = list(user_body_func(*init_val)) - res_list.insert(0, lower) - res_list.insert(0, torch.sub(upper, one_value_i)) - return res_list + if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value = body_fun(input_value) + weight = body_fun.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = body_fun.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), bias.clone(), weight.clone( + ), output_value.clone() + + res = torch_while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, input_value, output_value)) + else: + output_value = torch.tensor([1], dtype=torch.int32, device=device) + + def cond_fn(upper, lower, one_value, x, input_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value): + new_lower = torch.add(one_value, lower) + output_val = body_fun(one_value, input_value) + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), output_val.clone() + + res = torch_while_loop(cond_fn, body_fn, + (upper, lower, one_value, init_val, input_value)) - res = while_loop(cond_fn, body_fn, (lower, upper, *init_val)) return res @while_loop_op.py_impl(DispatchKey.XLA) -def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): +def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') # cond_fn&body_fn: callable # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) if additional_inputs is None: additional_inputs = tuple() - return _xla_while_loop( - cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) - - -def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): - # untuple carried_inputs from while_loop - carried_inputs = carried_inputs[0] + return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) + else: + # modify body_fn return with additional_inputs + def new_body_fn(*carried_inputs): + # new_lower = torch.add(one_value, lower) + # output_value = linear_0(input_value) + # weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + # bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + # one_value, x), input_value.clone(), bias.clone(), weight.clone( + # ), output_value.clone() + # return body_fn(carried_inputs), weight.clone(), bias.clone() + # print("carried_inputs: ", carried_inputs) + # print("additional_inputs: ", additional_inputs) + # res1 = body_fn(*carried_inputs) + # print("res1: ", res1) + # print("type res1: ", type(res1)) + # print("type additional_inputs: ", type(additional_inputs)) + # print("*additional_inputs: ", *additional_inputs) + # res2 = (res1, ) + additional_inputs + # print("res2: ", res2) + # print("type res2: ", type(res2)) + # print("before it") + # print("body_fn(*carried_inputs): ", body_fn(*carried_inputs)) + # print("list(body_fn(*carried_inputs)): ", list(body_fn(*carried_inputs))) + # print("additional_inputs: ", additional_inputs) + # print("type additional_inputs: ", type(additional_inputs)) + # print("list(body_fn(*carried_inputs)).extend(list(additional_inputs)): ", list(body_fn(*carried_inputs)).extend(list(additional_inputs))) + # aaa = [1, 2, 3] + # bbb = [4, 5] + # ccc = (4, 5) + # aaa.extend(bbb) + # print(aaa) + # aaa.extend(ccc) + # print(aaa) + # res0 = [1, 2, 3].extend((4, 5)) + # res1 = [1, 2, 3].extend([4, 5]) + # print("res0: ", res0) + # print("res1: ", res1) + # thislist = ["apple", "banana", "cherry"] + # tropical = ["mango", "pineapple", "papaya"] + # thislist.extend(tropical) + # print(thislist) + # thislist = ["apple", "banana", "cherry"] + # thistuple = ("kiwi", "orange") + # thislist.extend(thistuple) + # print(thislist) + # mid = body_fn(*carried_inputs) + # res = mid.extend(list(additional_inputs)) + # res = list(body_fn(*carried_inputs)) + # res.extend(additional_inputs) + # print("res: ", res) + # return list(body_fn(*carried_inputs)).extend(additional_inputs) + # self.named_parameters + # weight = self.linear.weight + res = list(body_fn(*carried_inputs)) + # print("res: ", res) + # trynewres = res[:-1] + [res[-1]] + # print("trynewres: ", trynewres) + newres = res[:-1] + list(additional_inputs) + [res[-1]] + # print("newres: ", newres) + # res.insert(-2, *additional_inputs) + # print("new res: ", res) + # res.extend(additional_inputs) + # res.append(body_fn.bias) + # res.append(body_fn.weight) + # return res + return newres + # return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) + return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) + print("$$$ additional_inputs: ", additional_inputs) + # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) + + +def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # fake carried_inputs to split formal code fake_carried_inputs = [] for carried_input in carried_inputs: @@ -51,37 +145,76 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): fake_carried_inputs.append( torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) - fake_carried_inputs = tuple(fake_carried_inputs) - - # trans fake_carried_inputs from list(tensor) to list(xla::op) - kwargs = {} - if type(fake_carried_inputs) is tuple: - shapes = xb.tensor_shape(fake_carried_inputs) - else: - shapes = xb.tensor_shape((fake_carried_inputs)) - builder = xb.create_builder('test_while') - params = [] - for shape in shapes: - p = xb.mkparam(builder, len(params), shape) - params.append(p) + for additional_input in additional_inputs: + device = additional_input.device + fake_carried_inputs.append( + torch.randint( + 10, additional_input.size(), + dtype=additional_input.dtype).to(device)) + # print("fake_carried_inputs: ", fake_carried_inputs) - # generate cond_fn xlacomputation + # TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") - cond_ctx.buildforiloop([cond_result], list(fake_carried_inputs[2:])) + + # TODO(@manfei): treat hard-code cond xlacomputation change: currently switch output_value and weight position if additional_inputs(weight/bias) exists + additional_inputs_list_cond = list( + fake_carried_inputs[2:] + ) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor + if additional_inputs: + tmp_bias = additional_inputs_list_cond[3] # additional_inputs_list_cond[-3] # not used, change order doesn't affect logic + del additional_inputs_list_cond[3] # additional_inputs_list_cond[-3] # not used, change order doesn't affect logic + additional_inputs_list_cond.append(tmp_bias) # not used, change order doesn't affect logic + + cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) + # cond_hlo_print = xb.get_computation_hlo(cond_computation) + # print("cond computation: !!!!!!!!!") + # print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") - body_ctx.buildforiloop(list(body_result), []) + + # TODO(@manfei): treat hard-code body xlacomputation change: currently add non-changed output_value argument if additional_inputs(weight/bias) exists + if additional_inputs: + additional_inputs_list_body = [fake_carried_inputs[-3]] + else: + additional_inputs_list_body = [] + + # TODO(@manfei): treat hard-code parameters: additional_inputs_list_body + # body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) + # TODO(@manfei): get index of output_value, then trasfer them into buildforiloop for + body_ctx.buildforiloop(list(body_result), [body_result[-1],]) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) + # body_hlo_print = xb.get_computation_hlo(body_computation) + # print("body computation: !!!!!!!!!") + # print(body_hlo_print) + + # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while + total_inputs = carried_inputs + additional_inputs + kwargs = {} + if type(total_inputs) is tuple: + shapes = xb.tensor_shape(total_inputs) + else: + shapes = xb.tensor_shape((total_inputs)) + builder = xb.create_builder('test_while') + params = [] + for shape in shapes: + p = xb.mkparam(builder, len(params), shape) + params.append(p) + + # TODO(@manfei): treat hard-code input arguments, currently switch bias and output_value if additional_inputs(weight/bias) exists + if additional_inputs: + tmp_bias = params[5] # params[-3] + del params[5] # params[-3] + params.append(tmp_bias) # generate while xlacomputation input_tuple = xb.Op.tuple(tuple(params)) @@ -91,9 +224,12 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) + # hlo_print = xb.get_computation_hlo(computation) + # print("while computation: !!!!!!!!!") + # print(hlo_print) # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', - (carried_inputs), computation) + (total_inputs), computation) - return result \ No newline at end of file + return result