From 01548502eebb6207eed6c518e0cfbdaf2d66e9a3 Mon Sep 17 00:00:00 2001 From: mcuiaws Date: Tue, 17 Dec 2024 12:53:01 -0800 Subject: [PATCH] Compute and hash buffer_donor_indices for step marker (#8467) --- test/test_input_output_aliases.py | 19 ++++ torch_xla/csrc/xla_graph_executor.cpp | 121 +++++++++++++------------- torch_xla/csrc/xla_graph_executor.h | 14 +-- 3 files changed, 89 insertions(+), 65 deletions(-) diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py index 4c76651ec0e..8e7b0a7ed6e 100644 --- a/test/test_input_output_aliases.py +++ b/test/test_input_output_aliases.py @@ -143,6 +143,25 @@ def try_grad_accum(model, device, train_x, train_label, accum_steps): alias_count == 1.0 ), f"Expect 1 input-output alias pair for gradient accumulation, got {alias_count}" + def test_separate_graphs(self): + """ + Test that paramater aliasing differences should produce different graphs. + """ + xla_device = xm.xla_device() + t0 = torch.tensor([1], device=xla_device) + t1 = torch.tensor([2], device=xla_device) + xm.mark_step() + + t1.add_(t0) + xm.mark_step() + + # This needs to be a separate graph, otherwise t1 can be corrupted + # or result in PJRT error. + t2 = t1 + t0 + xm.mark_step() + + self.assertEqual(t1.item(), 3) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 63ebc121a84..81cf0207029 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -451,20 +451,6 @@ std::vector GetBufferDonorIndexFromUserConfig( return buffer_donor_indexs; } -std::vector XLAGraphExecutor::SetBufferDonorsFromUserConfig( - LoweringContext* lowering_ctx) { - const std::vector& parameters_data = - lowering_ctx->GetParametersData(); - std::vector buffer_donor_indexs = - GetBufferDonorIndexFromUserConfig(parameters_data); - for (size_t i : buffer_donor_indexs) { - lowering_ctx->builder()->AddBufferDonor(/*param_number=*/i, - /*param_index=*/{}); - } - TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", buffer_donor_indexs.size()); - return buffer_donor_indexs; -} - void XLAGraphExecutor::WaitDeviceOps(absl::Span devices) { std::set wait_devices; if (!devices.empty()) { @@ -1262,9 +1248,9 @@ XLAGraphExecutor::TryRunCachedSync( std::move(cached_computation), tensor_data_vec)); } -std::vector XLAGraphExecutor::SetBufferDonors( +std::vector GetBufferDonorIndexForStepMarker( const std::vector& tensors, absl::Span indices, - LoweringContext* lowering_ctx) { + const std::vector& parameters_data) { std::unordered_map output_tensor_id_map; std::vector buffer_donor_indexs; // tensors[indices] represent all tensors that needs to be updated after @@ -1275,7 +1261,6 @@ std::vector XLAGraphExecutor::SetBufferDonors( int64_t tensor_id = tensors[tensor_index]->data()->alias_id; output_tensor_id_map[tensor_id] = i; } - const auto& parameters_data = lowering_ctx->GetParametersData(); std::vector alias_map(indices.size(), -1); for (size_t i = 0; i < parameters_data.size(); ++i) { auto* data_info = @@ -1287,45 +1272,19 @@ std::vector XLAGraphExecutor::SetBufferDonors( // this buffer is not needed after execution since XLATensor will get a // new buffer. if (it != output_tensor_id_map.end()) { - lowering_ctx->builder()->AddBufferDonor(/*param_number=*/i, - /*param_index=*/{}); buffer_donor_indexs.push_back(i); } } } - TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", buffer_donor_indexs.size()); return buffer_donor_indexs; } -XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( - std::vector& tensors, absl::Span devices, - const SyncTensorCollection& coll, PostOrderData* po_data, - const std::vector& ir_values) { - tsl::profiler::TraceMe activity( - [&] { - return tsl::profiler::TraceMeEncode( - "XLAGraphExecutor::Compile", - {{"graph_hash", torch::lazy::HashToString(coll.hash)}}); - }, - tsl::profiler::TraceMeLevel::kInfo); +std::vector XLAGraphExecutor::GetBufferDonors( + const std::vector& tensors, const SyncTensorCollection& coll, + const std::vector& parameters_data) { static const bool enable_aliasing = runtime::sys_util::GetEnvBool("XLA_ENABLE_PARAM_ALIASING", true); - static const size_t parameter_wrapping_threadshold = - runtime::sys_util::GetEnvInt("XLA_PARAMETER_WRAPPING_THREADSHOLD", 3200); static const bool use_autosharding = ShardingUtil::GetAutoSharding(); - std::string graph_name = - (CurrentGraphName() != "") ? CurrentGraphName() : "SyncTensorsGraph"; - LoweringContext lowering_ctx(graph_name, coll.device, po_data->post_order, - std::move(po_data->emission_map)); - for (auto ir_value : ir_values) { - xla::XlaOp root = lowering_ctx.GetOutputOp( - torch::lazy::Output(ir_value.node.get(), ir_value.index)); - lowering_ctx.AddResult(root); - } - // Always execute sharded when running in SPMD mode - bool is_sharded = (coll.device == GetVirtualDevice()) || UseVirtualDevice(); - // Annotate HLO sharding selectively in the compuation. - ShardingUtil::SetHloSharding(&lowering_ctx); std::vector buffer_donor_indices; // TODO(yeounoh) enable aliasing is disabled for partitioned computation, @@ -1357,14 +1316,59 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( // will later fetch the new value of A, which is incorrect. // But, when we issue a step barrier (force_ltc_data == true) we have to // turn everything into DEVICE_DATA, so we can activate aliasing. - buffer_donor_indices = - SetBufferDonors(tensors, coll.indices, &lowering_ctx); + buffer_donor_indices = GetBufferDonorIndexForStepMarker( + tensors, coll.indices, parameters_data); } else if (GetAliasWithBufferDonorConfig()) { // only alias based on buffer donor if LTC can't auto infer the input // output aliasing. - buffer_donor_indices = SetBufferDonorsFromUserConfig(&lowering_ctx); + buffer_donor_indices = GetBufferDonorIndexFromUserConfig(parameters_data); } } + return buffer_donor_indices; +} + +void XLAGraphExecutor::SetBufferDonors( + LoweringContext* lowering_ctx, + const std::vector& buffer_donor_indexs) { + const std::vector& parameters_data = + lowering_ctx->GetParametersData(); + for (size_t i : buffer_donor_indexs) { + lowering_ctx->builder()->AddBufferDonor(/*param_number=*/i, + /*param_index=*/{}); + } + TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", buffer_donor_indexs.size()); +} + +XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( + std::vector& tensors, absl::Span devices, + const SyncTensorCollection& coll, PostOrderData* po_data, + const std::vector& ir_values, + const std::vector& buffer_donor_indices) { + tsl::profiler::TraceMe activity( + [&] { + return tsl::profiler::TraceMeEncode( + "XLAGraphExecutor::Compile", + {{"graph_hash", torch::lazy::HashToString(coll.hash)}}); + }, + tsl::profiler::TraceMeLevel::kInfo); + static const size_t parameter_wrapping_threadshold = + runtime::sys_util::GetEnvInt("XLA_PARAMETER_WRAPPING_THREADSHOLD", 3200); + static const bool use_autosharding = ShardingUtil::GetAutoSharding(); + std::string graph_name = + (CurrentGraphName() != "") ? CurrentGraphName() : "SyncTensorsGraph"; + LoweringContext lowering_ctx(graph_name, coll.device, po_data->post_order, + std::move(po_data->emission_map)); + for (auto ir_value : ir_values) { + xla::XlaOp root = lowering_ctx.GetOutputOp( + torch::lazy::Output(ir_value.node.get(), ir_value.index)); + lowering_ctx.AddResult(root); + } + // Always execute sharded when running in SPMD mode + bool is_sharded = (coll.device == GetVirtualDevice()) || UseVirtualDevice(); + // Annotate HLO sharding selectively in the compuation. + ShardingUtil::SetHloSharding(&lowering_ctx); + + SetBufferDonors(&lowering_ctx, buffer_donor_indices); xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla()); xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); @@ -1499,14 +1503,13 @@ XLAGraphExecutor::SyncTensorsGraphInternal( PostOrderData po_data = RunPostOrder(ir_values, &coll); coll.hash = torch::lazy::HashCombine( coll.hash, torch::lazy::Hash(po_data.parameter_sequence)); - if (GetAliasWithBufferDonorConfig()) { - std::vector buffer_donor_index = - GetBufferDonorIndexFromUserConfig(po_data.parameters_data); - if (buffer_donor_index.size() > 0) { - // Do not include hash on a empty vector. - coll.hash = torch::lazy::HashCombine( - coll.hash, torch::lazy::Hash(buffer_donor_index)); - } + + std::vector buffer_donor_indices = + GetBufferDonors(*tensors, coll, po_data.parameters_data); + if (buffer_donor_indices.size() > 0) { + // Do not include hash on a empty vector. + coll.hash = torch::lazy::HashCombine( + coll.hash, torch::lazy::Hash(buffer_donor_indices)); } { // Auto-sharding configs @@ -1529,8 +1532,8 @@ XLAGraphExecutor::SyncTensorsGraphInternal( // we have a cache hit, execution has been scheduled by TryRunCachedSync. return cache_res.second; } - CompilationResult compile_result = - Compile(*tensors, devices, coll, &po_data, ir_values); + CompilationResult compile_result = Compile(*tensors, devices, coll, &po_data, + ir_values, buffer_donor_indices); TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", compile_result.emitted_nodes); TF_VLOG(5) << "TensorsGraphSize=" << compile_result.emitted_nodes; diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 33848767748..71fe5a83243 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -356,12 +356,13 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { const std::vector& tensor_data_vec, bool warm_up_cache_only); - std::vector SetBufferDonors(const std::vector& tensors, - absl::Span indices, - LoweringContext* lowering_ctx); + std::vector GetBufferDonors( + const std::vector& tensors, + const SyncTensorCollection& coll, + const std::vector& parameters_data); - std::vector SetBufferDonorsFromUserConfig( - LoweringContext* lowering_ctx); + void SetBufferDonors(LoweringContext* lowering_ctx, + const std::vector& buffer_donor_indices); // TODO(yeounoh) auto-sharding can change tensors shardings, which needs to be // accounted for in Dynamo integration. @@ -369,7 +370,8 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { absl::Span devices, const SyncTensorCollection& coll, PostOrderData* po_data, - const std::vector& ir_values); + const std::vector& ir_values, + const std::vector& buffer_donor_indices); // We don't use the upstream SyncTensorsGraphInternal since // our CachedComputation is different from upstream.