Skip to content

Commit

Permalink
Add PTX CompilationProvider for compiling PTX via the CUDA driver
Browse files Browse the repository at this point in the history
- Adds a new compilation provider for driver compilation.
- Moves compilation log output parser logic into ptx_compiler_helpers and use the shared logic in both driver and nvjitlink compilation.
- Adds test for the log output parser logic.
- Splits the compilation_provider_test into 2 test targets - one which requires a GPU and one which doesn't

PiperOrigin-RevId: 700226594
  • Loading branch information
beckerhe authored and tensorflower-gardener committed Nov 26, 2024
1 parent b67efc1 commit cbf3dae
Show file tree
Hide file tree
Showing 11 changed files with 647 additions and 80 deletions.
119 changes: 110 additions & 9 deletions third_party/xla/xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library")
load(
"@local_tsl//tsl/platform:build_config_root.bzl",
"if_static",
"tf_cuda_tests_tags",
)
load(
"@local_tsl//tsl/platform:rules_cc.bzl",
Expand Down Expand Up @@ -659,7 +660,22 @@ cc_library(
srcs = ["ptx_compiler_helpers.cc"],
hdrs = ["ptx_compiler_helpers.h"],
deps = [
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
)

cc_test(
name = "ptx_compiler_helpers_test",
srcs = ["ptx_compiler_helpers_test.cc"],
deps = [
":ptx_compiler_helpers",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:test",
],
)

Expand Down Expand Up @@ -1489,26 +1505,32 @@ cc_library(
],
)

xla_cc_test(
name = "compilation_provider_test",
# compilation_provider_test is split into two targets since only a subset of the tests need a GPU to run.
cc_library(
name = "compilation_provider_test_lib",
testonly = True,
srcs = ["compilation_provider_test.cc"],
args = if_google([
# nvjitlink allocates memory and only keeps a pointer past the usual offset of 1024 bytes;
# so we need to increase the max pointer offset. -1 means no limit.
# This is only relevant for Google's HeapLeakChecker. The newer Leak sanitizer doesn't
# have this issue.
"--heap_check_max_pointer_offset=-1",
]),
hdrs = ["compilation_provider_test.h"],
tags = [
"cuda-only",
"gpu",
],
deps = [
":compilation_options",
":compilation_provider",
":cuda_platform", # buildcleaner: keep
":cuda_platform_id",
":driver_compilation_provider",
":nvjitlink_compilation_provider",
":nvjitlink_support",
":nvptxcompiler_compilation_provider",
":ptx_compiler_support",
":subprocess_compilation",
":subprocess_compilation_provider",
"//xla/stream_executor:device_description",
"//xla/stream_executor:platform",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:stream_executor_h",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand All @@ -1520,6 +1542,51 @@ xla_cc_test(
],
)

xla_cc_test(
name = "compilation_provider_test_without_gpu",
srcs = ["compilation_provider_test_without_gpu.cc"],
args = if_google([
# nvjitlink allocates memory and only keeps a pointer past the usual offset of 1024 bytes;
# so we need to increase the max pointer offset. -1 means no limit.
# This is only relevant for Google's HeapLeakChecker. The newer Leak sanitizer doesn't
# have this issue.
"--heap_check_max_pointer_offset=-1",
]),
tags = [
"cuda-only",
"gpu",
],
deps = [
":compilation_provider_test_lib",
"@com_google_googletest//:gtest_main",
],
)

xla_cc_test(
name = "compilation_provider_test_with_gpu",
srcs = ["compilation_provider_test_with_gpu.cc"],
tags = [
"cuda-only",
"gpu",
] + tf_cuda_tests_tags(),
deps = [
":compilation_provider_test_lib",
"@com_google_googletest//:gtest_main",
],
)

test_suite(
name = "compilation_provider_test",
tags = [
"cuda-only",
"gpu",
],
tests = [
":compilation_provider_test_with_gpu",
":compilation_provider_test_without_gpu",
],
)

cc_library(
name = "nvjitlink_compilation_provider",
srcs = ["nvjitlink_compilation_provider.cc"],
Expand Down Expand Up @@ -1621,3 +1688,37 @@ xla_cc_test(
"@local_tsl//tsl/platform:test",
],
)

cc_library(
name = "driver_compilation_provider",
srcs = ["driver_compilation_provider.cc"],
hdrs = ["driver_compilation_provider.h"],
tags = [
"cuda-only",
"gpu",
],
deps = [
":compilation_options",
":compilation_provider",
":cuda_platform_id",
":cuda_status",
":ptx_compiler_helpers",
"//xla/stream_executor:activate_context",
"//xla/stream_executor:device_description",
"//xla/stream_executor:platform",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:stream_executor_h",
"//xla/tsl/cuda", # buildcleaner: keep
"@com_google_absl//absl/base",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@local_config_cuda//cuda:cuda_headers",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h"
#include "xla/stream_executor/cuda/compilation_options.h"
#include "xla/stream_executor/cuda/compilation_provider_test.h"
#include "xla/stream_executor/cuda/driver_compilation_provider.h"
#include "xla/stream_executor/cuda/nvjitlink_compilation_provider.h"
#include "xla/stream_executor/cuda/nvjitlink_support.h"
#include "xla/stream_executor/cuda/nvptxcompiler_compilation_provider.h"
Expand All @@ -41,7 +43,6 @@ limitations under the License.
#include "tsl/platform/threadpool.h"

namespace stream_executor::cuda {
namespace {
using ::testing::_;
using ::testing::AnyOf;
using ::testing::HasSubstr;
Expand All @@ -51,48 +52,30 @@ using ::tsl::testing::IsOk;
using ::tsl::testing::IsOkAndHolds;
using ::tsl::testing::StatusIs;

constexpr std::string_view kSubprocessCompilationProviderName = "subprocess";
constexpr std::string_view kNvJitLinkCompilationProviderName = "nvjitlink";
constexpr std::string_view kNvptxcompilerCompilationProviderName =
"nvptxcompiler";

class CompilationProviderTest
: public testing::TestWithParam<std::string_view> {
absl::StatusOr<std::unique_ptr<CompilationProvider>>
CreateCompilationProvider(std::string_view name);

void SetUp() override {
void CompilationProviderTest::SetUp() {
#ifdef ABSL_HAVE_MEMORY_SANITIZER
if (GetParam() == kNvJitLinkCompilationProviderName) {
GTEST_SKIP() << "nvjitlink is a precompiled and not instrumented binary "
"library, so it's not compatible with MSAN.";
}
if (GetParam() == kNvptxcompilerCompilationProviderName) {
GTEST_SKIP() << "nvptxcompiler is a precompiled and not instrumented "
"binary library, so it's not compatible with MSAN.";
}
if (GetParam() == kNvJitLinkCompilationProviderName) {
GTEST_SKIP() << "nvjitlink is a precompiled and not instrumented binary "
"library, so it's not compatible with MSAN.";
}
if (GetParam() == kNvptxcompilerCompilationProviderName) {
GTEST_SKIP() << "nvptxcompiler is a precompiled and not instrumented "
"binary library, so it's not compatible with MSAN.";
}
#endif

if (GetParam() == kNvJitLinkCompilationProviderName &&
!IsLibNvJitLinkSupported()) {
GTEST_SKIP() << "nvjitlink is not supported in this build.";
}
if (GetParam() == kNvptxcompilerCompilationProviderName &&
!IsLibNvPtxCompilerSupported()) {
GTEST_SKIP() << "nvptxcompiler is not supported in this build.";
}

TF_ASSERT_OK_AND_ASSIGN(compilation_provider_,
CreateCompilationProvider(GetParam()));
if (GetParam() == kNvJitLinkCompilationProviderName &&
!IsLibNvJitLinkSupported()) {
GTEST_SKIP() << "nvjitlink is not supported in this build.";
}

std::unique_ptr<CompilationProvider> compilation_provider_;

protected:
CompilationProvider* compilation_provider() {
return compilation_provider_.get();
if (GetParam() == kNvptxcompilerCompilationProviderName &&
!IsLibNvPtxCompilerSupported()) {
GTEST_SKIP() << "nvptxcompiler is not supported in this build.";
}
};

TF_ASSERT_OK_AND_ASSIGN(compilation_provider_,
CreateCompilationProvider(GetParam()));
}

absl::StatusOr<std::unique_ptr<CompilationProvider>>
CompilationProviderTest::CreateCompilationProvider(std::string_view name) {
Expand All @@ -112,6 +95,10 @@ CompilationProviderTest::CreateCompilationProvider(std::string_view name) {
return std::make_unique<NvptxcompilerCompilationProvider>();
}

if (name == kDriverCompilationProviderName) {
return std::make_unique<DriverCompilationProvider>();
}

return absl::NotFoundError(
absl::StrCat("Unknown compilation provider: ", name));
}
Expand Down Expand Up @@ -351,6 +338,9 @@ TEST_P(CompilationProviderTest, CancelsOnRegSpill) {
if (!compilation_provider()->SupportsCompileAndLink()) {
GTEST_SKIP() << "Compilation provider doesn't support CompileAndLink";
}
if (GetParam() == kDriverCompilationProviderName) {
GTEST_SKIP() << "Driver compilation doesn't support cancel_if_reg_spill";
}

std::string dependent_ptx = absl::StrReplaceAll(
kDependentPtx, {{"// Insert .maxnreg directive here!", ".maxnreg 16"}});
Expand Down Expand Up @@ -487,11 +477,6 @@ TEST_P(CompilationProviderTest, ParallelCompileAndLinkReturnsSameResult) {
}
}

INSTANTIATE_TEST_SUITE_P(
CompilationProviderTest, CompilationProviderTest,
testing::Values(kSubprocessCompilationProviderName,
kNvJitLinkCompilationProviderName,
kNvptxcompilerCompilationProviderName));
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CompilationProviderTest);

} // namespace
} // namespace stream_executor::cuda
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_STREAM_EXECUTOR_CUDA_COMPILATION_PROVIDER_TEST_H_
#define XLA_STREAM_EXECUTOR_CUDA_COMPILATION_PROVIDER_TEST_H_

#include <memory>
#include <string>
#include <string_view>

#include <gtest/gtest.h>
#include "absl/status/statusor.h"
#include "xla/stream_executor/cuda/compilation_provider.h"

namespace stream_executor::cuda {

inline constexpr std::string_view kSubprocessCompilationProviderName =
"subprocess";
inline constexpr std::string_view kNvJitLinkCompilationProviderName =
"nvjitlink";
inline constexpr std::string_view kNvptxcompilerCompilationProviderName =
"nvptxcompiler";
inline constexpr std::string_view kDriverCompilationProviderName = "driver";

class CompilationProviderTest
: public testing::TestWithParam<std::string_view> {
absl::StatusOr<std::unique_ptr<CompilationProvider>>
CreateCompilationProvider(std::string_view name);

void SetUp() override;
std::unique_ptr<CompilationProvider> compilation_provider_;

protected:
CompilationProvider* compilation_provider() {
return compilation_provider_.get();
}
};

// Prints the test parameter name as is. Needed for gtest instantiation.
struct CompilationProviderTestParamNamePrinter {
std::string operator()(
const ::testing::TestParamInfo<std::string_view>& name) const {
return std::string(name.param);
}
};

} // namespace stream_executor::cuda

#endif // XLA_STREAM_EXECUTOR_CUDA_COMPILATION_PROVIDER_TEST_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <gtest/gtest.h>
#include "xla/stream_executor/cuda/compilation_provider_test.h"

namespace stream_executor::cuda {
namespace {

// The CUDA driver needs a GPU to be present, otherwise it will fail to
// initialize. And without initialization no compilation.
INSTANTIATE_TEST_SUITE_P(CompilationProviderTest, CompilationProviderTest,
testing::Values(kDriverCompilationProviderName),
CompilationProviderTestParamNamePrinter());

} // namespace
} // namespace stream_executor::cuda
Loading

0 comments on commit cbf3dae

Please sign in to comment.