Skip to content

Commit

Permalink
recomment target version
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed May 6, 2024
1 parent 38fae84 commit dfb62e4
Showing 1 changed file with 89 additions and 89 deletions.
178 changes: 89 additions & 89 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1517,106 +1517,106 @@ XLAGraphExecutor::SyncTensorsGraphInternal(
}
}

// // std::shared_ptr<XLAGraphExecutor::Async>
// // XLAGraphExecutor::CompilationResult
// // runtime::ComputationClient::ComputationPtr
// // XLAGraphExecutor::saveComputation*
// // std::vector<runtime::ComputationClient::ComputationPtr>
// XLAGraphExecutor::CachedComputation XLAGraphExecutor::GetXLAComputation(std::vector<XLATensorPtr>* tensors,
// absl::Span<const std::string> devices, bool warm_up_cache_only = false) {
// tsl::profiler::TraceMe activity("GetXLAComputation",
// tsl::profiler::TraceMeLevel::kInfo);
// SyncTensorsConfig config;
// config.force_ltc_data = false;
// SyncTensorCollection coll = CollectSyncTensors(*tensors, config);
// if (coll.indices.empty()) {
// TensorCollectionBarrier(&coll);
// return nullptr;
// }
// DebugUtil::SaveTensorsGraphInfo("ScheduleSyncTensorsGraph", *tensors,
// &coll.indices);
// std::vector<torch::lazy::Value> ir_values;
// std::vector<torch::lazy::BackendDataPtr> tensor_data_vec;
// ExtractIRAndPrepareXlaData_(tensors, coll.config, coll.indices, ir_values,
// tensor_data_vec);
// PostOrderData po_data = RunPostOrder(ir_values, &coll);
// coll.hash = torch::lazy::HashCombine(
// coll.hash, torch::lazy::Hash(po_data.parameter_sequence));
// if (GetAliasWithBufferDonorConfig()) {
// std::vector<size_t> buffer_donor_index =
// GetBufferDonorIndexFromUserConfig(po_data.parameters_data);
// if (buffer_donor_index.size() > 0) {
// // Do not include hash on a empty vector.
// coll.hash = torch::lazy::HashCombine(
// coll.hash, torch::lazy::Hash(buffer_donor_index));
// }
// }
// {
// // Auto-sharding configs
// coll.hash = torch::lazy::HashCombine(
// coll.hash, torch::lazy::MHash(ShardingUtil::GetAutoSharding()));
// coll.hash = torch::lazy::HashCombine(
// coll.hash,
// torch::lazy::StringHash(
// runtime::sys_util::GetEnvString("XLA_AUTO_SPMD_MESH", "").c_str()));
// }

// DebugUtil::SaveGraphHash(coll.hash);
// TF_VLOG(4) << "Parameter sequence graph hash "
// << torch::lazy::HashToString(coll.hash);

// std::pair<bool, std::shared_ptr<XLAGraphExecutor::Async>> cache_res =
// TryRunCachedSync(tensors, &coll, &po_data, tensor_data_vec,
// warm_up_cache_only);
// if (cache_res.first) {
// // we have a cache hit, execution has been scheduled by TryRunCachedSync.
// return cache_res.second;
// }

// // CompilationResult compile_result =
// // Compile(*tensors, devices, coll, &po_data, ir_values);

// // runtime::ComputationClient::ComputationPtr
// // saveComputation* compile_result = std::move(
// // Compile(*tensors, devices, coll, &po_data, ir_values).computation)
// XLAGraphExecutor::saveComputation* compile_result = Compile(*tensors, devices, coll, &po_data, ir_values).computation
// return compile_result

// // TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", compile_result.emitted_nodes);
// // TF_VLOG(5) << "TensorsGraphSize=" << compile_result.emitted_nodes;
// // auto cached_computation = std::make_shared<CachedComputation>(
// // std::move(compile_result.computation), compile_result.is_sharded);
// // GetComputationCache()->Add(coll.hash, cached_computation);

// // if (warm_up_cache_only) {
// // return nullptr;
// // } else {
// // return ScheduleSyncTensorsGraph(
// // tensors, &coll, std::move(compile_result.parameters_data),
// // compile_result.device.toString(), std::move(cached_computation),
// // tensor_data_vec);
// // }
// }

