Skip to content

Commit

Permalink
[xla:collectives] NFC: Delete unused ScopedPlanAllocator
Browse files Browse the repository at this point in the history
Latest NCCL fixed performance issues with CUDA graph tracing. Delete workaround that we don't need anymore.

PiperOrigin-RevId: 703184603
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Dec 5, 2024
1 parent 23c19d7 commit 97509fb
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 413 deletions.
47 changes: 6 additions & 41 deletions third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1677,19 +1677,11 @@ absl::Status AllReduceCmd::Record(const Thunk::ExecuteParams& execute_params,
GetNcclComm(nccl_api(), *execute_params.collective_params,
*execute_params.collective_cliques, config().replica_groups,
config().group_mode, nccl_stream_id(), GetAsyncStreamKind()));
Communicator* comm = comm_handle.comm;

// Use custom allocator for persistent execution plans.
NcclApi::ScopedPersistentPlanAllocator scoped_allocator(
comm, tsl::MakeRef<NcclApi::PersistentPlanAllocator>(
execute_params.buffer_allocations->device_ordinal(),
execute_params.buffer_allocations->memory_allocator(),
execute_params.stream));

return AddTracedCommandBuffer(
execute_params, record_params, command_buffer, [&](se::Stream* stream) {
return RunAllReduce(nccl_api(), reduction_kind_, device_buffers,
*stream, comm);
*stream, comm_handle.comm);
});
}

Expand Down Expand Up @@ -1749,18 +1741,11 @@ absl::Status ReduceScatterCmd::Record(
GetNcclComm(nccl_api(), *execute_params.collective_params,
*execute_params.collective_cliques, config().replica_groups,
config().group_mode, nccl_stream_id(), GetAsyncStreamKind()));
Communicator* comm = comm_handle.comm;
// Use custom allocator for persistent execution plans.
NcclApi::ScopedPersistentPlanAllocator scoped_allocator(
comm, tsl::MakeRef<NcclApi::PersistentPlanAllocator>(
execute_params.buffer_allocations->device_ordinal(),
execute_params.buffer_allocations->memory_allocator(),
execute_params.stream));

return AddTracedCommandBuffer(
execute_params, record_params, command_buffer, [&](se::Stream* stream) {
return RunReduceScatter(nccl_api(), reduction_kind_, device_buffers,
*stream, comm);
*stream, comm_handle.comm);
});
}

Expand Down Expand Up @@ -1819,18 +1804,11 @@ absl::Status AllToAllCmd::Record(const Thunk::ExecuteParams& execute_params,
GetNcclComm(nccl_api(), *execute_params.collective_params,
*execute_params.collective_cliques, config().replica_groups,
config().group_mode, nccl_stream_id(), GetAsyncStreamKind()));
Communicator* comm = comm_handle.comm;
// Use custom allocator for persistent execution plans.
NcclApi::ScopedPersistentPlanAllocator scoped_allocator(
comm, tsl::MakeRef<NcclApi::PersistentPlanAllocator>(
execute_params.buffer_allocations->device_ordinal(),
execute_params.buffer_allocations->memory_allocator(),
execute_params.stream));

return AddTracedCommandBuffer(
execute_params, record_params, command_buffer, [&](se::Stream* stream) {
return RunAllToAll(nccl_api(), has_split_dimension_, device_buffers,
*stream, comm);
*stream, comm_handle.comm);
});
}

Expand Down Expand Up @@ -1887,17 +1865,11 @@ absl::Status AllGatherCmd::Record(const Thunk::ExecuteParams& execute_params,
GetNcclComm(nccl_api(), *execute_params.collective_params,
*execute_params.collective_cliques, config().replica_groups,
config().group_mode, nccl_stream_id(), GetAsyncStreamKind()));
Communicator* comm = comm_handle.comm;
// Use custom allocator for persistent execution plans.
NcclApi::ScopedPersistentPlanAllocator scoped_allocator(
comm, tsl::MakeRef<NcclApi::PersistentPlanAllocator>(
execute_params.buffer_allocations->device_ordinal(),
execute_params.buffer_allocations->memory_allocator(),
execute_params.stream));

return AddTracedCommandBuffer(
execute_params, record_params, command_buffer, [&](se::Stream* stream) {
return RunAllGather(nccl_api(), device_buffers, *stream, comm);
return RunAllGather(nccl_api(), device_buffers, *stream,
comm_handle.comm);
});
}

Expand Down Expand Up @@ -1956,17 +1928,10 @@ absl::Status CollectiveBroadcastCmd::Record(
GetNcclComm(nccl_api(), *execute_params.collective_params,
*execute_params.collective_cliques, config().replica_groups,
config().group_mode, nccl_stream_id(), GetAsyncStreamKind()));
Communicator* comm = comm_handle.comm;
// Use custom allocator for persistent execution plans.
NcclApi::ScopedPersistentPlanAllocator scoped_allocator(
comm, tsl::MakeRef<NcclApi::PersistentPlanAllocator>(
execute_params.buffer_allocations->device_ordinal(),
execute_params.buffer_allocations->memory_allocator(),
execute_params.stream));

