Skip to content

Commit

Permalink
[xla:cpu] Implement JitCompiler on top of LLVM ORC stack
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700848281
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Nov 28, 2024
1 parent c19ca57 commit 035ff9b
Show file tree
Hide file tree
Showing 7 changed files with 360 additions and 58 deletions.
34 changes: 32 additions & 2 deletions third_party/xla/xla/backends/cpu/codegen/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
load("//xla:xla.bzl", "xla_cc_test")
load(
"//xla/tsl/platform:build_config_root.bzl",
"if_llvm_aarch64_available",
"if_llvm_powerpc_available",
"if_llvm_system_z_available",
"if_llvm_x86_available",
)
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
Expand Down Expand Up @@ -46,11 +53,13 @@ cc_library(

cc_library(
name = "function_library",
srcs = ["function_library.cc"],
hdrs = ["function_library.h"],
deps = [
"//xla:util",
"//xla/backends/cpu/runtime:kernel_c_api",
"//xla/tsl/lib/gtl:int_type",
"@com_google_absl//absl/status:statusor",
"@local_tsl//tsl/platform:statusor",
],
)

Expand Down Expand Up @@ -85,26 +94,47 @@ cc_library(
srcs = ["jit_compiler.cc"],
hdrs = ["jit_compiler.h"],
deps = [
":contiguous_section_memory_manager",
":cpu_features",
":function_library",
":ir_compiler",
"//xla:util",
"//xla/service/cpu:orc_jit_memory_mapper",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:ExecutionEngine",
"@llvm-project//llvm:MC",
"@llvm-project//llvm:OrcJIT",
"@llvm-project//llvm:OrcShared",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:TargetParser",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:platform_port",
],
"@local_tsl//tsl/platform:statusor",
] + if_llvm_aarch64_available([
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
]) + if_llvm_powerpc_available([
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
]) + if_llvm_system_z_available([
"@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep
]) + if_llvm_x86_available([
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
]),
)

xla_cc_test(
name = "jit_compiler_test",
srcs = ["jit_compiler_test.cc"],
deps = [
":jit_compiler",
"//xla/tsl/lib/core:status_test_util",
"@llvm-project//llvm:AsmParser",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:OrcJIT",
Expand Down
28 changes: 28 additions & 0 deletions third_party/xla/xla/backends/cpu/codegen/function_library.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/backends/cpu/codegen/function_library.h"

#include <atomic>
#include <cstdint>

namespace xla::cpu {

FunctionLibrary::TypeId FunctionLibrary::GetNextTypeId() {
static auto* counter = new std::atomic<int64_t>(1);
return TypeId(counter->fetch_add(1));
}

} // namespace xla::cpu
44 changes: 37 additions & 7 deletions third_party/xla/xla/backends/cpu/codegen/function_library.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,55 @@ limitations under the License.
#ifndef XLA_BACKENDS_CPU_CODEGEN_FUNCTION_LIBRARY_H_
#define XLA_BACKENDS_CPU_CODEGEN_FUNCTION_LIBRARY_H_

#include <cstdint>
#include <string_view>
#include <type_traits>

#include "absl/status/statusor.h"
#include "xla/backends/cpu/runtime/kernel_c_api.h"
#include "xla/util.h"
#include "xla/tsl/lib/gtl/int_type.h"
#include "tsl/platform/statusor.h"

namespace xla::cpu {

// A library of compiled functions required by the XLA:CPU runtime to execute
// an XLA program.
// A library of functions required by the XLA:CPU runtime to execute an XLA
// program.
//
// XLA:CPU program compiles to a collection of functions that are dispatched by
// the runtime. The most common type of compiled function is an XLA CPU Kernel,
// however some operations can be compiled to auxiliary functions that are
// invoked by operation-specific Thunks, e.g. `sort` operation comparator
// compiles to a separate function used by a SortThunk in combination with an
// `std::sort` library call.
class FunctionLibrary {
public:
// We use a `TypeId` to distinguish functions of different type at run time.
TSL_LIB_GTL_DEFINE_INT_TYPE(TypeId, int64_t);
static constexpr TypeId kUnknownTypeId = TypeId(0);

virtual ~FunctionLibrary() = default;

using Kernel = XLA_CPU_Kernel*;
template <typename F, std::enable_if_t<std::is_function_v<F>>* = nullptr>
absl::StatusOr<F*> ResolveFunction(std::string_view name) {
TF_ASSIGN_OR_RETURN(void* ptr, ResolveFunction(GetTypeId<F>(), name));
return reinterpret_cast<F*>(ptr);
}

protected:
// Returns a type-erased pointer to the function with the given name and type
// id. Implementation might choose not to verify the type id and then it is up
// to the caller to ensure the resolved function is of the correct type.
virtual absl::StatusOr<void*> ResolveFunction(TypeId type_id,
std::string_view name) = 0;

virtual absl::StatusOr<Kernel> FindKernel(std::string_view name) const {
return Unimplemented("Kernel %s not found", name);
private:
// Returns a type id for a given function type.
template <typename F, std::enable_if_t<std::is_function_v<F>>* = nullptr>
static TypeId GetTypeId() {
static const TypeId id = GetNextTypeId();
return id;
}

static TypeId GetNextTypeId();
};

} // namespace xla::cpu
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/backends/cpu/codegen/ir_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> IrCompiler::operator()(
llvm::raw_svector_ostream ostream(mc_stream_buffer);

VLOG(2) << "IR after optimizations";
XLA_VLOG_LINES(2, llvm_ir::DumpToString(&module));

{ // Synchronize access to user-defined hooks.
absl::MutexLock lock(&mutex_);
Expand Down
Loading

0 comments on commit 035ff9b

Please sign in to comment.