runtime::ComputationClient::ComputationPtr XLAGraphExecutor::GetXLAComputation(
std::vector<XLATensorPtr>& tensors,
absl::Span<const std::string> devices, bool warm_up_cache_only) {
// coll
// std::shared_ptr<XLAGraphExecutor::Async>
// XLAGraphExecutor::CompilationResult
// runtime::ComputationClient::ComputationPtr
// XLAGraphExecutor::saveComputation*
// std::vector<runtime::ComputationClient::ComputationPtr>
XLAGraphExecutor::CachedComputation XLAGraphExecutor::GetXLAComputation(std::vector<XLATensorPtr>* tensors,
absl::Span<const std::string> devices, bool warm_up_cache_only = false) {
tsl::profiler::TraceMe activity("GetXLAComputation",
tsl::profiler::TraceMeLevel::kInfo);
SyncTensorsConfig config;
config.force_ltc_data = false;
SyncTensorCollection coll = CollectSyncTensors(tensors, config);
SyncTensorCollection coll = CollectSyncTensors(*tensors, config);
if (coll.indices.empty()) {
TensorCollectionBarrier(&coll);
return nullptr;
}
DebugUtil::SaveTensorsGraphInfo("ScheduleSyncTensorsGraph", tensors,
DebugUtil::SaveTensorsGraphInfo("ScheduleSyncTensorsGraph", *tensors,
&coll.indices);

// ir_values
std::vector<torch::lazy::Value> ir_values;
std::vector<torch::lazy::BackendDataPtr> tensor_data_vec;
ExtractIRAndPrepareXlaData_(tensors, coll.config, coll.indices, ir_values,
tensor_data_vec);
PostOrderData po_data = RunPostOrder(ir_values, &coll);
coll.hash = torch::lazy::HashCombine(
coll.hash, torch::lazy::Hash(po_data.parameter_sequence));
if (GetAliasWithBufferDonorConfig()) {
std::vector<size_t> buffer_donor_index =
GetBufferDonorIndexFromUserConfig(po_data.parameters_data);
if (buffer_donor_index.size() > 0) {
// Do not include hash on a empty vector.
coll.hash = torch::lazy::HashCombine(
coll.hash, torch::lazy::Hash(buffer_donor_index));
}
}
{
// Auto-sharding configs
coll.hash = torch::lazy::HashCombine(
coll.hash, torch::lazy::MHash(ShardingUtil::GetAutoSharding()));
coll.hash = torch::lazy::HashCombine(
coll.hash,
torch::lazy::StringHash(
runtime::sys_util::GetEnvString("XLA_AUTO_SPMD_MESH", "").c_str()));
}

DebugUtil::SaveGraphHash(coll.hash);
TF_VLOG(4) << "Parameter sequence graph hash "
<< torch::lazy::HashToString(coll.hash);

std::pair<bool, std::shared_ptr<XLAGraphExecutor::Async>> cache_res =
TryRunCachedSync(tensors, &coll, &po_data, tensor_data_vec,
warm_up_cache_only);
if (cache_res.first) {
// we have a cache hit, execution has been scheduled by TryRunCachedSync.
return cache_res.second;
}

// CompilationResult compile_result =
// Compile(*tensors, devices, coll, &po_data, ir_values);

// runtime::ComputationClient::ComputationPtr
// saveComputation* compile_result = std::move(
// Compile(*tensors, devices, coll, &po_data, ir_values).computation)
XLAGraphExecutor::saveComputation* compile_result = Compile(*tensors, devices, coll, &po_data, ir_values).computation
return compile_result

// TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", compile_result.emitted_nodes);
// TF_VLOG(5) << "TensorsGraphSize=" << compile_result.emitted_nodes;
// auto cached_computation = std::make_shared<CachedComputation>(
// std::move(compile_result.computation), compile_result.is_sharded);
// GetComputationCache()->Add(coll.hash, cached_computation);

// if (warm_up_cache_only) {
// return nullptr;
// } else {
// return ScheduleSyncTensorsGraph(
// tensors, &coll, std::move(compile_result.parameters_data),
// compile_result.device.toString(), std::move(cached_computation),
// tensor_data_vec);
// }
}

/// runtime::ComputationClient::ComputationPtr XLAGraphExecutor::GetXLAComputation(
/// std::vector<XLATensorPtr>& tensors,
/// absl::Span<const std::string> devices, bool warm_up_cache_only) {
/// // coll
/// SyncTensorsConfig config;
/// config.force_ltc_data = false;
/// SyncTensorCollection coll = CollectSyncTensors(tensors, config);
/// if (coll.indices.empty()) {
/// TensorCollectionBarrier(&coll);
/// return nullptr;
/// }
/// DebugUtil::SaveTensorsGraphInfo("ScheduleSyncTensorsGraph", tensors,
/// &coll.indices);

/// // ir_values
/// std::vector<torch::lazy::Value> ir_values;
/// std::vector<torch::lazy::BackendDataPtr> tensor_data_vec;
/// ExtractIRAndPrepareXlaData_(tensors, coll.config, coll.indices, ir_values,
/// tensor_data_vec);
// PostOrderData po_data = RunPostOrder(ir_values, &coll);

// coll.hash = torch::lazy::HashCombine(
Expand Down

0 comments on commit dfb62e4

Please sign in to comment.