From 94bbb6cdaf17f4691e6e2750e5ad4cf39044f218 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:11:53 -0700 Subject: [PATCH 001/546] Update test_xla_sharding.cpp --- test/cpp/test_xla_sharding.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index e1f908b5c80..a17031f148e 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -435,5 +435,17 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) { ->HasValue()); } +TEST_F(XLAShardingTest, TestForiLoopAddUnusedParameterInXlaComputation) { + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}); + // Build simple addition. + xla::XlaBuilder b("builder"); + auto x = xla::Parameter(&b, /*parameter_number=*/0, shape, "p0"); + xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}); + auto zzz = xla::Parameter(&b, /*parameter_number=*/1, shape2, "p1"); + auto y = xla::Add(x, xla::ConstantR0(&b, 3)); + xla::XlaComputation xla_computation = + ConsumeValue(b.Build(/*remove_dynamic_dimensions=*/false)); +} + } // namespace cpp_test } // namespace torch_xla From aacaa6279985aa4812f9bc696b097b8f0fb12574 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:12:40 -0700 Subject: [PATCH 002/546] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py --- ...while_loop_simple_add_dispatch_in_torch.py | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index a76197cc736..e4c5218ccf5 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -18,6 +18,19 @@ def _fake_while_loop(cond_fn, body_fn, operands): operands = body_fn(*operands) return operands +def _fake_fori_loop(lower, upper, body_fun, *init_val): + # operands need to be more than one here + # print("upper - lower: ", upper - lower) + # print("init_val: ", init_val) + # print("type init_val: ", type(init_val)) + (a, b) = init_val + # print("a: ", a) + # print("b: ", b) + for i in range((upper - lower)[0]): + a = body_fun(a, b) + # print("a: ", a) + # print("i: ", i) + return a def _fake_fori_loop(lower, upper, body_fun, *init_val): (plus_value, init_val) = init_val @@ -64,23 +77,23 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) - def test_while_loop_tpu_subtraction_nested(self): + def test_fori_loop_tpu_addition(self): + xm.mark_step() device = xm.xla_device() - def cond_fn(init, limit_value): - return limit_value[0] <= init[0] + lower = torch.tensor([2], dtype=torch.int32, device=device) + upper = torch.tensor([52], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) - def body_fn(init, limit_value): - one_value = torch.ones(1, dtype=torch.int32, device=device) - two_value = limit_value.clone() - return (torch.sub(torch.sub(init, one_value), one_value), two_value) + def body_fun(a, b): + return torch.add(a, b) - init = torch.tensor([10], dtype=torch.int32, device=device) - limit_value = torch.tensor([0], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, (init, limit_value)) - expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) - self.assertEqual(expected, res) + lower_, upper_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val) + expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value) + print("expected: ", expected) + self.assertEqual(expected, res_) def test_fori_loop_tpu_addition(self): From a2f7062689c98d538c7971618411438b94f8a995 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:14:46 -0700 Subject: [PATCH 003/546] Update init_python_bindings.cpp --- torch_xla/csrc/init_python_bindings.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c603e5d27a5..8890296a61f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -889,7 +889,17 @@ class PyLoweringContext { : lowering_ctx("PyLoweringContext", device) {} // Builds a HLO graph given a set of output tensors. - void Build(std::vector tensors) { + void Build(std::vector tensors, std::vector input_arguments) { + if (GetNameString() == "condctx") { + xla::XlaBuilder* local_builder = lowering_ctx.builder(); + int64_t parameters_number_i = 2; + for (at::Tensor input_argument : input_arguments) { + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); + xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, "UnusedArgumentsPlaceholder"); + parameters_number_i = parameters_number_i + 1; + } + } + // Get the backing XLA tensors from the output torch tensor handles std::vector xtensors = GetXlaTensors(tensors, /*want_all=*/true); From bd4ff83036211feb9fd2f106d78de6b46efb8b05 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:15:18 -0700 Subject: [PATCH 004/546] Update fori_loop.py --- torch_xla/experimental/fori_loop.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index bf32a712f3e..3533949fccd 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -11,6 +11,28 @@ import torch._higher_order_ops.while_loop from torch._higher_order_ops.while_loop import while_loop_op +def fori_loop(lower, upper, body_fun, one_value, init_val): + + device = xm.xla_device() + + def cond_fn(upper, lower, x): + return lower[0] < upper[0] + + def body_fn(upper, lower, x): + one_value = torch.ones(1, dtype=torch.int32, device=device) + return (torch.sub(upper, one_value), lower, body_fun(one_value, x)) + + def old_cond_fn(one_value, lower, upper, init_val): + lower_compare = torch.add(lower, one_value) + return lower_compare[0] <= upper[0] + + def old_body_fn(one_value, lower, upper, init_val): + new_lower = torch.add(lower, one_value) + new_init_val = body_fun(init_val, one_value) + return (one_value, new_lower, upper, new_init_val) + + res = _xla_while_loop(cond_fn, body_fn, lower, upper, init_val) + return res def fori_loop(lower, upper, user_body_func, *init_val): From 5765405e9ea751b5f2d786140d20a3e8700b6c44 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:26:59 -0700 Subject: [PATCH 005/546] format --- torch_xla/csrc/init_python_bindings.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 8890296a61f..c876e56bd8e 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -889,13 +889,16 @@ class PyLoweringContext { : lowering_ctx("PyLoweringContext", device) {} // Builds a HLO graph given a set of output tensors. - void Build(std::vector tensors, std::vector input_arguments) { + void Build(std::vector tensors, + std::vector input_arguments) { if (GetNameString() == "condctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); int64_t parameters_number_i = 2; for (at::Tensor input_argument : input_arguments) { - xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); - xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, "UnusedArgumentsPlaceholder"); + xla::Shape shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); + xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, + "UnusedArgumentsPlaceholder"); parameters_number_i = parameters_number_i + 1; } } From b518a1386df8fcc888599fb705b575fc311e9403 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:29:51 -0700 Subject: [PATCH 006/546] format --- ...fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index e4c5218ccf5..727ec9dadb8 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -18,6 +18,7 @@ def _fake_while_loop(cond_fn, body_fn, operands): operands = body_fn(*operands) return operands + def _fake_fori_loop(lower, upper, body_fun, *init_val): # operands need to be more than one here # print("upper - lower: ", upper - lower) @@ -38,7 +39,6 @@ def _fake_fori_loop(lower, upper, body_fun, *init_val): plus_value, init_val = body_fun(plus_value, init_val) return init_val - class WhileLoopTest(unittest.TestCase): def test_while_loop_tpu_subtraction(self): @@ -90,7 +90,8 @@ def test_fori_loop_tpu_addition(self): def body_fun(a, b): return torch.add(a, b) - lower_, upper_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val) + lower_, upper_, res_ = fori_loop(upper, lower, body_fun, one_value, + init_val) expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value) print("expected: ", expected) self.assertEqual(expected, res_) From aacb407c4441b20dcdefa809329136a97686ce39 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:32:22 -0700 Subject: [PATCH 007/546] format --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 3533949fccd..d34e4e8855d 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -11,6 +11,7 @@ import torch._higher_order_ops.while_loop from torch._higher_order_ops.while_loop import while_loop_op + def fori_loop(lower, upper, body_fun, one_value, init_val): device = xm.xla_device() @@ -22,7 +23,7 @@ def body_fn(upper, lower, x): one_value = torch.ones(1, dtype=torch.int32, device=device) return (torch.sub(upper, one_value), lower, body_fun(one_value, x)) - def old_cond_fn(one_value, lower, upper, init_val): + def old_cond_fn(one_value, lower, upper, init_val): lower_compare = torch.add(lower, one_value) return lower_compare[0] <= upper[0] @@ -51,7 +52,6 @@ def body_fn(upper, lower, *init_val): res = while_loop(cond_fn, body_fn, (lower, upper, *init_val)) return res - @while_loop_op.py_impl(DispatchKey.XLA) def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') From ff115afa9a783e3876f49e2cb1fa4ccc0ab592c2 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 29 Mar 2024 00:07:57 -0700 Subject: [PATCH 008/546] Update init_python_bindings.cpp --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c876e56bd8e..25ffae673f6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -890,7 +890,7 @@ class PyLoweringContext { // Builds a HLO graph given a set of output tensors. void Build(std::vector tensors, - std::vector input_arguments) { + std::vector input_arguments = std::vector::empty) { if (GetNameString() == "condctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); int64_t parameters_number_i = 2; From b53e87faf11a7521a59fcea33cc28afa68cae64a Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 29 Mar 2024 01:27:13 -0700 Subject: [PATCH 009/546] Update init_python_bindings.cpp --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 25ffae673f6..2b0221cb0a1 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -890,7 +890,7 @@ class PyLoweringContext { // Builds a HLO graph given a set of output tensors. void Build(std::vector tensors, - std::vector input_arguments = std::vector::empty) { + std::vector input_arguments = {}) { if (GetNameString() == "condctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); int64_t parameters_number_i = 2; From 27639c487d503e184cb85f72de25d7e579dc9f14 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 29 Mar 2024 17:47:17 +0000 Subject: [PATCH 010/546] test formal change --- ...p_with_while_loop_simple_add_dispatch_in_torch.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 727ec9dadb8..3578434bef3 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -20,17 +20,9 @@ def _fake_while_loop(cond_fn, body_fn, operands): def _fake_fori_loop(lower, upper, body_fun, *init_val): - # operands need to be more than one here - # print("upper - lower: ", upper - lower) - # print("init_val: ", init_val) - # print("type init_val: ", type(init_val)) (a, b) = init_val - # print("a: ", a) - # print("b: ", b) for i in range((upper - lower)[0]): a = body_fun(a, b) - # print("a: ", a) - # print("i: ", i) return a def _fake_fori_loop(lower, upper, body_fun, *init_val): @@ -55,7 +47,7 @@ def body_fn(init, limit_value): init = torch.tensor([10], dtype=torch.int32, device=device) limit_value = torch.tensor([0], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, (init, limit_value)) + res = while_loop(cond_fn, body_fn, init, limit_value) # (init, limit_value)) expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) @@ -73,7 +65,7 @@ def body_fn(init, limit_value): # TODO(@manfei): init and limit_value has to be torch.tensor. init = torch.tensor([0], dtype=torch.int32, device=device) limit_value = torch.tensor([10], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, (init, limit_value)) + res = while_loop(cond_fn, body_fn, init, limit_value) # (init, limit_value)) expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) From cf8b7bc8ae5f350d9bd81d8b391fc917ba1a7a2a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 29 Mar 2024 17:48:16 +0000 Subject: [PATCH 011/546] test formal change --- torch_xla/experimental/fori_loop.py | 95 ++++++++++++++--------------- 1 file changed, 45 insertions(+), 50 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index d34e4e8855d..57f7162bf3f 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -12,45 +12,35 @@ from torch._higher_order_ops.while_loop import while_loop_op -def fori_loop(lower, upper, body_fun, one_value, init_val): +# TODO(@manfei): delete one_value? +def fori_loop(upper, lower, body_fun, init_val, *input_value): device = xm.xla_device() + # weight_0 = body_fun.weight + # bias_0 = body_fun.bias + # one_value = torch.tensor([1], dtype=torch.int32, device=device) - def cond_fn(upper, lower, x): + def cond_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): # , bias_0): # , output_value): return lower[0] < upper[0] - def body_fn(upper, lower, x): - one_value = torch.ones(1, dtype=torch.int32, device=device) - return (torch.sub(upper, one_value), lower, body_fun(one_value, x)) - - def old_cond_fn(one_value, lower, upper, init_val): - lower_compare = torch.add(lower, one_value) - return lower_compare[0] <= upper[0] - - def old_body_fn(one_value, lower, upper, init_val): - new_lower = torch.add(lower, one_value) - new_init_val = body_fun(init_val, one_value) - return (one_value, new_lower, upper, new_init_val) - - res = _xla_while_loop(cond_fn, body_fn, lower, upper, init_val) + def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): # , bias_0): + # weight = body_fun.weight + new_lower = torch.add(one_value, lower) ### !!! this matter, torch.add might would change the second argument's value, even we use a new variable to catch the result!!! + output_value = body_fun(*input_value) ### !!! due to the output_value is not actually used here, + # --- !!! its original value would not be used, and it would be replaces by the result of body_fun + # --- !!! so, due to PTXLA is traced from result tensor, so the arguments `output_value` would not be included in the body_xlacomputation + # --- !!! so, we need to modify ini_python_binding.cpp to add a fake arguments in the xlacompputation + weight = body_fun.weight + bias = body_fun.bias + return upper, new_lower, one_value, torch.add(one_value, x), *input_value, weight, bias, output_value + + output_value = torch.zeros([20], dtype=torch.float32, device=device) + weight_0 = body_fun.weight + bias_0 = body_fun.bias + one_value = torch.tensor([1], dtype=torch.int32, device=device) + res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, *input_value, weight_0, bias_0, output_value)) return res -def fori_loop(lower, upper, user_body_func, *init_val): - - device = xm.xla_device() - - def cond_fn(upper, lower, *init_val): - return lower[0] < upper[0] - - def body_fn(upper, lower, *init_val): - one_value_i = torch.ones(1, dtype=torch.int32, device=device) - res_list = list(user_body_func(*init_val)) - res_list.insert(0, lower) - res_list.insert(0, torch.sub(upper, one_value_i)) - return res_list - - res = while_loop(cond_fn, body_fn, (lower, upper, *init_val)) - return res @while_loop_op.py_impl(DispatchKey.XLA) def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): @@ -59,8 +49,7 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) if additional_inputs is None: additional_inputs = tuple() - return _xla_while_loop( - cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) + return _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) # a=a, b=b, c=c, def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): @@ -70,34 +59,27 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): fake_carried_inputs = [] for carried_input in carried_inputs: device = carried_input.device + #TODO(@manfei) type = carried_input.type fake_carried_inputs.append( torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) fake_carried_inputs = tuple(fake_carried_inputs) - # trans fake_carried_inputs from list(tensor) to list(xla::op) - kwargs = {} - if type(fake_carried_inputs) is tuple: - shapes = xb.tensor_shape(fake_carried_inputs) - else: - shapes = xb.tensor_shape((fake_carried_inputs)) - builder = xb.create_builder('test_while') - params = [] - for shape in shapes: - p = xb.mkparam(builder, len(params), shape) - params.append(p) - # generate cond_fn xlacomputation - cond_result = cond_fn(*fake_carried_inputs) + # TODO(@manfei): specify which element is for which argument like a,b,c + cond_result = cond_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], output_value=fake_carried_inputs[-2], bias_0=fake_carried_inputs[-1]) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") - cond_ctx.buildforiloop([cond_result], list(fake_carried_inputs[2:])) + additional_inputs_list = list(fake_carried_inputs[2:]) + for i in range(len(additional_inputs)): + additional_inputs_list.append(additional_inputs[0]) + cond_ctx.buildforiloop([cond_result], additional_inputs_list) cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) # generate body_fn xlacomputation - body_result = body_fn(*fake_carried_inputs) + body_result = body_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], output_value=fake_carried_inputs[-2], bias_0=fake_carried_inputs[-1]) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") body_ctx.buildforiloop(list(body_result), []) @@ -105,6 +87,18 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) + # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while + kwargs = {} + if type(carried_inputs) is tuple: + shapes = xb.tensor_shape(carried_inputs) + else: + shapes = xb.tensor_shape((carried_inputs)) + builder = xb.create_builder('test_while') + params = [] + for shape in shapes: + p = xb.mkparam(builder, len(params), shape) + params.append(p) + # generate while xlacomputation input_tuple = xb.Op.tuple(tuple(params)) w = xb.mkop( @@ -116,6 +110,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', - (carried_inputs), computation) + (carried_inputs), + computation) return result \ No newline at end of file From 945ab7ae07e271342c74fdb8173fcb54c35c4e51 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 29 Mar 2024 17:52:16 +0000 Subject: [PATCH 012/546] test formal change --- ...ri_loop_with_while_loop_simple_add_dispatch_in_torch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 3578434bef3..928e70ffc7e 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -55,10 +55,12 @@ def test_while_loop_tpu_addition(self): device = xm.xla_device() - def cond_fn(init, limit_value): + def cond_fn(loop_carry): # init, limit_value): + init, limit_value = loop_carry return limit_value[0] >= init[0] - def body_fn(init, limit_value): + def body_fn(loop_carry): # init, limit_value): + init, limit_value = loop_carry one_value = torch.ones(1, dtype=torch.int32, device=device) return (torch.add(init, one_value), limit_value.clone()) @@ -85,7 +87,6 @@ def body_fun(a, b): lower_, upper_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val) expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value) - print("expected: ", expected) self.assertEqual(expected, res_) def test_fori_loop_tpu_addition(self): From f561a127a9cff04cacf96b4cfa176ebdb5c5fb29 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 29 Mar 2024 17:52:59 +0000 Subject: [PATCH 013/546] test formal change --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 928e70ffc7e..b418d0f0ba6 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -47,7 +47,7 @@ def body_fn(init, limit_value): init = torch.tensor([10], dtype=torch.int32, device=device) limit_value = torch.tensor([0], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, init, limit_value) # (init, limit_value)) + res = while_loop(cond_fn, body_fn, (init, limit_value)) expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) @@ -67,7 +67,7 @@ def body_fn(loop_carry): # init, limit_value): # TODO(@manfei): init and limit_value has to be torch.tensor. init = torch.tensor([0], dtype=torch.int32, device=device) limit_value = torch.tensor([10], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, init, limit_value) # (init, limit_value)) + res = while_loop(cond_fn, body_fn, (init, limit_value)) expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) From 834cb251c4972153c9ddb4b644246f48f1ba54b8 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 29 Mar 2024 18:24:04 +0000 Subject: [PATCH 014/546] test formal change --- ...ori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index b418d0f0ba6..55a02a55e48 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -55,12 +55,10 @@ def test_while_loop_tpu_addition(self): device = xm.xla_device() - def cond_fn(loop_carry): # init, limit_value): - init, limit_value = loop_carry + def cond_fn(init, limit_value): return limit_value[0] >= init[0] - def body_fn(loop_carry): # init, limit_value): - init, limit_value = loop_carry + def body_fn(init, limit_value): one_value = torch.ones(1, dtype=torch.int32, device=device) return (torch.add(init, one_value), limit_value.clone()) From a9814d2c04bde1729c3a6b015f5462b63024accd Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Mon, 1 Apr 2024 10:57:14 -0700 Subject: [PATCH 015/546] Update test_xla_sharding.cpp --- test/cpp/test_xla_sharding.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index a17031f148e..b59927cdbf7 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -440,11 +440,15 @@ TEST_F(XLAShardingTest, TestForiLoopAddUnusedParameterInXlaComputation) { // Build simple addition. xla::XlaBuilder b("builder"); auto x = xla::Parameter(&b, /*parameter_number=*/0, shape, "p0"); + // Add unused parameter before create xlacomputation xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}); auto zzz = xla::Parameter(&b, /*parameter_number=*/1, shape2, "p1"); auto y = xla::Add(x, xla::ConstantR0(&b, 3)); xla::XlaComputation xla_computation = ConsumeValue(b.Build(/*remove_dynamic_dimensions=*/false)); + + // Check whether the unused parameter has been included into xlacomputation or not + EXPECT_EQ(xla_computation.GetProgramShape().parameters_size(), 2); } } // namespace cpp_test From feb39cace9acd0e1cad5384f7cf565aee921e6db Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Mon, 1 Apr 2024 11:49:42 -0700 Subject: [PATCH 016/546] Update test_xla_sharding.cpp --- test/cpp/test_xla_sharding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index b59927cdbf7..55cdc8f1fb4 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -448,7 +448,7 @@ TEST_F(XLAShardingTest, TestForiLoopAddUnusedParameterInXlaComputation) { ConsumeValue(b.Build(/*remove_dynamic_dimensions=*/false)); // Check whether the unused parameter has been included into xlacomputation or not - EXPECT_EQ(xla_computation.GetProgramShape().parameters_size(), 2); + EEXPECT_EQ(xla_computation.GetProgramShape()->parameters_size(), 2); } } // namespace cpp_test From 8b2cd86eadde4e4ca7a37b37d93cea770ac35d56 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Mon, 1 Apr 2024 11:52:57 -0700 Subject: [PATCH 017/546] format --- test/cpp/test_xla_sharding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 55cdc8f1fb4..e27be283f66 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -447,7 +447,7 @@ TEST_F(XLAShardingTest, TestForiLoopAddUnusedParameterInXlaComputation) { xla::XlaComputation xla_computation = ConsumeValue(b.Build(/*remove_dynamic_dimensions=*/false)); - // Check whether the unused parameter has been included into xlacomputation or not + // Check whether the unused parameter has been included into xlacomputation EEXPECT_EQ(xla_computation.GetProgramShape()->parameters_size(), 2); } From cc1e7ef3a5d19f4c679c63549e762d87b8c9abdf Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Mon, 1 Apr 2024 11:56:01 -0700 Subject: [PATCH 018/546] format --- test/cpp/test_xla_sharding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index e27be283f66..167ffd753e7 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -448,7 +448,7 @@ TEST_F(XLAShardingTest, TestForiLoopAddUnusedParameterInXlaComputation) { ConsumeValue(b.Build(/*remove_dynamic_dimensions=*/false)); // Check whether the unused parameter has been included into xlacomputation - EEXPECT_EQ(xla_computation.GetProgramShape()->parameters_size(), 2); + EXPECT_EQ(xla_computation.GetProgramShape()->parameters_size(), 2); } } // namespace cpp_test From 1c54e4b03552c21b4a9dd66002a36253e9050e2d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 2 Apr 2024 18:07:15 +0000 Subject: [PATCH 019/546] upstream --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 55a02a55e48..b92d8e0ee44 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -47,7 +47,7 @@ def body_fn(init, limit_value): init = torch.tensor([10], dtype=torch.int32, device=device) limit_value = torch.tensor([0], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, (init, limit_value)) + res = while_loop(cond_fn, body_fn, init, limit_value) expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) From cb3a5f495d0ba41cdc4081b558945a0b87c6db54 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 2 Apr 2024 18:09:11 +0000 Subject: [PATCH 020/546] upstream --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index b92d8e0ee44..55a02a55e48 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -47,7 +47,7 @@ def body_fn(init, limit_value): init = torch.tensor([10], dtype=torch.int32, device=device) limit_value = torch.tensor([0], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, init, limit_value) + res = while_loop(cond_fn, body_fn, (init, limit_value)) expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) From 9007442dba389063ae79e0a0ea51cbdbaca25697 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 2 Apr 2024 21:35:42 +0000 Subject: [PATCH 021/546] upstream --- torch_xla/experimental/fori_loop.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 57f7162bf3f..da18f2e8683 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -52,7 +52,9 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): return _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) # a=a, b=b, c=c, -def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): +def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): + print("carried_inputs: ", carried_inputs) + print("additional_inputs: ", additional_inputs) # untuple carried_inputs from while_loop carried_inputs = carried_inputs[0] # fake carried_inputs to split formal code From 5394b464510398cdd120a1e31c9ef1b747a68346 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 2 Apr 2024 23:12:47 +0000 Subject: [PATCH 022/546] test --- ...fori_loop_simple_linear_model_test_code.py | 193 ++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 test/test_fori_loop_simple_linear_model_test_code.py diff --git a/test/test_fori_loop_simple_linear_model_test_code.py b/test/test_fori_loop_simple_linear_model_test_code.py new file mode 100644 index 00000000000..006b602f001 --- /dev/null +++ b/test/test_fori_loop_simple_linear_model_test_code.py @@ -0,0 +1,193 @@ +import os +# import unittest +# from typing import Callable, Dict, List + +import torch +import torch_xla +# We need to import the underlying implementation function to register with the dispatcher +import torch_xla.experimental.fori_loop +from torch_xla.experimental.fori_loop import fori_loop +# from torch._higher_order_ops.while_loop import while_loop +import torch_xla.core.xla_model as xm +# import torch_xla.core.xla_builder as xb +import torch_xla.utils.utils as xu + +torch.set_grad_enabled(False) + +device = xm.xla_device() + +# --- linear one --- +# l_in = torch.randn(10, device=xm.xla_device()) +# linear = torch.nn.Linear(10, 20).to(xm.xla_device()) +# l_out = linear(l_in) +# print("linear one: ", l_out) + +# --- while test case --- + +lower = torch.tensor([2], dtype=torch.int32, device=device) +upper = torch.tensor([52], dtype=torch.int32, device=device) +one_value = torch.tensor([1], dtype=torch.int32, device=device) +init_val = torch.tensor([1], dtype=torch.int32, device=device) +# one_one = torch.one(1, dtype=torch.int32, device=device) + +# def body_fun(l_in): +# # l_in = torch.randn(10, device=xm.xla_device()) +# linear = torch.nn.Linear(10, 20).to(xm.xla_device()) +# # l_out = linear(l_in) +# return linear(l_in) # torch.add(a, b) # [0]) +linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + +def body_fun(y, x, l_in_i): + # l_in = torch.randn(10, device=xm.xla_device()) + # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + l_out = linear_0(l_in_i) + # placeholder_func = torch.rand(size = l_out.size(), device = device) + # placeholder_input = torch.rand(size = l_in_i.size(), device = device) + # return torch.add(y, x), l_out, placeholder_func, placeholder_input # linear_0(l_in_i), linear_0, l_in_i # additional return: body and input-placeholder # linear(l_in) # torch.add(a, b) # [0]) + return torch.add(y, x), l_out + +# TODO(@manfei), need to create new variable to seperate old/formal HLO/IR +l_in_0 = torch.randn(10, device=xm.xla_device()) + +# def body_fun(x, y, l_in): +# # l_in = torch.randn(10, device=xm.xla_device()) +# linear = torch.nn.Linear(10, 20).to(xm.xla_device()) +# # l_out = linear(l_in) +# return torch.add(x, y), linear(l_in) # linear(l_in) # torch.add(a, b) # [0]) + +# placeholder_func = torch.rand(size = l_out.size(), device = device) +# placeholder_input = torch.rand(size = l_in_i.size(), device = device) +print("test code, body_fun: ", body_fun) + +lower_, upper_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val, l_in_0) # , placeholder_func, placeholder_input) # , linear_0, l_in_0) + +print("lower_: ", lower_) +print("upper_: ", upper_) +print("res_: ", res_) + +# --- linear two --- +# l_in_2 = torch.randn(10, device=xm.xla_device()) +# linear_2 = torch.nn.Linear(10, 20).to(xm.xla_device()) +# l_out_2 = linear(l_in_2) +# print("linear two: ", l_out_2) + +# ================================================================================= + +# import numpy as np +# # create dummy data for training +# # x_values = [i for i in range(11)] +# # x_train = np.array(x_values, dtype=np.float32) +# # x_train = x_train.reshape(-1, 1) + +# # y_values = [2*i + 1 for i in x_values] +# # y_train = np.array(y_values, dtype=np.float32) +# # y_train = y_train.reshape(-1, 1) + +# batch_size = 2 + +# train_loader = xu.SampleGenerator( +# data=(torch.zeros(batch_size, 1), torch.zeros(batch_size, dtype=torch.float32)), +# sample_count=64 // batch_size // xm.xrt_world_size()) +# test_loader = xu.SampleGenerator( +# data=(torch.zeros(batch_size, 1, torch.zeros(batch_size, dtype=torch.float32)), +# sample_count=32 // batch_size // xm.xrt_world_size()) + +# # import torch +# from torch.autograd import Variable + +# class linearRegression(torch.nn.Module): +# def __init__(self, inputSize, outputSize): +# super(linearRegression, self).__init__() +# self.linear = torch.nn.Linear(inputSize, outputSize).to(device) + +# def forward(self, x): +# out = self.linear(x) +# return out + +# # --- training --- +# inputDim = 1 # takes variable 'x' +# outputDim = 1 # takes variable 'y' +# learningRate = 0.01 * xm.xrt_world_size() +# epochs = 10 # 100 + +# model = linearRegression(inputDim, outputDim).to(device) +# # model = MNIST().to(device) +# ##### For GPU ####### +# # if torch.cuda.is_available(): +# # model.cuda() + +# if xr.using_pjrt(): +# xm.broadcast_master_param(model) + +# criterion = torch.nn.MSELoss() +# optimizer = torch.optim.SGD(model.parameters(), lr=learningRate) + +# for epoch in range(epochs): +# # Converting inputs and labels to Variable +# # if torch.cuda.is_available(): +# # inputs = Variable(torch.from_numpy(x_train).cuda()) +# # labels = Variable(torch.from_numpy(y_train).cuda()) +# # else: +# inputs = Variable(torch.from_numpy(x_train)).to(device) +# labels = Variable(torch.from_numpy(y_train)).to(device) + +# # Clear gradient buffers because we don't want any gradient from previous epoch to carry forward, dont want to cummulate gradients +# optimizer.zero_grad() + +# # get output from the model, given the inputs +# outputs = model(inputs) + +# # get loss for the predicted output +# loss = criterion(outputs, labels) +# print(loss) +# # get gradients w.r.t to parameters +# loss.backward() + +# # update parameters +# # optimizer.step() +# xm.optimizer_step(optimizer) + +# print('epoch {}, loss {}'.format(epoch, loss.item())) + +# # --- while simple test case --- + +# # device = xm.xla_device() + +# lower = torch.tensor([2], dtype=torch.int32, device=device) +# upper = torch.tensor([52], dtype=torch.int32, device=device) +# one_value = torch.tensor([1], dtype=torch.int32, device=device) +# init_val = torch.tensor([1], dtype=torch.int32, device=device) + +# def body_fun(a, b): +# return torch.add(a, b) # [0]) + +# lower_, upper_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val) + +# print("lower_: ", lower_) +# print("upper_: ", upper_) +# print("res_: ", res_) + +# # --- test --- +# for epoch in range(epochs): +# with torch.no_grad(): # we don't need gradients in the testing phase +# if torch.cuda.is_available(): +# predicted = model(Variable(torch.from_numpy(x_train).cuda())).cpu().data.numpy() +# else: +# predicted = model(Variable(torch.from_numpy(x_train))).data.numpy() +# print(epoch, "-th prediction finised") # ed result: ", predicted) + +# print("do one more prediction") +# with torch.no_grad(): # we don't need gradients in the testing phase +# if torch.cuda.is_available(): +# predicted = model(Variable(torch.from_numpy(x_train).cuda())).cpu().data.numpy() +# else: +# predicted = model(Variable(torch.from_numpy(x_train))).data.numpy() +# print(predicted) +# print("finished one more prediction") + +# # --- draw --- +# # plt.clf() +# # plt.plot(x_train, y_train, 'go', label='True data', alpha=0.5) +# # plt.plot(x_train, predicted, '--', label='Predictions', alpha=0.5) +# # plt.legend(loc='best') +# # plt.show() \ No newline at end of file From 972bc82f870e26c086670902c8acabf8f08e6938 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 2 Apr 2024 23:27:00 +0000 Subject: [PATCH 023/546] test --- torch_xla/experimental/fori_loop.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index da18f2e8683..cb8866e78ad 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -79,6 +79,9 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) + cond_hlo_print = xb.get_computation_hlo(cond_computation) + print("cond computation: !!!!!!!!!") + print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], output_value=fake_carried_inputs[-2], bias_0=fake_carried_inputs[-1]) @@ -88,6 +91,9 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) + body_hlo_print = xb.get_computation_hlo(body_computation) + print("body computation: !!!!!!!!!") + print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while kwargs = {} @@ -109,6 +115,9 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) + hlo_print = xb.get_computation_hlo(computation) + print("while computation: !!!!!!!!!") + print(hlo_print) # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', From 518bdbc124273a6fc547246faab259a7d981b249 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 3 Apr 2024 05:59:59 +0000 Subject: [PATCH 024/546] test --- torch_xla/csrc/init_python_bindings.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 2b0221cb0a1..a7bed9e0bf9 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -901,6 +901,20 @@ class PyLoweringContext { "UnusedArgumentsPlaceholder"); parameters_number_i = parameters_number_i + 1; } + // hard-code to meet requirement + // f32[20], /*index=5*/f32[20,10], s32[10] + parameters_number_i = parameters_number_i + 1; + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, + "OutPutTensor"); + parameters_number_i = parameters_number_i + 1; + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); + xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, + "WeightTensor"); + parameters_number_i = parameters_number_i + 1; + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {10}); + xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, + "FinalOneTensor"); } // Get the backing XLA tensors from the output torch tensor handles From 16977fdc46ffca76b3157f810b14ad4403a2c93a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 3 Apr 2024 06:03:11 +0000 Subject: [PATCH 025/546] test --- torch_xla/csrc/init_python_bindings.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a7bed9e0bf9..36c0e93fd23 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -904,16 +904,16 @@ class PyLoweringContext { // hard-code to meet requirement // f32[20], /*index=5*/f32[20,10], s32[10] parameters_number_i = parameters_number_i + 1; - xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, + xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, "OutPutTensor"); parameters_number_i = parameters_number_i + 1; - xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); - xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, + xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); + xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, "WeightTensor"); parameters_number_i = parameters_number_i + 1; - xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {10}); - xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, + xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {10}); + xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, "FinalOneTensor"); } From ab00e2ba0ecf23f6c113fbde2438147ccbd7b320 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 3 Apr 2024 06:35:29 +0000 Subject: [PATCH 026/546] test --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index cb8866e78ad..715a94cd4be 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -66,6 +66,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) fake_carried_inputs = tuple(fake_carried_inputs) + print("fake_carried_inputs: ", fake_carried_inputs) # generate cond_fn xlacomputation # TODO(@manfei): specify which element is for which argument like a,b,c From 33f9e070f10160a775d210ce284fe05dde7c3a8d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 3 Apr 2024 06:48:24 +0000 Subject: [PATCH 027/546] test --- torch_xla/csrc/init_python_bindings.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 36c0e93fd23..3367968de1a 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -894,7 +894,8 @@ class PyLoweringContext { if (GetNameString() == "condctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); int64_t parameters_number_i = 2; - for (at::Tensor input_argument : input_arguments) { + // for (at::Tensor input_argument : input_arguments) { + for (int i = 0; i < 2; i++) { xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, From 5fd313ae8ba70190777cd0a5b663f5a77e1ac632 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 3 Apr 2024 06:51:21 +0000 Subject: [PATCH 028/546] test --- torch_xla/csrc/init_python_bindings.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 3367968de1a..0a6ea4af3d7 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -895,13 +895,13 @@ class PyLoweringContext { xla::XlaBuilder* local_builder = lowering_ctx.builder(); int64_t parameters_number_i = 2; // for (at::Tensor input_argument : input_arguments) { - for (int i = 0; i < 2; i++) { - xla::Shape shape = - xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); - xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, - "UnusedArgumentsPlaceholder"); - parameters_number_i = parameters_number_i + 1; - } + // for (int i = 0; i < 2; i++) { + // xla::Shape shape = + // xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); + // xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, + // "UnusedArgumentsPlaceholder"); + // parameters_number_i = parameters_number_i + 1; + // } // hard-code to meet requirement // f32[20], /*index=5*/f32[20,10], s32[10] parameters_number_i = parameters_number_i + 1; From a6ba3b892065178d6f7181df5455c6a465671179 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 3 Apr 2024 06:53:47 +0000 Subject: [PATCH 029/546] test --- torch_xla/csrc/init_python_bindings.cpp | 38 ++++++++++++------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 0a6ea4af3d7..dbbcf1ecf57 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -894,28 +894,28 @@ class PyLoweringContext { if (GetNameString() == "condctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); int64_t parameters_number_i = 2; - // for (at::Tensor input_argument : input_arguments) { + for (at::Tensor input_argument : input_arguments) { // for (int i = 0; i < 2; i++) { - // xla::Shape shape = - // xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); - // xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, - // "UnusedArgumentsPlaceholder"); - // parameters_number_i = parameters_number_i + 1; - // } + xla::Shape shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); + xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, + "UnusedArgumentsPlaceholder"); + parameters_number_i = parameters_number_i + 1; + } // hard-code to meet requirement // f32[20], /*index=5*/f32[20,10], s32[10] - parameters_number_i = parameters_number_i + 1; - xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, - "OutPutTensor"); - parameters_number_i = parameters_number_i + 1; - xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); - xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, - "WeightTensor"); - parameters_number_i = parameters_number_i + 1; - xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {10}); - xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, - "FinalOneTensor"); + // parameters_number_i = parameters_number_i + 1; + // xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + // xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, + // "OutPutTensor"); + // parameters_number_i = parameters_number_i + 1; + // xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); + // xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, + // "WeightTensor"); + // parameters_number_i = parameters_number_i + 1; + // xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {10}); + // xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, + // "FinalOneTensor"); } // Get the backing XLA tensors from the output torch tensor handles From a09457edeb987a5c96a47bd4329f97b28546700b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 3 Apr 2024 07:08:46 +0000 Subject: [PATCH 030/546] test --- torch_xla/csrc/init_python_bindings.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index dbbcf1ecf57..140c6755b49 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -894,8 +894,8 @@ class PyLoweringContext { if (GetNameString() == "condctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); int64_t parameters_number_i = 2; - for (at::Tensor input_argument : input_arguments) { - // for (int i = 0; i < 2; i++) { + // for (at::Tensor input_argument : input_arguments) { + for (int i = 0; i < 5; i++) { xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, From 8ebc7721b8462a7df93dc8ab523c3b2f1b591c34 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 3 Apr 2024 07:12:50 +0000 Subject: [PATCH 031/546] test --- torch_xla/csrc/init_python_bindings.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 140c6755b49..04517683ac2 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -895,7 +895,7 @@ class PyLoweringContext { xla::XlaBuilder* local_builder = lowering_ctx.builder(); int64_t parameters_number_i = 2; // for (at::Tensor input_argument : input_arguments) { - for (int i = 0; i < 5; i++) { + for (int i = 0; i < 4; i++) { xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, @@ -904,10 +904,10 @@ class PyLoweringContext { } // hard-code to meet requirement // f32[20], /*index=5*/f32[20,10], s32[10] - // parameters_number_i = parameters_number_i + 1; - // xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - // xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, - // "OutPutTensor"); + parameters_number_i = parameters_number_i + 1; + xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, + "OutPutTensor"); // parameters_number_i = parameters_number_i + 1; // xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); // xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, From 8bfa5583c14b757c59a4847e4b1eeca9feb06017 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 3 Apr 2024 07:18:06 +0000 Subject: [PATCH 032/546] test --- torch_xla/csrc/init_python_bindings.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 04517683ac2..ab1103ad373 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -904,11 +904,11 @@ class PyLoweringContext { } // hard-code to meet requirement // f32[20], /*index=5*/f32[20,10], s32[10] - parameters_number_i = parameters_number_i + 1; + // parameters_number_i = parameters_number_i + 1; xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, "OutPutTensor"); - // parameters_number_i = parameters_number_i + 1; + parameters_number_i = parameters_number_i + 1; // xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); // xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, // "WeightTensor"); From 7ed983c93292bade5701df22cfd2db436973f1f3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 3 Apr 2024 07:22:02 +0000 Subject: [PATCH 033/546] test --- torch_xla/csrc/init_python_bindings.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ab1103ad373..588a213e197 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -895,7 +895,7 @@ class PyLoweringContext { xla::XlaBuilder* local_builder = lowering_ctx.builder(); int64_t parameters_number_i = 2; // for (at::Tensor input_argument : input_arguments) { - for (int i = 0; i < 4; i++) { + for (int i = 0; i < 3; i++) { xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, @@ -909,10 +909,10 @@ class PyLoweringContext { xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, "OutPutTensor"); parameters_number_i = parameters_number_i + 1; - // xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); - // xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, - // "WeightTensor"); - // parameters_number_i = parameters_number_i + 1; + xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); + xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, + "WeightTensor"); + parameters_number_i = parameters_number_i + 1; // xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {10}); // xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, // "FinalOneTensor"); From 608191653023bfd2916ecb5ea044cd5daa44c325 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 3 Apr 2024 07:22:19 +0000 Subject: [PATCH 034/546] test --- torch_xla/csrc/init_python_bindings.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 588a213e197..2c7709bc2e8 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -895,7 +895,7 @@ class PyLoweringContext { xla::XlaBuilder* local_builder = lowering_ctx.builder(); int64_t parameters_number_i = 2; // for (at::Tensor input_argument : input_arguments) { - for (int i = 0; i < 3; i++) { + for (int i = 0; i < 2; i++) { xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, @@ -913,9 +913,9 @@ class PyLoweringContext { xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, "WeightTensor"); parameters_number_i = parameters_number_i + 1; - // xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {10}); - // xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, - // "FinalOneTensor"); + xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {10}); + xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, + "FinalOneTensor"); } // Get the backing XLA tensors from the output torch tensor handles From 9456f7b3bafeeecaed8948e0c95966bd4978a7b1 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 3 Apr 2024 07:41:30 +0000 Subject: [PATCH 035/546] test --- torch_xla/csrc/init_python_bindings.cpp | 13 +++++++++++++ torch_xla/csrc/lowering_context.cpp | 2 ++ 2 files changed, 15 insertions(+) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 2c7709bc2e8..42e1205f95b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -692,6 +692,19 @@ std::vector XlaUserComputation( runtime::ComputationClient::ComputationPtr CreateComputation( const std::string& name, xla::XlaOp root) { + xla::XlaBuilder* local_builder = root.builder(); + int64_t parameters_number_i = 4; + xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, + "OutPutTensor"); + parameters_number_i = parameters_number_i + 1; + xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); + xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, + "WeightTensor"); + parameters_number_i = parameters_number_i + 1; + xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {10}); + xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, + "FinalOneTensor"); xla::XlaComputation computation = ConsumeValue(root.builder()->Build(root)); return std::make_shared( name, std::move(computation)); diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index a530995ca78..39f82a4887b 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -160,6 +160,8 @@ xla::StatusOr LoweringContext::BuildXla() { if (!root_tuple_.empty() & (root_tuple_.size() == 1) & ((get_name_string() == "condctx") or (get_name_string() == "bodyctx"))) { xla = builder()->Build(root_tuple_.at(0)); + // } else if (!root_tuple_.empty() & (root_tuple_.size() == 1) & ) { + // xla = builder()->Build(root_tuple_.at(0)); } else if (!root_tuple_.empty()) { xla::XlaOp root = xla::Tuple(builder(), root_tuple_); xla = builder()->Build(root); From 68199fb8e219fe7ea8da757dad7c67084786b81c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 3 Apr 2024 23:52:59 +0000 Subject: [PATCH 036/546] test --- torch_xla/csrc/init_python_bindings.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 42e1205f95b..958149b677b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -702,7 +702,7 @@ runtime::ComputationClient::ComputationPtr CreateComputation( xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, "WeightTensor"); parameters_number_i = parameters_number_i + 1; - xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {10}); + xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, "FinalOneTensor"); xla::XlaComputation computation = ConsumeValue(root.builder()->Build(root)); @@ -926,7 +926,7 @@ class PyLoweringContext { xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, "WeightTensor"); parameters_number_i = parameters_number_i + 1; - xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {10}); + xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, "FinalOneTensor"); } From f69403838133703e92f06855244fc0c6fa3f9be6 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 4 Apr 2024 00:23:44 +0000 Subject: [PATCH 037/546] test --- torch_xla/csrc/init_python_bindings.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 958149b677b..5f4241e603d 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -915,7 +915,7 @@ class PyLoweringContext { "UnusedArgumentsPlaceholder"); parameters_number_i = parameters_number_i + 1; } - // hard-code to meet requirement + // hard-code to meet requirement by change cond xlacomputation // f32[20], /*index=5*/f32[20,10], s32[10] // parameters_number_i = parameters_number_i + 1; xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); @@ -926,7 +926,7 @@ class PyLoweringContext { xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, "WeightTensor"); parameters_number_i = parameters_number_i + 1; - xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); + xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {10}); xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, "FinalOneTensor"); } From f8b4cb1103b90603dd52857a067057b3881d06d6 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 4 Apr 2024 00:28:14 +0000 Subject: [PATCH 038/546] test --- torch_xla/csrc/init_python_bindings.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 5f4241e603d..a83824e856b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -693,18 +693,18 @@ std::vector XlaUserComputation( runtime::ComputationClient::ComputationPtr CreateComputation( const std::string& name, xla::XlaOp root) { xla::XlaBuilder* local_builder = root.builder(); - int64_t parameters_number_i = 4; - xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, - "OutPutTensor"); - parameters_number_i = parameters_number_i + 1; - xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); - xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, - "WeightTensor"); - parameters_number_i = parameters_number_i + 1; - xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); - xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, - "FinalOneTensor"); + // int64_t parameters_number_i = 4; + // xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + // xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, + // "OutPutTensor"); + // parameters_number_i = parameters_number_i + 1; + // xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); + // xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, + // "WeightTensor"); + // parameters_number_i = parameters_number_i + 1; + // xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); + // xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, + // "FinalOneTensor"); xla::XlaComputation computation = ConsumeValue(root.builder()->Build(root)); return std::make_shared( name, std::move(computation)); From 0e16a6ed8355b7b4c817601e8d4b1c6cd7934b8f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 4 Apr 2024 01:09:25 +0000 Subject: [PATCH 039/546] test --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 715a94cd4be..9ec4d2923b8 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -61,6 +61,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): fake_carried_inputs = [] for carried_input in carried_inputs: device = carried_input.device + print("type carried_input: ", type(carried_input)) #TODO(@manfei) type = carried_input.type fake_carried_inputs.append( torch.randint(10, carried_input.size(), From ce49ba9d18315a7679fdeead6d62fbb5d4ed8cc3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 4 Apr 2024 01:11:11 +0000 Subject: [PATCH 040/546] test --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 9ec4d2923b8..69c7871c052 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -61,7 +61,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): fake_carried_inputs = [] for carried_input in carried_inputs: device = carried_input.device - print("type carried_input: ", type(carried_input)) + print("type carried_input: ", carried_input.type) #TODO(@manfei) type = carried_input.type fake_carried_inputs.append( torch.randint(10, carried_input.size(), From 58967a88e710dac2ac9efb58a313906088573b9e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 4 Apr 2024 01:13:15 +0000 Subject: [PATCH 041/546] test --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 69c7871c052..f9cae9f73db 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -62,6 +62,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): for carried_input in carried_inputs: device = carried_input.device print("type carried_input: ", carried_input.type) + print("is torch.int32: ", carried_input.type==torch.int32) #TODO(@manfei) type = carried_input.type fake_carried_inputs.append( torch.randint(10, carried_input.size(), From 9b0d8e86863aa0a6063a082a5b133e9ae746f9dc Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 4 Apr 2024 01:14:58 +0000 Subject: [PATCH 042/546] test --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index f9cae9f73db..e72b3791da4 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -61,8 +61,8 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): fake_carried_inputs = [] for carried_input in carried_inputs: device = carried_input.device - print("type carried_input: ", carried_input.type) - print("is torch.int32: ", carried_input.type==torch.int32) + print("type carried_input: ", carried_input.dtype) + print("is torch.int32: ", carried_input.dtype==torch.int32) #TODO(@manfei) type = carried_input.type fake_carried_inputs.append( torch.randint(10, carried_input.size(), From 79261794f9915cd3a73be18c3190583af6f80ce2 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 4 Apr 2024 01:15:59 +0000 Subject: [PATCH 043/546] test --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index e72b3791da4..dc93999b5f1 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -61,8 +61,8 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): fake_carried_inputs = [] for carried_input in carried_inputs: device = carried_input.device - print("type carried_input: ", carried_input.dtype) - print("is torch.int32: ", carried_input.dtype==torch.int32) + # print("type carried_input: ", carried_input.dtype) + # print("is torch.int32: ", carried_input.dtype==torch.int32) #TODO(@manfei) type = carried_input.type fake_carried_inputs.append( torch.randint(10, carried_input.size(), From b9e2be1ee37528f7936bafe598606e3e6c8884c6 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 4 Apr 2024 01:18:34 +0000 Subject: [PATCH 044/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a83824e856b..b84a6ecac78 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -926,7 +926,7 @@ class PyLoweringContext { xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, "WeightTensor"); parameters_number_i = parameters_number_i + 1; - xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {10}); + xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, "FinalOneTensor"); } From cb41d9b75374daafb1b53b840c7fc0e97d2743f0 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 5 Apr 2024 05:58:00 +0000 Subject: [PATCH 045/546] test --- torch_xla/csrc/init_python_bindings.cpp | 26 +++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index b84a6ecac78..4d2f0e905dd 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -966,6 +966,32 @@ class PyLoweringContext { "UnusedArgumentsPlaceholder"); parameter_idx += 1; } + // hard-code to meet requirement by change cond xlacomputation + // f32[20], /*index=5*/f32[20,10], s32[10] + // parameters_number_i = parameters_number_i + 1; + xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, + "BiasTensor"); + parameters_number_i = parameters_number_i + 1; + xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); + xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, + "WeightTensor"); + parameters_number_i = parameters_number_i + 1; + xla::Shape shape4 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); + xla::XlaOp x4 = xla::Parameter(local_builder, parameters_number_i, shape4, + "LInITensor"); + parameters_number_i = parameters_number_i + 1; + xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, + "LOutTensor"); + } + + if (GetNameString() == "bodyctx") { + xla::XlaBuilder* local_builder = lowering_ctx.builder(); + int64_t parameters_number_i = 7; + xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, + "WeightTensor"); } // Get the backing XLA tensors from the output torch tensor handles From 2695cfda8252561f338a9cf7647da10395a50b14 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 10 Apr 2024 19:57:28 +0000 Subject: [PATCH 046/546] test --- test/test_fori_loop_simple_linear_model_test_code.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_fori_loop_simple_linear_model_test_code.py b/test/test_fori_loop_simple_linear_model_test_code.py index 006b602f001..39ffffccf41 100644 --- a/test/test_fori_loop_simple_linear_model_test_code.py +++ b/test/test_fori_loop_simple_linear_model_test_code.py @@ -49,6 +49,8 @@ def body_fun(y, x, l_in_i): # TODO(@manfei), need to create new variable to seperate old/formal HLO/IR l_in_0 = torch.randn(10, device=xm.xla_device()) +print("body_fun.weight: ", body_fun.weight) +print("body_fun.weight_: ", body_fun.weight_) # def body_fun(x, y, l_in): # # l_in = torch.randn(10, device=xm.xla_device()) # linear = torch.nn.Linear(10, 20).to(xm.xla_device()) From 7dd5843241cef4c28c9a7abe5b2b79386fd7c772 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 10 Apr 2024 22:38:47 +0000 Subject: [PATCH 047/546] test --- test/test_fori_loop_simple_linear_model_test_code.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_simple_linear_model_test_code.py b/test/test_fori_loop_simple_linear_model_test_code.py index 39ffffccf41..ce03fccdbb8 100644 --- a/test/test_fori_loop_simple_linear_model_test_code.py +++ b/test/test_fori_loop_simple_linear_model_test_code.py @@ -24,8 +24,10 @@ # --- while test case --- -lower = torch.tensor([2], dtype=torch.int32, device=device) -upper = torch.tensor([52], dtype=torch.int32, device=device) +# lower = torch.tensor([2], dtype=torch.int32, device=device) +# upper = torch.tensor([52], dtype=torch.int32, device=device) +lower = torch.tensor([52], dtype=torch.int32, device=device) +upper = torch.tensor([2], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) # one_one = torch.one(1, dtype=torch.int32, device=device) From 535797e186f81161fce08415db1b198886e01b91 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 11 Apr 2024 04:33:51 +0000 Subject: [PATCH 048/546] test --- torch_xla/experimental/fori_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index dc93999b5f1..4035e207f11 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -94,9 +94,9 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - body_hlo_print = xb.get_computation_hlo(body_computation) - print("body computation: !!!!!!!!!") - print(body_hlo_print) + # body_hlo_print = xb.get_computation_hlo(body_computation) + # print("body computation: !!!!!!!!!") + # print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while kwargs = {} From 99709006756c3d553bbe85b349d871dc32a57287 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 11 Apr 2024 05:10:37 +0000 Subject: [PATCH 049/546] test --- test/test_fori_loop_simple_linear_model_test_code.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_fori_loop_simple_linear_model_test_code.py b/test/test_fori_loop_simple_linear_model_test_code.py index ce03fccdbb8..066eb34e91e 100644 --- a/test/test_fori_loop_simple_linear_model_test_code.py +++ b/test/test_fori_loop_simple_linear_model_test_code.py @@ -20,7 +20,11 @@ # l_in = torch.randn(10, device=xm.xla_device()) # linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # l_out = linear(l_in) +<<<<<<< HEAD # print("linear one: ", l_out) +======= +# print("$$$ linear one: ", l_out) +>>>>>>> test # --- while test case --- From f3a9df20e2d2034f2b67340144265fd626e2d5ba Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 11 Apr 2024 07:16:58 +0000 Subject: [PATCH 050/546] test --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 4035e207f11..78afdb10802 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -72,7 +72,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): # generate cond_fn xlacomputation # TODO(@manfei): specify which element is for which argument like a,b,c - cond_result = cond_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], output_value=fake_carried_inputs[-2], bias_0=fake_carried_inputs[-1]) + cond_result = cond_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-2], output_value=fake_carried_inputs[-1]) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") additional_inputs_list = list(fake_carried_inputs[2:]) @@ -87,7 +87,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): print(cond_hlo_print) # generate body_fn xlacomputation - body_result = body_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], output_value=fake_carried_inputs[-2], bias_0=fake_carried_inputs[-1]) + body_result = body_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-2], output_value=fake_carried_inputs[-1]) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") body_ctx.buildforiloop(list(body_result), []) From a671982d582dfd9e8da5717e3e43a208f40435ef Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 11 Apr 2024 20:54:56 +0000 Subject: [PATCH 051/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 4d2f0e905dd..8e1a800b316 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -988,7 +988,7 @@ class PyLoweringContext { if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); - int64_t parameters_number_i = 7; + int64_t parameters_number_i = 6; // 7; xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, "WeightTensor"); From d143181dd8df7c950a6545553f0fd0df92f11f3f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 11 Apr 2024 21:12:56 +0000 Subject: [PATCH 052/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 8e1a800b316..4d2f0e905dd 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -988,7 +988,7 @@ class PyLoweringContext { if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); - int64_t parameters_number_i = 6; // 7; + int64_t parameters_number_i = 7; xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, "WeightTensor"); From 7128f70fe0d49eb9a2b8f6fa29f3ffb90ab63f34 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 11 Apr 2024 21:42:25 +0000 Subject: [PATCH 053/546] test --- test/test_fori_loop_simple_linear_model_test_code.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_simple_linear_model_test_code.py b/test/test_fori_loop_simple_linear_model_test_code.py index 066eb34e91e..58352e5c548 100644 --- a/test/test_fori_loop_simple_linear_model_test_code.py +++ b/test/test_fori_loop_simple_linear_model_test_code.py @@ -12,7 +12,7 @@ # import torch_xla.core.xla_builder as xb import torch_xla.utils.utils as xu -torch.set_grad_enabled(False) +# torch.set_grad_enabled(False) device = xm.xla_device() From b993c091ca9e6ff27f2512c39b83a77ed02d0f71 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 11 Apr 2024 21:52:29 +0000 Subject: [PATCH 054/546] test --- ...oop_with_while_loop_simple_add_dispatch_in_torch.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 55a02a55e48..5f8d7ec01b5 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -20,9 +20,13 @@ def _fake_while_loop(cond_fn, body_fn, operands): def _fake_fori_loop(lower, upper, body_fun, *init_val): - (a, b) = init_val - for i in range((upper - lower)[0]): - a = body_fun(a, b) + if len(init_val) > 1: + (a, b) = init_val + for i in range((upper - lower)[0]): + a = body_fun(a, b) + else: + for i in range((upper - lower)[0]): + a = body_fun(*init_val) return a def _fake_fori_loop(lower, upper, body_fun, *init_val): From 99d2d78fb7fd001498357b080cdb4e92e6dad591 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 11 Apr 2024 21:54:19 +0000 Subject: [PATCH 055/546] test --- ...est_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 5f8d7ec01b5..e1a06b05d7d 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -95,6 +95,7 @@ def test_fori_loop_tpu_addition(self): xm.mark_step() device = xm.xla_device() + torch.set_grad_enabled(False) lower = torch.tensor([2], dtype=torch.int32, device=device) upper = torch.tensor([52], dtype=torch.int32, device=device) From 812c072ae458fb8f4941dfa6917c47fe567d812a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 00:28:39 +0000 Subject: [PATCH 056/546] rebase --- ...fori_loop_simple_linear_model_test_code.py | 187 ++---------------- ...while_loop_simple_add_dispatch_in_torch.py | 24 +-- torch_xla/experimental/fori_loop.py | 20 +- 3 files changed, 30 insertions(+), 201 deletions(-) diff --git a/test/test_fori_loop_simple_linear_model_test_code.py b/test/test_fori_loop_simple_linear_model_test_code.py index 58352e5c548..07a7636d880 100644 --- a/test/test_fori_loop_simple_linear_model_test_code.py +++ b/test/test_fori_loop_simple_linear_model_test_code.py @@ -1,15 +1,11 @@ import os -# import unittest -# from typing import Callable, Dict, List import torch import torch_xla # We need to import the underlying implementation function to register with the dispatcher import torch_xla.experimental.fori_loop from torch_xla.experimental.fori_loop import fori_loop -# from torch._higher_order_ops.while_loop import while_loop import torch_xla.core.xla_model as xm -# import torch_xla.core.xla_builder as xb import torch_xla.utils.utils as xu # torch.set_grad_enabled(False) @@ -17,181 +13,34 @@ device = xm.xla_device() # --- linear one --- -# l_in = torch.randn(10, device=xm.xla_device()) -# linear = torch.nn.Linear(10, 20).to(xm.xla_device()) -# l_out = linear(l_in) -<<<<<<< HEAD -# print("linear one: ", l_out) -======= -# print("$$$ linear one: ", l_out) ->>>>>>> test +l_in = torch.randn(10, device=xm.xla_device()) +linear = torch.nn.Linear(10, 20).to(xm.xla_device()) +l_out = linear(l_in) +print("$$$ different linear model with different weight/bias: ") +print(l_out) # --- while test case --- - -# lower = torch.tensor([2], dtype=torch.int32, device=device) -# upper = torch.tensor([52], dtype=torch.int32, device=device) -lower = torch.tensor([52], dtype=torch.int32, device=device) -upper = torch.tensor([2], dtype=torch.int32, device=device) -one_value = torch.tensor([1], dtype=torch.int32, device=device) +upper = torch.tensor([52], dtype=torch.int32, device=device) +lower = torch.tensor([0], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) -# one_one = torch.one(1, dtype=torch.int32, device=device) - -# def body_fun(l_in): -# # l_in = torch.randn(10, device=xm.xla_device()) -# linear = torch.nn.Linear(10, 20).to(xm.xla_device()) -# # l_out = linear(l_in) -# return linear(l_in) # torch.add(a, b) # [0]) linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) -def body_fun(y, x, l_in_i): - # l_in = torch.randn(10, device=xm.xla_device()) - # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - l_out = linear_0(l_in_i) - # placeholder_func = torch.rand(size = l_out.size(), device = device) - # placeholder_input = torch.rand(size = l_in_i.size(), device = device) - # return torch.add(y, x), l_out, placeholder_func, placeholder_input # linear_0(l_in_i), linear_0, l_in_i # additional return: body and input-placeholder # linear(l_in) # torch.add(a, b) # [0]) - return torch.add(y, x), l_out +# def body_fun(l_in_i): +# l_out = linear_0(l_in_i) +# return l_out -# TODO(@manfei), need to create new variable to seperate old/formal HLO/IR l_in_0 = torch.randn(10, device=xm.xla_device()) -print("body_fun.weight: ", body_fun.weight) -print("body_fun.weight_: ", body_fun.weight_) -# def body_fun(x, y, l_in): -# # l_in = torch.randn(10, device=xm.xla_device()) -# linear = torch.nn.Linear(10, 20).to(xm.xla_device()) -# # l_out = linear(l_in) -# return torch.add(x, y), linear(l_in) # linear(l_in) # torch.add(a, b) # [0]) - -# placeholder_func = torch.rand(size = l_out.size(), device = device) -# placeholder_input = torch.rand(size = l_in_i.size(), device = device) -print("test code, body_fun: ", body_fun) - -lower_, upper_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val, l_in_0) # , placeholder_func, placeholder_input) # , linear_0, l_in_0) - -print("lower_: ", lower_) -print("upper_: ", upper_) -print("res_: ", res_) - -# --- linear two --- -# l_in_2 = torch.randn(10, device=xm.xla_device()) -# linear_2 = torch.nn.Linear(10, 20).to(xm.xla_device()) -# l_out_2 = linear(l_in_2) -# print("linear two: ", l_out_2) - -# ================================================================================= - -# import numpy as np -# # create dummy data for training -# # x_values = [i for i in range(11)] -# # x_train = np.array(x_values, dtype=np.float32) -# # x_train = x_train.reshape(-1, 1) - -# # y_values = [2*i + 1 for i in x_values] -# # y_train = np.array(y_values, dtype=np.float32) -# # y_train = y_train.reshape(-1, 1) - -# batch_size = 2 - -# train_loader = xu.SampleGenerator( -# data=(torch.zeros(batch_size, 1), torch.zeros(batch_size, dtype=torch.float32)), -# sample_count=64 // batch_size // xm.xrt_world_size()) -# test_loader = xu.SampleGenerator( -# data=(torch.zeros(batch_size, 1, torch.zeros(batch_size, dtype=torch.float32)), -# sample_count=32 // batch_size // xm.xrt_world_size()) - -# # import torch -# from torch.autograd import Variable - -# class linearRegression(torch.nn.Module): -# def __init__(self, inputSize, outputSize): -# super(linearRegression, self).__init__() -# self.linear = torch.nn.Linear(inputSize, outputSize).to(device) - -# def forward(self, x): -# out = self.linear(x) -# return out - -# # --- training --- -# inputDim = 1 # takes variable 'x' -# outputDim = 1 # takes variable 'y' -# learningRate = 0.01 * xm.xrt_world_size() -# epochs = 10 # 100 - -# model = linearRegression(inputDim, outputDim).to(device) -# # model = MNIST().to(device) -# ##### For GPU ####### -# # if torch.cuda.is_available(): -# # model.cuda() - -# if xr.using_pjrt(): -# xm.broadcast_master_param(model) - -# criterion = torch.nn.MSELoss() -# optimizer = torch.optim.SGD(model.parameters(), lr=learningRate) - -# for epoch in range(epochs): -# # Converting inputs and labels to Variable -# # if torch.cuda.is_available(): -# # inputs = Variable(torch.from_numpy(x_train).cuda()) -# # labels = Variable(torch.from_numpy(y_train).cuda()) -# # else: -# inputs = Variable(torch.from_numpy(x_train)).to(device) -# labels = Variable(torch.from_numpy(y_train)).to(device) - -# # Clear gradient buffers because we don't want any gradient from previous epoch to carry forward, dont want to cummulate gradients -# optimizer.zero_grad() - -# # get output from the model, given the inputs -# outputs = model(inputs) - -# # get loss for the predicted output -# loss = criterion(outputs, labels) -# print(loss) -# # get gradients w.r.t to parameters -# loss.backward() - -# # update parameters -# # optimizer.step() -# xm.optimizer_step(optimizer) - -# print('epoch {}, loss {}'.format(epoch, loss.item())) - -# # --- while simple test case --- - -# # device = xm.xla_device() - -# lower = torch.tensor([2], dtype=torch.int32, device=device) -# upper = torch.tensor([52], dtype=torch.int32, device=device) -# one_value = torch.tensor([1], dtype=torch.int32, device=device) -# init_val = torch.tensor([1], dtype=torch.int32, device=device) - -# def body_fun(a, b): -# return torch.add(a, b) # [0]) - -# lower_, upper_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val) - -# print("lower_: ", lower_) -# print("upper_: ", upper_) -# print("res_: ", res_) +upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_= fori_loop(upper, lower, linear_0, init_val, l_in_0) -# # --- test --- -# for epoch in range(epochs): -# with torch.no_grad(): # we don't need gradients in the testing phase -# if torch.cuda.is_available(): -# predicted = model(Variable(torch.from_numpy(x_train).cuda())).cpu().data.numpy() -# else: -# predicted = model(Variable(torch.from_numpy(x_train))).data.numpy() -# print(epoch, "-th prediction finised") # ed result: ", predicted) +print("$$$ fori_loop l_out_: ") +print(l_out_) -# print("do one more prediction") -# with torch.no_grad(): # we don't need gradients in the testing phase -# if torch.cuda.is_available(): -# predicted = model(Variable(torch.from_numpy(x_train).cuda())).cpu().data.numpy() -# else: -# predicted = model(Variable(torch.from_numpy(x_train))).data.numpy() -# print(predicted) -# print("finished one more prediction") +range_num = upper - lower +for i in range(range_num[0]): + l_out_expected = linear_0(l_in_0) +print("$$$ without-fori_loop l_out_: ") +print(l_out_expected) # # --- draw --- # # plt.clf() diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index e1a06b05d7d..0bc5013438d 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -29,11 +29,6 @@ def _fake_fori_loop(lower, upper, body_fun, *init_val): a = body_fun(*init_val) return a -def _fake_fori_loop(lower, upper, body_fun, *init_val): - (plus_value, init_val) = init_val - for i in range((upper - lower)[0]): - plus_value, init_val = body_fun(plus_value, init_val) - return init_val class WhileLoopTest(unittest.TestCase): @@ -91,25 +86,24 @@ def body_fun(a, b): expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value) self.assertEqual(expected, res_) - def test_fori_loop_tpu_addition(self): + def test_fori_loop_tpu_simple_linear(self): xm.mark_step() device = xm.xla_device() torch.set_grad_enabled(False) - lower = torch.tensor([2], dtype=torch.int32, device=device) upper = torch.tensor([52], dtype=torch.int32, device=device) - plus_value = torch.tensor([1], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.randn(10, device=xm.xla_device()) + + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - def body_fun(*argus): - plus_value, init_val = argus - return plus_value, torch.add(plus_value, init_val) - - _, _, _, actual = fori_loop(upper, lower, body_fun, plus_value, init_val) - expected = _fake_fori_loop(lower, upper, body_fun, plus_value, init_val) - self.assertEqual(expected, actual) + upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_= fori_loop(upper, lower, linear_0, init_val, l_in_0) + + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + self.assertTrue(torch.all(torch.eq(expected, l_out_))) if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 78afdb10802..57f7162bf3f 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -52,27 +52,22 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): return _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) # a=a, b=b, c=c, -def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): - print("carried_inputs: ", carried_inputs) - print("additional_inputs: ", additional_inputs) +def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): # untuple carried_inputs from while_loop carried_inputs = carried_inputs[0] # fake carried_inputs to split formal code fake_carried_inputs = [] for carried_input in carried_inputs: device = carried_input.device - # print("type carried_input: ", carried_input.dtype) - # print("is torch.int32: ", carried_input.dtype==torch.int32) #TODO(@manfei) type = carried_input.type fake_carried_inputs.append( torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) fake_carried_inputs = tuple(fake_carried_inputs) - print("fake_carried_inputs: ", fake_carried_inputs) # generate cond_fn xlacomputation # TODO(@manfei): specify which element is for which argument like a,b,c - cond_result = cond_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-2], output_value=fake_carried_inputs[-1]) + cond_result = cond_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], output_value=fake_carried_inputs[-2], bias_0=fake_carried_inputs[-1]) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") additional_inputs_list = list(fake_carried_inputs[2:]) @@ -82,21 +77,15 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - cond_hlo_print = xb.get_computation_hlo(cond_computation) - print("cond computation: !!!!!!!!!") - print(cond_hlo_print) # generate body_fn xlacomputation - body_result = body_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-2], output_value=fake_carried_inputs[-1]) + body_result = body_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], output_value=fake_carried_inputs[-2], bias_0=fake_carried_inputs[-1]) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") body_ctx.buildforiloop(list(body_result), []) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - # body_hlo_print = xb.get_computation_hlo(body_computation) - # print("body computation: !!!!!!!!!") - # print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while kwargs = {} @@ -118,9 +107,6 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) - hlo_print = xb.get_computation_hlo(computation) - print("while computation: !!!!!!!!!") - print(hlo_print) # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', From 89dd1e8d05fd4882615c64bc52e76498c49f7d47 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 00:32:45 +0000 Subject: [PATCH 057/546] update --- torch_xla/csrc/init_python_bindings.cpp | 115 ++++++++++-------------- 1 file changed, 46 insertions(+), 69 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 4d2f0e905dd..94c84213fe6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -120,8 +120,6 @@ void PrepareToExit() { runtime::ComputationClient* client = runtime::GetComputationClientIfInitialized(); if (client != nullptr) { - auto xla_device = GetDeviceOrCurrent(""); - SetAllReduceToken(xla_device, nullptr); XLAGraphExecutor::Get()->WaitDeviceOps({}); } } @@ -464,13 +462,12 @@ void SyncLiveTensors(const std::string& device_str, } void StepMarker(const std::string& device_str, - const std::vector& devices, bool wait, - bool reset_scope) { + const std::vector& devices, bool wait) { tsl::profiler::TraceMe activity("StepMarker", tsl::profiler::TraceMeLevel::kInfo); torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str); XLAGraphExecutor::Get()->SyncLiveTensorsGraph(&device, devices, wait); - XLAGraphExecutor::Get()->MarkStep(device, reset_scope); + XLAGraphExecutor::Get()->MarkStep(device); bool debug_mode = runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false); if (TF_PREDICT_FALSE(debug_mode)) { std::string report = runtime::metrics::CreatePerformanceReport( @@ -902,35 +899,7 @@ class PyLoweringContext { : lowering_ctx("PyLoweringContext", device) {} // Builds a HLO graph given a set of output tensors. - void Build(std::vector tensors, - std::vector input_arguments = {}) { - if (GetNameString() == "condctx") { - xla::XlaBuilder* local_builder = lowering_ctx.builder(); - int64_t parameters_number_i = 2; - // for (at::Tensor input_argument : input_arguments) { - for (int i = 0; i < 2; i++) { - xla::Shape shape = - xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); - xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, - "UnusedArgumentsPlaceholder"); - parameters_number_i = parameters_number_i + 1; - } - // hard-code to meet requirement by change cond xlacomputation - // f32[20], /*index=5*/f32[20,10], s32[10] - // parameters_number_i = parameters_number_i + 1; - xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, - "OutPutTensor"); - parameters_number_i = parameters_number_i + 1; - xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); - xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, - "WeightTensor"); - parameters_number_i = parameters_number_i + 1; - xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); - xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, - "FinalOneTensor"); - } - + void Build(std::vector tensors) { // Get the backing XLA tensors from the output torch tensor handles std::vector xtensors = GetXlaTensors(tensors, /*want_all=*/true); @@ -957,42 +926,60 @@ class PyLoweringContext { std::vector input_arguments = {}) { if (GetNameString() == "condctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); - // hard-code parameter_idx to 2 to skip existing upper/lower arguments - int64_t parameter_idx = 2; - for (at::Tensor input_argument : input_arguments) { + int64_t parameters_number_i = 2; + // for (at::Tensor input_argument : input_arguments) { + for (int i = 0; i < 2; i++) { xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); - xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, "UnusedArgumentsPlaceholder"); - parameter_idx += 1; + parameters_number_i = parameters_number_i + 1; } // hard-code to meet requirement by change cond xlacomputation // f32[20], /*index=5*/f32[20,10], s32[10] // parameters_number_i = parameters_number_i + 1; - xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + + xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, - "BiasTensor"); + "LInITensor"); + // xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + // xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, + // "BiasTensor"); parameters_number_i = parameters_number_i + 1; xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, "WeightTensor"); + // parameters_number_i = parameters_number_i + 1; + // xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + // xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, + // "LOutTensor"); + // parameters_number_i = parameters_number_i + 1; + // xla::Shape shape4 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); + // xla::XlaOp x4 = xla::Parameter(local_builder, parameters_number_i, shape4, + // "LInITensor"); + // parameters_number_i = parameters_number_i + 1; + // xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + // xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, + // "LOutTensor"); parameters_number_i = parameters_number_i + 1; - xla::Shape shape4 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); + xla::Shape shape4 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); xla::XlaOp x4 = xla::Parameter(local_builder, parameters_number_i, shape4, - "LInITensor"); + "BiasTensor"); parameters_number_i = parameters_number_i + 1; xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, "LOutTensor"); + // // input_value!!!, weight_0, output_value, bias!!! } - if (GetNameString() == "bodyctx") { - xla::XlaBuilder* local_builder = lowering_ctx.builder(); - int64_t parameters_number_i = 7; - xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, - "WeightTensor"); - } + // // hard-code modify body xlacomputation input arguments + // if (GetNameString() == "bodyctx") { + // xla::XlaBuilder* local_builder = lowering_ctx.builder(); + // int64_t parameters_number_i = 7; + // xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + // xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, + // "WeightTensor"); + // } // Get the backing XLA tensors from the output torch tensor handles std::vector xtensors = @@ -1761,12 +1748,11 @@ void InitXlaModuleBindings(py::module m) { m.def( "_xla_step_marker", [](const std::string& device, const std::vector& devices, - bool wait, bool reset_scope) { + bool wait) { NoGilSection nogil; - StepMarker(device, devices, wait, reset_scope); + StepMarker(device, devices, wait); }, - py::arg("device") = "", py::arg("devices"), py::arg("wait") = true, - py::arg("reset_scope") = true); + py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); m.def("_get_stablehlo", [](const std::vector& tensors, const std::string& device, const std::vector& devices, @@ -2389,21 +2375,12 @@ void InitXlaModuleBindings(py::module m) { [](at::Tensor& self, const at::Tensor& source) -> at::Tensor& { return XLANativeFunctions::set_(self, source); }); - m.def("_xla_tpu_custom_call", - [](const std::vector& inputs, const std::string& payload, - const std::vector>& output_shapes, - const std::vector& output_dtypes) - -> std::vector { - std::vector dtypes; - dtypes.reserve(output_dtypes.size()); - for (auto& dtype : output_dtypes) { - dtypes.push_back( - reinterpret_cast(dtype.ptr())->scalar_type); - } - - auto xtensors = tensor_methods::tpu_custom_call( - bridge::GetXlaTensors(inputs), payload, output_shapes, dtypes); - return bridge::AtenFromXlaTensors(xtensors); + m.def("_xla_tpu_custom_call_", + [](const std::vector& outputs, + const std::vector& inputs, const std::string& payload) { + auto x_outputs = bridge::GetXlaTensors(outputs); + return tensor_methods::tpu_custom_call_( + x_outputs, bridge::GetXlaTensors(inputs), payload); }); m.def("_set_xla_custom_op_name_prefix", [](const at::Tensor& input, const std::string& op_name_prefix, From cfbe3a6a2ae9d50eb99e66bd4f14ad745a133a12 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 00:34:02 +0000 Subject: [PATCH 058/546] update --- test/cpp/test_xla_sharding.cpp | 16 ------ ...fori_loop_simple_linear_model_test_code.py | 50 ------------------- 2 files changed, 66 deletions(-) delete mode 100644 test/test_fori_loop_simple_linear_model_test_code.py diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 167ffd753e7..e1f908b5c80 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -435,21 +435,5 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) { ->HasValue()); } -TEST_F(XLAShardingTest, TestForiLoopAddUnusedParameterInXlaComputation) { - xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}); - // Build simple addition. - xla::XlaBuilder b("builder"); - auto x = xla::Parameter(&b, /*parameter_number=*/0, shape, "p0"); - // Add unused parameter before create xlacomputation - xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}); - auto zzz = xla::Parameter(&b, /*parameter_number=*/1, shape2, "p1"); - auto y = xla::Add(x, xla::ConstantR0(&b, 3)); - xla::XlaComputation xla_computation = - ConsumeValue(b.Build(/*remove_dynamic_dimensions=*/false)); - - // Check whether the unused parameter has been included into xlacomputation - EXPECT_EQ(xla_computation.GetProgramShape()->parameters_size(), 2); -} - } // namespace cpp_test } // namespace torch_xla diff --git a/test/test_fori_loop_simple_linear_model_test_code.py b/test/test_fori_loop_simple_linear_model_test_code.py deleted file mode 100644 index 07a7636d880..00000000000 --- a/test/test_fori_loop_simple_linear_model_test_code.py +++ /dev/null @@ -1,50 +0,0 @@ -import os - -import torch -import torch_xla -# We need to import the underlying implementation function to register with the dispatcher -import torch_xla.experimental.fori_loop -from torch_xla.experimental.fori_loop import fori_loop -import torch_xla.core.xla_model as xm -import torch_xla.utils.utils as xu - -# torch.set_grad_enabled(False) - -device = xm.xla_device() - -# --- linear one --- -l_in = torch.randn(10, device=xm.xla_device()) -linear = torch.nn.Linear(10, 20).to(xm.xla_device()) -l_out = linear(l_in) -print("$$$ different linear model with different weight/bias: ") -print(l_out) - -# --- while test case --- -upper = torch.tensor([52], dtype=torch.int32, device=device) -lower = torch.tensor([0], dtype=torch.int32, device=device) -init_val = torch.tensor([1], dtype=torch.int32, device=device) -linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - -# def body_fun(l_in_i): -# l_out = linear_0(l_in_i) -# return l_out - -l_in_0 = torch.randn(10, device=xm.xla_device()) - -upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_= fori_loop(upper, lower, linear_0, init_val, l_in_0) - -print("$$$ fori_loop l_out_: ") -print(l_out_) - -range_num = upper - lower -for i in range(range_num[0]): - l_out_expected = linear_0(l_in_0) -print("$$$ without-fori_loop l_out_: ") -print(l_out_expected) - -# # --- draw --- -# # plt.clf() -# # plt.plot(x_train, y_train, 'go', label='True data', alpha=0.5) -# # plt.plot(x_train, predicted, '--', label='Predictions', alpha=0.5) -# # plt.legend(loc='best') -# # plt.show() \ No newline at end of file From 4bcfa5089c5d61b29183d081ac96c43a61143b83 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 00:40:04 +0000 Subject: [PATCH 059/546] update --- ...while_loop_simple_add_dispatch_in_torch.py | 3 +- torch_xla/csrc/init_python_bindings.cpp | 94 ++++++++----------- 2 files changed, 39 insertions(+), 58 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 0bc5013438d..830b3ff7df7 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -23,7 +23,8 @@ def _fake_fori_loop(lower, upper, body_fun, *init_val): if len(init_val) > 1: (a, b) = init_val for i in range((upper - lower)[0]): - a = body_fun(a, b) + # a = body_fun(a, b) + a = body_fun(*init_val) else: for i in range((upper - lower)[0]): a = body_fun(*init_val) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 94c84213fe6..9fa9e89b191 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -120,6 +120,8 @@ void PrepareToExit() { runtime::ComputationClient* client = runtime::GetComputationClientIfInitialized(); if (client != nullptr) { + auto xla_device = GetDeviceOrCurrent(""); + SetAllReduceToken(xla_device, nullptr); XLAGraphExecutor::Get()->WaitDeviceOps({}); } } @@ -462,12 +464,13 @@ void SyncLiveTensors(const std::string& device_str, } void StepMarker(const std::string& device_str, - const std::vector& devices, bool wait) { + const std::vector& devices, bool wait, + bool reset_scope) { tsl::profiler::TraceMe activity("StepMarker", tsl::profiler::TraceMeLevel::kInfo); torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str); XLAGraphExecutor::Get()->SyncLiveTensorsGraph(&device, devices, wait); - XLAGraphExecutor::Get()->MarkStep(device); + XLAGraphExecutor::Get()->MarkStep(device, reset_scope); bool debug_mode = runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false); if (TF_PREDICT_FALSE(debug_mode)) { std::string report = runtime::metrics::CreatePerformanceReport( @@ -689,19 +692,6 @@ std::vector XlaUserComputation( runtime::ComputationClient::ComputationPtr CreateComputation( const std::string& name, xla::XlaOp root) { - xla::XlaBuilder* local_builder = root.builder(); - // int64_t parameters_number_i = 4; - // xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - // xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, - // "OutPutTensor"); - // parameters_number_i = parameters_number_i + 1; - // xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); - // xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, - // "WeightTensor"); - // parameters_number_i = parameters_number_i + 1; - // xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); - // xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, - // "FinalOneTensor"); xla::XlaComputation computation = ConsumeValue(root.builder()->Build(root)); return std::make_shared( name, std::move(computation)); @@ -926,41 +916,22 @@ class PyLoweringContext { std::vector input_arguments = {}) { if (GetNameString() == "condctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); - int64_t parameters_number_i = 2; - // for (at::Tensor input_argument : input_arguments) { + // hard-code parameter_idx to 2 to skip existing upper/lower arguments + int64_t parameter_idx = 2; for (int i = 0; i < 2; i++) { xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); - xla::XlaOp x = xla::Parameter(local_builder, parameters_number_i, shape, + xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, "UnusedArgumentsPlaceholder"); - parameters_number_i = parameters_number_i + 1; + parameter_idx += 1; } - // hard-code to meet requirement by change cond xlacomputation - // f32[20], /*index=5*/f32[20,10], s32[10] - // parameters_number_i = parameters_number_i + 1; - xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, "LInITensor"); - // xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - // xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, - // "BiasTensor"); parameters_number_i = parameters_number_i + 1; xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, "WeightTensor"); - // parameters_number_i = parameters_number_i + 1; - // xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - // xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, - // "LOutTensor"); - // parameters_number_i = parameters_number_i + 1; - // xla::Shape shape4 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); - // xla::XlaOp x4 = xla::Parameter(local_builder, parameters_number_i, shape4, - // "LInITensor"); - // parameters_number_i = parameters_number_i + 1; - // xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - // xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, - // "LOutTensor"); parameters_number_i = parameters_number_i + 1; xla::Shape shape4 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); xla::XlaOp x4 = xla::Parameter(local_builder, parameters_number_i, shape4, @@ -969,17 +940,16 @@ class PyLoweringContext { xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, "LOutTensor"); - // // input_value!!!, weight_0, output_value, bias!!! } - // // hard-code modify body xlacomputation input arguments - // if (GetNameString() == "bodyctx") { - // xla::XlaBuilder* local_builder = lowering_ctx.builder(); - // int64_t parameters_number_i = 7; - // xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - // xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, - // "WeightTensor"); - // } + // hard-code modify body xlacomputation input arguments + if (GetNameString() == "bodyctx") { + xla::XlaBuilder* local_builder = lowering_ctx.builder(); + int64_t parameters_number_i = 7; + xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, + "WeightTensor"); + } // Get the backing XLA tensors from the output torch tensor handles std::vector xtensors = @@ -1748,11 +1718,12 @@ void InitXlaModuleBindings(py::module m) { m.def( "_xla_step_marker", [](const std::string& device, const std::vector& devices, - bool wait) { + bool wait, bool reset_scope) { NoGilSection nogil; - StepMarker(device, devices, wait); + StepMarker(device, devices, wait, reset_scope); }, - py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); + py::arg("device") = "", py::arg("devices"), py::arg("wait") = true, + py::arg("reset_scope") = true); m.def("_get_stablehlo", [](const std::vector& tensors, const std::string& device, const std::vector& devices, @@ -2375,12 +2346,21 @@ void InitXlaModuleBindings(py::module m) { [](at::Tensor& self, const at::Tensor& source) -> at::Tensor& { return XLANativeFunctions::set_(self, source); }); - m.def("_xla_tpu_custom_call_", - [](const std::vector& outputs, - const std::vector& inputs, const std::string& payload) { - auto x_outputs = bridge::GetXlaTensors(outputs); - return tensor_methods::tpu_custom_call_( - x_outputs, bridge::GetXlaTensors(inputs), payload); + m.def("_xla_tpu_custom_call", + [](const std::vector& inputs, const std::string& payload, + const std::vector>& output_shapes, + const std::vector& output_dtypes) + -> std::vector { + std::vector dtypes; + dtypes.reserve(output_dtypes.size()); + for (auto& dtype : output_dtypes) { + dtypes.push_back( + reinterpret_cast(dtype.ptr())->scalar_type); + } + + auto xtensors = tensor_methods::tpu_custom_call( + bridge::GetXlaTensors(inputs), payload, output_shapes, dtypes); + return bridge::AtenFromXlaTensors(xtensors); }); m.def("_set_xla_custom_op_name_prefix", [](const at::Tensor& input, const std::string& op_name_prefix, @@ -2651,4 +2631,4 @@ void InitXlaBindings(py::module m) { InitXlaModuleBindings(m); } } // namespace torch_xla -PYBIND11_MODULE(_XLAC, m) { torch_xla::InitXlaBindings(m); } +PYBIND11_MODULE(_XLAC, m) { torch_xla::InitXlaBindings(m); } \ No newline at end of file From 09eaa3786a743bc068f8d582dd5ae13279d81ffa Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 00:41:51 +0000 Subject: [PATCH 060/546] update --- torch_xla/csrc/init_python_bindings.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 9fa9e89b191..1fc9ac3ccdb 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -926,28 +926,28 @@ class PyLoweringContext { parameter_idx += 1; } xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); - xla::XlaOp x1 = xla::Parameter(local_builder, parameters_number_i, shape1, + xla::XlaOp x1 = xla::Parameter(local_builder, parameter_idx, shape1, "LInITensor"); - parameters_number_i = parameters_number_i + 1; + parameter_idx = parameter_idx + 1; xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); - xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, + xla::XlaOp x2 = xla::Parameter(local_builder, parameter_idx, shape2, "WeightTensor"); - parameters_number_i = parameters_number_i + 1; + parameter_idx = parameter_idx + 1; xla::Shape shape4 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - xla::XlaOp x4 = xla::Parameter(local_builder, parameters_number_i, shape4, + xla::XlaOp x4 = xla::Parameter(local_builder, parameter_idx, shape4, "BiasTensor"); - parameters_number_i = parameters_number_i + 1; + parameter_idx = parameter_idx + 1; xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - xla::XlaOp x3 = xla::Parameter(local_builder, parameters_number_i, shape3, + xla::XlaOp x3 = xla::Parameter(local_builder, parameter_idx, shape3, "LOutTensor"); } // hard-code modify body xlacomputation input arguments if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); - int64_t parameters_number_i = 7; + int64_t parameter_idx = 7; xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - xla::XlaOp x2 = xla::Parameter(local_builder, parameters_number_i, shape2, + xla::XlaOp x2 = xla::Parameter(local_builder, parameter_idx, shape2, "WeightTensor"); } From 88bdb4a6241759dbdcd15341e43b496a690c3206 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 00:51:38 +0000 Subject: [PATCH 061/546] update --- ..._while_loop_simple_add_dispatch_in_torch.py | 18 ++++++++++++++++++ torch_xla/experimental/fori_loop.py | 3 --- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 830b3ff7df7..e84adbaa194 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -69,6 +69,24 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) + def test_while_loop_tpu_subtraction_nested(self): + + device = xm.xla_device() + + def cond_fn(init, limit_value): + return limit_value[0] <= init[0] + + def body_fn(init, limit_value): + one_value = torch.ones(1, dtype=torch.int32, device=device) + two_value = limit_value.clone() + return (torch.sub(torch.sub(init, one_value), one_value), two_value) + + init = torch.tensor([10], dtype=torch.int32, device=device) + limit_value = torch.tensor([0], dtype=torch.int32, device=device) + res = while_loop(cond_fn, body_fn, (init, limit_value)) + expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) + self.assertEqual(expected, res) + def test_fori_loop_tpu_addition(self): xm.mark_step() diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 57f7162bf3f..8141d777a4e 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -16,9 +16,6 @@ def fori_loop(upper, lower, body_fun, init_val, *input_value): device = xm.xla_device() - # weight_0 = body_fun.weight - # bias_0 = body_fun.bias - # one_value = torch.tensor([1], dtype=torch.int32, device=device) def cond_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): # , bias_0): # , output_value): return lower[0] < upper[0] From ec5e999fe4689a7cad750c1473e15a92c898333a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 06:00:18 +0000 Subject: [PATCH 062/546] update --- torch_xla/csrc/init_python_bindings.cpp | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 1fc9ac3ccdb..9a7cd811cb3 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -913,11 +913,29 @@ class PyLoweringContext { // Builds a HLO graph given a set of output tensors, and add unused parameters // needed in xlacomputation. void BuildForiLoop(std::vector tensors, - std::vector input_arguments = {}) { + std::vector additional_inputs_list = {}) { if (GetNameString() == "condctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // hard-code parameter_idx to 2 to skip existing upper/lower arguments - int64_t parameter_idx = 2; + // !!! since cond_fn only compare upper and lower, so it would only use two arguments, due to PyTorch/XLA + // !!! trace xlacomputation from result tensor, so all the other arguments would not be included or generated; + // !!! but to meet xla::while requirement, we would skip first two arguments, + // !!! then add all other arguments like body_fn/init + // !!! --- additional_inputs_list: this list include all other arguments like body_fn/init except upper and lower + // !!! --- next step: we add dump paras according to additional_inputs_list + // ??? --- could we get IRvalue of `additional_inputs_list` in this function to complete xlacomputation? + int64_t parameter_idx = 2; // parameter_idx start from 2 after upper and lower + // ? type, ? shape, + // for (int i = 0; i < additional_inputs_list.size(); i++) { + for (auto& additional_input_tensor : additional_inputs_list) { + XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + xla::Shape shape = xtensor->shape().get().ToString(); + xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + "UnusedArgumentsPlaceholder"); + parameter_idx += 1; + // xtensor->shape().get().ToString() + // xla_tensor->shaped_buffer().on_device_shape(); + } for (int i = 0; i < 2; i++) { xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); From e870e70ad8d4ce055e3c8d2ec77333e736d3c09e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 06:04:23 +0000 Subject: [PATCH 063/546] update --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 9a7cd811cb3..53b55e04601 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -929,7 +929,7 @@ class PyLoweringContext { // for (int i = 0; i < additional_inputs_list.size(); i++) { for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); - xla::Shape shape = xtensor->shape().get().ToString(); + xla::Shape shape = xtensor->shape().get(); // .ToString(); xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, "UnusedArgumentsPlaceholder"); parameter_idx += 1; From 349ed260f87d14b6199fd732fde90797701ce0b3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 06:07:34 +0000 Subject: [PATCH 064/546] update --- torch_xla/csrc/init_python_bindings.cpp | 44 ++++++++++++------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 53b55e04601..e1df6657574 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -936,28 +936,28 @@ class PyLoweringContext { // xtensor->shape().get().ToString() // xla_tensor->shaped_buffer().on_device_shape(); } - for (int i = 0; i < 2; i++) { - xla::Shape shape = - xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); - xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, - "UnusedArgumentsPlaceholder"); - parameter_idx += 1; - } - xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); - xla::XlaOp x1 = xla::Parameter(local_builder, parameter_idx, shape1, - "LInITensor"); - parameter_idx = parameter_idx + 1; - xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); - xla::XlaOp x2 = xla::Parameter(local_builder, parameter_idx, shape2, - "WeightTensor"); - parameter_idx = parameter_idx + 1; - xla::Shape shape4 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - xla::XlaOp x4 = xla::Parameter(local_builder, parameter_idx, shape4, - "BiasTensor"); - parameter_idx = parameter_idx + 1; - xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - xla::XlaOp x3 = xla::Parameter(local_builder, parameter_idx, shape3, - "LOutTensor"); + // for (int i = 0; i < 2; i++) { + // xla::Shape shape = + // xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); + // xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + // "UnusedArgumentsPlaceholder"); + // parameter_idx += 1; + // } + // xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); + // xla::XlaOp x1 = xla::Parameter(local_builder, parameter_idx, shape1, + // "LInITensor"); + // parameter_idx = parameter_idx + 1; + // xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); + // xla::XlaOp x2 = xla::Parameter(local_builder, parameter_idx, shape2, + // "WeightTensor"); + // parameter_idx = parameter_idx + 1; + // xla::Shape shape4 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + // xla::XlaOp x4 = xla::Parameter(local_builder, parameter_idx, shape4, + // "BiasTensor"); + // parameter_idx = parameter_idx + 1; + // xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); + // xla::XlaOp x3 = xla::Parameter(local_builder, parameter_idx, shape3, + // "LOutTensor"); } // hard-code modify body xlacomputation input arguments From 876298a657554dbb76c2f38ade7ad6426db6b932 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 08:22:56 +0000 Subject: [PATCH 065/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- torch_xla/experimental/fori_loop.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index e84adbaa194..c2ed1669e64 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -126,4 +126,4 @@ def test_fori_loop_tpu_simple_linear(self): if __name__ == '__main__': test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) \ No newline at end of file + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 8141d777a4e..753c68ba455 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -69,7 +69,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): cond_ctx.set_name_string("condctx") additional_inputs_list = list(fake_carried_inputs[2:]) for i in range(len(additional_inputs)): - additional_inputs_list.append(additional_inputs[0]) + additional_inputs_list.append(additional_inputs[i]) cond_ctx.buildforiloop([cond_result], additional_inputs_list) cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", From b73dc9629bfc4fcee805eff96ec3c014a1b7501e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 08:30:33 +0000 Subject: [PATCH 066/546] update --- torch_xla/csrc/init_python_bindings.cpp | 40 ++++++------------------- torch_xla/experimental/fori_loop.py | 9 ++++-- 2 files changed, 16 insertions(+), 33 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index e1df6657574..27dd7ac3bdf 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -925,48 +925,26 @@ class PyLoweringContext { // !!! --- next step: we add dump paras according to additional_inputs_list // ??? --- could we get IRvalue of `additional_inputs_list` in this function to complete xlacomputation? int64_t parameter_idx = 2; // parameter_idx start from 2 after upper and lower - // ? type, ? shape, - // for (int i = 0; i < additional_inputs_list.size(); i++) { for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); - xla::Shape shape = xtensor->shape().get(); // .ToString(); + xla::Shape shape = xtensor->shape().get(); xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, "UnusedArgumentsPlaceholder"); parameter_idx += 1; - // xtensor->shape().get().ToString() - // xla_tensor->shaped_buffer().on_device_shape(); } - // for (int i = 0; i < 2; i++) { - // xla::Shape shape = - // xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); - // xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, - // "UnusedArgumentsPlaceholder"); - // parameter_idx += 1; - // } - // xla::Shape shape1 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {10}); - // xla::XlaOp x1 = xla::Parameter(local_builder, parameter_idx, shape1, - // "LInITensor"); - // parameter_idx = parameter_idx + 1; - // xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20, 10}); - // xla::XlaOp x2 = xla::Parameter(local_builder, parameter_idx, shape2, - // "WeightTensor"); - // parameter_idx = parameter_idx + 1; - // xla::Shape shape4 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - // xla::XlaOp x4 = xla::Parameter(local_builder, parameter_idx, shape4, - // "BiasTensor"); - // parameter_idx = parameter_idx + 1; - // xla::Shape shape3 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - // xla::XlaOp x3 = xla::Parameter(local_builder, parameter_idx, shape3, - // "LOutTensor"); } // hard-code modify body xlacomputation input arguments if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); - int64_t parameter_idx = 7; - xla::Shape shape2 = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {20}); - xla::XlaOp x2 = xla::Parameter(local_builder, parameter_idx, shape2, - "WeightTensor"); + int64_t parameter_idx = tensors.size(); + for (auto& additional_input_tensor : additional_inputs_list) { + XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + xla::Shape shape = xtensor->shape().get(); + xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + "UnusedArgumentsPlaceholder"); + parameter_idx += 1; + } } // Get the backing XLA tensors from the output torch tensor handles diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 753c68ba455..88fa371a799 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -67,7 +67,8 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): cond_result = cond_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], output_value=fake_carried_inputs[-2], bias_0=fake_carried_inputs[-1]) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") - additional_inputs_list = list(fake_carried_inputs[2:]) + additional_inputs_list = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor + # treat and pass additional_inputs to cond_fn for i in range(len(additional_inputs)): additional_inputs_list.append(additional_inputs[i]) cond_ctx.buildforiloop([cond_result], additional_inputs_list) @@ -79,7 +80,11 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): body_result = body_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], output_value=fake_carried_inputs[-2], bias_0=fake_carried_inputs[-1]) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") - body_ctx.buildforiloop(list(body_result), []) + additional_inputs_list = list(fake_carried_inputs[-2]) # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor + # TODO(@manfei): treat and pass additional_inputs to body_fn too + for i in range(len(additional_inputs)): + additional_inputs_list.append(additional_inputs[i]) + body_ctx.buildforiloop(list(body_result), additional_inputs_list) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) From b5accda1bbe8e741e0334a5fa8e650848aa87e9e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 08:39:07 +0000 Subject: [PATCH 067/546] update --- torch_xla/csrc/init_python_bindings.cpp | 36 ++++++++++++++++--------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 27dd7ac3bdf..f90a5c11b1e 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -934,18 +934,18 @@ class PyLoweringContext { } } - // hard-code modify body xlacomputation input arguments - if (GetNameString() == "bodyctx") { - xla::XlaBuilder* local_builder = lowering_ctx.builder(); - int64_t parameter_idx = tensors.size(); - for (auto& additional_input_tensor : additional_inputs_list) { - XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); - xla::Shape shape = xtensor->shape().get(); - xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, - "UnusedArgumentsPlaceholder"); - parameter_idx += 1; - } - } + // // hard-code modify body xlacomputation input arguments + // if (GetNameString() == "bodyctx") { + // xla::XlaBuilder* local_builder = lowering_ctx.builder(); + // int64_t parameter_idx = tensors.size(); + // for (auto& additional_input_tensor : additional_inputs_list) { + // XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + // xla::Shape shape = xtensor->shape().get(); + // xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + // "UnusedArgumentsPlaceholder"); + // parameter_idx += 1; + // } + // } // Get the backing XLA tensors from the output torch tensor handles std::vector xtensors = @@ -972,6 +972,18 @@ class PyLoweringContext { std::vector buffer_donor_indices; xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); + // hard-code modify body xlacomputation input arguments + if (GetNameString() == "bodyctx") { + xla::XlaBuilder* local_builder = lowering_ctx.builder(); + int64_t parameter_idx = program_shape.parameters_size(); // tensors.size(); + for (auto& additional_input_tensor : additional_inputs_list) { + XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + xla::Shape shape = xtensor->shape().get(); + xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + "UnusedArgumentsPlaceholder"); + parameter_idx += 1; + } + } // TODO(@manfei): please confirm whether we check for more than two or use // default value true bool should_wrap_parameter = (program_shape.parameters_size() >= 2); From 8a1f6ad506a41fadd244786bfbbc57aaad37f82c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 08:42:25 +0000 Subject: [PATCH 068/546] update --- torch_xla/csrc/init_python_bindings.cpp | 48 ++++++++++++------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index f90a5c11b1e..7ce1ed45075 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -934,18 +934,18 @@ class PyLoweringContext { } } - // // hard-code modify body xlacomputation input arguments - // if (GetNameString() == "bodyctx") { - // xla::XlaBuilder* local_builder = lowering_ctx.builder(); - // int64_t parameter_idx = tensors.size(); - // for (auto& additional_input_tensor : additional_inputs_list) { - // XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); - // xla::Shape shape = xtensor->shape().get(); - // xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, - // "UnusedArgumentsPlaceholder"); - // parameter_idx += 1; - // } - // } + // hard-code modify body xlacomputation input arguments + if (GetNameString() == "bodyctx") { + xla::XlaBuilder* local_builder = lowering_ctx.builder(); + int64_t parameter_idx = 7; // tensors.size(); + for (auto& additional_input_tensor : additional_inputs_list) { + XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + xla::Shape shape = xtensor->shape().get(); + xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + "UnusedArgumentsPlaceholder"); + parameter_idx += 1; + } + } // Get the backing XLA tensors from the output torch tensor handles std::vector xtensors = @@ -972,18 +972,18 @@ class PyLoweringContext { std::vector buffer_donor_indices; xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); - // hard-code modify body xlacomputation input arguments - if (GetNameString() == "bodyctx") { - xla::XlaBuilder* local_builder = lowering_ctx.builder(); - int64_t parameter_idx = program_shape.parameters_size(); // tensors.size(); - for (auto& additional_input_tensor : additional_inputs_list) { - XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); - xla::Shape shape = xtensor->shape().get(); - xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, - "UnusedArgumentsPlaceholder"); - parameter_idx += 1; - } - } + // // hard-code modify body xlacomputation input arguments + // if (GetNameString() == "bodyctx") { + // xla::XlaBuilder* local_builder = lowering_ctx.builder(); + // int64_t parameter_idx = program_shape.parameters_size(); // tensors.size(); + // for (auto& additional_input_tensor : additional_inputs_list) { + // XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + // xla::Shape shape = xtensor->shape().get(); + // xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + // "UnusedArgumentsPlaceholder"); + // parameter_idx += 1; + // } + // } // TODO(@manfei): please confirm whether we check for more than two or use // default value true bool should_wrap_parameter = (program_shape.parameters_size() >= 2); From 7914abf109fe9399f59d58290f6277c88de24e06 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 08:43:25 +0000 Subject: [PATCH 069/546] update --- torch_xla/csrc/init_python_bindings.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 7ce1ed45075..145739e1d1a 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -973,6 +973,8 @@ class PyLoweringContext { xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); // // hard-code modify body xlacomputation input arguments + // // xxx: failed due to not change body_xlacomputation, might becase has been traced + // // xxx: after `computation = ConsumeValue(lowering_ctx.BuildXla());` // if (GetNameString() == "bodyctx") { // xla::XlaBuilder* local_builder = lowering_ctx.builder(); // int64_t parameter_idx = program_shape.parameters_size(); // tensors.size(); From 3da0e431daf640f22c69dd17bf26c217febbef14 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 08:46:08 +0000 Subject: [PATCH 070/546] update --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 88fa371a799..ca1e1e67f08 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -84,6 +84,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): # TODO(@manfei): treat and pass additional_inputs to body_fn too for i in range(len(additional_inputs)): additional_inputs_list.append(additional_inputs[i]) + print("len!!!: ", len(additional_inputs_list)) body_ctx.buildforiloop(list(body_result), additional_inputs_list) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", From 5c5ca2c083d847ef0e7b9f1c9a551294dae7e1a1 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 08:47:05 +0000 Subject: [PATCH 071/546] update --- torch_xla/experimental/fori_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index ca1e1e67f08..5eeb88121f0 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -82,9 +82,11 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): body_ctx.set_name_string("bodyctx") additional_inputs_list = list(fake_carried_inputs[-2]) # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor # TODO(@manfei): treat and pass additional_inputs to body_fn too + print("len0!!!: ", len(additional_inputs_list)) for i in range(len(additional_inputs)): additional_inputs_list.append(additional_inputs[i]) print("len!!!: ", len(additional_inputs_list)) + print("additional_inputs_list: ", additional_inputs_list) body_ctx.buildforiloop(list(body_result), additional_inputs_list) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", From 1c9e92fb47e9ef615ac8f907af1e6035e51a89b4 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 08:51:05 +0000 Subject: [PATCH 072/546] update --- torch_xla/experimental/fori_loop.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 5eeb88121f0..3237ae58a44 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -67,11 +67,11 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): cond_result = cond_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], output_value=fake_carried_inputs[-2], bias_0=fake_carried_inputs[-1]) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") - additional_inputs_list = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor + additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor # treat and pass additional_inputs to cond_fn for i in range(len(additional_inputs)): - additional_inputs_list.append(additional_inputs[i]) - cond_ctx.buildforiloop([cond_result], additional_inputs_list) + additional_inputs_list_cond.append(additional_inputs[i]) + cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) @@ -80,14 +80,14 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): body_result = body_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], output_value=fake_carried_inputs[-2], bias_0=fake_carried_inputs[-1]) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") - additional_inputs_list = list(fake_carried_inputs[-2]) # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor + additional_inputs_list_body = list(fake_carried_inputs[-2]) # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor # TODO(@manfei): treat and pass additional_inputs to body_fn too - print("len0!!!: ", len(additional_inputs_list)) + print("len0!!!: ", len(additional_inputs_list_body)) for i in range(len(additional_inputs)): - additional_inputs_list.append(additional_inputs[i]) - print("len!!!: ", len(additional_inputs_list)) - print("additional_inputs_list: ", additional_inputs_list) - body_ctx.buildforiloop(list(body_result), additional_inputs_list) + additional_inputs_list_body.append(additional_inputs[i]) + print("len!!!: ", len(additional_inputs_list_body)) + print("additional_inputs_list_body: ", additional_inputs_list_body) + body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) From 52d1c177bd4b5e104002314dd75d99f42aa81c3c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 08:51:56 +0000 Subject: [PATCH 073/546] update --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 3237ae58a44..d8328176607 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -82,6 +82,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): body_ctx.set_name_string("bodyctx") additional_inputs_list_body = list(fake_carried_inputs[-2]) # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor # TODO(@manfei): treat and pass additional_inputs to body_fn too + print("list(fake_carried_inputs[-2]: ", list(fake_carried_inputs[-2]) print("len0!!!: ", len(additional_inputs_list_body)) for i in range(len(additional_inputs)): additional_inputs_list_body.append(additional_inputs[i]) From c4fe1222a8f5d11f90cfa169d04a07def03e94d4 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 08:52:30 +0000 Subject: [PATCH 074/546] update --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index d8328176607..13a4dbf51ee 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -82,7 +82,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): body_ctx.set_name_string("bodyctx") additional_inputs_list_body = list(fake_carried_inputs[-2]) # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor # TODO(@manfei): treat and pass additional_inputs to body_fn too - print("list(fake_carried_inputs[-2]: ", list(fake_carried_inputs[-2]) + print("list(fake_carried_inputs[-2]: ", list(fake_carried_inputs[-2])) print("len0!!!: ", len(additional_inputs_list_body)) for i in range(len(additional_inputs)): additional_inputs_list_body.append(additional_inputs[i]) From 80d20034dd8e19b53c7659971d45b5a9144a29df Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 08:54:44 +0000 Subject: [PATCH 075/546] update --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 13a4dbf51ee..bb18c1b65cf 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -80,9 +80,9 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): body_result = body_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], output_value=fake_carried_inputs[-2], bias_0=fake_carried_inputs[-1]) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") - additional_inputs_list_body = list(fake_carried_inputs[-2]) # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor + additional_inputs_list_body = fake_carried_inputs[-2] # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor # TODO(@manfei): treat and pass additional_inputs to body_fn too - print("list(fake_carried_inputs[-2]: ", list(fake_carried_inputs[-2])) + print("list(fake_carried_inputs[-2]: ", fake_carried_inputs[-2]) print("len0!!!: ", len(additional_inputs_list_body)) for i in range(len(additional_inputs)): additional_inputs_list_body.append(additional_inputs[i]) From 71718e15f0670e273bb5d8022d94aef6d7c77be5 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 08:56:26 +0000 Subject: [PATCH 076/546] update --- torch_xla/experimental/fori_loop.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index bb18c1b65cf..fa50a422008 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -80,14 +80,14 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): body_result = body_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], output_value=fake_carried_inputs[-2], bias_0=fake_carried_inputs[-1]) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") - additional_inputs_list_body = fake_carried_inputs[-2] # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor + additional_inputs_list_body = [fake_carried_inputs[-2]] # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor # TODO(@manfei): treat and pass additional_inputs to body_fn too - print("list(fake_carried_inputs[-2]: ", fake_carried_inputs[-2]) - print("len0!!!: ", len(additional_inputs_list_body)) + # print("list(fake_carried_inputs[-2]: ", fake_carried_inputs[-2]) + # print("len0!!!: ", len(additional_inputs_list_body)) for i in range(len(additional_inputs)): additional_inputs_list_body.append(additional_inputs[i]) - print("len!!!: ", len(additional_inputs_list_body)) - print("additional_inputs_list_body: ", additional_inputs_list_body) + # print("len!!!: ", len(additional_inputs_list_body)) + # print("additional_inputs_list_body: ", additional_inputs_list_body) body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", From e8e18f764d20a19d99eb102b72590e8c79426e75 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 20:17:49 +0000 Subject: [PATCH 077/546] update --- ...while_loop_simple_add_dispatch_in_torch.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index c2ed1669e64..3db6d4e93c1 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -87,6 +87,40 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) + def test_while_loop_tpu_simple_linear(self): + + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + upper = torch.tensor([52], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.randn(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + weight_0 = linear_0.weight + bias_0 = linear_0.bias + + def cond_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): + new_lower = torch.add(one_value, lower) + output_value = body_fun(*input_value) + weight = body_fun.weight + bias = body_fun.bias + return upper, new_lower, one_value, torch.add(one_value, x), *input_value, weight, bias, output_value + + upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_= = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, *input_value, weight_0, bias_0, output_value)) + + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + self.assertTrue(torch.all(torch.eq(expected, l_out_))) + + def test_fori_loop_tpu_addition(self): xm.mark_step() From 9741e8d1f106a54ba82a0cf7f6d20b04bed1ea96 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 20:20:45 +0000 Subject: [PATCH 078/546] update --- ...i_loop_with_while_loop_simple_add_dispatch_in_torch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 3db6d4e93c1..1be3ccdd3b4 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -95,10 +95,10 @@ def test_while_loop_tpu_simple_linear(self): upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) - init_val = torch.tensor([1], dtype=torch.int32, device=device) - l_in_0 = torch.randn(10, device=xm.xla_device()) - output_value = torch.zeros([20], dtype=torch.float32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) # x + l_in_0 = torch.randn(10, device=xm.xla_device()) # input_value + output_value = torch.zeros([20], dtype=torch.float32, device=device) linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) weight_0 = linear_0.weight @@ -114,7 +114,7 @@ def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bi bias = body_fun.bias return upper, new_lower, one_value, torch.add(one_value, x), *input_value, weight, bias, output_value - upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_= = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, *input_value, weight_0, bias_0, output_value)) + upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From 26e30a399cbcbc3ba5ecc3db6f1828f81f2a8703 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 20:22:11 +0000 Subject: [PATCH 079/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 1be3ccdd3b4..0b5d79383af 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -114,7 +114,8 @@ def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bi bias = body_fun.bias return upper, new_lower, one_value, torch.add(one_value, x), *input_value, weight, bias, output_value - upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value)) + upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = + while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From c6f82592acc58dd87c9ec4546d903e32effbe11f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 20:24:06 +0000 Subject: [PATCH 080/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 1 + torch_xla/experimental/fori_loop.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 0b5d79383af..83a345f3760 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -114,6 +114,7 @@ def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bi bias = body_fun.bias return upper, new_lower, one_value, torch.add(one_value, x), *input_value, weight, bias, output_value + print("!!! arrive here !!!") upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value)) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index fa50a422008..20f0f108035 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -44,12 +44,14 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') # cond_fn&body_fn: callable # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) + print("!!! arrive here too !!!") if additional_inputs is None: additional_inputs = tuple() return _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) # a=a, b=b, c=c, def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): + print("!!! arrive here too too !!!") # untuple carried_inputs from while_loop carried_inputs = carried_inputs[0] # fake carried_inputs to split formal code From 5b9378c437a4be7bdf5bba417cec81d69d099e6c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 20:24:29 +0000 Subject: [PATCH 081/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 83a345f3760..f0674e6b994 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -115,8 +115,8 @@ def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bi return upper, new_lower, one_value, torch.add(one_value, x), *input_value, weight, bias, output_value print("!!! arrive here !!!") - upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = - while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value)) + upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop( + cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From 64088798e4f5032b21a7fd2dbcdf300d3c2d1efa Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 20:26:05 +0000 Subject: [PATCH 082/546] update --- ...i_loop_with_while_loop_simple_add_dispatch_in_torch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index f0674e6b994..bd41135199e 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -104,15 +104,15 @@ def test_while_loop_tpu_simple_linear(self): weight_0 = linear_0.weight bias_0 = linear_0.bias - def cond_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): + def cond_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): + def body_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): new_lower = torch.add(one_value, lower) - output_value = body_fun(*input_value) + output_value = body_fun(input_value) weight = body_fun.weight bias = body_fun.bias - return upper, new_lower, one_value, torch.add(one_value, x), *input_value, weight, bias, output_value + return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value print("!!! arrive here !!!") upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop( From 7c408a93358a2c834e14fe5ac62c42184308dcab Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 20:27:48 +0000 Subject: [PATCH 083/546] update --- ...ori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index bd41135199e..3c103deddbd 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -109,9 +109,9 @@ def cond_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bia def body_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): new_lower = torch.add(one_value, lower) - output_value = body_fun(input_value) - weight = body_fun.weight - bias = body_fun.bias + output_value = linear_0(input_value) + weight = linear_0.weight + bias = linear_0.bias return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value print("!!! arrive here !!!") From 9c6852877ba8511a6c3d2a1e592b66fbb5122b05 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 20:29:42 +0000 Subject: [PATCH 084/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 3c103deddbd..88ff7148a97 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -116,7 +116,7 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bia print("!!! arrive here !!!") upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop( - cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value)) + cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, output_value, bias_0)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From 141c7043ddbbc53715f4ed54ad1f5b7347c0169e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 20:41:27 +0000 Subject: [PATCH 085/546] update --- ...fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 88ff7148a97..3d5945cdb00 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -114,9 +114,8 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bia bias = linear_0.bias return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value - print("!!! arrive here !!!") - upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop( - cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, output_value, bias_0)) + # print("!!! arrive here !!!") + upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, output_value, bias_0)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From d5e5537c7a12d9698c4ade31564ac085bed03725 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 20:43:09 +0000 Subject: [PATCH 086/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 3d5945cdb00..dd8d0701a7e 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -115,7 +115,7 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bia return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value # print("!!! arrive here !!!") - upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, output_value, bias_0)) + upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, output_value, bias_0), additional_inputs=None) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From cfbe475113a9cde532fabf61da86052ea4aa40a0 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 20:47:31 +0000 Subject: [PATCH 087/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index dd8d0701a7e..3d5945cdb00 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -115,7 +115,7 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bia return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value # print("!!! arrive here !!!") - upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, output_value, bias_0), additional_inputs=None) + upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, output_value, bias_0)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From b11f015586b13119dbc75d99662599c40ed6256c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 20:58:28 +0000 Subject: [PATCH 088/546] update --- ...while_loop_simple_add_dispatch_in_torch.py | 38 ++++++++++++++++++- torch_xla/experimental/fori_loop.py | 13 +++++-- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 3d5945cdb00..e28d4cbd605 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -87,6 +87,39 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) +# ////// +# class SimpleWithLinear(torch.nn.Module): +# def __init__(self): +# super().__init__() +# self.linear = torch.nn.Linear(2, 2) +# self.register_buffer("dec", torch.tensor(1)) + +# def forward(self, iter, x): +# def cond_fn(it, x): +# return it - self.dec > 0 + +# def body_fn(it, x): +# return it - 1, self.linear(x) +# return while_loop(cond_fn, body_fn, (iter, x)) + +# class NestedWithLinear(torch.nn.Module): +# return while_loop(cond_fn, body_fn, (iter, x)) + +# nested2 = Nested() +# simple_with_linear = SimpleWithLinear() +# nested_with_linear = NestedWithLinear() + +# x = torch.zeros(1) +# y = torch.zeros(1) +# z = torch.zeros(1) +# return {"simple": (simple, (x,)), +# "nested": (nested, (x, y, z)), +# "nested2": (nested2, (torch.tensor(2), torch.tensor(2), torch.ones(2, 2), torch.ones(2, 2))), +# "simple_with_mutation": (simple_with_mutation, (x,)), +# "simple_with_linear": (simple_with_linear, (torch.tensor(3), torch.randn(2, 2))), +# "nested_with_linear": (nested_with_linear, (torch.tensor(3), torch.randn(2, 2)))} +# ////// + def test_while_loop_tpu_simple_linear(self): xm.mark_step() @@ -104,10 +137,11 @@ def test_while_loop_tpu_simple_linear(self): weight_0 = linear_0.weight bias_0 = linear_0.bias - def cond_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): + # def cond_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): + def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): + def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value = linear_0(input_value) weight = linear_0.weight diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 20f0f108035..2bbecd42f64 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -50,7 +50,7 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): return _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) # a=a, b=b, c=c, -def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): +def _xla_while_loop(cond_fn, body_fn, *carried_inputs, *additional_inputs): print("!!! arrive here too too !!!") # untuple carried_inputs from while_loop carried_inputs = carried_inputs[0] @@ -62,11 +62,18 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): fake_carried_inputs.append( torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) + # fake_carried_inputs = tuple(fake_carried_inputs) + for additional_input in additional_inputs: + device = additional_input.device + #TODO(@manfei) type = carried_input.type + fake_carried_inputs.append( + torch.randint(10, additional_input.size(), + dtype=additional_input.dtype).to(device)) fake_carried_inputs = tuple(fake_carried_inputs) # generate cond_fn xlacomputation # TODO(@manfei): specify which element is for which argument like a,b,c - cond_result = cond_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], output_value=fake_carried_inputs[-2], bias_0=fake_carried_inputs[-1]) + cond_result = cond_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-2], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-1]) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor @@ -79,7 +86,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): cond_hlo) # generate body_fn xlacomputation - body_result = body_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-3], output_value=fake_carried_inputs[-2], bias_0=fake_carried_inputs[-1]) + body_result = body_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-2], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-1]) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") additional_inputs_list_body = [fake_carried_inputs[-2]] # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor From 9cd8ca054c191ca0d7334e25d80ed8e868291ed1 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:00:29 +0000 Subject: [PATCH 089/546] update --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 2bbecd42f64..d9b399d2dfb 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -50,7 +50,7 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): return _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) # a=a, b=b, c=c, -def _xla_while_loop(cond_fn, body_fn, *carried_inputs, *additional_inputs): +def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): print("!!! arrive here too too !!!") # untuple carried_inputs from while_loop carried_inputs = carried_inputs[0] From 89a2b34f9eb42539ddb4c046315c4bd4f9de8f3e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:01:58 +0000 Subject: [PATCH 090/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index e28d4cbd605..307a2377bb7 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -149,7 +149,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value # print("!!! arrive here !!!") - upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, output_value, bias_0)) + upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From 9e420ad35a059cf3ceea3eec2a63c79fc79a2eaf Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:05:19 +0000 Subject: [PATCH 091/546] update --- ...while_loop_simple_add_dispatch_in_torch.py | 71 ++++++++++++------- 1 file changed, 47 insertions(+), 24 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 307a2377bb7..9264b43ba6a 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -126,30 +126,53 @@ def test_while_loop_tpu_simple_linear(self): device = xm.xla_device() torch.set_grad_enabled(False) - upper = torch.tensor([52], dtype=torch.int32, device=device) - lower = torch.tensor([0], dtype=torch.int32, device=device) - one_value = torch.tensor([1], dtype=torch.int32, device=device) - init_val = torch.tensor([1], dtype=torch.int32, device=device) # x - l_in_0 = torch.randn(10, device=xm.xla_device()) # input_value - output_value = torch.zeros([20], dtype=torch.float32, device=device) - - linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - weight_0 = linear_0.weight - bias_0 = linear_0.bias - - # def cond_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): - def cond_fn(upper, lower, one_value, x, input_value, output_value): - return lower[0] < upper[0] - - def body_fn(upper, lower, one_value, x, input_value, output_value): - new_lower = torch.add(one_value, lower) - output_value = linear_0(input_value) - weight = linear_0.weight - bias = linear_0.bias - return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value - - # print("!!! arrive here !!!") - upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) + class SimpleWithLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + # self.register_buffer("dec", torch.tensor(1)) + + def forward(self, upper, lower, one_value, x, input_value, output_value): + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value = linear_0(input_value) + weight = linear_0.weight + bias = linear_0.bias + return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value + # return while_loop(cond_fn, body_fn, (iter, x)) + return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) + + # xm.mark_step() + # device = xm.xla_device() + # torch.set_grad_enabled(False) + + # upper = torch.tensor([52], dtype=torch.int32, device=device) + # lower = torch.tensor([0], dtype=torch.int32, device=device) + # one_value = torch.tensor([1], dtype=torch.int32, device=device) + # init_val = torch.tensor([1], dtype=torch.int32, device=device) # x + # l_in_0 = torch.randn(10, device=xm.xla_device()) # input_value + # output_value = torch.zeros([20], dtype=torch.float32, device=device) + + # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + # weight_0 = linear_0.weight + # bias_0 = linear_0.bias + + # # def cond_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): + # def cond_fn(upper, lower, one_value, x, input_value, output_value): + # return lower[0] < upper[0] + + # def body_fn(upper, lower, one_value, x, input_value, output_value): + # new_lower = torch.add(one_value, lower) + # output_value = linear_0(input_value) + # weight = linear_0.weight + # bias = linear_0.bias + # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value + + # # print("!!! arrive here !!!") + # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From a536a3eec0ff6fdedc02de789cf0baeac93df54c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:05:41 +0000 Subject: [PATCH 092/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 9264b43ba6a..8c78ee4a8e1 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -174,9 +174,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # # print("!!! arrive here !!!") # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) - expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + # expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) - self.assertTrue(torch.all(torch.eq(expected, l_out_))) + # self.assertTrue(torch.all(torch.eq(expected, l_out_))) def test_fori_loop_tpu_addition(self): From b2b246f3804c435ae98962118a8909c9d70a7bd8 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:09:39 +0000 Subject: [PATCH 093/546] update --- ...while_loop_simple_add_dispatch_in_torch.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 8c78ee4a8e1..762f6960e58 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -145,6 +145,25 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return while_loop(cond_fn, body_fn, (iter, x)) return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) + simple_with_linear = SimpleWithLinear() + upper = torch.tensor([52], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) # x + l_in_0 = torch.randn(10, device=xm.xla_device()) # input_value + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + weight_0 = linear_0.weight + bias_0 = linear_0.bias + + return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} + +# x = torch.zeros(1) +# y = torch.zeros(1) +# z = torch.zeros(1) +# return {"simple_with_linear": (simple_with_linear, (torch.tensor(3), torch.randn(2, 2)))} + # xm.mark_step() # device = xm.xla_device() # torch.set_grad_enabled(False) From 2664e23a2f5053c34cea48eb9b217f4ca1b24a79 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:12:58 +0000 Subject: [PATCH 094/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 762f6960e58..9ef4132acd1 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -157,7 +157,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): weight_0 = linear_0.weight bias_0 = linear_0.bias + # return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} + simple_with_linear.apply(upper, lower, one_value, init_val, l_in_0, output_value) # x = torch.zeros(1) # y = torch.zeros(1) From 4f28a541e421aa4fa4ffade31b4c44e4f8a190ca Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:13:31 +0000 Subject: [PATCH 095/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 9ef4132acd1..f79c801f8e7 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -159,7 +159,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} - simple_with_linear.apply(upper, lower, one_value, init_val, l_in_0, output_value) + res = simple_with_linear.apply(upper, lower, one_value, init_val, l_in_0, output_value) + print("res: ", res) # x = torch.zeros(1) # y = torch.zeros(1) From 5c993afa7292577a77baf9b73af80ee24428633c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:14:35 +0000 Subject: [PATCH 096/546] update --- ...est_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index f79c801f8e7..37248be031a 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -161,6 +161,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} res = simple_with_linear.apply(upper, lower, one_value, init_val, l_in_0, output_value) print("res: ", res) + import pdb; pdb.set_trace() # x = torch.zeros(1) # y = torch.zeros(1) From 8298841bf8455c8cd58ed2a102e3879a86695326 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:15:44 +0000 Subject: [PATCH 097/546] update --- ...while_loop_simple_add_dispatch_in_torch.py | 113 +++++++----------- 1 file changed, 41 insertions(+), 72 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 37248be031a..bec06748415 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -87,39 +87,6 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) -# ////// -# class SimpleWithLinear(torch.nn.Module): -# def __init__(self): -# super().__init__() -# self.linear = torch.nn.Linear(2, 2) -# self.register_buffer("dec", torch.tensor(1)) - -# def forward(self, iter, x): -# def cond_fn(it, x): -# return it - self.dec > 0 - -# def body_fn(it, x): -# return it - 1, self.linear(x) -# return while_loop(cond_fn, body_fn, (iter, x)) - -# class NestedWithLinear(torch.nn.Module): -# return while_loop(cond_fn, body_fn, (iter, x)) - -# nested2 = Nested() -# simple_with_linear = SimpleWithLinear() -# nested_with_linear = NestedWithLinear() - -# x = torch.zeros(1) -# y = torch.zeros(1) -# z = torch.zeros(1) -# return {"simple": (simple, (x,)), -# "nested": (nested, (x, y, z)), -# "nested2": (nested2, (torch.tensor(2), torch.tensor(2), torch.ones(2, 2), torch.ones(2, 2))), -# "simple_with_mutation": (simple_with_mutation, (x,)), -# "simple_with_linear": (simple_with_linear, (torch.tensor(3), torch.randn(2, 2))), -# "nested_with_linear": (nested_with_linear, (torch.tensor(3), torch.randn(2, 2)))} -# ////// - def test_while_loop_tpu_simple_linear(self): xm.mark_step() @@ -163,45 +130,6 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): print("res: ", res) import pdb; pdb.set_trace() -# x = torch.zeros(1) -# y = torch.zeros(1) -# z = torch.zeros(1) -# return {"simple_with_linear": (simple_with_linear, (torch.tensor(3), torch.randn(2, 2)))} - - # xm.mark_step() - # device = xm.xla_device() - # torch.set_grad_enabled(False) - - # upper = torch.tensor([52], dtype=torch.int32, device=device) - # lower = torch.tensor([0], dtype=torch.int32, device=device) - # one_value = torch.tensor([1], dtype=torch.int32, device=device) - # init_val = torch.tensor([1], dtype=torch.int32, device=device) # x - # l_in_0 = torch.randn(10, device=xm.xla_device()) # input_value - # output_value = torch.zeros([20], dtype=torch.float32, device=device) - - # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - # weight_0 = linear_0.weight - # bias_0 = linear_0.bias - - # # def cond_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): - # def cond_fn(upper, lower, one_value, x, input_value, output_value): - # return lower[0] < upper[0] - - # def body_fn(upper, lower, one_value, x, input_value, output_value): - # new_lower = torch.add(one_value, lower) - # output_value = linear_0(input_value) - # weight = linear_0.weight - # bias = linear_0.bias - # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value - - # # print("!!! arrive here !!!") - # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) - - # expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) - - # self.assertTrue(torch.all(torch.eq(expected, l_out_))) - - def test_fori_loop_tpu_addition(self): xm.mark_step() @@ -242,3 +170,44 @@ def test_fori_loop_tpu_simple_linear(self): if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) + + +######## --------------------------------------------------------- + +# x = torch.zeros(1) +# y = torch.zeros(1) +# z = torch.zeros(1) +# return {"simple_with_linear": (simple_with_linear, (torch.tensor(3), torch.randn(2, 2)))} + + # xm.mark_step() + # device = xm.xla_device() + # torch.set_grad_enabled(False) + + # upper = torch.tensor([52], dtype=torch.int32, device=device) + # lower = torch.tensor([0], dtype=torch.int32, device=device) + # one_value = torch.tensor([1], dtype=torch.int32, device=device) + # init_val = torch.tensor([1], dtype=torch.int32, device=device) # x + # l_in_0 = torch.randn(10, device=xm.xla_device()) # input_value + # output_value = torch.zeros([20], dtype=torch.float32, device=device) + + # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + # weight_0 = linear_0.weight + # bias_0 = linear_0.bias + + # # def cond_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): + # def cond_fn(upper, lower, one_value, x, input_value, output_value): + # return lower[0] < upper[0] + + # def body_fn(upper, lower, one_value, x, input_value, output_value): + # new_lower = torch.add(one_value, lower) + # output_value = linear_0(input_value) + # weight = linear_0.weight + # bias = linear_0.bias + # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value + + # # print("!!! arrive here !!!") + # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) + + # expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + # self.assertTrue(torch.all(torch.eq(expected, l_out_))) From 32af47b623529f7e0c6021cf92aeb818a8716b18 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:16:23 +0000 Subject: [PATCH 098/546] update --- ...est_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index bec06748415..beb31887e5c 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -92,6 +92,7 @@ def test_while_loop_tpu_simple_linear(self): xm.mark_step() device = xm.xla_device() torch.set_grad_enabled(False) + print("start test !!!") class SimpleWithLinear(torch.nn.Module): def __init__(self): From 448de12d7a5ec519bf9474fbc5dc1dc57e30ee10 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:16:49 +0000 Subject: [PATCH 099/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index beb31887e5c..26438e6bf1e 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -92,7 +92,6 @@ def test_while_loop_tpu_simple_linear(self): xm.mark_step() device = xm.xla_device() torch.set_grad_enabled(False) - print("start test !!!") class SimpleWithLinear(torch.nn.Module): def __init__(self): @@ -113,6 +112,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return while_loop(cond_fn, body_fn, (iter, x)) return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) + print("start test !!!") simple_with_linear = SimpleWithLinear() upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) From abe4c788d7ce7ca396f0a77ce2e8b0116143e2d7 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:19:36 +0000 Subject: [PATCH 100/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 26438e6bf1e..5bd827626cb 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -92,6 +92,7 @@ def test_while_loop_tpu_simple_linear(self): xm.mark_step() device = xm.xla_device() torch.set_grad_enabled(False) + print("start test 1 !!!") class SimpleWithLinear(torch.nn.Module): def __init__(self): @@ -112,7 +113,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return while_loop(cond_fn, body_fn, (iter, x)) return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) - print("start test !!!") + print("start test 2 !!!") simple_with_linear = SimpleWithLinear() upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) From 1bf37d8322fa8aa19d5a635e3fa656dd574f5479 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:19:57 +0000 Subject: [PATCH 101/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 5bd827626cb..d2ad777b869 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -122,6 +122,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): l_in_0 = torch.randn(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) + print("start test 3 !!!") + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) weight_0 = linear_0.weight bias_0 = linear_0.bias From 20215378510d3c65dd04d2679b088517cc0ebeba Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:20:27 +0000 Subject: [PATCH 102/546] update --- ...fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index d2ad777b869..12abf30392f 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -128,8 +128,11 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): weight_0 = linear_0.weight bias_0 = linear_0.bias + print("start test 4 !!!") + + # return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} # return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} - return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} + print("start test 5 !!!") res = simple_with_linear.apply(upper, lower, one_value, init_val, l_in_0, output_value) print("res: ", res) import pdb; pdb.set_trace() From 83c2f97b2712f9c2d7009afd98246f56e6386456 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:20:57 +0000 Subject: [PATCH 103/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 12abf30392f..596a488cf60 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -135,7 +135,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): print("start test 5 !!!") res = simple_with_linear.apply(upper, lower, one_value, init_val, l_in_0, output_value) print("res: ", res) + print("start test 6 !!!") import pdb; pdb.set_trace() + return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} def test_fori_loop_tpu_addition(self): From 1945952aee79ac87e783cd2a805135741ecfc11e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:22:55 +0000 Subject: [PATCH 104/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 596a488cf60..ab480592e76 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -129,11 +129,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): bias_0 = linear_0.bias print("start test 4 !!!") - # return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} # return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} print("start test 5 !!!") - res = simple_with_linear.apply(upper, lower, one_value, init_val, l_in_0, output_value) + res = simple_with_linear.apply((upper, lower, one_value, init_val, l_in_0, output_value)) print("res: ", res) print("start test 6 !!!") import pdb; pdb.set_trace() From 6ee5400dd8f7c3c8256ff2045534406f0a2b4cc7 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:25:05 +0000 Subject: [PATCH 105/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index ab480592e76..cb148b60346 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -132,9 +132,11 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} # return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} print("start test 5 !!!") + aaa = {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} + print("aaa: ", aaa) + print("start test 6 !!!") res = simple_with_linear.apply((upper, lower, one_value, init_val, l_in_0, output_value)) print("res: ", res) - print("start test 6 !!!") import pdb; pdb.set_trace() return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} From 3db127e8bb9b2ce33bf280c87549b1d51870812c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:25:54 +0000 Subject: [PATCH 106/546] update --- ..._loop_with_while_loop_simple_add_dispatch_in_torch.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index cb148b60346..be1310f914d 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -135,10 +135,11 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): aaa = {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} print("aaa: ", aaa) print("start test 6 !!!") - res = simple_with_linear.apply((upper, lower, one_value, init_val, l_in_0, output_value)) - print("res: ", res) - import pdb; pdb.set_trace() - return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} + return aaa + # res = simple_with_linear.apply((upper, lower, one_value, init_val, l_in_0, output_value)) + # print("res: ", res) + # import pdb; pdb.set_trace() + # return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} def test_fori_loop_tpu_addition(self): From dc1837d999b2a61cb1156d8b4171aa0c4c5f6362 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:30:03 +0000 Subject: [PATCH 107/546] update --- ...ith_while_loop_simple_add_dispatch_in_torch.py | 15 ++++++--------- torch_xla/experimental/fori_loop.py | 1 + 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index be1310f914d..3c33a118a75 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -92,7 +92,6 @@ def test_while_loop_tpu_simple_linear(self): xm.mark_step() device = xm.xla_device() torch.set_grad_enabled(False) - print("start test 1 !!!") class SimpleWithLinear(torch.nn.Module): def __init__(self): @@ -113,7 +112,6 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return while_loop(cond_fn, body_fn, (iter, x)) return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) - print("start test 2 !!!") simple_with_linear = SimpleWithLinear() upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) @@ -122,20 +120,19 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): l_in_0 = torch.randn(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) - print("start test 3 !!!") - linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) weight_0 = linear_0.weight bias_0 = linear_0.bias - print("start test 4 !!!") - # return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} - # return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} - print("start test 5 !!!") aaa = {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} + upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = aaa print("aaa: ", aaa) - print("start test 6 !!!") + # print("start test 6 !!!") return aaa + + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + self.assertTrue(torch.all(torch.eq(expected, l_out_))) # res = simple_with_linear.apply((upper, lower, one_value, init_val, l_in_0, output_value)) # print("res: ", res) # import pdb; pdb.set_trace() diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index d9b399d2dfb..70b4388ab9c 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -52,6 +52,7 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): print("!!! arrive here too too !!!") + import pdb; pdb.set_trace() # untuple carried_inputs from while_loop carried_inputs = carried_inputs[0] # fake carried_inputs to split formal code From ea9bf565ad47e9d7014e3c8660573de5d90f77a4 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:30:34 +0000 Subject: [PATCH 108/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 3c33a118a75..563cfd39adf 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -125,7 +125,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): bias_0 = linear_0.bias aaa = {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} - upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = aaa + # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = aaa print("aaa: ", aaa) # print("start test 6 !!!") return aaa From ee05a67db6108568ea831a89f5cb84fcc3dadc1d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 21:31:40 +0000 Subject: [PATCH 109/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 563cfd39adf..138e1b18b46 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -117,7 +117,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) # x - l_in_0 = torch.randn(10, device=xm.xla_device()) # input_value + # l_in_0 = torch.randn(10, device=xm.xla_device()) # input_value + l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) From 67f26521d10af7f4229cfb531689e821ff709327 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 22:43:46 +0000 Subject: [PATCH 110/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 138e1b18b46..44948605d9c 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -128,6 +128,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): aaa = {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = aaa print("aaa: ", aaa) + bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) + print("bbb: ", bbb) # print("start test 6 !!!") return aaa From 3bba9fd7dfff560b07df023788f427f31a775286 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 22:45:40 +0000 Subject: [PATCH 111/546] update --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 70b4388ab9c..1e6d2dd732a 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -87,7 +87,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): cond_hlo) # generate body_fn xlacomputation - body_result = body_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-2], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-1]) + body_result = body_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-1], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-2]) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") additional_inputs_list_body = [fake_carried_inputs[-2]] # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor From 2e349a601de3dbc1bba418dc5a95e9242df98080 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 22:52:23 +0000 Subject: [PATCH 112/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 44948605d9c..b04d9dd3ef6 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -105,10 +105,10 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) - output_value = linear_0(input_value) + output_value_real = linear_0(input_value) weight = linear_0.weight bias = linear_0.bias - return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value + return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value_real # return while_loop(cond_fn, body_fn, (iter, x)) return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) From 7b63eb845b858d45651eeff1334b3db32f4f8a0c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 22:53:48 +0000 Subject: [PATCH 113/546] update --- ...est_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index b04d9dd3ef6..8b4804d5e31 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -105,6 +105,7 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) + output_value_real = output_value.copy() output_value_real = linear_0(input_value) weight = linear_0.weight bias = linear_0.bias From 62855cc33940a6ba8df45cce62e12f27592ccc00 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 22:54:25 +0000 Subject: [PATCH 114/546] update --- ...est_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 8b4804d5e31..b04d9dd3ef6 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -105,7 +105,6 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) - output_value_real = output_value.copy() output_value_real = linear_0(input_value) weight = linear_0.weight bias = linear_0.bias From 70a29c4f7a3064155b30516dd780f8f0a40b839b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:02:23 +0000 Subject: [PATCH 115/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index b04d9dd3ef6..b1ad2095c3d 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -7,6 +7,7 @@ # We need to import the underlying implementation function to register with the dispatcher import torch_xla.experimental.fori_loop from torch_xla.experimental.fori_loop import fori_loop +from torch_xla.experimental.fori_loop import _xla_while_loop from torch._higher_order_ops.while_loop import while_loop import torch_xla.core.xla_model as xm import torch_xla.core.xla_builder as xb @@ -110,7 +111,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): bias = linear_0.bias return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value_real # return while_loop(cond_fn, body_fn, (iter, x)) - return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) + # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) + return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() upper = torch.tensor([52], dtype=torch.int32, device=device) From 863ba37a872db3f70beabefb97f9f1f42236bae8 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:03:04 +0000 Subject: [PATCH 116/546] update --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 1e6d2dd732a..979be6c254d 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -50,7 +50,7 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): return _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) # a=a, b=b, c=c, -def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): +def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): print("!!! arrive here too too !!!") import pdb; pdb.set_trace() # untuple carried_inputs from while_loop From 4da7ede1c84c9d2c2e6fcb64d2a71599146e6d90 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:04:17 +0000 Subject: [PATCH 117/546] update --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 979be6c254d..2de46cdd938 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -50,9 +50,9 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): return _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) # a=a, b=b, c=c, -def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): +def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): print("!!! arrive here too too !!!") - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() # untuple carried_inputs from while_loop carried_inputs = carried_inputs[0] # fake carried_inputs to split formal code From 261ec24bab6a0e6e954eecc42d3bbc423931132f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:06:06 +0000 Subject: [PATCH 118/546] update --- torch_xla/experimental/fori_loop.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 2de46cdd938..a11af64db3a 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -64,6 +64,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) # fake_carried_inputs = tuple(fake_carried_inputs) + print("fake_carried_inputs first: ", fake_carried_inputs) for additional_input in additional_inputs: device = additional_input.device #TODO(@manfei) type = carried_input.type @@ -71,10 +72,11 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): torch.randint(10, additional_input.size(), dtype=additional_input.dtype).to(device)) fake_carried_inputs = tuple(fake_carried_inputs) + print("fake_carried_inputs second: ", fake_carried_inputs) # generate cond_fn xlacomputation # TODO(@manfei): specify which element is for which argument like a,b,c - cond_result = cond_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-2], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-1]) + cond_result = cond_fn(*fake_carried_inputs) # [:-3], weight_0=fake_carried_inputs[-2], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-1]) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor @@ -87,7 +89,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): cond_hlo) # generate body_fn xlacomputation - body_result = body_fn(*fake_carried_inputs[:-3], weight_0=fake_carried_inputs[-1], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-2]) + body_result = body_fn(*fake_carried_inputs) # [:-3], weight_0=fake_carried_inputs[-1], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-2]) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") additional_inputs_list_body = [fake_carried_inputs[-2]] # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor From 334036ed6778841935683c208a29ad7559507add Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:09:01 +0000 Subject: [PATCH 119/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- torch_xla/experimental/fori_loop.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index b1ad2095c3d..4e6f9585dd3 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -129,7 +129,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): aaa = {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = aaa - print("aaa: ", aaa) + # print("aaa: ", aaa) bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) print("bbb: ", bbb) # print("start test 6 !!!") diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index a11af64db3a..4396d25f184 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -52,6 +52,8 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): print("!!! arrive here too too !!!") + print("carried_inputs: ", carried_inputs) + print("additional_inputs: ", additional_inputs) # import pdb; pdb.set_trace() # untuple carried_inputs from while_loop carried_inputs = carried_inputs[0] From acc2e205f131fc6b6e7392cbb0c0f470fe2780c0 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:10:54 +0000 Subject: [PATCH 120/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 4e6f9585dd3..c3a8a05bd3b 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -112,7 +112,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value_real # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) - return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) + return 1 + # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() upper = torch.tensor([52], dtype=torch.int32, device=device) From 3e7172c80b1b992a31e5f02378433b7cd9f1df34 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:11:44 +0000 Subject: [PATCH 121/546] update --- ...fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index c3a8a05bd3b..cc72cfa995c 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -7,7 +7,7 @@ # We need to import the underlying implementation function to register with the dispatcher import torch_xla.experimental.fori_loop from torch_xla.experimental.fori_loop import fori_loop -from torch_xla.experimental.fori_loop import _xla_while_loop +# from torch_xla.experimental.fori_loop import _xla_while_loop from torch._higher_order_ops.while_loop import while_loop import torch_xla.core.xla_model as xm import torch_xla.core.xla_builder as xb @@ -112,7 +112,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value_real # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) - return 1 + # return 1 + return upper, lower, one_value, x, input_value, output_value # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() From ead598174e2180a28a4bebab7e6180195bfbdd35 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:12:54 +0000 Subject: [PATCH 122/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index cc72cfa995c..c12bcb48120 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -113,7 +113,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 - return upper, lower, one_value, x, input_value, output_value + # return upper, lower, one_value, x, input_value, output_value + return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() From 4762c554542215060c0729a3dfbb714e5f02fdda Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:13:55 +0000 Subject: [PATCH 123/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index c12bcb48120..0041e2f8d0a 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -113,8 +113,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 - # return upper, lower, one_value, x, input_value, output_value - return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) + return upper, lower, one_value, x, input_value, output_value + # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() From e4ff32fadbdf067d4734a0e92abad51ddd5b9038 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:14:41 +0000 Subject: [PATCH 124/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 0041e2f8d0a..5f56aaa7a9f 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -109,7 +109,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): output_value_real = linear_0(input_value) weight = linear_0.weight bias = linear_0.bias - return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value_real + # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value_real + return upper, new_lower, one_value, torch.add(one_value, x), input_value, output_value_real # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 From 7215162b92e088a584c31800646329c5baf9182a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:15:37 +0000 Subject: [PATCH 125/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 5f56aaa7a9f..46b3d8aa25f 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -114,8 +114,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 - return upper, lower, one_value, x, input_value, output_value - # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) + # return upper, lower, one_value, x, input_value, output_value + return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() From f275c0ee26805bb19ec2bb941d2b2f612668e200 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:16:07 +0000 Subject: [PATCH 126/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 46b3d8aa25f..5f56aaa7a9f 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -114,8 +114,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 - # return upper, lower, one_value, x, input_value, output_value - return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) + return upper, lower, one_value, x, input_value, output_value + # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() From 2958d830163c113afcd53da8a1f14278355f9faa Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:16:53 +0000 Subject: [PATCH 127/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 5f56aaa7a9f..46b3d8aa25f 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -114,8 +114,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 - return upper, lower, one_value, x, input_value, output_value - # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) + # return upper, lower, one_value, x, input_value, output_value + return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() From 9db53a3cb8ea4b4120f6ceddbadbdcf4435bc594 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:35:04 +0000 Subject: [PATCH 128/546] update --- ...fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 46b3d8aa25f..db2230ad85f 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -109,8 +109,11 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): output_value_real = linear_0(input_value) weight = linear_0.weight bias = linear_0.bias + new_upper = upper + new_one_value = one_value + new_input_value = input_value # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value_real - return upper, new_lower, one_value, torch.add(one_value, x), input_value, output_value_real + return new_upper, new_lower, new_one_value, torch.add(one_value, x), new_input_value, output_value_real # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 From d35e48ef4fbf85d03ec86e31e4b855a371ee229f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:36:11 +0000 Subject: [PATCH 129/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index db2230ad85f..666ee191109 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -113,7 +113,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): new_one_value = one_value new_input_value = input_value # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value_real - return new_upper, new_lower, new_one_value, torch.add(one_value, x), new_input_value, output_value_real + return upper.copy(), lower.copy(), one_value.copy(), torch.add(one_value, x), input_value.copy(), output_value_real # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 From 25ac04ed98217144733c3bc4ab3ac23b6f974f3f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:37:47 +0000 Subject: [PATCH 130/546] update --- ...i_loop_with_while_loop_simple_add_dispatch_in_torch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 666ee191109..19961d30e90 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -109,11 +109,11 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): output_value_real = linear_0(input_value) weight = linear_0.weight bias = linear_0.bias - new_upper = upper - new_one_value = one_value - new_input_value = input_value + # new_upper = upper + # new_one_value = one_value + # new_input_value = input_value # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value_real - return upper.copy(), lower.copy(), one_value.copy(), torch.add(one_value, x), input_value.copy(), output_value_real + return upper.clone(), lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 From 3f473ad15ad38ebd05dcf8aff1656a03c5afbf30 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:42:21 +0000 Subject: [PATCH 131/546] update --- torch_xla/experimental/fori_loop.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 4396d25f184..cc15991f7a6 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -52,11 +52,13 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): print("!!! arrive here too too !!!") - print("carried_inputs: ", carried_inputs) - print("additional_inputs: ", additional_inputs) + # print("carried_inputs: ", carried_inputs) + # print("additional_inputs: ", additional_inputs) # import pdb; pdb.set_trace() # untuple carried_inputs from while_loop carried_inputs = carried_inputs[0] + # TODO(@manfei): please clear pass additional_inputs in `while_loop`'s defination in this file + additional_inputs = carried_inputs[0] # fake carried_inputs to split formal code fake_carried_inputs = [] for carried_input in carried_inputs: From b31e28062a05094110d824be21de677ccab89a39 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:44:07 +0000 Subject: [PATCH 132/546] update --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index cc15991f7a6..af014058009 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -52,8 +52,8 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): print("!!! arrive here too too !!!") - # print("carried_inputs: ", carried_inputs) - # print("additional_inputs: ", additional_inputs) + print("carried_inputs: ", carried_inputs) + print("additional_inputs: ", additional_inputs) # import pdb; pdb.set_trace() # untuple carried_inputs from while_loop carried_inputs = carried_inputs[0] From 845d903893ecb04385e17652117ef0fd7795ca20 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:45:32 +0000 Subject: [PATCH 133/546] update --- torch_xla/experimental/fori_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index af014058009..0da353ba660 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -50,15 +50,15 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): return _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) # a=a, b=b, c=c, -def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): +def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_inputs=()): print("!!! arrive here too too !!!") print("carried_inputs: ", carried_inputs) print("additional_inputs: ", additional_inputs) # import pdb; pdb.set_trace() # untuple carried_inputs from while_loop - carried_inputs = carried_inputs[0] + carried_inputs = original_carried_inputs[0] # TODO(@manfei): please clear pass additional_inputs in `while_loop`'s defination in this file - additional_inputs = carried_inputs[0] + additional_inputs = original_carried_inputs[1] # fake carried_inputs to split formal code fake_carried_inputs = [] for carried_input in carried_inputs: From 13157d9b1158a7dcfb67d8e3df638ae1745cffdd Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:46:11 +0000 Subject: [PATCH 134/546] update --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 0da353ba660..def0a5a0e38 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -52,8 +52,8 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_inputs=()): print("!!! arrive here too too !!!") - print("carried_inputs: ", carried_inputs) - print("additional_inputs: ", additional_inputs) + # print("carried_inputs: ", carried_inputs) + # print("additional_inputs: ", additional_inputs) # import pdb; pdb.set_trace() # untuple carried_inputs from while_loop carried_inputs = original_carried_inputs[0] From 41d3ba6d11491e9bd09377880d84014ccab0b536 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:48:12 +0000 Subject: [PATCH 135/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- torch_xla/experimental/fori_loop.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 19961d30e90..d5a190279b1 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -113,7 +113,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # new_one_value = one_value # new_input_value = input_value # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value_real - return upper.clone(), lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real + return upper.clone(), lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), weight, bias, output_value_real # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index def0a5a0e38..cff55d207bb 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -68,7 +68,7 @@ def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_input torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) # fake_carried_inputs = tuple(fake_carried_inputs) - print("fake_carried_inputs first: ", fake_carried_inputs) + # print("fake_carried_inputs first: ", fake_carried_inputs) for additional_input in additional_inputs: device = additional_input.device #TODO(@manfei) type = carried_input.type @@ -76,7 +76,7 @@ def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_input torch.randint(10, additional_input.size(), dtype=additional_input.dtype).to(device)) fake_carried_inputs = tuple(fake_carried_inputs) - print("fake_carried_inputs second: ", fake_carried_inputs) + # print("fake_carried_inputs second: ", fake_carried_inputs) # generate cond_fn xlacomputation # TODO(@manfei): specify which element is for which argument like a,b,c From 52cd5bc9a6a7cc1bfae412b674c60561a9cf3887 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:49:08 +0000 Subject: [PATCH 136/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index d5a190279b1..ecfe7cd16f9 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -113,7 +113,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # new_one_value = one_value # new_input_value = input_value # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value_real - return upper.clone(), lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), weight, bias, output_value_real + return upper.clone(), lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), weight.clone(), bias.clone(), output_value_real # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 From a3756e6ef21ef12773bcfc331016b1fdf80226d8 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:50:48 +0000 Subject: [PATCH 137/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index ecfe7cd16f9..bf015598af4 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -137,7 +137,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): aaa = {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = aaa # print("aaa: ", aaa) - bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) + # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) + bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value) print("bbb: ", bbb) # print("start test 6 !!!") return aaa From 21ba83dbff3d6987c91e624523cf895da80abca9 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:52:54 +0000 Subject: [PATCH 138/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index bf015598af4..3119a992ac3 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -100,7 +100,7 @@ def __init__(self): self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.register_buffer("dec", torch.tensor(1)) - def forward(self, upper, lower, one_value, x, input_value, output_value): + def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] From 24f2fe3a3c89faa5ebfa86efc878d348a323c027 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:54:19 +0000 Subject: [PATCH 139/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 3119a992ac3..439c79b3393 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -118,7 +118,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 # return upper, lower, one_value, x, input_value, output_value - return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) + return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_0, bias_0, output_value)) # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() From 4687bbeee679def69b6790fe3f6c49b2d827f9a2 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:55:47 +0000 Subject: [PATCH 140/546] update --- torch_xla/experimental/fori_loop.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index cff55d207bb..59c925c83ed 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -69,14 +69,14 @@ def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_input dtype=carried_input.dtype).to(device)) # fake_carried_inputs = tuple(fake_carried_inputs) # print("fake_carried_inputs first: ", fake_carried_inputs) - for additional_input in additional_inputs: - device = additional_input.device - #TODO(@manfei) type = carried_input.type - fake_carried_inputs.append( - torch.randint(10, additional_input.size(), - dtype=additional_input.dtype).to(device)) - fake_carried_inputs = tuple(fake_carried_inputs) - # print("fake_carried_inputs second: ", fake_carried_inputs) + # for additional_input in additional_inputs: + # device = additional_input.device + # #TODO(@manfei) type = carried_input.type + # fake_carried_inputs.append( + # torch.randint(10, additional_input.size(), + # dtype=additional_input.dtype).to(device)) + # fake_carried_inputs = tuple(fake_carried_inputs) + # # print("fake_carried_inputs second: ", fake_carried_inputs) # generate cond_fn xlacomputation # TODO(@manfei): specify which element is for which argument like a,b,c From 8713b1c4ba79e889e00fb56eee51445f73cae70b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:56:36 +0000 Subject: [PATCH 141/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 439c79b3393..3ae1e5847ed 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -101,10 +101,10 @@ def __init__(self): # self.register_buffer("dec", torch.tensor(1)) def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - def cond_fn(upper, lower, one_value, x, input_value, output_value): + def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, output_value): + def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): new_lower = torch.add(one_value, lower) output_value_real = linear_0(input_value) weight = linear_0.weight From 05f1c33034020b678e8fb3056773ba12ed5331bb Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 12 Apr 2024 23:57:49 +0000 Subject: [PATCH 142/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 3ae1e5847ed..439c79b3393 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -101,10 +101,10 @@ def __init__(self): # self.register_buffer("dec", torch.tensor(1)) def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = linear_0(input_value) weight = linear_0.weight From d96d7355bcf681b5a2ecb137dbad700f7d08a6c0 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Sat, 13 Apr 2024 00:02:52 +0000 Subject: [PATCH 143/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 439c79b3393..3ae1e5847ed 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -101,10 +101,10 @@ def __init__(self): # self.register_buffer("dec", torch.tensor(1)) def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - def cond_fn(upper, lower, one_value, x, input_value, output_value): + def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, output_value): + def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): new_lower = torch.add(one_value, lower) output_value_real = linear_0(input_value) weight = linear_0.weight From e12bda67c7f8447ec259e5b6494418c05e946348 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Sat, 13 Apr 2024 00:04:29 +0000 Subject: [PATCH 144/546] update --- ...fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 3ae1e5847ed..608460153a4 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -100,7 +100,10 @@ def __init__(self): self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.register_buffer("dec", torch.tensor(1)) - def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def forward(self, upper, lower, one_value, x, input_value, output_value): + weight_0 = linear_0.weight + bias_0 = linear_0.bias + def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): return lower[0] < upper[0] From 8ed8f12ec102a3d61d0d6ab8914bcc9486fd2181 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Sat, 13 Apr 2024 00:05:18 +0000 Subject: [PATCH 145/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 608460153a4..88cec10543d 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -100,10 +100,9 @@ def __init__(self): self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.register_buffer("dec", torch.tensor(1)) - def forward(self, upper, lower, one_value, x, input_value, output_value): + def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): weight_0 = linear_0.weight bias_0 = linear_0.bias - def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): return lower[0] < upper[0] From 094a4cafc2f597652106a2098e55abee9c0f5e39 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Sat, 13 Apr 2024 00:05:43 +0000 Subject: [PATCH 146/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 88cec10543d..51a2cae47b2 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -101,8 +101,8 @@ def __init__(self): # self.register_buffer("dec", torch.tensor(1)) def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - weight_0 = linear_0.weight - bias_0 = linear_0.bias + weight_0 = self.linear.weight + bias_0 = self.linear.bias def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): return lower[0] < upper[0] From 529a8c824f43a72c9f5e035cae512cff88b9c0ca Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Sat, 13 Apr 2024 00:10:15 +0000 Subject: [PATCH 147/546] update --- ...ori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 51a2cae47b2..0e17da9d940 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -101,8 +101,8 @@ def __init__(self): # self.register_buffer("dec", torch.tensor(1)) def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - weight_0 = self.linear.weight - bias_0 = self.linear.bias + weight_1 = self.linear.weight + bias_1 = self.linear.bias def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): return lower[0] < upper[0] @@ -120,7 +120,7 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_va # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 # return upper, lower, one_value, x, input_value, output_value - return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_0, bias_0, output_value)) + return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_0, bias_0, output_value), (weight_1, bias_1)) # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() From 57083aef6460ddd81f771be54abeb2fde0b1d25a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Sat, 13 Apr 2024 00:10:59 +0000 Subject: [PATCH 148/546] update --- ...i_loop_with_while_loop_simple_add_dispatch_in_torch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 0e17da9d940..a579d109ed6 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -108,9 +108,9 @@ def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_va def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): new_lower = torch.add(one_value, lower) - output_value_real = linear_0(input_value) - weight = linear_0.weight - bias = linear_0.bias + output_value_real = self.linear(input_value) + weight = self.linear.weight + bias = self.linear.bias # new_upper = upper # new_one_value = one_value # new_input_value = input_value @@ -120,7 +120,7 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_va # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 # return upper, lower, one_value, x, input_value, output_value - return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_0, bias_0, output_value), (weight_1, bias_1)) + return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_0, bias_0, output_value)) # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() From 630caa9b88a38e677cce523b332821817b999172 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Sat, 13 Apr 2024 00:12:54 +0000 Subject: [PATCH 149/546] update --- ...i_loop_with_while_loop_simple_add_dispatch_in_torch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index a579d109ed6..182f5fd6985 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -132,15 +132,15 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_va l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) - linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - weight_0 = linear_0.weight - bias_0 = linear_0.bias + # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + weight_0 = simple_with_linear.linear.weight + bias_0 = simple_with_linear.linear.bias aaa = {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = aaa # print("aaa: ", aaa) # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) - bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value) + bbb = simple_with_linear((upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value), (weight_0, bias_0)) print("bbb: ", bbb) # print("start test 6 !!!") return aaa From 406635b973fed0212d4dd1f851fca581a2613be3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Sat, 13 Apr 2024 00:14:34 +0000 Subject: [PATCH 150/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 182f5fd6985..0fdcc385493 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -140,7 +140,7 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_va # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = aaa # print("aaa: ", aaa) # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) - bbb = simple_with_linear((upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value), (weight_0, bias_0)) + bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value, weight_0, bias_0) print("bbb: ", bbb) # print("start test 6 !!!") return aaa From 992546501e3f896dcea1a07601b80241988d4858 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Sat, 13 Apr 2024 00:15:20 +0000 Subject: [PATCH 151/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 0fdcc385493..b19e0c80243 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -140,7 +140,7 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_va # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = aaa # print("aaa: ", aaa) # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) - bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value, weight_0, bias_0) + bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value) # , weight_0, bias_0) print("bbb: ", bbb) # print("start test 6 !!!") return aaa From 9b1b7a72ffeb26a1ebb0044e7efb724c55bc4209 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 18:28:47 +0000 Subject: [PATCH 152/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index b19e0c80243..2d64dcf1bab 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -100,7 +100,8 @@ def __init__(self): self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.register_buffer("dec", torch.tensor(1)) - def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def forward(self, upper, lower, one_value, x, input_value, output_value): weight_1 = self.linear.weight bias_1 = self.linear.bias def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): From 3cf020071931568b7df7bc8ff249c977a64c10bd Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 18:31:32 +0000 Subject: [PATCH 153/546] update --- ...ri_loop_with_while_loop_simple_add_dispatch_in_torch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 2d64dcf1bab..40c81de463c 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -100,8 +100,8 @@ def __init__(self): self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.register_buffer("dec", torch.tensor(1)) - # def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - def forward(self, upper, lower, one_value, x, input_value, output_value): + def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def forward(self, upper, lower, one_value, x, input_value, output_value): weight_1 = self.linear.weight bias_1 = self.linear.bias def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): @@ -141,7 +141,8 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_va # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = aaa # print("aaa: ", aaa) # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) - bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value) # , weight_0, bias_0) + # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value) # , weight_0, bias_0) + bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) print("bbb: ", bbb) # print("start test 6 !!!") return aaa From 1b06687e1fb551725067aa729394fdfd3831fff7 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 18:32:15 +0000 Subject: [PATCH 154/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 40c81de463c..f663f5fde04 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -142,7 +142,7 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_va # print("aaa: ", aaa) # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value) # , weight_0, bias_0) - bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) + bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value) print("bbb: ", bbb) # print("start test 6 !!!") return aaa From 13b77670d9e3df788ea01d561e4fcd9d864c973a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 18:35:35 +0000 Subject: [PATCH 155/546] update --- ...th_while_loop_simple_add_dispatch_in_torch.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index f663f5fde04..63bf159f0c8 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -100,14 +100,16 @@ def __init__(self): self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.register_buffer("dec", torch.tensor(1)) - def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - # def forward(self, upper, lower, one_value, x, input_value, output_value): + # def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def forward(self, upper, lower, one_value, x, input_value, output_value): weight_1 = self.linear.weight bias_1 = self.linear.bias - def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = self.linear(input_value) weight = self.linear.weight @@ -121,7 +123,8 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_va # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 # return upper, lower, one_value, x, input_value, output_value - return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_0, bias_0, output_value)) + # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_0, bias_0, output_value)) + return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_1, bias_1, output_value)) # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() @@ -142,7 +145,8 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_va # print("aaa: ", aaa) # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value) # , weight_0, bias_0) - bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value) + # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value) + bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) print("bbb: ", bbb) # print("start test 6 !!!") return aaa From 90f6df1db6c43025994bce2bf01356370736b65d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 18:36:36 +0000 Subject: [PATCH 156/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 63bf159f0c8..fa69205cf75 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -124,8 +124,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return 1 # return upper, lower, one_value, x, input_value, output_value # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_0, bias_0, output_value)) - return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_1, bias_1, output_value)) - # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) + # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_1, bias_1, output_value)) + return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() upper = torch.tensor([52], dtype=torch.int32, device=device) From 732876b7f78015781359dc148a6fef0c97734760 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 18:37:04 +0000 Subject: [PATCH 157/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index fa69205cf75..94a4e753f10 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -7,7 +7,7 @@ # We need to import the underlying implementation function to register with the dispatcher import torch_xla.experimental.fori_loop from torch_xla.experimental.fori_loop import fori_loop -# from torch_xla.experimental.fori_loop import _xla_while_loop +from torch_xla.experimental.fori_loop import _xla_while_loop from torch._higher_order_ops.while_loop import while_loop import torch_xla.core.xla_model as xm import torch_xla.core.xla_builder as xb From 1ec25fef85e3267f2530923e9ce9582a3d866947 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 18:38:11 +0000 Subject: [PATCH 158/546] update --- ...ori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 6 +++--- torch_xla/experimental/fori_loop.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 94a4e753f10..63bf159f0c8 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -7,7 +7,7 @@ # We need to import the underlying implementation function to register with the dispatcher import torch_xla.experimental.fori_loop from torch_xla.experimental.fori_loop import fori_loop -from torch_xla.experimental.fori_loop import _xla_while_loop +# from torch_xla.experimental.fori_loop import _xla_while_loop from torch._higher_order_ops.while_loop import while_loop import torch_xla.core.xla_model as xm import torch_xla.core.xla_builder as xb @@ -124,8 +124,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return 1 # return upper, lower, one_value, x, input_value, output_value # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_0, bias_0, output_value)) - # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_1, bias_1, output_value)) - return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) + return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_1, bias_1, output_value)) + # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() upper = torch.tensor([52], dtype=torch.int32, device=device) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 59c925c83ed..7d53701a444 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -58,7 +58,8 @@ def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_input # untuple carried_inputs from while_loop carried_inputs = original_carried_inputs[0] # TODO(@manfei): please clear pass additional_inputs in `while_loop`'s defination in this file - additional_inputs = original_carried_inputs[1] + if len(original_carried_inputs) == 2: + additional_inputs = original_carried_inputs[1] # fake carried_inputs to split formal code fake_carried_inputs = [] for carried_input in carried_inputs: From fdb7cabb2c77829c2b424dd4ec56215aac4ba4df Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 18:39:02 +0000 Subject: [PATCH 159/546] update --- ...i_loop_with_while_loop_simple_add_dispatch_in_torch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 63bf159f0c8..1ca0d4aa8fb 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -104,12 +104,12 @@ def __init__(self): def forward(self, upper, lower, one_value, x, input_value, output_value): weight_1 = self.linear.weight bias_1 = self.linear.bias - # def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - def cond_fn(upper, lower, one_value, x, input_value, output_value): + def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] - # def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - def body_fn(upper, lower, one_value, x, input_value, output_value): + def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = self.linear(input_value) weight = self.linear.weight From 17d43eff82536499381458e7bcbea00440274fa9 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 18:41:07 +0000 Subject: [PATCH 160/546] update --- torch_xla/experimental/fori_loop.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 7d53701a444..a11697e33c5 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -79,6 +79,7 @@ def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_input # fake_carried_inputs = tuple(fake_carried_inputs) # # print("fake_carried_inputs second: ", fake_carried_inputs) + print("!!! arrive here too before cond !!!") # generate cond_fn xlacomputation # TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) # [:-3], weight_0=fake_carried_inputs[-2], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-1]) @@ -92,7 +93,9 @@ def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_input cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) + print("!!! arrive here too after cond !!!") + print("!!! arrive here too before body !!!") # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) # [:-3], weight_0=fake_carried_inputs[-1], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-2]) body_ctx = torch_xla._XLAC.lowering.LoweringContext() @@ -109,7 +112,9 @@ def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_input body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) + print("!!! arrive here too after body !!!") + print("!!! arrive here too before args!!!") # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while kwargs = {} if type(carried_inputs) is tuple: @@ -130,10 +135,13 @@ def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_input body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) + print("!!! arrive here too after args!!!") + print("!!! arrive here too before while!!!") # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', (carried_inputs), computation) + print("!!! arrive here too after while!!!") return result \ No newline at end of file From 0016420a184ca374cb32ba96c568add2cf9a4df4 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 18:42:14 +0000 Subject: [PATCH 161/546] update --- torch_xla/experimental/fori_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index a11697e33c5..c9decb56964 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -87,8 +87,10 @@ def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_input cond_ctx.set_name_string("condctx") additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor # treat and pass additional_inputs to cond_fn + print("additional_inputs_list_cond one: ", additional_inputs_list_cond) for i in range(len(additional_inputs)): additional_inputs_list_cond.append(additional_inputs[i]) + print("additional_inputs_list_cond two: ", additional_inputs_list_cond) cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", From 42745757d6d08f90622e21f7f8f9012b0311eddf Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 18:43:01 +0000 Subject: [PATCH 162/546] update --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index c9decb56964..048da5eb9a2 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -83,6 +83,7 @@ def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_input # generate cond_fn xlacomputation # TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) # [:-3], weight_0=fake_carried_inputs[-2], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-1]) + print("nnn here ???") cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor From 4fd2e4a0cb6b4edd3c89bd5a2d4223138154a763 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 18:43:44 +0000 Subject: [PATCH 163/546] update --- torch_xla/experimental/fori_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 048da5eb9a2..0d89dfce829 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -81,9 +81,10 @@ def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_input print("!!! arrive here too before cond !!!") # generate cond_fn xlacomputation + print("print fake_carried_inputs: ", fake_carried_inputs) # TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) # [:-3], weight_0=fake_carried_inputs[-2], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-1]) - print("nnn here ???") + # print("nnn here ???") cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor From b06290031ed720bb41fe2ae8da1817faa27c45a6 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 18:45:29 +0000 Subject: [PATCH 164/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 1ca0d4aa8fb..93eab72bc9b 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -100,8 +100,8 @@ def __init__(self): self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.register_buffer("dec", torch.tensor(1)) - # def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - def forward(self, upper, lower, one_value, x, input_value, output_value): + def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def forward(self, upper, lower, one_value, x, input_value, output_value): weight_1 = self.linear.weight bias_1 = self.linear.bias def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): From c2f29737be0856f08302d13367ba8e49b68b2dc7 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 18:46:06 +0000 Subject: [PATCH 165/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 93eab72bc9b..1ca0d4aa8fb 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -100,8 +100,8 @@ def __init__(self): self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.register_buffer("dec", torch.tensor(1)) - def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - # def forward(self, upper, lower, one_value, x, input_value, output_value): + # def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def forward(self, upper, lower, one_value, x, input_value, output_value): weight_1 = self.linear.weight bias_1 = self.linear.bias def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): From 8dbde24acbfcd0312728b8d02a1ba7be8be5a7b3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 18:47:08 +0000 Subject: [PATCH 166/546] update --- ...ri_loop_with_while_loop_simple_add_dispatch_in_torch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 1ca0d4aa8fb..4e93f32d48b 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -104,11 +104,14 @@ def __init__(self): def forward(self, upper, lower, one_value, x, input_value, output_value): weight_1 = self.linear.weight bias_1 = self.linear.bias - def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + + # def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def cond_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): # def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def body_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): # def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = self.linear(input_value) From 066047c56280cd1204a41c38377fa0eb3bd5ac7b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 20:19:07 +0000 Subject: [PATCH 167/546] update --- ..._with_while_loop_simple_add_dispatch_in_torch.py | 13 +++++++------ torch_xla/experimental/fori_loop.py | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 4e93f32d48b..7d12ab6f517 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -102,16 +102,16 @@ def __init__(self): # def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): def forward(self, upper, lower, one_value, x, input_value, output_value): - weight_1 = self.linear.weight - bias_1 = self.linear.bias + # weight_1 = self.linear.weight + # bias_1 = self.linear.bias - # def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - def cond_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): + def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def cond_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): # def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] - # def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - def body_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): + def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def body_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): # def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = self.linear(input_value) @@ -128,6 +128,7 @@ def body_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_va # return upper, lower, one_value, x, input_value, output_value # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_0, bias_0, output_value)) return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_1, bias_1, output_value)) + # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 0d89dfce829..600950f48ac 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -84,7 +84,7 @@ def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_input print("print fake_carried_inputs: ", fake_carried_inputs) # TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) # [:-3], weight_0=fake_carried_inputs[-2], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-1]) - # print("nnn here ???") + print("nnn here ???") cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor From 374d16d06cc541ff5f301764372d487aa74fa4ff Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 20:19:48 +0000 Subject: [PATCH 168/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 7d12ab6f517..6133d205242 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -127,8 +127,8 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_va # return 1 # return upper, lower, one_value, x, input_value, output_value # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_0, bias_0, output_value)) - return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_1, bias_1, output_value)) - # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) + # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_1, bias_1, output_value)) + return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() From 1f30e268996448e19935b8289cde2323633e3e03 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 20:21:34 +0000 Subject: [PATCH 169/546] update --- ...i_loop_with_while_loop_simple_add_dispatch_in_torch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 6133d205242..6162c3b921f 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -105,14 +105,14 @@ def forward(self, upper, lower, one_value, x, input_value, output_value): # weight_1 = self.linear.weight # bias_1 = self.linear.bias - def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def cond_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): - # def cond_fn(upper, lower, one_value, x, input_value, output_value): + def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def body_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): - # def body_fn(upper, lower, one_value, x, input_value, output_value): + def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = self.linear(input_value) weight = self.linear.weight From fc85bc3e1dcf7cf80dea7714d359cf7503b2251a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 20:23:17 +0000 Subject: [PATCH 170/546] update --- torch_xla/experimental/fori_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 600950f48ac..fab9ba09c50 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -53,12 +53,13 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_inputs=()): print("!!! arrive here too too !!!") # print("carried_inputs: ", carried_inputs) - # print("additional_inputs: ", additional_inputs) + print("additional_inputs: ", additional_inputs) # import pdb; pdb.set_trace() # untuple carried_inputs from while_loop carried_inputs = original_carried_inputs[0] # TODO(@manfei): please clear pass additional_inputs in `while_loop`'s defination in this file if len(original_carried_inputs) == 2: + print("use original_carried_inputs for additional_inputs") additional_inputs = original_carried_inputs[1] # fake carried_inputs to split formal code fake_carried_inputs = [] From 6ebda73b33bb1e61279b513f67bac44f20214c45 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 20:24:25 +0000 Subject: [PATCH 171/546] update --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index fab9ba09c50..097f67687a3 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -45,6 +45,7 @@ def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): # cond_fn&body_fn: callable # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) print("!!! arrive here too !!!") + print("while_loop additional_inputs: ", additional_inputs) if additional_inputs is None: additional_inputs = tuple() return _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) # a=a, b=b, c=c, From 55c2ea11432f9c0de4fcf2f6535a2abe072780e2 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 20:30:32 +0000 Subject: [PATCH 172/546] update --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 097f67687a3..31bc893f1d9 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -40,7 +40,7 @@ def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bi @while_loop_op.py_impl(DispatchKey.XLA) -def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): +def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') # cond_fn&body_fn: callable # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) From 36d256525fc2bd140abaaec4794e154e3f0e1f81 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 20:38:12 +0000 Subject: [PATCH 173/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 6162c3b921f..c8276bfb16d 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -92,7 +92,7 @@ def test_while_loop_tpu_simple_linear(self): xm.mark_step() device = xm.xla_device() - torch.set_grad_enabled(False) + # torch.set_grad_enabled(False) class SimpleWithLinear(torch.nn.Module): def __init__(self): From 4a500c019914096ad1398b13b4403b34ab575c54 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 20:39:39 +0000 Subject: [PATCH 174/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index c8276bfb16d..6162c3b921f 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -92,7 +92,7 @@ def test_while_loop_tpu_simple_linear(self): xm.mark_step() device = xm.xla_device() - # torch.set_grad_enabled(False) + torch.set_grad_enabled(False) class SimpleWithLinear(torch.nn.Module): def __init__(self): From 3a2553e4b9335c2b74ca117cfd500d72dac9823d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 20:50:00 +0000 Subject: [PATCH 175/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 6162c3b921f..152606a7d6a 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -128,7 +128,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return upper, lower, one_value, x, input_value, output_value # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_0, bias_0, output_value)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_1, bias_1, output_value)) - return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) + weight_1 = self.linear.weight + bias_1 = self.linear.bias + return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value), (bias_1, weight_1)) # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() From 0998b374e1344edb7d903be54e76ff9fb3ccb1a6 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 20:51:08 +0000 Subject: [PATCH 176/546] update --- ...ri_loop_with_while_loop_simple_add_dispatch_in_torch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 152606a7d6a..7fd99631943 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -128,9 +128,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return upper, lower, one_value, x, input_value, output_value # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_0, bias_0, output_value)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_1, bias_1, output_value)) - weight_1 = self.linear.weight - bias_1 = self.linear.bias - return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value), (bias_1, weight_1)) + # weight_1 = self.linear.weight + # bias_1 = self.linear.bias + # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value), (bias_1, weight_1)) + return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() From 267e8b3460fe4b2e9a3779383843ec8e19c33acc Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 20:54:12 +0000 Subject: [PATCH 177/546] update --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 31bc893f1d9..68db30dd6fa 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -40,7 +40,7 @@ def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bi @while_loop_op.py_impl(DispatchKey.XLA) -def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): +def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') # cond_fn&body_fn: callable # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) From a577859eff905e076e25a06a2b0d70101901f948 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 20:57:08 +0000 Subject: [PATCH 178/546] update --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 68db30dd6fa..fb8f8234b16 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -45,7 +45,7 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): # cond_fn&body_fn: callable # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) print("!!! arrive here too !!!") - print("while_loop additional_inputs: ", additional_inputs) + # print("while_loop additional_inputs: ", additional_inputs) if additional_inputs is None: additional_inputs = tuple() return _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) # a=a, b=b, c=c, @@ -53,7 +53,7 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_inputs=()): print("!!! arrive here too too !!!") - # print("carried_inputs: ", carried_inputs) + print("carried_inputs: ", carried_inputs) print("additional_inputs: ", additional_inputs) # import pdb; pdb.set_trace() # untuple carried_inputs from while_loop From a3ee72aefe21c863b4cfc052ab82a9468cc85b9e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 20:57:39 +0000 Subject: [PATCH 179/546] update --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index fb8f8234b16..8c8d47ca494 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -53,7 +53,7 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_inputs=()): print("!!! arrive here too too !!!") - print("carried_inputs: ", carried_inputs) + print("original_carried_inputs: ", original_carried_inputs) print("additional_inputs: ", additional_inputs) # import pdb; pdb.set_trace() # untuple carried_inputs from while_loop From 293b87aa9bf7729c995345e3964b3eb311c78a04 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 21:00:34 +0000 Subject: [PATCH 180/546] update --- torch_xla/experimental/fori_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 8c8d47ca494..aceb585ab45 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -53,11 +53,11 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_inputs=()): print("!!! arrive here too too !!!") - print("original_carried_inputs: ", original_carried_inputs) - print("additional_inputs: ", additional_inputs) + # print("original_carried_inputs: ", original_carried_inputs) + # print("additional_inputs: ", additional_inputs) # import pdb; pdb.set_trace() # untuple carried_inputs from while_loop - carried_inputs = original_carried_inputs[0] + # carried_inputs = original_carried_inputs[0] # due to PyTorch has already treat them , so skip split here # TODO(@manfei): please clear pass additional_inputs in `while_loop`'s defination in this file if len(original_carried_inputs) == 2: print("use original_carried_inputs for additional_inputs") From 128a3dcf6a37dd959716777fcb8edb9447556cad Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 21:02:13 +0000 Subject: [PATCH 181/546] update --- ...op_with_while_loop_simple_add_dispatch_in_torch.py | 8 ++++---- torch_xla/experimental/fori_loop.py | 11 ++++++----- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 7fd99631943..8e5eead8fc3 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -105,14 +105,14 @@ def forward(self, upper, lower, one_value, x, input_value, output_value): # weight_1 = self.linear.weight # bias_1 = self.linear.bias - # def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def cond_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): - def cond_fn(upper, lower, one_value, x, input_value, output_value): + # def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] - # def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def body_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): - def body_fn(upper, lower, one_value, x, input_value, output_value): + # def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = self.linear(input_value) weight = self.linear.weight diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index aceb585ab45..1b874ed9db2 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -51,17 +51,18 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): return _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) # a=a, b=b, c=c, -def _xla_while_loop(cond_fn, body_fn, *original_carried_inputs, additional_inputs=()): +def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): print("!!! arrive here too too !!!") # print("original_carried_inputs: ", original_carried_inputs) # print("additional_inputs: ", additional_inputs) # import pdb; pdb.set_trace() # untuple carried_inputs from while_loop - # carried_inputs = original_carried_inputs[0] # due to PyTorch has already treat them , so skip split here + # carried_inputs = original_carried_inputs[0] ### due to PyTorch has already treat them , so skip split here # TODO(@manfei): please clear pass additional_inputs in `while_loop`'s defination in this file - if len(original_carried_inputs) == 2: - print("use original_carried_inputs for additional_inputs") - additional_inputs = original_carried_inputs[1] + ### due to PyTorch has already treat them , so skip split here + # if len(original_carried_inputs) == 2: + # print("use original_carried_inputs for additional_inputs") + # additional_inputs = original_carried_inputs[1] # fake carried_inputs to split formal code fake_carried_inputs = [] for carried_input in carried_inputs: From f5298a59da61247956280cc274f478d0edfc6c39 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 21:03:27 +0000 Subject: [PATCH 182/546] update --- ...i_loop_with_while_loop_simple_add_dispatch_in_torch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 8e5eead8fc3..7fd99631943 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -105,14 +105,14 @@ def forward(self, upper, lower, one_value, x, input_value, output_value): # weight_1 = self.linear.weight # bias_1 = self.linear.bias - def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def cond_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): - # def cond_fn(upper, lower, one_value, x, input_value, output_value): + def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def body_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): - # def body_fn(upper, lower, one_value, x, input_value, output_value): + def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = self.linear(input_value) weight = self.linear.weight From e4104a4719837c6071ee2d1088af757f64ca0cf2 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 21:05:51 +0000 Subject: [PATCH 183/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 7fd99631943..a4566395afe 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -101,7 +101,8 @@ def __init__(self): # self.register_buffer("dec", torch.tensor(1)) # def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - def forward(self, upper, lower, one_value, x, input_value, output_value): + # def forward(self, upper, lower, one_value, x, input_value, output_value): + def forward(self, upper, lower, one_value, x, input_value, output_value, bias_0, weight_0): # weight_1 = self.linear.weight # bias_1 = self.linear.bias From 2ef7d32a0a3465de9f0bb4f8b59026df74f0b407 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 21:07:16 +0000 Subject: [PATCH 184/546] update --- ...p_with_while_loop_simple_add_dispatch_in_torch.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index a4566395afe..0bce0779c6b 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -100,20 +100,20 @@ def __init__(self): self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.register_buffer("dec", torch.tensor(1)) - # def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def forward(self, upper, lower, one_value, x, input_value, output_value): - def forward(self, upper, lower, one_value, x, input_value, output_value, bias_0, weight_0): + # def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # weight_1 = self.linear.weight # bias_1 = self.linear.bias - # def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def cond_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): - def cond_fn(upper, lower, one_value, x, input_value, output_value): + # def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] - # def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def body_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): - def body_fn(upper, lower, one_value, x, input_value, output_value): + # def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = self.linear(input_value) weight = self.linear.weight From 4affdd72b63047ab0c799bd6bcba595d11aaa158 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 21:09:27 +0000 Subject: [PATCH 185/546] update --- ...p_with_while_loop_simple_add_dispatch_in_torch.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 0bce0779c6b..eb108d59c46 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -100,20 +100,20 @@ def __init__(self): self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.register_buffer("dec", torch.tensor(1)) - def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - # def forward(self, upper, lower, one_value, x, input_value, output_value): + # def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def forward(self, upper, lower, one_value, x, input_value, output_value): # def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # weight_1 = self.linear.weight # bias_1 = self.linear.bias - def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def cond_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): - # def cond_fn(upper, lower, one_value, x, input_value, output_value): + def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + # def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def body_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): - # def body_fn(upper, lower, one_value, x, input_value, output_value): + def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = self.linear(input_value) weight = self.linear.weight From 325425388cdb06b62d56c13713ee6cc3b30d6238 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 21:10:58 +0000 Subject: [PATCH 186/546] update --- torch_xla/experimental/fori_loop.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 1b874ed9db2..29735ee1d38 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -72,15 +72,15 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) # fake_carried_inputs = tuple(fake_carried_inputs) - # print("fake_carried_inputs first: ", fake_carried_inputs) - # for additional_input in additional_inputs: - # device = additional_input.device - # #TODO(@manfei) type = carried_input.type - # fake_carried_inputs.append( - # torch.randint(10, additional_input.size(), - # dtype=additional_input.dtype).to(device)) + print("fake_carried_inputs first: ", fake_carried_inputs) + for additional_input in additional_inputs: + device = additional_input.device + #TODO(@manfei) type = carried_input.type + fake_carried_inputs.append( + torch.randint(10, additional_input.size(), + dtype=additional_input.dtype).to(device)) # fake_carried_inputs = tuple(fake_carried_inputs) - # # print("fake_carried_inputs second: ", fake_carried_inputs) + print("fake_carried_inputs second: ", fake_carried_inputs) print("!!! arrive here too before cond !!!") # generate cond_fn xlacomputation From 92887e70e25c00f9789fe8e567b9f4e1b97412ae Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 21:12:31 +0000 Subject: [PATCH 187/546] update --- torch_xla/csrc/init_python_bindings.cpp | 28 ++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 145739e1d1a..fa29bfdd4c8 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -925,26 +925,26 @@ class PyLoweringContext { // !!! --- next step: we add dump paras according to additional_inputs_list // ??? --- could we get IRvalue of `additional_inputs_list` in this function to complete xlacomputation? int64_t parameter_idx = 2; // parameter_idx start from 2 after upper and lower - for (auto& additional_input_tensor : additional_inputs_list) { - XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); - xla::Shape shape = xtensor->shape().get(); - xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, - "UnusedArgumentsPlaceholder"); - parameter_idx += 1; - } + // for (auto& additional_input_tensor : additional_inputs_list) { + // XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + // xla::Shape shape = xtensor->shape().get(); + // xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + // "UnusedArgumentsPlaceholder"); + // parameter_idx += 1; + // } } // hard-code modify body xlacomputation input arguments if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); int64_t parameter_idx = 7; // tensors.size(); - for (auto& additional_input_tensor : additional_inputs_list) { - XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); - xla::Shape shape = xtensor->shape().get(); - xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, - "UnusedArgumentsPlaceholder"); - parameter_idx += 1; - } + // for (auto& additional_input_tensor : additional_inputs_list) { + // XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + // xla::Shape shape = xtensor->shape().get(); + // xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + // "UnusedArgumentsPlaceholder"); + // parameter_idx += 1; + // } } // Get the backing XLA tensors from the output torch tensor handles From 98425a88746da7d5d70a26ce02a70780252159c8 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 21:14:30 +0000 Subject: [PATCH 188/546] update --- torch_xla/experimental/fori_loop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 29735ee1d38..d7f54af3ee4 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -133,7 +133,10 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): for shape in shapes: p = xb.mkparam(builder, len(params), shape) params.append(p) + print("args params: ", params) + print("!!! arrive here too after args!!!") + print("!!! arrive here too before while!!!") # generate while xlacomputation input_tuple = xb.Op.tuple(tuple(params)) w = xb.mkop( @@ -142,9 +145,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) - print("!!! arrive here too after args!!!") - print("!!! arrive here too before while!!!") # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', (carried_inputs), From 862e3f16f1259819ad7d3a322080de8ef30f5916 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 21:18:22 +0000 Subject: [PATCH 189/546] update --- torch_xla/experimental/fori_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index d7f54af3ee4..94080b31979 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -122,6 +122,8 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): print("!!! arrive here too after body !!!") print("!!! arrive here too before args!!!") + total_inputs = carried_inputs + additional_inputs + print("total_inputs: ", total_inputs) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while kwargs = {} if type(carried_inputs) is tuple: From 6b07e22d1d6eddaed524b89d79b4968c0a196815 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 21:20:13 +0000 Subject: [PATCH 190/546] update --- torch_xla/experimental/fori_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 94080b31979..27d6c214980 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -126,10 +126,10 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): print("total_inputs: ", total_inputs) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while kwargs = {} - if type(carried_inputs) is tuple: - shapes = xb.tensor_shape(carried_inputs) + if type(total_inputs) is tuple: + shapes = xb.tensor_shape(total_inputs) else: - shapes = xb.tensor_shape((carried_inputs)) + shapes = xb.tensor_shape((total_inputs)) builder = xb.create_builder('test_while') params = [] for shape in shapes: From 1edc7bd448559b440036620c2004b7272eeda7e8 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 22:10:38 +0000 Subject: [PATCH 191/546] update --- torch_xla/csrc/init_python_bindings.cpp | 28 ++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index fa29bfdd4c8..145739e1d1a 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -925,26 +925,26 @@ class PyLoweringContext { // !!! --- next step: we add dump paras according to additional_inputs_list // ??? --- could we get IRvalue of `additional_inputs_list` in this function to complete xlacomputation? int64_t parameter_idx = 2; // parameter_idx start from 2 after upper and lower - // for (auto& additional_input_tensor : additional_inputs_list) { - // XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); - // xla::Shape shape = xtensor->shape().get(); - // xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, - // "UnusedArgumentsPlaceholder"); - // parameter_idx += 1; - // } + for (auto& additional_input_tensor : additional_inputs_list) { + XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + xla::Shape shape = xtensor->shape().get(); + xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + "UnusedArgumentsPlaceholder"); + parameter_idx += 1; + } } // hard-code modify body xlacomputation input arguments if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); int64_t parameter_idx = 7; // tensors.size(); - // for (auto& additional_input_tensor : additional_inputs_list) { - // XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); - // xla::Shape shape = xtensor->shape().get(); - // xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, - // "UnusedArgumentsPlaceholder"); - // parameter_idx += 1; - // } + for (auto& additional_input_tensor : additional_inputs_list) { + XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + xla::Shape shape = xtensor->shape().get(); + xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + "UnusedArgumentsPlaceholder"); + parameter_idx += 1; + } } // Get the backing XLA tensors from the output torch tensor handles From 4cfa52231c1fb933fa49802e08dfce1f704d1923 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 22:42:01 +0000 Subject: [PATCH 192/546] update --- torch_xla/experimental/fori_loop.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 27d6c214980..118358ff81b 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -44,7 +44,7 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') # cond_fn&body_fn: callable # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) - print("!!! arrive here too !!!") + print("!!! arrive here @while_loop_op.py_impl(DispatchKey.XLA) !!!") # print("while_loop additional_inputs: ", additional_inputs) if additional_inputs is None: additional_inputs = tuple() @@ -52,7 +52,7 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): - print("!!! arrive here too too !!!") + print("!!! arrive here def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): !!!") # print("original_carried_inputs: ", original_carried_inputs) # print("additional_inputs: ", additional_inputs) # import pdb; pdb.set_trace() @@ -72,7 +72,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) # fake_carried_inputs = tuple(fake_carried_inputs) - print("fake_carried_inputs first: ", fake_carried_inputs) + # print("fake_carried_inputs first: ", fake_carried_inputs) for additional_input in additional_inputs: device = additional_input.device #TODO(@manfei) type = carried_input.type @@ -80,14 +80,14 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): torch.randint(10, additional_input.size(), dtype=additional_input.dtype).to(device)) # fake_carried_inputs = tuple(fake_carried_inputs) - print("fake_carried_inputs second: ", fake_carried_inputs) + # print("fake_carried_inputs second: ", fake_carried_inputs) print("!!! arrive here too before cond !!!") # generate cond_fn xlacomputation - print("print fake_carried_inputs: ", fake_carried_inputs) + # print("print fake_carried_inputs: ", fake_carried_inputs) # TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) # [:-3], weight_0=fake_carried_inputs[-2], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-1]) - print("nnn here ???") + # print("nnn here ???") cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor @@ -135,6 +135,13 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): for shape in shapes: p = xb.mkparam(builder, len(params), shape) params.append(p) + tmp_bias = params[-2] + tmp_output_value = params[-3] + del params[-3] + del params[-2] + params.append(tmp_bias) + params.append(tmp_output_value) + print("args params: ", params) print("!!! arrive here too after args!!!") From edb6fc72624998ce5dafe747f6172dfbb7303a14 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 22:42:52 +0000 Subject: [PATCH 193/546] update --- torch_xla/experimental/fori_loop.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 118358ff81b..b44e79e090c 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -93,9 +93,9 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor # treat and pass additional_inputs to cond_fn print("additional_inputs_list_cond one: ", additional_inputs_list_cond) - for i in range(len(additional_inputs)): - additional_inputs_list_cond.append(additional_inputs[i]) - print("additional_inputs_list_cond two: ", additional_inputs_list_cond) + # for i in range(len(additional_inputs)): + # additional_inputs_list_cond.append(additional_inputs[i]) + # print("additional_inputs_list_cond two: ", additional_inputs_list_cond) cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", @@ -111,8 +111,8 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): # TODO(@manfei): treat and pass additional_inputs to body_fn too # print("list(fake_carried_inputs[-2]: ", fake_carried_inputs[-2]) # print("len0!!!: ", len(additional_inputs_list_body)) - for i in range(len(additional_inputs)): - additional_inputs_list_body.append(additional_inputs[i]) + # for i in range(len(additional_inputs)): + # additional_inputs_list_body.append(additional_inputs[i]) # print("len!!!: ", len(additional_inputs_list_body)) # print("additional_inputs_list_body: ", additional_inputs_list_body) body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) From 99a6589851b4acccd2bc0c279f3f4902e17e3173 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 22:50:09 +0000 Subject: [PATCH 194/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- torch_xla/experimental/fori_loop.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index eb108d59c46..2629b28a543 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -116,13 +116,13 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = self.linear(input_value) - weight = self.linear.weight + weight = self.linear.weight # not be used actually, would be used as bias = self.linear.bias # new_upper = upper # new_one_value = one_value # new_input_value = input_value # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value_real - return upper.clone(), lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), weight.clone(), bias.clone(), output_value_real + return upper.clone(), lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, bias.clone(), weight.clone() # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index b44e79e090c..68e62d902ec 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -107,7 +107,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): body_result = body_fn(*fake_carried_inputs) # [:-3], weight_0=fake_carried_inputs[-1], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-2]) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") - additional_inputs_list_body = [fake_carried_inputs[-2]] # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor + additional_inputs_list_body = [fake_carried_inputs[-3]] # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor # TODO(@manfei): treat and pass additional_inputs to body_fn too # print("list(fake_carried_inputs[-2]: ", fake_carried_inputs[-2]) # print("len0!!!: ", len(additional_inputs_list_body)) From a34cd6fe3390039a3e45ea10dbf737828df3d797 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 22:53:05 +0000 Subject: [PATCH 195/546] update --- torch_xla/experimental/fori_loop.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 68e62d902ec..25c3dc917a7 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -63,6 +63,11 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): # if len(original_carried_inputs) == 2: # print("use original_carried_inputs for additional_inputs") # additional_inputs = original_carried_inputs[1] + + # exchange order of bias and weight in additional_inputs + (bias_p, weight_p) = additional_inputs + additional_inputs = (weight_p, bias_p) + # fake carried_inputs to split formal code fake_carried_inputs = [] for carried_input in carried_inputs: From be562ad96416296b057317b0b69543a5324947ed Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 22:54:09 +0000 Subject: [PATCH 196/546] update --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 25c3dc917a7..337b05c75b3 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -85,7 +85,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): torch.randint(10, additional_input.size(), dtype=additional_input.dtype).to(device)) # fake_carried_inputs = tuple(fake_carried_inputs) - # print("fake_carried_inputs second: ", fake_carried_inputs) + print("fake_carried_inputs second: ", fake_carried_inputs) print("!!! arrive here too before cond !!!") # generate cond_fn xlacomputation From 11c1b546ea9490a7567093ece87c4344e9c880ca Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 22:56:14 +0000 Subject: [PATCH 197/546] update --- torch_xla/experimental/fori_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 337b05c75b3..aa3730e7d7e 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -64,9 +64,9 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): # print("use original_carried_inputs for additional_inputs") # additional_inputs = original_carried_inputs[1] - # exchange order of bias and weight in additional_inputs - (bias_p, weight_p) = additional_inputs - additional_inputs = (weight_p, bias_p) + # # exchange order of bias and weight in additional_inputs + # (bias_p, weight_p) = additional_inputs + # additional_inputs = (weight_p, bias_p) # fake carried_inputs to split formal code fake_carried_inputs = [] From 2480b5b798dadacab4fc0d7da2e1ce329224879a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 22:58:08 +0000 Subject: [PATCH 198/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 2629b28a543..c126636b59e 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -122,7 +122,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # new_one_value = one_value # new_input_value = input_value # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value_real - return upper.clone(), lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, bias.clone(), weight.clone() + return upper.clone(), lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, weight.clone(), bias.clone() # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 From cce486d4e9fdd967d2883d1a711e6693ef5be9d5 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:02:13 +0000 Subject: [PATCH 199/546] update --- torch_xla/experimental/fori_loop.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index aa3730e7d7e..0429b97cf3a 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -96,6 +96,14 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor + + tmp_bias = additional_inputs_list_cond[-2] + tmp_output_value = additional_inputs_list_cond[-3] + del additional_inputs_list_cond[-3] + del additional_inputs_list_cond[-2] + additional_inputs_list_cond.append(tmp_bias) + additional_inputs_list_cond.append(tmp_output_value) + # treat and pass additional_inputs to cond_fn print("additional_inputs_list_cond one: ", additional_inputs_list_cond) # for i in range(len(additional_inputs)): @@ -140,12 +148,12 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): for shape in shapes: p = xb.mkparam(builder, len(params), shape) params.append(p) - tmp_bias = params[-2] - tmp_output_value = params[-3] - del params[-3] - del params[-2] - params.append(tmp_bias) - params.append(tmp_output_value) + # tmp_bias = params[-2] + # tmp_output_value = params[-3] + # del params[-3] + # del params[-2] + # params.append(tmp_bias) + # params.append(tmp_output_value) print("args params: ", params) print("!!! arrive here too after args!!!") From 38d30b7aa47af946acbc3ef60412c51d40cf0c78 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:02:59 +0000 Subject: [PATCH 200/546] update --- torch_xla/experimental/fori_loop.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 0429b97cf3a..4d566b5e768 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -97,12 +97,12 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): cond_ctx.set_name_string("condctx") additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor - tmp_bias = additional_inputs_list_cond[-2] - tmp_output_value = additional_inputs_list_cond[-3] - del additional_inputs_list_cond[-3] - del additional_inputs_list_cond[-2] - additional_inputs_list_cond.append(tmp_bias) - additional_inputs_list_cond.append(tmp_output_value) + tmp_bias = additional_inputs_list_cond[-2] # not used, change order doesn't affect logic + tmp_output_value = additional_inputs_list_cond[-3] # not used, change order doesn't affect logic + del additional_inputs_list_cond[-3] # not used, change order doesn't affect logic + del additional_inputs_list_cond[-2] # not used, change order doesn't affect logic + additional_inputs_list_cond.append(tmp_bias) # not used, change order doesn't affect logic + additional_inputs_list_cond.append(tmp_output_value) # not used, change order doesn't affect logic # treat and pass additional_inputs to cond_fn print("additional_inputs_list_cond one: ", additional_inputs_list_cond) From a85a8e30fb921b3cb577c935e7850332ace44757 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:04:54 +0000 Subject: [PATCH 201/546] update --- torch_xla/experimental/fori_loop.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 4d566b5e768..61d58043175 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -98,11 +98,11 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor tmp_bias = additional_inputs_list_cond[-2] # not used, change order doesn't affect logic - tmp_output_value = additional_inputs_list_cond[-3] # not used, change order doesn't affect logic - del additional_inputs_list_cond[-3] # not used, change order doesn't affect logic + # tmp_output_value = additional_inputs_list_cond[-3] # not used, change order doesn't affect logic + # del additional_inputs_list_cond[-3] # not used, change order doesn't affect logic del additional_inputs_list_cond[-2] # not used, change order doesn't affect logic additional_inputs_list_cond.append(tmp_bias) # not used, change order doesn't affect logic - additional_inputs_list_cond.append(tmp_output_value) # not used, change order doesn't affect logic + # additional_inputs_list_cond.append(tmp_output_value) # not used, change order doesn't affect logic # treat and pass additional_inputs to cond_fn print("additional_inputs_list_cond one: ", additional_inputs_list_cond) @@ -148,11 +148,12 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): for shape in shapes: p = xb.mkparam(builder, len(params), shape) params.append(p) - # tmp_bias = params[-2] + + tmp_bias = params[-2] # tmp_output_value = params[-3] # del params[-3] - # del params[-2] - # params.append(tmp_bias) + del params[-2] + params.append(tmp_bias) # params.append(tmp_output_value) print("args params: ", params) From 8b73e4322c9edc89557646ae40d1b4f4616d3653 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:06:20 +0000 Subject: [PATCH 202/546] update --- torch_xla/experimental/fori_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 61d58043175..8def75a982d 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -169,6 +169,8 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): name = 'fori_loop_ed_torch_func' computation = w.build(name) + print("carried_inputs: ", carried_inputs) + # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', (carried_inputs), From 839f5d1a1029290234dc2c31480b67a533132786 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:07:26 +0000 Subject: [PATCH 203/546] update --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 8def75a982d..828035591f4 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -173,7 +173,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', - (carried_inputs), + (total_inputs), # (carried_inputs), computation) print("!!! arrive here too after while!!!") From 4b96b9b6636ea415a5b791b7794dd8e210f8e227 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:10:21 +0000 Subject: [PATCH 204/546] update --- torch_xla/experimental/fori_loop.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 828035591f4..46ee2578cf0 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -44,7 +44,7 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') # cond_fn&body_fn: callable # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) - print("!!! arrive here @while_loop_op.py_impl(DispatchKey.XLA) !!!") + # print("!!! arrive here @while_loop_op.py_impl(DispatchKey.XLA) !!!") # print("while_loop additional_inputs: ", additional_inputs) if additional_inputs is None: additional_inputs = tuple() @@ -52,7 +52,7 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): - print("!!! arrive here def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): !!!") + # print("!!! arrive here def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): !!!") # print("original_carried_inputs: ", original_carried_inputs) # print("additional_inputs: ", additional_inputs) # import pdb; pdb.set_trace() @@ -85,9 +85,9 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): torch.randint(10, additional_input.size(), dtype=additional_input.dtype).to(device)) # fake_carried_inputs = tuple(fake_carried_inputs) - print("fake_carried_inputs second: ", fake_carried_inputs) + # print("fake_carried_inputs second: ", fake_carried_inputs) - print("!!! arrive here too before cond !!!") + # print("!!! arrive here too before cond !!!") # generate cond_fn xlacomputation # print("print fake_carried_inputs: ", fake_carried_inputs) # TODO(@manfei): specify which element is for which argument like a,b,c @@ -105,7 +105,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): # additional_inputs_list_cond.append(tmp_output_value) # not used, change order doesn't affect logic # treat and pass additional_inputs to cond_fn - print("additional_inputs_list_cond one: ", additional_inputs_list_cond) + # print("additional_inputs_list_cond one: ", additional_inputs_list_cond) # for i in range(len(additional_inputs)): # additional_inputs_list_cond.append(additional_inputs[i]) # print("additional_inputs_list_cond two: ", additional_inputs_list_cond) @@ -113,9 +113,9 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - print("!!! arrive here too after cond !!!") + # print("!!! arrive here too after cond !!!") - print("!!! arrive here too before body !!!") + # print("!!! arrive here too before body !!!") # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) # [:-3], weight_0=fake_carried_inputs[-1], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-2]) body_ctx = torch_xla._XLAC.lowering.LoweringContext() @@ -132,11 +132,11 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - print("!!! arrive here too after body !!!") + # print("!!! arrive here too after body !!!") - print("!!! arrive here too before args!!!") + # print("!!! arrive here too before args!!!") total_inputs = carried_inputs + additional_inputs - print("total_inputs: ", total_inputs) + # print("total_inputs: ", total_inputs) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while kwargs = {} if type(total_inputs) is tuple: @@ -156,10 +156,10 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): params.append(tmp_bias) # params.append(tmp_output_value) - print("args params: ", params) - print("!!! arrive here too after args!!!") + # print("args params: ", params) + # print("!!! arrive here too after args!!!") - print("!!! arrive here too before while!!!") + # print("!!! arrive here too before while!!!") # generate while xlacomputation input_tuple = xb.Op.tuple(tuple(params)) w = xb.mkop( @@ -169,7 +169,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): name = 'fori_loop_ed_torch_func' computation = w.build(name) - print("carried_inputs: ", carried_inputs) + # print("carried_inputs: ", carried_inputs) # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', From d1e141a04509d88ff442bf4c50cc8f27c577b9ab Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:12:15 +0000 Subject: [PATCH 205/546] update --- torch_xla/experimental/fori_loop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 46ee2578cf0..4bdf3d6c70b 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -170,11 +170,12 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): computation = w.build(name) # print("carried_inputs: ", carried_inputs) + print("total_inputs: ", total_inputs) # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', - (total_inputs), # (carried_inputs), + (total_inputs), computation) - print("!!! arrive here too after while!!!") + # print("!!! arrive here too after while!!!") return result \ No newline at end of file From 5f73913d30833b0bb69799bff3842076c06b5248 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:12:30 +0000 Subject: [PATCH 206/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index c126636b59e..aeef18c9841 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -109,7 +109,7 @@ def forward(self, upper, lower, one_value, x, input_value, output_value): # def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def cond_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): def cond_fn(upper, lower, one_value, x, input_value, output_value): - return lower[0] < upper[0] + return lower[0] >= upper[0] # def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def body_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): From 3fc27584f6f6c0b9ba66f33f2d1063f3932d3468 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:16:12 +0000 Subject: [PATCH 207/546] update --- ...p_with_while_loop_simple_add_dispatch_in_torch.py | 12 ++++++++++-- torch_xla/experimental/fori_loop.py | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index aeef18c9841..c5c2765dbcd 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -154,8 +154,16 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value) # , weight_0, bias_0) # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value) - bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) - print("bbb: ", bbb) + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) + # print("bbb: ", bbb) + print("upper__: ", upper__) + print("lower__: ", lower__) + print("one_value__: ", one_value__) + print("torch_add_res__: ", torch_add_res__) + print("input_value__: ", input_value__) + print("output_value_real__: ", output_value_real__) + print("weight__: ", weight__) + print("bias__: ", bias__) # print("start test 6 !!!") return aaa diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 4bdf3d6c70b..6ad00f67c7a 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -170,7 +170,7 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): computation = w.build(name) # print("carried_inputs: ", carried_inputs) - print("total_inputs: ", total_inputs) + # print("total_inputs: ", total_inputs) # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', From 191b6664bdff72c9a27ba0197f0b2a7019c1b533 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:16:40 +0000 Subject: [PATCH 208/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index c5c2765dbcd..844175ac050 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -162,8 +162,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): print("torch_add_res__: ", torch_add_res__) print("input_value__: ", input_value__) print("output_value_real__: ", output_value_real__) - print("weight__: ", weight__) - print("bias__: ", bias__) + # print("weight__: ", weight__) + # print("bias__: ", bias__) # print("start test 6 !!!") return aaa From d6935186149db6ca216a77fdfd391c95844280a0 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:20:05 +0000 Subject: [PATCH 209/546] update --- torch_xla/experimental/fori_loop.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 6ad00f67c7a..fb7b37b21a8 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -113,6 +113,9 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) + cond_hlo_print = xb.get_computation_hlo(cond_computation) + print("cond computation: !!!!!!!!!") + print(cond_hlo_print) # print("!!! arrive here too after cond !!!") # print("!!! arrive here too before body !!!") @@ -132,6 +135,9 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) + body_hlo_print = xb.get_computation_hlo(body_computation) + print("body computation: !!!!!!!!!") + print(body_hlo_print) # print("!!! arrive here too after body !!!") # print("!!! arrive here too before args!!!") @@ -168,6 +174,9 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) + hlo_print = xb.get_computation_hlo(computation) + print("while computation: !!!!!!!!!") + print(hlo_print) # print("carried_inputs: ", carried_inputs) # print("total_inputs: ", total_inputs) From f8cc89ec788aaab74bc1b652e0f3359e083d9d76 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:29:37 +0000 Subject: [PATCH 210/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 844175ac050..e13449266ea 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -122,7 +122,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # new_one_value = one_value # new_input_value = input_value # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value_real - return upper.clone(), lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, weight.clone(), bias.clone() + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, weight.clone(), bias.clone() # return while_loop(cond_fn, body_fn, (iter, x)) # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) # return 1 From 089a27f9b09bc34aefc1546c3db511ab4088b685 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:30:16 +0000 Subject: [PATCH 211/546] update --- torch_xla/experimental/fori_loop.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index fb7b37b21a8..14ceea7af04 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -113,9 +113,9 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - cond_hlo_print = xb.get_computation_hlo(cond_computation) - print("cond computation: !!!!!!!!!") - print(cond_hlo_print) + # cond_hlo_print = xb.get_computation_hlo(cond_computation) + # print("cond computation: !!!!!!!!!") + # print(cond_hlo_print) # print("!!! arrive here too after cond !!!") # print("!!! arrive here too before body !!!") @@ -135,9 +135,9 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - body_hlo_print = xb.get_computation_hlo(body_computation) - print("body computation: !!!!!!!!!") - print(body_hlo_print) + # body_hlo_print = xb.get_computation_hlo(body_computation) + # print("body computation: !!!!!!!!!") + # print(body_hlo_print) # print("!!! arrive here too after body !!!") # print("!!! arrive here too before args!!!") @@ -174,9 +174,9 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) - hlo_print = xb.get_computation_hlo(computation) - print("while computation: !!!!!!!!!") - print(hlo_print) + # hlo_print = xb.get_computation_hlo(computation) + # print("while computation: !!!!!!!!!") + # print(hlo_print) # print("carried_inputs: ", carried_inputs) # print("total_inputs: ", total_inputs) From 3a1fcf1a581708ea4e7345f73e08ee6b4bbb4de8 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:30:59 +0000 Subject: [PATCH 212/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index e13449266ea..4d85222ef01 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -109,7 +109,7 @@ def forward(self, upper, lower, one_value, x, input_value, output_value): # def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def cond_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): def cond_fn(upper, lower, one_value, x, input_value, output_value): - return lower[0] >= upper[0] + return lower[0] <= upper[0] # def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def body_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): From ea49168ee29839b2146c8d755c8187464070a30f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:31:56 +0000 Subject: [PATCH 213/546] update --- ...est_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 4d85222ef01..cb33020965c 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -168,6 +168,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return aaa expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + print("expected: ", expected) self.assertTrue(torch.all(torch.eq(expected, l_out_))) # res = simple_with_linear.apply((upper, lower, one_value, init_val, l_in_0, output_value)) From ab065d5f8e84daf45ab7fcbe54e1889341472474 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:32:39 +0000 Subject: [PATCH 214/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index cb33020965c..c31b8bf0e78 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -165,12 +165,12 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # print("weight__: ", weight__) # print("bias__: ", bias__) # print("start test 6 !!!") - return aaa + # return aaa expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) print("expected: ", expected) - self.assertTrue(torch.all(torch.eq(expected, l_out_))) + return self.assertTrue(torch.all(torch.eq(expected, l_out_))) # res = simple_with_linear.apply((upper, lower, one_value, init_val, l_in_0, output_value)) # print("res: ", res) # import pdb; pdb.set_trace() From 060eaace1bf7950d71e4b162c318f23170a210fb Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:33:04 +0000 Subject: [PATCH 215/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index c31b8bf0e78..9beffd3fe94 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -170,7 +170,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) print("expected: ", expected) - return self.assertTrue(torch.all(torch.eq(expected, l_out_))) + self.assertTrue(torch.all(torch.eq(expected, l_out_))) + return aaa # res = simple_with_linear.apply((upper, lower, one_value, init_val, l_in_0, output_value)) # print("res: ", res) # import pdb; pdb.set_trace() From b649e7e8fb93847409dd105927f02347c3adb319 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:35:20 +0000 Subject: [PATCH 216/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 9beffd3fe94..b82d0f3b994 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -167,6 +167,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # print("start test 6 !!!") # return aaa + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + linear_0.weight_.data = weight__ + linear_0.bias_.data = bias__ expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) print("expected: ", expected) From 1d901a3dd66e1e0cbb2886f177178358682fb078 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:35:53 +0000 Subject: [PATCH 217/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index b82d0f3b994..144fc024549 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -168,8 +168,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return aaa linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - linear_0.weight_.data = weight__ - linear_0.bias_.data = bias__ + linear_0.weight.data = weight__ + linear_0.bias.data = bias__ expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) print("expected: ", expected) From b62ba4621d768bd565e0ca729fb431c042d21dc0 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:36:31 +0000 Subject: [PATCH 218/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 144fc024549..34906cc363d 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -173,7 +173,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) print("expected: ", expected) - self.assertTrue(torch.all(torch.eq(expected, l_out_))) + self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) return aaa # res = simple_with_linear.apply((upper, lower, one_value, init_val, l_in_0, output_value)) # print("res: ", res) From b0f18f6285075e8e31199bf58da4b564f0b7af1a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:37:49 +0000 Subject: [PATCH 219/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 34906cc363d..401e17f2ca4 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -136,7 +136,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() - upper = torch.tensor([52], dtype=torch.int32, device=device) + upper = torch.tensor([2], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) # x @@ -172,6 +172,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): linear_0.bias.data = bias__ expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) print("expected: ", expected) + print("l_in_0: ", l_in_0) self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) return aaa From a37f72beff8271fe8f02ee799644b9ad8f32c475 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:38:49 +0000 Subject: [PATCH 220/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 401e17f2ca4..c79631e94a7 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -109,7 +109,7 @@ def forward(self, upper, lower, one_value, x, input_value, output_value): # def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def cond_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): def cond_fn(upper, lower, one_value, x, input_value, output_value): - return lower[0] <= upper[0] + return lower[0] < upper[0] # def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): # def body_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): From 36fa72effccd2a560909f8579c729a4104350062 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:41:41 +0000 Subject: [PATCH 221/546] update --- ...while_loop_simple_add_dispatch_in_torch.py | 46 ++----------------- 1 file changed, 5 insertions(+), 41 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index c79631e94a7..e569ebf62dc 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -98,83 +98,47 @@ class SimpleWithLinear(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) - # self.register_buffer("dec", torch.tensor(1)) - # def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): def forward(self, upper, lower, one_value, x, input_value, output_value): - # def forward(self, upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - # weight_1 = self.linear.weight - # bias_1 = self.linear.bias - - # def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - # def cond_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] - # def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): - # def body_fn(upper, lower, one_value, x, input_value, weight_1, bias_1, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = self.linear(input_value) - weight = self.linear.weight # not be used actually, would be used as - bias = self.linear.bias - # new_upper = upper - # new_one_value = one_value - # new_input_value = input_value - # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value_real + weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, weight.clone(), bias.clone() - # return while_loop(cond_fn, body_fn, (iter, x)) - # return while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) - # return 1 - # return upper, lower, one_value, x, input_value, output_value - # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_0, bias_0, output_value)) - # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, weight_1, bias_1, output_value)) - # weight_1 = self.linear.weight - # bias_1 = self.linear.bias - # return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value), (bias_1, weight_1)) return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) - # return _xla_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) simple_with_linear = SimpleWithLinear() upper = torch.tensor([2], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) - init_val = torch.tensor([1], dtype=torch.int32, device=device) # x - # l_in_0 = torch.randn(10, device=xm.xla_device()) # input_value + init_val = torch.tensor([1], dtype=torch.int32, device=device) l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) - # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) weight_0 = simple_with_linear.linear.weight bias_0 = simple_with_linear.linear.bias aaa = {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} - # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = aaa - # print("aaa: ", aaa) - # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) - # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value) # , weight_0, bias_0) - # bbb = simple_with_linear(upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value) upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) - # print("bbb: ", bbb) print("upper__: ", upper__) print("lower__: ", lower__) print("one_value__: ", one_value__) print("torch_add_res__: ", torch_add_res__) print("input_value__: ", input_value__) print("output_value_real__: ", output_value_real__) - # print("weight__: ", weight__) - # print("bias__: ", bias__) - # print("start test 6 !!!") - # return aaa linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) linear_0.weight.data = weight__ linear_0.bias.data = bias__ expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) print("expected: ", expected) - print("l_in_0: ", l_in_0) + # print("l_in_0: ", l_in_0) - self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + self.assertTrue(torch.all(torch.eq(expected, l_in_0))) return aaa # res = simple_with_linear.apply((upper, lower, one_value, init_val, l_in_0, output_value)) # print("res: ", res) From 6312c498b26ff41f25ef5c4b0b3cf15e24742be5 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:42:31 +0000 Subject: [PATCH 222/546] update --- ...ri_loop_with_while_loop_simple_add_dispatch_in_torch.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index e569ebf62dc..9a88267c030 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -136,14 +136,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): linear_0.bias.data = bias__ expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) print("expected: ", expected) - # print("l_in_0: ", l_in_0) - self.assertTrue(torch.all(torch.eq(expected, l_in_0))) + self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) return aaa - # res = simple_with_linear.apply((upper, lower, one_value, init_val, l_in_0, output_value)) - # print("res: ", res) - # import pdb; pdb.set_trace() - # return {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} def test_fori_loop_tpu_addition(self): From 4be4c237e526719cfe7f1d7d3076f20bb4a45273 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:48:22 +0000 Subject: [PATCH 223/546] update --- ...while_loop_simple_add_dispatch_in_torch.py | 43 -------- torch_xla/experimental/fori_loop.py | 99 +++---------------- 2 files changed, 16 insertions(+), 126 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 9a88267c030..a3e684f3069 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -7,7 +7,6 @@ # We need to import the underlying implementation function to register with the dispatcher import torch_xla.experimental.fori_loop from torch_xla.experimental.fori_loop import fori_loop -# from torch_xla.experimental.fori_loop import _xla_while_loop from torch._higher_order_ops.while_loop import while_loop import torch_xla.core.xla_model as xm import torch_xla.core.xla_builder as xb @@ -24,7 +23,6 @@ def _fake_fori_loop(lower, upper, body_fun, *init_val): if len(init_val) > 1: (a, b) = init_val for i in range((upper - lower)[0]): - # a = body_fun(a, b) a = body_fun(*init_val) else: for i in range((upper - lower)[0]): @@ -180,44 +178,3 @@ def test_fori_loop_tpu_simple_linear(self): if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) - - -######## --------------------------------------------------------- - -# x = torch.zeros(1) -# y = torch.zeros(1) -# z = torch.zeros(1) -# return {"simple_with_linear": (simple_with_linear, (torch.tensor(3), torch.randn(2, 2)))} - - # xm.mark_step() - # device = xm.xla_device() - # torch.set_grad_enabled(False) - - # upper = torch.tensor([52], dtype=torch.int32, device=device) - # lower = torch.tensor([0], dtype=torch.int32, device=device) - # one_value = torch.tensor([1], dtype=torch.int32, device=device) - # init_val = torch.tensor([1], dtype=torch.int32, device=device) # x - # l_in_0 = torch.randn(10, device=xm.xla_device()) # input_value - # output_value = torch.zeros([20], dtype=torch.float32, device=device) - - # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - # weight_0 = linear_0.weight - # bias_0 = linear_0.bias - - # # def cond_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): - # def cond_fn(upper, lower, one_value, x, input_value, output_value): - # return lower[0] < upper[0] - - # def body_fn(upper, lower, one_value, x, input_value, output_value): - # new_lower = torch.add(one_value, lower) - # output_value = linear_0(input_value) - # weight = linear_0.weight - # bias = linear_0.bias - # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value - - # # print("!!! arrive here !!!") - # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) - - # expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) - - # self.assertTrue(torch.all(torch.eq(expected, l_out_))) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 14ceea7af04..baa2dcebed7 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -17,18 +17,14 @@ def fori_loop(upper, lower, body_fun, init_val, *input_value): device = xm.xla_device() - def cond_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): # , bias_0): # , output_value): + def cond_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): # , bias_0): - # weight = body_fun.weight - new_lower = torch.add(one_value, lower) ### !!! this matter, torch.add might would change the second argument's value, even we use a new variable to catch the result!!! - output_value = body_fun(*input_value) ### !!! due to the output_value is not actually used here, - # --- !!! its original value would not be used, and it would be replaces by the result of body_fun - # --- !!! so, due to PTXLA is traced from result tensor, so the arguments `output_value` would not be included in the body_xlacomputation - # --- !!! so, we need to modify ini_python_binding.cpp to add a fake arguments in the xlacompputation - weight = body_fun.weight - bias = body_fun.bias + def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): + new_lower = torch.add(one_value, lower) + output_value = body_fun(*input_value) + weight = body_fun.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = body_fun.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper, new_lower, one_value, torch.add(one_value, x), *input_value, weight, bias, output_value output_value = torch.zeros([20], dtype=torch.float32, device=device) @@ -40,34 +36,16 @@ def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bi @while_loop_op.py_impl(DispatchKey.XLA) -def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): +def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') # cond_fn&body_fn: callable # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) - # print("!!! arrive here @while_loop_op.py_impl(DispatchKey.XLA) !!!") - # print("while_loop additional_inputs: ", additional_inputs) if additional_inputs is None: additional_inputs = tuple() - return _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) # a=a, b=b, c=c, - - -def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): - # print("!!! arrive here def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): !!!") - # print("original_carried_inputs: ", original_carried_inputs) - # print("additional_inputs: ", additional_inputs) - # import pdb; pdb.set_trace() - # untuple carried_inputs from while_loop - # carried_inputs = original_carried_inputs[0] ### due to PyTorch has already treat them , so skip split here - # TODO(@manfei): please clear pass additional_inputs in `while_loop`'s defination in this file - ### due to PyTorch has already treat them , so skip split here - # if len(original_carried_inputs) == 2: - # print("use original_carried_inputs for additional_inputs") - # additional_inputs = original_carried_inputs[1] - - # # exchange order of bias and weight in additional_inputs - # (bias_p, weight_p) = additional_inputs - # additional_inputs = (weight_p, bias_p) + return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) + +def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # fake carried_inputs to split formal code fake_carried_inputs = [] for carried_input in carried_inputs: @@ -76,74 +54,41 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): fake_carried_inputs.append( torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) - # fake_carried_inputs = tuple(fake_carried_inputs) - # print("fake_carried_inputs first: ", fake_carried_inputs) for additional_input in additional_inputs: device = additional_input.device #TODO(@manfei) type = carried_input.type fake_carried_inputs.append( torch.randint(10, additional_input.size(), dtype=additional_input.dtype).to(device)) - # fake_carried_inputs = tuple(fake_carried_inputs) - # print("fake_carried_inputs second: ", fake_carried_inputs) - # print("!!! arrive here too before cond !!!") - # generate cond_fn xlacomputation - # print("print fake_carried_inputs: ", fake_carried_inputs) # TODO(@manfei): specify which element is for which argument like a,b,c - cond_result = cond_fn(*fake_carried_inputs) # [:-3], weight_0=fake_carried_inputs[-2], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-1]) - # print("nnn here ???") + cond_result = cond_fn(*fake_carried_inputs) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor tmp_bias = additional_inputs_list_cond[-2] # not used, change order doesn't affect logic - # tmp_output_value = additional_inputs_list_cond[-3] # not used, change order doesn't affect logic - # del additional_inputs_list_cond[-3] # not used, change order doesn't affect logic del additional_inputs_list_cond[-2] # not used, change order doesn't affect logic additional_inputs_list_cond.append(tmp_bias) # not used, change order doesn't affect logic - # additional_inputs_list_cond.append(tmp_output_value) # not used, change order doesn't affect logic - # treat and pass additional_inputs to cond_fn - # print("additional_inputs_list_cond one: ", additional_inputs_list_cond) - # for i in range(len(additional_inputs)): - # additional_inputs_list_cond.append(additional_inputs[i]) - # print("additional_inputs_list_cond two: ", additional_inputs_list_cond) cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - # cond_hlo_print = xb.get_computation_hlo(cond_computation) - # print("cond computation: !!!!!!!!!") - # print(cond_hlo_print) - # print("!!! arrive here too after cond !!!") - # print("!!! arrive here too before body !!!") # generate body_fn xlacomputation - body_result = body_fn(*fake_carried_inputs) # [:-3], weight_0=fake_carried_inputs[-1], output_value=fake_carried_inputs[-3], bias_0=fake_carried_inputs[-2]) + body_result = body_fn(*fake_carried_inputs) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") - additional_inputs_list_body = [fake_carried_inputs[-3]] # missed arguments due to given output_value was not used and PyTorch/XLA trace xlacomputation from output tensor - # TODO(@manfei): treat and pass additional_inputs to body_fn too - # print("list(fake_carried_inputs[-2]: ", fake_carried_inputs[-2]) - # print("len0!!!: ", len(additional_inputs_list_body)) - # for i in range(len(additional_inputs)): - # additional_inputs_list_body.append(additional_inputs[i]) - # print("len!!!: ", len(additional_inputs_list_body)) - # print("additional_inputs_list_body: ", additional_inputs_list_body) + additional_inputs_list_body = [fake_carried_inputs[-3]] + # TODO(@manfei): treat hard-code parameters: additional_inputs_list_body body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - # body_hlo_print = xb.get_computation_hlo(body_computation) - # print("body computation: !!!!!!!!!") - # print(body_hlo_print) - # print("!!! arrive here too after body !!!") - # print("!!! arrive here too before args!!!") - total_inputs = carried_inputs + additional_inputs - # print("total_inputs: ", total_inputs) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while + total_inputs = carried_inputs + additional_inputs kwargs = {} if type(total_inputs) is tuple: shapes = xb.tensor_shape(total_inputs) @@ -155,17 +100,11 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): p = xb.mkparam(builder, len(params), shape) params.append(p) + # TODO(@manfei): treat hard-code input arguments tmp_bias = params[-2] - # tmp_output_value = params[-3] - # del params[-3] del params[-2] params.append(tmp_bias) - # params.append(tmp_output_value) - # print("args params: ", params) - # print("!!! arrive here too after args!!!") - - # print("!!! arrive here too before while!!!") # generate while xlacomputation input_tuple = xb.Op.tuple(tuple(params)) w = xb.mkop( @@ -174,12 +113,6 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=()): body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) - # hlo_print = xb.get_computation_hlo(computation) - # print("while computation: !!!!!!!!!") - # print(hlo_print) - - # print("carried_inputs: ", carried_inputs) - # print("total_inputs: ", total_inputs) # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', From e3a2fcd8a1daca469296706153c01f23112a8132 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:51:14 +0000 Subject: [PATCH 224/546] update --- ...ri_loop_with_while_loop_simple_add_dispatch_in_torch.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index a3e684f3069..111b4203b8b 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -122,13 +122,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): aaa = {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) - print("upper__: ", upper__) - print("lower__: ", lower__) - print("one_value__: ", one_value__) - print("torch_add_res__: ", torch_add_res__) - print("input_value__: ", input_value__) - print("output_value_real__: ", output_value_real__) + # same weight/bias liear model linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) linear_0.weight.data = weight__ linear_0.bias.data = bias__ From abb30456f68e4929070317be8c429572d7d471e3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:52:15 +0000 Subject: [PATCH 225/546] update --- ...est_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 111b4203b8b..d06c8e8d56c 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -128,7 +128,6 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): linear_0.weight.data = weight__ linear_0.bias.data = bias__ expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) - print("expected: ", expected) self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) return aaa From 95608ff09b4c54d31a19d52b3e4ba7c95930ca72 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:53:58 +0000 Subject: [PATCH 226/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index d06c8e8d56c..88243eee9b4 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -171,4 +171,4 @@ def test_fori_loop_tpu_simple_linear(self): if __name__ == '__main__': test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) + sys.exit(0 if test.result.wasSuccessful() else 1) \ No newline at end of file From c8d722ee54bb6c74ba5170eecbd2adb14ec4b266 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 15 Apr 2024 23:59:20 +0000 Subject: [PATCH 227/546] update --- torch_xla/csrc/init_python_bindings.cpp | 2 ++ torch_xla/experimental/fori_loop.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 145739e1d1a..c7013569dd1 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -917,6 +917,7 @@ class PyLoweringContext { if (GetNameString() == "condctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // hard-code parameter_idx to 2 to skip existing upper/lower arguments + // TODO(@manfei): get body xlacomputation arguments' number first then decide items in `additional_inputs_list`, maybe implement in python level // !!! since cond_fn only compare upper and lower, so it would only use two arguments, due to PyTorch/XLA // !!! trace xlacomputation from result tensor, so all the other arguments would not be included or generated; // !!! but to meet xla::while requirement, we would skip first two arguments, @@ -935,6 +936,7 @@ class PyLoweringContext { } // hard-code modify body xlacomputation input arguments + // TODO(@manfei): get body xlacomputation arguments' number first then decide items in `additional_inputs_list`, maybe implement in python level if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); int64_t parameter_idx = 7; // tensors.size(); diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index baa2dcebed7..a0f8024333d 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -75,6 +75,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) + cond_hlo_print = xb.get_computation_hlo(cond_computation) + print("cond computation: !!!!!!!!!") + print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) @@ -86,6 +89,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) + body_hlo_print = xb.get_computation_hlo(body_computation) + print("body computation: !!!!!!!!!") + print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs @@ -113,6 +119,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) + hlo_print = xb.get_computation_hlo(computation) + print("while computation: !!!!!!!!!") + print(hlo_print) # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', From d33f14c82ffa25cdfff6ea22add19498a973f126 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 17:06:12 +0000 Subject: [PATCH 228/546] update --- ...while_loop_simple_add_dispatch_in_torch.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 88243eee9b4..c7e5b9392e7 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -86,12 +86,52 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) +# passed def test_while_loop_tpu_simple_linear(self): xm.mark_step() device = xm.xla_device() torch.set_grad_enabled(False) + # def forward(self, upper, lower, one_value, x, input_value, output_value): + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value_real = linear_0(input_value) + weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, weight.clone(), bias.clone() + + simple_with_linear = SimpleWithLinear() + upper = torch.tensor([2], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + # weight_0 = simple_with_linear.linear.weight + # bias_0 = simple_with_linear.linear.bias + + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = + while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) + + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + return aaa + +# passed + def test_while_loop_tpu_simple_linear_class(self): + + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + class SimpleWithLinear(torch.nn.Module): def __init__(self): super().__init__() From 1f8002a07893545be707b11f4cef5091f721ed59 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 17:07:05 +0000 Subject: [PATCH 229/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index c7e5b9392e7..e33cd21e476 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -117,8 +117,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # weight_0 = simple_with_linear.linear.weight # bias_0 = simple_with_linear.linear.bias - upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = - while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From 5cffda24f659ed8cdd1b4abd3092b7a03ff82494 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 17:07:41 +0000 Subject: [PATCH 230/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index e33cd21e476..304ce22b96e 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -106,7 +106,6 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, weight.clone(), bias.clone() - simple_with_linear = SimpleWithLinear() upper = torch.tensor([2], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) @@ -114,9 +113,6 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) - # weight_0 = simple_with_linear.linear.weight - # bias_0 = simple_with_linear.linear.bias - upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From f9cd4dc55bdf180c85d68b20bff8a64cbb412a84 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 17:08:49 +0000 Subject: [PATCH 231/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 304ce22b96e..db73f40a87c 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -113,7 +113,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) - upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From 1fc141409d24c83b5e7c80dcd64bb0484aeae7de Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 17:09:17 +0000 Subject: [PATCH 232/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index db73f40a87c..436e5a68b55 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -113,7 +113,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) - upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From 90481612e5e838118ce7084834507b57f1d42250 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 17:10:07 +0000 Subject: [PATCH 233/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 436e5a68b55..411d2d9b41e 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -117,6 +117,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + print("output_value_real__: ", output_value_real__) + print("expected: ", expected) + self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) return aaa From 614a57fe461cd79be05a6d5e292aa19963d0b3b8 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 17:10:46 +0000 Subject: [PATCH 234/546] update --- torch_xla/experimental/fori_loop.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index a0f8024333d..b8b0583e10e 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -75,9 +75,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - cond_hlo_print = xb.get_computation_hlo(cond_computation) - print("cond computation: !!!!!!!!!") - print(cond_hlo_print) + # cond_hlo_print = xb.get_computation_hlo(cond_computation) + # print("cond computation: !!!!!!!!!") + # print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) @@ -89,9 +89,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - body_hlo_print = xb.get_computation_hlo(body_computation) - print("body computation: !!!!!!!!!") - print(body_hlo_print) + # body_hlo_print = xb.get_computation_hlo(body_computation) + # print("body computation: !!!!!!!!!") + # print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs @@ -119,9 +119,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) - hlo_print = xb.get_computation_hlo(computation) - print("while computation: !!!!!!!!!") - print(hlo_print) + # hlo_print = xb.get_computation_hlo(computation) + # print("while computation: !!!!!!!!!") + # print(hlo_print) # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', From eded2acc639cd03d600b2e60ac25d21b68c24ce3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 17:12:27 +0000 Subject: [PATCH 235/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 411d2d9b41e..738d592035a 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -148,7 +148,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) simple_with_linear = SimpleWithLinear() - upper = torch.tensor([2], dtype=torch.int32, device=device) + upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) From 504ea2ba74e2a599a936d161d371b6ddc1cd1da3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 17:13:21 +0000 Subject: [PATCH 236/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 738d592035a..aef680f2edf 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -86,7 +86,6 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) -# passed def test_while_loop_tpu_simple_linear(self): xm.mark_step() @@ -106,7 +105,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, weight.clone(), bias.clone() - upper = torch.tensor([2], dtype=torch.int32, device=device) + upper = torch.tensor([1], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) From b5889a85f4838fd504afb974e87b16b9e10e9b06 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 17:14:47 +0000 Subject: [PATCH 237/546] update --- ...est_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index aef680f2edf..f75a87b8cbc 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -116,6 +116,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + print("torch_add_res__: ", torch_add_res__) print("output_value_real__: ", output_value_real__) print("expected: ", expected) From a572c1b6662fd0d59c6aa3e2c7f2cab32a233231 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 17:15:24 +0000 Subject: [PATCH 238/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index f75a87b8cbc..287797bab62 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -105,7 +105,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, weight.clone(), bias.clone() - upper = torch.tensor([1], dtype=torch.int32, device=device) + upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) From aeda7f78f91f70dc235fd4a6dc824c8c1f687669 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 17:16:07 +0000 Subject: [PATCH 239/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 287797bab62..ef01d7213d0 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -103,7 +103,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): output_value_real = linear_0(input_value) weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, weight.clone(), bias.clone() + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), weight.clone(), bias.clone(), output_value_real upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) From c6344a9aa60e763849323fd1910afd7bbd6e4fa0 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 17:17:50 +0000 Subject: [PATCH 240/546] update --- ...ori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- torch_xla/experimental/fori_loop.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index ef01d7213d0..287797bab62 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -103,7 +103,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): output_value_real = linear_0(input_value) weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), weight.clone(), bias.clone(), output_value_real + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, weight.clone(), bias.clone() upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index b8b0583e10e..0160fca2d5a 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -119,9 +119,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) - # hlo_print = xb.get_computation_hlo(computation) - # print("while computation: !!!!!!!!!") - # print(hlo_print) + hlo_print = xb.get_computation_hlo(computation) + print("while computation: !!!!!!!!!") + print(hlo_print) # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', From 2e7f5fb0beed5539ab2aeab7ba26f439716cc327 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 17:26:53 +0000 Subject: [PATCH 241/546] update --- ...i_loop_with_while_loop_simple_add_dispatch_in_torch.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 287797bab62..e60de52f2bb 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -94,16 +94,18 @@ def test_while_loop_tpu_simple_linear(self): # def forward(self, upper, lower, one_value, x, input_value, output_value): linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + weight_0 = linear_0.weight_ + bias_0 = linear_0.bias_ - def cond_fn(upper, lower, one_value, x, input_value, output_value): + def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, output_value): + def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): new_lower = torch.add(one_value, lower) output_value_real = linear_0(input_value) weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, weight.clone(), bias.clone() + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), weight.clone(), output_value_real, bias.clone() upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) From abc2c9c77b971e6a0cdbff3c7e8a5c3f4d1349cd Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 17:27:27 +0000 Subject: [PATCH 242/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index e60de52f2bb..ec29fa9e0f7 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -94,8 +94,8 @@ def test_while_loop_tpu_simple_linear(self): # def forward(self, upper, lower, one_value, x, input_value, output_value): linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - weight_0 = linear_0.weight_ - bias_0 = linear_0.bias_ + weight_0 = linear_0.weight + bias_0 = linear_0.bias def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): return lower[0] < upper[0] From c7d46545bf0a59795b675a55a0a695541f732d52 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 18:14:45 +0000 Subject: [PATCH 243/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index ec29fa9e0f7..01178fb6fa5 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -86,6 +86,7 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) +# debugging def test_while_loop_tpu_simple_linear(self): xm.mark_step() @@ -105,7 +106,7 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_va output_value_real = linear_0(input_value) weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), weight.clone(), output_value_real, bias.clone() + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), weight.clone(), bias.clone(), output_value_real upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) From 5dbbffb396f37cfd27aaa8557bdf592c194a7fb1 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 18:18:37 +0000 Subject: [PATCH 244/546] update --- torch_xla/experimental/fori_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 0160fca2d5a..d4fc5c96d09 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -40,12 +40,14 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') # cond_fn&body_fn: callable # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) + print("arrive here @while_loop_op.py_impl(DispatchKey.XLA) !!!") if additional_inputs is None: additional_inputs = tuple() return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): + print("arrive here _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): !!!") # fake carried_inputs to split formal code fake_carried_inputs = [] for carried_input in carried_inputs: From 5fca07d41090cd6424fa8aa3c0b1220c79c6ac26 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 18:28:38 +0000 Subject: [PATCH 245/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 01178fb6fa5..ba045688258 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -115,7 +115,7 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_va l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) - upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From 35ed14dd5da3cb5a8d5cf221840fb0cc0c54145e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 18:33:44 +0000 Subject: [PATCH 246/546] update --- torch_xla/experimental/fori_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index d4fc5c96d09..418080ed90a 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -41,6 +41,8 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # cond_fn&body_fn: callable # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) print("arrive here @while_loop_op.py_impl(DispatchKey.XLA) !!!") + print("carried_inputs: ", carried_inputs) + print("additional_inputs: ", additional_inputs) if additional_inputs is None: additional_inputs = tuple() return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) From b826cf9b0ec0f437e00cf340f374d82b9a76fd72 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 18:38:07 +0000 Subject: [PATCH 247/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index ba045688258..d1028fea36b 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -115,7 +115,7 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_va l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) - upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, weight_0, bias_0, output_value)) + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value, weight_0, bias_0)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From df34a6bd967a0ecae8e94f6f7deaf5acbf9f74ce Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 18:41:37 +0000 Subject: [PATCH 248/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index d1028fea36b..f16abca316b 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -115,7 +115,7 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_va l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) - upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value, weight_0, bias_0)) + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From fde54ac7038908d24042a3460d86be619b2b954e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 18:43:00 +0000 Subject: [PATCH 249/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index f16abca316b..0ac78cffbda 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -98,10 +98,10 @@ def test_while_loop_tpu_simple_linear(self): weight_0 = linear_0.weight bias_0 = linear_0.bias - def cond_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, weight_0, bias_0, output_value): + def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = linear_0(input_value) weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement From 1f2f50e21d00cfc06d1358814b483eaa37a7f90f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 19:28:54 +0000 Subject: [PATCH 250/546] update --- ...while_loop_simple_add_dispatch_in_torch.py | 2 +- torch_xla/experimental/fori_loop.py | 46 ++++++++++--------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 0ac78cffbda..511f8a8d853 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -106,7 +106,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): output_value_real = linear_0(input_value) weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), weight.clone(), bias.clone(), output_value_real + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, weight.clone(), bias.clone() # , output_value_real upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 418080ed90a..a69dd8d99df 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -12,7 +12,7 @@ from torch._higher_order_ops.while_loop import while_loop_op -# TODO(@manfei): delete one_value? +### TODO(@manfei): delete one_value? def fori_loop(upper, lower, body_fun, init_val, *input_value): device = xm.xla_device() @@ -23,8 +23,8 @@ def cond_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bi def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): new_lower = torch.add(one_value, lower) output_value = body_fun(*input_value) - weight = body_fun.weight # not be used actually, initialized as placeholder xlacomputation requirement - bias = body_fun.bias # not be used actually, initialized as placeholder xlacomputation requirement + weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement + bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement return upper, new_lower, one_value, torch.add(one_value, x), *input_value, weight, bias, output_value output_value = torch.zeros([20], dtype=torch.float32, device=device) @@ -37,9 +37,9 @@ def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bi @while_loop_op.py_impl(DispatchKey.XLA) def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): - # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') - # cond_fn&body_fn: callable - # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) + ### TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') + ### cond_fn&body_fn: callable + ### carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) print("arrive here @while_loop_op.py_impl(DispatchKey.XLA) !!!") print("carried_inputs: ", carried_inputs) print("additional_inputs: ", additional_inputs) @@ -50,30 +50,30 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): print("arrive here _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): !!!") - # fake carried_inputs to split formal code + ### fake carried_inputs to split formal code fake_carried_inputs = [] for carried_input in carried_inputs: device = carried_input.device - #TODO(@manfei) type = carried_input.type + ###TODO(@manfei) type = carried_input.type fake_carried_inputs.append( torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) for additional_input in additional_inputs: device = additional_input.device - #TODO(@manfei) type = carried_input.type + ###TODO(@manfei) type = carried_input.type fake_carried_inputs.append( torch.randint(10, additional_input.size(), dtype=additional_input.dtype).to(device)) - # TODO(@manfei): specify which element is for which argument like a,b,c + ### TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") - additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor - - tmp_bias = additional_inputs_list_cond[-2] # not used, change order doesn't affect logic - del additional_inputs_list_cond[-2] # not used, change order doesn't affect logic - additional_inputs_list_cond.append(tmp_bias) # not used, change order doesn't affect logic + # !!! cond xlacomputation change !!! switch bias and weight position + additional_inputs_list_cond = list(fake_carried_inputs[2:]) ### all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor + tmp_bias = additional_inputs_list_cond[-2] ### not used, change order doesn't affect logic + del additional_inputs_list_cond[-2] ### not used, change order doesn't affect logic + additional_inputs_list_cond.append(tmp_bias) ### not used, change order doesn't affect logic cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) cond_hlo = cond_ctx.hlo() @@ -83,12 +83,13 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # print("cond computation: !!!!!!!!!") # print(cond_hlo_print) - # generate body_fn xlacomputation + ### generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") + # !!! body xlacomputation change !!! add output_value argument additional_inputs_list_body = [fake_carried_inputs[-3]] - # TODO(@manfei): treat hard-code parameters: additional_inputs_list_body + ### TODO(@manfei): treat hard-code parameters: additional_inputs_list_body body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", @@ -97,7 +98,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # print("body computation: !!!!!!!!!") # print(body_hlo_print) - # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while + ### trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs kwargs = {} if type(total_inputs) is tuple: @@ -110,12 +111,13 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): p = xb.mkparam(builder, len(params), shape) params.append(p) - # TODO(@manfei): treat hard-code input arguments + ### TODO(@manfei): treat hard-code input arguments + # !!! init change !!! tmp_bias = params[-2] del params[-2] params.append(tmp_bias) - # generate while xlacomputation + ### generate while xlacomputation input_tuple = xb.Op.tuple(tuple(params)) w = xb.mkop( 'While', (input_tuple.op,), @@ -127,10 +129,10 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): print("while computation: !!!!!!!!!") print(hlo_print) - # gain final result with generated while xlacomputation + ### gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', (total_inputs), computation) - # print("!!! arrive here too after while!!!") + ### print("!!! arrive here too after while!!!") return result \ No newline at end of file From 27f20b7489f2592d51a0c13781a1c304de831c5b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 22:36:39 +0000 Subject: [PATCH 251/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 511f8a8d853..33b11e8eafe 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -106,7 +106,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): output_value_real = linear_0(input_value) weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, weight.clone(), bias.clone() # , output_value_real + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real.clone(), weight.clone(), bias.clone() # , output_value_real upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) From 43af2152df0ad7a0ac87bdf4f5b58f6c60431cbe Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 16 Apr 2024 22:38:17 +0000 Subject: [PATCH 252/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 33b11e8eafe..d5bedc787f6 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -155,7 +155,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) - l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value + l_in_0 = torch.randint(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) weight_0 = simple_with_linear.linear.weight From 79bd30345593dde19a3cfbefc1bbe22be5907577 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 05:00:25 +0000 Subject: [PATCH 253/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index d5bedc787f6..1355697b32d 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -155,7 +155,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) - l_in_0 = torch.randint(10, device=xm.xla_device()) # input_value + l_in_0 = torch.randint(10, (10,), device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) weight_0 = simple_with_linear.linear.weight From 08a155ede78326be59ba446c3592a9b9813e17f4 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 05:02:08 +0000 Subject: [PATCH 254/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 1355697b32d..c8119721b19 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -155,7 +155,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) - l_in_0 = torch.randint(10, (10,), device=xm.xla_device()) # input_value + l_in_0 = torch.rand(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) weight_0 = simple_with_linear.linear.weight From 2840aefe6af8a5035c45103e763094cbaeb3749f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 17:09:58 +0000 Subject: [PATCH 255/546] update --- ...est_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index c8119721b19..0b578ded851 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -113,6 +113,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value + # l_in_0 = torch.rand(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) From f144d25c3306a0ba45b89a56066b96a93fae6f3f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 17:17:43 +0000 Subject: [PATCH 256/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 0b578ded851..76ad459798e 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -112,8 +112,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) - l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value - # l_in_0 = torch.rand(10, device=xm.xla_device()) # input_value + # l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value + l_in_0 = torch.rand(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) From 03743edbdc91df0739f1ffaab6258686041ad190 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 17:30:53 +0000 Subject: [PATCH 257/546] update --- ...i_loop_with_while_loop_simple_add_dispatch_in_torch.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 76ad459798e..e00f2ffeaa1 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -95,8 +95,8 @@ def test_while_loop_tpu_simple_linear(self): # def forward(self, upper, lower, one_value, x, input_value, output_value): linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - weight_0 = linear_0.weight - bias_0 = linear_0.bias + # weight_0 = linear_0.weight + # bias_0 = linear_0.bias def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] @@ -115,8 +115,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value l_in_0 = torch.rand(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) + # weight_0 = linear_0.weight + # bias_0 = linear_0.bias - upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) From 118e64023d783e47c9c66f467a44a312f575f6bf Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 18:00:07 +0000 Subject: [PATCH 258/546] update --- ...ori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index e00f2ffeaa1..bbcfdb60d83 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -103,10 +103,12 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) - output_value_real = linear_0(input_value) + # output_value_real = linear_0(input_value) + output_value = linear_0(input_value) weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real.clone(), weight.clone(), bias.clone() # , output_value_real + # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real.clone(), weight.clone(), bias.clone() # , output_value_real + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone(), weight.clone(), bias.clone() # , output_value_real upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) From 915ca5a6126e5323340cb55cc2e7e8aa2cf37f3a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 18:01:17 +0000 Subject: [PATCH 259/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index bbcfdb60d83..1b771d20af1 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -108,7 +108,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real.clone(), weight.clone(), bias.clone() # , output_value_real - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone(), weight.clone(), bias.clone() # , output_value_real + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), weight.clone(), bias.clone(), output_value.clone() # , output_value_real upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) From 0bca16d922038a8e377612ea0102c1d341df86cf Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 18:02:36 +0000 Subject: [PATCH 260/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 1b771d20af1..4f52e21dc6a 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -108,7 +108,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real.clone(), weight.clone(), bias.clone() # , output_value_real - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), weight.clone(), bias.clone(), output_value.clone() # , output_value_real + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() # , output_value_real upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) From 0b04e2819ec1f0065eb833750bd5f7f709200ad3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 18:03:19 +0000 Subject: [PATCH 261/546] update --- ...est_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 4f52e21dc6a..173afeba744 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -126,6 +126,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): print("torch_add_res__: ", torch_add_res__) print("output_value_real__: ", output_value_real__) + print("bias__: ", bias__) print("expected: ", expected) self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) From 68e9aba51643232884dc216a898aa0a9086f928e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 18:15:33 +0000 Subject: [PATCH 262/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 173afeba744..58bc78c806e 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -110,7 +110,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real.clone(), weight.clone(), bias.clone() # , output_value_real return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() # , output_value_real - upper = torch.tensor([52], dtype=torch.int32, device=device) + upper = torch.tensor([1], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) From 94d004582913e9e520112782550bd9b01d2d6f9b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 18:21:02 +0000 Subject: [PATCH 263/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 58bc78c806e..b4344c90d31 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -96,7 +96,8 @@ def test_while_loop_tpu_simple_linear(self): # def forward(self, upper, lower, one_value, x, input_value, output_value): linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) # weight_0 = linear_0.weight - # bias_0 = linear_0.bias + bias_0 = linear_0.bias + print("original bias: ", bias_0) def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] From b9192beb819839246c55d636bf1f75861cc95235 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 18:21:55 +0000 Subject: [PATCH 264/546] update --- torch_xla/experimental/fori_loop.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index a69dd8d99df..b239d7eb56d 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -94,9 +94,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - # body_hlo_print = xb.get_computation_hlo(body_computation) - # print("body computation: !!!!!!!!!") - # print(body_hlo_print) + body_hlo_print = xb.get_computation_hlo(body_computation) + print("body computation: !!!!!!!!!") + print(body_hlo_print) ### trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs @@ -125,9 +125,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) - hlo_print = xb.get_computation_hlo(computation) - print("while computation: !!!!!!!!!") - print(hlo_print) + # hlo_print = xb.get_computation_hlo(computation) + # print("while computation: !!!!!!!!!") + # print(hlo_print) ### gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', From 1b594e60120dfda7e71c59cf7aba77ca57e9f0e1 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 18:29:42 +0000 Subject: [PATCH 265/546] update --- torch_xla/experimental/fori_loop.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index b239d7eb56d..5694a1ba1f0 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -64,11 +64,13 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): fake_carried_inputs.append( torch.randint(10, additional_input.size(), dtype=additional_input.dtype).to(device)) + print("fake_carried_inputs: ", fake_carried_inputs) ### TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") + # !!! cond xlacomputation change !!! switch bias and weight position additional_inputs_list_cond = list(fake_carried_inputs[2:]) ### all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor tmp_bias = additional_inputs_list_cond[-2] ### not used, change order doesn't affect logic @@ -87,6 +89,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_result = body_fn(*fake_carried_inputs) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") + # !!! body xlacomputation change !!! add output_value argument additional_inputs_list_body = [fake_carried_inputs[-3]] ### TODO(@manfei): treat hard-code parameters: additional_inputs_list_body From 867e488701ec4147331586cd948c16df4bb42a42 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 18:33:55 +0000 Subject: [PATCH 266/546] update --- torch_xla/experimental/fori_loop.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 5694a1ba1f0..e69323f1c8b 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -71,10 +71,16 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") - # !!! cond xlacomputation change !!! switch bias and weight position + # # !!! cond xlacomputation change !!! switch bias and weight position + # additional_inputs_list_cond = list(fake_carried_inputs[2:]) ### all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor + # tmp_bias = additional_inputs_list_cond[-2] ### not used, change order doesn't affect logic + # del additional_inputs_list_cond[-2] ### not used, change order doesn't affect logic + # additional_inputs_list_cond.append(tmp_bias) ### not used, change order doesn't affect logic + + # !!! cond xlacomputation change !!! switch output_value and weight position additional_inputs_list_cond = list(fake_carried_inputs[2:]) ### all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor - tmp_bias = additional_inputs_list_cond[-2] ### not used, change order doesn't affect logic - del additional_inputs_list_cond[-2] ### not used, change order doesn't affect logic + tmp_bias = additional_inputs_list_cond[-3] ### not used, change order doesn't affect logic + del additional_inputs_list_cond[-3] ### not used, change order doesn't affect logic additional_inputs_list_cond.append(tmp_bias) ### not used, change order doesn't affect logic cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) From 7a1bde9229c780fc6219710d0ac44057eeab2be9 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 18:35:20 +0000 Subject: [PATCH 267/546] update --- torch_xla/experimental/fori_loop.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index e69323f1c8b..c4923c76fe7 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -96,8 +96,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") - # !!! body xlacomputation change !!! add output_value argument + # !!! body xlacomputation change !!! add non-changed output_value argument additional_inputs_list_body = [fake_carried_inputs[-3]] + ### TODO(@manfei): treat hard-code parameters: additional_inputs_list_body body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) body_hlo = body_ctx.hlo() @@ -120,10 +121,16 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): p = xb.mkparam(builder, len(params), shape) params.append(p) + # ### TODO(@manfei): treat hard-code input arguments + # # !!! init change !!! + # tmp_bias = params[-2] + # del params[-2] + # params.append(tmp_bias) + ### TODO(@manfei): treat hard-code input arguments - # !!! init change !!! - tmp_bias = params[-2] - del params[-2] + # !!! init change !!! switch bias and output_value + tmp_bias = params[-3] + del params[-3] params.append(tmp_bias) ### generate while xlacomputation From 120c33598e00838e943bd7a82253a86076c93881 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 18:36:36 +0000 Subject: [PATCH 268/546] update --- ...oop_with_while_loop_simple_add_dispatch_in_torch.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index b4344c90d31..2ddb6af2f17 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -121,14 +121,14 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # weight_0 = linear_0.weight # bias_0 = linear_0.bias - upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) + upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) - print("torch_add_res__: ", torch_add_res__) - print("output_value_real__: ", output_value_real__) - print("bias__: ", bias__) - print("expected: ", expected) + # print("torch_add_res__: ", torch_add_res__) + # print("output_value_real__: ", output_value_real__) + # print("bias__: ", bias__) + # print("expected: ", expected) self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) return aaa From bc5950a04ad2af1113f657f228e824fef5653b77 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 18:37:40 +0000 Subject: [PATCH 269/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 2ddb6af2f17..90d23dc2686 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -131,7 +131,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # print("expected: ", expected) self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) - return aaa + return True # passed def test_while_loop_tpu_simple_linear_class(self): From 91a3aa8a11230ce8ec8f3e4ad46871f2f898ea37 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 18:38:04 +0000 Subject: [PATCH 270/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 90d23dc2686..956acee8b1a 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -130,8 +130,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # print("bias__: ", bias__) # print("expected: ", expected) - self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) - return True + # self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + return self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) # passed def test_while_loop_tpu_simple_linear_class(self): From 7041bc32203a74a298f1a11730aa45629647b5cf Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 18:38:54 +0000 Subject: [PATCH 271/546] update --- torch_xla/experimental/fori_loop.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index c4923c76fe7..3e88a5f4446 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -41,8 +41,8 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): ### cond_fn&body_fn: callable ### carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) print("arrive here @while_loop_op.py_impl(DispatchKey.XLA) !!!") - print("carried_inputs: ", carried_inputs) - print("additional_inputs: ", additional_inputs) + # print("carried_inputs: ", carried_inputs) + # print("additional_inputs: ", additional_inputs) if additional_inputs is None: additional_inputs = tuple() return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) @@ -64,7 +64,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): fake_carried_inputs.append( torch.randint(10, additional_input.size(), dtype=additional_input.dtype).to(device)) - print("fake_carried_inputs: ", fake_carried_inputs) + # print("fake_carried_inputs: ", fake_carried_inputs) ### TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) @@ -104,9 +104,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - body_hlo_print = xb.get_computation_hlo(body_computation) - print("body computation: !!!!!!!!!") - print(body_hlo_print) + # body_hlo_print = xb.get_computation_hlo(body_computation) + # print("body computation: !!!!!!!!!") + # print(body_hlo_print) ### trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs From 3ec8120e9e5f4261427d2da608ded09011edb6cc Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 19:40:59 +0000 Subject: [PATCH 272/546] update --- ..._loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- torch_xla/experimental/fori_loop.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 956acee8b1a..fc803d252a3 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -86,7 +86,7 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) -# debugging +# passed def test_while_loop_tpu_simple_linear(self): xm.mark_step() diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 3e88a5f4446..5918fd40900 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -23,9 +23,12 @@ def cond_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bi def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): new_lower = torch.add(one_value, lower) output_value = body_fun(*input_value) - weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement - bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement - return upper, new_lower, one_value, torch.add(one_value, x), *input_value, weight, bias, output_value + if (hasattr(body_fun, 'weight') and hasattr(body_fun, 'bias')): + weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement + bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement + return upper, new_lower, one_value, torch.add(one_value, x), *input_value, weight, bias, output_value + else: + return upper, new_lower, one_value, torch.add(one_value, x), *input_value, output_value output_value = torch.zeros([20], dtype=torch.float32, device=device) weight_0 = body_fun.weight From 9beac796744f17d1cea55556e843f3d503cb2444 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 19:42:15 +0000 Subject: [PATCH 273/546] update --- torch_xla/experimental/fori_loop.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 5918fd40900..d6832d2e532 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -31,10 +31,13 @@ def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bi return upper, new_lower, one_value, torch.add(one_value, x), *input_value, output_value output_value = torch.zeros([20], dtype=torch.float32, device=device) - weight_0 = body_fun.weight - bias_0 = body_fun.bias one_value = torch.tensor([1], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, *input_value, weight_0, bias_0, output_value)) + if (hasattr(body_fun, 'weight') and hasattr(body_fun, 'bias')): + weight_0 = body_fun.weight + bias_0 = body_fun.bias + res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, *input_value, weight_0, bias_0, output_value)) + else: + res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, *input_value, output_value)) return res From 70891937cfc91bc125f5f09d99ebb72195c0e001 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 19:44:22 +0000 Subject: [PATCH 274/546] update --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- torch_xla/experimental/fori_loop.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index fc803d252a3..7ae046f696b 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -97,7 +97,7 @@ def test_while_loop_tpu_simple_linear(self): linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) # weight_0 = linear_0.weight bias_0 = linear_0.bias - print("original bias: ", bias_0) + # print("original bias: ", bias_0) def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index d6832d2e532..80e9b290848 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -70,7 +70,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): fake_carried_inputs.append( torch.randint(10, additional_input.size(), dtype=additional_input.dtype).to(device)) - # print("fake_carried_inputs: ", fake_carried_inputs) + print("fake_carried_inputs: ", fake_carried_inputs) ### TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) From 9d9bc32a84bffc10b884a0d30f744ce8b385fc87 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 19:50:11 +0000 Subject: [PATCH 275/546] update --- torch_xla/experimental/fori_loop.py | 30 +++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 80e9b290848..47b998857ff 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -16,28 +16,38 @@ def fori_loop(upper, lower, body_fun, init_val, *input_value): device = xm.xla_device() + + output_value = torch.zeros([20], dtype=torch.float32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) def cond_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): - new_lower = torch.add(one_value, lower) - output_value = body_fun(*input_value) - if (hasattr(body_fun, 'weight') and hasattr(body_fun, 'bias')): + if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): + def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): + new_lower = torch.add(one_value, lower) + output_value = body_fun(*input_value) weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement return upper, new_lower, one_value, torch.add(one_value, x), *input_value, weight, bias, output_value - else: - return upper, new_lower, one_value, torch.add(one_value, x), *input_value, output_value - - output_value = torch.zeros([20], dtype=torch.float32, device=device) - one_value = torch.tensor([1], dtype=torch.int32, device=device) - if (hasattr(body_fun, 'weight') and hasattr(body_fun, 'bias')): weight_0 = body_fun.weight bias_0 = body_fun.bias res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, *input_value, weight_0, bias_0, output_value)) else: + def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): + new_lower = torch.add(one_value, lower) + output_value = body_fun(*input_value) + return upper, new_lower, one_value, torch.add(one_value, x), *input_value, output_value res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, *input_value, output_value)) + + # output_value = torch.zeros([20], dtype=torch.float32, device=device) + # one_value = torch.tensor([1], dtype=torch.int32, device=device) + # if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): + # weight_0 = body_fun.weight + # bias_0 = body_fun.bias + # res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, *input_value, weight_0, bias_0, output_value)) + # else: + # res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, *input_value, output_value)) return res From 0d218c8e208247ff339b7011df4209234f3168c3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 19:58:39 +0000 Subject: [PATCH 276/546] update --- torch_xla/experimental/fori_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 47b998857ff..bdc03d21981 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -24,6 +24,8 @@ def cond_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bi return lower[0] < upper[0] if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): + print("body_fun.weight: ", body_fun.weight) + print("body_fun.bias: ", body_fun.bias) def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): new_lower = torch.add(one_value, lower) output_value = body_fun(*input_value) From bae4952783bfe1ac464297bd05c892041229be55 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 20:02:30 +0000 Subject: [PATCH 277/546] update --- torch_xla/experimental/fori_loop.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index bdc03d21981..b09632dee2a 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -12,21 +12,34 @@ from torch._higher_order_ops.while_loop import while_loop_op +# /////////////// +# def cond_fn(upper, lower, one_value, x, input_value, output_value): +# return lower[0] < upper[0] + +# def body_fn(upper, lower, one_value, x, input_value, output_value): +# new_lower = torch.add(one_value, lower) +# output_value = linear_0(input_value) +# weight = linear_0.weight +# bias = linear_0.bias +# return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() +# /////////////// + + ### TODO(@manfei): delete one_value? def fori_loop(upper, lower, body_fun, init_val, *input_value): device = xm.xla_device() - + output_value = torch.zeros([20], dtype=torch.float32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) - def cond_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): + def cond_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): return lower[0] < upper[0] if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): print("body_fun.weight: ", body_fun.weight) print("body_fun.bias: ", body_fun.bias) - def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): + def body_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): new_lower = torch.add(one_value, lower) output_value = body_fun(*input_value) weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement From 7bb4791dd8c2a2bb755f894fada8c9c6cb22704c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 20:07:04 +0000 Subject: [PATCH 278/546] update --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index b09632dee2a..73b9362f607 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -37,8 +37,8 @@ def cond_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bia return lower[0] < upper[0] if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): - print("body_fun.weight: ", body_fun.weight) - print("body_fun.bias: ", body_fun.bias) + # print("body_fun.weight: ", body_fun.weight) + # print("body_fun.bias: ", body_fun.bias) def body_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): new_lower = torch.add(one_value, lower) output_value = body_fun(*input_value) From 114531545a22c2045010b58fc10e5a3faf71e172 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 20:11:19 +0000 Subject: [PATCH 279/546] update --- torch_xla/experimental/fori_loop.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 73b9362f607..af36c8d11c7 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -22,10 +22,22 @@ # weight = linear_0.weight # bias = linear_0.bias # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() + # upper = torch.tensor([1], dtype=torch.int32, device=device) + # lower = torch.tensor([0], dtype=torch.int32, device=device) + # one_value = torch.tensor([1], dtype=torch.int32, device=device) + # init_val = torch.tensor([1], dtype=torch.int32, device=device) + # # l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value + # l_in_0 = torch.rand(10, device=xm.xla_device()) # input_value + # output_value = torch.zeros([20], dtype=torch.float32, device=device) + # # weight_0 = linear_0.weight + # # bias_0 = linear_0.bias + + # upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) + # /////////////// -### TODO(@manfei): delete one_value? +### TODO(@manfei): treat *input_value def fori_loop(upper, lower, body_fun, init_val, *input_value): device = xm.xla_device() @@ -41,13 +53,13 @@ def cond_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bia # print("body_fun.bias: ", body_fun.bias) def body_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): new_lower = torch.add(one_value, lower) - output_value = body_fun(*input_value) + output_value = body_fun(input_value) weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement - return upper, new_lower, one_value, torch.add(one_value, x), *input_value, weight, bias, output_value + return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value weight_0 = body_fun.weight bias_0 = body_fun.bias - res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, *input_value, weight_0, bias_0, output_value)) + res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) else: def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): new_lower = torch.add(one_value, lower) From bad4c1fe3338a9dbaac1dd94093dd6785d4e6258 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 20:13:18 +0000 Subject: [PATCH 280/546] update --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index af36c8d11c7..2f603b427ae 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -84,8 +84,8 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): ### cond_fn&body_fn: callable ### carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) print("arrive here @while_loop_op.py_impl(DispatchKey.XLA) !!!") - # print("carried_inputs: ", carried_inputs) - # print("additional_inputs: ", additional_inputs) + print("carried_inputs: ", carried_inputs) + print("additional_inputs: ", additional_inputs) if additional_inputs is None: additional_inputs = tuple() return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) From 00cd07c103cdd536f4f0032b302889d03c47a190 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 20:18:23 +0000 Subject: [PATCH 281/546] update --- torch_xla/experimental/fori_loop.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 2f603b427ae..b141717a157 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -45,12 +45,11 @@ def fori_loop(upper, lower, body_fun, init_val, *input_value): output_value = torch.zeros([20], dtype=torch.float32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) - def cond_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): - return lower[0] < upper[0] - if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): # print("body_fun.weight: ", body_fun.weight) # print("body_fun.bias: ", body_fun.bias) + def cond_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): + return lower[0] < upper[0] def body_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): new_lower = torch.add(one_value, lower) output_value = body_fun(input_value) @@ -61,10 +60,12 @@ def body_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bia bias_0 = body_fun.bias res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) else: - def body_fn(upper, lower, one_value, x, *input_value, weight_0, output_value, bias_0): + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) - output_value = body_fun(*input_value) - return upper, new_lower, one_value, torch.add(one_value, x), *input_value, output_value + output_value = body_fun(input_value) + return upper, new_lower, one_value, torch.add(one_value, x), input_value, output_value res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, *input_value, output_value)) # output_value = torch.zeros([20], dtype=torch.float32, device=device) From 629f42e444cdaa39926701c9d0048e52daa061d9 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 20:20:32 +0000 Subject: [PATCH 282/546] update --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index b141717a157..fe125f4e54e 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -38,7 +38,7 @@ ### TODO(@manfei): treat *input_value -def fori_loop(upper, lower, body_fun, init_val, *input_value): +def fori_loop(upper, lower, body_fun, init_val, input_value): device = xm.xla_device() From d7bab3854838427ed407b8e87ecfb772952c83ef Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 20:25:18 +0000 Subject: [PATCH 283/546] update --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index fe125f4e54e..0436cb449a0 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -48,9 +48,9 @@ def fori_loop(upper, lower, body_fun, init_val, input_value): if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): # print("body_fun.weight: ", body_fun.weight) # print("body_fun.bias: ", body_fun.bias) - def cond_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): + def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, weight_0, output_value, bias_0): + def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value = body_fun(input_value) weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement From 2e485a50cc49e40e5b4cd8b57b65a26122add77b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 20:34:43 +0000 Subject: [PATCH 284/546] update --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 0436cb449a0..ef5f2156043 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -56,8 +56,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value - weight_0 = body_fun.weight - bias_0 = body_fun.bias + # weight_0 = body_fun.weight + # bias_0 = body_fun.bias res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) else: def cond_fn(upper, lower, one_value, x, input_value, output_value): From 5ec118425c2300dfb9601e93aaba9301344d669c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 20:38:38 +0000 Subject: [PATCH 285/546] update --- torch_xla/experimental/fori_loop.py | 43 ++++++++++++----------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index ef5f2156043..0073d9d1ccd 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -12,31 +12,6 @@ from torch._higher_order_ops.while_loop import while_loop_op -# /////////////// -# def cond_fn(upper, lower, one_value, x, input_value, output_value): -# return lower[0] < upper[0] - -# def body_fn(upper, lower, one_value, x, input_value, output_value): -# new_lower = torch.add(one_value, lower) -# output_value = linear_0(input_value) -# weight = linear_0.weight -# bias = linear_0.bias -# return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() - # upper = torch.tensor([1], dtype=torch.int32, device=device) - # lower = torch.tensor([0], dtype=torch.int32, device=device) - # one_value = torch.tensor([1], dtype=torch.int32, device=device) - # init_val = torch.tensor([1], dtype=torch.int32, device=device) - # # l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value - # l_in_0 = torch.rand(10, device=xm.xla_device()) # input_value - # output_value = torch.zeros([20], dtype=torch.float32, device=device) - # # weight_0 = linear_0.weight - # # bias_0 = linear_0.bias - - # upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) - -# /////////////// - - ### TODO(@manfei): treat *input_value def fori_loop(upper, lower, body_fun, init_val, input_value): @@ -45,6 +20,21 @@ def fori_loop(upper, lower, body_fun, init_val, input_value): output_value = torch.zeros([20], dtype=torch.float32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) +# ///////// +# def cond_fn(upper, lower, one_value, x, input_value, output_value): +# return lower[0] < upper[0] + +# def body_fn(upper, lower, one_value, x, input_value, output_value): +# new_lower = torch.add(one_value, lower) +# # output_value_real = linear_0(input_value) +# output_value = linear_0(input_value) +# weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement +# bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement +# # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real.clone(), weight.clone(), bias.clone() # , output_value_real +# return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() # , output_value_real + +# ///////// + if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): # print("body_fun.weight: ", body_fun.weight) # print("body_fun.bias: ", body_fun.bias) @@ -55,7 +45,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): output_value = body_fun(input_value) weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement - return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value + # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() # # weight_0 = body_fun.weight # bias_0 = body_fun.bias res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) From db3c8bfc11730a3df3ed0bf65f39f96da478b8bb Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 20:40:07 +0000 Subject: [PATCH 286/546] update --- ...ith_while_loop_simple_add_dispatch_in_torch.py | 2 +- torch_xla/experimental/fori_loop.py | 15 --------------- 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 7ae046f696b..1cdc0e054af 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -96,7 +96,7 @@ def test_while_loop_tpu_simple_linear(self): # def forward(self, upper, lower, one_value, x, input_value, output_value): linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) # weight_0 = linear_0.weight - bias_0 = linear_0.bias + # bias_0 = linear_0.bias # print("original bias: ", bias_0) def cond_fn(upper, lower, one_value, x, input_value, output_value): diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 0073d9d1ccd..15b12464d2c 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -20,21 +20,6 @@ def fori_loop(upper, lower, body_fun, init_val, input_value): output_value = torch.zeros([20], dtype=torch.float32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) -# ///////// -# def cond_fn(upper, lower, one_value, x, input_value, output_value): -# return lower[0] < upper[0] - -# def body_fn(upper, lower, one_value, x, input_value, output_value): -# new_lower = torch.add(one_value, lower) -# # output_value_real = linear_0(input_value) -# output_value = linear_0(input_value) -# weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement -# bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement -# # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real.clone(), weight.clone(), bias.clone() # , output_value_real -# return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() # , output_value_real - -# ///////// - if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): # print("body_fun.weight: ", body_fun.weight) # print("body_fun.bias: ", body_fun.bias) From 14f056962e16a2c9c0d7d7d3141d9c4dda52c179 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 20:41:45 +0000 Subject: [PATCH 287/546] update --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 15b12464d2c..c2ba37f2473 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -16,6 +16,7 @@ def fori_loop(upper, lower, body_fun, init_val, input_value): device = xm.xla_device() + body_fun = torch.nn.Linear(10, 20).to(xm.xla_device()) output_value = torch.zeros([20], dtype=torch.float32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) From d07015a03effef894117ad1075bbe93e9e19fa35 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 20:46:33 +0000 Subject: [PATCH 288/546] update --- torch_xla/experimental/fori_loop.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index c2ba37f2473..40992de631e 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -16,11 +16,25 @@ def fori_loop(upper, lower, body_fun, init_val, input_value): device = xm.xla_device() - body_fun = torch.nn.Linear(10, 20).to(xm.xla_device()) + # body_fun = torch.nn.Linear(10, 20).to(xm.xla_device()) output_value = torch.zeros([20], dtype=torch.float32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value = body_fun(input_value) + weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement + bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement + # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() + # weight_0 = body_fun.weight + # bias_0 = body_fun.bias + res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) + return res + if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): # print("body_fun.weight: ", body_fun.weight) # print("body_fun.bias: ", body_fun.bias) @@ -32,7 +46,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() # + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() # weight_0 = body_fun.weight # bias_0 = body_fun.bias res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) @@ -43,7 +57,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value = body_fun(input_value) return upper, new_lower, one_value, torch.add(one_value, x), input_value, output_value - res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, *input_value, output_value)) + res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) # output_value = torch.zeros([20], dtype=torch.float32, device=device) # one_value = torch.tensor([1], dtype=torch.int32, device=device) From 55061c3a17f15d342ae8a0cd13b29b75eae70502 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:00:05 +0000 Subject: [PATCH 289/546] update --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 40992de631e..77eb4485f78 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -10,6 +10,7 @@ from torch._ops import HigherOrderOperator import torch._higher_order_ops.while_loop from torch._higher_order_ops.while_loop import while_loop_op +from torch._higher_order_ops.while_loop import while_loop ### TODO(@manfei): treat *input_value From f1a596a724d12c086d2f1826f12931ddad5ea391 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:01:28 +0000 Subject: [PATCH 290/546] update --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 77eb4485f78..e634fb797a2 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -10,7 +10,7 @@ from torch._ops import HigherOrderOperator import torch._higher_order_ops.while_loop from torch._higher_order_ops.while_loop import while_loop_op -from torch._higher_order_ops.while_loop import while_loop +from torch._higher_order_ops.while_loop import while_loop as torch_while_loop ### TODO(@manfei): treat *input_value @@ -33,7 +33,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() # weight_0 = body_fun.weight # bias_0 = body_fun.bias - res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) + res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) return res if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): From 3818380dcdfc6f3fe2abe309a617d93c1e9ba671 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:02:56 +0000 Subject: [PATCH 291/546] update --- torch_xla/experimental/fori_loop.py | 30 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index e634fb797a2..2ecb88fc1cc 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -22,19 +22,19 @@ def fori_loop(upper, lower, body_fun, init_val, input_value): output_value = torch.zeros([20], dtype=torch.float32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) - def cond_fn(upper, lower, one_value, x, input_value, output_value): - return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, output_value): - new_lower = torch.add(one_value, lower) - output_value = body_fun(input_value) - weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement - bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement - # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() - # weight_0 = body_fun.weight - # bias_0 = body_fun.bias - res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) - return res + # def cond_fn(upper, lower, one_value, x, input_value, output_value): + # return lower[0] < upper[0] + # def body_fn(upper, lower, one_value, x, input_value, output_value): + # new_lower = torch.add(one_value, lower) + # output_value = body_fun(input_value) + # weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement + # bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement + # # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value + # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() + # # weight_0 = body_fun.weight + # # bias_0 = body_fun.bias + # res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) + # return res if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): # print("body_fun.weight: ", body_fun.weight) @@ -50,7 +50,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() # weight_0 = body_fun.weight # bias_0 = body_fun.bias - res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) + res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) else: def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] @@ -58,7 +58,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value = body_fun(input_value) return upper, new_lower, one_value, torch.add(one_value, x), input_value, output_value - res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) + res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) # output_value = torch.zeros([20], dtype=torch.float32, device=device) # one_value = torch.tensor([1], dtype=torch.int32, device=device) From b70a34d5f2e0e203821828dadd9b6c73203625a9 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:03:48 +0000 Subject: [PATCH 292/546] update --- torch_xla/experimental/fori_loop.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 2ecb88fc1cc..cf94532ca15 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -17,28 +17,11 @@ def fori_loop(upper, lower, body_fun, init_val, input_value): device = xm.xla_device() - # body_fun = torch.nn.Linear(10, 20).to(xm.xla_device()) output_value = torch.zeros([20], dtype=torch.float32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) - # def cond_fn(upper, lower, one_value, x, input_value, output_value): - # return lower[0] < upper[0] - # def body_fn(upper, lower, one_value, x, input_value, output_value): - # new_lower = torch.add(one_value, lower) - # output_value = body_fun(input_value) - # weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement - # bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement - # # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value - # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() - # # weight_0 = body_fun.weight - # # bias_0 = body_fun.bias - # res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) - # return res - if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): - # print("body_fun.weight: ", body_fun.weight) - # print("body_fun.bias: ", body_fun.bias) def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] def body_fn(upper, lower, one_value, x, input_value, output_value): @@ -46,10 +29,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): output_value = body_fun(input_value) weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement - # return upper, new_lower, one_value, torch.add(one_value, x), input_value, weight, bias, output_value return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() - # weight_0 = body_fun.weight - # bias_0 = body_fun.bias res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) else: def cond_fn(upper, lower, one_value, x, input_value, output_value): @@ -60,14 +40,6 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return upper, new_lower, one_value, torch.add(one_value, x), input_value, output_value res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) - # output_value = torch.zeros([20], dtype=torch.float32, device=device) - # one_value = torch.tensor([1], dtype=torch.int32, device=device) - # if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): - # weight_0 = body_fun.weight - # bias_0 = body_fun.bias - # res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, *input_value, weight_0, bias_0, output_value)) - # else: - # res = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, *input_value, output_value)) return res From 26d2fdb57486ff7cd246af243aa72518a31a78b2 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:06:46 +0000 Subject: [PATCH 293/546] update --- ...ri_loop_with_while_loop_simple_add_dispatch_in_torch.py | 1 + torch_xla/experimental/fori_loop.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 1cdc0e054af..322829a1e61 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -198,6 +198,7 @@ def body_fun(a, b): expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value) self.assertEqual(expected, res_) +# passed def test_fori_loop_tpu_simple_linear(self): xm.mark_step() diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index cf94532ca15..3aac50599a0 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -87,9 +87,10 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # !!! cond xlacomputation change !!! switch output_value and weight position additional_inputs_list_cond = list(fake_carried_inputs[2:]) ### all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor - tmp_bias = additional_inputs_list_cond[-3] ### not used, change order doesn't affect logic - del additional_inputs_list_cond[-3] ### not used, change order doesn't affect logic - additional_inputs_list_cond.append(tmp_bias) ### not used, change order doesn't affect logic + if additional_inputs: + tmp_bias = additional_inputs_list_cond[-3] ### not used, change order doesn't affect logic + del additional_inputs_list_cond[-3] ### not used, change order doesn't affect logic + additional_inputs_list_cond.append(tmp_bias) ### not used, change order doesn't affect logic cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) cond_hlo = cond_ctx.hlo() From c8e39fe8f883c565e1cb1a51201ee5ddbc03d49c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:07:45 +0000 Subject: [PATCH 294/546] update --- torch_xla/experimental/fori_loop.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 3aac50599a0..7140bafb039 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -106,7 +106,10 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_ctx.set_name_string("bodyctx") # !!! body xlacomputation change !!! add non-changed output_value argument - additional_inputs_list_body = [fake_carried_inputs[-3]] + if additional_inputs: + additional_inputs_list_body = [fake_carried_inputs[-3]] + else: + additional_inputs_list_body = [] ### TODO(@manfei): treat hard-code parameters: additional_inputs_list_body body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) From 12371764f300045e96241e112b4a8b02c2df5bb0 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:08:49 +0000 Subject: [PATCH 295/546] update --- torch_xla/experimental/fori_loop.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 7140bafb039..4c2920bbf1e 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -141,9 +141,10 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): ### TODO(@manfei): treat hard-code input arguments # !!! init change !!! switch bias and output_value - tmp_bias = params[-3] - del params[-3] - params.append(tmp_bias) + if additional_inputs: + tmp_bias = params[-3] + del params[-3] + params.append(tmp_bias) ### generate while xlacomputation input_tuple = xb.Op.tuple(tuple(params)) From 620d6270e71782833a84987cd9ba693739a2775a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:09:51 +0000 Subject: [PATCH 296/546] update --- torch_xla/experimental/fori_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 4c2920bbf1e..f5af1b8c967 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -85,7 +85,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # del additional_inputs_list_cond[-2] ### not used, change order doesn't affect logic # additional_inputs_list_cond.append(tmp_bias) ### not used, change order doesn't affect logic - # !!! cond xlacomputation change !!! switch output_value and weight position + # !!! cond xlacomputation change !!! switch output_value and weight position if additional_inputs(weight/bias) exists additional_inputs_list_cond = list(fake_carried_inputs[2:]) ### all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor if additional_inputs: tmp_bias = additional_inputs_list_cond[-3] ### not used, change order doesn't affect logic @@ -105,7 +105,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") - # !!! body xlacomputation change !!! add non-changed output_value argument + # !!! body xlacomputation change !!! add non-changed output_value argument if additional_inputs(weight/bias) exists if additional_inputs: additional_inputs_list_body = [fake_carried_inputs[-3]] else: @@ -140,7 +140,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # params.append(tmp_bias) ### TODO(@manfei): treat hard-code input arguments - # !!! init change !!! switch bias and output_value + # !!! init change !!! switch bias and output_value if additional_inputs(weight/bias) exists if additional_inputs: tmp_bias = params[-3] del params[-3] From cfa5f7e932a364bfb52eac25ec46e68b7e5019c2 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:10:06 +0000 Subject: [PATCH 297/546] update --- torch_xla/experimental/fori_loop.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index f5af1b8c967..ed6bf8438c5 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -79,12 +79,6 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") - # # !!! cond xlacomputation change !!! switch bias and weight position - # additional_inputs_list_cond = list(fake_carried_inputs[2:]) ### all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor - # tmp_bias = additional_inputs_list_cond[-2] ### not used, change order doesn't affect logic - # del additional_inputs_list_cond[-2] ### not used, change order doesn't affect logic - # additional_inputs_list_cond.append(tmp_bias) ### not used, change order doesn't affect logic - # !!! cond xlacomputation change !!! switch output_value and weight position if additional_inputs(weight/bias) exists additional_inputs_list_cond = list(fake_carried_inputs[2:]) ### all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor if additional_inputs: From 0cb1dce13221b29b5da139facfb8712a7b817ddc Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:13:41 +0000 Subject: [PATCH 298/546] update --- ...ori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 6 ++++-- torch_xla/experimental/fori_loop.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 322829a1e61..84dfe3e11d8 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -32,6 +32,7 @@ def _fake_fori_loop(lower, upper, body_fun, *init_val): class WhileLoopTest(unittest.TestCase): +# passed def test_while_loop_tpu_subtraction(self): device = xm.xla_device() @@ -50,6 +51,7 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) +# passed def test_while_loop_tpu_addition(self): device = xm.xla_device() @@ -68,6 +70,7 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) +# passed def test_while_loop_tpu_subtraction_nested(self): device = xm.xla_device() @@ -193,8 +196,7 @@ def test_fori_loop_tpu_addition(self): def body_fun(a, b): return torch.add(a, b) - lower_, upper_, res_ = fori_loop(upper, lower, body_fun, one_value, - init_val) + lower_, upper_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val) expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value) self.assertEqual(expected, res_) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index ed6bf8438c5..e99953c9ec4 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -26,7 +26,7 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) - output_value = body_fun(input_value) + output_value = body_fun(input_value, one_value) weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() From 2dcfab0790dd95b103a2805e690d1fd3c71dc473 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:14:21 +0000 Subject: [PATCH 299/546] update --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index e99953c9ec4..da166177d01 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -26,7 +26,7 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) - output_value = body_fun(input_value, one_value) + output_value = body_fun(input_value) weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() @@ -36,7 +36,7 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) - output_value = body_fun(input_value) + output_value = body_fun(input_value, one_value) return upper, new_lower, one_value, torch.add(one_value, x), input_value, output_value res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) From d0270b6cf80720b099f5853fc2e4facace86baff Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:14:57 +0000 Subject: [PATCH 300/546] update --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index da166177d01..5c66d48010e 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -36,7 +36,7 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) - output_value = body_fun(input_value, one_value) + output_value = body_fun(one_value, input_value) return upper, new_lower, one_value, torch.add(one_value, x), input_value, output_value res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) From 2377b8b12e85d67c1a23759dd0801fa4093843fc Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:23:16 +0000 Subject: [PATCH 301/546] update --- torch_xla/experimental/fori_loop.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 5c66d48010e..4c25db6515c 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -12,6 +12,32 @@ from torch._higher_order_ops.while_loop import while_loop_op from torch._higher_order_ops.while_loop import while_loop as torch_while_loop +///////// + def test_while_loop_tpu_addition(self): + device = xm.xla_device() + def cond_fn(init, limit_value): + return limit_value[0] >= init[0] + def body_fn(init, limit_value): + one_value = torch.ones(1, dtype=torch.int32, device=device) + return (torch.add(init, one_value), limit_value.clone()) + # TODO(@manfei): init and limit_value has to be torch.tensor. + init = torch.tensor([0], dtype=torch.int32, device=device) + limit_value = torch.tensor([10], dtype=torch.int32, device=device) + res = while_loop(cond_fn, body_fn, (init, limit_value)) +///////// +def fori_loop(lower, upper, user_body_func, *init_val): + device = xm.xla_device() + def cond_fn(upper, lower, *init_val): + return lower[0] < upper[0] + def body_fn(upper, lower, *init_val): + one_value_i = torch.ones(1, dtype=torch.int32, device=device) + res_list = list(user_body_func(*init_val)) + res_list.insert(0, lower) + res_list.insert(0, torch.sub(upper, one_value_i)) + return res_list + res = while_loop(cond_fn, body_fn, (lower, upper, *init_val)) + return res +///////// ### TODO(@manfei): treat *input_value def fori_loop(upper, lower, body_fun, init_val, input_value): @@ -37,7 +63,7 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value = body_fun(one_value, input_value) - return upper, new_lower, one_value, torch.add(one_value, x), input_value, output_value + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone() res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) return res From 8947db7e6ec4b0233a0532f1c7b2637107400b8c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:23:40 +0000 Subject: [PATCH 302/546] update --- torch_xla/experimental/fori_loop.py | 52 ++++++++++++++--------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 4c25db6515c..eef46f7f0e6 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -12,32 +12,32 @@ from torch._higher_order_ops.while_loop import while_loop_op from torch._higher_order_ops.while_loop import while_loop as torch_while_loop -///////// - def test_while_loop_tpu_addition(self): - device = xm.xla_device() - def cond_fn(init, limit_value): - return limit_value[0] >= init[0] - def body_fn(init, limit_value): - one_value = torch.ones(1, dtype=torch.int32, device=device) - return (torch.add(init, one_value), limit_value.clone()) - # TODO(@manfei): init and limit_value has to be torch.tensor. - init = torch.tensor([0], dtype=torch.int32, device=device) - limit_value = torch.tensor([10], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, (init, limit_value)) -///////// -def fori_loop(lower, upper, user_body_func, *init_val): - device = xm.xla_device() - def cond_fn(upper, lower, *init_val): - return lower[0] < upper[0] - def body_fn(upper, lower, *init_val): - one_value_i = torch.ones(1, dtype=torch.int32, device=device) - res_list = list(user_body_func(*init_val)) - res_list.insert(0, lower) - res_list.insert(0, torch.sub(upper, one_value_i)) - return res_list - res = while_loop(cond_fn, body_fn, (lower, upper, *init_val)) - return res -///////// +# ///////// +# def test_while_loop_tpu_addition(self): +# device = xm.xla_device() +# def cond_fn(init, limit_value): +# return limit_value[0] >= init[0] +# def body_fn(init, limit_value): +# one_value = torch.ones(1, dtype=torch.int32, device=device) +# return (torch.add(init, one_value), limit_value.clone()) +# # TODO(@manfei): init and limit_value has to be torch.tensor. +# init = torch.tensor([0], dtype=torch.int32, device=device) +# limit_value = torch.tensor([10], dtype=torch.int32, device=device) +# res = while_loop(cond_fn, body_fn, (init, limit_value)) +# ///////// +# def fori_loop(lower, upper, user_body_func, *init_val): +# device = xm.xla_device() +# def cond_fn(upper, lower, *init_val): +# return lower[0] < upper[0] +# def body_fn(upper, lower, *init_val): +# one_value_i = torch.ones(1, dtype=torch.int32, device=device) +# res_list = list(user_body_func(*init_val)) +# res_list.insert(0, lower) +# res_list.insert(0, torch.sub(upper, one_value_i)) +# return res_list +# res = while_loop(cond_fn, body_fn, (lower, upper, *init_val)) +# return res +# ///////// ### TODO(@manfei): treat *input_value def fori_loop(upper, lower, body_fun, init_val, input_value): From 0fbb23d0049d2f61ee9e9b728ee552d8a9f98200 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:27:25 +0000 Subject: [PATCH 303/546] update --- torch_xla/experimental/fori_loop.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index eef46f7f0e6..a492e56a462 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -58,13 +58,13 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) else: - def cond_fn(upper, lower, one_value, x, input_value, output_value): + def cond_fn(upper, lower, one_value, x, input_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, output_value): + def body_fn(upper, lower, one_value, x, input_value): new_lower = torch.add(one_value, lower) - output_value = body_fun(one_value, input_value) - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone() - res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) + output_val = body_fun(one_value, input_value) + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), output_val.clone() + res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value)) return res From 53f818564c0a46ee18d3395b0c3dde0402318f38 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:29:29 +0000 Subject: [PATCH 304/546] update --- ...i_loop_with_while_loop_simple_add_dispatch_in_torch.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 84dfe3e11d8..ad14be19f40 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -196,7 +196,13 @@ def test_fori_loop_tpu_addition(self): def body_fun(a, b): return torch.add(a, b) - lower_, upper_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val) + # lower_, upper_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val) + upper_, new_lower_, one_value_, add_res_x_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val) + print("upper_: ", upper_) + print("new_lower_: ", new_lower_) + print("one_value_: ", one_value_) + print("add_res_x_: ", add_res_x_) + print("res_: ", res_) expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value) self.assertEqual(expected, res_) From 52435ecbf6152d10471b06d0189f3b79528b98db Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:33:18 +0000 Subject: [PATCH 305/546] update --- torch_xla/experimental/fori_loop.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index a492e56a462..0696977b069 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -44,10 +44,11 @@ def fori_loop(upper, lower, body_fun, init_val, input_value): device = xm.xla_device() - output_value = torch.zeros([20], dtype=torch.float32, device=device) + # output_value = torch.zeros([20], dtype=torch.float32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): + output_value = torch.zeros([20], dtype=torch.float32, device=device) def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] def body_fn(upper, lower, one_value, x, input_value, output_value): @@ -58,13 +59,15 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) else: - def cond_fn(upper, lower, one_value, x, input_value): + # output_value = torch.zeros([1], dtype=torch.float32, device=device) + output_value = torch.tensor([1], dtype=torch.int32, device=device) + def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value): + def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_val = body_fun(one_value, input_value) - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), output_val.clone() - res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value)) + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_val.clone() + res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) return res From 116a68bc9b19ac61ae60e92af25a08fea41a0256 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:39:13 +0000 Subject: [PATCH 306/546] update --- torch_xla/experimental/fori_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 0696977b069..950e26d4e3e 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -132,7 +132,8 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): if additional_inputs: additional_inputs_list_body = [fake_carried_inputs[-3]] else: - additional_inputs_list_body = [] + # add fake output_value to do map and not reuse output in the next turn + additional_inputs_list_body = [fake_carried_inputs[-1]] ### TODO(@manfei): treat hard-code parameters: additional_inputs_list_body body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) From 3cb631a21bd1d952d9a3d089dac73882de486714 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:40:53 +0000 Subject: [PATCH 307/546] update --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++++ torch_xla/experimental/fori_loop.py | 3 +-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index ad14be19f40..9a49d6b5c2b 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -49,6 +49,8 @@ def body_fn(init, limit_value): limit_value = torch.tensor([0], dtype=torch.int32, device=device) res = while_loop(cond_fn, body_fn, (init, limit_value)) expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) + print("expected: ", expected) + print("res: ", res) self.assertEqual(expected, res) # passed @@ -69,6 +71,8 @@ def body_fn(init, limit_value): res = while_loop(cond_fn, body_fn, (init, limit_value)) expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) + print("expected: ", expected) + print("res: ", res) # passed def test_while_loop_tpu_subtraction_nested(self): diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 950e26d4e3e..0696977b069 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -132,8 +132,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): if additional_inputs: additional_inputs_list_body = [fake_carried_inputs[-3]] else: - # add fake output_value to do map and not reuse output in the next turn - additional_inputs_list_body = [fake_carried_inputs[-1]] + additional_inputs_list_body = [] ### TODO(@manfei): treat hard-code parameters: additional_inputs_list_body body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) From fe3a5302d314630af0dfcb3d01f1338acc50238f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:42:22 +0000 Subject: [PATCH 308/546] update --- torch_xla/experimental/fori_loop.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 0696977b069..32513a59936 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -61,13 +61,13 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): else: # output_value = torch.zeros([1], dtype=torch.float32, device=device) output_value = torch.tensor([1], dtype=torch.int32, device=device) - def cond_fn(upper, lower, one_value, x, input_value, output_value): + def cond_fn(upper, lower, one_value, x, input_value): # , output_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, output_value): + def body_fn(upper, lower, one_value, x, input_value): # , output_value): new_lower = torch.add(one_value, lower) output_val = body_fun(one_value, input_value) - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_val.clone() - res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), output_val.clone() + res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value)) return res From ac574f35f9be4630e532f8d8e8961612c077de7e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:44:08 +0000 Subject: [PATCH 309/546] update --- ...est_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 9a49d6b5c2b..6d7009b04ce 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -208,6 +208,7 @@ def body_fun(a, b): print("add_res_x_: ", add_res_x_) print("res_: ", res_) expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value) + print("expected: ", expected) self.assertEqual(expected, res_) # passed From ca3e7576e3885c103cb8d40cac737df13bc4450f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:46:51 +0000 Subject: [PATCH 310/546] update --- ...t_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 6d7009b04ce..1cacbdee9f6 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -23,7 +23,8 @@ def _fake_fori_loop(lower, upper, body_fun, *init_val): if len(init_val) > 1: (a, b) = init_val for i in range((upper - lower)[0]): - a = body_fun(*init_val) + # a = body_fun(*init_val) + a = body_fun(a, b) else: for i in range((upper - lower)[0]): a = body_fun(*init_val) From 18194202a1444db45602395fa544eb51aef54f47 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:49:25 +0000 Subject: [PATCH 311/546] update --- ...while_loop_simple_add_dispatch_in_torch.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 1cacbdee9f6..29ae7e108c0 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -50,8 +50,8 @@ def body_fn(init, limit_value): limit_value = torch.tensor([0], dtype=torch.int32, device=device) res = while_loop(cond_fn, body_fn, (init, limit_value)) expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) - print("expected: ", expected) - print("res: ", res) + # print("expected: ", expected) + # print("res: ", res) self.assertEqual(expected, res) # passed @@ -72,8 +72,8 @@ def body_fn(init, limit_value): res = while_loop(cond_fn, body_fn, (init, limit_value)) expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) - print("expected: ", expected) - print("res: ", res) + # print("expected: ", expected) + # print("res: ", res) # passed def test_while_loop_tpu_subtraction_nested(self): @@ -188,6 +188,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) return aaa +# passed def test_fori_loop_tpu_addition(self): xm.mark_step() @@ -203,13 +204,13 @@ def body_fun(a, b): # lower_, upper_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val) upper_, new_lower_, one_value_, add_res_x_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val) - print("upper_: ", upper_) - print("new_lower_: ", new_lower_) - print("one_value_: ", one_value_) - print("add_res_x_: ", add_res_x_) - print("res_: ", res_) + # print("upper_: ", upper_) + # print("new_lower_: ", new_lower_) + # print("one_value_: ", one_value_) + # print("add_res_x_: ", add_res_x_) + # print("res_: ", res_) expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value) - print("expected: ", expected) + # print("expected: ", expected) self.assertEqual(expected, res_) # passed From 1b91fc82a12eb429e35f2844c35967ff87491910 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:49:51 +0000 Subject: [PATCH 312/546] update --- torch_xla/experimental/fori_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 32513a59936..863ba0c6d1f 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -78,8 +78,8 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): ### cond_fn&body_fn: callable ### carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) print("arrive here @while_loop_op.py_impl(DispatchKey.XLA) !!!") - print("carried_inputs: ", carried_inputs) - print("additional_inputs: ", additional_inputs) + # print("carried_inputs: ", carried_inputs) + # print("additional_inputs: ", additional_inputs) if additional_inputs is None: additional_inputs = tuple() return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) @@ -101,7 +101,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): fake_carried_inputs.append( torch.randint(10, additional_input.size(), dtype=additional_input.dtype).to(device)) - print("fake_carried_inputs: ", fake_carried_inputs) + # print("fake_carried_inputs: ", fake_carried_inputs) ### TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) From 673145d105d15c965bf72bbe8bc160c7df6307ab Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 21:50:44 +0000 Subject: [PATCH 313/546] update --- torch_xla/experimental/fori_loop.py | 28 +--------------------------- 1 file changed, 1 insertion(+), 27 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 863ba0c6d1f..1c4394d73bd 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -12,32 +12,6 @@ from torch._higher_order_ops.while_loop import while_loop_op from torch._higher_order_ops.while_loop import while_loop as torch_while_loop -# ///////// -# def test_while_loop_tpu_addition(self): -# device = xm.xla_device() -# def cond_fn(init, limit_value): -# return limit_value[0] >= init[0] -# def body_fn(init, limit_value): -# one_value = torch.ones(1, dtype=torch.int32, device=device) -# return (torch.add(init, one_value), limit_value.clone()) -# # TODO(@manfei): init and limit_value has to be torch.tensor. -# init = torch.tensor([0], dtype=torch.int32, device=device) -# limit_value = torch.tensor([10], dtype=torch.int32, device=device) -# res = while_loop(cond_fn, body_fn, (init, limit_value)) -# ///////// -# def fori_loop(lower, upper, user_body_func, *init_val): -# device = xm.xla_device() -# def cond_fn(upper, lower, *init_val): -# return lower[0] < upper[0] -# def body_fn(upper, lower, *init_val): -# one_value_i = torch.ones(1, dtype=torch.int32, device=device) -# res_list = list(user_body_func(*init_val)) -# res_list.insert(0, lower) -# res_list.insert(0, torch.sub(upper, one_value_i)) -# return res_list -# res = while_loop(cond_fn, body_fn, (lower, upper, *init_val)) -# return res -# ///////// ### TODO(@manfei): treat *input_value def fori_loop(upper, lower, body_fun, init_val, input_value): @@ -77,7 +51,7 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): ### TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') ### cond_fn&body_fn: callable ### carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) - print("arrive here @while_loop_op.py_impl(DispatchKey.XLA) !!!") + # print("arrive here @while_loop_op.py_impl(DispatchKey.XLA) !!!") # print("carried_inputs: ", carried_inputs) # print("additional_inputs: ", additional_inputs) if additional_inputs is None: From 12f6f71716785bc2fd7940aa947f28839d074a25 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 22:00:51 +0000 Subject: [PATCH 314/546] update --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 1c4394d73bd..80bc209aacf 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -60,7 +60,7 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): - print("arrive here _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): !!!") + # print("arrive here _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): !!!") ### fake carried_inputs to split formal code fake_carried_inputs = [] for carried_input in carried_inputs: From 89605cbb6b91336c3012bc715b0ae870eaa26f2e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 22:38:43 +0000 Subject: [PATCH 315/546] format --- ...while_loop_simple_add_dispatch_in_torch.py | 37 +------------------ torch_xla/experimental/fori_loop.py | 36 +++--------------- 2 files changed, 7 insertions(+), 66 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 29ae7e108c0..73d55d3dc47 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -23,7 +23,6 @@ def _fake_fori_loop(lower, upper, body_fun, *init_val): if len(init_val) > 1: (a, b) = init_val for i in range((upper - lower)[0]): - # a = body_fun(*init_val) a = body_fun(a, b) else: for i in range((upper - lower)[0]): @@ -33,7 +32,6 @@ def _fake_fori_loop(lower, upper, body_fun, *init_val): class WhileLoopTest(unittest.TestCase): -# passed def test_while_loop_tpu_subtraction(self): device = xm.xla_device() @@ -50,11 +48,8 @@ def body_fn(init, limit_value): limit_value = torch.tensor([0], dtype=torch.int32, device=device) res = while_loop(cond_fn, body_fn, (init, limit_value)) expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) - # print("expected: ", expected) - # print("res: ", res) self.assertEqual(expected, res) -# passed def test_while_loop_tpu_addition(self): device = xm.xla_device() @@ -72,10 +67,7 @@ def body_fn(init, limit_value): res = while_loop(cond_fn, body_fn, (init, limit_value)) expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) - # print("expected: ", expected) - # print("res: ", res) -# passed def test_while_loop_tpu_subtraction_nested(self): device = xm.xla_device() @@ -94,54 +86,37 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) -# passed def test_while_loop_tpu_simple_linear(self): xm.mark_step() device = xm.xla_device() torch.set_grad_enabled(False) - # def forward(self, upper, lower, one_value, x, input_value, output_value): linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - # weight_0 = linear_0.weight - # bias_0 = linear_0.bias - # print("original bias: ", bias_0) def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) - # output_value_real = linear_0(input_value) output_value = linear_0(input_value) weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement - # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real.clone(), weight.clone(), bias.clone() # , output_value_real return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() # , output_value_real upper = torch.tensor([1], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) - # l_in_0 = torch.ones(10, device=xm.xla_device()) # input_value - l_in_0 = torch.rand(10, device=xm.xla_device()) # input_value + l_in_0 = torch.rand(10, device=xm.xla_device()) output_value = torch.zeros([20], dtype=torch.float32, device=device) - # weight_0 = linear_0.weight - # bias_0 = linear_0.bias upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) - # print("torch_add_res__: ", torch_add_res__) - # print("output_value_real__: ", output_value_real__) - # print("bias__: ", bias__) - # print("expected: ", expected) - - # self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) return self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) -# passed def test_while_loop_tpu_simple_linear_class(self): xm.mark_step() @@ -177,6 +152,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): bias_0 = simple_with_linear.linear.bias aaa = {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) # same weight/bias liear model @@ -188,7 +164,6 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) return aaa -# passed def test_fori_loop_tpu_addition(self): xm.mark_step() @@ -202,18 +177,10 @@ def test_fori_loop_tpu_addition(self): def body_fun(a, b): return torch.add(a, b) - # lower_, upper_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val) upper_, new_lower_, one_value_, add_res_x_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val) - # print("upper_: ", upper_) - # print("new_lower_: ", new_lower_) - # print("one_value_: ", one_value_) - # print("add_res_x_: ", add_res_x_) - # print("res_: ", res_) expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value) - # print("expected: ", expected) self.assertEqual(expected, res_) -# passed def test_fori_loop_tpu_simple_linear(self): xm.mark_step() diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 80bc209aacf..da1ebb4a0c9 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -18,7 +18,6 @@ def fori_loop(upper, lower, body_fun, init_val, input_value): device = xm.xla_device() - # output_value = torch.zeros([20], dtype=torch.float32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): @@ -33,11 +32,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) else: - # output_value = torch.zeros([1], dtype=torch.float32, device=device) output_value = torch.tensor([1], dtype=torch.int32, device=device) - def cond_fn(upper, lower, one_value, x, input_value): # , output_value): + def cond_fn(upper, lower, one_value, x, input_value): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value): # , output_value): + def body_fn(upper, lower, one_value, x, input_value): new_lower = torch.add(one_value, lower) output_val = body_fun(one_value, input_value) return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), output_val.clone() @@ -51,38 +49,31 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): ### TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') ### cond_fn&body_fn: callable ### carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) - # print("arrive here @while_loop_op.py_impl(DispatchKey.XLA) !!!") - # print("carried_inputs: ", carried_inputs) - # print("additional_inputs: ", additional_inputs) if additional_inputs is None: additional_inputs = tuple() return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): - # print("arrive here _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): !!!") ### fake carried_inputs to split formal code fake_carried_inputs = [] for carried_input in carried_inputs: device = carried_input.device - ###TODO(@manfei) type = carried_input.type fake_carried_inputs.append( torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) for additional_input in additional_inputs: device = additional_input.device - ###TODO(@manfei) type = carried_input.type fake_carried_inputs.append( torch.randint(10, additional_input.size(), dtype=additional_input.dtype).to(device)) - # print("fake_carried_inputs: ", fake_carried_inputs) ### TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") - # !!! cond xlacomputation change !!! switch output_value and weight position if additional_inputs(weight/bias) exists + ### TODO(@manfei): treat hard-code cond xlacomputation change: currently switch output_value and weight position if additional_inputs(weight/bias) exists additional_inputs_list_cond = list(fake_carried_inputs[2:]) ### all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor if additional_inputs: tmp_bias = additional_inputs_list_cond[-3] ### not used, change order doesn't affect logic @@ -93,16 +84,13 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - # cond_hlo_print = xb.get_computation_hlo(cond_computation) - # print("cond computation: !!!!!!!!!") - # print(cond_hlo_print) ### generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") - # !!! body xlacomputation change !!! add non-changed output_value argument if additional_inputs(weight/bias) exists + ### TODO(@manfei): treat hard-code body xlacomputation change: currently add non-changed output_value argument if additional_inputs(weight/bias) exists if additional_inputs: additional_inputs_list_body = [fake_carried_inputs[-3]] else: @@ -113,9 +101,6 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - # body_hlo_print = xb.get_computation_hlo(body_computation) - # print("body computation: !!!!!!!!!") - # print(body_hlo_print) ### trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs @@ -130,14 +115,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): p = xb.mkparam(builder, len(params), shape) params.append(p) - # ### TODO(@manfei): treat hard-code input arguments - # # !!! init change !!! - # tmp_bias = params[-2] - # del params[-2] - # params.append(tmp_bias) - - ### TODO(@manfei): treat hard-code input arguments - # !!! init change !!! switch bias and output_value if additional_inputs(weight/bias) exists + ### TODO(@manfei): treat hard-code input arguments, currently switch bias and output_value if additional_inputs(weight/bias) exists if additional_inputs: tmp_bias = params[-3] del params[-3] @@ -151,14 +129,10 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) - # hlo_print = xb.get_computation_hlo(computation) - # print("while computation: !!!!!!!!!") - # print(hlo_print) ### gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', (total_inputs), computation) - ### print("!!! arrive here too after while!!!") return result \ No newline at end of file From 2e9c979997e146b51a420ecc61d6a1b2ce93d4b2 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 22:44:27 +0000 Subject: [PATCH 316/546] format --- torch_xla/csrc/init_python_bindings.cpp | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c7013569dd1..f48dcf9eb68 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -914,18 +914,10 @@ class PyLoweringContext { // needed in xlacomputation. void BuildForiLoop(std::vector tensors, std::vector additional_inputs_list = {}) { + // hard-code modify cond xlacomputation input arguments with unusedarguments for xla::while requriement if (GetNameString() == "condctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); - // hard-code parameter_idx to 2 to skip existing upper/lower arguments - // TODO(@manfei): get body xlacomputation arguments' number first then decide items in `additional_inputs_list`, maybe implement in python level - // !!! since cond_fn only compare upper and lower, so it would only use two arguments, due to PyTorch/XLA - // !!! trace xlacomputation from result tensor, so all the other arguments would not be included or generated; - // !!! but to meet xla::while requirement, we would skip first two arguments, - // !!! then add all other arguments like body_fn/init - // !!! --- additional_inputs_list: this list include all other arguments like body_fn/init except upper and lower - // !!! --- next step: we add dump paras according to additional_inputs_list - // ??? --- could we get IRvalue of `additional_inputs_list` in this function to complete xlacomputation? - int64_t parameter_idx = 2; // parameter_idx start from 2 after upper and lower + int64_t parameter_idx = 2; // parameter_idx start from 2 after used upper and lower for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); xla::Shape shape = xtensor->shape().get(); @@ -935,11 +927,11 @@ class PyLoweringContext { } } - // hard-code modify body xlacomputation input arguments - // TODO(@manfei): get body xlacomputation arguments' number first then decide items in `additional_inputs_list`, maybe implement in python level + // hard-code modify body xlacomputation input arguments with unusedarguments for xla::while requriement if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); - int64_t parameter_idx = 7; // tensors.size(); + // TODO(@manfei): treat hard code parameter_idx value + int64_t parameter_idx = 7; for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); xla::Shape shape = xtensor->shape().get(); From 6db813948ce2e9bfb3c526f68d5100f7d7d61e06 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 22:45:32 +0000 Subject: [PATCH 317/546] format --- torch_xla/csrc/init_python_bindings.cpp | 16 +--------------- torch_xla/csrc/lowering_context.cpp | 2 -- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index f48dcf9eb68..85545ae267e 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -966,20 +966,6 @@ class PyLoweringContext { std::vector buffer_donor_indices; xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); - // // hard-code modify body xlacomputation input arguments - // // xxx: failed due to not change body_xlacomputation, might becase has been traced - // // xxx: after `computation = ConsumeValue(lowering_ctx.BuildXla());` - // if (GetNameString() == "bodyctx") { - // xla::XlaBuilder* local_builder = lowering_ctx.builder(); - // int64_t parameter_idx = program_shape.parameters_size(); // tensors.size(); - // for (auto& additional_input_tensor : additional_inputs_list) { - // XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); - // xla::Shape shape = xtensor->shape().get(); - // xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, - // "UnusedArgumentsPlaceholder"); - // parameter_idx += 1; - // } - // } // TODO(@manfei): please confirm whether we check for more than two or use // default value true bool should_wrap_parameter = (program_shape.parameters_size() >= 2); @@ -2635,4 +2621,4 @@ void InitXlaBindings(py::module m) { InitXlaModuleBindings(m); } } // namespace torch_xla -PYBIND11_MODULE(_XLAC, m) { torch_xla::InitXlaBindings(m); } \ No newline at end of file +PYBIND11_MODULE(_XLAC, m) { torch_xla::InitXlaBindings(m); } diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 39f82a4887b..a530995ca78 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -160,8 +160,6 @@ xla::StatusOr LoweringContext::BuildXla() { if (!root_tuple_.empty() & (root_tuple_.size() == 1) & ((get_name_string() == "condctx") or (get_name_string() == "bodyctx"))) { xla = builder()->Build(root_tuple_.at(0)); - // } else if (!root_tuple_.empty() & (root_tuple_.size() == 1) & ) { - // xla = builder()->Build(root_tuple_.at(0)); } else if (!root_tuple_.empty()) { xla::XlaOp root = xla::Tuple(builder(), root_tuple_); xla = builder()->Build(root); From da3556144b3f2c93a0b1282da38aed00f4952473 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 22:47:15 +0000 Subject: [PATCH 318/546] format --- torch_xla/experimental/fori_loop.py | 40 ++++++++++++++--------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index da1ebb4a0c9..a36649fcc3e 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -13,7 +13,7 @@ from torch._higher_order_ops.while_loop import while_loop as torch_while_loop -### TODO(@manfei): treat *input_value +# TODO(@manfei): treat *input_value def fori_loop(upper, lower, body_fun, init_val, input_value): device = xm.xla_device() @@ -27,8 +27,8 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value = body_fun(input_value) - weight = body_fun.weight ### not be used actually, initialized as placeholder xlacomputation requirement - bias = body_fun.bias ### not be used actually, initialized as placeholder xlacomputation requirement + weight = body_fun.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = body_fun.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) else: @@ -46,16 +46,16 @@ def body_fn(upper, lower, one_value, x, input_value): @while_loop_op.py_impl(DispatchKey.XLA) def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): - ### TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') - ### cond_fn&body_fn: callable - ### carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) + # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') + # cond_fn&body_fn: callable + # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) if additional_inputs is None: additional_inputs = tuple() return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): - ### fake carried_inputs to split formal code + # fake carried_inputs to split formal code fake_carried_inputs = [] for carried_input in carried_inputs: device = carried_input.device @@ -68,41 +68,41 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): torch.randint(10, additional_input.size(), dtype=additional_input.dtype).to(device)) - ### TODO(@manfei): specify which element is for which argument like a,b,c + # TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") - ### TODO(@manfei): treat hard-code cond xlacomputation change: currently switch output_value and weight position if additional_inputs(weight/bias) exists - additional_inputs_list_cond = list(fake_carried_inputs[2:]) ### all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor + # TODO(@manfei): treat hard-code cond xlacomputation change: currently switch output_value and weight position if additional_inputs(weight/bias) exists + additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor if additional_inputs: - tmp_bias = additional_inputs_list_cond[-3] ### not used, change order doesn't affect logic - del additional_inputs_list_cond[-3] ### not used, change order doesn't affect logic - additional_inputs_list_cond.append(tmp_bias) ### not used, change order doesn't affect logic + tmp_bias = additional_inputs_list_cond[-3] # not used, change order doesn't affect logic + del additional_inputs_list_cond[-3] # not used, change order doesn't affect logic + additional_inputs_list_cond.append(tmp_bias) # not used, change order doesn't affect logic cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - ### generate body_fn xlacomputation + # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) body_ctx = torch_xla._XLAC.lowering.LoweringContext() body_ctx.set_name_string("bodyctx") - ### TODO(@manfei): treat hard-code body xlacomputation change: currently add non-changed output_value argument if additional_inputs(weight/bias) exists + # TODO(@manfei): treat hard-code body xlacomputation change: currently add non-changed output_value argument if additional_inputs(weight/bias) exists if additional_inputs: additional_inputs_list_body = [fake_carried_inputs[-3]] else: additional_inputs_list_body = [] - ### TODO(@manfei): treat hard-code parameters: additional_inputs_list_body + # TODO(@manfei): treat hard-code parameters: additional_inputs_list_body body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - ### trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while + # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs kwargs = {} if type(total_inputs) is tuple: @@ -115,13 +115,13 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): p = xb.mkparam(builder, len(params), shape) params.append(p) - ### TODO(@manfei): treat hard-code input arguments, currently switch bias and output_value if additional_inputs(weight/bias) exists + # TODO(@manfei): treat hard-code input arguments, currently switch bias and output_value if additional_inputs(weight/bias) exists if additional_inputs: tmp_bias = params[-3] del params[-3] params.append(tmp_bias) - ### generate while xlacomputation + # generate while xlacomputation input_tuple = xb.Op.tuple(tuple(params)) w = xb.mkop( 'While', (input_tuple.op,), @@ -130,7 +130,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): name = 'fori_loop_ed_torch_func' computation = w.build(name) - ### gain final result with generated while xlacomputation + # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', (total_inputs), computation) From 431ab6627bf51f7a98baa7af6974207115278a8d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 22:50:56 +0000 Subject: [PATCH 319/546] format --- torch_xla/csrc/init_python_bindings.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 85545ae267e..f02fc059609 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -914,10 +914,12 @@ class PyLoweringContext { // needed in xlacomputation. void BuildForiLoop(std::vector tensors, std::vector additional_inputs_list = {}) { - // hard-code modify cond xlacomputation input arguments with unusedarguments for xla::while requriement + // hard-code modify cond xlacomputation input arguments with unusedarguments + // for xla::while requriement if (GetNameString() == "condctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); - int64_t parameter_idx = 2; // parameter_idx start from 2 after used upper and lower + int64_t parameter_idx = + 2; // parameter_idx start from 2 after used upper and lower for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); xla::Shape shape = xtensor->shape().get(); @@ -927,7 +929,8 @@ class PyLoweringContext { } } - // hard-code modify body xlacomputation input arguments with unusedarguments for xla::while requriement + // hard-code modify body xlacomputation input arguments with unusedarguments + // for xla::while requriement if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value From 6244d21b590373e02c5cb054092c9fd5806434bc Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 22:52:24 +0000 Subject: [PATCH 320/546] format --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index f02fc059609..e20e28fbb8f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -919,7 +919,7 @@ class PyLoweringContext { if (GetNameString() == "condctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); int64_t parameter_idx = - 2; // parameter_idx start from 2 after used upper and lower + 2; // parameter_idx start from 2 after used upper and lower for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); xla::Shape shape = xtensor->shape().get(); From 04ca72dceadd12e33a2abb216090efdee40a9c57 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 23:06:39 +0000 Subject: [PATCH 321/546] format --- ...while_loop_simple_add_dispatch_in_torch.py | 71 ++++++++++++------- torch_xla/experimental/fori_loop.py | 41 +++++++---- 2 files changed, 74 insertions(+), 38 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 73d55d3dc47..3b2b018cada 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -100,9 +100,11 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value = linear_0(input_value) - weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement - bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() # , output_value_real + weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), bias.clone(), weight.clone( + ), output_value.clone() # , output_value_real upper = torch.tensor([1], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) @@ -111,7 +113,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): l_in_0 = torch.rand(10, device=xm.xla_device()) output_value = torch.zeros([20], dtype=torch.float32, device=device) - upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) + upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value)) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) @@ -124,36 +128,48 @@ def test_while_loop_tpu_simple_linear_class(self): torch.set_grad_enabled(False) class SimpleWithLinear(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) - - def forward(self, upper, lower, one_value, x, input_value, output_value): - def cond_fn(upper, lower, one_value, x, input_value, output_value): - return lower[0] < upper[0] - - def body_fn(upper, lower, one_value, x, input_value, output_value): - new_lower = torch.add(one_value, lower) - output_value_real = self.linear(input_value) - weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement - bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value_real, weight.clone(), bias.clone() - return while_loop(cond_fn, body_fn, (upper, lower, one_value, x, input_value, output_value)) + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def forward(self, upper, lower, one_value, x, input_value, output_value): + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value_real = self.linear(input_value) + weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone( + ), output_value_real, weight.clone(), bias.clone() + + return while_loop( + cond_fn, body_fn, + (upper, lower, one_value, x, input_value, output_value)) simple_with_linear = SimpleWithLinear() upper = torch.tensor([52], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) - l_in_0 = torch.rand(10, device=xm.xla_device()) # input_value + l_in_0 = torch.rand(10, device=xm.xla_device()) # input_value output_value = torch.zeros([20], dtype=torch.float32, device=device) weight_0 = simple_with_linear.linear.weight bias_0 = simple_with_linear.linear.bias - aaa = {"simple_with_linear": (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, output_value))} + aaa = { + "simple_with_linear": + (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, + output_value)) + } - upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear(upper, lower, one_value, init_val, l_in_0, output_value) + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear( + upper, lower, one_value, init_val, l_in_0, output_value) # same weight/bias liear model linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) @@ -177,7 +193,8 @@ def test_fori_loop_tpu_addition(self): def body_fun(a, b): return torch.add(a, b) - upper_, new_lower_, one_value_, add_res_x_, res_ = fori_loop(upper, lower, body_fun, one_value, init_val) + upper_, new_lower_, one_value_, add_res_x_, res_ = fori_loop( + upper, lower, body_fun, one_value, init_val) expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value) self.assertEqual(expected, res_) @@ -191,15 +208,17 @@ def test_fori_loop_tpu_simple_linear(self): lower = torch.tensor([0], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) l_in_0 = torch.randn(10, device=xm.xla_device()) - + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_= fori_loop(upper, lower, linear_0, init_val, l_in_0) - + upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_= fori_loop( + upper, lower, linear_0, init_val, l_in_0) + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) self.assertTrue(torch.all(torch.eq(expected, l_out_))) + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) \ No newline at end of file diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index a36649fcc3e..f07e9062f37 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -22,24 +22,36 @@ def fori_loop(upper, lower, body_fun, init_val, input_value): if (hasattr(body_fun, 'weight') or hasattr(body_fun, 'bias')): output_value = torch.zeros([20], dtype=torch.float32, device=device) + def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] + def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value = body_fun(input_value) weight = body_fun.weight # not be used actually, initialized as placeholder xlacomputation requirement bias = body_fun.bias # not be used actually, initialized as placeholder xlacomputation requirement - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), bias.clone(), weight.clone(), output_value.clone() - res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value, output_value)) + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), bias.clone(), weight.clone( + ), output_value.clone() + + res = torch_while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, input_value, output_value)) else: output_value = torch.tensor([1], dtype=torch.int32, device=device) + def cond_fn(upper, lower, one_value, x, input_value): return lower[0] < upper[0] + def body_fn(upper, lower, one_value, x, input_value): new_lower = torch.add(one_value, lower) output_val = body_fun(one_value, input_value) - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), output_val.clone() - res = torch_while_loop(cond_fn, body_fn, (upper, lower, one_value, init_val, input_value)) + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), output_val.clone() + + res = torch_while_loop(cond_fn, body_fn, + (upper, lower, one_value, init_val, input_value)) return res @@ -60,8 +72,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): for carried_input in carried_inputs: device = carried_input.device fake_carried_inputs.append( - torch.randint(10, carried_input.size(), - dtype=carried_input.dtype).to(device)) + torch.randint( + 10, carried_input.size(), + dtype=carried_input.dtype).to(device)) for additional_input in additional_inputs: device = additional_input.device fake_carried_inputs.append( @@ -74,11 +87,16 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_ctx.set_name_string("condctx") # TODO(@manfei): treat hard-code cond xlacomputation change: currently switch output_value and weight position if additional_inputs(weight/bias) exists - additional_inputs_list_cond = list(fake_carried_inputs[2:]) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor + additional_inputs_list_cond = list( + fake_carried_inputs[2:] + ) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor if additional_inputs: - tmp_bias = additional_inputs_list_cond[-3] # not used, change order doesn't affect logic - del additional_inputs_list_cond[-3] # not used, change order doesn't affect logic - additional_inputs_list_cond.append(tmp_bias) # not used, change order doesn't affect logic + tmp_bias = additional_inputs_list_cond[ + -3] # not used, change order doesn't affect logic + del additional_inputs_list_cond[ + -3] # not used, change order doesn't affect logic + additional_inputs_list_cond.append( + tmp_bias) # not used, change order doesn't affect logic cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) cond_hlo = cond_ctx.hlo() @@ -132,7 +150,6 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', - (total_inputs), - computation) + (total_inputs), computation) return result \ No newline at end of file From 33fa1fbb30f3f305dbd2f1fe60d6225fc0abcf79 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 17 Apr 2024 23:12:49 +0000 Subject: [PATCH 322/546] format --- ...while_loop_simple_add_dispatch_in_torch.py | 18 ++++++++--------- torch_xla/experimental/fori_loop.py | 20 +++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 3b2b018cada..8a1f2bdb737 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -100,11 +100,11 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value = linear_0(input_value) - weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement - bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( one_value, x), input_value.clone(), bias.clone(), weight.clone( - ), output_value.clone() # , output_value_real + ), output_value.clone() upper = torch.tensor([1], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) @@ -141,8 +141,8 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = self.linear(input_value) - weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement - bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement + weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( one_value, x), input_value.clone( ), output_value_real, weight.clone(), bias.clone() @@ -156,7 +156,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) - l_in_0 = torch.rand(10, device=xm.xla_device()) # input_value + l_in_0 = torch.rand(10, device=xm.xla_device()) output_value = torch.zeros([20], dtype=torch.float32, device=device) weight_0 = simple_with_linear.linear.weight @@ -171,7 +171,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear( upper, lower, one_value, init_val, l_in_0, output_value) - # same weight/bias liear model + # create same weight/bias liear model for compare linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) linear_0.weight.data = weight__ linear_0.bias.data = bias__ @@ -211,7 +211,7 @@ def test_fori_loop_tpu_simple_linear(self): linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_= fori_loop( + upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop( upper, lower, linear_0, init_val, l_in_0) expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) @@ -221,4 +221,4 @@ def test_fori_loop_tpu_simple_linear(self): if __name__ == '__main__': test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) \ No newline at end of file + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index f07e9062f37..8ed3a783200 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -72,14 +72,14 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): for carried_input in carried_inputs: device = carried_input.device fake_carried_inputs.append( - torch.randint( - 10, carried_input.size(), - dtype=carried_input.dtype).to(device)) + torch.randint(10, carried_input.size(), + dtype=carried_input.dtype).to(device)) for additional_input in additional_inputs: device = additional_input.device fake_carried_inputs.append( - torch.randint(10, additional_input.size(), - dtype=additional_input.dtype).to(device)) + torch.randint( + 10, additional_input.size(), + dtype=additional_input.dtype).to(device)) # TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) @@ -89,14 +89,14 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # TODO(@manfei): treat hard-code cond xlacomputation change: currently switch output_value and weight position if additional_inputs(weight/bias) exists additional_inputs_list_cond = list( fake_carried_inputs[2:] - ) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor + ) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor if additional_inputs: tmp_bias = additional_inputs_list_cond[ - -3] # not used, change order doesn't affect logic + -3] # not used, change order doesn't affect logic del additional_inputs_list_cond[ - -3] # not used, change order doesn't affect logic + -3] # not used, change order doesn't affect logic additional_inputs_list_cond.append( - tmp_bias) # not used, change order doesn't affect logic + tmp_bias) # not used, change order doesn't affect logic cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) cond_hlo = cond_ctx.hlo() @@ -152,4 +152,4 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', (total_inputs), computation) - return result \ No newline at end of file + return result From ea0e663748465fc51f555d196384825cc123dc5b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 23 Apr 2024 23:36:33 +0000 Subject: [PATCH 323/546] down into cpp --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 8ed3a783200..7119947ad2b 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -63,6 +63,7 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) if additional_inputs is None: additional_inputs = tuple() + print("$$$ additional_inputs: ", additional_inputs) return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) From cf9bb2fbc5bf1062e06992454d3ef23c053f6c0e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 00:27:58 +0000 Subject: [PATCH 324/546] down into cpp --- ...ri_loop_with_while_loop_simple_add_dispatch_in_torch.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 8a1f2bdb737..26440105a26 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -34,6 +34,7 @@ class WhileLoopTest(unittest.TestCase): def test_while_loop_tpu_subtraction(self): + print("$$$ test_while_loop_tpu_subtraction !!!") device = xm.xla_device() def cond_fn(init, limit_value): @@ -52,6 +53,7 @@ def body_fn(init, limit_value): def test_while_loop_tpu_addition(self): + print("$$$ test_while_loop_tpu_addition !!!") device = xm.xla_device() def cond_fn(init, limit_value): @@ -70,6 +72,7 @@ def body_fn(init, limit_value): def test_while_loop_tpu_subtraction_nested(self): + print("$$$ test_while_loop_tpu_subtraction_nested !!!") device = xm.xla_device() def cond_fn(init, limit_value): @@ -88,6 +91,7 @@ def body_fn(init, limit_value): def test_while_loop_tpu_simple_linear(self): + print("$$$ test_while_loop_tpu_simple_linear !!!") xm.mark_step() device = xm.xla_device() torch.set_grad_enabled(False) @@ -123,6 +127,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): def test_while_loop_tpu_simple_linear_class(self): + print("$$$ test_while_loop_tpu_simple_linear_class !!!") xm.mark_step() device = xm.xla_device() torch.set_grad_enabled(False) @@ -182,6 +187,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): def test_fori_loop_tpu_addition(self): + print("$$$ test_fori_loop_tpu_addition !!!") xm.mark_step() device = xm.xla_device() @@ -200,6 +206,7 @@ def body_fun(a, b): def test_fori_loop_tpu_simple_linear(self): + print("$$$ test_fori_loop_tpu_simple_linear !!!") xm.mark_step() device = xm.xla_device() torch.set_grad_enabled(False) From 0e64276a64ebae5ff5a95c5c881896aa7e44016e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:16:55 +0000 Subject: [PATCH 325/546] down into cpp --- ...while_loop_simple_add_dispatch_in_torch.py | 46 +++++++++++++++++++ torch_xla/experimental/fori_loop.py | 22 ++++++++- 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 26440105a26..4be0c547523 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -32,6 +32,7 @@ def _fake_fori_loop(lower, upper, body_fun, *init_val): class WhileLoopTest(unittest.TestCase): +# additional_inputs: () def test_while_loop_tpu_subtraction(self): print("$$$ test_while_loop_tpu_subtraction !!!") @@ -51,6 +52,7 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) +# additional_inputs: () def test_while_loop_tpu_addition(self): print("$$$ test_while_loop_tpu_addition !!!") @@ -70,6 +72,7 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) +# additional_inputs: () def test_while_loop_tpu_subtraction_nested(self): print("$$$ test_while_loop_tpu_subtraction_nested !!!") @@ -89,6 +92,8 @@ def body_fn(init, limit_value): expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) self.assertEqual(expected, res) +### return weight/bias +# additional_inputs: (tensor([1*20], device='xla:0'), tensor([10*20], device='xla:0')) def test_while_loop_tpu_simple_linear(self): print("$$$ test_while_loop_tpu_simple_linear !!!") @@ -125,6 +130,45 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + def test_while_loop_tpu_simple_linear_wrapper(self): + + print("$$$ test_while_loop_tpu_simple_linear_wrapper !!!") + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value = linear_0(input_value) + weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), bias.clone(), weight.clone( + ), output_value.clone() + + upper = torch.tensor([1], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value)) + + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + return self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + + +### return weight/bias +# additional_inputs: (tensor([ 1*20], device='xla:0'), tensor([10*20], device='xla:0')) def test_while_loop_tpu_simple_linear_class(self): print("$$$ test_while_loop_tpu_simple_linear_class !!!") @@ -185,6 +229,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) return aaa +# additional_inputs: () def test_fori_loop_tpu_addition(self): print("$$$ test_fori_loop_tpu_addition !!!") @@ -204,6 +249,7 @@ def body_fun(a, b): expected = _fake_fori_loop(lower, upper, body_fun, init_val, one_value) self.assertEqual(expected, res_) +# additional_inputs: (tensor([1*20], device='xla:0'), tensor([[10*20], device='xla:0')) def test_fori_loop_tpu_simple_linear(self): print("$$$ test_fori_loop_tpu_simple_linear !!!") diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 7119947ad2b..1619a73fb2e 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -63,8 +63,28 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) if additional_inputs is None: additional_inputs = tuple() + return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) + else: + # modify body_fn return with additional_inputs + def new_body_fn(carried_inputs): + # new_lower = torch.add(one_value, lower) + # output_value = linear_0(input_value) + weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + # one_value, x), input_value.clone(), bias.clone(), weight.clone( + # ), output_value.clone() + # return body_fn(carried_inputs), weight.clone(), bias.clone() + res1 = body_fn(carried_inputs) + print("res1: ", res1) + print("type res1: ", type(res1)) + res2 = res1.append(additional_inputs) + print("res2: ", res2) + print("type res2: ", type(res2)) + return (body_fn(carried_inputs)).append(additional_inputs) + return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) print("$$$ additional_inputs: ", additional_inputs) - return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) + # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): From 6fbc3303015dbee2dcfa6de382bef6f2e8294354 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:22:39 +0000 Subject: [PATCH 326/546] down into cpp --- ...fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 4be0c547523..6bf8177123c 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -130,6 +130,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) +### +# def test_while_loop_tpu_simple_linear_wrapper(self): print("$$$ test_while_loop_tpu_simple_linear_wrapper !!!") @@ -148,8 +150,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( - one_value, x), input_value.clone(), bias.clone(), weight.clone( - ), output_value.clone() + one_value, x), input_value.clone(), output_value.clone() upper = torch.tensor([1], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) From 3d2bd415ff6355458ce9f58bfacce9bdbb6ff492 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:23:59 +0000 Subject: [PATCH 327/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 1619a73fb2e..85cd710d2ed 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -66,7 +66,7 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) else: # modify body_fn return with additional_inputs - def new_body_fn(carried_inputs): + def new_body_fn(*carried_inputs): # new_lower = torch.add(one_value, lower) # output_value = linear_0(input_value) weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement From d9d6358b190b7673adc8bfa9ca963921da2ba1eb Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:24:49 +0000 Subject: [PATCH 328/546] down into cpp --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 85cd710d2ed..b2534055e2d 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -69,8 +69,8 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): def new_body_fn(*carried_inputs): # new_lower = torch.add(one_value, lower) # output_value = linear_0(input_value) - weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement - bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + # weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + # bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( # one_value, x), input_value.clone(), bias.clone(), weight.clone( # ), output_value.clone() From 8ac3fc85b45a08a13f4e10cb4ebadaada07c7407 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:29:45 +0000 Subject: [PATCH 329/546] down into cpp --- torch_xla/experimental/fori_loop.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index b2534055e2d..6e9a457fdef 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -75,13 +75,15 @@ def new_body_fn(*carried_inputs): # one_value, x), input_value.clone(), bias.clone(), weight.clone( # ), output_value.clone() # return body_fn(carried_inputs), weight.clone(), bias.clone() - res1 = body_fn(carried_inputs) + print("carried_inputs: ", carried_inputs) + print("additional_inputs: ", additional_inputs) + res1 = body_fn(*carried_inputs) print("res1: ", res1) print("type res1: ", type(res1)) - res2 = res1.append(additional_inputs) + res2 = res1.append(*additional_inputs) print("res2: ", res2) print("type res2: ", type(res2)) - return (body_fn(carried_inputs)).append(additional_inputs) + return (body_fn(*carried_inputs)).append(*additional_inputs) return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) print("$$$ additional_inputs: ", additional_inputs) # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) From 387b9fdb28b363ccbf8ed1f9cba158f915761c2a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:33:35 +0000 Subject: [PATCH 330/546] down into cpp --- ..._fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 4 ++-- torch_xla/experimental/fori_loop.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 6bf8177123c..a4739f6374a 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -147,8 +147,8 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value = linear_0(input_value) - weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement - bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + # weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + # bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( one_value, x), input_value.clone(), output_value.clone() diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 6e9a457fdef..5223c46dba3 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -83,7 +83,7 @@ def new_body_fn(*carried_inputs): res2 = res1.append(*additional_inputs) print("res2: ", res2) print("type res2: ", type(res2)) - return (body_fn(*carried_inputs)).append(*additional_inputs) + return list(body_fn(*carried_inputs)).append(*additional_inputs) return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) print("$$$ additional_inputs: ", additional_inputs) # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) From dc10d64858daface673e574a9ad8d915d5612fc1 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:34:15 +0000 Subject: [PATCH 331/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 5223c46dba3..60372e6657a 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -80,7 +80,7 @@ def new_body_fn(*carried_inputs): res1 = body_fn(*carried_inputs) print("res1: ", res1) print("type res1: ", type(res1)) - res2 = res1.append(*additional_inputs) + res2 = list(res1).append(*additional_inputs) print("res2: ", res2) print("type res2: ", type(res2)) return list(body_fn(*carried_inputs)).append(*additional_inputs) From a2c01a1e3695d62ea624e6eddb9a6349d4205b7f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:35:23 +0000 Subject: [PATCH 332/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 60372e6657a..cfc77f0ed85 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -80,7 +80,7 @@ def new_body_fn(*carried_inputs): res1 = body_fn(*carried_inputs) print("res1: ", res1) print("type res1: ", type(res1)) - res2 = list(res1).append(*additional_inputs) + res2 = list(res1).add(*additional_inputs) print("res2: ", res2) print("type res2: ", type(res2)) return list(body_fn(*carried_inputs)).append(*additional_inputs) From 6f7dcc5c8c5f8d1bdabc42ddaabfd445227eaa60 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:36:31 +0000 Subject: [PATCH 333/546] down into cpp --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index cfc77f0ed85..6da0f4d5aba 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -80,6 +80,7 @@ def new_body_fn(*carried_inputs): res1 = body_fn(*carried_inputs) print("res1: ", res1) print("type res1: ", type(res1)) + print("type additional_inputs: ", type(additional_inputs)) res2 = list(res1).add(*additional_inputs) print("res2: ", res2) print("type res2: ", type(res2)) From 83aa4b03c823fce9be68fc1e9817a1342e5d93d6 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:37:25 +0000 Subject: [PATCH 334/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 6da0f4d5aba..83292facd8b 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -81,7 +81,7 @@ def new_body_fn(*carried_inputs): print("res1: ", res1) print("type res1: ", type(res1)) print("type additional_inputs: ", type(additional_inputs)) - res2 = list(res1).add(*additional_inputs) + res2 = (res1, ) + additional_inputs print("res2: ", res2) print("type res2: ", type(res2)) return list(body_fn(*carried_inputs)).append(*additional_inputs) From 73ab5c1629304758a64f3591e6dba1df82487932 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:38:05 +0000 Subject: [PATCH 335/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 83292facd8b..7296a498a10 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -84,7 +84,7 @@ def new_body_fn(*carried_inputs): res2 = (res1, ) + additional_inputs print("res2: ", res2) print("type res2: ", type(res2)) - return list(body_fn(*carried_inputs)).append(*additional_inputs) + return (body_fn(*carried_inputs), ) + additional_inputs return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) print("$$$ additional_inputs: ", additional_inputs) # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) From 6ce33bd3291b0fc8e031f7f2a3e054c9978cb9e4 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:38:31 +0000 Subject: [PATCH 336/546] down into cpp --- torch_xla/experimental/fori_loop.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 7296a498a10..f1149363761 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -75,15 +75,15 @@ def new_body_fn(*carried_inputs): # one_value, x), input_value.clone(), bias.clone(), weight.clone( # ), output_value.clone() # return body_fn(carried_inputs), weight.clone(), bias.clone() - print("carried_inputs: ", carried_inputs) - print("additional_inputs: ", additional_inputs) + # print("carried_inputs: ", carried_inputs) + # print("additional_inputs: ", additional_inputs) res1 = body_fn(*carried_inputs) - print("res1: ", res1) - print("type res1: ", type(res1)) - print("type additional_inputs: ", type(additional_inputs)) + # print("res1: ", res1) + # print("type res1: ", type(res1)) + # print("type additional_inputs: ", type(additional_inputs)) res2 = (res1, ) + additional_inputs - print("res2: ", res2) - print("type res2: ", type(res2)) + # print("res2: ", res2) + # print("type res2: ", type(res2)) return (body_fn(*carried_inputs), ) + additional_inputs return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) print("$$$ additional_inputs: ", additional_inputs) From ea4cc8c14220ec4835b2a07a61a246a42a689652 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:42:24 +0000 Subject: [PATCH 337/546] down into cpp --- torch_xla/experimental/fori_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index f1149363761..4780fd4625c 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -77,14 +77,14 @@ def new_body_fn(*carried_inputs): # return body_fn(carried_inputs), weight.clone(), bias.clone() # print("carried_inputs: ", carried_inputs) # print("additional_inputs: ", additional_inputs) - res1 = body_fn(*carried_inputs) + # res1 = body_fn(*carried_inputs) # print("res1: ", res1) # print("type res1: ", type(res1)) # print("type additional_inputs: ", type(additional_inputs)) - res2 = (res1, ) + additional_inputs + # res2 = (res1, ) + additional_inputs # print("res2: ", res2) # print("type res2: ", type(res2)) - return (body_fn(*carried_inputs), ) + additional_inputs + return list(body_fn(*carried_inputs)).extend(*additional_inputs) return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) print("$$$ additional_inputs: ", additional_inputs) # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) From 362bdc399c8996e5fcbb6d72f6865e230791d9fc Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:43:06 +0000 Subject: [PATCH 338/546] down into cpp --- torch_xla/experimental/fori_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 4780fd4625c..424de36f2f0 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -80,7 +80,8 @@ def new_body_fn(*carried_inputs): # res1 = body_fn(*carried_inputs) # print("res1: ", res1) # print("type res1: ", type(res1)) - # print("type additional_inputs: ", type(additional_inputs)) + print("type additional_inputs: ", type(additional_inputs)) + print("type *additional_inputs: ", type(*additional_inputs)) # res2 = (res1, ) + additional_inputs # print("res2: ", res2) # print("type res2: ", type(res2)) From 2f29c1b3c8d636a8c05ebd8056daa5951ab5a66f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:43:38 +0000 Subject: [PATCH 339/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 424de36f2f0..e658448ed78 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -81,7 +81,7 @@ def new_body_fn(*carried_inputs): # print("res1: ", res1) # print("type res1: ", type(res1)) print("type additional_inputs: ", type(additional_inputs)) - print("type *additional_inputs: ", type(*additional_inputs)) + print("*additional_inputs: ", *additional_inputs) # res2 = (res1, ) + additional_inputs # print("res2: ", res2) # print("type res2: ", type(res2)) From 32dd2ea75f519b8791131ecadcbe5cfa99374bfe Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:46:04 +0000 Subject: [PATCH 340/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index e658448ed78..281f4fe20f2 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -85,7 +85,7 @@ def new_body_fn(*carried_inputs): # res2 = (res1, ) + additional_inputs # print("res2: ", res2) # print("type res2: ", type(res2)) - return list(body_fn(*carried_inputs)).extend(*additional_inputs) + return list(body_fn(*carried_inputs)).extend(additional_inputs) return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) print("$$$ additional_inputs: ", additional_inputs) # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) From 2bae9fec6fdb9da3138ed0f4f669f00530ecd87a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:47:14 +0000 Subject: [PATCH 341/546] down into cpp --- torch_xla/experimental/fori_loop.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 281f4fe20f2..b40cdb5247c 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -80,11 +80,13 @@ def new_body_fn(*carried_inputs): # res1 = body_fn(*carried_inputs) # print("res1: ", res1) # print("type res1: ", type(res1)) - print("type additional_inputs: ", type(additional_inputs)) - print("*additional_inputs: ", *additional_inputs) + # print("type additional_inputs: ", type(additional_inputs)) + # print("*additional_inputs: ", *additional_inputs) # res2 = (res1, ) + additional_inputs # print("res2: ", res2) # print("type res2: ", type(res2)) + res = list(body_fn(*carried_inputs)).extend(additional_inputs) + print("res: ", res) return list(body_fn(*carried_inputs)).extend(additional_inputs) return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) print("$$$ additional_inputs: ", additional_inputs) From 3a06a3ea1630514b2824eaa62858d610da02eb32 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:48:20 +0000 Subject: [PATCH 342/546] down into cpp --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index b40cdb5247c..d0a2c403f6d 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -85,6 +85,7 @@ def new_body_fn(*carried_inputs): # res2 = (res1, ) + additional_inputs # print("res2: ", res2) # print("type res2: ", type(res2)) + print("before it") res = list(body_fn(*carried_inputs)).extend(additional_inputs) print("res: ", res) return list(body_fn(*carried_inputs)).extend(additional_inputs) From c7fdd16fbe4642482accd96a9d3f4ee89e0d8f99 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:49:57 +0000 Subject: [PATCH 343/546] down into cpp --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index d0a2c403f6d..27a15f38d10 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -86,6 +86,7 @@ def new_body_fn(*carried_inputs): # print("res2: ", res2) # print("type res2: ", type(res2)) print("before it") + print("body_fn(*carried_inputs): ", body_fn(*carried_inputs)) res = list(body_fn(*carried_inputs)).extend(additional_inputs) print("res: ", res) return list(body_fn(*carried_inputs)).extend(additional_inputs) From 0f8d315cc0763d4767c523cf09d20ce69accd2a2 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:51:24 +0000 Subject: [PATCH 344/546] down into cpp --- torch_xla/experimental/fori_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 27a15f38d10..6089cabee22 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -86,7 +86,8 @@ def new_body_fn(*carried_inputs): # print("res2: ", res2) # print("type res2: ", type(res2)) print("before it") - print("body_fn(*carried_inputs): ", body_fn(*carried_inputs)) + # print("body_fn(*carried_inputs): ", body_fn(*carried_inputs)) + print("list(body_fn(*carried_inputs)): ", list(body_fn(*carried_inputs))) res = list(body_fn(*carried_inputs)).extend(additional_inputs) print("res: ", res) return list(body_fn(*carried_inputs)).extend(additional_inputs) From 9f6912c843b2c42e50bc28bbd50b64f5b1823fc4 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:52:18 +0000 Subject: [PATCH 345/546] down into cpp --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 6089cabee22..5e4b2f67441 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -88,6 +88,7 @@ def new_body_fn(*carried_inputs): print("before it") # print("body_fn(*carried_inputs): ", body_fn(*carried_inputs)) print("list(body_fn(*carried_inputs)): ", list(body_fn(*carried_inputs))) + print("additional_inputs: ", additional_inputs) res = list(body_fn(*carried_inputs)).extend(additional_inputs) print("res: ", res) return list(body_fn(*carried_inputs)).extend(additional_inputs) From 9f16c7d2a608904f510d5e121643caf92bc5567a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:53:15 +0000 Subject: [PATCH 346/546] down into cpp --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 5e4b2f67441..b7f54b047c2 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -89,6 +89,7 @@ def new_body_fn(*carried_inputs): # print("body_fn(*carried_inputs): ", body_fn(*carried_inputs)) print("list(body_fn(*carried_inputs)): ", list(body_fn(*carried_inputs))) print("additional_inputs: ", additional_inputs) + print("list(body_fn(*carried_inputs)).extend(additional_inputs): ", list(body_fn(*carried_inputs)).extend(additional_inputs)) res = list(body_fn(*carried_inputs)).extend(additional_inputs) print("res: ", res) return list(body_fn(*carried_inputs)).extend(additional_inputs) From 77656a13cdf80139653cad4513dfa5c1bd70897b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:54:23 +0000 Subject: [PATCH 347/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index b7f54b047c2..5507f19fb12 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -90,6 +90,8 @@ def new_body_fn(*carried_inputs): print("list(body_fn(*carried_inputs)): ", list(body_fn(*carried_inputs))) print("additional_inputs: ", additional_inputs) print("list(body_fn(*carried_inputs)).extend(additional_inputs): ", list(body_fn(*carried_inputs)).extend(additional_inputs)) + res0 = list([1, 2, 3]).extend((4, 5)) + print("res0: ", res0) res = list(body_fn(*carried_inputs)).extend(additional_inputs) print("res: ", res) return list(body_fn(*carried_inputs)).extend(additional_inputs) From 3cd1a6b3884261610d56540e244ac2fee75dc39f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:55:11 +0000 Subject: [PATCH 348/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 5507f19fb12..17ff9480740 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -91,7 +91,9 @@ def new_body_fn(*carried_inputs): print("additional_inputs: ", additional_inputs) print("list(body_fn(*carried_inputs)).extend(additional_inputs): ", list(body_fn(*carried_inputs)).extend(additional_inputs)) res0 = list([1, 2, 3]).extend((4, 5)) + res1 = list([1, 2, 3]).extend([4, 5]) print("res0: ", res0) + print("res1: ", res1) res = list(body_fn(*carried_inputs)).extend(additional_inputs) print("res: ", res) return list(body_fn(*carried_inputs)).extend(additional_inputs) From 1841bd2270c7d7fff3463bc557157c32c8d04606 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:56:05 +0000 Subject: [PATCH 349/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 17ff9480740..7024e59d695 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -91,7 +91,7 @@ def new_body_fn(*carried_inputs): print("additional_inputs: ", additional_inputs) print("list(body_fn(*carried_inputs)).extend(additional_inputs): ", list(body_fn(*carried_inputs)).extend(additional_inputs)) res0 = list([1, 2, 3]).extend((4, 5)) - res1 = list([1, 2, 3]).extend([4, 5]) + res1 = [1, 2, 3].extend([4, 5]) print("res0: ", res0) print("res1: ", res1) res = list(body_fn(*carried_inputs)).extend(additional_inputs) From 98a7e7c2442335a127984423fb7da493b04a666e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 20:56:44 +0000 Subject: [PATCH 350/546] down into cpp --- torch_xla/experimental/fori_loop.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 7024e59d695..8ad613b0ea4 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -94,6 +94,10 @@ def new_body_fn(*carried_inputs): res1 = [1, 2, 3].extend([4, 5]) print("res0: ", res0) print("res1: ", res1) + thislist = ["apple", "banana", "cherry"] + tropical = ["mango", "pineapple", "papaya"] + thislist.extend(tropical) + print(thislist) res = list(body_fn(*carried_inputs)).extend(additional_inputs) print("res: ", res) return list(body_fn(*carried_inputs)).extend(additional_inputs) From 98d044b29b465c2c67f8064754ffb8901ad1436d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:08:22 +0000 Subject: [PATCH 351/546] down into cpp --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 8ad613b0ea4..df1760a1441 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -89,6 +89,7 @@ def new_body_fn(*carried_inputs): # print("body_fn(*carried_inputs): ", body_fn(*carried_inputs)) print("list(body_fn(*carried_inputs)): ", list(body_fn(*carried_inputs))) print("additional_inputs: ", additional_inputs) + print("type additional_inputs: ", type(additional_inputs)) print("list(body_fn(*carried_inputs)).extend(additional_inputs): ", list(body_fn(*carried_inputs)).extend(additional_inputs)) res0 = list([1, 2, 3]).extend((4, 5)) res1 = [1, 2, 3].extend([4, 5]) From 8bbd848ed15e50c387aeb93bda27ef071ba839a9 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:08:55 +0000 Subject: [PATCH 352/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index df1760a1441..c00983795a4 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -99,7 +99,7 @@ def new_body_fn(*carried_inputs): tropical = ["mango", "pineapple", "papaya"] thislist.extend(tropical) print(thislist) - res = list(body_fn(*carried_inputs)).extend(additional_inputs) + res = list(body_fn(*carried_inputs)).extend(list(additional_inputs)) print("res: ", res) return list(body_fn(*carried_inputs)).extend(additional_inputs) return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) From b47a02283ab51f966ae8043660068a652db246fe Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:09:33 +0000 Subject: [PATCH 353/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index c00983795a4..1ccb6924c3a 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -90,7 +90,7 @@ def new_body_fn(*carried_inputs): print("list(body_fn(*carried_inputs)): ", list(body_fn(*carried_inputs))) print("additional_inputs: ", additional_inputs) print("type additional_inputs: ", type(additional_inputs)) - print("list(body_fn(*carried_inputs)).extend(additional_inputs): ", list(body_fn(*carried_inputs)).extend(additional_inputs)) + print("list(body_fn(*carried_inputs)).extend(list(additional_inputs)): ", list(body_fn(*carried_inputs)).extend(list(additional_inputs))) res0 = list([1, 2, 3]).extend((4, 5)) res1 = [1, 2, 3].extend([4, 5]) print("res0: ", res0) From d705dafec0bf2a25d8fb38f05cb0f07637c98f62 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:10:16 +0000 Subject: [PATCH 354/546] down into cpp --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 1ccb6924c3a..88e67845f23 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -88,8 +88,8 @@ def new_body_fn(*carried_inputs): print("before it") # print("body_fn(*carried_inputs): ", body_fn(*carried_inputs)) print("list(body_fn(*carried_inputs)): ", list(body_fn(*carried_inputs))) - print("additional_inputs: ", additional_inputs) - print("type additional_inputs: ", type(additional_inputs)) + # print("additional_inputs: ", additional_inputs) + # print("type additional_inputs: ", type(additional_inputs)) print("list(body_fn(*carried_inputs)).extend(list(additional_inputs)): ", list(body_fn(*carried_inputs)).extend(list(additional_inputs))) res0 = list([1, 2, 3]).extend((4, 5)) res1 = [1, 2, 3].extend([4, 5]) From 3c178390793b5a929e2a356488502bfcda33bca2 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:11:02 +0000 Subject: [PATCH 355/546] down into cpp --- torch_xla/experimental/fori_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 88e67845f23..2bad8220647 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -99,7 +99,8 @@ def new_body_fn(*carried_inputs): tropical = ["mango", "pineapple", "papaya"] thislist.extend(tropical) print(thislist) - res = list(body_fn(*carried_inputs)).extend(list(additional_inputs)) + mid = list(body_fn(*carried_inputs)) + res = mid.extend(list(additional_inputs)) print("res: ", res) return list(body_fn(*carried_inputs)).extend(additional_inputs) return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) From b4f479b238d848fbaa9f3eb31fdcc4f1a0d609e4 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:11:52 +0000 Subject: [PATCH 356/546] down into cpp --- torch_xla/experimental/fori_loop.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 2bad8220647..e3c973bbe2b 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -99,6 +99,10 @@ def new_body_fn(*carried_inputs): tropical = ["mango", "pineapple", "papaya"] thislist.extend(tropical) print(thislist) + thislist = ["apple", "banana", "cherry"] + thistuple = ("kiwi", "orange") + thislist.extend(thistuple) + print(thislist) mid = list(body_fn(*carried_inputs)) res = mid.extend(list(additional_inputs)) print("res: ", res) From c53b9b94018c6c7b46f15c54d9139b3f74f30dd2 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:13:00 +0000 Subject: [PATCH 357/546] down into cpp --- torch_xla/experimental/fori_loop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index e3c973bbe2b..63536e44cc7 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -90,7 +90,7 @@ def new_body_fn(*carried_inputs): print("list(body_fn(*carried_inputs)): ", list(body_fn(*carried_inputs))) # print("additional_inputs: ", additional_inputs) # print("type additional_inputs: ", type(additional_inputs)) - print("list(body_fn(*carried_inputs)).extend(list(additional_inputs)): ", list(body_fn(*carried_inputs)).extend(list(additional_inputs))) + # print("list(body_fn(*carried_inputs)).extend(list(additional_inputs)): ", list(body_fn(*carried_inputs)).extend(list(additional_inputs))) res0 = list([1, 2, 3]).extend((4, 5)) res1 = [1, 2, 3].extend([4, 5]) print("res0: ", res0) @@ -104,7 +104,8 @@ def new_body_fn(*carried_inputs): thislist.extend(thistuple) print(thislist) mid = list(body_fn(*carried_inputs)) - res = mid.extend(list(additional_inputs)) + # res = mid.extend(list(additional_inputs)) + res = mid.extend(additional_inputs) print("res: ", res) return list(body_fn(*carried_inputs)).extend(additional_inputs) return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) From 4ded9a315112c78ecb5a934f977422bce07fd63d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:13:19 +0000 Subject: [PATCH 358/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 63536e44cc7..e234203855d 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -88,7 +88,7 @@ def new_body_fn(*carried_inputs): print("before it") # print("body_fn(*carried_inputs): ", body_fn(*carried_inputs)) print("list(body_fn(*carried_inputs)): ", list(body_fn(*carried_inputs))) - # print("additional_inputs: ", additional_inputs) + print("additional_inputs: ", additional_inputs) # print("type additional_inputs: ", type(additional_inputs)) # print("list(body_fn(*carried_inputs)).extend(list(additional_inputs)): ", list(body_fn(*carried_inputs)).extend(list(additional_inputs))) res0 = list([1, 2, 3]).extend((4, 5)) From ffbe9a5d0e3f131c0bf19d10bb7405ce625cd2be Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:14:15 +0000 Subject: [PATCH 359/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index e234203855d..10c14025393 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -91,7 +91,7 @@ def new_body_fn(*carried_inputs): print("additional_inputs: ", additional_inputs) # print("type additional_inputs: ", type(additional_inputs)) # print("list(body_fn(*carried_inputs)).extend(list(additional_inputs)): ", list(body_fn(*carried_inputs)).extend(list(additional_inputs))) - res0 = list([1, 2, 3]).extend((4, 5)) + res0 = [1, 2, 3].extend((4, 5)) res1 = [1, 2, 3].extend([4, 5]) print("res0: ", res0) print("res1: ", res1) From 2be42542e82fac6dc85fba713e388a41d3fb8024 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:15:28 +0000 Subject: [PATCH 360/546] down into cpp --- torch_xla/experimental/fori_loop.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 10c14025393..9c2202c3002 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -91,6 +91,14 @@ def new_body_fn(*carried_inputs): print("additional_inputs: ", additional_inputs) # print("type additional_inputs: ", type(additional_inputs)) # print("list(body_fn(*carried_inputs)).extend(list(additional_inputs)): ", list(body_fn(*carried_inputs)).extend(list(additional_inputs))) + aaa = [1, 2, 3] + bbb = [4, 5] + ccc = (4, 5) + aaa.extend(bbb) + print(aaa) + aaa.extend(ccc) + print(aaa) + res0 = [1, 2, 3].extend((4, 5)) res1 = [1, 2, 3].extend([4, 5]) print("res0: ", res0) From 7df2af49345c235e3087cd3c0bd923e709f21c5d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:16:31 +0000 Subject: [PATCH 361/546] down into cpp --- torch_xla/experimental/fori_loop.py | 44 ++++++++++++++--------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 9c2202c3002..0fd51c53a3b 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -91,29 +91,29 @@ def new_body_fn(*carried_inputs): print("additional_inputs: ", additional_inputs) # print("type additional_inputs: ", type(additional_inputs)) # print("list(body_fn(*carried_inputs)).extend(list(additional_inputs)): ", list(body_fn(*carried_inputs)).extend(list(additional_inputs))) - aaa = [1, 2, 3] - bbb = [4, 5] - ccc = (4, 5) - aaa.extend(bbb) - print(aaa) - aaa.extend(ccc) - print(aaa) - - res0 = [1, 2, 3].extend((4, 5)) - res1 = [1, 2, 3].extend([4, 5]) - print("res0: ", res0) - print("res1: ", res1) - thislist = ["apple", "banana", "cherry"] - tropical = ["mango", "pineapple", "papaya"] - thislist.extend(tropical) - print(thislist) - thislist = ["apple", "banana", "cherry"] - thistuple = ("kiwi", "orange") - thislist.extend(thistuple) - print(thislist) - mid = list(body_fn(*carried_inputs)) + # aaa = [1, 2, 3] + # bbb = [4, 5] + # ccc = (4, 5) + # aaa.extend(bbb) + # print(aaa) + # aaa.extend(ccc) + # print(aaa) + # res0 = [1, 2, 3].extend((4, 5)) + # res1 = [1, 2, 3].extend([4, 5]) + # print("res0: ", res0) + # print("res1: ", res1) + # thislist = ["apple", "banana", "cherry"] + # tropical = ["mango", "pineapple", "papaya"] + # thislist.extend(tropical) + # print(thislist) + # thislist = ["apple", "banana", "cherry"] + # thistuple = ("kiwi", "orange") + # thislist.extend(thistuple) + # print(thislist) + # mid = body_fn(*carried_inputs) # res = mid.extend(list(additional_inputs)) - res = mid.extend(additional_inputs) + res = body_fn(*carried_inputs) + res.extend(additional_inputs) print("res: ", res) return list(body_fn(*carried_inputs)).extend(additional_inputs) return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) From 1be8e3a793da68b51f4e186b6aa3bc212c7a4043 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:17:23 +0000 Subject: [PATCH 362/546] down into cpp --- torch_xla/experimental/fori_loop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 0fd51c53a3b..ab2ce8d665f 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -112,10 +112,11 @@ def new_body_fn(*carried_inputs): # print(thislist) # mid = body_fn(*carried_inputs) # res = mid.extend(list(additional_inputs)) - res = body_fn(*carried_inputs) + res = list(body_fn(*carried_inputs)) res.extend(additional_inputs) print("res: ", res) - return list(body_fn(*carried_inputs)).extend(additional_inputs) + # return list(body_fn(*carried_inputs)).extend(additional_inputs) + return res return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) print("$$$ additional_inputs: ", additional_inputs) # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) From 929ff17889b3ba93a105b67fae649627add3e0d3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:20:34 +0000 Subject: [PATCH 363/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index ab2ce8d665f..93bf36c9ba2 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -114,7 +114,7 @@ def new_body_fn(*carried_inputs): # res = mid.extend(list(additional_inputs)) res = list(body_fn(*carried_inputs)) res.extend(additional_inputs) - print("res: ", res) + # print("res: ", res) # return list(body_fn(*carried_inputs)).extend(additional_inputs) return res return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) From 0d8a23c0b3274836996671a9f106e63e03fcc422 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:22:25 +0000 Subject: [PATCH 364/546] down into cpp --- torch_xla/experimental/fori_loop.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 93bf36c9ba2..605427a0104 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -158,6 +158,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) + cond_hlo_print = xb.get_computation_hlo(cond_computation) + print("cond computation: !!!!!!!!!") + print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) @@ -175,6 +178,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) + body_hlo_print = xb.get_computation_hlo(body_computation) + print("body computation: !!!!!!!!!") + print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs @@ -203,6 +209,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) + hlo_print = xb.get_computation_hlo(computation) + print("while computation: !!!!!!!!!") + print(hlo_print) # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', From be617cce176152d48f2792d5ecc267705d424280 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:22:58 +0000 Subject: [PATCH 365/546] down into cpp --- torch_xla/experimental/fori_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 605427a0104..dfa8b4e64a7 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -85,10 +85,10 @@ def new_body_fn(*carried_inputs): # res2 = (res1, ) + additional_inputs # print("res2: ", res2) # print("type res2: ", type(res2)) - print("before it") + # print("before it") # print("body_fn(*carried_inputs): ", body_fn(*carried_inputs)) - print("list(body_fn(*carried_inputs)): ", list(body_fn(*carried_inputs))) - print("additional_inputs: ", additional_inputs) + # print("list(body_fn(*carried_inputs)): ", list(body_fn(*carried_inputs))) + # print("additional_inputs: ", additional_inputs) # print("type additional_inputs: ", type(additional_inputs)) # print("list(body_fn(*carried_inputs)).extend(list(additional_inputs)): ", list(body_fn(*carried_inputs)).extend(list(additional_inputs))) # aaa = [1, 2, 3] From 9925fd2921ee320ee84c45b508e0c64cce0bfd52 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:23:54 +0000 Subject: [PATCH 366/546] down into cpp --- torch_xla/csrc/init_python_bindings.cpp | 28 ++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index e20e28fbb8f..2112222ac94 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -929,20 +929,20 @@ class PyLoweringContext { } } - // hard-code modify body xlacomputation input arguments with unusedarguments - // for xla::while requriement - if (GetNameString() == "bodyctx") { - xla::XlaBuilder* local_builder = lowering_ctx.builder(); - // TODO(@manfei): treat hard code parameter_idx value - int64_t parameter_idx = 7; - for (auto& additional_input_tensor : additional_inputs_list) { - XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); - xla::Shape shape = xtensor->shape().get(); - xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, - "UnusedArgumentsPlaceholder"); - parameter_idx += 1; - } - } + // // hard-code modify body xlacomputation input arguments with unusedarguments + // // for xla::while requriement + // if (GetNameString() == "bodyctx") { + // xla::XlaBuilder* local_builder = lowering_ctx.builder(); + // // TODO(@manfei): treat hard code parameter_idx value + // int64_t parameter_idx = 7; + // for (auto& additional_input_tensor : additional_inputs_list) { + // XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + // xla::Shape shape = xtensor->shape().get(); + // xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + // "UnusedArgumentsPlaceholder"); + // parameter_idx += 1; + // } + // } // Get the backing XLA tensors from the output torch tensor handles std::vector xtensors = From bf038a5f0cce2bdf61b9e709426084593b78f9d4 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:36:25 +0000 Subject: [PATCH 367/546] down into cpp --- torch_xla/experimental/fori_loop.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index dfa8b4e64a7..b00c9cb0476 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -112,12 +112,15 @@ def new_body_fn(*carried_inputs): # print(thislist) # mid = body_fn(*carried_inputs) # res = mid.extend(list(additional_inputs)) - res = list(body_fn(*carried_inputs)) - res.extend(additional_inputs) + # res = list(body_fn(*carried_inputs)) + # res.extend(additional_inputs) # print("res: ", res) # return list(body_fn(*carried_inputs)).extend(additional_inputs) + res = list(body_fn(*carried_inputs)) + res.extend(additional_inputs) return res - return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) + # return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) + return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) print("$$$ additional_inputs: ", additional_inputs) # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) From 3727dbfdf388f4d1fa81042c2afc762a8e2c5a90 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:42:45 +0000 Subject: [PATCH 368/546] down into cpp --- torch_xla/experimental/fori_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index b00c9cb0476..231e674b7aa 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -177,7 +177,8 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): additional_inputs_list_body = [] # TODO(@manfei): treat hard-code parameters: additional_inputs_list_body - body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) + # body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) + body_ctx.buildforiloop(list(body_result), ()) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) From e210bc7c1b6fef52218563863611e21545f520a4 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:43:40 +0000 Subject: [PATCH 369/546] down into cpp --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 231e674b7aa..516e65b3247 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -119,8 +119,8 @@ def new_body_fn(*carried_inputs): res = list(body_fn(*carried_inputs)) res.extend(additional_inputs) return res - # return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) - return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) + return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) + # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) print("$$$ additional_inputs: ", additional_inputs) # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) From 7a084c506f505c7744a95321da72916aebd5e2df Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:50:25 +0000 Subject: [PATCH 370/546] down into cpp --- torch_xla/experimental/fori_loop.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 516e65b3247..e4d55d470aa 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -117,7 +117,9 @@ def new_body_fn(*carried_inputs): # print("res: ", res) # return list(body_fn(*carried_inputs)).extend(additional_inputs) res = list(body_fn(*carried_inputs)) - res.extend(additional_inputs) + # res.extend(additional_inputs) + res.append(body_fn.bias) + res.append(body_fn.weight) return res return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) From b7d90a61973c2581fd9ed1d9e138bae7602fd188 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:53:16 +0000 Subject: [PATCH 371/546] down into cpp --- torch_xla/experimental/fori_loop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index e4d55d470aa..5c90496b4f0 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -117,9 +117,10 @@ def new_body_fn(*carried_inputs): # print("res: ", res) # return list(body_fn(*carried_inputs)).extend(additional_inputs) res = list(body_fn(*carried_inputs)) + res.insert(-1, additional_inputs) # res.extend(additional_inputs) - res.append(body_fn.bias) - res.append(body_fn.weight) + # res.append(body_fn.bias) + # res.append(body_fn.weight) return res return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) From dff1e9d6bc30166ecea82aaafb9913a23a5f373c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:54:16 +0000 Subject: [PATCH 372/546] down into cpp --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 5c90496b4f0..3e2fd15a329 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -117,6 +117,7 @@ def new_body_fn(*carried_inputs): # print("res: ", res) # return list(body_fn(*carried_inputs)).extend(additional_inputs) res = list(body_fn(*carried_inputs)) + print("res: ", res) res.insert(-1, additional_inputs) # res.extend(additional_inputs) # res.append(body_fn.bias) From 251bd6f2b2daa328af83d0cf27cf818ab12fe4f0 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:56:32 +0000 Subject: [PATCH 373/546] down into cpp --- torch_xla/experimental/fori_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 3e2fd15a329..ebc754dd6c0 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -118,7 +118,8 @@ def new_body_fn(*carried_inputs): # return list(body_fn(*carried_inputs)).extend(additional_inputs) res = list(body_fn(*carried_inputs)) print("res: ", res) - res.insert(-1, additional_inputs) + res.insert(-2, additional_inputs) + print("new res: ", res) # res.extend(additional_inputs) # res.append(body_fn.bias) # res.append(body_fn.weight) From 87f9ecadcf4feae7d48167405ea319b06949700c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:57:23 +0000 Subject: [PATCH 374/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index ebc754dd6c0..fcc80083ae3 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -118,7 +118,7 @@ def new_body_fn(*carried_inputs): # return list(body_fn(*carried_inputs)).extend(additional_inputs) res = list(body_fn(*carried_inputs)) print("res: ", res) - res.insert(-2, additional_inputs) + res.insert(-2, *additional_inputs) print("new res: ", res) # res.extend(additional_inputs) # res.append(body_fn.bias) From d15a492a255c373a69b0065534733efc105e1dca Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:58:17 +0000 Subject: [PATCH 375/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index fcc80083ae3..0d46622b5e6 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -118,6 +118,8 @@ def new_body_fn(*carried_inputs): # return list(body_fn(*carried_inputs)).extend(additional_inputs) res = list(body_fn(*carried_inputs)) print("res: ", res) + newres = res[:-1] + additional_inputs + res[-1] + print("newres: ", newres) res.insert(-2, *additional_inputs) print("new res: ", res) # res.extend(additional_inputs) From 15ec18555e1aac14cda9f9c4320b4e9edadbfbdc Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:58:47 +0000 Subject: [PATCH 376/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 0d46622b5e6..92288bd335c 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -118,7 +118,7 @@ def new_body_fn(*carried_inputs): # return list(body_fn(*carried_inputs)).extend(additional_inputs) res = list(body_fn(*carried_inputs)) print("res: ", res) - newres = res[:-1] + additional_inputs + res[-1] + newres = res[:-1] + list(additional_inputs) + res[-1] print("newres: ", newres) res.insert(-2, *additional_inputs) print("new res: ", res) From 31afff57063b746701cf28c066d6ca61c76c5476 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 21:59:56 +0000 Subject: [PATCH 377/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 92288bd335c..82d2a3ba62e 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -118,6 +118,8 @@ def new_body_fn(*carried_inputs): # return list(body_fn(*carried_inputs)).extend(additional_inputs) res = list(body_fn(*carried_inputs)) print("res: ", res) + trynewres = res[:-1] + res[-1] + print("trynewres: ", trynewres) newres = res[:-1] + list(additional_inputs) + res[-1] print("newres: ", newres) res.insert(-2, *additional_inputs) From 6e3c69905b9bc15c8668a80a570b66d010e0d5ed Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 22:00:34 +0000 Subject: [PATCH 378/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 82d2a3ba62e..12748cea24a 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -118,7 +118,7 @@ def new_body_fn(*carried_inputs): # return list(body_fn(*carried_inputs)).extend(additional_inputs) res = list(body_fn(*carried_inputs)) print("res: ", res) - trynewres = res[:-1] + res[-1] + trynewres = res[:-1].extend(res[-1]) print("trynewres: ", trynewres) newres = res[:-1] + list(additional_inputs) + res[-1] print("newres: ", newres) From 8405149fec81579fdad82eedaf910d3bc378c7cc Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 22:00:53 +0000 Subject: [PATCH 379/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 12748cea24a..474baf1f871 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -118,7 +118,7 @@ def new_body_fn(*carried_inputs): # return list(body_fn(*carried_inputs)).extend(additional_inputs) res = list(body_fn(*carried_inputs)) print("res: ", res) - trynewres = res[:-1].extend(res[-1]) + trynewres = res[:-1] + [res[-1]] print("trynewres: ", trynewres) newres = res[:-1] + list(additional_inputs) + res[-1] print("newres: ", newres) From 25efff622cb6d405d770a94656357697de636275 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 22:01:20 +0000 Subject: [PATCH 380/546] down into cpp --- torch_xla/experimental/fori_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 474baf1f871..a5b6bee0d31 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -118,9 +118,9 @@ def new_body_fn(*carried_inputs): # return list(body_fn(*carried_inputs)).extend(additional_inputs) res = list(body_fn(*carried_inputs)) print("res: ", res) - trynewres = res[:-1] + [res[-1]] - print("trynewres: ", trynewres) - newres = res[:-1] + list(additional_inputs) + res[-1] + # trynewres = res[:-1] + [res[-1]] + # print("trynewres: ", trynewres) + newres = res[:-1] + list(additional_inputs) + [res[-1]] print("newres: ", newres) res.insert(-2, *additional_inputs) print("new res: ", res) From 8790af29b4f047268027780e94fcba18aa8fc195 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 22:01:42 +0000 Subject: [PATCH 381/546] down into cpp --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index a5b6bee0d31..b41a92f5c76 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -117,7 +117,7 @@ def new_body_fn(*carried_inputs): # print("res: ", res) # return list(body_fn(*carried_inputs)).extend(additional_inputs) res = list(body_fn(*carried_inputs)) - print("res: ", res) + # print("res: ", res) # trynewres = res[:-1] + [res[-1]] # print("trynewres: ", trynewres) newres = res[:-1] + list(additional_inputs) + [res[-1]] From be5710d0da5d091e059ab8c395e8b296ce458f11 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 22:02:14 +0000 Subject: [PATCH 382/546] down into cpp --- torch_xla/experimental/fori_loop.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index b41a92f5c76..265f42bf7ad 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -122,12 +122,13 @@ def new_body_fn(*carried_inputs): # print("trynewres: ", trynewres) newres = res[:-1] + list(additional_inputs) + [res[-1]] print("newres: ", newres) - res.insert(-2, *additional_inputs) - print("new res: ", res) + # res.insert(-2, *additional_inputs) + # print("new res: ", res) # res.extend(additional_inputs) # res.append(body_fn.bias) # res.append(body_fn.weight) - return res + # return res + return newres return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) print("$$$ additional_inputs: ", additional_inputs) From 9cbdb405b25816eebf25f6ce3ef45a9b952a6a61 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 22:02:57 +0000 Subject: [PATCH 383/546] down into cpp --- ...fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index a4739f6374a..6c14af9eb2a 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -159,7 +159,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): l_in_0 = torch.rand(10, device=xm.xla_device()) output_value = torch.zeros([20], dtype=torch.float32, device=device) - upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop( + # upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop( + # cond_fn, body_fn, + # (upper, lower, one_value, init_val, l_in_0, output_value)) + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, = while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) From 5b8fcb7fc93702705d8b6104e002242cca86aee9 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 22:03:56 +0000 Subject: [PATCH 384/546] down into cpp --- torch_xla/experimental/fori_loop.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 265f42bf7ad..6d406fdecaf 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -171,9 +171,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - cond_hlo_print = xb.get_computation_hlo(cond_computation) - print("cond computation: !!!!!!!!!") - print(cond_hlo_print) + # cond_hlo_print = xb.get_computation_hlo(cond_computation) + # print("cond computation: !!!!!!!!!") + # print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) @@ -192,9 +192,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - body_hlo_print = xb.get_computation_hlo(body_computation) - print("body computation: !!!!!!!!!") - print(body_hlo_print) + # body_hlo_print = xb.get_computation_hlo(body_computation) + # print("body computation: !!!!!!!!!") + # print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs @@ -223,9 +223,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_computation=body_computation) name = 'fori_loop_ed_torch_func' computation = w.build(name) - hlo_print = xb.get_computation_hlo(computation) - print("while computation: !!!!!!!!!") - print(hlo_print) + # hlo_print = xb.get_computation_hlo(computation) + # print("while computation: !!!!!!!!!") + # print(hlo_print) # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', From 7d07234916952a54938ca7a25f9b9259a9bd9d8c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 22:04:13 +0000 Subject: [PATCH 385/546] down into cpp --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 6c14af9eb2a..254df2d43c1 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -142,7 +142,7 @@ def test_while_loop_tpu_simple_linear_wrapper(self): linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) def cond_fn(upper, lower, one_value, x, input_value, output_value): - return lower[0] < upper[0] + return lower[0] >= upper[0] def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) From f11daa0c9b09c3c1f4bd23879446dde4163e4ca6 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 22:05:37 +0000 Subject: [PATCH 386/546] down into cpp --- ...st_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py | 2 +- torch_xla/experimental/fori_loop.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 254df2d43c1..6c14af9eb2a 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -142,7 +142,7 @@ def test_while_loop_tpu_simple_linear_wrapper(self): linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) def cond_fn(upper, lower, one_value, x, input_value, output_value): - return lower[0] >= upper[0] + return lower[0] < upper[0] def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 6d406fdecaf..fcaa337e487 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -121,7 +121,7 @@ def new_body_fn(*carried_inputs): # trynewres = res[:-1] + [res[-1]] # print("trynewres: ", trynewres) newres = res[:-1] + list(additional_inputs) + [res[-1]] - print("newres: ", newres) + # print("newres: ", newres) # res.insert(-2, *additional_inputs) # print("new res: ", res) # res.extend(additional_inputs) From 3aa67e7a51fc635cb0180eaeccf051680fc26555 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 22:43:56 +0000 Subject: [PATCH 387/546] down into cpp --- ...while_loop_simple_add_dispatch_in_torch.py | 62 +++++++++++++++++++ torch_xla/experimental/fori_loop.py | 1 + 2 files changed, 63 insertions(+) diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 6c14af9eb2a..e20f1e5a5bf 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -233,6 +233,68 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) return aaa +### +# + def test_while_loop_tpu_simple_linear_class_wrapper(self): + + print("$$$ test_while_loop_tpu_simple_linear_class_wrapper !!!") + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + class SimpleWithLinear(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def forward(self, upper, lower, one_value, x, input_value, output_value): + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value_real = self.linear(input_value) + weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone( + ), output_value_real + + return while_loop( + cond_fn, body_fn, + (upper, lower, one_value, x, input_value, output_value)) + + simple_with_linear = SimpleWithLinear() + upper = torch.tensor([52], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + weight_0 = simple_with_linear.linear.weight + bias_0 = simple_with_linear.linear.bias + + aaa = { + "simple_with_linear": + (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, + output_value)) + } + + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear( + upper, lower, one_value, init_val, l_in_0, output_value) + + # create same weight/bias liear model for compare + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + linear_0.weight.data = weight__ + linear_0.bias.data = bias__ + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + return aaa + # additional_inputs: () def test_fori_loop_tpu_addition(self): diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index fcaa337e487..8b5dc7efcc0 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -116,6 +116,7 @@ def new_body_fn(*carried_inputs): # res.extend(additional_inputs) # print("res: ", res) # return list(body_fn(*carried_inputs)).extend(additional_inputs) + self.named_parameters res = list(body_fn(*carried_inputs)) # print("res: ", res) # trynewres = res[:-1] + [res[-1]] From 760d49c4615545fa1aa5669fc7191f93f019be0d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 22:45:57 +0000 Subject: [PATCH 388/546] down into cpp --- torch_xla/experimental/fori_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 8b5dc7efcc0..3d0fff3d49f 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -116,7 +116,8 @@ def new_body_fn(*carried_inputs): # res.extend(additional_inputs) # print("res: ", res) # return list(body_fn(*carried_inputs)).extend(additional_inputs) - self.named_parameters + # self.named_parameters + weight = self.linear.weight res = list(body_fn(*carried_inputs)) # print("res: ", res) # trynewres = res[:-1] + [res[-1]] From 9f7ddc264a975f8e292999166d1be73f3779db3a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 22:58:05 +0000 Subject: [PATCH 389/546] down into cpp --- test/test_train_mp_mnist.py | 3 ++- torch_xla/experimental/fori_loop.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 3b078d22fab..661c47f8f8a 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -170,6 +170,7 @@ def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() + print("loader: ", loader) for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] @@ -185,7 +186,7 @@ def test_loop_fn(loader): accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) - train_loop_fn(train_device_loader, epoch) + # train_loop_fn(train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 3d0fff3d49f..3bab0bbc665 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -117,7 +117,7 @@ def new_body_fn(*carried_inputs): # print("res: ", res) # return list(body_fn(*carried_inputs)).extend(additional_inputs) # self.named_parameters - weight = self.linear.weight + # weight = self.linear.weight res = list(body_fn(*carried_inputs)) # print("res: ", res) # trynewres = res[:-1] + [res[-1]] From 76b4b829bd8905f29d07de1c2b05243250f55062 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 23:05:36 +0000 Subject: [PATCH 390/546] down into cpp --- test/test_train_mp_mnist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 661c47f8f8a..03589a21fb3 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -171,6 +171,7 @@ def test_loop_fn(loader): correct = 0 model.eval() print("loader: ", loader) + print("type loader: ", type(loader)) for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] @@ -181,7 +182,7 @@ def test_loop_fn(loader): accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy - train_device_loader = pl.MpDeviceLoader(train_loader, device) + # train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): From 042d37cb6c38c987a1b5e510d967fbfd8872e59f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 23:14:19 +0000 Subject: [PATCH 391/546] down into cpp --- test/test_train_mp_mnist.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 03589a21fb3..7d1f83afe0b 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -186,21 +186,31 @@ def test_loop_fn(loader): test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): - xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) + # xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) # train_loop_fn(train_device_loader, epoch) - xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) - + # xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) - xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( - epoch, test_utils.now(), accuracy)) + # xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) - test_utils.write_to_summary( - writer, - epoch, - dict_to_write={'Accuracy/test': accuracy}, - write_xla_metrics=True) - if flags.metrics_debug: - xm.master_print(met.metrics_report()) + # test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) + # if flags.metrics_debug: xm.master_print(met.metrics_report()) + + # ### fori_loop + # # torch.set_grad_enabled(False) + # new_test_device_loader = pl.MpDeviceLoader(test_loader, device) + # upper = torch.tensor([flags.num_epochs + 1], dtype=torch.int32, device=device) # flags.num_epochs + 1 + # lower = torch.tensor([1], dtype=torch.int32, device=device) # 1 + # init_val = torch.tensor([1], dtype=torch.int32, device=device) + # # l_in_0 = torch.randn(10, device=xm.xla_device()) # test_device_loader + # # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + # def body_fun(test_device_loader): + # accuracy = test_loop_fn(test_device_loader) + # max_accuracy = max(accuracy, max_accuracy) + # return max_accuracy + + # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop( + # upper, lower, body_fun, init_val, new_test_device_loader) + test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) From 088355bbc6844f9781a937ab3efa7cd6fd94dd77 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 23:16:31 +0000 Subject: [PATCH 392/546] down into cpp --- test/test_train_mp_mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 7d1f83afe0b..f27e1f58ed6 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -170,8 +170,8 @@ def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() - print("loader: ", loader) - print("type loader: ", type(loader)) + # print("loader: ", loader) + # print("type loader: ", type(loader)) for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] From 67b5840d60c9d5a9cc44df84edee0ca0232f8653 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 23:22:44 +0000 Subject: [PATCH 393/546] down into cpp --- test/test_test_mnist.py | 316 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 316 insertions(+) create mode 100644 test/test_test_mnist.py diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py new file mode 100644 index 00000000000..e64c5e158dd --- /dev/null +++ b/test/test_test_mnist.py @@ -0,0 +1,316 @@ +import torch +import torchvision + +n_epochs = 3 +batch_size_train = 8 # 64 +batch_size_test = 10 # 1000 +learning_rate = 0.01 +momentum = 0.5 +log_interval = 10 + +random_seed = 1 +torch.backends.cudnn.enabled = False +torch.manual_seed(random_seed) + +### load data +test_loader = xu.SampleGenerator( + data=(torch.zeros(flags.batch_size, 1, 28, + 28), torch.zeros(flags.batch_size, + dtype=torch.int64)), + sample_count=10000 // flags.batch_size // xm.xrt_world_size()) + +examples = enumerate(test_loader) +batch_idx, (example_data, example_targets) = next(examples) + +example_data.shape + +### build model +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x) + + +network = Net() +optimizer = optim.SGD(network.parameters(), lr=learning_rate, + momentum=momentum) + + +train_losses = [] +train_counter = [] +test_losses = [] +test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)] + + +def test(): + network.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + output = network(data) + test_loss += F.nll_loss(output, target, size_average=False).item() + pred = output.data.max(1, keepdim=True)[1] + correct += pred.eq(target.data.view_as(pred)).sum() + test_loss /= len(test_loader.dataset) + test_losses.append(test_loss) + print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + +# run test model +test() +for epoch in range(1, n_epochs + 1): + test() + + + +# import args_parse +# from torch_xla import runtime as xr + +# MODEL_OPTS = { +# '--ddp': { +# 'action': 'store_true', +# }, +# '--pjrt_distributed': { +# 'action': 'store_true', +# }, +# } + +# FLAGS = args_parse.parse_common_options( +# datadir='/tmp/mnist-data', +# batch_size=128, +# momentum=0.5, +# lr=0.01, +# target_accuracy=98.0, +# num_epochs=18, +# opts=MODEL_OPTS.items(), +# ) + +# import os +# import shutil +# import sys +# import numpy as np +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# import torch.optim as optim +# from torchvision import datasets, transforms +# import torch_xla +# import torch_xla.debug.metrics as met +# import torch_xla.distributed.parallel_loader as pl +# import torch_xla.utils.utils as xu +# import torch_xla.core.xla_model as xm +# import torch_xla.distributed.xla_multiprocessing as xmp +# import torch_xla.test.test_utils as test_utils + +# import torch.distributed as dist +# from torch.nn.parallel import DistributedDataParallel as DDP +# import torch_xla.distributed.xla_backend + + +# class MNIST(nn.Module): + +# def __init__(self): +# super(MNIST, self).__init__() +# self.conv1 = nn.Conv2d(1, 10, kernel_size=5) +# self.bn1 = nn.BatchNorm2d(10) +# self.conv2 = nn.Conv2d(10, 20, kernel_size=5) +# self.bn2 = nn.BatchNorm2d(20) +# self.fc1 = nn.Linear(320, 50) +# self.fc2 = nn.Linear(50, 10) + +# def forward(self, x): +# x = F.relu(F.max_pool2d(self.conv1(x), 2)) +# x = self.bn1(x) +# x = F.relu(F.max_pool2d(self.conv2(x), 2)) +# x = self.bn2(x) +# x = torch.flatten(x, 1) +# x = F.relu(self.fc1(x)) +# x = self.fc2(x) +# return F.log_softmax(x, dim=1) + + +# def _train_update(device, step, loss, tracker, epoch, writer): +# test_utils.print_training_update( +# device, +# step, +# loss.item(), +# tracker.rate(), +# tracker.global_rate(), +# epoch, +# summary_writer=writer) + + +# def train_mnist(flags, **kwargs): +# if flags.ddp or flags.pjrt_distributed: +# dist.init_process_group('xla', init_method='xla://') + +# torch.manual_seed(1) + +# if flags.fake_data: +# train_loader = xu.SampleGenerator( +# data=(torch.zeros(flags.batch_size, 1, 28, +# 28), torch.zeros(flags.batch_size, +# dtype=torch.int64)), +# sample_count=60000 // flags.batch_size // xm.xrt_world_size()) +# test_loader = xu.SampleGenerator( +# data=(torch.zeros(flags.batch_size, 1, 28, +# 28), torch.zeros(flags.batch_size, +# dtype=torch.int64)), +# sample_count=10000 // flags.batch_size // xm.xrt_world_size()) +# else: +# train_dataset = datasets.MNIST( +# os.path.join(flags.datadir, str(xm.get_ordinal())), +# train=True, +# download=True, +# transform=transforms.Compose( +# [transforms.ToTensor(), +# transforms.Normalize((0.1307,), (0.3081,))])) +# test_dataset = datasets.MNIST( +# os.path.join(flags.datadir, str(xm.get_ordinal())), +# train=False, +# download=True, +# transform=transforms.Compose( +# [transforms.ToTensor(), +# transforms.Normalize((0.1307,), (0.3081,))])) +# train_sampler = None +# if xm.xrt_world_size() > 1: +# train_sampler = torch.utils.data.distributed.DistributedSampler( +# train_dataset, +# num_replicas=xm.xrt_world_size(), +# rank=xm.get_ordinal(), +# shuffle=True) +# train_loader = torch.utils.data.DataLoader( +# train_dataset, +# batch_size=flags.batch_size, +# sampler=train_sampler, +# drop_last=flags.drop_last, +# shuffle=False if train_sampler else True, +# num_workers=flags.num_workers) +# test_loader = torch.utils.data.DataLoader( +# test_dataset, +# batch_size=flags.batch_size, +# drop_last=flags.drop_last, +# shuffle=False, +# num_workers=flags.num_workers) + +# # Scale learning rate to num cores +# lr = flags.lr * xm.xrt_world_size() + +# device = xm.xla_device() +# model = MNIST().to(device) + +# # Initialization is nondeterministic with multiple threads in PjRt. +# # Synchronize model parameters across replicas manually. +# if xr.using_pjrt(): +# xm.broadcast_master_param(model) + +# if flags.ddp: +# model = DDP(model, gradient_as_bucket_view=True) +# writer = None +# if xm.is_master_ordinal(): +# writer = test_utils.get_summary_writer(flags.logdir) +# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) +# loss_fn = nn.NLLLoss() + +# def train_loop_fn(loader, epoch): +# tracker = xm.RateTracker() +# model.train() +# for step, (data, target) in enumerate(loader): +# optimizer.zero_grad() +# output = model(data) +# loss = loss_fn(output, target) +# loss.backward() +# if flags.ddp: +# optimizer.step() +# else: +# xm.optimizer_step(optimizer) +# tracker.add(flags.batch_size) +# if step % flags.log_steps == 0: +# xm.add_step_closure( +# _train_update, +# args=(device, step, loss, tracker, epoch, writer), +# run_async=flags.async_closures) + +# def test_loop_fn(loader): +# total_samples = 0 +# correct = 0 +# model.eval() +# # print("loader: ", loader) +# # print("type loader: ", type(loader)) +# for data, target in loader: +# output = model(data) +# pred = output.max(1, keepdim=True)[1] +# correct += pred.eq(target.view_as(pred)).sum() +# total_samples += data.size()[0] + +# accuracy = 100.0 * correct.item() / total_samples +# accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) +# return accuracy + +# # train_device_loader = pl.MpDeviceLoader(train_loader, device) +# test_device_loader = pl.MpDeviceLoader(test_loader, device) +# accuracy, max_accuracy = 0.0, 0.0 +# for epoch in range(1, flags.num_epochs + 1): +# # xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) +# # train_loop_fn(train_device_loader, epoch) +# # xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) +# accuracy = test_loop_fn(test_device_loader) +# # xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(epoch, test_utils.now(), accuracy)) +# max_accuracy = max(accuracy, max_accuracy) +# # test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) +# # if flags.metrics_debug: xm.master_print(met.metrics_report()) + +# # ### fori_loop +# # # torch.set_grad_enabled(False) +# # new_test_device_loader = pl.MpDeviceLoader(test_loader, device) +# # upper = torch.tensor([flags.num_epochs + 1], dtype=torch.int32, device=device) # flags.num_epochs + 1 +# # lower = torch.tensor([1], dtype=torch.int32, device=device) # 1 +# # init_val = torch.tensor([1], dtype=torch.int32, device=device) +# # # l_in_0 = torch.randn(10, device=xm.xla_device()) # test_device_loader +# # # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) +# # def body_fun(test_device_loader): +# # accuracy = test_loop_fn(test_device_loader) +# # max_accuracy = max(accuracy, max_accuracy) +# # return max_accuracy + +# # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop( +# # upper, lower, body_fun, init_val, new_test_device_loader) + + +# test_utils.close_summary_writer(writer) +# xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) +# return max_accuracy + + +# def _mp_fn(index, flags): +# torch.set_default_dtype(torch.float32) +# accuracy = train_mnist(flags) +# if flags.tidy and os.path.isdir(flags.datadir): +# shutil.rmtree(flags.datadir) +# if accuracy < flags.target_accuracy: +# print('Accuracy {} is below target {}'.format(accuracy, +# flags.target_accuracy)) +# sys.exit(21) + + +# if __name__ == '__main__': +# xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores) From 1b81829588a26171f897188fd39a9c707bea9689 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 23:23:18 +0000 Subject: [PATCH 394/546] down into cpp --- test/test_test_mnist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index e64c5e158dd..6fdce77017a 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -13,6 +13,7 @@ torch.manual_seed(random_seed) ### load data +import torch_xla.utils.utils as xu test_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, From bee343f182da8bfeed83c6ff924c9ca3e816fff3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 23:24:36 +0000 Subject: [PATCH 395/546] down into cpp --- test/test_test_mnist.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 6fdce77017a..5bf3ed588dd 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -1,6 +1,27 @@ import torch import torchvision +import os +import shutil +import sys +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +import torch_xla +import torch_xla.debug.metrics as met +import torch_xla.distributed.parallel_loader as pl +import torch_xla.utils.utils as xu +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.test.test_utils as test_utils + +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import torch_xla.distributed.xla_backend + n_epochs = 3 batch_size_train = 8 # 64 batch_size_test = 10 # 1000 @@ -13,12 +34,11 @@ torch.manual_seed(random_seed) ### load data -import torch_xla.utils.utils as xu test_loader = xu.SampleGenerator( - data=(torch.zeros(flags.batch_size, 1, 28, - 28), torch.zeros(flags.batch_size, + data=(torch.zeros(8, 1, 28, + 28), torch.zeros(8, dtype=torch.int64)), - sample_count=10000 // flags.batch_size // xm.xrt_world_size()) + sample_count=1000 // 8 // xm.xrt_world_size()) examples = enumerate(test_loader) batch_idx, (example_data, example_targets) = next(examples) From 2de7464aa91053eac3fcbe105a722cb21984af7f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 23:25:39 +0000 Subject: [PATCH 396/546] down into cpp --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 5bf3ed588dd..5a19800fe39 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -77,7 +77,7 @@ def forward(self, x): train_losses = [] train_counter = [] test_losses = [] -test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)] +test_counter = [i*20 for i in range(n_epochs + 1)] def test(): From e25e9d2eabd58375d3fff1883ef702ad410d9a05 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 23:27:35 +0000 Subject: [PATCH 397/546] down into cpp --- test/test_test_mnist.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 5a19800fe39..7bc355ae7ba 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -35,15 +35,13 @@ ### load data test_loader = xu.SampleGenerator( - data=(torch.zeros(8, 1, 28, - 28), torch.zeros(8, - dtype=torch.int64)), + data=(torch.zeros(8, 1, 28,28), torch.zeros(8, dtype=torch.int64)), sample_count=1000 // 8 // xm.xrt_world_size()) examples = enumerate(test_loader) batch_idx, (example_data, example_targets) = next(examples) -example_data.shape +print("shape: ", example_data.shape) ### build model import torch.nn as nn @@ -90,7 +88,7 @@ def test(): test_loss += F.nll_loss(output, target, size_average=False).item() pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).sum() - test_loss /= len(test_loader.dataset) + test_loss /= 20 test_losses.append(test_loss) print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), From 58359ca803cce2a6d001fa1622e18d15af75d7c7 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 23:28:38 +0000 Subject: [PATCH 398/546] down into cpp --- test/test_test_mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 7bc355ae7ba..c3613045682 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -91,8 +91,8 @@ def test(): test_loss /= 20 test_losses.append(test_loss) print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) + test_loss, correct, 20, + 100. * correct / 20)) # run test model test() From 6ee65d158575702b75410b65de437f727aead2e9 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 23:39:04 +0000 Subject: [PATCH 399/546] down into cpp --- test/test_test_mnist.py | 273 ++++++---------------------------------- 1 file changed, 36 insertions(+), 237 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index c3613045682..9d3d09c0562 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -67,17 +67,39 @@ def forward(self, x): return F.log_softmax(x) -network = Net() +class MNIST(nn.Module): + + def __init__(self): + super(MNIST, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.bn1 = nn.BatchNorm2d(10) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.bn2 = nn.BatchNorm2d(20) + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = self.bn1(x) + x = F.relu(F.max_pool2d(self.conv2(x), 2)) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + +device = xm.xla_device() +# model = MNIST().to(device) +network = Net().to(device) optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum) - +# loss_fn = nn.NLLLoss() train_losses = [] train_counter = [] test_losses = [] test_counter = [i*20 for i in range(n_epochs + 1)] - def test(): network.eval() test_loss = 0 @@ -93,243 +115,20 @@ def test(): print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, 20, 100. * correct / 20)) + return test_loss # run test model -test() -for epoch in range(1, n_epochs + 1): - test() - - - -# import args_parse -# from torch_xla import runtime as xr - -# MODEL_OPTS = { -# '--ddp': { -# 'action': 'store_true', -# }, -# '--pjrt_distributed': { -# 'action': 'store_true', -# }, -# } - -# FLAGS = args_parse.parse_common_options( -# datadir='/tmp/mnist-data', -# batch_size=128, -# momentum=0.5, -# lr=0.01, -# target_accuracy=98.0, -# num_epochs=18, -# opts=MODEL_OPTS.items(), -# ) - -# import os -# import shutil -# import sys -# import numpy as np -# import torch -# import torch.nn as nn -# import torch.nn.functional as F -# import torch.optim as optim -# from torchvision import datasets, transforms -# import torch_xla -# import torch_xla.debug.metrics as met -# import torch_xla.distributed.parallel_loader as pl -# import torch_xla.utils.utils as xu -# import torch_xla.core.xla_model as xm -# import torch_xla.distributed.xla_multiprocessing as xmp -# import torch_xla.test.test_utils as test_utils - -# import torch.distributed as dist -# from torch.nn.parallel import DistributedDataParallel as DDP -# import torch_xla.distributed.xla_backend - - -# class MNIST(nn.Module): - -# def __init__(self): -# super(MNIST, self).__init__() -# self.conv1 = nn.Conv2d(1, 10, kernel_size=5) -# self.bn1 = nn.BatchNorm2d(10) -# self.conv2 = nn.Conv2d(10, 20, kernel_size=5) -# self.bn2 = nn.BatchNorm2d(20) -# self.fc1 = nn.Linear(320, 50) -# self.fc2 = nn.Linear(50, 10) - -# def forward(self, x): -# x = F.relu(F.max_pool2d(self.conv1(x), 2)) -# x = self.bn1(x) -# x = F.relu(F.max_pool2d(self.conv2(x), 2)) -# x = self.bn2(x) -# x = torch.flatten(x, 1) -# x = F.relu(self.fc1(x)) -# x = self.fc2(x) -# return F.log_softmax(x, dim=1) - - -# def _train_update(device, step, loss, tracker, epoch, writer): -# test_utils.print_training_update( -# device, -# step, -# loss.item(), -# tracker.rate(), -# tracker.global_rate(), -# epoch, -# summary_writer=writer) - +def test_mnist(): + if flags.ddp or flags.pjrt_distributed: + dist.init_process_group('xla', init_method='xla://') -# def train_mnist(flags, **kwargs): -# if flags.ddp or flags.pjrt_distributed: -# dist.init_process_group('xla', init_method='xla://') + torch.manual_seed(1) -# torch.manual_seed(1) - -# if flags.fake_data: -# train_loader = xu.SampleGenerator( -# data=(torch.zeros(flags.batch_size, 1, 28, -# 28), torch.zeros(flags.batch_size, -# dtype=torch.int64)), -# sample_count=60000 // flags.batch_size // xm.xrt_world_size()) -# test_loader = xu.SampleGenerator( -# data=(torch.zeros(flags.batch_size, 1, 28, -# 28), torch.zeros(flags.batch_size, -# dtype=torch.int64)), -# sample_count=10000 // flags.batch_size // xm.xrt_world_size()) -# else: -# train_dataset = datasets.MNIST( -# os.path.join(flags.datadir, str(xm.get_ordinal())), -# train=True, -# download=True, -# transform=transforms.Compose( -# [transforms.ToTensor(), -# transforms.Normalize((0.1307,), (0.3081,))])) -# test_dataset = datasets.MNIST( -# os.path.join(flags.datadir, str(xm.get_ordinal())), -# train=False, -# download=True, -# transform=transforms.Compose( -# [transforms.ToTensor(), -# transforms.Normalize((0.1307,), (0.3081,))])) -# train_sampler = None -# if xm.xrt_world_size() > 1: -# train_sampler = torch.utils.data.distributed.DistributedSampler( -# train_dataset, -# num_replicas=xm.xrt_world_size(), -# rank=xm.get_ordinal(), -# shuffle=True) -# train_loader = torch.utils.data.DataLoader( -# train_dataset, -# batch_size=flags.batch_size, -# sampler=train_sampler, -# drop_last=flags.drop_last, -# shuffle=False if train_sampler else True, -# num_workers=flags.num_workers) -# test_loader = torch.utils.data.DataLoader( -# test_dataset, -# batch_size=flags.batch_size, -# drop_last=flags.drop_last, -# shuffle=False, -# num_workers=flags.num_workers) - -# # Scale learning rate to num cores -# lr = flags.lr * xm.xrt_world_size() - -# device = xm.xla_device() -# model = MNIST().to(device) - -# # Initialization is nondeterministic with multiple threads in PjRt. -# # Synchronize model parameters across replicas manually. -# if xr.using_pjrt(): -# xm.broadcast_master_param(model) - -# if flags.ddp: -# model = DDP(model, gradient_as_bucket_view=True) -# writer = None -# if xm.is_master_ordinal(): -# writer = test_utils.get_summary_writer(flags.logdir) -# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) -# loss_fn = nn.NLLLoss() - -# def train_loop_fn(loader, epoch): -# tracker = xm.RateTracker() -# model.train() -# for step, (data, target) in enumerate(loader): -# optimizer.zero_grad() -# output = model(data) -# loss = loss_fn(output, target) -# loss.backward() -# if flags.ddp: -# optimizer.step() -# else: -# xm.optimizer_step(optimizer) -# tracker.add(flags.batch_size) -# if step % flags.log_steps == 0: -# xm.add_step_closure( -# _train_update, -# args=(device, step, loss, tracker, epoch, writer), -# run_async=flags.async_closures) - -# def test_loop_fn(loader): -# total_samples = 0 -# correct = 0 -# model.eval() -# # print("loader: ", loader) -# # print("type loader: ", type(loader)) -# for data, target in loader: -# output = model(data) -# pred = output.max(1, keepdim=True)[1] -# correct += pred.eq(target.view_as(pred)).sum() -# total_samples += data.size()[0] - -# accuracy = 100.0 * correct.item() / total_samples -# accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) -# return accuracy - -# # train_device_loader = pl.MpDeviceLoader(train_loader, device) -# test_device_loader = pl.MpDeviceLoader(test_loader, device) -# accuracy, max_accuracy = 0.0, 0.0 -# for epoch in range(1, flags.num_epochs + 1): -# # xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) -# # train_loop_fn(train_device_loader, epoch) -# # xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) -# accuracy = test_loop_fn(test_device_loader) -# # xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(epoch, test_utils.now(), accuracy)) -# max_accuracy = max(accuracy, max_accuracy) -# # test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) -# # if flags.metrics_debug: xm.master_print(met.metrics_report()) - -# # ### fori_loop -# # # torch.set_grad_enabled(False) -# # new_test_device_loader = pl.MpDeviceLoader(test_loader, device) -# # upper = torch.tensor([flags.num_epochs + 1], dtype=torch.int32, device=device) # flags.num_epochs + 1 -# # lower = torch.tensor([1], dtype=torch.int32, device=device) # 1 -# # init_val = torch.tensor([1], dtype=torch.int32, device=device) -# # # l_in_0 = torch.randn(10, device=xm.xla_device()) # test_device_loader -# # # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) -# # def body_fun(test_device_loader): -# # accuracy = test_loop_fn(test_device_loader) -# # max_accuracy = max(accuracy, max_accuracy) -# # return max_accuracy - -# # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop( -# # upper, lower, body_fun, init_val, new_test_device_loader) - - -# test_utils.close_summary_writer(writer) -# xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) -# return max_accuracy - - -# def _mp_fn(index, flags): -# torch.set_default_dtype(torch.float32) -# accuracy = train_mnist(flags) -# if flags.tidy and os.path.isdir(flags.datadir): -# shutil.rmtree(flags.datadir) -# if accuracy < flags.target_accuracy: -# print('Accuracy {} is below target {}'.format(accuracy, -# flags.target_accuracy)) -# sys.exit(21) + test() + # target fori_loop + for epoch in range(1, n_epochs + 1): + test() -# if __name__ == '__main__': -# xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores) +torch.set_default_dtype(torch.float32) +accuracy = test_mnist() From 72a8d073060f9152064d1a156a390fc510a3bd43 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 23:39:34 +0000 Subject: [PATCH 400/546] down into cpp --- test/test_test_mnist.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 9d3d09c0562..e5b18c08494 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -119,9 +119,6 @@ def test(): # run test model def test_mnist(): - if flags.ddp or flags.pjrt_distributed: - dist.init_process_group('xla', init_method='xla://') - torch.manual_seed(1) test() From 5734811d72c774e4b4f52ecff776a4f171b1331c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 23:40:03 +0000 Subject: [PATCH 401/546] down into cpp --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index e5b18c08494..894d1349858 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -127,5 +127,5 @@ def test_mnist(): test() -torch.set_default_dtype(torch.float32) +# torch.set_default_dtype(torch.float32) accuracy = test_mnist() From f244aab1a10618ce7de24f5297277f0df9f871f5 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 23:40:50 +0000 Subject: [PATCH 402/546] down into cpp --- test/test_test_mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 894d1349858..e482cd5ecdf 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -89,8 +89,8 @@ def forward(self, x): return F.log_softmax(x, dim=1) device = xm.xla_device() -# model = MNIST().to(device) -network = Net().to(device) +network = MNIST().to(device) +# network = Net().to(device) optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum) # loss_fn = nn.NLLLoss() From 85a7975f5bc8a896c7061199eedb5fb186d0e66a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 25 Apr 2024 00:02:35 +0000 Subject: [PATCH 403/546] down into cpp --- test/test_test_mnist.py | 324 ++++++++++++++++++++++++++++++---------- 1 file changed, 241 insertions(+), 83 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index e482cd5ecdf..df4cc32e687 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -1,5 +1,158 @@ -import torch -import torchvision +# import torch +# import torchvision + +# import os +# import shutil +# import sys +# import numpy as np +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# import torch.optim as optim +# from torchvision import datasets, transforms +# import torch_xla +# import torch_xla.debug.metrics as met +# import torch_xla.distributed.parallel_loader as pl +# import torch_xla.utils.utils as xu +# import torch_xla.core.xla_model as xm +# import torch_xla.distributed.xla_multiprocessing as xmp +# import torch_xla.test.test_utils as test_utils + +# import torch.distributed as dist +# from torch.nn.parallel import DistributedDataParallel as DDP +# import torch_xla.distributed.xla_backend + +# n_epochs = 3 +# batch_size_train = 8 # 64 +# batch_size_test = 10 # 1000 +# learning_rate = 0.01 +# momentum = 0.5 +# log_interval = 10 + +# random_seed = 1 +# torch.backends.cudnn.enabled = False +# torch.manual_seed(random_seed) + +# ### load data +# test_loader = xu.SampleGenerator( +# data=(torch.zeros(8, 1, 28,28), torch.zeros(8, dtype=torch.int64)), +# sample_count=1000 // 8 // xm.xrt_world_size()) + +# examples = enumerate(test_loader) +# batch_idx, (example_data, example_targets) = next(examples) + +# print("shape: ", example_data.shape) + +# ### build model +# import torch.nn as nn +# import torch.nn.functional as F +# import torch.optim as optim + +# class Net(nn.Module): +# def __init__(self): +# super(Net, self).__init__() +# self.conv1 = nn.Conv2d(1, 10, kernel_size=5) +# self.conv2 = nn.Conv2d(10, 20, kernel_size=5) +# self.conv2_drop = nn.Dropout2d() +# self.fc1 = nn.Linear(320, 50) +# self.fc2 = nn.Linear(50, 10) + +# def forward(self, x): +# x = F.relu(F.max_pool2d(self.conv1(x), 2)) +# x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) +# x = x.view(-1, 320) +# x = F.relu(self.fc1(x)) +# x = F.dropout(x, training=self.training) +# x = self.fc2(x) +# return F.log_softmax(x) + + +# class MNIST(nn.Module): + +# def __init__(self): +# super(MNIST, self).__init__() +# self.conv1 = nn.Conv2d(1, 10, kernel_size=5) +# self.bn1 = nn.BatchNorm2d(10) +# self.conv2 = nn.Conv2d(10, 20, kernel_size=5) +# self.bn2 = nn.BatchNorm2d(20) +# self.fc1 = nn.Linear(320, 50) +# self.fc2 = nn.Linear(50, 10) + +# def forward(self, x): +# x = F.relu(F.max_pool2d(self.conv1(x), 2)) +# x = self.bn1(x) +# x = F.relu(F.max_pool2d(self.conv2(x), 2)) +# x = self.bn2(x) +# x = torch.flatten(x, 1) +# x = F.relu(self.fc1(x)) +# x = self.fc2(x) +# return F.log_softmax(x, dim=1) + +# device = xm.xla_device() +# network = MNIST().to(device) +# # network = Net().to(device) +# optimizer = optim.SGD(network.parameters(), lr=learning_rate, +# momentum=momentum) +# # loss_fn = nn.NLLLoss() + +# train_losses = [] +# train_counter = [] +# test_losses = [] +# test_counter = [i*20 for i in range(n_epochs + 1)] + +# def test(): +# network.eval() +# test_loss = 0 +# correct = 0 +# with torch.no_grad(): +# for data, target in test_loader: +# output = network(data) +# test_loss += F.nll_loss(output, target, size_average=False).item() +# pred = output.data.max(1, keepdim=True)[1] +# correct += pred.eq(target.data.view_as(pred)).sum() +# test_loss /= 20 +# test_losses.append(test_loss) +# print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( +# test_loss, correct, 20, +# 100. * correct / 20)) +# return test_loss + +# # run test model +# def test_mnist(): +# torch.manual_seed(1) + +# test() +# # target fori_loop +# for epoch in range(1, n_epochs + 1): +# test() + + +# # torch.set_default_dtype(torch.float32) +# accuracy = test_mnist() + + + +import args_parse +from torch_xla import runtime as xr + +# MODEL_OPTS = { +# '--ddp': { +# 'action': 'store_true', +# }, +# '--pjrt_distributed': { +# 'action': 'store_true', +# }, +# } + +FLAGS = args_parse.parse_common_options( + datadir='/tmp/mnist-data', + batch_size=128, + momentum=0.5, + lr=0.01, + target_accuracy=98.0, + num_epochs=18, + opts=MODEL_OPTS.items(), +) import os import shutil @@ -22,50 +175,6 @@ from torch.nn.parallel import DistributedDataParallel as DDP import torch_xla.distributed.xla_backend -n_epochs = 3 -batch_size_train = 8 # 64 -batch_size_test = 10 # 1000 -learning_rate = 0.01 -momentum = 0.5 -log_interval = 10 - -random_seed = 1 -torch.backends.cudnn.enabled = False -torch.manual_seed(random_seed) - -### load data -test_loader = xu.SampleGenerator( - data=(torch.zeros(8, 1, 28,28), torch.zeros(8, dtype=torch.int64)), - sample_count=1000 // 8 // xm.xrt_world_size()) - -examples = enumerate(test_loader) -batch_idx, (example_data, example_targets) = next(examples) - -print("shape: ", example_data.shape) - -### build model -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim - -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 10, kernel_size=5) - self.conv2 = nn.Conv2d(10, 20, kernel_size=5) - self.conv2_drop = nn.Dropout2d() - self.fc1 = nn.Linear(320, 50) - self.fc2 = nn.Linear(50, 10) - - def forward(self, x): - x = F.relu(F.max_pool2d(self.conv1(x), 2)) - x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) - x = x.view(-1, 320) - x = F.relu(self.fc1(x)) - x = F.dropout(x, training=self.training) - x = self.fc2(x) - return F.log_softmax(x) - class MNIST(nn.Module): @@ -88,44 +197,93 @@ def forward(self, x): x = self.fc2(x) return F.log_softmax(x, dim=1) -device = xm.xla_device() -network = MNIST().to(device) -# network = Net().to(device) -optimizer = optim.SGD(network.parameters(), lr=learning_rate, - momentum=momentum) -# loss_fn = nn.NLLLoss() - -train_losses = [] -train_counter = [] -test_losses = [] -test_counter = [i*20 for i in range(n_epochs + 1)] - -def test(): - network.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - output = network(data) - test_loss += F.nll_loss(output, target, size_average=False).item() - pred = output.data.max(1, keepdim=True)[1] - correct += pred.eq(target.data.view_as(pred)).sum() - test_loss /= 20 - test_losses.append(test_loss) - print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, 20, - 100. * correct / 20)) - return test_loss - -# run test model -def test_mnist(): + +def train_mnist(flags, **kwargs): torch.manual_seed(1) - test() - # target fori_loop - for epoch in range(1, n_epochs + 1): - test() + test_loader = xu.SampleGenerator( + data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), + sample_count=10000 // flags.batch_size // xm.xrt_world_size()) + + # Scale learning rate to num cores + lr = flags.lr * xm.xrt_world_size() + device = xm.xla_device() + model = MNIST().to(device) + + # Initialization is nondeterministic with multiple threads in PjRt. + # Synchronize model parameters across replicas manually. + if xr.using_pjrt(): + xm.broadcast_master_param(model) + + writer = None + if xm.is_master_ordinal(): + writer = test_utils.get_summary_writer(flags.logdir) + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) + loss_fn = nn.NLLLoss() + + def test_loop_fn(loader): + total_samples = 0 + correct = 0 + model.eval() + # print("loader: ", loader) + # print("type loader: ", type(loader)) + for data, target in loader: + output = model(data) + pred = output.max(1, keepdim=True)[1] + correct += pred.eq(target.view_as(pred)).sum() + total_samples += data.size()[0] + + accuracy = 100.0 * correct.item() / total_samples + accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) + return accuracy + + # train_device_loader = pl.MpDeviceLoader(train_loader, device) + test_device_loader = pl.MpDeviceLoader(test_loader, device) + accuracy, max_accuracy = 0.0, 0.0 + for epoch in range(1, flags.num_epochs + 1): + # xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) + # train_loop_fn(train_device_loader, epoch) + # xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) + accuracy = test_loop_fn(test_device_loader) + # xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(epoch, test_utils.now(), accuracy)) + max_accuracy = max(accuracy, max_accuracy) + # test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) + # if flags.metrics_debug: xm.master_print(met.metrics_report()) + + # ### fori_loop + # # torch.set_grad_enabled(False) + # new_test_device_loader = pl.MpDeviceLoader(test_loader, device) + # upper = torch.tensor([flags.num_epochs + 1], dtype=torch.int32, device=device) # flags.num_epochs + 1 + # lower = torch.tensor([1], dtype=torch.int32, device=device) # 1 + # init_val = torch.tensor([1], dtype=torch.int32, device=device) + # # l_in_0 = torch.randn(10, device=xm.xla_device()) # test_device_loader + # # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + # def body_fun(test_device_loader): + # accuracy = test_loop_fn(test_device_loader) + # max_accuracy = max(accuracy, max_accuracy) + # return max_accuracy + + # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop( + # upper, lower, body_fun, init_val, new_test_device_loader) + + test_utils.close_summary_writer(writer) + xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) + return max_accuracy + + +# def _mp_fn(index, flags): +def main_fun(flags): + torch.set_default_dtype(torch.float32) + accuracy = train_mnist(flags) + if flags.tidy and os.path.isdir(flags.datadir): + shutil.rmtree(flags.datadir) + if accuracy < flags.target_accuracy: + print('Accuracy {} is below target {}'.format(accuracy, + flags.target_accuracy)) + sys.exit(21) -# torch.set_default_dtype(torch.float32) -accuracy = test_mnist() +if __name__ == '__main__': +# xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores) +# _mp_fn() + main_fun(FLAGS) From c7a2d9a90261751b4a7e7784df5bda2aa77efbf6 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 25 Apr 2024 00:04:02 +0000 Subject: [PATCH 404/546] down into cpp --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index df4cc32e687..dd6adce4515 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -151,7 +151,7 @@ lr=0.01, target_accuracy=98.0, num_epochs=18, - opts=MODEL_OPTS.items(), + # opts=MODEL_OPTS.items(), ) import os From 6afbc66a3a2c008b1d73284eec1648229b0904e4 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 25 Apr 2024 00:04:38 +0000 Subject: [PATCH 405/546] down into cpp --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index dd6adce4515..92992f8e73b 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -245,7 +245,7 @@ def test_loop_fn(loader): # train_loop_fn(train_device_loader, epoch) # xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) - # xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(epoch, test_utils.now(), accuracy)) + xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) # test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) # if flags.metrics_debug: xm.master_print(met.metrics_report()) From 6c71661ae1d7efb2e4082085281ccbb2ecb19948 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 25 Apr 2024 00:12:50 +0000 Subject: [PATCH 406/546] down into cpp --- test/test_test_mnist.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 92992f8e73b..e7348b274c1 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -240,6 +240,7 @@ def test_loop_fn(loader): # train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 + for epoch in range(1, flags.num_epochs + 1): # xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) # train_loop_fn(train_device_loader, epoch) @@ -250,21 +251,26 @@ def test_loop_fn(loader): # test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) # if flags.metrics_debug: xm.master_print(met.metrics_report()) - # ### fori_loop - # # torch.set_grad_enabled(False) + ### fori_loop + # torch.set_grad_enabled(False) # new_test_device_loader = pl.MpDeviceLoader(test_loader, device) - # upper = torch.tensor([flags.num_epochs + 1], dtype=torch.int32, device=device) # flags.num_epochs + 1 - # lower = torch.tensor([1], dtype=torch.int32, device=device) # 1 - # init_val = torch.tensor([1], dtype=torch.int32, device=device) - # # l_in_0 = torch.randn(10, device=xm.xla_device()) # test_device_loader - # # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - # def body_fun(test_device_loader): - # accuracy = test_loop_fn(test_device_loader) - # max_accuracy = max(accuracy, max_accuracy) - # return max_accuracy - - # upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop( - # upper, lower, body_fun, init_val, new_test_device_loader) + upper = torch.tensor([flags.num_epochs + 1], dtype=torch.int32, device=device) # flags.num_epochs + 1 + lower = torch.tensor([1], dtype=torch.int32, device=device) # 1 + init_val = torch.tensor([1], dtype=torch.int32, device=device) + # l_in_0 = torch.randn(10, device=xm.xla_device()) # test_device_loader + # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + def body_fun(test_device_loader): + res1 = torch.tensor([2], dtype=torch.int32, device=device) + res2 = torch.tensor([2], dtype=torch.int32, device=device) + res3 = res1 + res2 + return res3 +# def body_fun(test_device_loader): +# accuracy = test_loop_fn(test_device_loader) +# max_accuracy = max(accuracy, max_accuracy) +# return max_accuracy + + upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop( + upper, lower, body_fun, init_val, test_device_loader) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) From b492c01c0026951cba1d302498cbba30acbc6c4f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 25 Apr 2024 00:13:33 +0000 Subject: [PATCH 407/546] down into cpp --- test/test_test_mnist.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index e7348b274c1..3918fbdfaeb 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -175,6 +175,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP import torch_xla.distributed.xla_backend +import torch_xla.experimental.fori_loop +from torch_xla.experimental.fori_loop import fori_loop + class MNIST(nn.Module): From 8dc49d5203a63127120b6c7c9a496e6df7cfdeb7 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 25 Apr 2024 00:21:40 +0000 Subject: [PATCH 408/546] down into cpp --- test/test_test_mnist.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 3918fbdfaeb..b608298bec3 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -224,13 +224,14 @@ def train_mnist(flags, **kwargs): optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() - def test_loop_fn(loader): + def test_loop_fn(): # loader): total_samples = 0 correct = 0 model.eval() # print("loader: ", loader) # print("type loader: ", type(loader)) - for data, target in loader: + # for data, target in loader: + for data, target in test_loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() @@ -262,7 +263,7 @@ def test_loop_fn(loader): init_val = torch.tensor([1], dtype=torch.int32, device=device) # l_in_0 = torch.randn(10, device=xm.xla_device()) # test_device_loader # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - def body_fun(test_device_loader): + def body_fun(): res1 = torch.tensor([2], dtype=torch.int32, device=device) res2 = torch.tensor([2], dtype=torch.int32, device=device) res3 = res1 + res2 @@ -273,7 +274,7 @@ def body_fun(test_device_loader): # return max_accuracy upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop( - upper, lower, body_fun, init_val, test_device_loader) + upper, lower, body_fun, ()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) From 32c0e1a5c3df6863622e172f878ca063e1a81ab5 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 25 Apr 2024 00:24:05 +0000 Subject: [PATCH 409/546] down into cpp --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index b608298bec3..80b21530e93 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -249,7 +249,7 @@ def test_loop_fn(): # loader): # xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) # train_loop_fn(train_device_loader, epoch) # xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) - accuracy = test_loop_fn(test_device_loader) + accuracy = test_loop_fn() # test_device_loader) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) # test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) From b2b14a3ae913c65e0d242945308e8312c75dfd1d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 00:45:07 +0000 Subject: [PATCH 410/546] format --- test/test_test_mnist.py | 537 ++++++++++++++++++++++------------------ 1 file changed, 290 insertions(+), 247 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 80b21530e93..7da754eeba6 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -1,159 +1,5 @@ -# import torch -# import torchvision - -# import os -# import shutil -# import sys -# import numpy as np -# import torch -# import torch.nn as nn -# import torch.nn.functional as F -# import torch.optim as optim -# from torchvision import datasets, transforms -# import torch_xla -# import torch_xla.debug.metrics as met -# import torch_xla.distributed.parallel_loader as pl -# import torch_xla.utils.utils as xu -# import torch_xla.core.xla_model as xm -# import torch_xla.distributed.xla_multiprocessing as xmp -# import torch_xla.test.test_utils as test_utils - -# import torch.distributed as dist -# from torch.nn.parallel import DistributedDataParallel as DDP -# import torch_xla.distributed.xla_backend - -# n_epochs = 3 -# batch_size_train = 8 # 64 -# batch_size_test = 10 # 1000 -# learning_rate = 0.01 -# momentum = 0.5 -# log_interval = 10 - -# random_seed = 1 -# torch.backends.cudnn.enabled = False -# torch.manual_seed(random_seed) - -# ### load data -# test_loader = xu.SampleGenerator( -# data=(torch.zeros(8, 1, 28,28), torch.zeros(8, dtype=torch.int64)), -# sample_count=1000 // 8 // xm.xrt_world_size()) - -# examples = enumerate(test_loader) -# batch_idx, (example_data, example_targets) = next(examples) - -# print("shape: ", example_data.shape) - -# ### build model -# import torch.nn as nn -# import torch.nn.functional as F -# import torch.optim as optim - -# class Net(nn.Module): -# def __init__(self): -# super(Net, self).__init__() -# self.conv1 = nn.Conv2d(1, 10, kernel_size=5) -# self.conv2 = nn.Conv2d(10, 20, kernel_size=5) -# self.conv2_drop = nn.Dropout2d() -# self.fc1 = nn.Linear(320, 50) -# self.fc2 = nn.Linear(50, 10) - -# def forward(self, x): -# x = F.relu(F.max_pool2d(self.conv1(x), 2)) -# x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) -# x = x.view(-1, 320) -# x = F.relu(self.fc1(x)) -# x = F.dropout(x, training=self.training) -# x = self.fc2(x) -# return F.log_softmax(x) - - -# class MNIST(nn.Module): - -# def __init__(self): -# super(MNIST, self).__init__() -# self.conv1 = nn.Conv2d(1, 10, kernel_size=5) -# self.bn1 = nn.BatchNorm2d(10) -# self.conv2 = nn.Conv2d(10, 20, kernel_size=5) -# self.bn2 = nn.BatchNorm2d(20) -# self.fc1 = nn.Linear(320, 50) -# self.fc2 = nn.Linear(50, 10) - -# def forward(self, x): -# x = F.relu(F.max_pool2d(self.conv1(x), 2)) -# x = self.bn1(x) -# x = F.relu(F.max_pool2d(self.conv2(x), 2)) -# x = self.bn2(x) -# x = torch.flatten(x, 1) -# x = F.relu(self.fc1(x)) -# x = self.fc2(x) -# return F.log_softmax(x, dim=1) - -# device = xm.xla_device() -# network = MNIST().to(device) -# # network = Net().to(device) -# optimizer = optim.SGD(network.parameters(), lr=learning_rate, -# momentum=momentum) -# # loss_fn = nn.NLLLoss() - -# train_losses = [] -# train_counter = [] -# test_losses = [] -# test_counter = [i*20 for i in range(n_epochs + 1)] - -# def test(): -# network.eval() -# test_loss = 0 -# correct = 0 -# with torch.no_grad(): -# for data, target in test_loader: -# output = network(data) -# test_loss += F.nll_loss(output, target, size_average=False).item() -# pred = output.data.max(1, keepdim=True)[1] -# correct += pred.eq(target.data.view_as(pred)).sum() -# test_loss /= 20 -# test_losses.append(test_loss) -# print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( -# test_loss, correct, 20, -# 100. * correct / 20)) -# return test_loss - -# # run test model -# def test_mnist(): -# torch.manual_seed(1) - -# test() -# # target fori_loop -# for epoch in range(1, n_epochs + 1): -# test() - - -# # torch.set_default_dtype(torch.float32) -# accuracy = test_mnist() - - - -import args_parse -from torch_xla import runtime as xr - -# MODEL_OPTS = { -# '--ddp': { -# 'action': 'store_true', -# }, -# '--pjrt_distributed': { -# 'action': 'store_true', -# }, -# } - -FLAGS = args_parse.parse_common_options( - datadir='/tmp/mnist-data', - batch_size=128, - momentum=0.5, - lr=0.01, - target_accuracy=98.0, - num_epochs=18, - # opts=MODEL_OPTS.items(), -) - +import torch +import torchvision import os import shutil import sys @@ -170,13 +16,75 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.test.test_utils as test_utils - import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP import torch_xla.distributed.xla_backend -import torch_xla.experimental.fori_loop -from torch_xla.experimental.fori_loop import fori_loop +n_epochs = 3 +batch_size_train = 8 # 64 +batch_size_test = 10 # 1000 +learning_rate = 0.01 +momentum = 0.5 +log_interval = 10 +random_seed = 1 +torch.backends.cudnn.enabled = False +torch.manual_seed(random_seed) + +### load data +test_loader = xu.SampleGenerator( + data=(torch.zeros(8, 1, 28,28), torch.zeros(8, dtype=torch.int64)), + sample_count=1000 // 8 // xm.xrt_world_size()) + +examples = enumerate(test_loader) +batch_idx, (example_data, example_targets) = next(examples) + +print("shape: ", example_data.shape) + +### build model +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +class SimpleWithLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def forward(self, upper, lower, one_value, x, input_value, output_value): + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value_real = self.linear(input_value) + weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone( + ), output_value_real + + return while_loop( + cond_fn, body_fn, + (upper, lower, one_value, x, input_value, output_value)) + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x) class MNIST(nn.Module): @@ -200,100 +108,235 @@ def forward(self, x): x = self.fc2(x) return F.log_softmax(x, dim=1) +device = xm.xla_device() +network = MNIST().to(device) +# network = Net().to(device) +optimizer = optim.SGD(network.parameters(), lr=learning_rate, + momentum=momentum) +# loss_fn = nn.NLLLoss() + +train_losses = [] +train_counter = [] +test_losses = [] +test_counter = [i*20 for i in range(n_epochs + 1)] + +def test(): + network.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + output = network(data) + test_loss += F.nll_loss(output, target, size_average=False).item() + pred = output.data.max(1, keepdim=True)[1] + correct += pred.eq(target.data.view_as(pred)).sum() + test_loss /= 20 + test_losses.append(test_loss) + print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, 20, + 100. * correct / 20)) + return test_loss + +def new_test(): + simple_with_linear = SimpleWithLinear() + upper = torch.tensor([52], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + weight_0 = simple_with_linear.linear.weight + bias_0 = simple_with_linear.linear.bias -def train_mnist(flags, **kwargs): + aaa = { + "simple_with_linear": + (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, + output_value)) + } + + upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear( + upper, lower, one_value, init_val, l_in_0, output_value) + print("finish new_test") + + +# run test model +def test_mnist(): torch.manual_seed(1) - test_loader = xu.SampleGenerator( - data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), - sample_count=10000 // flags.batch_size // xm.xrt_world_size()) - - # Scale learning rate to num cores - lr = flags.lr * xm.xrt_world_size() - device = xm.xla_device() - model = MNIST().to(device) - - # Initialization is nondeterministic with multiple threads in PjRt. - # Synchronize model parameters across replicas manually. - if xr.using_pjrt(): - xm.broadcast_master_param(model) - - writer = None - if xm.is_master_ordinal(): - writer = test_utils.get_summary_writer(flags.logdir) - optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) - loss_fn = nn.NLLLoss() - - def test_loop_fn(): # loader): - total_samples = 0 - correct = 0 - model.eval() - # print("loader: ", loader) - # print("type loader: ", type(loader)) - # for data, target in loader: - for data, target in test_loader: - output = model(data) - pred = output.max(1, keepdim=True)[1] - correct += pred.eq(target.view_as(pred)).sum() - total_samples += data.size()[0] - - accuracy = 100.0 * correct.item() / total_samples - accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) - return accuracy - - # train_device_loader = pl.MpDeviceLoader(train_loader, device) - test_device_loader = pl.MpDeviceLoader(test_loader, device) - accuracy, max_accuracy = 0.0, 0.0 - - for epoch in range(1, flags.num_epochs + 1): - # xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) - # train_loop_fn(train_device_loader, epoch) - # xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) - accuracy = test_loop_fn() # test_device_loader) - xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(epoch, test_utils.now(), accuracy)) - max_accuracy = max(accuracy, max_accuracy) - # test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) - # if flags.metrics_debug: xm.master_print(met.metrics_report()) - - ### fori_loop - # torch.set_grad_enabled(False) - # new_test_device_loader = pl.MpDeviceLoader(test_loader, device) - upper = torch.tensor([flags.num_epochs + 1], dtype=torch.int32, device=device) # flags.num_epochs + 1 - lower = torch.tensor([1], dtype=torch.int32, device=device) # 1 - init_val = torch.tensor([1], dtype=torch.int32, device=device) - # l_in_0 = torch.randn(10, device=xm.xla_device()) # test_device_loader - # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - def body_fun(): - res1 = torch.tensor([2], dtype=torch.int32, device=device) - res2 = torch.tensor([2], dtype=torch.int32, device=device) - res3 = res1 + res2 - return res3 -# def body_fun(test_device_loader): -# accuracy = test_loop_fn(test_device_loader) -# max_accuracy = max(accuracy, max_accuracy) -# return max_accuracy + test() + # target fori_loop + for epoch in range(1, n_epochs + 1): + test() - upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop( - upper, lower, body_fun, ()) - test_utils.close_summary_writer(writer) - xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) - return max_accuracy +# torch.set_default_dtype(torch.float32) +# accuracy = test_mnist() +# ///////////////////////////////////////////////////////////////////////////////////////////////////////// + +# import args_parse +# from torch_xla import runtime as xr + +# # MODEL_OPTS = { +# # '--ddp': { +# # 'action': 'store_true', +# # }, +# # '--pjrt_distributed': { +# # 'action': 'store_true', +# # }, +# # } + +# FLAGS = args_parse.parse_common_options( +# datadir='/tmp/mnist-data', +# batch_size=128, +# momentum=0.5, +# lr=0.01, +# target_accuracy=98.0, +# num_epochs=18, +# # opts=MODEL_OPTS.items(), +# ) -# def _mp_fn(index, flags): -def main_fun(flags): - torch.set_default_dtype(torch.float32) - accuracy = train_mnist(flags) - if flags.tidy and os.path.isdir(flags.datadir): - shutil.rmtree(flags.datadir) - if accuracy < flags.target_accuracy: - print('Accuracy {} is below target {}'.format(accuracy, - flags.target_accuracy)) - sys.exit(21) +# import os +# import shutil +# import sys +# import numpy as np +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# import torch.optim as optim +# from torchvision import datasets, transforms +# import torch_xla +# import torch_xla.debug.metrics as met +# import torch_xla.distributed.parallel_loader as pl +# import torch_xla.utils.utils as xu +# import torch_xla.core.xla_model as xm +# import torch_xla.distributed.xla_multiprocessing as xmp +# import torch_xla.test.test_utils as test_utils +# import torch.distributed as dist +# from torch.nn.parallel import DistributedDataParallel as DDP +# import torch_xla.distributed.xla_backend + +# import torch_xla.experimental.fori_loop +# from torch_xla.experimental.fori_loop import fori_loop + + +# class MNIST(nn.Module): -if __name__ == '__main__': -# xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores) -# _mp_fn() - main_fun(FLAGS) +# def __init__(self): +# super(MNIST, self).__init__() +# self.conv1 = nn.Conv2d(1, 10, kernel_size=5) +# self.bn1 = nn.BatchNorm2d(10) +# self.conv2 = nn.Conv2d(10, 20, kernel_size=5) +# self.bn2 = nn.BatchNorm2d(20) +# self.fc1 = nn.Linear(320, 50) +# self.fc2 = nn.Linear(50, 10) + +# def forward(self, x): +# x = F.relu(F.max_pool2d(self.conv1(x), 2)) +# x = self.bn1(x) +# x = F.relu(F.max_pool2d(self.conv2(x), 2)) +# x = self.bn2(x) +# x = torch.flatten(x, 1) +# x = F.relu(self.fc1(x)) +# x = self.fc2(x) +# return F.log_softmax(x, dim=1) + + +# def train_mnist(flags, **kwargs): +# torch.manual_seed(1) + +# test_loader = xu.SampleGenerator( +# data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), +# sample_count=10000 // flags.batch_size // xm.xrt_world_size()) + +# # Scale learning rate to num cores +# lr = flags.lr * xm.xrt_world_size() +# device = xm.xla_device() +# model = MNIST().to(device) + +# # Initialization is nondeterministic with multiple threads in PjRt. +# # Synchronize model parameters across replicas manually. +# if xr.using_pjrt(): +# xm.broadcast_master_param(model) + +# writer = None +# if xm.is_master_ordinal(): +# writer = test_utils.get_summary_writer(flags.logdir) +# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) +# loss_fn = nn.NLLLoss() + +# def test_loop_fn(): # loader): +# total_samples = 0 +# correct = 0 +# model.eval() +# # print("loader: ", loader) +# # print("type loader: ", type(loader)) +# # for data, target in loader: +# for data, target in test_loader: +# output = model(data) +# pred = output.max(1, keepdim=True)[1] +# correct += pred.eq(target.view_as(pred)).sum() +# total_samples += data.size()[0] + +# accuracy = 100.0 * correct.item() / total_samples +# accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) +# return accuracy + +# # train_device_loader = pl.MpDeviceLoader(train_loader, device) +# test_device_loader = pl.MpDeviceLoader(test_loader, device) +# accuracy, max_accuracy = 0.0, 0.0 + +# for epoch in range(1, flags.num_epochs + 1): +# # xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) +# # train_loop_fn(train_device_loader, epoch) +# # xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) +# accuracy = test_loop_fn() # test_device_loader) +# xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(epoch, test_utils.now(), accuracy)) +# max_accuracy = max(accuracy, max_accuracy) +# # test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) +# # if flags.metrics_debug: xm.master_print(met.metrics_report()) + +# ### fori_loop +# # torch.set_grad_enabled(False) +# # new_test_device_loader = pl.MpDeviceLoader(test_loader, device) +# upper = torch.tensor([flags.num_epochs + 1], dtype=torch.int32, device=device) # flags.num_epochs + 1 +# lower = torch.tensor([1], dtype=torch.int32, device=device) # 1 +# init_val = torch.tensor([1], dtype=torch.int32, device=device) +# # l_in_0 = torch.randn(10, device=xm.xla_device()) # test_device_loader +# # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) +# def body_fun(): +# res1 = torch.tensor([2], dtype=torch.int32, device=device) +# res2 = torch.tensor([2], dtype=torch.int32, device=device) +# res3 = res1 + res2 +# return res3 +# # def body_fun(test_device_loader): +# # accuracy = test_loop_fn(test_device_loader) +# # max_accuracy = max(accuracy, max_accuracy) +# # return max_accuracy + +# upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop( +# upper, lower, body_fun, ()) + +# test_utils.close_summary_writer(writer) +# xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) +# return max_accuracy + + +# # def _mp_fn(index, flags): +# def main_fun(flags): +# torch.set_default_dtype(torch.float32) +# accuracy = train_mnist(flags) +# if flags.tidy and os.path.isdir(flags.datadir): +# shutil.rmtree(flags.datadir) +# if accuracy < flags.target_accuracy: +# print('Accuracy {} is below target {}'.format(accuracy, +# flags.target_accuracy)) +# sys.exit(21) + + +# if __name__ == '__main__': +# # xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores) +# # _mp_fn() +# main_fun(FLAGS) From 187330a31977df7a85807daa53c10923cc501e21 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 00:48:15 +0000 Subject: [PATCH 411/546] format --- test/test_test_mnist.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 7da754eeba6..0be53a25eb3 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -164,11 +164,16 @@ def new_test(): def test_mnist(): torch.manual_seed(1) + print("before test_mnist") test() # target fori_loop for epoch in range(1, n_epochs + 1): test() + print("after test_mnist") + +if __name__ == '__main__': + test_mnist() # torch.set_default_dtype(torch.float32) # accuracy = test_mnist() From 53a1856b3c2e280ffbb0fdfff71c0c532770e04f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 00:49:19 +0000 Subject: [PATCH 412/546] format --- test/test_test_mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 0be53a25eb3..dec3b70673b 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -165,10 +165,10 @@ def test_mnist(): torch.manual_seed(1) print("before test_mnist") - test() + new_test() # test() # target fori_loop for epoch in range(1, n_epochs + 1): - test() + new_test() # test() print("after test_mnist") From fac3f1ab9a790de7de58d62deedfe7d450a98e98 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 00:50:42 +0000 Subject: [PATCH 413/546] format --- test/test_test_mnist.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index dec3b70673b..1e16f92c4b6 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -19,6 +19,8 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP import torch_xla.distributed.xla_backend +import torch_xla.experimental.fori_loop +from torch._higher_order_ops.while_loop import while_loop n_epochs = 3 batch_size_train = 8 # 64 @@ -149,11 +151,11 @@ def new_test(): weight_0 = simple_with_linear.linear.weight bias_0 = simple_with_linear.linear.bias - aaa = { - "simple_with_linear": - (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, - output_value)) - } + # aaa = { + # "simple_with_linear": + # (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, + # output_value)) + # } upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear( upper, lower, one_value, init_val, l_in_0, output_value) From 82dedf7c827c2295cd9a00fc5e0e52137c2c177f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 00:53:49 +0000 Subject: [PATCH 414/546] format --- torch_xla/csrc/init_python_bindings.cpp | 28 ++++++++++++------------- torch_xla/experimental/fori_loop.py | 4 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 2112222ac94..e20e28fbb8f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -929,20 +929,20 @@ class PyLoweringContext { } } - // // hard-code modify body xlacomputation input arguments with unusedarguments - // // for xla::while requriement - // if (GetNameString() == "bodyctx") { - // xla::XlaBuilder* local_builder = lowering_ctx.builder(); - // // TODO(@manfei): treat hard code parameter_idx value - // int64_t parameter_idx = 7; - // for (auto& additional_input_tensor : additional_inputs_list) { - // XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); - // xla::Shape shape = xtensor->shape().get(); - // xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, - // "UnusedArgumentsPlaceholder"); - // parameter_idx += 1; - // } - // } + // hard-code modify body xlacomputation input arguments with unusedarguments + // for xla::while requriement + if (GetNameString() == "bodyctx") { + xla::XlaBuilder* local_builder = lowering_ctx.builder(); + // TODO(@manfei): treat hard code parameter_idx value + int64_t parameter_idx = 7; + for (auto& additional_input_tensor : additional_inputs_list) { + XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); + xla::Shape shape = xtensor->shape().get(); + xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + "UnusedArgumentsPlaceholder"); + parameter_idx += 1; + } + } // Get the backing XLA tensors from the output torch tensor handles std::vector xtensors = diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 3bab0bbc665..1f4612ab5d3 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -131,8 +131,8 @@ def new_body_fn(*carried_inputs): # res.append(body_fn.weight) # return res return newres - return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) - # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) + # return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, additional_inputs) + return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) print("$$$ additional_inputs: ", additional_inputs) # return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) From 289a7b0aad4445fd29df7bc7bb3721f4984ae89c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 00:57:53 +0000 Subject: [PATCH 415/546] format --- test/test_test_mnist.py | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 1e16f92c4b6..b6739d4b285 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -48,27 +48,26 @@ import torch.optim as optim class SimpleWithLinear(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) - - def forward(self, upper, lower, one_value, x, input_value, output_value): - - def cond_fn(upper, lower, one_value, x, input_value, output_value): - return lower[0] < upper[0] - - def body_fn(upper, lower, one_value, x, input_value, output_value): - new_lower = torch.add(one_value, lower) - output_value_real = self.linear(input_value) - weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement - bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( - one_value, x), input_value.clone( - ), output_value_real - - return while_loop( - cond_fn, body_fn, - (upper, lower, one_value, x, input_value, output_value)) + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def forward(self, upper, lower, one_value, x, input_value, output_value): + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value_real = self.linear(input_value) + weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone( + ), output_value_real, weight.clone(), bias.clone() + + return while_loop( + cond_fn, body_fn, + (upper, lower, one_value, x, input_value, output_value)) class Net(nn.Module): def __init__(self): From d2fe90795572c8b5f822fb668b55490b9f94754f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 00:59:04 +0000 Subject: [PATCH 416/546] format --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 1f4612ab5d3..7b40e91e12f 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -189,8 +189,8 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): additional_inputs_list_body = [] # TODO(@manfei): treat hard-code parameters: additional_inputs_list_body - # body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) - body_ctx.buildforiloop(list(body_result), ()) + body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) + # body_ctx.buildforiloop(list(body_result), ()) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) From 1bd65bf1163f7a3bbdee4081540aeca60c4de347 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 01:33:35 +0000 Subject: [PATCH 417/546] test --- test/test_test_mnist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index b6739d4b285..fabfb5337a3 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -51,6 +51,7 @@ class SimpleWithLinear(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + self.linear2 = torch.nn.Linear(50, 10).to(xm.xla_device()) def forward(self, upper, lower, one_value, x, input_value, output_value): def cond_fn(upper, lower, one_value, x, input_value, output_value): From 08f06f6f1c6d2e509252657483f156bccf7e2871 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 01:36:02 +0000 Subject: [PATCH 418/546] test --- torch_xla/experimental/fori_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 7b40e91e12f..dba107e774e 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -194,9 +194,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - # body_hlo_print = xb.get_computation_hlo(body_computation) - # print("body computation: !!!!!!!!!") - # print(body_hlo_print) + body_hlo_print = xb.get_computation_hlo(body_computation) + print("body computation: !!!!!!!!!") + print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs From e7fa8f8e5abac41e0e5d8c05179b2f2b618d3e54 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 01:39:28 +0000 Subject: [PATCH 419/546] test --- test/test_test_mnist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index fabfb5337a3..2c4b124e17a 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -51,7 +51,7 @@ class SimpleWithLinear(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) - self.linear2 = torch.nn.Linear(50, 10).to(xm.xla_device()) + self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) def forward(self, upper, lower, one_value, x, input_value, output_value): def cond_fn(upper, lower, one_value, x, input_value, output_value): @@ -60,6 +60,7 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = self.linear(input_value) + output_value_real2 = self.linear2(output_value_real) weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( From 74bfd5fc93e4cc2b2ae17ca68bad957494fb349f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 01:42:53 +0000 Subject: [PATCH 420/546] test --- test/test_test_mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 2c4b124e17a..1f02100f50f 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -60,12 +60,12 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = self.linear(input_value) - output_value_real2 = self.linear2(output_value_real) + output_value_real_final = self.linear2(output_value_real) weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( one_value, x), input_value.clone( - ), output_value_real, weight.clone(), bias.clone() + ), output_value_real_final, weight.clone(), bias.clone() return while_loop( cond_fn, body_fn, From 4fcbceb6b02087e40698df3bf0f6ffb91beec942 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 01:53:32 +0000 Subject: [PATCH 421/546] test --- torch_xla/csrc/init_python_bindings.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index e20e28fbb8f..d54340bf598 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -934,7 +934,8 @@ class PyLoweringContext { if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value - int64_t parameter_idx = 7; + int64_t parameter_idx = local_builder.parameter_numbers_; // GetProgramShape(); + // int64_t parameter_idx = 7; for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); xla::Shape shape = xtensor->shape().get(); From 4525f2e86e30e61078331e37d5b30d8b4803caa7 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 01:54:49 +0000 Subject: [PATCH 422/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index d54340bf598..1d6170c7425 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -934,7 +934,7 @@ class PyLoweringContext { if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value - int64_t parameter_idx = local_builder.parameter_numbers_; // GetProgramShape(); + int64_t parameter_idx = local_builder->parameter_numbers_; // GetProgramShape(); // int64_t parameter_idx = 7; for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); From 0f7faedde371859d5d4b17926b488de0ec300385 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 01:57:44 +0000 Subject: [PATCH 423/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 1d6170c7425..7713f693da8 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -934,7 +934,7 @@ class PyLoweringContext { if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value - int64_t parameter_idx = local_builder->parameter_numbers_; // GetProgramShape(); + int64_t parameter_idx = local_builder.GetProgramShape(/*root_id=*/0); // local_builder->parameter_numbers_; // GetProgramShape(); // int64_t parameter_idx = 7; for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); From f9fdb2a988a0b751b01aacc43b78f57eef0d0c6c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 01:58:53 +0000 Subject: [PATCH 424/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 7713f693da8..7d20ede1039 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -934,7 +934,7 @@ class PyLoweringContext { if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value - int64_t parameter_idx = local_builder.GetProgramShape(/*root_id=*/0); // local_builder->parameter_numbers_; // GetProgramShape(); + int64_t parameter_idx = local_builder->GetProgramShape(/*root_id=*/0); // local_builder->parameter_numbers_; // GetProgramShape(); // int64_t parameter_idx = 7; for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); From 443c28ab65771fd232e5a9b8a905be9bceb8b803 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 02:00:12 +0000 Subject: [PATCH 425/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 7d20ede1039..13c7e75b125 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -934,7 +934,7 @@ class PyLoweringContext { if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value - int64_t parameter_idx = local_builder->GetProgramShape(/*root_id=*/0); // local_builder->parameter_numbers_; // GetProgramShape(); + int64_t parameter_idx = local_builder->GetProgramShape(); // local_builder->parameter_numbers_; // GetProgramShape(); // int64_t parameter_idx = 7; for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); From 5d8ea8005ac07b1c8f8dca45fa6270ed12a5e3a0 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 02:02:12 +0000 Subject: [PATCH 426/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 13c7e75b125..6ed7cd8beb2 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -934,7 +934,7 @@ class PyLoweringContext { if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value - int64_t parameter_idx = local_builder->GetProgramShape(); // local_builder->parameter_numbers_; // GetProgramShape(); + int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); // int64_t parameter_idx = 7; for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); From 9e6e12e05637f08124ada1bac4237d8e0bb23c02 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 02:04:36 +0000 Subject: [PATCH 427/546] test --- torch_xla/csrc/init_python_bindings.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 6ed7cd8beb2..4030cb5834a 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -934,8 +934,8 @@ class PyLoweringContext { if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value - int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); - // int64_t parameter_idx = 7; + // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); + int64_t parameter_idx = 8; for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); xla::Shape shape = xtensor->shape().get(); From 41ddcfc9d4db87c5dda844650b148a9c94d14c6d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 02:10:45 +0000 Subject: [PATCH 428/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 4030cb5834a..677901fb5c6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -935,7 +935,7 @@ class PyLoweringContext { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); - int64_t parameter_idx = 8; + int64_t parameter_idx = 9; for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); xla::Shape shape = xtensor->shape().get(); From c6acfe3b0e3344ab65980ad1fa170180b2b0fa43 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 02:16:29 +0000 Subject: [PATCH 429/546] test --- torch_xla/experimental/fori_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index dba107e774e..3e05d4a0ec1 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -189,8 +189,8 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): additional_inputs_list_body = [] # TODO(@manfei): treat hard-code parameters: additional_inputs_list_body - body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) - # body_ctx.buildforiloop(list(body_result), ()) + # body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) + body_ctx.buildforiloop(list(body_result), ()) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) From 4a0303fd9d1a639c2f66353dea57d096b0974e6f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 18:39:00 +0000 Subject: [PATCH 430/546] test --- test/test_test_mnist.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 1f02100f50f..642553baf1d 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -61,11 +61,11 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value_real = self.linear(input_value) output_value_real_final = self.linear2(output_value_real) - weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement - bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement + # weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + # bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( one_value, x), input_value.clone( - ), output_value_real_final, weight.clone(), bias.clone() + ), output_value_real_final # , weight.clone(), bias.clone() return while_loop( cond_fn, body_fn, From 9b4b93a0c27259f92ab981bc4b612d58ff5dcdc3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 18:44:34 +0000 Subject: [PATCH 431/546] test --- torch_xla/csrc/init_python_bindings.cpp | 3 +++ torch_xla/experimental/fori_loop.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 677901fb5c6..7395786f193 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -931,6 +931,9 @@ class PyLoweringContext { // hard-code modify body xlacomputation input arguments with unusedarguments // for xla::while requriement + // !!! actually weight/bias don't need to be added here as dummy arguments by additional_inputs_list, + // !!! they will be added automatically added here, we need to add dummy argument for output/return_value + // !!! if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 3e05d4a0ec1..bfa163ff60a 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -190,7 +190,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # TODO(@manfei): treat hard-code parameters: additional_inputs_list_body # body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) - body_ctx.buildforiloop(list(body_result), ()) + body_ctx.buildforiloop(list(body_result), list(body_result)) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) From 9eaf3fd323c6ded35921306e84917c6f0120167e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 18:56:34 +0000 Subject: [PATCH 432/546] test --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index bfa163ff60a..f1c6412f5fc 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -190,7 +190,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # TODO(@manfei): treat hard-code parameters: additional_inputs_list_body # body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) - body_ctx.buildforiloop(list(body_result), list(body_result)) + body_ctx.buildforiloop(list(body_result), list(body_result[-1])) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) From 7dca87af3778b90dccdd7f0bbf62df1f983683c9 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 18:58:23 +0000 Subject: [PATCH 433/546] test --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index f1c6412f5fc..57c4be3adbb 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -190,7 +190,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # TODO(@manfei): treat hard-code parameters: additional_inputs_list_body # body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) - body_ctx.buildforiloop(list(body_result), list(body_result[-1])) + body_ctx.buildforiloop(list(body_result), [body_result[-1],]) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) From 8c115b0c9317c41a10a77f206e7bf8abae44a485 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 19:27:32 +0000 Subject: [PATCH 434/546] test --- test/test_test_mnist.py | 86 ++++++++++++++++++++++++++--- torch_xla/experimental/fori_loop.py | 1 + 2 files changed, 78 insertions(+), 9 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 642553baf1d..01bbf3662ab 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -47,6 +47,20 @@ import torch.nn.functional as F import torch.optim as optim +# model.parameters() + +class SimpleWithLinearPure(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) + + def forward(self, upper, lower, one_value, x, input_value, output_value): + output_value_real = self.linear(input_value) + output_value_real_final = self.linear2(output_value_real) + return output_value_real_final + + class SimpleWithLinear(torch.nn.Module): def __init__(self): super().__init__() @@ -89,7 +103,6 @@ def forward(self, x): x = self.fc2(x) return F.log_softmax(x) - class MNIST(nn.Module): def __init__(self): @@ -152,26 +165,81 @@ def new_test(): weight_0 = simple_with_linear.linear.weight bias_0 = simple_with_linear.linear.bias - # aaa = { - # "simple_with_linear": - # (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, - # output_value)) - # } - upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear( upper, lower, one_value, init_val, l_in_0, output_value) print("finish new_test") +def newnew_test(): + device = xm.xla_device() + torch.set_grad_enabled(False) + + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + # simple_with_linear = SimpleWithLinear() + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value = linear_0(input_value) + weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), bias.clone(), weight.clone( + ), output_value.clone() + + upper = torch.tensor([1], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value)) + print("finish newnew_test") + + # # simple_with_linear = SimpleWithLinear() + # simple_with_linear = SimpleWithLinear() + # upper = torch.tensor([52], dtype=torch.int32, device=device) + # lower = torch.tensor([0], dtype=torch.int32, device=device) + # one_value = torch.tensor([1], dtype=torch.int32, device=device) + # init_val = torch.tensor([1], dtype=torch.int32, device=device) + # l_in_0 = torch.rand(10, device=xm.xla_device()) + # output_value = torch.zeros([20], dtype=torch.float32, device=device) + # weight_0 = simple_with_linear.linear.weight + # bias_0 = simple_with_linear.linear.bias + + # upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear( + # upper, lower, one_value, init_val, l_in_0, output_value) + # print("finish new_test") + + # def cond_fn(upper, lower, one_value, x, input_value, output_value): + # return lower[0] < upper[0] + + # def body_fn(upper, lower, one_value, x, input_value, output_value): + # new_lower = torch.add(one_value, lower) + # output_value_real = self.linear(input_value) + # output_value_real_final = self.linear2(output_value_real) + # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + # one_value, x), input_value.clone( + # ), output_value_real_final # , weight.clone(), bias.clone() + + # return while_loop( + # cond_fn, body_fn, + # (upper, lower, one_value, x, input_value, output_value)) + # run test model def test_mnist(): torch.manual_seed(1) print("before test_mnist") - new_test() # test() + newnew_test() # new_test() # test() # target fori_loop for epoch in range(1, n_epochs + 1): - new_test() # test() + newnew_test() # new_test() # test() print("after test_mnist") diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 57c4be3adbb..f2b8a98c63d 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -190,6 +190,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # TODO(@manfei): treat hard-code parameters: additional_inputs_list_body # body_ctx.buildforiloop(list(body_result), additional_inputs_list_body) + # TODO(@manfei): get index of output_value, then trasfer them into buildforiloop for body_ctx.buildforiloop(list(body_result), [body_result[-1],]) body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", From 76893a34650ce251395a0a4c14b970505b7dd874 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 19:28:25 +0000 Subject: [PATCH 435/546] test --- torch_xla/experimental/fori_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index f2b8a98c63d..9d35a037d4a 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -173,9 +173,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - # cond_hlo_print = xb.get_computation_hlo(cond_computation) - # print("cond computation: !!!!!!!!!") - # print(cond_hlo_print) + cond_hlo_print = xb.get_computation_hlo(cond_computation) + print("cond computation: !!!!!!!!!") + print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) From d6c00253206c78d29a05c4f20f4921a7f95a9db1 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 19:32:05 +0000 Subject: [PATCH 436/546] test --- torch_xla/csrc/init_python_bindings.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 7395786f193..5153fe2e05c 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -938,7 +938,8 @@ class PyLoweringContext { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); - int64_t parameter_idx = 9; + // int64_t parameter_idx = 9; + int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); xla::Shape shape = xtensor->shape().get(); From 48a2b1795f5ac448cd73f1aa3201a47b906d274b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 19:36:14 +0000 Subject: [PATCH 437/546] test --- torch_xla/csrc/init_python_bindings.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 5153fe2e05c..37ad6bb1a32 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -938,8 +938,8 @@ class PyLoweringContext { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); - // int64_t parameter_idx = 9; - int64_t parameter_idx = tensors.size(); + int64_t parameter_idx = 7; + // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); xla::Shape shape = xtensor->shape().get(); From 7b6551498a8fed831ca8698ffb9ebcd340871ff4 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 20:42:04 +0000 Subject: [PATCH 438/546] test --- torch_xla/experimental/fori_loop.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 9d35a037d4a..7478355caa4 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -173,9 +173,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - cond_hlo_print = xb.get_computation_hlo(cond_computation) - print("cond computation: !!!!!!!!!") - print(cond_hlo_print) + # cond_hlo_print = xb.get_computation_hlo(cond_computation) + # print("cond computation: !!!!!!!!!") + # print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) @@ -195,9 +195,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - body_hlo_print = xb.get_computation_hlo(body_computation) - print("body computation: !!!!!!!!!") - print(body_hlo_print) + # body_hlo_print = xb.get_computation_hlo(body_computation) + # print("body computation: !!!!!!!!!") + # print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs From af6d44b64f408c869329281331b3c3d9e91ecc8c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 20:44:29 +0000 Subject: [PATCH 439/546] test --- test/test_test_mnist.py | 43 ++++++++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 01bbf3662ab..dc7a8d3f324 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -37,10 +37,9 @@ data=(torch.zeros(8, 1, 28,28), torch.zeros(8, dtype=torch.int64)), sample_count=1000 // 8 // xm.xrt_world_size()) -examples = enumerate(test_loader) -batch_idx, (example_data, example_targets) = next(examples) - -print("shape: ", example_data.shape) +# examples = enumerate(test_loader) +# batch_idx, (example_data, example_targets) = next(examples) +# print("shape: ", example_data.shape) ### build model import torch.nn as nn @@ -169,7 +168,6 @@ def new_test(): upper, lower, one_value, init_val, l_in_0, output_value) print("finish new_test") - def newnew_test(): device = xm.xla_device() torch.set_grad_enabled(False) @@ -231,15 +229,46 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # cond_fn, body_fn, # (upper, lower, one_value, x, input_value, output_value)) +def newnewnew_test(): + device = xm.xla_device() + torch.set_grad_enabled(False) + + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + # simple_with_linear = SimpleWithLinear() + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value = linear_0(input_value) + weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), bias.clone(), weight.clone( + ), output_value.clone() + + upper = torch.tensor([1], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value)) + print("finish newnewnew_test") + # run test model def test_mnist(): torch.manual_seed(1) print("before test_mnist") - newnew_test() # new_test() # test() + newnewnew_test() # newnew_test() # new_test() # test() # target fori_loop for epoch in range(1, n_epochs + 1): - newnew_test() # new_test() # test() + newnewnew_test() # newnew_test() # new_test() # test() print("after test_mnist") From d3bea05fb489c8f5d8a736a88b3beaa9c5807108 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 20:46:26 +0000 Subject: [PATCH 440/546] test --- test/test_test_mnist.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index dc7a8d3f324..8f4659e6f71 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -233,17 +233,17 @@ def newnewnew_test(): device = xm.xla_device() torch.set_grad_enabled(False) - linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - # simple_with_linear = SimpleWithLinear() + # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + simple_with_linear = SimpleWithLinear() def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) - output_value = linear_0(input_value) - weight = linear_0.weight # not be used actually, initialized as placeholder xlacomputation requirement - bias = linear_0.bias # not be used actually, initialized as placeholder xlacomputation requirement + output_value = SimpleWithLinear(input_value) + weight = SimpleWithLinear.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = SimpleWithLinear.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( one_value, x), input_value.clone(), bias.clone(), weight.clone( ), output_value.clone() From afdac6c2e6bc5be2d3272ed0d86e62c3f0b01c87 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 20:56:38 +0000 Subject: [PATCH 441/546] test --- test/test_test_mnist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 8f4659e6f71..440e58d9725 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -20,6 +20,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP import torch_xla.distributed.xla_backend import torch_xla.experimental.fori_loop +from torch_xla.experimental.fori_loop import _xla_while_loop from torch._higher_order_ops.while_loop import while_loop n_epochs = 3 @@ -255,7 +256,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): l_in_0 = torch.rand(10, device=xm.xla_device()) output_value = torch.zeros([20], dtype=torch.float32, device=device) - upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop( + upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = _xla_while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value)) print("finish newnewnew_test") From 1f052a5d5582b63e9cd296c1c1e8dc03b918f7fb Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 20:58:21 +0000 Subject: [PATCH 442/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 440e58d9725..10fba95b226 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -258,7 +258,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = _xla_while_loop( cond_fn, body_fn, - (upper, lower, one_value, init_val, l_in_0, output_value)) + (upper, lower, one_value, init_val, l_in_0, output_value), ()) print("finish newnewnew_test") # run test model From 2c8a445cffe8907de92f4be4fc400efdc4f5a7d3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 21:00:57 +0000 Subject: [PATCH 443/546] test --- torch_xla/experimental/fori_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 7478355caa4..40d09f4cf93 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -151,6 +151,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): torch.randint( 10, additional_input.size(), dtype=additional_input.dtype).to(device)) + print("fake_carried_inputs: ", fake_carried_inputs) # TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) From 83d5e4ff96599b59e054da3d17cce110f96efc02 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 21:02:59 +0000 Subject: [PATCH 444/546] test --- test/test_test_mnist.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 10fba95b226..c5931fe9210 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -242,9 +242,9 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) - output_value = SimpleWithLinear(input_value) - weight = SimpleWithLinear.weight # not be used actually, initialized as placeholder xlacomputation requirement - bias = SimpleWithLinear.bias # not be used actually, initialized as placeholder xlacomputation requirement + output_value = simple_with_linear(input_value) + weight = simple_with_linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + bias = simple_with_linear.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( one_value, x), input_value.clone(), bias.clone(), weight.clone( ), output_value.clone() From 19eaf2c7adc1dc22655c79e8b14a95a4fd50c543 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 26 Apr 2024 21:44:40 +0000 Subject: [PATCH 445/546] test --- test/test_test_mnist.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index c5931fe9210..29b54e02aac 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -243,11 +243,10 @@ def cond_fn(upper, lower, one_value, x, input_value, output_value): def body_fn(upper, lower, one_value, x, input_value, output_value): new_lower = torch.add(one_value, lower) output_value = simple_with_linear(input_value) - weight = simple_with_linear.weight # not be used actually, initialized as placeholder xlacomputation requirement - bias = simple_with_linear.bias # not be used actually, initialized as placeholder xlacomputation requirement + # weight = simple_with_linear.weight # not be used actually, initialized as placeholder xlacomputation requirement + # bias = simple_with_linear.bias # not be used actually, initialized as placeholder xlacomputation requirement return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( - one_value, x), input_value.clone(), bias.clone(), weight.clone( - ), output_value.clone() + one_value, x), input_value.clone(), output_value.clone() # bias.clone(), weight.clone(), output_value.clone() upper = torch.tensor([1], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) From 58d92f85880752719f0d630124a542b3a284c534 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 22:53:08 +0000 Subject: [PATCH 446/546] test --- test/test_test_mnist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 29b54e02aac..801834031ef 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -55,7 +55,8 @@ def __init__(self): self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) - def forward(self, upper, lower, one_value, x, input_value, output_value): + # def forward(self, upper, lower, one_value, x, input_value, output_value): + def forward(self, input_value): output_value_real = self.linear(input_value) output_value_real_final = self.linear2(output_value_real) return output_value_real_final From 1265107b479835c54410cb3fad7f5b91a2d8fb19 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 22:54:13 +0000 Subject: [PATCH 447/546] test --- test/test_test_mnist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 801834031ef..159030bc3bb 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -236,7 +236,8 @@ def newnewnew_test(): torch.set_grad_enabled(False) # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) - simple_with_linear = SimpleWithLinear() + # simple_with_linear = SimpleWithLinear() + simple_with_linear = SimpleWithLinearPure() def cond_fn(upper, lower, one_value, x, input_value, output_value): return lower[0] < upper[0] From b15d72194b5fb73ef0ba1bd423b4c52c68a25bb3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 22:57:06 +0000 Subject: [PATCH 448/546] test --- test/test_test_mnist.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 159030bc3bb..1c0c14500e7 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -250,6 +250,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( one_value, x), input_value.clone(), output_value.clone() # bias.clone(), weight.clone(), output_value.clone() + print("simple_with_linear weight: ", simple_with_linear.weight) + print("simple_with_linear bias: ", simple_with_linear.bias) + upper = torch.tensor([1], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) From 000d5ca1f496738a81829d75aa373bf0e4e84e56 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 22:59:20 +0000 Subject: [PATCH 449/546] test --- test/test_test_mnist.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 1c0c14500e7..f676bb3d874 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -250,8 +250,12 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( one_value, x), input_value.clone(), output_value.clone() # bias.clone(), weight.clone(), output_value.clone() - print("simple_with_linear weight: ", simple_with_linear.weight) - print("simple_with_linear bias: ", simple_with_linear.bias) + # print("simple_with_linear weight: ", simple_with_linear.weight) + # print("simple_with_linear bias: ", simple_with_linear.bias) + print("prine all things!!!") + for name, param in simple_with_linear.named_parameters(): + if name in ['bias']: + print(param.size()) upper = torch.tensor([1], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) From f0678f7d8d4311b1ad3a4d886a194cc4e5563f0d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:00:04 +0000 Subject: [PATCH 450/546] test --- test/test_test_mnist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index f676bb3d874..ca1e13bf723 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -253,6 +253,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # print("simple_with_linear weight: ", simple_with_linear.weight) # print("simple_with_linear bias: ", simple_with_linear.bias) print("prine all things!!!") + print("simple_with_linear.named_parameters(): ", simple_with_linear.named_parameters()) for name, param in simple_with_linear.named_parameters(): if name in ['bias']: print(param.size()) From 66b6027f97193168c9020a67d7540a15fbcd213d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:02:37 +0000 Subject: [PATCH 451/546] test --- test/test_test_mnist.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index ca1e13bf723..a4562c0e011 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -252,8 +252,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # print("simple_with_linear weight: ", simple_with_linear.weight) # print("simple_with_linear bias: ", simple_with_linear.bias) - print("prine all things!!!") + print("print all things!!!") + print(type(simple_with_linear.parameters())) print("simple_with_linear.named_parameters(): ", simple_with_linear.named_parameters()) + import pdb; pdb.set_trace() for name, param in simple_with_linear.named_parameters(): if name in ['bias']: print(param.size()) From 9953a7fc3d0860f6b9a651f17f72e018ac07c1f5 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:08:09 +0000 Subject: [PATCH 452/546] test --- test/test_test_mnist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index a4562c0e011..ed5fc14c839 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -255,8 +255,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): print("print all things!!!") print(type(simple_with_linear.parameters())) print("simple_with_linear.named_parameters(): ", simple_with_linear.named_parameters()) - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() for name, param in simple_with_linear.named_parameters(): + print("arrive the loop") if name in ['bias']: print(param.size()) From a6e28e0600244c5f74998b8e275b1abe0a177bff Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:08:57 +0000 Subject: [PATCH 453/546] test --- test/test_test_mnist.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index ed5fc14c839..80abc528721 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -258,6 +258,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # import pdb; pdb.set_trace() for name, param in simple_with_linear.named_parameters(): print("arrive the loop") + print("name: ", name) + print("param: ", param) if name in ['bias']: print(param.size()) From a30db8e1e805e215756b49fc22fde8e6035ef26f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:18:26 +0000 Subject: [PATCH 454/546] test --- test/test_test_mnist.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 80abc528721..147c9198592 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -247,21 +247,23 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): output_value = simple_with_linear(input_value) # weight = simple_with_linear.weight # not be used actually, initialized as placeholder xlacomputation requirement # bias = simple_with_linear.bias # not be used actually, initialized as placeholder xlacomputation requirement + for name, param in simple_with_linear.named_parameters(): + asd return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( - one_value, x), input_value.clone(), output_value.clone() # bias.clone(), weight.clone(), output_value.clone() + one_value, x), input_value.clone(), output_value.clone(), simple_with_linear.linear.weight # bias.clone(), weight.clone(), output_value.clone() # print("simple_with_linear weight: ", simple_with_linear.weight) # print("simple_with_linear bias: ", simple_with_linear.bias) - print("print all things!!!") - print(type(simple_with_linear.parameters())) - print("simple_with_linear.named_parameters(): ", simple_with_linear.named_parameters()) + # print("print all things!!!") + # print(type(simple_with_linear.parameters())) + # print("simple_with_linear.named_parameters(): ", simple_with_linear.named_parameters()) # import pdb; pdb.set_trace() for name, param in simple_with_linear.named_parameters(): print("arrive the loop") print("name: ", name) print("param: ", param) - if name in ['bias']: - print(param.size()) + # if name in ['bias']: + # print(param.size()) upper = torch.tensor([1], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) From a54ba4cc917a6c3dc80d01ed295e1a28dd807ee7 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:18:50 +0000 Subject: [PATCH 455/546] test --- test/test_test_mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 147c9198592..48546837fe7 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -247,8 +247,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): output_value = simple_with_linear(input_value) # weight = simple_with_linear.weight # not be used actually, initialized as placeholder xlacomputation requirement # bias = simple_with_linear.bias # not be used actually, initialized as placeholder xlacomputation requirement - for name, param in simple_with_linear.named_parameters(): - asd + # for name, param in simple_with_linear.named_parameters(): + # asd return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( one_value, x), input_value.clone(), output_value.clone(), simple_with_linear.linear.weight # bias.clone(), weight.clone(), output_value.clone() From b86a03ae5ceaf8775802f7f9f4bb2551654b0930 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:24:32 +0000 Subject: [PATCH 456/546] test --- test/test_test_mnist.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 48546837fe7..b0e8686afc9 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -249,8 +249,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # bias = simple_with_linear.bias # not be used actually, initialized as placeholder xlacomputation requirement # for name, param in simple_with_linear.named_parameters(): # asd - return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( - one_value, x), input_value.clone(), output_value.clone(), simple_with_linear.linear.weight # bias.clone(), weight.clone(), output_value.clone() + return (upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), output_value.clone(), simple_with_linear.linear.weight) # bias.clone(), weight.clone(), output_value.clone() # print("simple_with_linear weight: ", simple_with_linear.weight) # print("simple_with_linear bias: ", simple_with_linear.bias) @@ -258,10 +258,12 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): # print(type(simple_with_linear.parameters())) # print("simple_with_linear.named_parameters(): ", simple_with_linear.named_parameters()) # import pdb; pdb.set_trace() - for name, param in simple_with_linear.named_parameters(): - print("arrive the loop") - print("name: ", name) - print("param: ", param) + + # for name, param in simple_with_linear.named_parameters(): + # print("arrive the loop") + # print("name: ", name) + # print("param: ", param) + # if name in ['bias']: # print(param.size()) From 407adf891a3bda9efae315421c2345857bfab61d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:27:42 +0000 Subject: [PATCH 457/546] test --- test/test_test_mnist.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index b0e8686afc9..d146069ea70 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -247,10 +247,13 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): output_value = simple_with_linear(input_value) # weight = simple_with_linear.weight # not be used actually, initialized as placeholder xlacomputation requirement # bias = simple_with_linear.bias # not be used actually, initialized as placeholder xlacomputation requirement - # for name, param in simple_with_linear.named_parameters(): - # asd - return (upper.clone(), new_lower.clone(), one_value.clone(), torch.add( - one_value, x), input_value.clone(), output_value.clone(), simple_with_linear.linear.weight) # bias.clone(), weight.clone(), output_value.clone() + res = [upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone()] + for name, param in simple_with_linear.named_parameters(): + res.insert(-1, param) + # torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) + return tuple(res) + # return (upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + # one_value, x), input_value.clone(), output_value.clone(), simple_with_linear.linear.weight) # bias.clone(), weight.clone(), output_value.clone() # print("simple_with_linear weight: ", simple_with_linear.weight) # print("simple_with_linear bias: ", simple_with_linear.bias) From 532b9cedcd54d3b77f817d9d0d6189c07cb4517b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:29:31 +0000 Subject: [PATCH 458/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 37ad6bb1a32..56f8f16b8e6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -938,7 +938,7 @@ class PyLoweringContext { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); - int64_t parameter_idx = 7; + int64_t parameter_idx = 9; // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); From 205f2e3030d738dd237772e40715018141dbbac8 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:38:31 +0000 Subject: [PATCH 459/546] test --- test/test_test_mnist.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index d146069ea70..5a7c10d8abe 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -277,9 +277,14 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): l_in_0 = torch.rand(10, device=xm.xla_device()) output_value = torch.zeros([20], dtype=torch.float32, device=device) + additional_inputs = [] + for name, param in simple_with_linear.named_parameters(): + additional_inputs.insert(-1, param) + upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = _xla_while_loop( cond_fn, body_fn, - (upper, lower, one_value, init_val, l_in_0, output_value), ()) + (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) + # (upper, lower, one_value, init_val, l_in_0, output_value), ()) print("finish newnewnew_test") # run test model From c4e547dba3efa4faf320d6b65cb3af52a521b775 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:41:50 +0000 Subject: [PATCH 460/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 5a7c10d8abe..449bea09400 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -239,7 +239,7 @@ def newnewnew_test(): # simple_with_linear = SimpleWithLinear() simple_with_linear = SimpleWithLinearPure() - def cond_fn(upper, lower, one_value, x, input_value, output_value): + def cond_fn(upper, lower, one_value, x, input_value, output_value, args*): return lower[0] < upper[0] def body_fn(upper, lower, one_value, x, input_value, output_value): From d4f86c79551b6ce585bf5e2f9d951eaadeb07ffa Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:42:13 +0000 Subject: [PATCH 461/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 449bea09400..3ff1efaf017 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -239,7 +239,7 @@ def newnewnew_test(): # simple_with_linear = SimpleWithLinear() simple_with_linear = SimpleWithLinearPure() - def cond_fn(upper, lower, one_value, x, input_value, output_value, args*): + def cond_fn(upper, lower, one_value, x, input_value, output_value, *args): return lower[0] < upper[0] def body_fn(upper, lower, one_value, x, input_value, output_value): From 98a5b7b32199fa8639c44f7bb9458619e8a5fc78 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:42:51 +0000 Subject: [PATCH 462/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 3ff1efaf017..7f12fb003bb 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -242,7 +242,7 @@ def newnewnew_test(): def cond_fn(upper, lower, one_value, x, input_value, output_value, *args): return lower[0] < upper[0] - def body_fn(upper, lower, one_value, x, input_value, output_value): + def body_fn(upper, lower, one_value, x, input_value, output_value, *args): new_lower = torch.add(one_value, lower) output_value = simple_with_linear(input_value) # weight = simple_with_linear.weight # not be used actually, initialized as placeholder xlacomputation requirement From 93c5db01bee4e55aee3092f5c14098547d226d90 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:46:50 +0000 Subject: [PATCH 463/546] test --- torch_xla/experimental/fori_loop.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 40d09f4cf93..f958119f91e 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -162,13 +162,13 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): additional_inputs_list_cond = list( fake_carried_inputs[2:] ) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor - if additional_inputs: - tmp_bias = additional_inputs_list_cond[ - -3] # not used, change order doesn't affect logic - del additional_inputs_list_cond[ - -3] # not used, change order doesn't affect logic - additional_inputs_list_cond.append( - tmp_bias) # not used, change order doesn't affect logic + # if additional_inputs: + # tmp_bias = additional_inputs_list_cond[ + # -3] # not used, change order doesn't affect logic + # del additional_inputs_list_cond[ + # -3] # not used, change order doesn't affect logic + # additional_inputs_list_cond.append( + # tmp_bias) # not used, change order doesn't affect logic cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) cond_hlo = cond_ctx.hlo() From 9977250ddf35719689d8c5fab56874b21a398d1a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:48:16 +0000 Subject: [PATCH 464/546] test --- test/test_test_mnist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 7f12fb003bb..0b0b596d229 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -281,6 +281,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): for name, param in simple_with_linear.named_parameters(): additional_inputs.insert(-1, param) + print("in mnist additional_inputs: ", additional_inputs) upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = _xla_while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) From ad19fcad24c322b41d7687167895ea1a1b5a8394 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:55:43 +0000 Subject: [PATCH 465/546] test --- test/test_test_mnist.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 0b0b596d229..bf13dc1a1c9 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -275,11 +275,12 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) l_in_0 = torch.rand(10, device=xm.xla_device()) - output_value = torch.zeros([20], dtype=torch.float32, device=device) + output_value = torch.zeros([30], dtype=torch.float32, device=device) additional_inputs = [] for name, param in simple_with_linear.named_parameters(): - additional_inputs.insert(-1, param) + # additional_inputs.insert(-1, param) + additional_inputs.append(-1, param) print("in mnist additional_inputs: ", additional_inputs) upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = _xla_while_loop( From 47eb44f64a3b23ff596ac060e051dc6a92b379b3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Mon, 29 Apr 2024 23:56:09 +0000 Subject: [PATCH 466/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index bf13dc1a1c9..358704ec328 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -280,7 +280,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): additional_inputs = [] for name, param in simple_with_linear.named_parameters(): # additional_inputs.insert(-1, param) - additional_inputs.append(-1, param) + additional_inputs.append(param) print("in mnist additional_inputs: ", additional_inputs) upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = _xla_while_loop( From 7d66521e6cc5d3c4db98e8af2abcf2499aec8b1f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:00:28 +0000 Subject: [PATCH 467/546] test --- torch_xla/experimental/fori_loop.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index f958119f91e..45476285c27 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -213,11 +213,11 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): p = xb.mkparam(builder, len(params), shape) params.append(p) - # TODO(@manfei): treat hard-code input arguments, currently switch bias and output_value if additional_inputs(weight/bias) exists - if additional_inputs: - tmp_bias = params[-3] - del params[-3] - params.append(tmp_bias) + # # TODO(@manfei): treat hard-code input arguments, currently switch bias and output_value if additional_inputs(weight/bias) exists + # if additional_inputs: + # tmp_bias = params[-3] + # del params[-3] + # params.append(tmp_bias) # generate while xlacomputation input_tuple = xb.Op.tuple(tuple(params)) From bcbed012402d8b98493403b80a8e6c4e0522950f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:04:28 +0000 Subject: [PATCH 468/546] test --- torch_xla/experimental/fori_loop.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 45476285c27..541a3f74a42 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -162,13 +162,10 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): additional_inputs_list_cond = list( fake_carried_inputs[2:] ) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor - # if additional_inputs: - # tmp_bias = additional_inputs_list_cond[ - # -3] # not used, change order doesn't affect logic - # del additional_inputs_list_cond[ - # -3] # not used, change order doesn't affect logic - # additional_inputs_list_cond.append( - # tmp_bias) # not used, change order doesn't affect logic + if additional_inputs: + tmp_bias = additional_inputs_list_cond[3] # additional_inputs_list_cond[-3] # not used, change order doesn't affect logic + del additional_inputs_list_cond[3] # additional_inputs_list_cond[-3] # not used, change order doesn't affect logic + additional_inputs_list_cond.append(tmp_bias) # not used, change order doesn't affect logic cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) cond_hlo = cond_ctx.hlo() @@ -213,11 +210,11 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): p = xb.mkparam(builder, len(params), shape) params.append(p) - # # TODO(@manfei): treat hard-code input arguments, currently switch bias and output_value if additional_inputs(weight/bias) exists - # if additional_inputs: - # tmp_bias = params[-3] - # del params[-3] - # params.append(tmp_bias) + # TODO(@manfei): treat hard-code input arguments, currently switch bias and output_value if additional_inputs(weight/bias) exists + if additional_inputs: + tmp_bias = params[5] # params[-3] + del params[5] # params[-3] + params.append(tmp_bias) # generate while xlacomputation input_tuple = xb.Op.tuple(tuple(params)) From a3744e4e9e1019531b637a98a982ab5748750fa0 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:06:38 +0000 Subject: [PATCH 469/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 358704ec328..f07fd73932c 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -283,7 +283,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): additional_inputs.append(param) print("in mnist additional_inputs: ", additional_inputs) - upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = _xla_while_loop( + upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, w2_, b2_, output_value_real__, = _xla_while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) # (upper, lower, one_value, init_val, l_in_0, output_value), ()) From 4783e72e5f6253e6dc8046edf0ba6b341c0c0cea Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:07:07 +0000 Subject: [PATCH 470/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index f07fd73932c..96e468526dd 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -282,7 +282,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # additional_inputs.insert(-1, param) additional_inputs.append(param) - print("in mnist additional_inputs: ", additional_inputs) + # print("in mnist additional_inputs: ", additional_inputs) upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, w2_, b2_, output_value_real__, = _xla_while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) From e4e0066e846c27ccbd3ca1d73a2935d409585186 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:08:23 +0000 Subject: [PATCH 471/546] test --- test/test_test_mnist.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 96e468526dd..7ba76df3ebd 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -288,6 +288,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) # (upper, lower, one_value, init_val, l_in_0, output_value), ()) print("finish newnewnew_test") + print("actual res: ", output_value_real__) + expected_ = simple_with_linear(l_in_0) + print("expected res: ", expected_) # run test model def test_mnist(): From 906292a2996a9a2bf67213ebcc82fa0fb38c7d13 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:09:18 +0000 Subject: [PATCH 472/546] test --- torch_xla/experimental/fori_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 541a3f74a42..1424ce62853 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -151,7 +151,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): torch.randint( 10, additional_input.size(), dtype=additional_input.dtype).to(device)) - print("fake_carried_inputs: ", fake_carried_inputs) + # print("fake_carried_inputs: ", fake_carried_inputs) # TODO(@manfei): specify which element is for which argument like a,b,c cond_result = cond_fn(*fake_carried_inputs) From 1b7d3afed1c4b3e3bb17878df525e223d6e5d38d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:10:07 +0000 Subject: [PATCH 473/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 7ba76df3ebd..09d7ecdad5f 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -270,7 +270,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # if name in ['bias']: # print(param.size()) - upper = torch.tensor([1], dtype=torch.int32, device=device) + upper = torch.tensor([2], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) From 82ba3238341054b874912d8bc84d6e0028ed53c7 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:11:01 +0000 Subject: [PATCH 474/546] test --- test/test_test_mnist.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 09d7ecdad5f..81d1b68450a 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -298,9 +298,9 @@ def test_mnist(): print("before test_mnist") newnewnew_test() # newnew_test() # new_test() # test() - # target fori_loop - for epoch in range(1, n_epochs + 1): - newnewnew_test() # newnew_test() # new_test() # test() + # # target fori_loop + # for epoch in range(1, n_epochs + 1): + # newnewnew_test() # newnew_test() # new_test() # test() print("after test_mnist") From f2307aadbac0ee8d7905032a41467eafb2c34538 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:11:20 +0000 Subject: [PATCH 475/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 81d1b68450a..53cc5a6f674 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -270,7 +270,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # if name in ['bias']: # print(param.size()) - upper = torch.tensor([2], dtype=torch.int32, device=device) + upper = torch.tensor([5], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) From 74e69de99e3b073ec2d4d18fa6caa808d1cf119e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:12:25 +0000 Subject: [PATCH 476/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 53cc5a6f674..30d2a971349 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -270,7 +270,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # if name in ['bias']: # print(param.size()) - upper = torch.tensor([5], dtype=torch.int32, device=device) + upper = torch.tensor([15], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) From d596c60f023b8c9303c575d946cd48ed9588ee8b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:12:50 +0000 Subject: [PATCH 477/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 30d2a971349..a98e91ebc35 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -270,7 +270,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # if name in ['bias']: # print(param.size()) - upper = torch.tensor([15], dtype=torch.int32, device=device) + upper = torch.tensor([50], dtype=torch.int32, device=device) lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) From cea7a985fdbd6a2d25d25fe31cb5a08bb23f47b3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:26:39 +0000 Subject: [PATCH 478/546] test --- test/test_test_mnist.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index a98e91ebc35..b56a83e9978 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -52,8 +52,11 @@ class SimpleWithLinearPure(torch.nn.Module): def __init__(self): super().__init__() - self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) - self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) + self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5).to(xm.xla_device()) + # self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + # self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) + # self.fc1 = nn.Linear(320, 50).to(xm.xla_device()) + # self.fc2 = nn.Linear(50, 10).to(xm.xla_device()) # def forward(self, upper, lower, one_value, x, input_value, output_value): def forward(self, input_value): From 1da2f8ea370cd72b81d81c8e3eec3dbfa82c9ab4 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:28:42 +0000 Subject: [PATCH 479/546] test --- test/test_test_mnist.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index b56a83e9978..af9637034dd 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -60,8 +60,9 @@ def __init__(self): # def forward(self, upper, lower, one_value, x, input_value, output_value): def forward(self, input_value): - output_value_real = self.linear(input_value) - output_value_real_final = self.linear2(output_value_real) + # output_value_real = self.linear(input_value) + # output_value_real_final = self.linear2(output_value_real) + output_value_real_final = self.conv1(input_value) return output_value_real_final From 59152dd925a73a675f2fabd6512e68128240bcc3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:45:06 +0000 Subject: [PATCH 480/546] test --- test/test_test_mnist.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index af9637034dd..1502a2904d5 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -65,7 +65,6 @@ def forward(self, input_value): output_value_real_final = self.conv1(input_value) return output_value_real_final - class SimpleWithLinear(torch.nn.Module): def __init__(self): super().__init__() @@ -278,9 +277,15 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) - l_in_0 = torch.rand(10, device=xm.xla_device()) + # l_in_0 = torch.rand(10, device=xm.xla_device()) output_value = torch.zeros([30], dtype=torch.float32, device=device) + bs=16 + l_in_0 = torch.randn(bs, 1, 28, 28, dtype=torch.int32, device=device) +# c = nn.Conv2d(3,10,kernel_size=5,stride=1,padding=2) +# out = c(x) +# print(out.nelement()) + additional_inputs = [] for name, param in simple_with_linear.named_parameters(): # additional_inputs.insert(-1, param) From 8d795b0906dbf484e7471e6accee664eeba9ae53 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:46:19 +0000 Subject: [PATCH 481/546] test --- test/test_test_mnist.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 1502a2904d5..f5bcf36302e 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -52,7 +52,7 @@ class SimpleWithLinearPure(torch.nn.Module): def __init__(self): super().__init__() - self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5).to(xm.xla_device()) + self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5, stride=1, padding=2).to(xm.xla_device()) # self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) # self.fc1 = nn.Linear(320, 50).to(xm.xla_device()) @@ -282,9 +282,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): bs=16 l_in_0 = torch.randn(bs, 1, 28, 28, dtype=torch.int32, device=device) -# c = nn.Conv2d(3,10,kernel_size=5,stride=1,padding=2) -# out = c(x) -# print(out.nelement()) + # c = nn.Conv2d(3,10,kernel_size=5,stride=1,padding=2) + # out = c(x) + # print(out.nelement()) additional_inputs = [] for name, param in simple_with_linear.named_parameters(): From fa0632d34b4aeb7708c1e13545580c54497c0089 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:46:47 +0000 Subject: [PATCH 482/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index f5bcf36302e..ab7028ab4a6 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -281,7 +281,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): output_value = torch.zeros([30], dtype=torch.float32, device=device) bs=16 - l_in_0 = torch.randn(bs, 1, 28, 28, dtype=torch.int32, device=device) + l_in_0 = torch.randn(bs, 1, 28, 28, dtype=torch.float32, device=device) # c = nn.Conv2d(3,10,kernel_size=5,stride=1,padding=2) # out = c(x) # print(out.nelement()) From 8f4b9a33031052ad36c72e80bfae08505e022974 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 00:47:40 +0000 Subject: [PATCH 483/546] test --- torch_xla/csrc/init_python_bindings.cpp | 4 ++-- torch_xla/experimental/fori_loop.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 56f8f16b8e6..1c83cfcbbb0 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -937,8 +937,8 @@ class PyLoweringContext { if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value - // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); - int64_t parameter_idx = 9; + int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); + // int64_t parameter_idx = 9; // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 1424ce62853..3ba1cea26ee 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -171,9 +171,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - # cond_hlo_print = xb.get_computation_hlo(cond_computation) - # print("cond computation: !!!!!!!!!") - # print(cond_hlo_print) + cond_hlo_print = xb.get_computation_hlo(cond_computation) + print("cond computation: !!!!!!!!!") + print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) From ba14f8dc14a9f8e1b02b1f73e45cffd7ed05cecc Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 16:56:27 +0000 Subject: [PATCH 484/546] test --- torch_xla/experimental/fori_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 3ba1cea26ee..0d2a5ee66cb 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -193,9 +193,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - # body_hlo_print = xb.get_computation_hlo(body_computation) - # print("body computation: !!!!!!!!!") - # print(body_hlo_print) + body_hlo_print = xb.get_computation_hlo(body_computation) + print("body computation: !!!!!!!!!") + print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs From 4e00273e7c3dba35c963d8e796f4b64b70fd63a0 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 16:59:00 +0000 Subject: [PATCH 485/546] test --- torch_xla/csrc/init_python_bindings.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 1c83cfcbbb0..4b9d0635557 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -937,7 +937,8 @@ class PyLoweringContext { if (GetNameString() == "bodyctx") { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value - int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); + // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); + int64_t parameter_idx = 8; # conv2d // int64_t parameter_idx = 9; // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { From 499e304a660aed75df032ed7e93a0e5b1fe7f00c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:00:17 +0000 Subject: [PATCH 486/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 4b9d0635557..03694618ec6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -938,7 +938,7 @@ class PyLoweringContext { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); - int64_t parameter_idx = 8; # conv2d + int64_t parameter_idx = 8; // conv2d // int64_t parameter_idx = 9; // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { From 04009c498aa62316231620e4fa278814eb98b8a9 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:03:08 +0000 Subject: [PATCH 487/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 03694618ec6..623d9231186 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -938,7 +938,7 @@ class PyLoweringContext { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); - int64_t parameter_idx = 8; // conv2d + int64_t parameter_idx = 9; // conv2d // int64_t parameter_idx = 9; // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { From eb4aeec3d8fbaa1ae7993725125def8ffb45079d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:03:21 +0000 Subject: [PATCH 488/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 623d9231186..f85013e50db 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -939,7 +939,7 @@ class PyLoweringContext { // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); int64_t parameter_idx = 9; // conv2d - // int64_t parameter_idx = 9; + // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { XLATensorPtr xtensor = bridge::GetXlaTensor(additional_input_tensor); From ee2a0d57a27552dca0d18c91423cee81f28c27f0 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:05:49 +0000 Subject: [PATCH 489/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index f85013e50db..949f96d36bc 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -938,7 +938,7 @@ class PyLoweringContext { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); - int64_t parameter_idx = 9; // conv2d + int64_t parameter_idx = 7; // conv2d // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { From 5ef8deadb4e30b24b4571477745e946ff822d8a2 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:13:00 +0000 Subject: [PATCH 490/546] test --- test/test_test_mnist.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index ab7028ab4a6..12f8938a758 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -277,14 +277,17 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): lower = torch.tensor([0], dtype=torch.int32, device=device) one_value = torch.tensor([1], dtype=torch.int32, device=device) init_val = torch.tensor([1], dtype=torch.int32, device=device) + ### linear 10*20 + 20*30 input&output # l_in_0 = torch.rand(10, device=xm.xla_device()) - output_value = torch.zeros([30], dtype=torch.float32, device=device) - + # output_value = torch.zeros([30], dtype=torch.float32, device=device) + ### conv2d input&output bs=16 l_in_0 = torch.randn(bs, 1, 28, 28, dtype=torch.float32, device=device) # c = nn.Conv2d(3,10,kernel_size=5,stride=1,padding=2) # out = c(x) # print(out.nelement()) + output_value = torch.zeros([16,10,28,28], dtype=torch.float32, device=device) + additional_inputs = [] for name, param in simple_with_linear.named_parameters(): From 70793018d5a1553a7f41cb4eb04a3b78858d95fe Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:14:13 +0000 Subject: [PATCH 491/546] test --- test/test_test_mnist.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 12f8938a758..d73c72d41fd 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -295,7 +295,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): additional_inputs.append(param) # print("in mnist additional_inputs: ", additional_inputs) - upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, w2_, b2_, output_value_real__, = _xla_while_loop( + ### linear 10*20 + 20*30 + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, w2_, b2_, output_value_real__, = _xla_while_loop( + ### conv2d + upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, output_value_real__, = _xla_while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) # (upper, lower, one_value, init_val, l_in_0, output_value), ()) From 584d6ccc102248ab7badc787df3e2e45e25bc658 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:17:26 +0000 Subject: [PATCH 492/546] test --- test/test_test_mnist.py | 4 +++- torch_xla/experimental/fori_loop.py | 12 ++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index d73c72d41fd..ea6ca3bb83c 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -53,6 +53,7 @@ class SimpleWithLinearPure(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5, stride=1, padding=2).to(xm.xla_device()) + # self.bn1 = nn.BatchNorm2d(10) # self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) # self.fc1 = nn.Linear(320, 50).to(xm.xla_device()) @@ -62,7 +63,8 @@ def __init__(self): def forward(self, input_value): # output_value_real = self.linear(input_value) # output_value_real_final = self.linear2(output_value_real) - output_value_real_final = self.conv1(input_value) + # output_value_real_final = self.conv1(input_value) # conv2d + output_value_real_final = F.relu(F.max_pool2d(self.conv1(input_value), 2)) return output_value_real_final class SimpleWithLinear(torch.nn.Module): diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 0d2a5ee66cb..1424ce62853 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -171,9 +171,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - cond_hlo_print = xb.get_computation_hlo(cond_computation) - print("cond computation: !!!!!!!!!") - print(cond_hlo_print) + # cond_hlo_print = xb.get_computation_hlo(cond_computation) + # print("cond computation: !!!!!!!!!") + # print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) @@ -193,9 +193,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - body_hlo_print = xb.get_computation_hlo(body_computation) - print("body computation: !!!!!!!!!") - print(body_hlo_print) + # body_hlo_print = xb.get_computation_hlo(body_computation) + # print("body computation: !!!!!!!!!") + # print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs From 0affbb4ee20ce56b7527fe43dd67d59797fa0552 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:19:51 +0000 Subject: [PATCH 493/546] test --- test/test_test_mnist.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index ea6ca3bb83c..ed42f2124c0 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -64,7 +64,7 @@ def forward(self, input_value): # output_value_real = self.linear(input_value) # output_value_real_final = self.linear2(output_value_real) # output_value_real_final = self.conv1(input_value) # conv2d - output_value_real_final = F.relu(F.max_pool2d(self.conv1(input_value), 2)) + output_value_real_final = F.relu(F.max_pool2d(self.conv1(input_value), 2)) # conv2d+mnist-treat return output_value_real_final class SimpleWithLinear(torch.nn.Module): @@ -288,7 +288,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # c = nn.Conv2d(3,10,kernel_size=5,stride=1,padding=2) # out = c(x) # print(out.nelement()) - output_value = torch.zeros([16,10,28,28], dtype=torch.float32, device=device) + # output_value = torch.zeros([16,10,28,28], dtype=torch.float32, device=device) # conv2d + output_value = torch.zeros([16,10,14,14], dtype=torch.float32, device=device)# conv2d+mnist-treat additional_inputs = [] From 4fe7a62000d2fa190683772fc30fc4ad88617d82 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:21:02 +0000 Subject: [PATCH 494/546] test --- test/test_test_mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index ed42f2124c0..eb7bc2b08cb 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -306,9 +306,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) # (upper, lower, one_value, init_val, l_in_0, output_value), ()) print("finish newnewnew_test") - print("actual res: ", output_value_real__) + print("actual res: ", output_value_real__[0][0][0]) expected_ = simple_with_linear(l_in_0) - print("expected res: ", expected_) + print("expected res: ", expected_[0][0][0]) # run test model def test_mnist(): From 7f9dfa372e7c343b5c52739d82936c1a7b36334c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:22:57 +0000 Subject: [PATCH 495/546] test --- test/test_test_mnist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index eb7bc2b08cb..3fd8085b0cd 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -53,7 +53,7 @@ class SimpleWithLinearPure(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5, stride=1, padding=2).to(xm.xla_device()) - # self.bn1 = nn.BatchNorm2d(10) + self.bn1 = nn.BatchNorm2d(10) # self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) # self.fc1 = nn.Linear(320, 50).to(xm.xla_device()) @@ -65,6 +65,7 @@ def forward(self, input_value): # output_value_real_final = self.linear2(output_value_real) # output_value_real_final = self.conv1(input_value) # conv2d output_value_real_final = F.relu(F.max_pool2d(self.conv1(input_value), 2)) # conv2d+mnist-treat + output_value_real_final = self.bn1(output_value_real_final) return output_value_real_final class SimpleWithLinear(torch.nn.Module): From d9bb401100da59f7d2b8b935ffc8cceafadae531 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:24:20 +0000 Subject: [PATCH 496/546] test --- test/test_test_mnist.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 3fd8085b0cd..a8f51a910e6 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -53,7 +53,7 @@ class SimpleWithLinearPure(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5, stride=1, padding=2).to(xm.xla_device()) - self.bn1 = nn.BatchNorm2d(10) + self.bn1 = torch.nn.BatchNorm2d(10).to(xm.xla_device()) # self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) # self.fc1 = nn.Linear(320, 50).to(xm.xla_device()) @@ -114,12 +114,12 @@ class MNIST(nn.Module): def __init__(self): super(MNIST, self).__init__() - self.conv1 = nn.Conv2d(1, 10, kernel_size=5) - self.bn1 = nn.BatchNorm2d(10) - self.conv2 = nn.Conv2d(10, 20, kernel_size=5) - self.bn2 = nn.BatchNorm2d(20) - self.fc1 = nn.Linear(320, 50) - self.fc2 = nn.Linear(50, 10) + self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5) + self.bn1 = torch.nn.BatchNorm2d(10) + self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5) + self.bn2 = torch.nn.BatchNorm2d(20) + self.fc1 = torch.nn.Linear(320, 50) + self.fc2 = torch.nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) From bebed0a9f5483d9570c9a85a45f88e2cb772d26b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:25:46 +0000 Subject: [PATCH 497/546] test --- torch_xla/csrc/init_python_bindings.cpp | 3 ++- torch_xla/experimental/fori_loop.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 949f96d36bc..543b3e93ce1 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -938,7 +938,8 @@ class PyLoweringContext { xla::XlaBuilder* local_builder = lowering_ctx.builder(); // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); - int64_t parameter_idx = 7; // conv2d + // int64_t parameter_idx = 7; // conv2d + int64_t parameter_idx = 8; // conv2d+mnist-treat // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 1424ce62853..3ba1cea26ee 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -171,9 +171,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - # cond_hlo_print = xb.get_computation_hlo(cond_computation) - # print("cond computation: !!!!!!!!!") - # print(cond_hlo_print) + cond_hlo_print = xb.get_computation_hlo(cond_computation) + print("cond computation: !!!!!!!!!") + print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) From 881dbcaa11e84b89fa8323bc12568772f08e32fb Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:26:40 +0000 Subject: [PATCH 498/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 543b3e93ce1..63c3abd5fbb 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -939,7 +939,7 @@ class PyLoweringContext { // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); // int64_t parameter_idx = 7; // conv2d - int64_t parameter_idx = 8; // conv2d+mnist-treat + int64_t parameter_idx = 9; // conv2d+mnist-treat // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { From bab82ee6d3fdb744ed7ada8f57cdfb9602ecf2c6 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:29:22 +0000 Subject: [PATCH 499/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 63c3abd5fbb..cc221772f02 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -939,7 +939,7 @@ class PyLoweringContext { // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); // int64_t parameter_idx = 7; // conv2d - int64_t parameter_idx = 9; // conv2d+mnist-treat + int64_t parameter_idx = 10; // conv2d+mnist-treat // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { From 76b4ba44cf1c5819b1da058eb690967bfa12afcf Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:33:34 +0000 Subject: [PATCH 500/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index cc221772f02..f89c4fca414 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -939,7 +939,7 @@ class PyLoweringContext { // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); // int64_t parameter_idx = 7; // conv2d - int64_t parameter_idx = 10; // conv2d+mnist-treat + int64_t parameter_idx = 12; // conv2d+mnist-treat // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { From 1e6cb5ba6cd543f41eda84a0ee41889ef4850b0c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:49:08 +0000 Subject: [PATCH 501/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index f89c4fca414..6ec0cb5c6fd 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -939,7 +939,7 @@ class PyLoweringContext { // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); // int64_t parameter_idx = 7; // conv2d - int64_t parameter_idx = 12; // conv2d+mnist-treat + int64_t parameter_idx = 11; // conv2d+mnist-treat // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { From a80e747444b90ca91c499e2e6554be1d94c0e443 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:54:43 +0000 Subject: [PATCH 502/546] test --- test/test_test_mnist.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index a8f51a910e6..7081f02fe43 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -268,10 +268,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # print("simple_with_linear.named_parameters(): ", simple_with_linear.named_parameters()) # import pdb; pdb.set_trace() - # for name, param in simple_with_linear.named_parameters(): - # print("arrive the loop") - # print("name: ", name) - # print("param: ", param) + for name, param in simple_with_linear.named_parameters(): + print("arrive the loop") + print("name: ", name) + print("param: ", param) # if name in ['bias']: # print(param.size()) From bb7f6822d7dd41330420c4d27cdd3887c7aa91e1 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 17:59:56 +0000 Subject: [PATCH 503/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 7081f02fe43..442e091ce06 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -23,7 +23,7 @@ from torch_xla.experimental.fori_loop import _xla_while_loop from torch._higher_order_ops.while_loop import while_loop -n_epochs = 3 +n_epochs = 1 # 3 batch_size_train = 8 # 64 batch_size_test = 10 # 1000 learning_rate = 0.01 From 9e257f6483560c0a3e33948eab2e7767fedb332b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 18:13:23 +0000 Subject: [PATCH 504/546] test --- test/test_test_mnist.py | 2 +- torch_xla/experimental/fori_loop.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 442e091ce06..7081f02fe43 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -23,7 +23,7 @@ from torch_xla.experimental.fori_loop import _xla_while_loop from torch._higher_order_ops.while_loop import while_loop -n_epochs = 1 # 3 +n_epochs = 3 batch_size_train = 8 # 64 batch_size_test = 10 # 1000 learning_rate = 0.01 diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 3ba1cea26ee..0d2a5ee66cb 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -193,9 +193,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - # body_hlo_print = xb.get_computation_hlo(body_computation) - # print("body computation: !!!!!!!!!") - # print(body_hlo_print) + body_hlo_print = xb.get_computation_hlo(body_computation) + print("body computation: !!!!!!!!!") + print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs From 42da876805d81a12191a5f8d6b42bf8ba8d9fea6 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 19:54:00 +0000 Subject: [PATCH 505/546] test --- test/test_test_mnist.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 7081f02fe43..2785f0c4c59 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -255,6 +255,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # bias = simple_with_linear.bias # not be used actually, initialized as placeholder xlacomputation requirement res = [upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone()] for name, param in simple_with_linear.named_parameters(): + if name[:2]=='bn': + skip # skip bn res.insert(-1, param) # torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) return tuple(res) From c31b3ecf0964a3b14285f183e02ac4cb4999f052 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 19:54:56 +0000 Subject: [PATCH 506/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 2785f0c4c59..d2fba0991c5 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -256,7 +256,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): res = [upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone()] for name, param in simple_with_linear.named_parameters(): if name[:2]=='bn': - skip # skip bn + continue # skip bn res.insert(-1, param) # torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) return tuple(res) From cf33820d0170f969b8c4ea7446ea06dc128b1f3b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 20:27:17 +0000 Subject: [PATCH 507/546] test --- test/test_test_mnist.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index d2fba0991c5..13ff3c50f36 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -256,7 +256,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): res = [upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone()] for name, param in simple_with_linear.named_parameters(): if name[:2]=='bn': - continue # skip bn + res.insert(-1, param) # dumpicate # continue # skip bn res.insert(-1, param) # torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) return tuple(res) @@ -297,8 +297,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): additional_inputs = [] for name, param in simple_with_linear.named_parameters(): - # additional_inputs.insert(-1, param) - additional_inputs.append(param) + if name[:2]=='bn': + additional_inputs.append(param) # dumplicate + # additional_inputs.insert(-1, param) + additional_inputs.append(param) # print("in mnist additional_inputs: ", additional_inputs) ### linear 10*20 + 20*30 From 65f2edb4b7739c8ee458d3c5243e2305562cdb32 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 20:30:23 +0000 Subject: [PATCH 508/546] test --- test/test_test_mnist.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 13ff3c50f36..200b05f3fd2 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -306,7 +306,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): ### linear 10*20 + 20*30 # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, w2_, b2_, output_value_real__, = _xla_while_loop( ### conv2d - upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, output_value_real__, = _xla_while_loop( + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, output_value_real__, = _xla_while_loop( + #### conv1+bn1 + upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, output_value_real__, = _xla_while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) # (upper, lower, one_value, init_val, l_in_0, output_value), ()) From 2154f650580b5f914331f213310976f3721a75e4 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 20:38:52 +0000 Subject: [PATCH 509/546] test --- test/test_test_mnist.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 200b05f3fd2..657d7ff426f 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -254,10 +254,20 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # weight = simple_with_linear.weight # not be used actually, initialized as placeholder xlacomputation requirement # bias = simple_with_linear.bias # not be used actually, initialized as placeholder xlacomputation requirement res = [upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone()] + bn_list = [] + bn_flag = False for name, param in simple_with_linear.named_parameters(): if name[:2]=='bn': - res.insert(-1, param) # dumpicate # continue # skip bn + bn_flag = True + bn_list.insert(-1, param) # dumpicate # continue # skip bn + else: + bn_flag = False + res.insert(-1, param) + + if not bn_flag and (len(bn_list) !=0): # False + res = res[:-1] + bn_list + res[-1] + bn_list = [] # torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) return tuple(res) # return (upper.clone(), new_lower.clone(), one_value.clone(), torch.add( @@ -296,12 +306,24 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): additional_inputs = [] + bn_list = [] + bn_flag = False for name, param in simple_with_linear.named_parameters(): + # if name[:2]=='bn': + # additional_inputs.append(param) # dumplicate if name[:2]=='bn': - additional_inputs.append(param) # dumplicate + bn_flag = True + bn_list.insert(-1, param) # dumpicate # continue # skip bn + else: + bn_flag = False + # additional_inputs.insert(-1, param) additional_inputs.append(param) + if not bn_flag and (len(bn_list) !=0): # False + additional_inputs =additional_inputs + bn_list + bn_list = [] + # print("in mnist additional_inputs: ", additional_inputs) ### linear 10*20 + 20*30 # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, w2_, b2_, output_value_real__, = _xla_while_loop( From 214ac1fb8a9a13dfd9f77faec4fd28759f6a9896 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 20:40:08 +0000 Subject: [PATCH 510/546] test --- test/test_test_mnist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 657d7ff426f..f5354ef439f 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -312,6 +312,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # if name[:2]=='bn': # additional_inputs.append(param) # dumplicate if name[:2]=='bn': + print("catch: ", name) bn_flag = True bn_list.insert(-1, param) # dumpicate # continue # skip bn else: From ddb69ee55381d68b540b40c1140f6c532f7c72d6 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 20:41:04 +0000 Subject: [PATCH 511/546] test --- test/test_test_mnist.py | 8 ++++---- torch_xla/experimental/fori_loop.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index f5354ef439f..6c58a6f931d 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -280,10 +280,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # print("simple_with_linear.named_parameters(): ", simple_with_linear.named_parameters()) # import pdb; pdb.set_trace() - for name, param in simple_with_linear.named_parameters(): - print("arrive the loop") - print("name: ", name) - print("param: ", param) + # for name, param in simple_with_linear.named_parameters(): + # print("arrive the loop") + # print("name: ", name) + # print("param: ", param) # if name in ['bias']: # print(param.size()) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 0d2a5ee66cb..1424ce62853 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -171,9 +171,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - cond_hlo_print = xb.get_computation_hlo(cond_computation) - print("cond computation: !!!!!!!!!") - print(cond_hlo_print) + # cond_hlo_print = xb.get_computation_hlo(cond_computation) + # print("cond computation: !!!!!!!!!") + # print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) @@ -193,9 +193,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - body_hlo_print = xb.get_computation_hlo(body_computation) - print("body computation: !!!!!!!!!") - print(body_hlo_print) + # body_hlo_print = xb.get_computation_hlo(body_computation) + # print("body computation: !!!!!!!!!") + # print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs From 75dad47d02a272e6e8e9acef338f4956fa03205f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 20:43:16 +0000 Subject: [PATCH 512/546] test --- test/test_test_mnist.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 6c58a6f931d..1ae6af17615 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -265,7 +265,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): res.insert(-1, param) - if not bn_flag and (len(bn_list) !=0): # False + if (not bn_flag) and (len(bn_list) !=0): # False res = res[:-1] + bn_list + res[-1] bn_list = [] # torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) @@ -315,14 +315,16 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): print("catch: ", name) bn_flag = True bn_list.insert(-1, param) # dumpicate # continue # skip bn + print("newest bn_list: ", bn_list) else: bn_flag = False # additional_inputs.insert(-1, param) additional_inputs.append(param) - if not bn_flag and (len(bn_list) !=0): # False + if (not bn_flag) and (len(bn_list) !=0): # False additional_inputs =additional_inputs + bn_list + print("added bn_list: ", bn_list) bn_list = [] # print("in mnist additional_inputs: ", additional_inputs) From 7225a3c46d74cfadfc4d8b07cb10d0071bcaf6ba Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 20:45:11 +0000 Subject: [PATCH 513/546] test --- test/test_test_mnist.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 1ae6af17615..7c80d1929a1 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -327,6 +327,12 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): print("added bn_list: ", bn_list) bn_list = [] + ### !!! add still exist bn_list if the last additional_inputs is bn- pre + if flag and (len(bn_list) !=0): + additional_inputs =additional_inputs + bn_list + print("added bn_list: ", bn_list) + bn_list = [] + # print("in mnist additional_inputs: ", additional_inputs) ### linear 10*20 + 20*30 # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, w2_, b2_, output_value_real__, = _xla_while_loop( From 77c0e42c542288e8ecac959645da5cff02019b44 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 20:45:41 +0000 Subject: [PATCH 514/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 7c80d1929a1..4a75b9dd0b9 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -328,7 +328,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): bn_list = [] ### !!! add still exist bn_list if the last additional_inputs is bn- pre - if flag and (len(bn_list) !=0): + if bn_flag and (len(bn_list) !=0): additional_inputs =additional_inputs + bn_list print("added bn_list: ", bn_list) bn_list = [] From 5454e205f41fac7208db207f6e011f8d0017fa36 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 20:48:54 +0000 Subject: [PATCH 515/546] test --- test/test_test_mnist.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 4a75b9dd0b9..6fb4a6ec71f 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -269,6 +269,13 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): res = res[:-1] + bn_list + res[-1] bn_list = [] # torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) + + ### !!! add still exist bn_list if the last additional_inputs is bn- pre + if bn_flag and (len(bn_list) !=0): + res = res[:-1] + bn_list + res[-1] + bn_list = [] + bn_flag = False + return tuple(res) # return (upper.clone(), new_lower.clone(), one_value.clone(), torch.add( # one_value, x), input_value.clone(), output_value.clone(), simple_with_linear.linear.weight) # bias.clone(), weight.clone(), output_value.clone() From 843758f771ec95b13859f04b05797ad1fc8eb8a8 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 20:50:34 +0000 Subject: [PATCH 516/546] test --- test/test_test_mnist.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 6fb4a6ec71f..76682c4bdfe 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -266,13 +266,15 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): res.insert(-1, param) if (not bn_flag) and (len(bn_list) !=0): # False - res = res[:-1] + bn_list + res[-1] + res = res[:-1] + bn_list # + res[-1] + res.append(res[-1]) bn_list = [] # torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) ### !!! add still exist bn_list if the last additional_inputs is bn- pre if bn_flag and (len(bn_list) !=0): - res = res[:-1] + bn_list + res[-1] + res = res[:-1] + bn_list # + res[-1] + res.append(res[-1]) bn_list = [] bn_flag = False @@ -339,6 +341,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): additional_inputs =additional_inputs + bn_list print("added bn_list: ", bn_list) bn_list = [] + bn_flag = False # print("in mnist additional_inputs: ", additional_inputs) ### linear 10*20 + 20*30 From 558f80fad4d1b4c968752da3eeb7bcfa880bb3be Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 20:51:55 +0000 Subject: [PATCH 517/546] test --- test/test_test_mnist.py | 8 ++++---- torch_xla/experimental/fori_loop.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 76682c4bdfe..74a1a195c71 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -321,10 +321,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # if name[:2]=='bn': # additional_inputs.append(param) # dumplicate if name[:2]=='bn': - print("catch: ", name) + # print("catch: ", name) bn_flag = True bn_list.insert(-1, param) # dumpicate # continue # skip bn - print("newest bn_list: ", bn_list) + # print("newest bn_list: ", bn_list) else: bn_flag = False @@ -333,13 +333,13 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): if (not bn_flag) and (len(bn_list) !=0): # False additional_inputs =additional_inputs + bn_list - print("added bn_list: ", bn_list) + # print("added bn_list: ", bn_list) bn_list = [] ### !!! add still exist bn_list if the last additional_inputs is bn- pre if bn_flag and (len(bn_list) !=0): additional_inputs =additional_inputs + bn_list - print("added bn_list: ", bn_list) + # print("added bn_list: ", bn_list) bn_list = [] bn_flag = False diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 1424ce62853..133d4e06285 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -172,7 +172,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) # cond_hlo_print = xb.get_computation_hlo(cond_computation) - # print("cond computation: !!!!!!!!!") + print("cond computation: !!!!!!!!!") # print(cond_hlo_print) # generate body_fn xlacomputation @@ -194,7 +194,7 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) # body_hlo_print = xb.get_computation_hlo(body_computation) - # print("body computation: !!!!!!!!!") + print("body computation: !!!!!!!!!") # print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while From af06fa69d15d224f974b07303310fe72e1f3e4bc Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 20:53:17 +0000 Subject: [PATCH 518/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- torch_xla/experimental/fori_loop.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 6ec0cb5c6fd..8796fa8b0a5 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -939,7 +939,7 @@ class PyLoweringContext { // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); // int64_t parameter_idx = 7; // conv2d - int64_t parameter_idx = 11; // conv2d+mnist-treat + int64_t parameter_idx = 10; // conv2d+mnist-treat // conv1 + bn1 // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 133d4e06285..0d2a5ee66cb 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -171,9 +171,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - # cond_hlo_print = xb.get_computation_hlo(cond_computation) + cond_hlo_print = xb.get_computation_hlo(cond_computation) print("cond computation: !!!!!!!!!") - # print(cond_hlo_print) + print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) @@ -193,9 +193,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - # body_hlo_print = xb.get_computation_hlo(body_computation) + body_hlo_print = xb.get_computation_hlo(body_computation) print("body computation: !!!!!!!!!") - # print(body_hlo_print) + print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs From e1b6c8a2d1cf04cb9a3b83132feb042356f5eef9 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 20:56:38 +0000 Subject: [PATCH 519/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 8796fa8b0a5..71b12f8d3b9 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -939,7 +939,7 @@ class PyLoweringContext { // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); // int64_t parameter_idx = 7; // conv2d - int64_t parameter_idx = 10; // conv2d+mnist-treat // conv1 + bn1 + int64_t parameter_idx = 9; // conv2d+mnist-treat // conv1 + bn1 // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { From 8f5572fbf1c122162fe1d11c236738d1e6f86111 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 21:03:26 +0000 Subject: [PATCH 520/546] test --- test/test_test_mnist.py | 6 ++++-- torch_xla/csrc/init_python_bindings.cpp | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 74a1a195c71..7b26f00deb3 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -266,15 +266,17 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): res.insert(-1, param) if (not bn_flag) and (len(bn_list) !=0): # False + output_value = res[-1] res = res[:-1] + bn_list # + res[-1] - res.append(res[-1]) + res.append(output_value) bn_list = [] # torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) ### !!! add still exist bn_list if the last additional_inputs is bn- pre if bn_flag and (len(bn_list) !=0): + output_value = res[-1] res = res[:-1] + bn_list # + res[-1] - res.append(res[-1]) + res.append(output_value) bn_list = [] bn_flag = False diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 71b12f8d3b9..270f9790654 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -939,7 +939,7 @@ class PyLoweringContext { // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); // int64_t parameter_idx = 7; // conv2d - int64_t parameter_idx = 9; // conv2d+mnist-treat // conv1 + bn1 + int64_t parameter_idx = 11; // conv2d+mnist-treat // conv1 + bn1 // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { From 00f499068c629a8ec66b4658d25521d4ab439a08 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 21:08:05 +0000 Subject: [PATCH 521/546] test --- test/test_test_mnist.py | 10 ++++++---- torch_xla/experimental/fori_loop.py | 12 ++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 7b26f00deb3..cb7efdfa1a1 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -54,19 +54,21 @@ def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5, stride=1, padding=2).to(xm.xla_device()) self.bn1 = torch.nn.BatchNorm2d(10).to(xm.xla_device()) + self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5).to(xm.xla_device()) # self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) # self.fc1 = nn.Linear(320, 50).to(xm.xla_device()) # self.fc2 = nn.Linear(50, 10).to(xm.xla_device()) # def forward(self, upper, lower, one_value, x, input_value, output_value): - def forward(self, input_value): + def forward(self, x): # output_value_real = self.linear(input_value) # output_value_real_final = self.linear2(output_value_real) # output_value_real_final = self.conv1(input_value) # conv2d - output_value_real_final = F.relu(F.max_pool2d(self.conv1(input_value), 2)) # conv2d+mnist-treat - output_value_real_final = self.bn1(output_value_real_final) - return output_value_real_final + x = F.relu(F.max_pool2d(self.conv1(x), 2)) # conv2d+mnist-treat + x = self.bn1(x) + x = F.relu(F.max_pool2d(self.conv2(x), 2)) + return x class SimpleWithLinear(torch.nn.Module): def __init__(self): diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 0d2a5ee66cb..1424ce62853 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -171,9 +171,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - cond_hlo_print = xb.get_computation_hlo(cond_computation) - print("cond computation: !!!!!!!!!") - print(cond_hlo_print) + # cond_hlo_print = xb.get_computation_hlo(cond_computation) + # print("cond computation: !!!!!!!!!") + # print(cond_hlo_print) # generate body_fn xlacomputation body_result = body_fn(*fake_carried_inputs) @@ -193,9 +193,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - body_hlo_print = xb.get_computation_hlo(body_computation) - print("body computation: !!!!!!!!!") - print(body_hlo_print) + # body_hlo_print = xb.get_computation_hlo(body_computation) + # print("body computation: !!!!!!!!!") + # print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs From e10961990e12203c470321ea181e754a876ef770 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 21:09:20 +0000 Subject: [PATCH 522/546] test --- torch_xla/csrc/init_python_bindings.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 270f9790654..b83323372b6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -939,7 +939,8 @@ class PyLoweringContext { // TODO(@manfei): treat hard code parameter_idx value // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); // int64_t parameter_idx = 7; // conv2d - int64_t parameter_idx = 11; // conv2d+mnist-treat // conv1 + bn1 + // int64_t parameter_idx = 11; // conv2d+mnist-treat // conv1 + bn1 + int64_t parameter_idx = 13; // conv1 + bn1 + conv2 // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { From ea1ce95d490edc53b703485ceed69a22b1da1a9c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 21:13:19 +0000 Subject: [PATCH 523/546] test --- test/test_test_mnist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index cb7efdfa1a1..e55a586624b 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -315,7 +315,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # out = c(x) # print(out.nelement()) # output_value = torch.zeros([16,10,28,28], dtype=torch.float32, device=device) # conv2d - output_value = torch.zeros([16,10,14,14], dtype=torch.float32, device=device)# conv2d+mnist-treat + # output_value = torch.zeros([16,10,14,14], dtype=torch.float32, device=device) # conv2d+mnist-treat # conv1 + bn1 + output_value = torch.zeros([16,20,5,5], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 additional_inputs = [] From abafd6ab9a8e23c808fbc4580da0b02ae208ab80 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 21:16:05 +0000 Subject: [PATCH 524/546] test --- test/test_test_mnist.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index e55a586624b..0470ce8af29 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -293,10 +293,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # print("simple_with_linear.named_parameters(): ", simple_with_linear.named_parameters()) # import pdb; pdb.set_trace() - # for name, param in simple_with_linear.named_parameters(): - # print("arrive the loop") - # print("name: ", name) - # print("param: ", param) + for name, param in simple_with_linear.named_parameters(): + # print("arrive the loop") + print("name: ", name) + print("param: ", param) # if name in ['bias']: # print(param.size()) From 95227f041329a3f2b37b62257bcecef49f7674b3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 21:39:33 +0000 Subject: [PATCH 525/546] test --- test/test_test_mnist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 0470ce8af29..dd615b7705a 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -348,6 +348,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): bn_list = [] bn_flag = False + print("final additional_inputs: ", additional_inputs) # print("in mnist additional_inputs: ", additional_inputs) ### linear 10*20 + 20*30 # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, w2_, b2_, output_value_real__, = _xla_while_loop( From 06cb773c607fa7059f295c65d2037ba60bb4ab0a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 21:49:42 +0000 Subject: [PATCH 526/546] test --- test/test_test_mnist.py | 50 ++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index dd615b7705a..54036c86bca 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -257,30 +257,32 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # bias = simple_with_linear.bias # not be used actually, initialized as placeholder xlacomputation requirement res = [upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone()] bn_list = [] - bn_flag = False + # bn_flag = False for name, param in simple_with_linear.named_parameters(): if name[:2]=='bn': - bn_flag = True + # bn_flag = True bn_list.insert(-1, param) # dumpicate # continue # skip bn - else: - bn_flag = False + # else: + # bn_flag = False res.insert(-1, param) - if (not bn_flag) and (len(bn_list) !=0): # False - output_value = res[-1] - res = res[:-1] + bn_list # + res[-1] - res.append(output_value) - bn_list = [] - # torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) + # if (not bn_flag) and (len(bn_list) !=0): # False + # output_value = res[-1] + # res = res[:-1] + bn_list # + res[-1] + # res.append(output_value) + # bn_list = [] + # # torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) ### !!! add still exist bn_list if the last additional_inputs is bn- pre - if bn_flag and (len(bn_list) !=0): + # if bn_flag and (len(bn_list) !=0): + ### !!! add at the tile + if len(bn_list) !=0: output_value = res[-1] res = res[:-1] + bn_list # + res[-1] res.append(output_value) bn_list = [] - bn_flag = False + # bn_flag = False return tuple(res) # return (upper.clone(), new_lower.clone(), one_value.clone(), torch.add( @@ -321,32 +323,34 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): additional_inputs = [] bn_list = [] - bn_flag = False + # bn_flag = False for name, param in simple_with_linear.named_parameters(): # if name[:2]=='bn': # additional_inputs.append(param) # dumplicate if name[:2]=='bn': # print("catch: ", name) - bn_flag = True + # bn_flag = True bn_list.insert(-1, param) # dumpicate # continue # skip bn # print("newest bn_list: ", bn_list) - else: - bn_flag = False + # else: + # bn_flag = False # additional_inputs.insert(-1, param) additional_inputs.append(param) - if (not bn_flag) and (len(bn_list) !=0): # False - additional_inputs =additional_inputs + bn_list - # print("added bn_list: ", bn_list) - bn_list = [] + # if (not bn_flag) and (len(bn_list) !=0): # False + # additional_inputs =additional_inputs + bn_list + # # print("added bn_list: ", bn_list) + # bn_list = [] ### !!! add still exist bn_list if the last additional_inputs is bn- pre - if bn_flag and (len(bn_list) !=0): - additional_inputs =additional_inputs + bn_list + # if bn_flag and (len(bn_list) !=0): + ### !!! add duplicated bn argus as the tile of the list + if len(bn_list) !=0: + additional_inputs = additional_inputs + bn_list # print("added bn_list: ", bn_list) bn_list = [] - bn_flag = False + # bn_flag = False print("final additional_inputs: ", additional_inputs) # print("in mnist additional_inputs: ", additional_inputs) From f55e4b544bea20891916e20c8728ab563d534a9c Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 21:50:44 +0000 Subject: [PATCH 527/546] test --- test/test_test_mnist.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 54036c86bca..effdd76fc68 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -359,7 +359,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): ### conv2d # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, output_value_real__, = _xla_while_loop( #### conv1+bn1 - upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, output_value_real__, = _xla_while_loop( + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, output_value_real__, = _xla_while_loop( + ##### conv1 + bn1 + conv2 + upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, output_value_real__, = _xla_while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) # (upper, lower, one_value, init_val, l_in_0, output_value), ()) From 47d06eb9a17c3052d28282012a70645d8a188d1a Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 21:51:15 +0000 Subject: [PATCH 528/546] test --- test/test_test_mnist.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index effdd76fc68..ba1d3fdf4e4 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -295,10 +295,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # print("simple_with_linear.named_parameters(): ", simple_with_linear.named_parameters()) # import pdb; pdb.set_trace() - for name, param in simple_with_linear.named_parameters(): - # print("arrive the loop") - print("name: ", name) - print("param: ", param) + # for name, param in simple_with_linear.named_parameters(): + # # print("arrive the loop") + # print("name: ", name) + # print("param: ", param) # if name in ['bias']: # print(param.size()) @@ -352,7 +352,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): bn_list = [] # bn_flag = False - print("final additional_inputs: ", additional_inputs) + # print("final additional_inputs: ", additional_inputs) # print("in mnist additional_inputs: ", additional_inputs) ### linear 10*20 + 20*30 # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, w2_, b2_, output_value_real__, = _xla_while_loop( From 68892b44a0443daea8d14b03f405c18cf94cfe9b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 21:52:07 +0000 Subject: [PATCH 529/546] test --- test/test_test_mnist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index ba1d3fdf4e4..505249ebca7 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -366,6 +366,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) # (upper, lower, one_value, init_val, l_in_0, output_value), ()) print("finish newnewnew_test") + print("torch_add_res__: run times: ", torch_add_res__) print("actual res: ", output_value_real__[0][0][0]) expected_ = simple_with_linear(l_in_0) print("expected res: ", expected_[0][0][0]) From fa12ead411657ceba3e6265290443292753fcc58 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 21:53:30 +0000 Subject: [PATCH 530/546] test --- test/test_test_mnist.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 505249ebca7..8512b2b6707 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -55,6 +55,7 @@ def __init__(self): self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5, stride=1, padding=2).to(xm.xla_device()) self.bn1 = torch.nn.BatchNorm2d(10).to(xm.xla_device()) self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5).to(xm.xla_device()) + self.bn2 = torch.nn.BatchNorm2d(20).to(xm.xla_device()) # self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) # self.fc1 = nn.Linear(320, 50).to(xm.xla_device()) @@ -68,6 +69,7 @@ def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) # conv2d+mnist-treat x = self.bn1(x) x = F.relu(F.max_pool2d(self.conv2(x), 2)) + x = self.bn2(x) return x class SimpleWithLinear(torch.nn.Module): From f4b79fca6243b212e813139f2a07a2f77e002602 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 21:54:35 +0000 Subject: [PATCH 531/546] test --- torch_xla/csrc/init_python_bindings.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index b83323372b6..c149a8b1fa8 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -940,7 +940,8 @@ class PyLoweringContext { // int64_t parameter_idx = local_builder->GetProgramShape()->parameters_size(); // local_builder->parameter_numbers_; // GetProgramShape(); // int64_t parameter_idx = 7; // conv2d // int64_t parameter_idx = 11; // conv2d+mnist-treat // conv1 + bn1 - int64_t parameter_idx = 13; // conv1 + bn1 + conv2 + // int64_t parameter_idx = 13; // conv1 + bn1 + conv2 + int64_t parameter_idx = 15; // conv1 + bn1 + conv2 + bn2 // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { From 532e8c0697dffb16429d8aaee5d4380a2c4719a8 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 21:57:11 +0000 Subject: [PATCH 532/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c149a8b1fa8..c1e9e7cbaca 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -941,7 +941,7 @@ class PyLoweringContext { // int64_t parameter_idx = 7; // conv2d // int64_t parameter_idx = 11; // conv2d+mnist-treat // conv1 + bn1 // int64_t parameter_idx = 13; // conv1 + bn1 + conv2 - int64_t parameter_idx = 15; // conv1 + bn1 + conv2 + bn2 + int64_t parameter_idx = 17; // conv1 + bn1 + conv2 + bn2 // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { From 122270563625a0bdaa9125217f93009dc83f0eed Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 22:02:47 +0000 Subject: [PATCH 533/546] test --- test/test_test_mnist.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 8512b2b6707..8aafddacfd1 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -297,10 +297,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # print("simple_with_linear.named_parameters(): ", simple_with_linear.named_parameters()) # import pdb; pdb.set_trace() - # for name, param in simple_with_linear.named_parameters(): - # # print("arrive the loop") - # print("name: ", name) - # print("param: ", param) + for name, param in simple_with_linear.named_parameters(): + # print("arrive the loop") + print("name: ", name) + print("param: ", param) # if name in ['bias']: # print(param.size()) @@ -354,7 +354,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): bn_list = [] # bn_flag = False - # print("final additional_inputs: ", additional_inputs) + print("final additional_inputs: ", additional_inputs) + # print("in mnist additional_inputs: ", additional_inputs) ### linear 10*20 + 20*30 # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, w2_, b2_, output_value_real__, = _xla_while_loop( From 814f561a7017b2999d489f5b2087985d5b6207f9 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 22:17:49 +0000 Subject: [PATCH 534/546] test --- test/test_test_mnist.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 8aafddacfd1..731e41de55e 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -263,7 +263,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): for name, param in simple_with_linear.named_parameters(): if name[:2]=='bn': # bn_flag = True - bn_list.insert(-1, param) # dumpicate # continue # skip bn + # bn_list.insert(-1, param) # dumpicate # continue # skip bn + bn_list.append(param) # else: # bn_flag = False @@ -281,7 +282,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): ### !!! add at the tile if len(bn_list) !=0: output_value = res[-1] - res = res[:-1] + bn_list # + res[-1] + # res = res[:-1] + bn_list # + res[-1] + res = res[:-1] + bn_list.reverse() res.append(output_value) bn_list = [] # bn_flag = False @@ -332,7 +334,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): if name[:2]=='bn': # print("catch: ", name) # bn_flag = True - bn_list.insert(-1, param) # dumpicate # continue # skip bn + # bn_list.insert(-1, param) # dumpicate # continue # skip bn + bn_list.append(param) # print("newest bn_list: ", bn_list) # else: # bn_flag = False @@ -349,7 +352,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # if bn_flag and (len(bn_list) !=0): ### !!! add duplicated bn argus as the tile of the list if len(bn_list) !=0: - additional_inputs = additional_inputs + bn_list + # additional_inputs = additional_inputs + bn_list + additional_inputs = additional_inputs + bn_list.reverse() # print("added bn_list: ", bn_list) bn_list = [] # bn_flag = False From 84f27916143981352bc3da953f96beac67db0c90 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 22:22:05 +0000 Subject: [PATCH 535/546] test --- test/test_test_mnist.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 731e41de55e..a3b69782343 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -283,7 +283,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): if len(bn_list) !=0: output_value = res[-1] # res = res[:-1] + bn_list # + res[-1] - res = res[:-1] + bn_list.reverse() + bn_list.reverse() + res = res[:-1] + bn_list res.append(output_value) bn_list = [] # bn_flag = False @@ -353,7 +354,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): ### !!! add duplicated bn argus as the tile of the list if len(bn_list) !=0: # additional_inputs = additional_inputs + bn_list - additional_inputs = additional_inputs + bn_list.reverse() + bn_list.reverse() + additional_inputs = additional_inputs + bn_list # print("added bn_list: ", bn_list) bn_list = [] # bn_flag = False From 113dacce3b56050cb35fab8f45235a0db4ab25a9 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 22:23:48 +0000 Subject: [PATCH 536/546] test --- test/test_test_mnist.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index a3b69782343..0a2394e839c 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -354,7 +354,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): ### !!! add duplicated bn argus as the tile of the list if len(bn_list) !=0: # additional_inputs = additional_inputs + bn_list - bn_list.reverse() + bn_list.reverse() ### !!! reverse list for bn duplicate lists additional_inputs = additional_inputs + bn_list # print("added bn_list: ", bn_list) bn_list = [] @@ -370,7 +370,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): #### conv1+bn1 # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, output_value_real__, = _xla_while_loop( ##### conv1 + bn1 + conv2 - upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, output_value_real__, = _xla_while_loop( + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, output_value_real__, = _xla_while_loop( + ##### conv1 + bn1 + conv2 + bn2 + upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, output_value_real__, = _xla_while_loop( + cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) # (upper, lower, one_value, init_val, l_in_0, output_value), ()) From 94e75993e647003084e4ee0e6033b678d8e1b843 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 22:24:19 +0000 Subject: [PATCH 537/546] test --- test/test_test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 0a2394e839c..66eae803671 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -372,7 +372,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): ##### conv1 + bn1 + conv2 # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, output_value_real__, = _xla_while_loop( ##### conv1 + bn1 + conv2 + bn2 - upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, output_value_real__, = _xla_while_loop( + upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, output_value_real__, = _xla_while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) From d38b49436dcd18f72af8ab478d521b83885cc244 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 22:25:37 +0000 Subject: [PATCH 538/546] test --- test/test_test_mnist.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 66eae803671..60b041ad901 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -300,10 +300,10 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # print("simple_with_linear.named_parameters(): ", simple_with_linear.named_parameters()) # import pdb; pdb.set_trace() - for name, param in simple_with_linear.named_parameters(): - # print("arrive the loop") - print("name: ", name) - print("param: ", param) + # for name, param in simple_with_linear.named_parameters(): + # # print("arrive the loop") + # print("name: ", name) + # print("param: ", param) # if name in ['bias']: # print(param.size()) @@ -360,7 +360,7 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): bn_list = [] # bn_flag = False - print("final additional_inputs: ", additional_inputs) + # print("final additional_inputs: ", additional_inputs) # print("in mnist additional_inputs: ", additional_inputs) ### linear 10*20 + 20*30 From 180d7c5531661888e80be23d50f30728c0df8837 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 22:26:59 +0000 Subject: [PATCH 539/546] test --- test/test_test_mnist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 60b041ad901..93e15f034cd 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -70,6 +70,7 @@ def forward(self, x): x = self.bn1(x) x = F.relu(F.max_pool2d(self.conv2(x), 2)) x = self.bn2(x) + x = torch.flatten(x, 1) return x class SimpleWithLinear(torch.nn.Module): From a904700880c0565210e3a3f2f6b617c27e8b20b8 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 22:28:56 +0000 Subject: [PATCH 540/546] test --- test/test_test_mnist.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 93e15f034cd..667b056a838 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -56,6 +56,7 @@ def __init__(self): self.bn1 = torch.nn.BatchNorm2d(10).to(xm.xla_device()) self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5).to(xm.xla_device()) self.bn2 = torch.nn.BatchNorm2d(20).to(xm.xla_device()) + # self.fc1 = torch.nn.Linear(320, 50).to(xm.xla_device()) # self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) # self.fc1 = nn.Linear(320, 50).to(xm.xla_device()) @@ -324,7 +325,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # print(out.nelement()) # output_value = torch.zeros([16,10,28,28], dtype=torch.float32, device=device) # conv2d # output_value = torch.zeros([16,10,14,14], dtype=torch.float32, device=device) # conv2d+mnist-treat # conv1 + bn1 - output_value = torch.zeros([16,20,5,5], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + # output_value = torch.zeros([16,20,5,5], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + output_value = torch.zeros([16,500], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1 additional_inputs = [] From 1831d2c46d9b8d4323f0f5d19ff34786da138b12 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 22:29:36 +0000 Subject: [PATCH 541/546] test --- test/test_test_mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 667b056a838..06687ad8447 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -382,9 +382,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # (upper, lower, one_value, init_val, l_in_0, output_value), ()) print("finish newnewnew_test") print("torch_add_res__: run times: ", torch_add_res__) - print("actual res: ", output_value_real__[0][0][0]) + print("actual res: ", output_value_real__[0][0]) expected_ = simple_with_linear(l_in_0) - print("expected res: ", expected_[0][0][0]) + print("expected res: ", expected_[0][0]) # run test model def test_mnist(): From 2fad582bc289ae5ebf1d07e328826ab5238438f0 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 22:30:48 +0000 Subject: [PATCH 542/546] test --- test/test_test_mnist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 06687ad8447..8a131f97119 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -56,7 +56,7 @@ def __init__(self): self.bn1 = torch.nn.BatchNorm2d(10).to(xm.xla_device()) self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5).to(xm.xla_device()) self.bn2 = torch.nn.BatchNorm2d(20).to(xm.xla_device()) - # self.fc1 = torch.nn.Linear(320, 50).to(xm.xla_device()) + self.fc1 = torch.nn.Linear(320, 50).to(xm.xla_device()) # self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) # self.fc1 = nn.Linear(320, 50).to(xm.xla_device()) @@ -72,6 +72,7 @@ def forward(self, x): x = F.relu(F.max_pool2d(self.conv2(x), 2)) x = self.bn2(x) x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) return x class SimpleWithLinear(torch.nn.Module): From 6e7ae0995bfc9d2272563a4b53fdb592ed29e31e Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 22:31:39 +0000 Subject: [PATCH 543/546] test --- test/test_test_mnist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 8a131f97119..a171942cf3d 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -56,7 +56,8 @@ def __init__(self): self.bn1 = torch.nn.BatchNorm2d(10).to(xm.xla_device()) self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5).to(xm.xla_device()) self.bn2 = torch.nn.BatchNorm2d(20).to(xm.xla_device()) - self.fc1 = torch.nn.Linear(320, 50).to(xm.xla_device()) + self.fc1 = torch.nn.Linear(500, 50).to(xm.xla_device()) + # self.fc1 = torch.nn.Linear(320, 50).to(xm.xla_device()) # self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) # self.fc1 = nn.Linear(320, 50).to(xm.xla_device()) From 5004a745a818686f66a2faefcbea165178c3ff44 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Tue, 30 Apr 2024 22:33:09 +0000 Subject: [PATCH 544/546] test --- torch_xla/csrc/init_python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c1e9e7cbaca..ddd7045669b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -941,7 +941,7 @@ class PyLoweringContext { // int64_t parameter_idx = 7; // conv2d // int64_t parameter_idx = 11; // conv2d+mnist-treat // conv1 + bn1 // int64_t parameter_idx = 13; // conv1 + bn1 + conv2 - int64_t parameter_idx = 17; // conv1 + bn1 + conv2 + bn2 + int64_t parameter_idx = 19; // conv1 + bn1 + conv2 + bn2 // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { From 5fb1769182183dc20edb85025a461e0b131caa7d Mon Sep 17 00:00:00 2001 From: root Date: Wed, 1 May 2024 18:41:45 +0000 Subject: [PATCH 545/546] mnist --- test/test_test_mnist.py | 16 ++++++++++++---- torch_xla/csrc/init_python_bindings.cpp | 3 ++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index a171942cf3d..c3da2020635 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -57,6 +57,7 @@ def __init__(self): self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5).to(xm.xla_device()) self.bn2 = torch.nn.BatchNorm2d(20).to(xm.xla_device()) self.fc1 = torch.nn.Linear(500, 50).to(xm.xla_device()) + self.fc2 = torch.nn.Linear(50, 10).to(xm.xla_device()) # self.fc1 = torch.nn.Linear(320, 50).to(xm.xla_device()) # self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) # self.linear2 = torch.nn.Linear(20, 30).to(xm.xla_device()) @@ -74,7 +75,9 @@ def forward(self, x): x = self.bn2(x) x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) - return x + x = self.fc2(x) + return F.log_softmax(x, dim=1) + # return x class SimpleWithLinear(torch.nn.Module): def __init__(self): @@ -328,7 +331,9 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # output_value = torch.zeros([16,10,28,28], dtype=torch.float32, device=device) # conv2d # output_value = torch.zeros([16,10,14,14], dtype=torch.float32, device=device) # conv2d+mnist-treat # conv1 + bn1 # output_value = torch.zeros([16,20,5,5], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 - output_value = torch.zeros([16,500], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1 + # output_value = torch.zeros([16,500], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1 + # output_value = torch.zeros([16,50], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1 + fc1 + output_value = torch.zeros([16,10], dtype=torch.float32, device=device) # conv1 + bn1 + conv2 + bn2 + flatten1 + fc1 additional_inputs = [] @@ -377,8 +382,11 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): ##### conv1 + bn1 + conv2 # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, output_value_real__, = _xla_while_loop( ##### conv1 + bn1 + conv2 + bn2 - upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, output_value_real__, = _xla_while_loop( - + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, output_value_real__, = _xla_while_loop( + ##### conv1 + bn1 + conv2 + bn2 + fc1 + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, p7, p8, output_value_real__, = _xla_while_loop( + ##### conv1 + bn1 + conv2 + bn2 + fc1 + fc2 + softmax + upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, bw1, bw11, bb1, bb11, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, output_value_real__, = _xla_while_loop( cond_fn, body_fn, (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) # (upper, lower, one_value, init_val, l_in_0, output_value), ()) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ddd7045669b..a1adab814fb 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -941,7 +941,8 @@ class PyLoweringContext { // int64_t parameter_idx = 7; // conv2d // int64_t parameter_idx = 11; // conv2d+mnist-treat // conv1 + bn1 // int64_t parameter_idx = 13; // conv1 + bn1 + conv2 - int64_t parameter_idx = 19; // conv1 + bn1 + conv2 + bn2 + // int64_t parameter_idx = 19; // conv1 + bn1 + conv2 + bn2 + int64_t parameter_idx = 21; // conv1 + bn1 + conv2 + bn2 // int64_t parameter_idx = 9; // linear // int64_t parameter_idx = tensors.size(); for (auto& additional_input_tensor : additional_inputs_list) { From 94025fa13e47be9d1078434d3ecfae60af1ffb24 Mon Sep 17 00:00:00 2001 From: manfei Date: Mon, 6 May 2024 22:20:41 +0000 Subject: [PATCH 546/546] addd --- test/test_torch_xla_while_loop_test.py | 33 ++++++++++++++++++++++++++ torch_xla/csrc/xla_graph_executor.cpp | 3 +++ 2 files changed, 36 insertions(+) create mode 100644 test/test_torch_xla_while_loop_test.py diff --git a/test/test_torch_xla_while_loop_test.py b/test/test_torch_xla_while_loop_test.py new file mode 100644 index 00000000000..6573bcb3ede --- /dev/null +++ b/test/test_torch_xla_while_loop_test.py @@ -0,0 +1,33 @@ +import time +start_time = time.time() + +import torch +import torch_xla +import torch_xla.experimental.fori_loop +from torch._higher_order_ops.while_loop import while_loop +import torch_xla.core.xla_model as xm +import torch_xla.core.xla_builder as xb +import torch_xla.debug.profiler as xp + +# server = xp.start_server(9012) + +# xp.trace_detached( +# f'localhost:9012', +# '/root/profiles/', +# duration_ms=2000) + +device = xm.xla_device() + +def cond_fn(init, limit_value): + return limit_value[0] >= init[0] + +def body_fn(init, limit_value): + one_value = torch.ones(1, dtype=torch.int32, device=device) + return (torch.add(init, one_value), limit_value.clone()) + +init = torch.tensor([0], dtype=torch.int32, device=device) +limit_value = torch.tensor([1000], dtype=torch.int32, device=device) +res = while_loop(cond_fn, body_fn, (init, limit_value)) +print("res: ", res) + +print("--- %s seconds ---" % (time.time() - start_time)) diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index fe12e392ea4..8bb054f0fe1 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1284,6 +1284,9 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( {{"graph_hash", torch::lazy::HashToString(coll.hash)}}); }, tsl::profiler::TraceMeLevel::kInfo); + + TF_VLOG(3) << "We are running XLAGraphExecutor::Compile now"; + static const bool enable_aliasing = runtime::sys_util::GetEnvBool("XLA_ENABLE_PARAM_ALIASING", true); static const size_t parameter_wrapping_threadshold =