diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py index 8a1f2bdb737..5488d11aa43 100644 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py @@ -7,6 +7,7 @@ # We need to import the underlying implementation function to register with the dispatcher import torch_xla.experimental.fori_loop from torch_xla.experimental.fori_loop import fori_loop +from torch_xla.experimental.fori_loop import _xla_while_loop, get_module_parameters from torch._higher_order_ops.while_loop import while_loop import torch_xla.core.xla_model as xm import torch_xla.core.xla_builder as xb @@ -121,6 +122,40 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): return self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + def test_while_loop_tpu_simple_linear_clean(self): + + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def cond_fn(upper, lower, one_value, x, input_value, output_value): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value): + new_lower = torch.add(one_value, lower) + output_value = linear_0(input_value) + return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + one_value, x), input_value.clone(), output_value.clone() + # return upper.clone(), new_lower.clone(), one_value.clone(), torch.add( + # one_value, x), input_value.clone(), output_value.clone() + + upper = torch.tensor([1], dtype=torch.int32, device=device) + lower = torch.tensor([0], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + upper__, lower__, one_value__, torch_add_res__, input_value__, bias__, weight__, output_value_real__, = while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value)) + + expected = _fake_fori_loop(lower, upper, linear_0, l_in_0) + + return self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + def test_while_loop_tpu_simple_linear_class(self): xm.mark_step() @@ -180,6 +215,135 @@ def body_fn(upper, lower, one_value, x, input_value, output_value): self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) return aaa + def test_while_loop_tpu_simple_linear_class_clean(self): + + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + # define simple linear model class + class SimpleWithLinear(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def forward(self, one_value, x, input_value): + output_value_real = self.linear(input_value) + torch_add_res = torch.add(one_value, x) + return torch_add_res, output_value_real + + simple_with_linear = SimpleWithLinear() + + # define cond and body + def cond_fn(upper, lower, one_value, x, input_value, output_value, *args): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value, *args): + new_lower = torch.add(one_value, lower) + output_value_real = simple_with_linear(one_value, x, input_value) + res = [upper.clone(), new_lower.clone(), one_value.clone(), output_value_real[0], input_value.clone(), output_value_real[1]] + res = get_module_parameters(res, simple_with_linear, contain_output_balue=True) + return tuple(res) + + # simple_with_linear = SimpleWithLinear() + upper = torch.tensor([52], dtype=torch.int32, device=device) + lower = torch.tensor([1], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, device=xm.xla_device()) + output_value = torch.zeros([20], dtype=torch.float32, device=device) + + aaa = { + "simple_with_linear": + (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, + output_value)) + } + + additional_inputs = get_module_parameters([], simple_with_linear, contain_output_balue=False) + + upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, output_value_real__, = _xla_while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) + + print("while_loop run times: ", torch_add_res__) + print("output_value_real__: ", output_value_real__) + + expected = simple_with_linear(one_value, init_val, l_in_0) + expected = expected[-1] + print("expected: ", expected) + + self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + return aaa + + def test_while_loop_tpu_simple_linear_class_clean_only_linear(self): + + xm.mark_step() + device = xm.xla_device() + torch.set_grad_enabled(False) + + # define simple linear model class + class SimpleWithLinear(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 20).to(xm.xla_device()) + + def forward(self, x): + x = self.linear(x) + return x + + simple_with_linear = SimpleWithLinear() + + # define cond and body + def cond_fn(upper, lower, one_value, x, input_value, output_value, *args): + return lower[0] < upper[0] + + def body_fn(upper, lower, one_value, x, input_value, output_value, *args): + new_lower = torch.add(one_value, lower) + output_value_real = simple_with_linear(input_value) + res = [upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x).clone(), input_value.clone(), output_value_real.clone()] + res = get_module_parameters(res, simple_with_linear, contain_output_balue=False) + return tuple(res) + + # simple_with_linear = SimpleWithLinear() + upper = torch.tensor([52], dtype=torch.int32, device=device) + lower = torch.tensor([1], dtype=torch.int32, device=device) + one_value = torch.tensor([1], dtype=torch.int32, device=device) + init_val = torch.tensor([1], dtype=torch.int32, device=device) + l_in_0 = torch.rand(10, dtype=torch.float32, device=device) # xm.xla_device()) # float + output_value = torch.zeros([20], dtype=torch.float32, device=device) + print("output_value: ", output_value) + + aaa = { + "simple_with_linear": + (simple_with_linear, (upper, lower, one_value, init_val, l_in_0, + output_value)) + } + + additional_inputs = get_module_parameters([], simple_with_linear, contain_output_balue=False) + # print("additional_inputs: ", additional_inputs) + + # while_loop + upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, output_value_real__, = while_loop( + cond_fn, body_fn, + (upper, lower, one_value, init_val, l_in_0, output_value)) + + # _xla_while_loop + # upper__, lower__, one_value__, torch_add_res__, input_value__, weight1__, bias1__, output_value_real__, = _xla_while_loop( + # cond_fn, body_fn, + # (upper, lower, one_value, init_val, l_in_0, output_value), tuple(additional_inputs)) + + print("while_loop run times: ", torch_add_res__) + print("output_value_real__: ", output_value_real__) + + expected = simple_with_linear(l_in_0) + print("expected: ", expected) + + self.assertTrue(torch.all(torch.eq(expected, output_value_real__))) + return aaa + + def test_fori_loop_tpu_addition(self): xm.mark_step() @@ -218,6 +382,12 @@ def test_fori_loop_tpu_simple_linear(self): self.assertTrue(torch.all(torch.eq(expected, l_out_))) + def test_get_xlacomputation(self): + xla_device = xm.xla_device() + t1 = torch.randn(20, 5).to(xla_device) + t2 = torch.randn(20, 5).to(xla_device) + # expected_data_handles = torch_xla._XLAC._get_xla_computation([t1], [], False) + print("finish test here") if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index e20e28fbb8f..023ec59372f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -706,6 +706,87 @@ runtime::ComputationClient::ComputationPtr CreateComputationFromProto( name, std::move(computation)); } +// // std::shared_ptr +// // XLAGraphExecutor::CompilationResult +// // runtime::ComputationClient::ComputationPtr +// // XLAGraphExecutor::saveComputation* +// // std::vector +// // XLAGraphExecutor::CachedComputation +// runtime::ComputationClient::ComputationPtr GetXLAComputation(std::vector tensors, +// absl::Span devices, bool warm_up_cache_only = false) { +// tsl::profiler::TraceMe activity("GetXLAComputation", tsl::profiler::TraceMeLevel::kInfo); +// SyncTensorsConfig config; +// config.force_ltc_data = false; +// SyncTensorCollection coll = CollectSyncTensors(*tensors, config); +// if (coll.indices.empty()) { +// TensorCollectionBarrier(&coll); +// return nullptr; +// } +// DebugUtil::SaveTensorsGraphInfo("ScheduleSyncTensorsGraph", *tensors, +// &coll.indices); +// std::vector ir_values; +// std::vector tensor_data_vec; +// ExtractIRAndPrepareXlaData_(tensors, coll.config, coll.indices, ir_values, +// tensor_data_vec); +// PostOrderData po_data = RunPostOrder(ir_values, &coll); +// coll.hash = torch::lazy::HashCombine( +// coll.hash, torch::lazy::Hash(po_data.parameter_sequence)); +// if (GetAliasWithBufferDonorConfig()) { +// std::vector buffer_donor_index = +// GetBufferDonorIndexFromUserConfig(po_data.parameters_data); +// if (buffer_donor_index.size() > 0) { +// // Do not include hash on a empty vector. +// coll.hash = torch::lazy::HashCombine( +// coll.hash, torch::lazy::Hash(buffer_donor_index)); +// } +// } +// { +// // Auto-sharding configs +// coll.hash = torch::lazy::HashCombine( +// coll.hash, torch::lazy::MHash(ShardingUtil::GetAutoSharding())); +// coll.hash = torch::lazy::HashCombine( +// coll.hash, +// torch::lazy::StringHash( +// runtime::sys_util::GetEnvString("XLA_AUTO_SPMD_MESH", "").c_str())); +// } + +// DebugUtil::SaveGraphHash(coll.hash); +// TF_VLOG(4) << "Parameter sequence graph hash " +// << torch::lazy::HashToString(coll.hash); + +// std::pair> cache_res = +// TryRunCachedSync(tensors, &coll, &po_data, tensor_data_vec, +// warm_up_cache_only); +// if (cache_res.first) { +// // we have a cache hit, execution has been scheduled by TryRunCachedSync. +// return cache_res.second; +// } + +// // CompilationResult compile_result = +// // Compile(*tensors, devices, coll, &po_data, ir_values); + +// // runtime::ComputationClient::ComputationPtr +// // saveComputation* compile_result = std::move( +// // Compile(*tensors, devices, coll, &po_data, ir_values).computation) +// XLAGraphExecutor::saveComputation* compile_result = Compile(*tensors, devices, coll, &po_data, ir_values).computation +// return compile_result + +// // TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", compile_result.emitted_nodes); +// // TF_VLOG(5) << "TensorsGraphSize=" << compile_result.emitted_nodes; +// // auto cached_computation = std::make_shared( +// // std::move(compile_result.computation), compile_result.is_sharded); +// // GetComputationCache()->Add(coll.hash, cached_computation); + +// // if (warm_up_cache_only) { +// // return nullptr; +// // } else { +// // return ScheduleSyncTensorsGraph( +// // tensors, &coll, std::move(compile_result.parameters_data), +// // compile_result.device.toString(), std::move(cached_computation), +// // tensor_data_vec); +// // } +// } + xla::Shape GetTensorShape(const at::Tensor& tensor, const std::string& device_str) { auto xtensor = bridge::TryGetXlaTensor(tensor); @@ -885,6 +966,8 @@ class PyLoweringContext { public: PyLoweringContext() : PyLoweringContext(bridge::GetCurrentDevice()) {} + // PostOrderData po_data = RunPostOrder(ir_values, &coll); + PyLoweringContext(torch::lazy::BackendDevice device) : lowering_ctx("PyLoweringContext", device) {} @@ -2379,6 +2462,17 @@ void InitXlaModuleBindings(py::module m) { BuildProfilerSubmodule(&m); BuildLoweringContextSubmodule(&m); + m.def("_get_xla_computation", [](const std::vector& tensors, + const std::vector& devices, const bool warm_up_cache_only) { + std::vector xtensors; + xtensors.reserve(tensors.size()); + for (auto& tensor : tensors) { + xtensors.push_back(bridge::GetXlaTensor(tensor)); + } + + runtime::ComputationClient::ComputationPtr xla_computation = XLAGraphExecutor::Get()->GetXLAComputation(xtensors, {}, true); + return xla_computation; + }); m.def("_get_tensors_handle", [](const std::vector& tensors) -> std::vector { std::vector handles; diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index fe12e392ea4..91f40fb9646 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1284,6 +1284,9 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( {{"graph_hash", torch::lazy::HashToString(coll.hash)}}); }, tsl::profiler::TraceMeLevel::kInfo); + + TF_VLOG(3) << "We are running XLAGraphExecutor::Compile now"; + static const bool enable_aliasing = runtime::sys_util::GetEnvBool("XLA_ENABLE_PARAM_ALIASING", true); static const size_t parameter_wrapping_threadshold = @@ -1514,4 +1517,170 @@ XLAGraphExecutor::SyncTensorsGraphInternal( } } +// // std::shared_ptr +// // XLAGraphExecutor::CompilationResult +// // runtime::ComputationClient::ComputationPtr +// // XLAGraphExecutor::saveComputation* +// // std::vector +// XLAGraphExecutor::CachedComputation XLAGraphExecutor::GetXLAComputation(std::vector* tensors, +// absl::Span devices, bool warm_up_cache_only = false) { +// tsl::profiler::TraceMe activity("GetXLAComputation", +// tsl::profiler::TraceMeLevel::kInfo); +// SyncTensorsConfig config; +// config.force_ltc_data = false; +// SyncTensorCollection coll = CollectSyncTensors(*tensors, config); +// if (coll.indices.empty()) { +// TensorCollectionBarrier(&coll); +// return nullptr; +// } +// DebugUtil::SaveTensorsGraphInfo("ScheduleSyncTensorsGraph", *tensors, +// &coll.indices); +// std::vector ir_values; +// std::vector tensor_data_vec; +// ExtractIRAndPrepareXlaData_(tensors, coll.config, coll.indices, ir_values, +// tensor_data_vec); +// PostOrderData po_data = RunPostOrder(ir_values, &coll); +// coll.hash = torch::lazy::HashCombine( +// coll.hash, torch::lazy::Hash(po_data.parameter_sequence)); +// if (GetAliasWithBufferDonorConfig()) { +// std::vector buffer_donor_index = +// GetBufferDonorIndexFromUserConfig(po_data.parameters_data); +// if (buffer_donor_index.size() > 0) { +// // Do not include hash on a empty vector. +// coll.hash = torch::lazy::HashCombine( +// coll.hash, torch::lazy::Hash(buffer_donor_index)); +// } +// } +// { +// // Auto-sharding configs +// coll.hash = torch::lazy::HashCombine( +// coll.hash, torch::lazy::MHash(ShardingUtil::GetAutoSharding())); +// coll.hash = torch::lazy::HashCombine( +// coll.hash, +// torch::lazy::StringHash( +// runtime::sys_util::GetEnvString("XLA_AUTO_SPMD_MESH", "").c_str())); +// } + +// DebugUtil::SaveGraphHash(coll.hash); +// TF_VLOG(4) << "Parameter sequence graph hash " +// << torch::lazy::HashToString(coll.hash); + +// std::pair> cache_res = +// TryRunCachedSync(tensors, &coll, &po_data, tensor_data_vec, +// warm_up_cache_only); +// if (cache_res.first) { +// // we have a cache hit, execution has been scheduled by TryRunCachedSync. +// return cache_res.second; +// } + +// // CompilationResult compile_result = +// // Compile(*tensors, devices, coll, &po_data, ir_values); + +// // runtime::ComputationClient::ComputationPtr +// // saveComputation* compile_result = std::move( +// // Compile(*tensors, devices, coll, &po_data, ir_values).computation) +// XLAGraphExecutor::saveComputation* compile_result = Compile(*tensors, devices, coll, &po_data, ir_values).computation +// return compile_result + +// // TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", compile_result.emitted_nodes); +// // TF_VLOG(5) << "TensorsGraphSize=" << compile_result.emitted_nodes; +// // auto cached_computation = std::make_shared( +// // std::move(compile_result.computation), compile_result.is_sharded); +// // GetComputationCache()->Add(coll.hash, cached_computation); + +// // if (warm_up_cache_only) { +// // return nullptr; +// // } else { +// // return ScheduleSyncTensorsGraph( +// // tensors, &coll, std::move(compile_result.parameters_data), +// // compile_result.device.toString(), std::move(cached_computation), +// // tensor_data_vec); +// // } +// } + +runtime::ComputationClient::ComputationPtr XLAGraphExecutor::GetXLAComputation( + std::vector& tensors, + absl::Span devices, bool warm_up_cache_only) { + // coll + SyncTensorsConfig config; + config.force_ltc_data = false; + SyncTensorCollection coll = CollectSyncTensors(tensors, config); + if (coll.indices.empty()) { + TensorCollectionBarrier(&coll); + return nullptr; + } + DebugUtil::SaveTensorsGraphInfo("ScheduleSyncTensorsGraph", tensors, + &coll.indices); + + // ir_values + std::vector ir_values; + std::vector tensor_data_vec; + ExtractIRAndPrepareXlaData_(tensors, coll.config, coll.indices, ir_values, + tensor_data_vec); + // PostOrderData po_data = RunPostOrder(ir_values, &coll); + + // coll.hash = torch::lazy::HashCombine( + // coll.hash, torch::lazy::Hash(po_data.parameter_sequence)); + + // if (GetAliasWithBufferDonorConfig()) { + // std::vector buffer_donor_index = + // GetBufferDonorIndexFromUserConfig(po_data.parameters_data); + // if (buffer_donor_index.size() > 0) { + // // Do not include hash on a empty vector. + // coll.hash = torch::lazy::HashCombine( + // coll.hash, torch::lazy::Hash(buffer_donor_index)); + // } + // } + + // { + // // Auto-sharding configs + // coll.hash = torch::lazy::HashCombine( + // coll.hash, torch::lazy::MHash(ShardingUtil::GetAutoSharding())); + // coll.hash = torch::lazy::HashCombine( + // coll.hash, + // torch::lazy::StringHash( + // runtime::sys_util::GetEnvString("XLA_AUTO_SPMD_MESH", "").c_str())); + // } + + // DebugUtil::SaveGraphHash(coll.hash); + // TF_VLOG(4) << "Parameter sequence graph hash " + // << torch::lazy::HashToString(coll.hash); + + // std::pair> cache_res = + // TryRunCachedSync(tensors, &coll, &po_data, tensor_data_vec, + // warm_up_cache_only); + // if (cache_res.first) { + // // we have a cache hit, execution has been scheduled by TryRunCachedSync. + // return cache_res.second; + // } + + // CompilationResult compile_result = + // Compile(*tensors, devices, coll, &po_data, ir_values); + + runtime::ComputationClient::ComputationPtr compile_result = + XLAGraphExecutor::Compile(tensors, devices, coll, nullptr, ir_values).computation; + + runtime::ComputationClient::ComputationPtr a = nullptr; + return a; // return nullptr; + +} + +XLAGraphExecutor::ComputationCache::TypePtr +XLAGraphExecutor::LookupCachedCompiletwo(const torch::lazy::hash_t& hash) { + ComputationCache::TypePtr cached_computation = + GetComputationCache()->Get(hash); + if (cached_computation == nullptr) { + TORCH_LAZY_COUNTER("UncachedCompile", 1); + return nullptr; + } + TF_VLOG(5) << "Graph hash " << torch::lazy::HashToString(hash) + << " is computation hash " + << torch::lazy::HashToString(torch::lazy::Hash( + cached_computation->computation->computation() + .proto() + .SerializeAsString())); + TORCH_LAZY_COUNTER("CachedCompile", 1); + return cached_computation; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index b2b76b8ae33..b745c9ffec2 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -187,6 +187,14 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { void ClearPendingIrs(std::vector tensors, const torch::lazy::BackendDevice& device); +// struct saveComputation { +// runtime::ComputationClient::ComputationPtr computation; +// }; +// std::vector + runtime::ComputationClient::ComputationPtr GetXLAComputation( + std::vector& tensors, + absl::Span devices, bool warm_up_cache_only = false); + private: // This is just to group results from compile(). Since our computation is // different, we don't reuse the upstream CompilationResult. @@ -331,6 +339,9 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { ComputationCache::TypePtr LookupCachedCompile( const torch::lazy::hash_t& hash); + ComputationCache::TypePtr LookupCachedCompiletwo( + const torch::lazy::hash_t& hash); + // We don't use the upstream TryRunCachedSync since // our CachedComputation is different from upstream. std::pair> TryRunCachedSync( @@ -360,6 +371,10 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { std::vector* tensors, absl::Span devices, const SyncTensorsConfig& config, bool warm_up_cache_only = false); +// runtime::ComputationClient::ComputationPtr GetXLAComputation(std::vector& tensors, +// absl::Span devices, +// bool warm_up_cache_only = false); + ComputationCache* computation_cache_; }; diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 8ed3a783200..a79df1dc72f 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -12,6 +12,32 @@ from torch._higher_order_ops.while_loop import while_loop_op from torch._higher_order_ops.while_loop import while_loop as torch_while_loop +# def insert_into_res_with_output_value(): + +def get_module_parameters(res, module, contain_output_balue=False): + bn_list = [] + for name, param in module.named_parameters(): + if name[:2]=='bn': + bn_list.append(param.clone()) + + if contain_output_balue: + res.insert(-1, param.clone()) + else: + res.append(param.clone()) + + # hard-code BatchNorm2d duplicated parameters + if len(bn_list) !=0: + bn_list.reverse() + if contain_output_balue: + output_value = res[-1] + res = res[:-1] + bn_list + res.append(output_value) + else: + res = res + bn_list + # print("bn_list: ", bn_list) + # if not contain_output_balue: + # print("res: ", res) + return res # TODO(@manfei): treat *input_value def fori_loop(upper, lower, body_fun, init_val, input_value): @@ -55,7 +81,6 @@ def body_fn(upper, lower, one_value, x, input_value): return res - @while_loop_op.py_impl(DispatchKey.XLA) def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') @@ -63,6 +88,21 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) if additional_inputs is None: additional_inputs = tuple() + else: + print("while additional_inputs: ", additional_inputs) + ### !!! we could use additional_inputs from PyTorch Dynamo to build dummy params in cond and maybe inputs too? + # additional_inputs = tuple(reversed(additional_inputs)) + + # additional_inputs_list = list(additional_inputs) + # additional_inputs_list.reverse() + # additional_inputs = tuple(additional_inputs_list) + + + def wrapped_body_fn_with_paras(*carried_inputs): + res = body_fn(*carried_inputs) + return tuple(res) + + # return _xla_while_loop(cond_fn, wrapped_body_fn_with_paras, carried_inputs, additional_inputs) return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs) @@ -86,31 +126,50 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") + # # TODO(@manfei): treat hard-code cond xlacomputation change: currently switch output_value and weight position if additional_inputs(weight/bias) exists + # additional_inputs_list_cond = list( + # fake_carried_inputs[2:] + # ) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor + # if additional_inputs: + # tmp_bias = additional_inputs_list_cond[ + # -3] # not used, change order doesn't affect logic + # del additional_inputs_list_cond[ + # -3] # not used, change order doesn't affect logic + # additional_inputs_list_cond.append( + # tmp_bias) # not used, change order doesn't affect logic + # TODO(@manfei): treat hard-code cond xlacomputation change: currently switch output_value and weight position if additional_inputs(weight/bias) exists additional_inputs_list_cond = list( fake_carried_inputs[2:] ) # all missed arguments except upper/lower due to PyTorch/XLA trace from output tensor if additional_inputs: - tmp_bias = additional_inputs_list_cond[ - -3] # not used, change order doesn't affect logic - del additional_inputs_list_cond[ - -3] # not used, change order doesn't affect logic - additional_inputs_list_cond.append( - tmp_bias) # not used, change order doesn't affect logic + ### actually output_value + tmp_bias = additional_inputs_list_cond[3] # additional_inputs_list_cond[-3] # not used, change order doesn't affect logic + del additional_inputs_list_cond[3] # additional_inputs_list_cond[-3] # not used, change order doesn't affect logic + additional_inputs_list_cond.append(tmp_bias) # not used, change order doesn't affect logic cond_ctx.buildforiloop([cond_result], additional_inputs_list_cond) cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) + cond_hlo_print = xb.get_computation_hlo(cond_computation) + print("cond computation: !!!!!!!!!") + print(cond_hlo_print) + print("fake_carried_inputs: ", fake_carried_inputs) # generate body_fn xlacomputation + import pdb; pdb.set_trace() body_result = body_fn(*fake_carried_inputs) - body_ctx = torch_xla._XLAC.lowering.LoweringContext() + body_ctx = torch_xla._XLAC.lowering.LoweringContext() # PyLoweringContext body_ctx.set_name_string("bodyctx") + ### !!! new solution to create xla_computation + body_ctx = torch_xla._XLAC._get_xla_computation(outputs) + # TODO(@manfei): treat hard-code body xlacomputation change: currently add non-changed output_value argument if additional_inputs(weight/bias) exists if additional_inputs: - additional_inputs_list_body = [fake_carried_inputs[-3]] + # add output if additional params + additional_inputs_list_body = [fake_carried_inputs[5]] # -3]] else: additional_inputs_list_body = [] @@ -119,6 +178,9 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): body_hlo = body_ctx.hlo() body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) + body_hlo_print = xb.get_computation_hlo(body_computation) + print("body computation: !!!!!!!!!") + print(body_hlo_print) # trans fake_carried_inputs from list(tensor) to list(xla::op), which part could change init of xla::while total_inputs = carried_inputs + additional_inputs