Skip to content

Commit

Permalink
Use TPU profiler plugin (#5793)
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar authored Nov 14, 2023
1 parent 5710a83 commit 3b34cb2
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 4 deletions.
24 changes: 21 additions & 3 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ cc_library(
":debug_macros",
":env_vars",
":multi_wait",
":profiler",
":stablehlo_helper",
":tensor_source",
":tf_logging",
Expand All @@ -97,6 +98,7 @@ cc_library(
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt:tfrt_cpu_pjrt_client",
"@xla//xla/pjrt:pjrt_c_api_client",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@tsl//tsl/profiler/lib:traceme",
"@tsl//tsl/platform/cloud:gcs_file_system",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -194,19 +196,35 @@ cc_library(
],
)

# Profiler silently fails unless we link these backends
cc_library(
name = "profiler_backends",
visibility = ["//visibility:private"],
deps = [
"@xla//xla/backends/profiler/cpu:host_tracer",
"@xla//xla/backends/profiler/cpu:metadata_collector",
] + if_cuda_is_configured([
"@xla//xla/backends/profiler/gpu:device_tracer",
]),
alwayslink = True,
)

cc_library(
name = "profiler",
srcs = ["profiler.cc"],
hdrs = ["profiler.h"],
deps = [
":debug_macros",
":profiler_backends",
"@xla//xla/backends/profiler/plugin:profiler_c_api_hdrs",
"@xla//xla/backends/profiler/plugin:plugin_tracer",
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",
"@tsl//tsl/platform:status",
"@tsl//tsl/profiler/lib:profiler_factory",
"@tsl//tsl/profiler/rpc:profiler_server_impl",
"@tsl//tsl/profiler/rpc/client:capture_profile",
"@com_google_absl//absl/container:flat_hash_map",

# Profiler silently fails unless we include this
"@xla//xla/backends/profiler:profiler_backends",

# TODO: We get missing symbol errors without these deps. Why aren't they
# included transitively from TensorFlow/TSL?
"@tsl//tsl/profiler/protobuf:profiler_analysis_proto_cc_impl",
Expand Down
7 changes: 6 additions & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/profiler.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
Expand All @@ -21,6 +22,7 @@
#include "xla/client/xla_computation.h"
#include "xla/layout_util.h"
#include "xla/literal.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/pjrt_api.h"
Expand Down Expand Up @@ -117,8 +119,11 @@ PjRtComputationClient::PjRtComputationClient() {
"tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so"))
.status());
tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu");
XLA_CHECK(tpu_status.ok());
XLA_CHECK_OK(tpu_status);
client_ = std::move(xla::GetCApiClient("TPU").value());
const PJRT_Api* c_api =
static_cast<xla::PjRtCApiClient*>(client_.get())->pjrt_c_api();
profiler::RegisterProfilerForPlugin(c_api);
} else if (device_type == "TPU_LEGACY") {
XLA_ERROR() << "TPU_LEGACY client is no longer available.";
} else if (device_type == "GPU" || device_type == "CUDA" ||
Expand Down
35 changes: 35 additions & 0 deletions torch_xla/csrc/runtime/profiler.cc
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
#include "torch_xla/csrc/runtime/profiler.h"

#include "absl/container/flat_hash_map.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "tsl/platform/status.h"
#include "tsl/profiler/lib/profiler_factory.h"
#include "tsl/profiler/rpc/client/capture_profile.h"
#include "tsl/profiler/rpc/profiler_server.h"
#include "xla/backends/profiler/plugin/plugin_tracer.h"
#include "xla/backends/profiler/plugin/profiler_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_profiler_extension.h"

namespace torch_xla {
namespace runtime {
namespace profiler {

namespace {

const PLUGIN_Profiler_Api* FindProfilerApi(const PJRT_Api* pjrt_api) {
const PJRT_Structure_Base* next =
reinterpret_cast<const PJRT_Structure_Base*>(pjrt_api->extension_start);
while (next != nullptr &&
next->type != PJRT_Structure_Type::PJRT_Structure_Type_Profiler) {
next = next->next;
}
if (next == nullptr) {
return nullptr;
}
return reinterpret_cast<const PJRT_Profiler_Extension*>(next)->profiler_api;
}

} // namespace

struct ProfilerServer::Impl {
Impl() : server(new tsl::profiler::ProfilerServer()) {}

Expand All @@ -33,6 +55,19 @@ tsl::Status Trace(
/*include_dataset_ops=*/false, duration_ms, num_tracing_attempts,
options);
}

void RegisterProfilerForPlugin(const PJRT_Api* c_api) {
const PLUGIN_Profiler_Api* profiler_api = FindProfilerApi(c_api);
XLA_CHECK(profiler_api);

tsl::profiler::ProfilerFactory create_func =
[profiler_api](const tensorflow::ProfileOptions& options) {
return std::make_unique<xla::profiler::PluginTracer>(profiler_api,
options);
};
tsl::profiler::RegisterProfilerFactory(std::move(create_func));
}

} // namespace profiler
} // namespace runtime
} // namespace torch_xla
3 changes: 3 additions & 0 deletions torch_xla/csrc/runtime/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "absl/container/flat_hash_map.h"
#include "tsl/platform/status.h"
#include "xla/pjrt/c/pjrt_c_api.h"

namespace torch_xla {
namespace runtime {
Expand All @@ -28,6 +29,8 @@ tsl::Status Trace(
const absl::flat_hash_map<std::string, std::variant<int, std::string>>&
options);

void RegisterProfilerForPlugin(const PJRT_Api* c_api);

} // namespace profiler
} // namespace runtime
} // namespace torch_xla
Expand Down

0 comments on commit 3b34cb2

Please sign in to comment.