From 9541a214545ab3e1ba5043ab3b0a516fc615cd0f Mon Sep 17 00:00:00 2001 From: manfei Date: Fri, 10 May 2024 17:21:21 +0000 Subject: [PATCH] merge while_loop linear layer --- ...while_loop_simple_add_dispatch_in_torch.py | 493 +++++++++++++++++- test/test_test_mnist.py | 145 ++++++ torch_xla/csrc/init_python_bindings.cpp | 79 ++- torch_xla/experimental/fori_loop.py | 226 ++++++-- 4 files changed, 871 insertions(+), 72 deletions(-) create mode 100644 test/test_test_mnist.py diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 2bc88e75c77..9e35acec7fd 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -6,7 +6,7 @@ import torch_xla # We need to import the underlying implementation function to register with the dispatcher import torch_xla.experimental.fori_loop -from torch_xla.experimental.fori_loop import fori_loop +from torch_xla.experimental.fori_loop import fori_loop, _xla_while_loop_get_xla_computation from torch._higher_order_ops.while_loop import while_loop import torch_xla.core.xla_model as xm import torch_xla.core.xla_builder as xb @@ -18,16 +18,25 @@ def _fake_while_loop(cond_fn, body_fn, operands): operands = body_fn(*operands) return operands +# def _fake_fori_loop(lower, upper, body_fun, *init_val): +# (plus_value, init_val) = init_val +# for i in range((upper - lower)[0]): +# plus_value, init_val = body_fun(plus_value, init_val) +# return init_val def _fake_fori_loop(lower, upper, body_fun, *init_val): - (plus_value, init_val) = init_val - for i in range((upper - lower)[0]): - plus_value, init_val = body_fun(plus_value, init_val) - return init_val - + if len(init_val) > 1: + (a, b) = init_val + for i in range((upper - lower)[0]): + a = body_fun(a, b) + else: + for i in range((upper - lower)[0]): + a = body_fun(*init_val) + return a class WhileLoopTest(unittest.TestCase): - + # -------------------------------------- + # while_loop + PyLoweringContext def test_while_loop_tpu_subtraction(self): device = xm.xla_device() @@ -82,24 +91,285 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) - def test_fori_loop_tpu_addition(self): + # while_loop + PyLoweringContext + linear + def test_while_loop_tpu_simple_linear_outside_loop(self): xm.mark_step() device = xm.xla_device() + torch.set_grad_enabled(False) - lower = torch.tensor([2], dtype=torch.int32, device=device) + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value = linear_0(input_value) + weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), bias.clone(), weight.clone( + ), output_value.clone() + + upper = torch.tensor([1], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value)) + + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + return self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + + def test_while_loop_tpu_simple_linear_class_inside_loop(self): + + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + class SimpleWithLinear(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def forward(self, upper, lower, one_value, x, input_value, output_value): + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value_real = self.linear(input_value) + weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone( + ), output_value_real, weight.clone(), bias.clone() + + return while_loop( + cond_fn, body_fn, + (upper, lower, one_value, x, input_value, output_value)) + + simple_with_linear = SimpleWithLinear() upper = torch.tensor([52], dtype=torch.int32, device=device) - plus_value = torch.tensor([1], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) - def body_fun(*argus): - plus_value, init_val = argus - return plus_value, torch.add(plus_value, init_val) + weight_0 = simple_with_linear.linear.weight + bias_0 = simple_with_linear.linear.bias - _, _, _, actual = fori_loop(upper, lower, body_fun, plus_value, init_val) - expected = _fake_fori_loop(lower, upper, body_fun, plus_value, init_val) - self.assertEqual(expected, actual) + aaa = { + "simple_with_linear": + (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, + output_value)) + } + + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear( + upper, lower, one_value, init_val, l_in_0, output_value) + + # create same weight/bias liear model for compare + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + linear_0.weight.data = weight__ + linear_0.bias.data = bias__ + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + return aaa + + # WIP for target while_loop + PyLoweringContext + linear + def test_while_loop_tpu_simple_linear_class_inside_loop(self): + + xm.mark_step() + device = xm.xla_device() + #device = '' + torch.set_grad_enabled(False) + + class SimpleWithLinear(torch.nn.Module): + + def __init__(self): + super().__init__() + # self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + self.linear = torch.nn.Linear(2, 2) + self.register_buffer("dec", torch.tensor(1)) + + # def forward(self, iter, x): + # def cond_fn(it, x): + # return it - self.dec > 0 + + # def body_fn(it, x): + # return it - 1, self.linear(x) + + # return while_loop(cond_fn, body_fn, (iter, x)) + + def forward(self, iter, x): + def cond_fn(it, x): + return it - self.dec > 0 + + def body_fn(it, x): + return it - 1, self.linear(x) + + # return while_loop(cond_fn, body_fn, (iter, x)) + return _xla_while_loop_get_xla_computation(cond_fn, body_fn, (iter, x), ()) + + # def forward(self, upper, lower, one_value, x, input_value, output_value): + # def cond_fn(upper, lower, one_value, x, input_value, output_value): + # return lower[0] < upper[0] + + # def body_fn(upper, lower, one_value, x, input_value, output_value): + # new_lower = torch.add(one_value, lower) + # output_value_real = self.linear(input_value) + # weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + # bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement + # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + # one_value, x), input_value.clone( + # ), output_value_real, weight.clone(), bias.clone() + + # return while_loop( + # cond_fn, body_fn, + # (upper, lower, one_value, x, input_value, output_value)) + + simple_with_linear = SimpleWithLinear() + simple_with_linear.to(device) + #breakpoint() + input = torch.randn(2, 2).to(device) + iter = torch.tensor(3, device=device) + res = simple_with_linear(iter, input) + + return res + + # upper = torch.tensor([52], dtype=torch.int32, device=device) + # lower = torch.tensor([0], dtype=torch.int32, device=device) + # one_value = torch.tensor([1], dtype=torch.int32, device=device) + # init_val = torch.tensor([1], dtype=torch.int32, device=device) + # l_in_0 = torch.rand(10, device=xm.xla_device()) + # output_value = torch.zeros([20], dtype=torch.float32, device=device) + + # upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear( + # upper, lower, one_value, init_val, l_in_0, output_value) + + # # create same weight/bias liear model for compare + # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + # linear_0.weight.data = weight__ + # linear_0.bias.data = bias__ + # expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + # self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + # return aaa + + def test_while_loop_tpu_simple_linear_target_inside_loop(self): + + xm.mark_step() + device = xm.xla_device() + #device = '' + torch.set_grad_enabled(False) + + class SimpleWithLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + self.register_buffer("dec", torch.tensor(1)) + + def forward(self, iter, x): + def cond_fn(it, x): + return it - self.dec > 0 + + def body_fn(it, x): + return it - 1, self.linear(x) + + return while_loop(cond_fn, body_fn, (iter, x)) + + simple_with_linear = SimpleWithLinear() + simple_with_linear.to(device) + #breakpoint() + input = torch.randn(2, 2).to(device) + iter = torch.tensor(3, device=device) + res = simple_with_linear(iter, input) + + def test_while_loop_tpu_MNIST_outside_loop(self): + + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value = linear_0(input_value) + weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), bias.clone(), weight.clone( + ), output_value.clone() + + upper = torch.tensor([1], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value)) + + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + return self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + + # WIP for while_loop + PyLoweringContext + MNIST & inside_loop + def test_while_loop_tpu_MNIST_outside_loop(self): + + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value = linear_0(input_value) + weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), bias.clone(), weight.clone( + ), output_value.clone() + + upper = torch.tensor([1], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value)) + + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + return self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + + # ------------------------ + # _get_xla_computation + # pass def test_while_loop_get_xlacomputation(self): xm.mark_step() @@ -170,7 +440,157 @@ def forward(self, x): else: print("print computation from _get_xla_computation: null !!!!!!!!!!!!!") - # this test should be modified/enabled after merge with the PR #6867 + # _xla_while_loop_get_xla_computation + _get_xla_computation + def test_while_loop_tpu_subtraction_get_xla_computation(self): + + device = xm.xla_device() + + def cond_fn(init, limit_value): + return limit_value[0] <= init[0] + + def body_fn(init, limit_value): + one_value = torch.ones(1, dtype=torch.int32, device=device) + two_value = limit_value.clone() + return (torch.sub(init, one_value), two_value) + + init = torch.tensor([10], dtype=torch.int32, device=device) + limit_value = torch.tensor([0], dtype=torch.int32, device=device) + res = _xla_while_loop_get_xla_computation(cond_fn, body_fn, (init, limit_value), ()) + expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) + self.assertEqual(expected, res) + + def test_while_loop_tpu_addition_get_xla_computation(self): + + device = xm.xla_device() + + def cond_fn(init, limit_value): + return limit_value[0] >= init[0] + + def body_fn(init, limit_value): + one_value = torch.ones(1, dtype=torch.int32, device=device) + return (torch.add(init, one_value), limit_value.clone()) + + # TODO(@manfei): init and limit_value has to be torch.tensor. + init = torch.tensor([0], dtype=torch.int32, device=device) + limit_value = torch.tensor([10], dtype=torch.int32, device=device) + res = _xla_while_loop_get_xla_computation(cond_fn, body_fn, (init, limit_value), ()) + expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) + self.assertEqual(expected, res) + + def test_while_loop_tpu_subtraction_nested_get_xla_computation(self): + + device = xm.xla_device() + + def cond_fn(init, limit_value): + return limit_value[0] <= init[0] + + def body_fn(init, limit_value): + one_value = torch.ones(1, dtype=torch.int32, device=device) + two_value = limit_value.clone() + return (torch.sub(torch.sub(init, one_value), one_value), two_value) + + init = torch.tensor([10], dtype=torch.int32, device=device) + limit_value = torch.tensor([0], dtype=torch.int32, device=device) + res = _xla_while_loop_get_xla_computation(cond_fn, body_fn, (init, limit_value), ()) + expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) + self.assertEqual(expected, res) + + # _xla_while_loop_get_xla_computation + _get_xla_computation + linear + def test_while_loop_tpu_simple_linear_outside_loop_get_xla_computation(self): + + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value = linear_0(input_value) + weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), bias.clone(), weight.clone( + ), output_value.clone() + + upper = torch.tensor([1], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = _xla_while_loop_get_xla_computation( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value), ()) + + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + return self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + + def test_while_loop_tpu_simple_linear_class_inside_loop_get_xla_computation(self): + + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + class SimpleWithLinear(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def forward(self, upper, lower, one_value, x, input_value, output_value): + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value_real = self.linear(input_value) + weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone( + ), output_value_real, weight.clone(), bias.clone() + + return _xla_while_loop_get_xla_computation( + cond_fn, body_fn, + (upper, lower, one_value, x, input_value, output_value), ()) + + simple_with_linear = SimpleWithLinear() + upper = torch.tensor([52], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + weight_0 = simple_with_linear.linear.weight + bias_0 = simple_with_linear.linear.bias + + aaa = { + "simple_with_linear": + (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, + output_value)) + } + + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear( + upper, lower, one_value, init_val, l_in_0, output_value) + + # create same weight/bias liear model for compare + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + linear_0.weight.data = weight__ + linear_0.bias.data = bias__ + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + return aaa + + # while_loop + _get_xla_computation: WIP def test_while_loop_get_xlacomputation_tpu_simple_linear_while_loop(self): xm.mark_step() @@ -223,6 +643,45 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): print("print computation from PyLoweringContext: !!!!!!!!!") print(body_hlo_print) + # fori_loop + PyLoweringContext: WIP + def test_fori_loop_tpu_addition(self): + + xm.mark_step() + device = xm.xla_device() + + lower = torch.tensor([2], dtype=torch.int32, device=device) + upper = torch.tensor([52], dtype=torch.int32, device=device) + plus_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + + def body_fun(*argus): + plus_value, init_val = argus + return plus_value.clone(), torch.add(plus_value, init_val).clone() + + _, _, _, actual = fori_loop(upper, lower, body_fun, plus_value, init_val) + expected = _fake_fori_loop(lower, upper, body_fun, plus_value, init_val) + self.assertEqual(expected, actual) + + def test_fori_loop_tpu_simple_linear(self): + + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + upper = torch.tensor([52], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.randn(10, device=xm.xla_device()) + + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + + upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop( + upper, lower, linear_0, init_val, l_in_0) + + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + self.assertTrue(torch.all(torch.eq(expected, l_out_))) + if __name__ == '__main__': test = unittest.main() diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py new file mode 100644 index 00000000000..1cb6de86641 --- /dev/null +++ b/test/test_test_mnist.py @@ -0,0 +1,145 @@ +import torch +# import torchvision +import os +import shutil +import sys +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +# from torchvision import datasets, transforms +import torch_xla +import torch_xla.debug.metrics as met +import torch_xla.distributed.parallel_loader as pl +import torch_xla.utils.utils as xu +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.test.test_utils as test_utils +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import torch_xla.distributed.xla_backend +import torch_xla.experimental.fori_loop +from torch_xla.experimental.fori_loop import _xla_while_loop, _xla_while_loop_get_xla_computation +from torch._higher_order_ops.while_loop import while_loop + +n_epochs = 3 +batch_size_train = 8 # 64 +batch_size_test = 10 # 1000 +learning_rate = 0.01 +momentum = 0.5 +log_interval = 10 +random_seed = 1 +torch.backends.cudnn.enabled = False +torch.manual_seed(random_seed) + +### load data +test_loader = xu.SampleGenerator( + data=(torch.zeros(8, 1, 28,28), torch.zeros(8, dtype=torch.int64)), + sample_count=1000 // 8 // xm.xrt_world_size()) + +### build model +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +class MNIST(torch.nn.Module): + def __init__(self): + # super().__init__() + super(MNIST, self).__init__() + self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5, stride=1, padding=2).to(xm.xla_device()) + self.bn1 = torch.nn.BatchNorm2d(10).to(xm.xla_device()) + self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5).to(xm.xla_device()) + self.bn2 = torch.nn.BatchNorm2d(20).to(xm.xla_device()) + self.fc1 = torch.nn.Linear(500, 50).to(xm.xla_device()) + self.fc2 = torch.nn.Linear(50, 10).to(xm.xla_device()) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = self.bn1(x) + x = F.relu(F.max_pool2d(self.conv2(x), 2)) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + +def newnewnew_test(): + device = xm.xla_device() + torch.set_grad_enabled(False) + + simple_with_linear = MNIST() + + def cond_fn(upper, lower, one_value, x, input_value, output_value, *args): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value, *args): + new_lower = torch.add(one_value, lower) + output_value = simple_with_linear(input_value) + res = [upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone()] + bn_list = [] + for name, param in simple_with_linear.named_parameters(): + if name[:2]=='bn': + bn_list.append(param) + + res.insert(-1, param) + + # add still exist bn_list if the last additional_inputs is bn- pre add at the tile + if len(bn_list) !=0: + output_value = res[-1] + bn_list.reverse() + res = res[:-1] + bn_list + res.append(output_value) + bn_list = [] + + return tuple(res) + + upper = torch.tensor([50], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + bs=16 + l_in_0 = torch.randn(bs, 1, 28, 28, dtype=torch.float32, device=device) + output_value = torch.zeros([16,10], dtype=torch.float32, device=device) + + for name, param in simple_with_linear.named_parameters(): + print("name: ", name) + print("param: ", param.size()) + + additional_inputs = [] + bn_list = [] + for name, param in simple_with_linear.named_parameters(): + if name[:2]=='bn': + bn_list.append(param) + + additional_inputs.append(param) + + # add still exist bn_list if the last additional_inputs is bn- pre, add duplicated bn argus as the tile of the list + if len(bn_list) !=0: + bn_list.reverse() # reverse list for bn duplicate lists + additional_inputs = additional_inputs + bn_list + bn_list = [] + + upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, output_value_real__, = _xla_while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, output_value_real__, = _xla_while_loop_get_xla_computation( + # cond_fn, body_fn, + # (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) + print("finish newnewnew_test") + print("torch_add_res__: run times: ", torch_add_res__) + print("actual res: ", output_value_real__[0][0]) + expected_ = simple_with_linear(l_in_0) + print("expected res: ", expected_[0][0]) + +# run test model +def test_mnist(): + torch.manual_seed(1) + + print("before test_mnist") + newnewnew_test() + + print("after test_mnist") + +if __name__ == '__main__': + test_mnist() \ No newline at end of file diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index dacd6717994..4cae79c3a4f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -914,12 +914,66 @@ class PyLoweringContext { // needed in xlacomputation. void BuildForiLoop(std::vector tensors, std::vector additional_inputs_list = {}) { + // // hard-code modify cond xlacomputation input arguments with unusedarguments + // // for xla::while requriement + // if (GetNameString() == "condctx") { + // xla::XlaBuilder* local_builder = lowering_ctx.builder(); + // // XLA_ERROR() << "for condctx, we want test here! "; + // // int64_t parameter_idx0 = local_builder->GetProgramShape()->parameters_size(); + // // XLA_ERROR() << "for condctx, we have args now: " << parameter_idx0; + // int64_t parameter_idx = + // 2; // parameter_idx start from 2 after used upper and lower // param_count + // for (auto& additional_input_tensor : additional_inputs_list) { + // XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + // xla::Shape shape = xtensor->shape().get(); + // xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + // "UnusedArgumentsPlaceholder"); + // parameter_idx += 1; + // } + // } + + // // hard-code modify body xlacomputation input arguments with unusedarguments + // // for xla::while requriement + // if (GetNameString() == "bodyctx") { + // xla::XlaBuilder* local_builder = lowering_ctx.builder(); + // // TODO(@manfei): treat hard code parameter_idx value + // int64_t parameter_idx = 21; + // for (auto& additional_input_tensor : additional_inputs_list) { + // XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + // xla::Shape shape = xtensor->shape().get(); + // xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + // "UnusedArgumentsPlaceholder"); + // parameter_idx += 1; + // } + // } + + // Get the backing XLA tensors from the output torch tensor handles + std::vector xtensors = + GetXlaTensors(tensors, /*want_all=*/true); + + // Get the lazy IR value from the output XLA tensors + std::vector ir_values; + for (auto& xtensor : xtensors) { + torch::lazy::Value value = xtensor->GetIrValue(); + ir_values.push_back(value); + } + + // Lower the graph using the output IR values + for (auto& ir_value : ir_values) { + xla::XlaOp root = lowering_ctx.GetOutputOp( + torch::lazy::Output(ir_value.node.get(), ir_value.index)); + lowering_ctx.AddResult(root); + } + + xla::XlaBuilder* local_builder = lowering_ctx.builder(); + int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); + // XLA_ERROR() << "for fori_loop, we have args now: " << parameter_idx; + // hard-code modify cond xlacomputation input arguments with unusedarguments // for xla::while requriement if (GetNameString() == "condctx") { - xla::XlaBuilder* local_builder = lowering_ctx.builder(); - int64_t parameter_idx = - 2; // parameter_idx start from 2 after used upper and lower + // xla::XlaBuilder* local_builder = lowering_ctx.builder(); + // int64_t parameter_idx = 2; // parameter_idx start from 2 after used upper and lower // param_count for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); xla::Shape shape = xtensor->shape().get(); @@ -934,7 +988,7 @@ class PyLoweringContext { if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value - int64_t parameter_idx = 7; + // int64_t parameter_idx = 21; for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); xla::Shape shape = xtensor->shape().get(); @@ -944,23 +998,6 @@ class PyLoweringContext { } } - // Get the backing XLA tensors from the output torch tensor handles - std::vector xtensors = - GetXlaTensors(tensors, /*want_all=*/true); - - // Get the lazy IR value from the output XLA tensors - std::vector ir_values; - for (auto& xtensor : xtensors) { - torch::lazy::Value value = xtensor->GetIrValue(); - ir_values.push_back(value); - } - - // Lower the graph using the output IR values - for (auto& ir_value : ir_values) { - xla::XlaOp root = lowering_ctx.GetOutputOp( - torch::lazy::Output(ir_value.node.get(), ir_value.index)); - lowering_ctx.AddResult(root); - } computation = ConsumeValue(lowering_ctx.BuildXla()); // wrap inputs of cond/body_computation diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index bf32a712f3e..500ea416e14 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -1,3 +1,4 @@ + import numpy as np import torch import torch_xla @@ -10,40 +11,65 @@ from torch._ops import HigherOrderOperator import torch._higher_order_ops.while_loop from torch._higher_order_ops.while_loop import while_loop_op +from torch._higher_order_ops.while_loop import while_loop as torch_while_loop -def fori_loop(lower, upper, user_body_func, *init_val): +# TODO(@manfei): treat *input_value +def fori_loop(upper, lower, body_fun, init_val, input_value): device = xm.xla_device() - def cond_fn(upper, lower, *init_val): - return lower[0] < upper[0] + one_value = torch.tensor([1], dtype=torch.int32, device=device) + + if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] - def body_fn(upper, lower, *init_val): - one_value_i = torch.ones(1, dtype=torch.int32, device=device) - res_list = list(user_body_func(*init_val)) - res_list.insert(0, lower) - res_list.insert(0, torch.sub(upper, one_value_i)) - return res_list + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value = body_fun(input_value) + weight = body_fun.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = body_fun.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), bias.clone(), weight.clone( + ), output_value.clone() + + res = torch_while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, input_value, output_value)) + else: + output_value = torch.tensor([1], dtype=torch.int32, device=device) + + def cond_fn(upper, lower, one_value, x, input_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value): + new_lower = torch.add(one_value, lower) + output_val = body_fun(one_value, input_value) + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), output_val.clone() + + res = torch_while_loop(cond_fn, body_fn, + (upper, lower, one_value, init_val, input_value)) - res = while_loop(cond_fn, body_fn, (lower, upper, *init_val)) return res @while_loop_op.py_impl(DispatchKey.XLA) -def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): +def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') # cond_fn&body_fn: callable # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) if additional_inputs is None: additional_inputs = tuple() - return _xla_while_loop( - cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) + print("arrive @while_loop_op.py_impl(DispatchKey.XLA)") + return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) -def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): - # untuple carried_inputs from while_loop - carried_inputs = carried_inputs[0] +def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): + print("arrive _xla_while_loop") # fake carried_inputs to split formal code fake_carried_inputs = [] for carried_input in carried_inputs: @@ -51,37 +77,169 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): fake_carried_inputs.append( torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) - fake_carried_inputs = tuple(fake_carried_inputs) - - # trans fake_carried_inputs from list(tensor) to list(xla::op) - kwargs = {} - if type(fake_carried_inputs) is tuple: - shapes = xb.tensor_shape(fake_carried_inputs) - else: - shapes = xb.tensor_shape((fake_carried_inputs)) - builder = xb.create_builder('test_while') - params = [] - for shape in shapes: - p = xb.mkparam(builder, len(params), shape) - params.append(p) + for additional_input in additional_inputs: + device = additional_input.device + fake_carried_inputs.append( + torch.randint( + 10, additional_input.size(), + dtype=additional_input.dtype).to(device)) - # generate cond_fn xlacomputation + # TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") - cond_ctx.buildforiloop([cond_result], list(fake_carried_inputs[2:])) + + # TODO(@manfei): treat hard-code cond xlacomputation change: currently switch output_value and weight position if additional_inputs(weight/bias) exists + additional_inputs_list_cond = list( + fake_carried_inputs[2:] + ) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor + if additional_inputs: + tmp_bias = additional_inputs_list_cond[ + -3] # not used, change order doesn't affect logic + del additional_inputs_list_cond[ + -3] # not used, change order doesn't affect logic + additional_inputs_list_cond.append( + tmp_bias) # not used, change order doesn't affect logic + + # cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) + cond_ctx.buildforiloop([cond_result], ()) cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) + cond_hlo_print = xb.get_computation_hlo(cond_computation) + print("cond computation: !!!!!!!!!") + print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") - body_ctx.buildforiloop(list(body_result), []) + + # TODO(@manfei): treat hard-code body xlacomputation change: currently add non-changed output_value argument if additional_inputs(weight/bias) exists + if additional_inputs: + additional_inputs_list_body = [fake_carried_inputs[-3]] + else: + additional_inputs_list_body = [] + + # TODO(@manfei): treat hard-code parameters: additional_inputs_list_body + # body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) + body_ctx.buildforiloop(list(body_result), ()) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) + body_hlo_print = xb.get_computation_hlo(body_computation) + print("body computation: !!!!!!!!!") + print(body_hlo_print) + + # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while + total_inputs = carried_inputs + additional_inputs + kwargs = {} + if type(total_inputs) is tuple: + shapes = xb.tensor_shape(total_inputs) + else: + shapes = xb.tensor_shape((total_inputs)) + builder = xb.create_builder('test_while') + params = [] + for shape in shapes: + p = xb.mkparam(builder, len(params), shape) + params.append(p) + + # TODO(@manfei): treat hard-code input arguments, currently switch bias and output_value if additional_inputs(weight/bias) exists + if additional_inputs: + tmp_bias = params[-3] + del params[-3] + params.append(tmp_bias) + + # generate while xlacomputation + input_tuple = xb.Op.tuple(tuple(params)) + w = xb.mkop( + 'While', (input_tuple.op,), + condition_computation=cond_computation, + body_computation=body_computation) + name = 'fori_loop_ed_torch_func' + computation = w.build(name) + + # gain final result with generated while xlacomputation + result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', + (total_inputs), computation) + + return result + +def _xla_while_loop_get_xla_computation(cond_fn, body_fn, carried_inputs, additional_inputs=None): + # fake carried_inputs to split formal code + fake_carried_inputs = [] + for carried_input in carried_inputs: + device = carried_input.device + fake_carried_inputs.append( + torch.randint(10, carried_input.size(), + dtype=carried_input.dtype).to(device)) + for additional_input in additional_inputs: + device = additional_input.device + fake_carried_inputs.append( + torch.randint( + 10, additional_input.size(), + dtype=additional_input.dtype).to(device)) + + # cond_fn xlacomputation + additional_inputs_list_cond = list( + fake_carried_inputs[2:] + ) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor + if additional_inputs: + tmp_bias = additional_inputs_list_cond[ + -3] # not used, change order doesn't affect logic + del additional_inputs_list_cond[ + -3] # not used, change order doesn't affect logic + additional_inputs_list_cond.append( + tmp_bias) # not used, change order doesn't affect logic + + cond_result = cond_fn(*fake_carried_inputs) + # cond_ctx = torch_xla._XLAC.lowering.LoweringContext() + # cond_ctx.set_name_string("condctx") + # cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) + # cond_hlo = cond_ctx.hlo() + # cond_computation = xb.computation_from_module_proto("condcomputation", + # cond_hlo) + cond_computation = torch_xla._XLAC._get_xla_computation([cond_result], [], True) + cond_hlo_print = xb.get_computation_hlo(cond_computation) + print("cond computation: !!!!!!!!!") + print(cond_hlo_print) + + # generate body_fn xlacomputation + if additional_inputs: + additional_inputs_list_body = [fake_carried_inputs[-3]] + else: + additional_inputs_list_body = [] + + body_result = body_fn(*fake_carried_inputs) + # body_ctx = torch_xla._XLAC.lowering.LoweringContext() + # body_ctx.set_name_string("bodyctx") + # body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) + # body_hlo = body_ctx.hlo() + # body_computation = xb.computation_from_module_proto("bodycomputation", + # body_hlo) + body_computation = torch_xla._XLAC._get_xla_computation(list(body_result), [], True) + body_hlo_print = xb.get_computation_hlo(body_computation) + print("body computation: !!!!!!!!!") + print(body_hlo_print) + + # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while + total_inputs = carried_inputs + additional_inputs + kwargs = {} + if type(total_inputs) is tuple: + shapes = xb.tensor_shape(total_inputs) + else: + shapes = xb.tensor_shape((total_inputs)) + builder = xb.create_builder('test_while') + params = [] + for shape in shapes: + p = xb.mkparam(builder, len(params), shape) + params.append(p) + + # TODO(@manfei): treat hard-code input arguments, currently switch bias and output_value if additional_inputs(weight/bias) exists + if additional_inputs: + tmp_bias = params[-3] + del params[-3] + params.append(tmp_bias) # generate while xlacomputation input_tuple = xb.Op.tuple(tuple(params)) @@ -94,6 +252,6 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', - (carried_inputs), computation) + (total_inputs), computation) - return result \ No newline at end of file + return result