From 9184c1258715be3dfa7972598ece4298ebcfb810 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 4 Dec 2024 17:35:38 -0800 Subject: [PATCH] [xla:collectives] NFC: Move Config and DeviceRank from NcclApi to Collectives API PiperOrigin-RevId: 702919277 --- .../xla/xla/backends/gpu/collectives/BUILD | 8 +++ .../gpu/collectives/gpu_clique_locking.cc | 18 ++++--- .../gpu/collectives/gpu_collectives.cc | 52 +++++++++++++++++++ .../gpu/collectives/gpu_collectives.h | 34 ++++++++++++ third_party/xla/xla/core/collectives/BUILD | 1 + .../xla/xla/core/collectives/collectives.h | 16 ++++++ .../gpu/runtime/nccl_all_to_all_thunk.cc | 16 +++--- .../xla/xla/service/gpu/runtime/nccl_api.cc | 3 +- .../xla/xla/service/gpu/runtime/nccl_api.h | 25 --------- 9 files changed, 132 insertions(+), 41 deletions(-) create mode 100644 third_party/xla/xla/backends/gpu/collectives/gpu_collectives.cc diff --git a/third_party/xla/xla/backends/gpu/collectives/BUILD b/third_party/xla/xla/backends/gpu/collectives/BUILD index 7ff026c071e35d..c03cdbf9f77573 100644 --- a/third_party/xla/xla/backends/gpu/collectives/BUILD +++ b/third_party/xla/xla/backends/gpu/collectives/BUILD @@ -85,6 +85,7 @@ cc_library( "//xla:status_macros", "//xla:types", "//xla:util", + "//xla/core/collectives", "//xla/core/collectives:clique_id", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", @@ -115,12 +116,19 @@ cc_library( cc_library( name = "gpu_collectives", + srcs = ["gpu_collectives.cc"], hdrs = ["gpu_collectives.h"], deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", "//xla/core/collectives", "//xla/core/collectives:clique_id", "//xla/core/collectives:clique_key", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:casts", ], ) diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.cc b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.cc index cc3dcb3dcab448..2404f8e989db98 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.cc +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/collectives.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" #include "xla/debug_options_flags.h" @@ -60,6 +61,8 @@ limitations under the License. namespace xla::gpu { +using DeviceRank = Collectives::DeviceRank; + //===----------------------------------------------------------------------===// // GpuClique Acquire and Initialization Timeouts //===----------------------------------------------------------------------===// @@ -178,7 +181,7 @@ static void StartGpuCliqueHeartBeatMonitor() { // defined order and do not deadlock inside underlying collective communication // library. -static auto DeviceRanksToString(absl::Span ranks) { +static auto DeviceRanksToString(absl::Span ranks) { return absl::StrJoin(ranks, ",", [](std::string* str, auto& rank) { str->append(std::to_string(rank.rank.value())); }); @@ -192,7 +195,7 @@ InitializeGpuClique(NcclApi* nccl_api, se::StreamExecutor* device, RunId run_id, const GpuCliqueKey& clique_key, const GpuCollectives::CliqueIdCallback& clique_id_callback, int32_t num_local_participants, RankId rank, - NcclApi::Config& config) { + const GpuCollectives::Config& config) { int nranks = clique_key.devices().size(); VLOG(3) << "Initialize GPU clique " << clique_key.ToString() << " rank #" << rank << "; num_local_participants=" << num_local_participants; @@ -200,7 +203,7 @@ InitializeGpuClique(NcclApi* nccl_api, se::StreamExecutor* device, RunId run_id, // Start GPU clique heart beat monitor when create a first clique. StartGpuCliqueHeartBeatMonitor(); - using RendezvousArg = std::pair; + using RendezvousArg = std::pair; // Initializes a GpuClique for given device ranks and returns a lock that // gives access to clique communicators. @@ -219,7 +222,7 @@ InitializeGpuClique(NcclApi* nccl_api, se::StreamExecutor* device, RunId run_id, } } - std::vector ranks; + std::vector ranks; ranks.reserve(args.size()); for (auto* arg : args) ranks.emplace_back(arg->first); @@ -274,7 +277,8 @@ InitializeGpuClique(NcclApi* nccl_api, se::StreamExecutor* device, RunId run_id, absl::StrFormat("initialize clique for rank %d; clique=%s; run_id=%d", rank.value(), clique_key.ToString(), run_id.ToInt()); - NcclApi::DeviceRank device_rank = {device, rank}; + GpuCollectives::Device gpu_device(device); + GpuCollectives::DeviceRank device_rank = {&gpu_device, rank}; bool synchronized = device->SynchronizeAllActivity(); // We choose not to exit early on failed synchronization, because it will lead @@ -317,7 +321,7 @@ InitializeGpuClique(NcclApi* nccl_api, se::StreamExecutor* device, RunId run_id, const GpuCliqueKey& clique_key, std::shared_ptr parent_clique, int32_t num_local_participants, RankId rank, - NcclApi::Config& config) { + const GpuCollectives::Config& config) { // Find our rank in the parent clique. const GpuCliqueKey& parent_clique_key = (*parent_clique)->key(); RankId parent_rank = @@ -471,7 +475,7 @@ absl::StatusOr> AcquireGpuClique( // We enable resource sharing between parent and split communicators by // default because that's the only reason why we use comm splitting. - NcclApi::Config config; + GpuCollectives::Config config; config.split_share = true; config.max_nchannels = max_nchannels; diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.cc b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.cc new file mode 100644 index 00000000000000..aa61a645ee3c15 --- /dev/null +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.cc @@ -0,0 +1,52 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/gpu/collectives/gpu_collectives.h" + +#include + +#include "absl/status/statusor.h" +#include "xla/core/collectives/collectives.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/util.h" +#include "tsl/platform/casts.h" + +namespace xla::gpu { + +GpuCollectives::Device::Device(se::StreamExecutor* stream_executor) + : stream_executor_(stream_executor) {} + +se::StreamExecutor* GpuCollectives::Device::stream_executor() const { + return stream_executor_; +} + +absl::StatusOr GpuCollectives::TryCast( + Collectives::Device* device) { + if (auto* gpu_device = tsl::down_cast(device)) { + return gpu_device; + } + return InvalidArgument("Collectvies device is not a GPU device"); +} + +se::DeviceMemoryBase GpuCollectives::Slice(se::DeviceMemoryBase buff, + PrimitiveType dtype, size_t offset, + size_t count) { + size_t multiplier = ShapeUtil::ByteSizeOfPrimitiveType(dtype); + return buff.GetByteSlice(offset * multiplier, count * multiplier); +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h index 15c394dec39957..fb5eb3a1a9d50f 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h @@ -16,12 +16,17 @@ limitations under the License. #ifndef XLA_BACKENDS_GPU_COLLECTIVES_GPU_COLLECTIVES_H_ #define XLA_BACKENDS_GPU_COLLECTIVES_GPU_COLLECTIVES_H_ +#include +#include #include #include "absl/status/statusor.h" #include "xla/core/collectives/clique_id.h" #include "xla/core/collectives/clique_key.h" #include "xla/core/collectives/collectives.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/xla_data.pb.h" namespace xla::gpu { @@ -32,6 +37,29 @@ class GpuCollectives : public Collectives { using CliqueIdCallback = // NOLINT std::function(const CliqueKey&)>; + // GPU collectives device is just a wrapper around the StreamExecutor. + class Device : public Collectives::Device { + public: + explicit Device(stream_executor::StreamExecutor* stream_executor); + stream_executor::StreamExecutor* stream_executor() const; + + private: + stream_executor::StreamExecutor* stream_executor_; + }; + + // Casts a Collectives::Device to a GPU device and returns an error if it's + // not a GPU device. + static absl::StatusOr TryCast(Collectives::Device* device); + + // GPU communicator configuration. + // + // For NCCL backend see configuration options documentation at: + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig + struct Config { + bool split_share = false; + int64_t max_nchannels = 0; + }; + // Returns true if collectives backend uses global config. virtual bool IsGlobalConfig() const = 0; @@ -39,6 +67,12 @@ class GpuCollectives : public Collectives { // default callback to get create a clique id if we are running in local mode. virtual absl::StatusOr GetCliqueIdCallback( const CliqueIdCallback* clique_id_callback, bool is_local) = 0; + + // Returns a slice of device memory `buff` containing `count` values of data + // type `dtype` starting from `offset`. + static stream_executor::DeviceMemoryBase Slice( + stream_executor::DeviceMemoryBase buff, PrimitiveType dtype, + size_t offset, size_t count); }; } // namespace xla::gpu diff --git a/third_party/xla/xla/core/collectives/BUILD b/third_party/xla/xla/core/collectives/BUILD index c429f1fc492a60..8e431494db5c1b 100644 --- a/third_party/xla/xla/core/collectives/BUILD +++ b/third_party/xla/xla/core/collectives/BUILD @@ -35,6 +35,7 @@ cc_library( deps = [ ":clique_id", ":communicator", + ":rank_id", "@com_google_absl//absl/status:statusor", ], ) diff --git a/third_party/xla/xla/core/collectives/collectives.h b/third_party/xla/xla/core/collectives/collectives.h index 92bd5463b92385..5624f3978198e7 100644 --- a/third_party/xla/xla/core/collectives/collectives.h +++ b/third_party/xla/xla/core/collectives/collectives.h @@ -18,6 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/rank_id.h" namespace xla { @@ -35,6 +36,21 @@ class Collectives { public: virtual ~Collectives() = default; + // A base class for the device that the collectives are running on, i.e. in + // XLA:GPU this is the GPU device (StreamExecutor). + class Device { + public: + virtual ~Device() = default; + }; + + // A collective device together with its rank in the collective clique. + struct DeviceRank { + DeviceRank(Device* device, RankId rank) : device(device), rank(rank) {} + + Device* device; + RankId rank; + }; + // Creates a unique CliqueId. virtual absl::StatusOr CreateUniqueCliqueId() const = 0; }; diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc index 2de6e956e25937..24c6240f442451 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc @@ -236,12 +236,12 @@ absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension, for (int peer = 0; peer < num_ranks; ++peer) { se::DeviceMemoryBase send_slice = - NcclApi::Slice(buffer.source_buffer, buffer.element_type, - peer * chunk_elements, chunk_elements); + nccl_api->Slice(buffer.source_buffer, buffer.element_type, + peer * chunk_elements, chunk_elements); se::DeviceMemoryBase recv_slice = - NcclApi::Slice(buffer.destination_buffer, buffer.element_type, - peer * chunk_elements, chunk_elements); + nccl_api->Slice(buffer.destination_buffer, buffer.element_type, + peer * chunk_elements, chunk_elements); TF_RETURN_IF_ERROR(nccl_api->Send(send_slice, buffer.element_type, chunk_elements, peer, comm, &stream)); @@ -298,8 +298,8 @@ absl::Status RunMemCpyAllToAll( TF_RETURN_IF_ERROR(nccl_api->GroupStart()); for (int peer = 0; peer < num_ranks; ++peer) { se::DeviceMemoryBase recv_slice = - NcclApi::Slice(buffer.destination_buffer, buffer.element_type, - peer * chunk_elements, chunk_elements); + nccl_api->Slice(buffer.destination_buffer, buffer.element_type, + peer * chunk_elements, chunk_elements); send_pointer_map[peer] = (uint64_t)recv_slice.opaque(); TF_RETURN_IF_ERROR(nccl_api->SendPtrToPeer(&send_pointer_map[peer], @@ -312,8 +312,8 @@ absl::Status RunMemCpyAllToAll( for (int peer = 0; peer < num_ranks; ++peer) { se::DeviceMemoryBase send_slice = - NcclApi::Slice(buffer.source_buffer, buffer.element_type, - peer * chunk_elements, chunk_elements); + nccl_api->Slice(buffer.source_buffer, buffer.element_type, + peer * chunk_elements, chunk_elements); se::DeviceMemoryBase dst_addr = se::DeviceMemoryBase((void*)receive_pointer_map[peer]); TF_RETURN_IF_ERROR( diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_api.cc b/third_party/xla/xla/service/gpu/runtime/nccl_api.cc index 9680e05e9ff8a4..fb675dd5b4a611 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_api.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_api.cc @@ -390,7 +390,8 @@ DefaultNcclApi::CommInitRanks(int32_t nranks, const CliqueId& clique_id, VLOG(1) << "Initialize NCCL communicator for rank #" << ranks[i].rank << " of " << nranks << "; fingerprint(id)=" << clique_id.fingerprint(); - auto activate_context = ranks[i].device->Activate(); + TF_ASSIGN_OR_RETURN(auto* device, TryCast(ranks[i].device)); + auto activate_context = device->stream_executor()->Activate(); TF_ASSIGN_OR_RETURN(auto nccl_unique_id, AsNcclUniqueId(clique_id)); XLA_NCCL_RETURN_IF_ERROR( diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_api.h b/third_party/xla/xla/service/gpu/runtime/nccl_api.h index 30b33d05acd543..92be208b6554b5 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_api.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_api.h @@ -52,14 +52,6 @@ class NcclApi : public GpuCollectives { public: virtual ~NcclApi() = default; - // Communicator configuration. - // - // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig - struct Config { - bool split_share = false; - int64_t max_nchannels = 0; - }; - // Returns a default NcclApi for a current process. Can be a real one based on // NCCL or a stub if XLA compiled without NCCL or CUDA support. static NcclApi* Default(); @@ -119,23 +111,6 @@ class NcclApi : public GpuCollectives { tsl::RCReference allocator_; }; - struct DeviceRank { - DeviceRank(se::StreamExecutor* device, RankId rank) - : device(device), rank(rank) {} - - se::StreamExecutor* device; - RankId rank; - }; - - // Returns a slice of device memory `buff` containing `count` values of data - // type `dtype` starting from `offset`. - static se::DeviceMemoryBase Slice(se::DeviceMemoryBase buff, - PrimitiveType dtype, size_t offset, - size_t count) { - size_t multiplier = ShapeUtil::ByteSizeOfPrimitiveType(dtype); - return buff.GetByteSlice(offset * multiplier, count * multiplier); - } - // Creates new communicators for given devices. // // This API doesn't have a corresponding API in NCCL and implemented as