diff --git a/.github/workflows/tpu_ci.yml b/.github/workflows/tpu_ci.yml index d9269dc1a7ae..fc2cf06fec32 100644 --- a/.github/workflows/tpu_ci.yml +++ b/.github/workflows/tpu_ci.yml @@ -46,3 +46,4 @@ jobs: python3 -u test/test_autocast.py python3 -u test/dynamo/test_dynamo.py python3 -u test/spmd/test_spmd_debugging.py + python3 -u test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py diff --git a/test/run_tests.sh b/test/run_tests.sh index be98575c45e6..3b32fca2fc53 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -188,6 +188,7 @@ function run_xla_op_tests1 { function run_xla_op_tests2 { run_downcast_bf16 "$CDIR/test_data_type.py" run_test "$CDIR/pjrt/test_dtypes.py" + run_test "$CDIR/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py" run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU } diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py new file mode 100644 index 000000000000..30f51374c8fd --- /dev/null +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -0,0 +1,55 @@ +import os +import unittest +from typing import Callable, Dict, List + +import torch +import torch_xla +import torch_xla.experimental.fori_loop +from torch._higher_order_ops.while_loop import while_loop +import torch_xla.core.xla_model as xm +import torch_xla.core.xla_builder as xb + + +def _fake_while_loop(cond_fn, body_fn, operands): + # print("fake func operands: ", operands) + # print("fake func *operands: ", *operands) + # while cond_fn(*operands): + # operands = body_fn(*operands) + # return operands + while cond_fn(operands): + operands = body_fn(operands) + return operands + + +class WhileLoopTest(unittest.TestCase): + + def test_while_loop_tpu(self): + + device = xm.xla_device() + # ten = torch.ones(1, dtype=torch.int32, device=device) + # ten[0] = 10 + + def cond_fn(x): # x = (xi,) + ten = torch.ones(1, dtype=torch.int32, device=device) + # ten = torch.tensor([5], dtype=torch.int32, device=device) + return x[0] >= ten[0] # ==x[0] # torch.equal(x[0], ten) # x[0] <= ten # 30 + + def body_fn(x): # x = (xi,) + # onei = torch.tensor(10, dtype=torch.int32, device=device) + return (torch.sub(x[0], 1),) # onei,) + + # device = xm.xla_device() + # xi = torch.ones(1, dtype=torch.int32, device=device) + xi = torch.tensor([5], dtype=torch.int32, device=device) + # xi[0] = 5 + # xi.to(device) + # yi = torch.ones(1, dtype=torch.int32, device=device) + # yi = torch.tensor([1], dtype=torch.int32, device=device) + res = while_loop(cond_fn, body_fn, (xi,)) + expected = _fake_while_loop(cond_fn, body_fn, xi) + self.assertEqual(expected, res) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/xla_test_job.yaml b/test/tpu/xla_test_job.yaml index 5d12fc9db0ef..56db9f5c131d 100644 --- a/test/tpu/xla_test_job.yaml +++ b/test/tpu/xla_test_job.yaml @@ -61,6 +61,7 @@ spec: python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py python3 /src/pytorch/xla/test/pjrt/test_dtypes.py python3 /src/pytorch/xla/test/pjrt/test_dynamic_plugin_tpu.py + python3 /src/pytorch/xla/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py volumeMounts: - mountPath: /dev/shm name: dshm diff --git a/torch_xla/core/xla_builder.py b/torch_xla/core/xla_builder.py index 126b0e889d98..1fafd4132c12 100644 --- a/torch_xla/core/xla_builder.py +++ b/torch_xla/core/xla_builder.py @@ -762,6 +762,9 @@ def mkop(name, ops, **kwargs): builder = kwargs.get('builder', None) if builder is None: assert ops + # if not isinstance(ops, (list, tuple)): + # builder = torch_xla._XLAC._xla_op_builder(ops) + # else: builder = torch_xla._XLAC._xla_op_builder(ops[0]) return Op(torch_xla._XLAC._xla_op_create(builder, name, ops, kwargs)) diff --git a/torch_xla/core/xla_op_registry.py b/torch_xla/core/xla_op_registry.py index aba1c7076c39..8b14d590851d 100644 --- a/torch_xla/core/xla_op_registry.py +++ b/torch_xla/core/xla_op_registry.py @@ -41,6 +41,8 @@ def __call__(self, *args, **kwargs): self._computations[key] = computation if xu.getenv_as('XLA_OP_PRINT_COMPUTATIONS', bool, False): print(xb.get_computation_hlo(computation), file=sys.stderr) + print("777777777 args: ", args) + print("777777777 type args: ", type(args)) result = torch_xla._XLAC._xla_user_computation(self._opname, args, computation) return result[0] if len(result) == 1 else result diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 8e3403a923c2..617b8552df16 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -937,6 +937,8 @@ xla::StatusOr XlaHelpers::WrapXlaComputation( /*param_index=*/xla::ShapeIndex({input_index})); } + // xla::XlaOp a = xla::GetTupleElement(orig_result, 0); + return builder.Build(orig_result); } diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c023cd3a0dc4..cb75c4848c9e 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" @@ -671,7 +672,17 @@ py::object XlaNms(const at::Tensor& boxes, const at::Tensor& scores, std::vector XlaUserComputation( const std::string& opname, const std::vector& inputs, runtime::ComputationClient::ComputationPtr computation) { + std::cout << " !!!$$$: " << std::endl; + for (int i = 0; i < inputs.size(); i++) { + std::cout << inputs[i] << "; " << std::endl; + std::cout << inputs[i].type() << "; type !!!" << std::endl; + // inputs[i] = (inputs[i]); + } std::vector xinputs = GetXlaTensors(inputs, /*want_all=*/true); + // std::cout << " !!!$$$###: " << std::endl; + // for (int i = 0; i < xinputs.size(); i++) { + // std::cout << DumpUtil::ToText(xinputs[i]->CurrentIrValue().node.get()) << "; "; + // } std::vector xresults = tensor_methods::user_computation(opname, xinputs, std::move(computation)); std::vector results; @@ -685,7 +696,14 @@ std::vector XlaUserComputation( runtime::ComputationClient::ComputationPtr CreateComputation( const std::string& name, xla::XlaOp root) { + std::cout << "w's build func name: " << name << std::endl; + // std::cout << "w's build builder name: " << root.builder().name_ << std::endl; + // https://github.com/openxla/xla/blob/762bde36adf22792e91c38fe87cabe5af05bfadc/xla/client/xla_builder.cc#L710 xla::XlaComputation computation = ConsumeValue(root.builder()->Build(root)); + // std::vector> input_output_alias_pair; + // xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); + // computation = ConsumeValue(XlaHelpers::WrapXlaComputation( + // computation, program_shape.parameters(), input_output_alias_pair)); return std::make_shared( name, std::move(computation)); } @@ -876,6 +894,11 @@ void BuildProfilerSubmodule(py::module* m) { class PyLoweringContext { public: + PyLoweringContext(const std::string& name) : PyLoweringContext(name, bridge::GetCurrentDevice()) {} + + PyLoweringContext(const std::string& name, torch::lazy::BackendDevice device) + : lowering_ctx(name, device) {} + PyLoweringContext() : PyLoweringContext(bridge::GetCurrentDevice()) {} PyLoweringContext(torch::lazy::BackendDevice device) @@ -883,6 +906,7 @@ class PyLoweringContext { // Builds a HLO graph given a set of output tensors. void Build(std::vector tensors) { + // std::cout<< "let's see how many timed this was called? !!!" << GetNameString() << std::endl; // Get the backing XLA tensors from the output torch tensor handles std::vector xtensors = GetXlaTensors(tensors, /*want_all=*/true); @@ -894,13 +918,31 @@ class PyLoweringContext { ir_values.push_back(value); } + // // check computation name + // XLA_ERROR() << computation.proto().name(); + // Lower the graph using the output IR values for (auto& ir_value : ir_values) { xla::XlaOp root = lowering_ctx.GetOutputOp( torch::lazy::Output(ir_value.node.get(), ir_value.index)); + // if (computation.proto().name()=='condctx') { + // xla::XlaOp a = xla::GetTupleElement(root, 0); // they are not tupled here + // } lowering_ctx.AddResult(root); } computation = ConsumeValue(lowering_ctx.BuildXla()); + + std::vector> input_output_alias_pair; + xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); + bool should_wrap_parameter = (program_shape.parameters_size() >= 2); // true; + if (should_wrap_parameter) { + computation = ConsumeValue(XlaHelpers::WrapXlaComputation( + computation, program_shape.parameters(), input_output_alias_pair)); + } + + // // unwrap (pred[]) + // xla::XlaBuilder builder(computation.proto().name()); + // xla::XlaOp orig_result = xla::Call(&builder, computation, inner_params); } // Get a mapping from the HLO input parameters to the backing Tensor values. @@ -983,6 +1025,22 @@ class PyLoweringContext { return result; } + void SetNameString(const std::string& name) { + lowering_ctx.setnamestring(name); + } + + std::string GetNameString() { + return lowering_ctx.getnamestring(); + } + + // LoweringContext GetLoweringCtx() { + // return lowering_ctx; + // } + + // LoweringContext SetLoweringCtxName(const std::string name) { + // lowering_ctx.builder().name_ = name; + // } + private: LoweringContext lowering_ctx; xla::XlaComputation computation; @@ -1027,7 +1085,9 @@ void BuildLoweringContextSubmodule(py::module* m) { .def("hlo_json", &PyLoweringContext::GetHloJsonText) .def("parameter_id_tensor_mapping", &PyLoweringContext::GetParameterIdTensorMapping) - .def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId); + .def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId) + .def("setnamestring", &PyLoweringContext::SetNameString) + .def("getnamestring", &PyLoweringContext::GetNameString); } void InitXlaModuleBindings(py::module m) { diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 644d4ee7ca9d..6f4297c12cf9 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -154,9 +154,34 @@ void LoweringContext::SetResult(size_t index, xla::XlaOp op) { xla::StatusOr LoweringContext::BuildXla() { xla::StatusOr xla; - if (!root_tuple_.empty()) { + // if (builder_.name() == 'bodyctx') { + // XLA_ERROR() << builder_.name(); + // } + std::cout << "???" << getnamestring(); + if (!root_tuple_.empty() & (root_tuple_.size()>1)) { xla::XlaOp root = xla::Tuple(builder(), root_tuple_); + // xla::XlaOp a = xla::GetTupleElement(root, 0); xla = builder()->Build(root); + } else if (!root_tuple_.empty() & (root_tuple_.size()==1)) { + // xla::XlaOp root = xla::Tuple(builder(), root_tuple_); + // xla::XlaOp a = xla::GetTupleElement(root, 0); + // const xla::Shape& root_shape = ShapeHelper::ShapeOfXlaOp(root_tuple_.at(0)); + // xla::XlaBuilder cb("predone"); + // xla::Shape xla_scalar_shape = xla::ShapeUtil::MakeShape(element_type, {}); + // xla::XlaOp p0 = xla::Parameter(&cb, 0, xla_scalar_shape, "p0"); + // Below are to untuple the parameters + // xla = builder()->Build(root_tuple_.at(0)); // root); + const std::string condctx = "condctx"; + const std::string bodyctx = "bodyctx"; + // std::cout << "???" << builder()->name(); + const std::string currentname = getnamestring(); + if ((currentname == condctx) or (currentname == bodyctx)) { // == "condctx") { + xla = builder()->Build(root_tuple_.at(0)); // root); + } + else { + xla::XlaOp root = xla::Tuple(builder(), root_tuple_); + xla = builder()->Build(root); + } } else { xla = builder()->Build(); } diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index b8751673fb68..acd346af8de1 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -34,6 +35,10 @@ class LoweringContext : public torch::lazy::LoweringContext { xla::XlaBuilder* builder() { return &builder_; } + void setnamestring(const std::string& name) { name_ = name; std::cout << "LoweringContext~~~??>>: " << name_ << std::endl;} + + const std::string& getnamestring() { return name_; } + StackFrameIndexBuilder* stack_frame_index_builder() { return stack_frame_index_builder_.get(); } @@ -121,6 +126,7 @@ class LoweringContext : public torch::lazy::LoweringContext { parameters_map_; std::vector root_tuple_; OutputMap emitted_outputs_; + std::string name_; std::shared_ptr stack_frame_index_builder_; }; // namespace torch_xla diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index e0f2c6f47b6f..5041b37c540c 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -629,9 +629,14 @@ std::vector user_computation( runtime::ComputationClient::ComputationPtr computation) { XLA_CHECK(!inputs.empty()); std::vector input_values; + std::vector root_nodes; for (auto& input : inputs) { - input_values.push_back(input->GetIrValue()); + torch::lazy::Value ir_value = input->GetIrValue(); + input_values.push_back(ir_value); + root_nodes.push_back(ir_value.node.get()); } + std::string graph_str = DumpUtil::ToText(root_nodes); + std::cout << "inputs' torch::lazy::Node are ##@#@#@#@#: " << graph_str << std::endl; torch::lazy::NodePtr node = torch::lazy::MakeNode( torch::lazy::OpKind::Get(opname), input_values, std::move(computation)); // Cast can be one of the user computation and we don't want to inherit the diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py new file mode 100644 index 000000000000..f47c5ab314bd --- /dev/null +++ b/torch_xla/experimental/fori_loop.py @@ -0,0 +1,247 @@ +import numpy as np +import torch +import torch_xla +import torch_xla.core.xla_builder as xb +import torch_xla.core.xla_model as xm +import torch_xla.utils.utils as xu +import torch_xla.core.xla_op_registry as xor + +from torch._C import DispatchKey +from torch._ops import HigherOrderOperator +import torch._higher_order_ops.while_loop +from torch._higher_order_ops.while_loop import while_loop_op + + +@while_loop_op.py_impl(DispatchKey.XLA) +def while_loop(cond_fn, body_fn, operands): + # cond_fn&body_fn: callable + # operands: (Tuple of possibly nested dict/list/tuple of tensors) + return _xla_while_loop(cond_fn, body_fn, operands) + + +def _xla_while_loop(cond_fn, body_fn, operands): + + # def op_fn(operands):# internal_x): + # # TODO(manfei): replace cond_fn_placeholder and body_fn_placeholder after confirm xlacomputation could be in xla::while + # ## print body/cond type + # print("cond_fn type: ", type(cond_fn)) + # print("body_fn type: ", type(body_fn)) + # print("operands type: ", type(operands)) + + # ## trans body/cond to xlacomputation + # xm.mark_step() + # cond_result = cond_fn(operands) + # cond_ctx = torch_xla._XLAC.lowering.LoweringContext() + # cond_ctx_builder = cond_ctx.builder() + # cond_ctx_builder.name_ = 'condctx' + # cond_ctx.build([cond_result]) + # cond_hlo = cond_ctx.hlo() + # cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) + + # xm.mark_step() + # body_result = body_fn(operands) + # body_ctx = torch_xla._XLAC.lowering.LoweringContext() + # # body_ctx_builder = ctx.builder() + # # body_ctx_builder.name_ = 'bodyctx' + # body_ctx.build([body_result]) + # body_hlo = body_ctx.hlo() + # body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) + + # # def cond_fn_placeholder(counter, operands): + # # return counter < xb.Op.scalar((operands[0]).builder(), 10, dtype=xb.Type.S32) + # # return counter < xb.Op.scalar((internal_x).builder(), 10, dtype=xb.Type.S32) + + # # def body_fn_placeholder(counter, internal_x): + # # next_counter = counter + xb.Op.scalar( + # # counter.builder(), 1, dtype=xb.Type.S32) + # # internal_x = internal_x + xb.Op.scalar( + # # internal_x.builder(), 1, dtype=xb.Type.S32) + # # return xb.Op.tuple((next_counter, internal_x)) + + # # zero = xb.Op.scalar(internal_x.builder(), 0, dtype=xb.Type.S32) + # # w = xb.Op.mkwhile((zero, internal_x), cond_fn_placeholder, + # # body_computation) + + # ## trest operands + # input_tuple = Op.tuple(operands) + # w = input_tuple.while_loop( + # condition_computation=cond_computation, body_computation=body_computation) + + # return w.get_tuple_element(1) + + # op = xor.register('test_while', op_fn) + print("type operands: ", type(operands)) + print("operands: ", operands) + kwargs = {} + shapes = xb.tensor_shape(operands) + builder = xb.create_builder('test_while') + params = [] + secondparams = [] + for shape in shapes: + p = xb.mkparam(builder, len(params), shape) + params.append(p) # single_tuple) + # single_tuple = xb.Op.tuple([p]) + secondparams.append(xb.Op.tuple([p])) + + # secondparams = [] + # for shape in shapes: + # p = xb.mkparam(builder, len(secondparams), shape) + # single_tuple = xb.Op.tuple([p]) + # secondparams.append(single_tuple) # p) # single_tuple) + + xm.mark_step() + cond_result = cond_fn(operands) + cond_ctx = torch_xla._XLAC.lowering.LoweringContext() # "condctx") + # print("type cond_ctx: ", type(cond_ctx)) + cond_ctx.setnamestring("condctx") + # cond_builder = xb.create_builder('condctx') + # cond_ctx_builder = cond_ctx.GetLoweringCtx().builder() + # cond_ctx_builder.name_ = 'condctx' + cond_ctx.build(list(cond_result)) + cond_hlo = cond_ctx.hlo() + cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) + # cond_hlo_print = xb.get_computation_hlo(cond_computation) + # print("cond_hlo: !!!!!!!!!") + # print(cond_hlo_print) + + xm.mark_step() + body_result = body_fn(operands) + body_ctx = torch_xla._XLAC.lowering.LoweringContext() # "bodyctx") + body_ctx.setnamestring("bodyctx") + # body_ctx_builder = body_ctx.builder() + # body_ctx_builder.name_ = 'bodyctx' + body_ctx.build(list(body_result)) + body_hlo = body_ctx.hlo() + body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) + # body_hlo_print = xb.get_computation_hlo(body_computation) + # print("body_hlo: !!!!!!!!!") + # print(body_hlo_print) + + input_tuple = xb.Op.tuple(params) + aaa_tuple = xb.Op.get_tuple_element(input_tuple, 0) # maybe move it to the cycle? + print("aaa_tuple: ", aaa_tuple) + print("[aaa_tuple.op]: ", [aaa_tuple.op]) + print("type [aaa_tuple.op]: ", type([aaa_tuple.op])) + print("(aaa_tuple.op): ", (aaa_tuple.op)) + print("type (aaa_tuple.op): ", type((aaa_tuple.op))) + w = xb.mkop('While', [aaa_tuple.op], condition_computation=cond_computation, body_computation=body_computation) + # w = xb.mkop('While', (aaa_tuple.op,), condition_computation=cond_computation, body_computation=body_computation) + # w = xb.mkop('While', aaa_tuple.op, condition_computation=cond_computation, body_computation=body_computation) + # w # + print("pass this line") + name = 'fori_loop_ed_torch_func' + computation = w.build(name) + print("pass this line second @@@@@@@@@@@") + + while_loop_hlo_print = xb.get_computation_hlo(computation) + print("while_loop_hlo: !!!!!!!!!") + print(while_loop_hlo_print) + + # root = fn(*params, **kwargs) + # computation = root.build(name) + + # print("operands type: ", type(operands)) # + # print("operands: ", operands) # (tensor([1], device='xla:0', dtype=torch.int32),) + # print("operands[0]: ", operands[0]) # tensor([1], device='xla:0', dtype=torch.int32) + # print("type operands[0]: ", type(operands[0])) # + # print("type [operands[0],]: ", type([operands[0],])) # + # print("type (operands[0],): ", type((operands[0],))) # + # print("type [(operands[0],),]: ", type([(operands[0],),])) # + # localoperands = torch.tensor(1, dtype=torch.int32, device=xm.xla_device()) + # localoperands = torch.tensor([1], dtype=torch.int32, device=xm.xla_device()) + # print("localoperands: ", localoperands) # tensor(1, device='xla:0', dtype=torch.int32) + # print("999999 secondparams: ", secondparams) + # print("999999 type secondparams: ", type(secondparams)) + # result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', secondparams, computation) + result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', operands, + computation) + # result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', [localoperands,], + # computation) + # _xla_user_computation: + # [](const std::string& opname, const std::vector& inputs, + # const runtime::ComputationClient::ComputationPtr& computation) { + print("done the result!!!") + print("result: ", result) + # op = result[0] if len(result) == 1 else result + + + return result # xu.as_list(op(operands)) + +# --------------------------------------- +# import torch +# import torch_xla +# import torch_xla.core.xla_builder as xb +# import torch_xla.core.xla_model as xm + +# device = xm.xla_device() +# a = torch.rand(1, device=device) +# b = torch.rand(1, device=device) +# c = torch.ones(1, device=device) + +# name = 'fori_loop_ed_torch_func' +# opname = 'xla::_op_' + name +# kwargs = {} + +# inputss = (a, b, c) +# shapess = xb.tensor_shape(inputss) + +# builder = xb.create_builder(name) +# params = [] +# p = xb.mkparam(builder, len(params), shapess[0]) # TODO: change to for...in +# params.append(p) +# p = xb.mkparam(builder, len(params), shapess[1]) +# params.append(p) +# p = xb.mkparam(builder, len(params), shapess[2]) +# params.append(p) + +# def body_func(a, b, c): +# return torch.add(a, b) + +# xm.mark_step() +# result = body_func(a, b, c) +# ctx = torch_xla._XLAC.lowering.LoweringContext() +# # body_ctx_builder = ctx.builder() +# # body_ctx_builder.name_ = 'bodyctx' +# ctx.build([result]) +# hlo = ctx.hlo() +# # hlo_text = ctx.hlo_text() + +# def cond_func(a, b, c): +# return c < xb.Op.scalar(c.builder(), 10, dtype=xb.Type.F32) +# # c = c + 1 +# # return c < 10 + +# input_tuple = xb.Op.tuple(params) # shapess +# cond_root = cond_func(*params, **kwargs) +# cond_computation = cond_root.build(name) +# print("finish cond computation creation") + +# xm.mark_step() +# cond_result = cond_func(a, b, c) +# cond_ctx = torch_xla._XLAC.lowering.LoweringContext() +# # cond_ctx_builder = cond_ctx.builder() +# # cond_ctx_builder.name_ = 'condctx' +# cond_ctx.build([cond_result]) +# cond_hlo = cond_ctx.hlo() +# cond_hlo_text = cond_ctx.hlo_text() + +# body_computation = xb.computation_from_module_proto("bodycomputation", hlo) +# cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) + +# body_hlo_print = xb.get_computation_hlo(body_computation) +# cond_hlo_print = xb.get_computation_hlo(cond_computation) + +# print("body_hlo: !!!!!!!!!") +# print(body_hlo_print) +# print("cond_hlo: !!!!!!!!!") +# print(cond_hlo_print) + +# input_tuple = xb.Op.tuple(params) # shapess +# 1: +# w = input_tuple.while_loop(condition_computation=cond_computation, body_computation=body_computation) # w: +# 2: +# condition_computation = Op.make_computation('Condition', condition_computation, (input_tuple,)) +# body_computation = Op.make_computation('Body', body_computation, (input_tuple,)) +# w = xb.mkop('While', (input_tuple.op,), condition_computation=cond_computation, body_computation=body_computation) +# w # +# w.build(name)