From b4d83e43c331c7a55abd023e1c30d449bbc648dc Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Mon, 20 May 2024 22:56:29 +0000 Subject: [PATCH] Dump HLO HBM usage info --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 3d5fbaf1f8e..a1cb5713f19 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -591,6 +591,20 @@ std::vector PjRtComputationClient::Compile( client_->Compile(instance.computation, compile_options).value(); } + auto memory_stats_status_or = executable->GetCompiledMemoryStats(); + if (memory_stats_status_or.ok()) { + xla::CompiledMemoryStats memory_stats = memory_stats_status_or.value(); + TF_VLOG(3) << "memory usage detail = " << memory_stats.DebugString(); + TF_VLOG(3) + << "total runtime device memory required to run this program = " + << ((memory_stats.output_size_in_bytes + + memory_stats.temp_size_in_bytes) >> + 20) + << " MB"; + } else { + TF_VLOG(3) << "memory usage is not availiable"; + } + const auto& hlo_modules = ConsumeValue(executable->GetHloModules()); xla::HloComputation* hlo_computation = hlo_modules[0]->entry_computation(); std::shared_ptr pjrt_computation =