From 40ab61d71b4b261212f625115a2b315221164532 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 28 Nov 2023 19:24:32 +0000 Subject: [PATCH] fix rebase issues --- torch_xla/csrc/runtime/BUILD | 31 ++-- .../csrc/runtime/ifrt_computation_client.cc | 138 +++++++----------- .../csrc/runtime/ifrt_computation_client.h | 35 +++-- torch_xla/csrc/runtime/initialize_pjrt.cc | 134 +++++++++++++++++ torch_xla/csrc/runtime/initialize_pjrt.h | 15 ++ .../csrc/runtime/pjrt_computation_client.cc | 112 +------------- 6 files changed, 248 insertions(+), 217 deletions(-) create mode 100644 torch_xla/csrc/runtime/initialize_pjrt.cc create mode 100644 torch_xla/csrc/runtime/initialize_pjrt.h diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index ed6f8aa56572..c15247fd16a1 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -83,19 +83,15 @@ cc_library( ":computation_client", ":debug_macros", ":env_vars", - ":multi_wait", + ":initialize_pjrt", + ":operation_manager", ":stablehlo_helper", ":tf_logging", - ":thread_pool", "@xla//xla:literal", "@xla//xla:shape_util", "@xla//xla/client:xla_computation", "@xla//xla/pjrt/distributed", - "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", - "@xla//xla/service:gpu_plugin", "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt:tfrt_cpu_pjrt_client", - "@xla//xla/pjrt:pjrt_c_api_client", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", "@tsl//tsl/profiler/lib:traceme", @@ -117,6 +113,7 @@ cc_library( ":computation_client", ":debug_macros", ":env_vars", + ":initialize_pjrt", ":operation_manager", ":profiler", ":stablehlo_helper", @@ -128,11 +125,7 @@ cc_library( "@xla//xla:shape_util", "@xla//xla/client:xla_computation", "@xla//xla/pjrt/distributed", - "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", - "@xla//xla/service:gpu_plugin", "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt:tfrt_cpu_pjrt_client", - "@xla//xla/pjrt:pjrt_c_api_client", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/platform/cloud:gcs_file_system", @@ -174,6 +167,24 @@ cc_library( hdrs = ["env_vars.h"], ) +cc_library( + name = "initialize_pjrt", + srcs = ["initialize_pjrt.cc"], + hdrs = ["initialize_pjrt.h"], + deps = [ + ":debug_macros", + ":env_vars", + ":profiler", + ":sys_util", + ":tf_logging", + ":xla_coordinator", + "@xla//xla/service:gpu_plugin", + "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", + "@xla//xla/pjrt:tfrt_cpu_pjrt_client", + "@xla//xla/pjrt:pjrt_c_api_client", + ], +) + cc_library( name = "metrics_analysis", srcs = ["metrics_analysis.cc"], diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 3e0c39627419..121ee80df1bf 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -9,22 +9,19 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_vars.h" -#include "torch_xla/csrc/runtime/multi_wait.h" +#include "torch_xla/csrc/runtime/initialize_pjrt.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/tf_logging.h" -#include "torch_xla/csrc/runtime/thread_pool.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "tsl/profiler/lib/traceme.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/pjrt/distributed/distributed.h" -#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_api.h" -#include "xla/pjrt/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/sharding.h" @@ -98,26 +95,7 @@ std::vector IfrtComputationClient::PjRtDevicesToString( IfrtComputationClient::IfrtComputationClient() { std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); - if (device_type == "CPU") { - TF_VLOG(1) << "Initializing PjRt CPU client..."; - bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true); - int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1); - client_ = xla::ifrt::PjRtClient::Create( - std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value())); - } else if (device_type == "TPU" || device_type == "TPU_C_API") { - TF_VLOG(1) << "Initializing TFRT TPU client..."; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin( - "tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so"))); - tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); - XLA_CHECK(tpu_status.ok()); - client_ = xla::ifrt::PjRtClient::Create( - std::move(xla::GetCApiClient("TPU").value())); - } else { - XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, - device_type); - } - - XLA_CHECK(client_.get() != nullptr); + client_ = xla::ifrt::PjRtClient::Create(std::move(InitializePjRt(device_type))); // PjRtDevice IDs are not guaranteed to be dense, so we need to track // a device's global ordinal separately from its device ID. Order the @@ -130,10 +108,38 @@ IfrtComputationClient::IfrtComputationClient() { global_ordinals_[device->id()] = global_ordinals_.size(); std::string device_str = PjRtDeviceToString(device); string_to_device_.emplace(device_str, device); - device_locks_.emplace(device_str, std::make_unique()); } - // manually create the device_locks for SPMD device - device_locks_.emplace(spmd_device_str, std::make_unique()); + + auto tracked_devices = GetLocalDevices(); + tracked_devices.emplace_back(spmd_device_str); + operation_manager_ = std::move(OperationManager(std::move(tracked_devices))); +} + +IfrtComputationClient::~IfrtComputationClient() { + // In the GPU case, the PjRtClient depends on the DistributedRuntimeClient + // tracked in XlaCoordinator, so the PjRtClient must be destroyed first. + client_ = nullptr; + coordinator_ = nullptr; +} + +bool IfrtComputationClient::CoordinatorInitialized() const { + return coordinator_ != nullptr; +} + +void IfrtComputationClient::InitializeCoordinator(int global_rank, + int world_size, + std::string master_addr, + std::string port) { + XLA_CHECK(coordinator_ == nullptr) + << "Can only initialize the XlaCoordinator once."; + coordinator_ = std::make_unique(global_rank, world_size, + master_addr, port); +} + +XlaCoordinator& IfrtComputationClient::GetCoordinator() { + XLA_CHECK(coordinator_ != nullptr) + << "XlaCoordinator has not been initialized"; + return *coordinator_; } void IfrtComputationClient::IfrtData::Assign( @@ -222,7 +228,7 @@ std::optional IfrtComputationClient::GetDataSharding( } std::vector IfrtComputationClient::TransferToServer( - absl::Span tensors) { + absl::Span> tensors) { metrics::TimedSection timed(TransferToServerMetric()); tsl::profiler::TraceMe activity("IfrtComputationClient::TransferToServer", tsl::profiler::TraceMeLevel::kInfo); @@ -230,35 +236,27 @@ std::vector IfrtComputationClient::TransferToServer( datas.reserve(tensors.size()); int64_t total_size = 0; for (auto& tensor : tensors) { - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor.device); + xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device()); - auto literal = std::make_shared(tensor.shape); - tensor.populate_fn(tensor, literal->untyped_data(), literal->size_bytes()); - std::vector byte_strides(literal->shape().dimensions_size()); - XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal->shape(), - absl::MakeSpan(byte_strides))); - total_size += literal->size_bytes(); + total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape()); - // Avoid use-after-free on `literal` due to unsequenced move and use. - xla::Literal* literal_pointer = literal.get(); tsl::RCReference buffer = client_ ->MakeArrayFromHostBuffer( - literal_pointer->untyped_data(), - xla::ifrt::ToDType(literal_pointer->shape().element_type()) - .value(), - xla::ifrt::Shape(literal_pointer->shape().dimensions()), - byte_strides, + tensor->data(), + xla::ifrt::ToDType(tensor->primitive_type()).value(), + xla::ifrt::Shape(tensor->dimensions()), + tensor->byte_strides(), // TODO: what is MemoryKind? xla::ifrt::SingleDeviceSharding::Create( pjrt_device, xla::ifrt::MemoryKind()), xla::PjRtClient::HostBufferSemantics:: kImmutableUntilTransferCompletes, - [literal{std::move(literal)}]() { /* frees literal */ }) + [tensor]() { /* frees tensor */ }) .value(); ComputationClient::DataPtr data = - std::make_shared(tensor.device, tensor.shape, buffer); + std::make_shared(tensor->device(), tensor->shape(), buffer); datas.push_back(data); } OutboundDataMetric()->AddSample(total_size); @@ -268,8 +266,8 @@ std::vector IfrtComputationClient::TransferToServer( } ComputationClient::DataPtr IfrtComputationClient::TransferShardsToServer( - absl::Span tensor_shards, std::string device, - xla::Shape shape, xla::OpSharding sharding) { + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) { tsl::profiler::TraceMe activity( "IfrtComputationClient::TransferShardsToServer", tsl::profiler::TraceMeLevel::kInfo); @@ -367,7 +365,7 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( auto sharded_results = ExecuteReplicated(*computations.front(), {{handle_but_not_const}}, GetLocalDevices(), execute_options); - auto replicated_output = std::dynamic_pointer_cast(sharded_results[0][0])->buffer->FullyReplicatedShard(xla::ifrt::ArrayCopySemantics::kAlwaysCopy); + auto replicated_output = std::dynamic_pointer_cast(sharded_results[0])->buffer->FullyReplicatedShard(xla::ifrt::ArrayCopySemantics::kAlwaysCopy); // TODO: sanity check outputs return *replicated_output; } @@ -575,10 +573,10 @@ IfrtComputationClient::ExecuteComputation( // return datas; } -std::vector> +std::vector IfrtComputationClient::ExecuteReplicated( const ComputationClient::Computation& computation, - const std::vector>& arguments, + const absl::Span arguments, // TODO: devices isn't doing anything helpful here absl::Span devices, const ExecuteReplicatedOptions& options) { @@ -597,9 +595,9 @@ IfrtComputationClient::ExecuteReplicated( // << "ExecuteReplicated over " << devices.size() << " devices, but " // << arguments.size() << " arguments devices."; // TODO: parallelize again if necessary - std::vector> argument_handles(arguments[0].size()); - for (int32_t i = 0; i < arguments[0].size(); ++i) { - auto ifrt_data = std::dynamic_pointer_cast(arguments[0][i]); + std::vector> argument_handles(arguments.size()); + for (int32_t i = 0; i < arguments.size(); ++i) { + auto ifrt_data = std::dynamic_pointer_cast(arguments[i]); argument_handles[i] = ifrt_data->buffer; } @@ -683,37 +681,11 @@ xla::PjRtDevice* IfrtComputationClient::StringToPjRtDevice( return pjrt_device; } -std::shared_lock IfrtComputationClient::lock_device_shared( - const std::string& device) { - std::shared_lock lock(*device_locks_[device]); - return lock; -} - -std::unique_lock IfrtComputationClient::lock_device( - const std::string& device) { - std::unique_lock lock(*device_locks_[device]); - return lock; -} - void IfrtComputationClient::WaitDeviceOps( - const std::vector& devices) { - std::unordered_set wait_devices; - if (!devices.empty()) { - for (auto& device_str : devices) { - wait_devices.insert(device_str); - } - } else { - for (auto& device_str : GetLocalDevices()) { - wait_devices.insert(device_str); - } - } - for (const std::string& device_str : wait_devices) { - TF_VLOG(3) << "Waiting for device execution for " << device_str - << " to finish"; - lock_device(device_str); - TF_VLOG(3) << "Waiting for device execution for " << device_str - << " to finish.. Done"; - } + absl::Span devices) { + TF_VLOG(3) << "Waiting for " << absl::StrJoin(devices, ", "); + operation_manager_.WaitForDevices(devices.empty() ? GetLocalDevices() + : devices); } std::map IfrtComputationClient::GetMetrics() const { diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index f0ac786a9beb..e7d5b8900c5b 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -10,6 +10,7 @@ #include "absl/types/span.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/operation_manager.h" #include "torch_xla/csrc/runtime/util.h" #include "xla/client/xla_computation.h" #include "xla/literal.h" @@ -26,6 +27,7 @@ namespace runtime { class IfrtComputationClient : public ComputationClient { public: IfrtComputationClient(); + ~IfrtComputationClient(); DataPtr CreateDataPlaceholder(std::string device, xla::Shape shape) override; @@ -39,7 +41,7 @@ class IfrtComputationClient : public ComputationClient { std::optional GetDataSharding(DataPtr handle) override; std::vector TransferToServer( - absl::Span tensors) override; + absl::Span> tensors) override; // Use XLA replication to re-assemble the sharded data. // DataPtr ReplicateShardedData(const DataPtr& handle); @@ -47,9 +49,9 @@ class IfrtComputationClient : public ComputationClient { std::vector TransferFromServer( absl::Span handles) override; - DataPtr TransferShardsToServer(absl::Span tensor_shards, - std::string device, xla::Shape shape, - xla::OpSharding sharding) override; + DataPtr TransferShardsToServer( + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) override; DataPtr CopyToDevice(DataPtr data, std::string dst) override; @@ -61,9 +63,9 @@ class IfrtComputationClient : public ComputationClient { const std::string& device, const ExecuteComputationOptions& options) override; - std::vector> ExecuteReplicated( + std::vector ExecuteReplicated( const Computation& computation, - const std::vector>& arguments, + const absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) override; @@ -88,12 +90,18 @@ class IfrtComputationClient : public ComputationClient { std::shared_ptr> GetReplicationDevices() override; - void PrepareToExit() override { return; }; - - void WaitDeviceOps(const std::vector& devices) override; + void WaitDeviceOps(absl::Span devices) override; std::map GetMetrics() const override; + void InitializeCoordinator(int global_rank, int world_size, + std::string master_addr, + std::string port) override; + + XlaCoordinator& GetCoordinator() override; + + bool CoordinatorInitialized() const override; + // NOT IMPLEMENTED MemoryInfo GetMemoryInfo(const std::string& device) override { @@ -102,18 +110,17 @@ class IfrtComputationClient : public ComputationClient { private: std::shared_ptr client_; + std::unique_ptr coordinator_; // global_ordinals_ tracks a map from PjRtDeviceId to the device's // dense global ordinal. std::unordered_map global_ordinals_; std::unordered_map string_to_device_; std::shared_ptr> replication_devices_; - std::unordered_map> - device_locks_; + OperationManager operation_manager_; + tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool( + tsl::Env::Default(), "ifrt", std::thread::hardware_concurrency()); xla::PjRtDevice* StringToPjRtDevice(const std::string& device); - std::shared_lock lock_device_shared( - const std::string& device); - std::unique_lock lock_device(const std::string& device); std::string PjRtDeviceToString(xla::PjRtDevice* const device) const; std::vector PjRtDevicesToString( diff --git a/torch_xla/csrc/runtime/initialize_pjrt.cc b/torch_xla/csrc/runtime/initialize_pjrt.cc new file mode 100644 index 000000000000..f9d2166b2a75 --- /dev/null +++ b/torch_xla/csrc/runtime/initialize_pjrt.cc @@ -0,0 +1,134 @@ +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/env_vars.h" +#include "torch_xla/csrc/runtime/profiler.h" +#include "torch_xla/csrc/runtime/sys_util.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" +#include "torch_xla/csrc/runtime/tf_logging.h" + +namespace torch_xla { +namespace runtime { + +namespace { + +xla::GpuAllocatorConfig GetGpuAllocatorConfig() { + auto allocator_config = xla::GpuAllocatorConfig{}; + if (sys_util::GetEnvString(env::kEnvPjrtAllocatorCudaAsync, "").empty() && + sys_util::GetEnvString(env::kEnvPjrtAllocatorPreallocate, "").empty() && + sys_util::GetEnvString(env::kEnvPjrtAllocatorFraction, "").empty()) { + return allocator_config; + } + if (sys_util::GetEnvBool(env::kEnvPjrtAllocatorCudaAsync, false)) { + allocator_config.kind = xla::GpuAllocatorConfig::Kind::kCudaAsync; + } + allocator_config.preallocate = + sys_util::GetEnvBool(env::kEnvPjrtAllocatorPreallocate, true); + allocator_config.memory_fraction = + sys_util::GetEnvDouble(env::kEnvPjrtAllocatorFraction, 0.75); + return allocator_config; +} + +} + +std::unique_ptr InitializePjRt(const std::string& device_type) { + std::unique_ptr client; + + if (device_type == "CPU") { + TF_VLOG(1) << "Initializing PjRt CPU client..."; + bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true); + int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1); + client = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value()); + } else if (device_type == "TPU" || device_type == "TPU_C_API") { + TF_VLOG(1) << "Initializing TFRT TPU client..."; + // Prefer $TPU_LIBRARY_PATH if set + auto tpu_library_path = sys_util::GetEnvString( + env::kEnvTpuLibraryPath, + sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so")); + XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status()); + tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); + XLA_CHECK_OK(tpu_status); + client = std::move(xla::GetCApiClient("TPU").value()); + const PJRT_Api* c_api = + static_cast(client.get())->pjrt_c_api(); + profiler::RegisterProfilerForPlugin(c_api); + } else if (device_type == "TPU_LEGACY") { + XLA_ERROR() << "TPU_LEGACY client is no longer available."; + } else if (device_type == "GPU" || device_type == "CUDA" || + device_type == "ROCM") { + TF_VLOG(1) << "Initializing PjRt GPU client..."; + bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true); + int local_process_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0); + int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank); + int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); + int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); + std::string master_addr = + runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); + std::string port = runtime::sys_util::GetEnvString( + "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); + + + // Use the XlaCoordinator as the distributed key-value store. + auto coordinator = std::make_unique( + global_process_rank, global_world_size, master_addr, port); + std::shared_ptr distributed_client = + coordinator->GetClient(); + auto allowed_devices = + std::make_optional>(std::set{local_process_rank}); + xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; + xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; + if (distributed_client != nullptr) { + std::string key_prefix = "gpu:"; + kv_get = [distributed_client, key_prefix](const std::string& k, + absl::Duration timeout) { + return distributed_client->BlockingKeyValueGet( + absl::StrCat(key_prefix, k), timeout); + }; + kv_put = [distributed_client, key_prefix](const std::string& k, + const std::string& v) { + return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); + }; + } + TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" + << global_process_rank << ", num_nodes=" << global_world_size; + client = std::move(xla::GetStreamExecutorGpuClient( + /*asynchronous=*/async, + /*allocator_config=*/GetGpuAllocatorConfig(), + /*node_id=*/global_process_rank, + /*num_nodes=*/global_world_size, + /*allowed_devices=*/allowed_devices, + /*platform_name=*/"gpu", + /*should_stage_host_to_device_transfers=*/true, + /*kv_get=*/kv_get, + /*kv_put=*/kv_put) + .value()); + } else if (device_type == "XPU") { + TF_VLOG(1) << "Initializing PjRt XPU client..."; + XLA_CHECK_OK( + pjrt::LoadPjrtPlugin( + "xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so")) + .status()); + client = std::move(xla::GetCApiClient("XPU").value()); + } else if (device_type == "NEURON") { + TF_VLOG(1) << "Initializing PjRt NEURON client..."; + XLA_CHECK_OK(pjrt::LoadPjrtPlugin("NEURON", sys_util::GetEnvString( + env::kEnvNeuronLibraryPath, + "libneuronpjrt.so")) + .status()); + client = std::move(xla::GetCApiClient("NEURON").value()); + } else { + XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, + device_type); + } + + XLA_CHECK(client.get() != nullptr); + + return std::move(client); +} + +} +} diff --git a/torch_xla/csrc/runtime/initialize_pjrt.h b/torch_xla/csrc/runtime/initialize_pjrt.h new file mode 100644 index 000000000000..395deac1182c --- /dev/null +++ b/torch_xla/csrc/runtime/initialize_pjrt.h @@ -0,0 +1,15 @@ +#ifndef XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_ +#define XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_ + + +#include "xla/pjrt/pjrt_client.h" + +namespace torch_xla { +namespace runtime { + +std::unique_ptr InitializePjRt(const std::string& device_type); + +} +} + +#endif // XLA_CLIENT_INITIALIZE_PJRT_H_ diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index f12cbfb40f76..7eecc0d5dbb1 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -12,6 +12,7 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_vars.h" +#include "torch_xla/csrc/runtime/initialize_pjrt.h" #include "torch_xla/csrc/runtime/operation_manager.h" #include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" @@ -24,14 +25,8 @@ #include "xla/client/xla_computation.h" #include "xla/layout_util.h" #include "xla/literal.h" -#include "xla/pjrt/c/pjrt_c_api.h" -#include "xla/pjrt/distributed/distributed.h" -#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" -#include "xla/pjrt/pjrt_api.h" -#include "xla/pjrt/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "xla/shape.h" using xla::internal::XlaBuilderFriend; @@ -67,23 +62,6 @@ xla::Shape host_output_shape(xla::PjRtBuffer* buffer) { return xla::ShapeUtil::DeviceShapeToHostShape(shape); } -xla::GpuAllocatorConfig GetGpuAllocatorConfig() { - auto allocator_config = xla::GpuAllocatorConfig{}; - if (sys_util::GetEnvString(env::kEnvPjrtAllocatorCudaAsync, "").empty() && - sys_util::GetEnvString(env::kEnvPjrtAllocatorPreallocate, "").empty() && - sys_util::GetEnvString(env::kEnvPjrtAllocatorFraction, "").empty()) { - return allocator_config; - } - if (sys_util::GetEnvBool(env::kEnvPjrtAllocatorCudaAsync, false)) { - allocator_config.kind = xla::GpuAllocatorConfig::Kind::kCudaAsync; - } - allocator_config.preallocate = - sys_util::GetEnvBool(env::kEnvPjrtAllocatorPreallocate, true); - allocator_config.memory_fraction = - sys_util::GetEnvDouble(env::kEnvPjrtAllocatorFraction, 0.75); - return allocator_config; -} - } // namespace std::string PjRtComputationClient::PjRtDeviceToString( @@ -109,93 +87,7 @@ std::vector PjRtComputationClient::PjRtDevicesToString( PjRtComputationClient::PjRtComputationClient() { std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); - if (device_type == "CPU") { - TF_VLOG(1) << "Initializing PjRt CPU client..."; - bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncCpuClient, true); - int cpu_device_count = sys_util::GetEnvInt(env::kEnvNumCpu, 1); - client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value()); - } else if (device_type == "TPU" || device_type == "TPU_C_API") { - TF_VLOG(1) << "Initializing TFRT TPU client..."; - // Prefer $TPU_LIBRARY_PATH if set - auto tpu_library_path = sys_util::GetEnvString( - env::kEnvTpuLibraryPath, - sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so")); - XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status()); - tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); - XLA_CHECK_OK(tpu_status); - client_ = std::move(xla::GetCApiClient("TPU").value()); - const PJRT_Api* c_api = - static_cast(client_.get())->pjrt_c_api(); - profiler::RegisterProfilerForPlugin(c_api); - } else if (device_type == "TPU_LEGACY") { - XLA_ERROR() << "TPU_LEGACY client is no longer available."; - } else if (device_type == "GPU" || device_type == "CUDA" || - device_type == "ROCM") { - TF_VLOG(1) << "Initializing PjRt GPU client..."; - bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true); - int local_process_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0); - int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank); - int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); - int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); - std::string master_addr = - runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); - std::string port = runtime::sys_util::GetEnvString( - "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); - - // Use the XlaCoordinator as the distributed key-value store. - coordinator_ = std::make_unique( - global_process_rank, global_world_size, master_addr, port); - std::shared_ptr distributed_client = - coordinator_->GetClient(); - auto allowed_devices = - std::make_optional>(std::set{local_process_rank}); - xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; - xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; - if (distributed_client != nullptr) { - std::string key_prefix = "gpu:"; - kv_get = [distributed_client, key_prefix](const std::string& k, - absl::Duration timeout) { - return distributed_client->BlockingKeyValueGet( - absl::StrCat(key_prefix, k), timeout); - }; - kv_put = [distributed_client, key_prefix](const std::string& k, - const std::string& v) { - return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); - }; - } - TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" - << global_process_rank << ", num_nodes=" << global_world_size; - client_ = std::move(xla::GetStreamExecutorGpuClient( - /*asynchronous=*/async, - /*allocator_config=*/GetGpuAllocatorConfig(), - /*node_id=*/global_process_rank, - /*num_nodes=*/global_world_size, - /*allowed_devices=*/allowed_devices, - /*platform_name=*/"gpu", - /*should_stage_host_to_device_transfers=*/true, - /*kv_get=*/kv_get, - /*kv_put=*/kv_put) - .value()); - } else if (device_type == "XPU") { - TF_VLOG(1) << "Initializing PjRt XPU client..."; - XLA_CHECK_OK( - pjrt::LoadPjrtPlugin( - "xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so")) - .status()); - client_ = std::move(xla::GetCApiClient("XPU").value()); - } else if (device_type == "NEURON") { - TF_VLOG(1) << "Initializing PjRt NEURON client..."; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin("NEURON", sys_util::GetEnvString( - env::kEnvNeuronLibraryPath, - "libneuronpjrt.so")) - .status()); - client_ = std::move(xla::GetCApiClient("NEURON").value()); - } else { - XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, - device_type); - } - - XLA_CHECK(client_.get() != nullptr); + client_ = std::move(InitializePjRt(device_type)); // PjRtDevice IDs are not guaranteed to be dense, so we need to track // a device's global ordinal separately from its device ID. Order the