return AddTracedCommandBuffer(
execute_params, record_params, command_buffer, [&](se::Stream* stream) {
return RunCollectiveBroadcast(device_buffers, *stream, comm,
return RunCollectiveBroadcast(device_buffers, *stream, comm_handle.comm,
nccl_api());
});
}
Expand Down
248 changes: 0 additions & 248 deletions third_party/xla/xla/service/gpu/runtime/nccl_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,8 @@ limitations under the License.

#include "xla/service/gpu/runtime/nccl_api.h"

#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
#include "xla/backends/gpu/collectives/nccl_collectives.h"
#include "xla/backends/gpu/collectives/nccl_communicator.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/primitive_util.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"

#if TENSORFLOW_USE_ROCM
#include "rocm/rocm_config.h"
Expand All @@ -71,220 +37,6 @@ limitations under the License.

namespace xla::gpu {

//==-----------------------------------------------------------------------===//
// Macros to return or warn on NCCL errors.
//==-----------------------------------------------------------------------===//

static absl::Status ToStatus(ncclResult_t s, const char* file, int64_t line,
const char* expr) {
if (s == ncclSuccess) return absl::OkStatus();

return absl::InternalError(absl::StrFormat(
"%s:%d: NCCL operation %s failed: %s."
" Last NCCL warning(error) log entry (may be unrelated) '%s'.",
file, line, expr, ncclGetErrorString(s), ncclGetLastError(nullptr)));
}

#define XLA_NCCL_STATUS(expr) \
xla::gpu::ToStatus(expr, __FILE__, __LINE__, #expr)

#define XLA_NCCL_RETURN_IF_ERROR(expr) \
do { \
absl::Status s = XLA_NCCL_STATUS(expr); \
if (!s.ok()) { \
return s; \
} \
} while (0)

#define XLA_NCCL_LOG_IF_ERROR(expr) \
do { \
absl::Status s = XLA_NCCL_STATUS(expr); \
if (!s.ok()) { \
LOG(ERROR) << s.ToString(); \
} \
} while (0)

#define XLA_NCCL_CHECK(expr) CHECK(XLA_NCCL_STATUS(expr).ok())

//==-----------------------------------------------------------------------===//
// Conversions between XLA and NCCL data types
//==-----------------------------------------------------------------------===//

static size_t ToNcclCount(PrimitiveType dtype, size_t count) {
return primitive_util::IsComplexType(dtype) ? count * 2 : count;
}

static absl::StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType dtype,
bool is_reduction_op) {
switch (dtype) {
case S8:
case F8E5M2:
case F8E4M3FN:
case F8E5M2FNUZ:
case F8E4M3FNUZ:
return ncclInt8;
case PRED:
case U8:
return ncclUint8;
case S32:
return ncclInt32;
case U32:
return ncclUint32;
case S64:
return ncclInt64;
case U64:
return ncclUint64;
case F16:
return ncclFloat16;
case F32:
case C64:
return ncclFloat32;
case F64:
case C128:
return ncclFloat64;
case S16:
case U16:
// For reductions we expect 16 bit integer types to be promoted to 32-bit.
if (is_reduction_op) {
return absl::InvalidArgumentError(
absl::StrFormat("Unsupported data type for reduction operation: %s",
primitive_util::LowercasePrimitiveTypeName(dtype)));
}
// For collectives that just move data around, we can use ncclFloat16 for
// 16-bit integer data types.
return ncclFloat16;
case BF16:
return ncclBfloat16;
default:
return absl::InvalidArgumentError(
absl::StrFormat("Unsupported data type: %s",
primitive_util::LowercasePrimitiveTypeName(dtype)));
}
}

static ncclRedOp_t ToNcclReduction(ReductionKind kind) {
switch (kind) {
case ReductionKind::SUM:
return ncclSum;
case ReductionKind::PRODUCT:
return ncclProd;
case ReductionKind::MIN:
return ncclMin;
case ReductionKind::MAX:
return ncclMax;
}
}

//==-----------------------------------------------------------------------===//
// Casting between opaque API structs and NCCL types.
//==-----------------------------------------------------------------------===//

static ncclComm_t Cast(const Communicator* comm) {
auto* nccl_communicator = tsl::down_cast<const NcclCommunicator*>(comm);
CHECK(nccl_communicator != nullptr) << "Unsupported XLA communicator";
return nccl_communicator->comm();
}

#if WITH_PERSISTENT_PLAN_ALLOCATOR_SUPPORT
static ncclPersistentPlanAllocator* Cast(
NcclApi::NcclPersistentPlanAllocatorHandle handle) {
return reinterpret_cast<ncclPersistentPlanAllocator*>(handle);
}

static ncclPersistentPlanAllocator** Cast(
NcclApi::NcclPersistentPlanAllocatorHandle* handle) {
return reinterpret_cast<ncclPersistentPlanAllocator**>(handle);
}

static NcclApi::NcclPersistentPlanAllocatorHandle Cast(
ncclPersistentPlanAllocator* ptr) {
return reinterpret_cast<NcclApi::NcclPersistentPlanAllocatorHandle>(ptr);
}
#endif // WITH_PERSISTENT_PLAN_ALLOCATOR_SUPPORT

//==-----------------------------------------------------------------------===//
// NcclApi::PersistentPlanAllocator
//==-----------------------------------------------------------------------===//

using PersistentPlanAllocator = NcclApi::PersistentPlanAllocator;
using ScopedPersistentPlanAllocator = NcclApi::ScopedPersistentPlanAllocator;

PersistentPlanAllocator::PersistentPlanAllocator(
int64_t device_ordinal, se::DeviceMemoryAllocator* allocator,
se::Stream* stream)
: handle_(nullptr),
device_ordinal_(device_ordinal),
allocator_(allocator),
stream_(stream) {
// NCCL persistent plan allocator is implemented as NCCL patch that is not yet
// open sourced and can't be used from OSS XLA.
#if WITH_PERSISTENT_PLAN_ALLOCATOR_SUPPORT
auto* nccl_allocator = new ncclPersistentPlanAllocator;
nccl_allocator->ctl = this;

nccl_allocator->alloc = +[](void** ptr, void* src, size_t size, void* ctl) {
auto allocator = reinterpret_cast<PersistentPlanAllocator*>(ctl);
auto allocated = allocator->AllocateAndInitialize(src, size);
if (!allocated.ok()) return ncclInternalError;
*ptr = allocated->opaque();
allocator->AddRef();
return ncclSuccess;
};

nccl_allocator->free = +[](void* ptr, void* ctl) -> ncclResult_t {
auto allocator = reinterpret_cast<PersistentPlanAllocator*>(ctl);
auto status = allocator->Deallocate(se::DeviceMemoryBase(ptr));
allocator->DropRef();
return status.ok() ? ncclSuccess : ncclInternalError;
};

handle_ = Cast(nccl_allocator);
#endif // WITH_PERSISTENT_PLAN_ALLOCATOR_SUPPORT
}

PersistentPlanAllocator::~PersistentPlanAllocator() {
#if WITH_PERSISTENT_PLAN_ALLOCATOR_SUPPORT
delete Cast(handle_);
#endif // WITH_PERSISTENT_PLAN_ALLOCATOR_SUPPORT
}

absl::StatusOr<se::DeviceMemoryBase>
PersistentPlanAllocator::AllocateAndInitialize(void* src, size_t size) {
TF_ASSIGN_OR_RETURN(auto owned_mem,
allocator_->Allocate(device_ordinal_, size));
VLOG(5) << "Allocate and initialize NCCL persistent plan; mem="
<< owned_mem->opaque() << "; size=" << size;
se::DeviceMemoryBase mem = owned_mem.Release();
TF_RETURN_IF_ERROR(stream_->Memcpy(&mem, src, size));
return mem;
}

absl::Status PersistentPlanAllocator::Deallocate(se::DeviceMemoryBase mem) {
VLOG(5) << "Deallocate NCCL persistent plan; mem=" << mem.opaque();
return allocator_->Deallocate(device_ordinal_, mem);
}

ScopedPersistentPlanAllocator::ScopedPersistentPlanAllocator(
Communicator* comm, tsl::RCReference<PersistentPlanAllocator> allocator)
: comm_(comm), allocator_(std::move(allocator)) {
#if WITH_PERSISTENT_PLAN_ALLOCATOR_SUPPORT
XLA_NCCL_CHECK(
ncclCommGetPersistentPlanAllocator(Cast(comm_), Cast(&recover_)))
<< "Failed to get NCCL persistent plan allocator";
XLA_NCCL_CHECK(ncclCommSetPersistentPlanAllocator(Cast(comm_),
Cast(allocator_->handle())))
<< "Failed to set NCCL persistent plan allocator";
#endif // WITH_PERSISTENT_PLAN_ALLOCATOR_SUPPORT
}

ScopedPersistentPlanAllocator::~ScopedPersistentPlanAllocator() {
#if WITH_PERSISTENT_PLAN_ALLOCATOR_SUPPORT
XLA_NCCL_CHECK(
ncclCommSetPersistentPlanAllocator(Cast(comm_), Cast(recover_)))
<< "Failed to set NCCL persistent plan allocator";
#endif // WITH_PERSISTENT_PLAN_ALLOCATOR_SUPPORT
}

//==-----------------------------------------------------------------------===//
// NcclApi
//==-----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 97509fb

Please sign in to comment.