forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add PTX CompilationProvider for compiling PTX via the CUDA driver
- Adds a new compilation provider for driver compilation. - Moves compilation log output parser logic into ptx_compiler_helpers and use the shared logic in both driver and nvjitlink compilation. - Adds test for the log output parser logic. - Splits the compilation_provider_test into 2 test targets - one which requires a GPU and one which doesn't PiperOrigin-RevId: 700226594
- Loading branch information
1 parent
b67efc1
commit cbf3dae
Showing
11 changed files
with
647 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
third_party/xla/xla/stream_executor/cuda/compilation_provider_test.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/* Copyright 2024 The OpenXLA Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#ifndef XLA_STREAM_EXECUTOR_CUDA_COMPILATION_PROVIDER_TEST_H_ | ||
#define XLA_STREAM_EXECUTOR_CUDA_COMPILATION_PROVIDER_TEST_H_ | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <string_view> | ||
|
||
#include <gtest/gtest.h> | ||
#include "absl/status/statusor.h" | ||
#include "xla/stream_executor/cuda/compilation_provider.h" | ||
|
||
namespace stream_executor::cuda { | ||
|
||
inline constexpr std::string_view kSubprocessCompilationProviderName = | ||
"subprocess"; | ||
inline constexpr std::string_view kNvJitLinkCompilationProviderName = | ||
"nvjitlink"; | ||
inline constexpr std::string_view kNvptxcompilerCompilationProviderName = | ||
"nvptxcompiler"; | ||
inline constexpr std::string_view kDriverCompilationProviderName = "driver"; | ||
|
||
class CompilationProviderTest | ||
: public testing::TestWithParam<std::string_view> { | ||
absl::StatusOr<std::unique_ptr<CompilationProvider>> | ||
CreateCompilationProvider(std::string_view name); | ||
|
||
void SetUp() override; | ||
std::unique_ptr<CompilationProvider> compilation_provider_; | ||
|
||
protected: | ||
CompilationProvider* compilation_provider() { | ||
return compilation_provider_.get(); | ||
} | ||
}; | ||
|
||
// Prints the test parameter name as is. Needed for gtest instantiation. | ||
struct CompilationProviderTestParamNamePrinter { | ||
std::string operator()( | ||
const ::testing::TestParamInfo<std::string_view>& name) const { | ||
return std::string(name.param); | ||
} | ||
}; | ||
|
||
} // namespace stream_executor::cuda | ||
|
||
#endif // XLA_STREAM_EXECUTOR_CUDA_COMPILATION_PROVIDER_TEST_H_ |
29 changes: 29 additions & 0 deletions
29
third_party/xla/xla/stream_executor/cuda/compilation_provider_test_with_gpu.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
/* Copyright 2024 The OpenXLA Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include <gtest/gtest.h> | ||
#include "xla/stream_executor/cuda/compilation_provider_test.h" | ||
|
||
namespace stream_executor::cuda { | ||
namespace { | ||
|
||
// The CUDA driver needs a GPU to be present, otherwise it will fail to | ||
// initialize. And without initialization no compilation. | ||
INSTANTIATE_TEST_SUITE_P(CompilationProviderTest, CompilationProviderTest, | ||
testing::Values(kDriverCompilationProviderName), | ||
CompilationProviderTestParamNamePrinter()); | ||
|
||
} // namespace | ||
} // namespace stream_executor::cuda |
Oops, something went wrong.