Skip to content

Commit

Permalink
[xla:collectives] NFC: Move stubs from NcclApiStub to GpuCollectivesStub
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702901797
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Dec 5, 2024
1 parent 9f45db6 commit 43f75e7
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 16 deletions.
13 changes: 13 additions & 0 deletions third_party/xla/xla/backends/gpu/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,19 @@ cc_library(
],
)

cc_library(
name = "gpu_collectives_stub",
hdrs = ["gpu_collectives_stub.h"],
deps = [
":gpu_collectives",
"//xla:util",
"//xla/core/collectives:clique_id",
"//xla/service/gpu/runtime:nccl_api_header",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)

cc_library(
name = "nccl_errors",
hdrs = if_gpu_is_configured(["nccl_errors.h"]),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_BACKENDS_GPU_COLLECTIVES_GPU_COLLECTIVES_STUB_H_
#define XLA_BACKENDS_GPU_COLLECTIVES_GPU_COLLECTIVES_STUB_H_

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/service/gpu/runtime/nccl_api.h"
#include "xla/util.h"

namespace xla::gpu {

// A stub for GPU collectives when XLA:GPU compiled without collectives support.
class GpuCollectivesStub : public NcclApi {
public:
bool IsGlobalConfig() const final { return false; }

absl::StatusOr<CliqueId> CreateUniqueCliqueId() const final {
return UnimplementedError();
}

absl::StatusOr<const CliqueIdCallback*> GetCliqueIdCallback(
const CliqueIdCallback*, bool) final {
return UnimplementedError();
}

protected:
static absl::Status UnimplementedError() {
return Unimplemented("XLA compiled without GPU collectives support");
}
};

} // namespace xla::gpu

#endif // XLA_BACKENDS_GPU_COLLECTIVES_GPU_COLLECTIVES_STUB_H_
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/backends/gpu/collectives:gpu_collectives",
"//xla/backends/gpu/collectives:gpu_collectives_stub",
"//xla/backends/gpu/collectives:nccl_collectives",
"//xla/backends/gpu/collectives:nccl_communicator",
"//xla/core/collectives:clique_id",
Expand Down Expand Up @@ -286,6 +287,7 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/backends/gpu/collectives:gpu_collectives",
"//xla/backends/gpu/collectives:gpu_collectives_stub",
"//xla/backends/gpu/collectives:nccl_collectives",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:communicator",
Expand Down
18 changes: 2 additions & 16 deletions third_party/xla/xla/service/gpu/runtime/nccl_api_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/backends/gpu/collectives/gpu_collectives_stub.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
Expand Down Expand Up @@ -80,11 +81,7 @@ ScopedPersistentPlanAllocator::~ScopedPersistentPlanAllocator() = default;
// NcclApiStub
//===----------------------------------------------------------------------===//

static absl::Status UnimplementedError() {
return absl::UnimplementedError("XLA compiled without NCCL support");
}

class NcclApiStub final : public NcclApi {
class NcclApiStub final : public GpuCollectivesStub {
public:
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>> CommInitRanks(
int32_t, const CliqueId&, absl::Span<const DeviceRank>,
Expand All @@ -98,17 +95,6 @@ class NcclApiStub final : public NcclApi {
return UnimplementedError();
}

absl::StatusOr<CliqueId> CreateUniqueCliqueId() const final {
return UnimplementedError();
}

bool IsGlobalConfig() const final { return false; }

absl::StatusOr<const CliqueIdCallback*> GetCliqueIdCallback(
const CliqueIdCallback*, bool) final {
return UnimplementedError();
}

absl::Status GroupStart() final { return UnimplementedError(); }
absl::Status GroupEnd() final { return UnimplementedError(); }

Expand Down

0 comments on commit 43f75e7

Please sign in to comment.