diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 9dc3730299d..cbeea6abeb7 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -83,6 +83,7 @@ cc_library( ":debug_macros", ":env_vars", ":multi_wait", + ":profiler", ":stablehlo_helper", ":tensor_source", ":tf_logging", @@ -97,6 +98,7 @@ cc_library( "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt:tfrt_cpu_pjrt_client", "@xla//xla/pjrt:pjrt_c_api_client", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/platform/cloud:gcs_file_system", "@com_google_absl//absl/strings", @@ -194,19 +196,35 @@ cc_library( ], ) +# Profiler silently fails unless we link these backends +cc_library( + name = "profiler_backends", + visibility = ["//visibility:private"], + deps = [ + "@xla//xla/backends/profiler/cpu:host_tracer", + "@xla//xla/backends/profiler/cpu:metadata_collector", + ] + if_cuda_is_configured([ + "@xla//xla/backends/profiler/gpu:device_tracer", + ]), + alwayslink = True, +) + cc_library( name = "profiler", srcs = ["profiler.cc"], hdrs = ["profiler.h"], deps = [ + ":debug_macros", + ":profiler_backends", + "@xla//xla/backends/profiler/plugin:profiler_c_api_hdrs", + "@xla//xla/backends/profiler/plugin:plugin_tracer", + "@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs", "@tsl//tsl/platform:status", + "@tsl//tsl/profiler/lib:profiler_factory", "@tsl//tsl/profiler/rpc:profiler_server_impl", "@tsl//tsl/profiler/rpc/client:capture_profile", "@com_google_absl//absl/container:flat_hash_map", - # Profiler silently fails unless we include this - "@xla//xla/backends/profiler:profiler_backends", - # TODO: We get missing symbol errors without these deps. Why aren't they # included transitively from TensorFlow/TSL? "@tsl//tsl/profiler/protobuf:profiler_analysis_proto_cc_impl", diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index fba50dcb63d..0fa3a790092 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -11,6 +11,7 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/multi_wait.h" +#include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/tensor_source.h" #include "torch_xla/csrc/runtime/tf_logging.h" @@ -21,6 +22,7 @@ #include "xla/client/xla_computation.h" #include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/distributed/distributed.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_api.h" @@ -117,8 +119,11 @@ PjRtComputationClient::PjRtComputationClient() { "tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so")) .status()); tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); - XLA_CHECK(tpu_status.ok()); + XLA_CHECK_OK(tpu_status); client_ = std::move(xla::GetCApiClient("TPU").value()); + const PJRT_Api* c_api = + static_cast(client_.get())->pjrt_c_api(); + profiler::RegisterProfilerForPlugin(c_api); } else if (device_type == "TPU_LEGACY") { XLA_ERROR() << "TPU_LEGACY client is no longer available."; } else if (device_type == "GPU" || device_type == "CUDA" || diff --git a/torch_xla/csrc/runtime/profiler.cc b/torch_xla/csrc/runtime/profiler.cc index 41de76ebd5e..a2ea89be16d 100644 --- a/torch_xla/csrc/runtime/profiler.cc +++ b/torch_xla/csrc/runtime/profiler.cc @@ -1,14 +1,36 @@ #include "torch_xla/csrc/runtime/profiler.h" #include "absl/container/flat_hash_map.h" +#include "torch_xla/csrc/runtime/debug_macros.h" #include "tsl/platform/status.h" +#include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/rpc/client/capture_profile.h" #include "tsl/profiler/rpc/profiler_server.h" +#include "xla/backends/profiler/plugin/plugin_tracer.h" +#include "xla/backends/profiler/plugin/profiler_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_profiler_extension.h" namespace torch_xla { namespace runtime { namespace profiler { +namespace { + +const PLUGIN_Profiler_Api* FindProfilerApi(const PJRT_Api* pjrt_api) { + const PJRT_Structure_Base* next = + reinterpret_cast(pjrt_api->extension_start); + while (next != nullptr && + next->type != PJRT_Structure_Type::PJRT_Structure_Type_Profiler) { + next = next->next; + } + if (next == nullptr) { + return nullptr; + } + return reinterpret_cast(next)->profiler_api; +} + +} // namespace + struct ProfilerServer::Impl { Impl() : server(new tsl::profiler::ProfilerServer()) {} @@ -33,6 +55,19 @@ tsl::Status Trace( /*include_dataset_ops=*/false, duration_ms, num_tracing_attempts, options); } + +void RegisterProfilerForPlugin(const PJRT_Api* c_api) { + const PLUGIN_Profiler_Api* profiler_api = FindProfilerApi(c_api); + XLA_CHECK(profiler_api); + + tsl::profiler::ProfilerFactory create_func = + [profiler_api](const tensorflow::ProfileOptions& options) { + return std::make_unique(profiler_api, + options); + }; + tsl::profiler::RegisterProfilerFactory(std::move(create_func)); +} + } // namespace profiler } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/profiler.h b/torch_xla/csrc/runtime/profiler.h index 639e6b2a6d1..d5d49540c24 100644 --- a/torch_xla/csrc/runtime/profiler.h +++ b/torch_xla/csrc/runtime/profiler.h @@ -5,6 +5,7 @@ #include "absl/container/flat_hash_map.h" #include "tsl/platform/status.h" +#include "xla/pjrt/c/pjrt_c_api.h" namespace torch_xla { namespace runtime { @@ -28,6 +29,8 @@ tsl::Status Trace( const absl::flat_hash_map>& options); +void RegisterProfilerForPlugin(const PJRT_Api* c_api); + } // namespace profiler } // namespace runtime } // namespace torch_xla