diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 91f40fb9646..e7e2824eeaa 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1517,106 +1517,106 @@ XLAGraphExecutor::SyncTensorsGraphInternal( } } -// // std::shared_ptr -// // XLAGraphExecutor::CompilationResult -// // runtime::ComputationClient::ComputationPtr -// // XLAGraphExecutor::saveComputation* -// // std::vector -// XLAGraphExecutor::CachedComputation XLAGraphExecutor::GetXLAComputation(std::vector* tensors, -// absl::Span devices, bool warm_up_cache_only = false) { -// tsl::profiler::TraceMe activity("GetXLAComputation", -// tsl::profiler::TraceMeLevel::kInfo); -// SyncTensorsConfig config; -// config.force_ltc_data = false; -// SyncTensorCollection coll = CollectSyncTensors(*tensors, config); -// if (coll.indices.empty()) { -// TensorCollectionBarrier(&coll); -// return nullptr; -// } -// DebugUtil::SaveTensorsGraphInfo("ScheduleSyncTensorsGraph", *tensors, -// &coll.indices); -// std::vector ir_values; -// std::vector tensor_data_vec; -// ExtractIRAndPrepareXlaData_(tensors, coll.config, coll.indices, ir_values, -// tensor_data_vec); -// PostOrderData po_data = RunPostOrder(ir_values, &coll); -// coll.hash = torch::lazy::HashCombine( -// coll.hash, torch::lazy::Hash(po_data.parameter_sequence)); -// if (GetAliasWithBufferDonorConfig()) { -// std::vector buffer_donor_index = -// GetBufferDonorIndexFromUserConfig(po_data.parameters_data); -// if (buffer_donor_index.size() > 0) { -// // Do not include hash on a empty vector. -// coll.hash = torch::lazy::HashCombine( -// coll.hash, torch::lazy::Hash(buffer_donor_index)); -// } -// } -// { -// // Auto-sharding configs -// coll.hash = torch::lazy::HashCombine( -// coll.hash, torch::lazy::MHash(ShardingUtil::GetAutoSharding())); -// coll.hash = torch::lazy::HashCombine( -// coll.hash, -// torch::lazy::StringHash( -// runtime::sys_util::GetEnvString("XLA_AUTO_SPMD_MESH", "").c_str())); -// } - -// DebugUtil::SaveGraphHash(coll.hash); -// TF_VLOG(4) << "Parameter sequence graph hash " -// << torch::lazy::HashToString(coll.hash); - -// std::pair> cache_res = -// TryRunCachedSync(tensors, &coll, &po_data, tensor_data_vec, -// warm_up_cache_only); -// if (cache_res.first) { -// // we have a cache hit, execution has been scheduled by TryRunCachedSync. -// return cache_res.second; -// } - -// // CompilationResult compile_result = -// // Compile(*tensors, devices, coll, &po_data, ir_values); - -// // runtime::ComputationClient::ComputationPtr -// // saveComputation* compile_result = std::move( -// // Compile(*tensors, devices, coll, &po_data, ir_values).computation) -// XLAGraphExecutor::saveComputation* compile_result = Compile(*tensors, devices, coll, &po_data, ir_values).computation -// return compile_result - -// // TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", compile_result.emitted_nodes); -// // TF_VLOG(5) << "TensorsGraphSize=" << compile_result.emitted_nodes; -// // auto cached_computation = std::make_shared( -// // std::move(compile_result.computation), compile_result.is_sharded); -// // GetComputationCache()->Add(coll.hash, cached_computation); - -// // if (warm_up_cache_only) { -// // return nullptr; -// // } else { -// // return ScheduleSyncTensorsGraph( -// // tensors, &coll, std::move(compile_result.parameters_data), -// // compile_result.device.toString(), std::move(cached_computation), -// // tensor_data_vec); -// // } -// } - -runtime::ComputationClient::ComputationPtr XLAGraphExecutor::GetXLAComputation( - std::vector& tensors, - absl::Span devices, bool warm_up_cache_only) { - // coll +// std::shared_ptr +// XLAGraphExecutor::CompilationResult +// runtime::ComputationClient::ComputationPtr +// XLAGraphExecutor::saveComputation* +// std::vector +XLAGraphExecutor::CachedComputation XLAGraphExecutor::GetXLAComputation(std::vector* tensors, + absl::Span devices, bool warm_up_cache_only = false) { + tsl::profiler::TraceMe activity("GetXLAComputation", + tsl::profiler::TraceMeLevel::kInfo); SyncTensorsConfig config; config.force_ltc_data = false; - SyncTensorCollection coll = CollectSyncTensors(tensors, config); + SyncTensorCollection coll = CollectSyncTensors(*tensors, config); if (coll.indices.empty()) { TensorCollectionBarrier(&coll); return nullptr; } - DebugUtil::SaveTensorsGraphInfo("ScheduleSyncTensorsGraph", tensors, + DebugUtil::SaveTensorsGraphInfo("ScheduleSyncTensorsGraph", *tensors, &coll.indices); - - // ir_values std::vector ir_values; std::vector tensor_data_vec; ExtractIRAndPrepareXlaData_(tensors, coll.config, coll.indices, ir_values, tensor_data_vec); + PostOrderData po_data = RunPostOrder(ir_values, &coll); + coll.hash = torch::lazy::HashCombine( + coll.hash, torch::lazy::Hash(po_data.parameter_sequence)); + if (GetAliasWithBufferDonorConfig()) { + std::vector buffer_donor_index = + GetBufferDonorIndexFromUserConfig(po_data.parameters_data); + if (buffer_donor_index.size() > 0) { + // Do not include hash on a empty vector. + coll.hash = torch::lazy::HashCombine( + coll.hash, torch::lazy::Hash(buffer_donor_index)); + } + } + { + // Auto-sharding configs + coll.hash = torch::lazy::HashCombine( + coll.hash, torch::lazy::MHash(ShardingUtil::GetAutoSharding())); + coll.hash = torch::lazy::HashCombine( + coll.hash, + torch::lazy::StringHash( + runtime::sys_util::GetEnvString("XLA_AUTO_SPMD_MESH", "").c_str())); + } + + DebugUtil::SaveGraphHash(coll.hash); + TF_VLOG(4) << "Parameter sequence graph hash " + << torch::lazy::HashToString(coll.hash); + + std::pair> cache_res = + TryRunCachedSync(tensors, &coll, &po_data, tensor_data_vec, + warm_up_cache_only); + if (cache_res.first) { + // we have a cache hit, execution has been scheduled by TryRunCachedSync. + return cache_res.second; + } + + // CompilationResult compile_result = + // Compile(*tensors, devices, coll, &po_data, ir_values); + + // runtime::ComputationClient::ComputationPtr + // saveComputation* compile_result = std::move( + // Compile(*tensors, devices, coll, &po_data, ir_values).computation) + XLAGraphExecutor::saveComputation* compile_result = Compile(*tensors, devices, coll, &po_data, ir_values).computation + return compile_result + + // TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", compile_result.emitted_nodes); + // TF_VLOG(5) << "TensorsGraphSize=" << compile_result.emitted_nodes; + // auto cached_computation = std::make_shared( + // std::move(compile_result.computation), compile_result.is_sharded); + // GetComputationCache()->Add(coll.hash, cached_computation); + + // if (warm_up_cache_only) { + // return nullptr; + // } else { + // return ScheduleSyncTensorsGraph( + // tensors, &coll, std::move(compile_result.parameters_data), + // compile_result.device.toString(), std::move(cached_computation), + // tensor_data_vec); + // } +} + +/// runtime::ComputationClient::ComputationPtr XLAGraphExecutor::GetXLAComputation( +/// std::vector& tensors, +/// absl::Span devices, bool warm_up_cache_only) { +/// // coll +/// SyncTensorsConfig config; +/// config.force_ltc_data = false; +/// SyncTensorCollection coll = CollectSyncTensors(tensors, config); +/// if (coll.indices.empty()) { +/// TensorCollectionBarrier(&coll); +/// return nullptr; +/// } +/// DebugUtil::SaveTensorsGraphInfo("ScheduleSyncTensorsGraph", tensors, +/// &coll.indices); + +/// // ir_values +/// std::vector ir_values; +/// std::vector tensor_data_vec; +/// ExtractIRAndPrepareXlaData_(tensors, coll.config, coll.indices, ir_values, +/// tensor_data_vec); // PostOrderData po_data = RunPostOrder(ir_values, &coll); // coll.hash = torch::lazy::HashCombine(