diff --git a/third_party/xla/xla/backends/cpu/codegen/BUILD b/third_party/xla/xla/backends/cpu/codegen/BUILD index 34b498d5a0fdac..d6f4a0c3f504d2 100644 --- a/third_party/xla/xla/backends/cpu/codegen/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/BUILD @@ -1,4 +1,4 @@ -load("//xla:xla.bzl", "xla_cc_test") +load("//xla:xla.bzl", "xla_cc_test", "xla_internal") load( "//xla/tsl/platform:build_config_root.bzl", "if_llvm_aarch64_available", @@ -131,7 +131,7 @@ cc_library( "@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep ]) + if_llvm_x86_available([ "@llvm-project//llvm:X86CodeGen", # fixdeps: keep - ]), + ]) + xla_internal(["service/cpu:named_orc_jit_memory_mapper"]), ) xla_cc_test( diff --git a/third_party/xla/xla/backends/cpu/codegen/jit_compiler.cc b/third_party/xla/xla/backends/cpu/codegen/jit_compiler.cc index 057324092f4efc..3ee60bdbff2c1a 100644 --- a/third_party/xla/xla/backends/cpu/codegen/jit_compiler.cc +++ b/third_party/xla/xla/backends/cpu/codegen/jit_compiler.cc @@ -121,11 +121,15 @@ absl::StatusOr JitCompiler::Create( options.max_cpu_feature); TF_ASSIGN_OR_RETURN(auto target_machine, target_machine_builder()); + // Dispatch compilation tasks using the provided task runner. + auto task_dispatcher = + std::make_unique(std::move(task_runner)); + TaskDispatcher* task_dispatcher_ptr = task_dispatcher.get(); + // LLVM execution session that holds jit-compiled functions. auto execution_session = std::make_unique( std::make_unique( - /*SSP=*/nullptr, - std::make_unique(std::move(task_runner)))); + /*SSP=*/nullptr, std::move(task_dispatcher))); execution_session->setErrorReporter([](llvm::Error err) { LOG(ERROR) << "LLVM compilation error: " << llvm::toString(std::move(err)); @@ -136,10 +140,10 @@ absl::StatusOr JitCompiler::Create( target_machine_builder, std::move(options.ir_compiler_options), std::move(options.ir_compiler_hooks)); - return JitCompiler(std::move(target_machine_builder), - std::move(target_machine), std::move(execution_session), - std::move(ir_compiler), options.num_dylibs, - std::move(options.definition_generator)); + return JitCompiler( + std::move(target_machine_builder), std::move(target_machine), + task_dispatcher_ptr, std::move(execution_session), std::move(ir_compiler), + options.num_dylibs, std::move(options.definition_generator)); } static std::unique_ptr @@ -162,11 +166,13 @@ static std::unique_ptr CreateCompileLayer( JitCompiler::JitCompiler( IrCompiler::TargetMachineBuilder target_machine_builder, std::shared_ptr target_machine, + TaskDispatcher* task_dispatcher, std::unique_ptr execution_session, std::unique_ptr ir_compiler, size_t num_dylibs, DefinitionGenerator definition_generator) : target_machine_builder_(std::move(target_machine_builder)), target_machine_(std::move(target_machine)), + task_dispatcher_(task_dispatcher), execution_session_(std::move(execution_session)), object_layer_(CreateObjectLinkingLayer(*execution_session_)), compile_layer_(CreateCompileLayer(*execution_session_, *object_layer_, @@ -267,6 +273,10 @@ absl::StatusOr> JitCompiler::Compile( // Look up all requested symbols in the execution session. auto symbol_map = execution_session_->lookup(std::move(search_order), std::move(lookup_set)); + + // Wait for all compilation tasks to finish. + task_dispatcher_->shutdown(); + if (auto err = symbol_map.takeError()) { return Internal("%s", llvm::toString(std::move(err))); } @@ -342,11 +352,6 @@ JitCompiler::CompiledFunctionLibrary::~CompiledFunctionLibrary() { if (auto err = execution_session_->endSession()) { execution_session_->reportError(std::move(err)); } - // Explicitly destroy the execution session to ensure that all tasks are - // finished, because otherwise object layer materialization running inside the - // task dispatched triggers use-after-free errors. This is super fishy, and we - // don't really understand why this is happening. - execution_session_.reset(); } absl::StatusOr JitCompiler::CompiledFunctionLibrary::ResolveFunction( diff --git a/third_party/xla/xla/backends/cpu/codegen/jit_compiler.h b/third_party/xla/xla/backends/cpu/codegen/jit_compiler.h index 54f0e0cdf469a4..b66f7dab7b762d 100644 --- a/third_party/xla/xla/backends/cpu/codegen/jit_compiler.h +++ b/third_party/xla/xla/backends/cpu/codegen/jit_compiler.h @@ -145,12 +145,6 @@ class JitCompiler { llvm::TargetMachine* target_machine() { return target_machine_.get(); } private: - JitCompiler(IrCompiler::TargetMachineBuilder target_machine_builder, - std::shared_ptr target_machine, - std::unique_ptr execution_session, - std::unique_ptr ir_compiler, size_t num_dylibs, - DefinitionGenerator definition_generator); - // LLVM ORC task dispatcher that uses `TaskRunner` to run compilation tasks. class TaskDispatcher : public llvm::orc::TaskDispatcher { public: @@ -192,11 +186,20 @@ class JitCompiler { absl::flat_hash_map symbols_map_; }; + JitCompiler(IrCompiler::TargetMachineBuilder target_machine_builder, + std::shared_ptr target_machine, + TaskDispatcher* task_dispatcher, + std::unique_ptr execution_session, + std::unique_ptr ir_compiler, size_t num_dylibs, + DefinitionGenerator definition_generator); + // Target machine builder that is used to construct target machines for this // instance of `JitCompiler` (when compiling LLVM modules in parallel). IrCompiler::TargetMachineBuilder target_machine_builder_; std::shared_ptr target_machine_; + TaskDispatcher* task_dispatcher_; // owned by `execution_session_` + std::unique_ptr execution_session_; std::unique_ptr object_layer_; std::unique_ptr compile_layer_; diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index b5bfdf8adb5372..68f349cedbc6fa 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -230,6 +230,7 @@ cc_library( ":onednn_contraction_rewriter", ":onednn_ops_rewriter", ":parallel_task_assignment", + ":runtime_symbol_generator", ":simple_orc_jit", ":thunk_emitter", ":xla_framework", @@ -244,6 +245,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/backends/cpu/codegen:cpu_features", + "//xla/backends/cpu/codegen:function_library", "//xla/backends/cpu/codegen:ir_compiler", "//xla/backends/cpu/codegen:jit_compiler", "//xla/backends/cpu/codegen:target_machine_features", @@ -623,6 +625,7 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/codegen:function_library", "//xla/backends/cpu/runtime:buffer_allocations", "//xla/backends/cpu/runtime:thread_pool_task_runner", "//xla/backends/cpu/runtime:thunk", @@ -632,6 +635,7 @@ cc_library( "//xla/service:custom_call_status", "//xla/service:custom_call_status_internal", "//xla/service:executable", + "//xla/service:hlo_execution_profile", "//xla/service:hlo_profile_printer_data_cc", "//xla/service:hlo_value", "//xla/service:maybe_owning_device_memory", @@ -653,6 +657,7 @@ cc_library( "@llvm-project//llvm:OrcShared", "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 2189e82bf3e4a6..3615bca8aafa9b 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -77,6 +78,7 @@ limitations under the License. #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/DialectConversion.h" #include "xla/backends/cpu/codegen/cpu_features.h" +#include "xla/backends/cpu/codegen/function_library.h" #include "xla/backends/cpu/codegen/ir_compiler.h" #include "xla/backends/cpu/codegen/jit_compiler.h" #include "xla/backends/cpu/codegen/target_machine_features.h" @@ -160,7 +162,7 @@ limitations under the License. #include "xla/service/cpu/ir_emitter2.h" #include "xla/service/cpu/metrics.h" #include "xla/service/cpu/parallel_task_assignment.h" -#include "xla/service/cpu/simple_orc_jit.h" +#include "xla/service/cpu/runtime_symbol_generator.h" #include "xla/service/cpu/thunk_emitter.h" #include "xla/service/cpu_gpu_shape_verifier.h" #include "xla/service/dump.h" @@ -254,13 +256,6 @@ static tsl::thread::ThreadPool* GetCompilationThreadPool() { return thread_pool; } -// Returns a global (per-process) async executor for XLA CPU compilation tasks. -static AsyncValue::Executor* GetCompilationAsyncExecutor() { - static auto* executor = - new tsl::thread::ThreadPoolAsyncExecutor(GetCompilationThreadPool()); - return executor; -} - // Returns task runner that uses the global compilation thread pool. static cpu::JitCompiler::TaskRunner GetCompilationTaskRunner() { return [](cpu::JitCompiler::Task task) { @@ -1379,12 +1374,19 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { CreateOrcJITPostCompilationHook(module.get(), &obj_files), }; + // Definition generator to link with XLA:CPU host runtime symbols. + JitCompiler::DefinitionGenerator definition_generator = + [](llvm::TargetMachine* target_machine) { + return std::make_unique( + target_machine->createDataLayout()); + }; + // Options for orchestrating the JIT compilation process. JitCompiler::Options jit_compiler_options{ std::move(ir_compiler_options), std::move(ir_compiler_hooks), /*num_dylibs=*/parallel_codegen_split_count, - /*definition_generator=*/nullptr, + /*definition_generator=*/std::move(definition_generator), /*max_cpu_isa=*/CpuFeatureFromString(debug_options.xla_cpu_max_isa()), }; @@ -1395,22 +1397,6 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { std::move(jit_compiler_options), GetCompilationTaskRunner())); - auto jit = SimpleOrcJIT::Create( - CompilerTargetOptions(module->config()), - CodeGenOptLevel(module->config()), - options::OptimizeForSizeRequested(module->config()), - debug_options.xla_llvm_disable_expensive_passes(), - options::SlpVectorizerDisabled(module->config()), - llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook, - post_optimization_ir_hook, - CreateOrcJITPostCompilationHook(module.get(), &obj_files), - parallel_codegen_split_count, debug_options.xla_cpu_max_isa()); - if (!jit) { - return Internal("Creating JIT failed: %s", llvm::toString(jit.takeError())); - } - llvm_module->setDataLayout((*jit)->data_layout()); - llvm_module->setTargetTriple((*jit)->target_triple().getTriple()); - HloComputation* entry_computation = module->entry_computation(); absl::flat_hash_map instruction_to_profile_idx; @@ -1465,7 +1451,7 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { return cpu_executable; }; - TargetMachineFeatures target_machine_features((*jit)->target_machine()); + TargetMachineFeatures target_machine_features(jit_compiler.target_machine()); // TODO(ezhulenev): Once we fully migrate to Thunks current IrEmitter should // be renamed to NestedIrEmitter and be used only for emitting nested (aka @@ -1534,9 +1520,6 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { // module preserving locals, which should guarantee that all thread local // computations end up in the same module with the corresponding kernel. - // We rely on async executor to run compilation-related tasks in parallel. - auto* async = GetCompilationAsyncExecutor(); - // Collect all compiled symbols grouped by LLVM module part, so that we can // issue compile tasks in parallel without any interference. std::vector compiled_parts; @@ -1569,7 +1552,8 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { // Clone LLVM module part into its own thread safe context. auto tsm = CloneAsThreadSafeModule(n, std::move(llvm_module_part)); - cantFail((*jit)->AddModule(std::move(tsm), /*dylib_index=*/n++)); + TF_CHECK_OK( + jit_compiler.AddModule(std::move(tsm), /*dylib_index=*/n++)); }, /*PreserveLocals=*/true, /*RoundRobin=*/true); @@ -1582,79 +1566,34 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { << parallel_codegen_split_count << ")"; compiled_parts.push_back( CollectCompiledSymbolsPart(ir_emitter2, *llvm_module)); - cantFail((*jit)->AddModule(llvm::orc::ThreadSafeModule( + TF_CHECK_OK(jit_compiler.AddModule(llvm::orc::ThreadSafeModule( std::move(llvm_module), std::move(llvm_context)))); } - auto mangle = [&](std::string_view name) { - llvm::SmallVector mangled; - llvm::Mangler::getNameWithPrefix(mangled, name, (*jit)->data_layout()); - return std::string(mangled.begin(), mangled.end()); - }; - - // Compile all symbols in the given LLVM module part. - auto compile_part = [&](size_t part) -> absl::Status { - CompiledSymbolsPart& symbols = compiled_parts[part]; - - TraceMe trace([&] { - return TraceMeEncode("CpuCompiler::Codegen", - {{"part", part}, - {"num_kernels", symbols.kernels.size()}, - {"num_comparators", symbols.comparators.size()}}); - }); + // Collect compiled symbols from all LLVM module parts. + using Kernel = std::remove_pointer_t; + using Cmp = std::remove_pointer_t; + std::vector compiled_symbols; - for (const auto& kernel : symbols.kernels) { - TraceMe trace( - [&] { return TraceMeEncode("Kernel", {{"name", kernel.name}}); }); - if (auto s = (*jit)->FindCompiledSymbol(mangle(kernel.name)); !s) { - return Internal("Failed to find compiled symbol for kernel %s", - kernel.name); - } + for (const CompiledSymbolsPart& part : compiled_parts) { + for (const IrEmitter2::KernelInfo& kernel : part.kernels) { + compiled_symbols.push_back(FunctionLibrary::Sym(kernel.name)); } - - for (const auto& comparator : symbols.comparators) { - TraceMe trace([&] { - return TraceMeEncode("Comparator", {{"name", comparator.name}}); - }); - if (auto s = (*jit)->FindCompiledSymbol(mangle(comparator.name)); !s) { - return Internal("Failed to find compiled symbol for comparator %s", - comparator.name); - } + for (const IrEmitter2::ComparatorInfo& comparator : part.comparators) { + compiled_symbols.push_back(FunctionLibrary::Sym(comparator.name)); } - - return absl::OkStatus(); - }; - - // Mark kernel and comparator symbols as "kernel symbols" to suppress - // SimpleOrcJIT error logging when symbol is not found in module part. - for (const auto& kernel : ir_emitter2.kernels()) { - (*jit)->AddKernelSymbol(mangle(kernel.name)); - } - for (const auto& comparator : ir_emitter2.comparators()) { - (*jit)->AddKernelSymbol(mangle(comparator.name)); } - // Schedule compilation of LLVM module parts in parallel. - std::vector> compile_tasks(compiled_parts.size()); - for (size_t part = 0; part < compiled_parts.size(); ++part) { - compile_tasks[part] = tsl::TryMakeAsyncValueRef( - *async, [&, part]() -> absl::StatusOr { - TF_RETURN_IF_ERROR(compile_part(part)); - return Chain{}; - }); - } + VLOG(3) << "Collected " << compiled_symbols.size() << " compiled symbols"; - { // Wait for all compilation tasks to finish. - TraceMe trace_codegen([&] { - return TraceMeEncode("Codegen (Wait)", {{"num_parts", num_parts}, - {"num_compiled_functions", - num_compiled_functions}}); - }); - for (auto& task : compile_tasks) { - tsl::BlockUntilReady(task); - if (task.IsError()) return task.GetError(); - } - } + TraceMe trace_codegen([&] { + return TraceMeEncode( + "Codegen", {{"num_parts", num_parts}, + {"num_compiled_functions", num_compiled_functions}}); + }); + + TF_ASSIGN_OR_RETURN(std::unique_ptr function_library, + std::move(jit_compiler).Compile(compiled_symbols)); // Create constant allocations from the buffer assignment. TF_ASSIGN_OR_RETURN( @@ -1663,9 +1602,9 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { TF_ASSIGN_OR_RETURN( auto cpu_executable, - CpuExecutable::Create(std::move(*jit), std::move(assignment), - std::move(module), std::move(thunks), - std::move(constants), + CpuExecutable::Create(std::move(function_library), + std::move(assignment), std::move(module), + std::move(thunks), std::move(constants), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); @@ -1708,14 +1647,6 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { schedule.sequence(entry_computation).instructions(), /*allow_reassociation=*/false)); - std::string function_name = [&]() { - llvm::SmallVector function_name_vector; - llvm::Mangler::getNameWithPrefix( - function_name_vector, entry_function->getName(), (*jit)->data_layout()); - return std::string(function_name_vector.begin(), - function_name_vector.end()); - }(); - std::string ir_module_string; if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpToString(llvm_module.get()); @@ -1723,15 +1654,24 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); + // Save entry function name before destroying LLVM module. + std::string entry_function_name = entry_function->getName().str(); + // JIT compile the LLVM IR module to in-memory machine code. llvm::orc::ThreadSafeModule thread_safe_module(std::move(llvm_module), std::move(llvm_context)); - cantFail((*jit)->AddModule(std::move(thread_safe_module))); + TF_RETURN_IF_ERROR(jit_compiler.AddModule(std::move(thread_safe_module))); + + using ComputeFn = std::remove_pointer_t; + TF_ASSIGN_OR_RETURN( + std::unique_ptr function_library, + std::move(jit_compiler) + .Compile({FunctionLibrary::Sym(entry_function_name)})); TF_ASSIGN_OR_RETURN( auto cpu_executable, - CpuExecutable::Create(std::move(*jit), std::move(assignment), - std::move(module), function_name, + CpuExecutable::Create(std::move(function_library), std::move(assignment), + std::move(module), entry_function_name, std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); @@ -2126,12 +2066,22 @@ CpuExecutableAotCompilationResult::LoadExecutable( /*slp_vectorizer_disabled=*/options::SlpVectorizerDisabled(config), }; + // We don't need any hooks when loading AOT compilation result. + IrCompiler::CompilationHooks ir_compiler_hooks = {}; + + // Definition generator to link with XLA:CPU host runtime symbols. + JitCompiler::DefinitionGenerator definition_generator = + [](llvm::TargetMachine* target_machine) { + return std::make_unique( + target_machine->createDataLayout()); + }; + // Options for orchestrating the JIT compilation process. JitCompiler::Options jit_compiler_options{ std::move(ir_compiler_options), - IrCompiler::CompilationHooks{}, + std::move(ir_compiler_hooks), /*num_dylibs=*/1, - /*definition_generator=*/nullptr, + /*definition_generator=*/std::move(definition_generator), /*max_cpu_isa=*/CpuFeatureFromString(debug_options.xla_cpu_max_isa()), }; @@ -2142,20 +2092,6 @@ CpuExecutableAotCompilationResult::LoadExecutable( std::move(jit_compiler_options), /*task_runner=*/nullptr)); - auto jit = SimpleOrcJIT::Create( - CompilerTargetOptions(module->config()), - CodeGenOptLevel(module->config()), - options::OptimizeForSizeRequested(module->config()), - debug_options.xla_llvm_disable_expensive_passes(), - options::SlpVectorizerDisabled(module->config()), - llvm_ir::GetCpuFastMathFlags(module->config()), - /*pre_optimization_hook=*/nullptr, /*post_optimization_hook=*/nullptr, - /*post_codegen_hook=*/nullptr, /*num_jit_dylibs=*/1, - debug_options.xla_cpu_max_isa()); - if (!jit) { - return Internal("Creating JIT failed: %s", llvm::toString(jit.takeError())); - } - // We might have an XLA:CPU executable that has only runtime thunks and // doesn't have any corresponding object files, and it's absolutely fine. VLOG(2) << "Load XLA:CPU executable from " << proto_.obj_files_size() @@ -2165,7 +2101,7 @@ CpuExecutableAotCompilationResult::LoadExecutable( size_t obj_file_index = 0; for (auto& obj_file : proto_.obj_files()) { llvm::StringRef data(obj_file.data(), obj_file.size()); - cantFail((*jit)->AddObjFile(llvm::MemoryBuffer::getMemBuffer( + TF_RETURN_IF_ERROR(jit_compiler.AddObjFile(llvm::MemoryBuffer::getMemBuffer( data, absl::StrCat(proto_.entry_function_name(), "_", obj_file_index++)))); } @@ -2186,7 +2122,8 @@ CpuExecutableAotCompilationResult::LoadExecutable( auto llvm_module = std::make_unique(kXlaModuleIdentifier, *llvm_context); - TargetMachineFeatures target_machine_features((*jit)->target_machine()); + TargetMachineFeatures target_machine_features( + jit_compiler.target_machine()); IrEmitter nested_ir_emitter( nullptr, *module, *buffer_assignment, llvm_module.get(), {}, {}, @@ -2200,35 +2137,21 @@ CpuExecutableAotCompilationResult::LoadExecutable( TF_ASSIGN_OR_RETURN(ThunkSequence thunks, thunk_emitter.EmitEntryComputation(*module)); - auto mangle = [&](std::string_view name) { - llvm::SmallVector mangled; - llvm::Mangler::getNameWithPrefix(mangled, name, (*jit)->data_layout()); - return std::string(mangled.begin(), mangled.end()); - }; + // Collect compiled symbols from IrEmitter2. + using Kernel = std::remove_pointer_t; + using Cmp = std::remove_pointer_t; + std::vector compiled_symbols; - // Mark kernel and comparator symbols as "kernel symbols" to suppress run - // time error logging when symbol is not found in module part. for (const auto& kernel : ir_emitter2.kernels()) { - (*jit)->AddKernelSymbol(mangle(kernel.name)); + compiled_symbols.push_back(FunctionLibrary::Sym(kernel.name)); } for (const auto& comparator : ir_emitter2.comparators()) { - (*jit)->AddKernelSymbol(mangle(comparator.name)); - } - - // Lookup all kernel functions by name in the loaded object file. - for (const auto& kernel : ir_emitter2.kernels()) { - if (auto s = (*jit)->FindCompiledSymbol(mangle(kernel.name)); !s) { - return Internal("Failed to find compiled symbol for kernel %s", - kernel.name); - } + compiled_symbols.push_back(FunctionLibrary::Sym(comparator.name)); } - for (const auto& comparator : ir_emitter2.comparators()) { - if (auto s = (*jit)->FindCompiledSymbol(mangle(comparator.name)); !s) { - return Internal("Failed to find compiled symbol for comparator %s", - comparator.name); - } - } + VLOG(3) << "Collected " << compiled_symbols.size() << " compiled symbols"; + TF_ASSIGN_OR_RETURN(std::unique_ptr function_library, + std::move(jit_compiler).Compile(compiled_symbols)); // Create constant allocations from the buffer assignment. TF_ASSIGN_OR_RETURN( @@ -2237,17 +2160,24 @@ CpuExecutableAotCompilationResult::LoadExecutable( TF_ASSIGN_OR_RETURN( cpu_executable, - CpuExecutable::Create(std::move(*jit), std::move(buffer_assignment), - std::move(module), std::move(thunks), - std::move(constants), nullptr, nullptr)); + CpuExecutable::Create(std::move(function_library), + std::move(buffer_assignment), std::move(module), + std::move(thunks), std::move(constants), nullptr, + nullptr)); } else if (proto_.obj_files_kind() == CompilationResultProto::CLASSIC) { // Create a "classic" CPU executable. + using ComputeFn = std::remove_pointer_t; + TF_ASSIGN_OR_RETURN(std::unique_ptr function_library, + std::move(jit_compiler) + .Compile({FunctionLibrary::Sym( + proto_.entry_function_name())})); + TF_ASSIGN_OR_RETURN( cpu_executable, - CpuExecutable::Create(std::move(*jit), std::move(buffer_assignment), - std::move(module), proto_.entry_function_name(), - nullptr, nullptr)); + CpuExecutable::Create(std::move(function_library), + std::move(buffer_assignment), std::move(module), + proto_.entry_function_name(), nullptr, nullptr)); } else { return Internal("Unknown obj file kind"); diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc index ddc96e141dfa22..9bf0eefe873fdb 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.cc +++ b/third_party/xla/xla/service/cpu/cpu_executable.cc @@ -15,8 +15,6 @@ limitations under the License. #include "xla/service/cpu/cpu_executable.h" -#include "xla/service/hlo_profile_printer_data.pb.h" - #define EIGEN_USE_THREADS #include @@ -26,6 +24,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -43,6 +42,7 @@ limitations under the License. #include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h" #include "llvm/IR/Mangler.h" #include "llvm/Support/Error.h" +#include "xla/backends/cpu/codegen/function_library.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/thread_pool_task_runner.h" #include "xla/backends/cpu/runtime/thunk.h" @@ -58,6 +58,8 @@ limitations under the License. #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" #include "xla/service/executable.h" +#include "xla/service/hlo_execution_profile.h" +#include "xla/service/hlo_profile_printer_data.pb.h" #include "xla/service/hlo_value.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/service/service_executable_run_options.h" @@ -82,40 +84,21 @@ namespace cpu { using ConstantAllocation = CpuExecutable::ConstantAllocation; using FunctionRegistry = CpuExecutable::FunctionRegistry; -FunctionRegistry::FunctionRegistry(SimpleOrcJIT* jit) : jit_(jit) {} - -std::string FunctionRegistry::Mangle(std::string_view name) { - llvm::SmallVector mangled; - llvm::Mangler::getNameWithPrefix(mangled, name, jit_->data_layout()); - return std::string(mangled.begin(), mangled.end()); -} +FunctionRegistry::FunctionRegistry(FunctionLibrary* function_library) + : function_library_(function_library) {} absl::StatusOr FunctionRegistry::FindKernel( std::string_view name) { VLOG(3) << "Find host kernel with a name " << name; - - llvm::Expected sym = - jit_->FindCompiledSymbol(Mangle(name)); - if (!sym) { - return absl::InvalidArgumentError( - absl::StrCat("Can't resolve host kernel with a name ", name, - " in the jit compiled module.")); - } - return reinterpret_cast(sym->getAddress().getValue()); + using F = std::remove_pointer_t; + return function_library_->ResolveFunction(name); } absl::StatusOr FunctionRegistry::FindComparator( std::string_view name) { VLOG(3) << "Find comparator with a name " << name; - - llvm::Expected sym = - jit_->FindCompiledSymbol(Mangle(name)); - if (!sym) { - return absl::InvalidArgumentError( - absl::StrCat("Can't resolve comparator with a name ", name, - " in the jit compiled module.")); - } - return reinterpret_cast(sym->getAddress().getValue()); + using F = std::remove_pointer_t; + return function_library_->ResolveFunction(name); } se::DeviceMemoryBase ConstantAllocation::AsDeviceMemoryBase() const { @@ -135,7 +118,7 @@ se::DeviceMemoryBase ConstantAllocation::AsDeviceMemoryBase() const { } absl::StatusOr> CpuExecutable::Create( - std::unique_ptr jit, + std::unique_ptr function_library, std::unique_ptr assignment, std::unique_ptr hlo_module, const std::string& entry_function_name, @@ -147,31 +130,23 @@ absl::StatusOr> CpuExecutable::Create( std::unique_ptr executable(new CpuExecutable( std::move(hlo_module), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map), std::move(assignment))); - executable->jit_ = std::move(jit); + executable->function_library_ = std::move(function_library); executable->module_name_ = entry_function_name; - // Resolve symbols in the constructor rather than at execution time to avoid - // races because FindSymbol is not thread safe. - llvm::Expected sym = - executable->jit_->FindCompiledSymbol(entry_function_name); - // We expect to find the symbol provided with entry_function_name; otherwise - // this is an internal error. - if (!sym) { - return absl::NotFoundError( - absl::StrCat("Symbol ", entry_function_name, " not found.")); - } - // getAddress can do work under the hood in the jit, so it needs to be - // guarded by the mutex. - executable->compute_function_ = - reinterpret_cast(sym->getAddress().getValue()); + TF_ASSIGN_OR_RETURN( + executable->compute_function_, + executable->function_library_ + ->ResolveFunction>( + entry_function_name)); + VLOG(1) << "compute_function_ at address " << reinterpret_cast(executable->compute_function_); - executable->jit_->DoneCompiling(); + return executable; } absl::StatusOr> CpuExecutable::Create( - std::unique_ptr jit, + std::unique_ptr function_library, std::unique_ptr assignment, std::unique_ptr hlo_module, ThunkSequence thunks, std::vector constants, @@ -184,9 +159,8 @@ absl::StatusOr> CpuExecutable::Create( std::move(hlo_module), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map), std::move(assignment))); - executable->jit_ = std::move(jit); - executable->jit_->DoneCompiling(); - executable->function_registry_ = FunctionRegistry(executable->jit_.get()); + executable->function_registry_ = FunctionRegistry(function_library.get()); + executable->function_library_ = std::move(function_library); TF_ASSIGN_OR_RETURN(executable->thunks_, ThunkExecutor::Create(std::move(thunks))); @@ -591,7 +565,8 @@ const InstructionValueSet& CpuExecutable::GetRootValueSet() const { } int64_t CpuExecutable::SizeOfGeneratedCodeInBytes() const { - return jit_ ? jit_->SizeOfGeneratedCodeInBytes() : 0; + // TODO(ezhulenev): Delete this function, it's not really used anywhere. + return 0; } } // namespace cpu diff --git a/third_party/xla/xla/service/cpu/cpu_executable.h b/third_party/xla/xla/service/cpu/cpu_executable.h index 592b1af45395b9..fb21eacfda4b1e 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.h +++ b/third_party/xla/xla/service/cpu/cpu_executable.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/cpu/codegen/function_library.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/executable_run_options.h" @@ -68,7 +69,7 @@ class CpuExecutable : public Executable { // Creates a CpuExecutable from JIT compiled cpu function by resolving // `entry_function_name` in the `jit`. static absl::StatusOr> Create( - std::unique_ptr jit, + std::unique_ptr function_library, std::unique_ptr assignment, std::unique_ptr hlo_module, const std::string& entry_function_name, @@ -77,7 +78,7 @@ class CpuExecutable : public Executable { // Creates a CpuExecutable from a thunk sequence. static absl::StatusOr> Create( - std::unique_ptr jit, + std::unique_ptr function_library, std::unique_ptr assignment, std::unique_ptr hlo_module, ThunkSequence thunks, std::vector constants, @@ -139,18 +140,16 @@ class CpuExecutable : public Executable { return assignment_->Allocations(); } - // A Thunk::FunctionRegistry implementation that jit-compiles functions on - // demand using the SimpleOrcJIT instance owned by the CpuExecutable. + // A Thunk::FunctionRegistry implementation that looks up functions in the + // FunctionLibrary. class FunctionRegistry : public Thunk::FunctionRegistry { public: - explicit FunctionRegistry(SimpleOrcJIT* jit); + explicit FunctionRegistry(FunctionLibrary* function_library); absl::StatusOr FindKernel(std::string_view name) final; absl::StatusOr FindComparator(std::string_view name) final; private: - std::string Mangle(std::string_view name); - - SimpleOrcJIT* jit_; + FunctionLibrary* function_library_; }; Thunk::FunctionRegistry& function_registry() { return *function_registry_; } @@ -190,11 +189,11 @@ class CpuExecutable : public Executable { // computation. Uses dataflow analysis from buffer assignment. const InstructionValueSet& GetRootValueSet() const; - // The JIT containing compiled modules. - std::unique_ptr jit_; + // The FunctionLibrary containing compiled modules. + std::unique_ptr function_library_; // Object files (machine code) compiled from an HLO module by the JIT - // compiler. We capture all object files created by SimpleOrcJIT so we can + // compiler. We capture all object files created by JitCompiler so we can // export them to AOT compilation result. std::vector obj_files_; diff --git a/third_party/xla/xla/service/cpu/runtime_symbol_generator.cc b/third_party/xla/xla/service/cpu/runtime_symbol_generator.cc index 2d05717e625583..f0a0742a2e2998 100644 --- a/third_party/xla/xla/service/cpu/runtime_symbol_generator.cc +++ b/third_party/xla/xla/service/cpu/runtime_symbol_generator.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include #include +#include #include -#include #include #include "absl/functional/any_invocable.h" @@ -68,23 +68,19 @@ limitations under the License. namespace xla::cpu { -RuntimeSymbolGenerator::RuntimeSymbolGenerator( - llvm::DataLayout data_layout, - absl::AnyInvocable is_kernel_symbol) - : data_layout_(std::move(data_layout)), - is_kernel_symbol_(std::move(is_kernel_symbol)) {} +RuntimeSymbolGenerator::RuntimeSymbolGenerator(llvm::DataLayout data_layout) + : data_layout_(std::move(data_layout)) {} llvm::Error RuntimeSymbolGenerator::tryToGenerate( - llvm::orc::LookupState&, llvm::orc::LookupKind, + llvm::orc::LookupState&, llvm::orc::LookupKind kind, llvm::orc::JITDylib& jit_dylib, llvm::orc::JITDylibLookupFlags, const llvm::orc::SymbolLookupSet& names) { llvm::orc::SymbolMap new_defs; for (const auto& kv : names) { const auto& name = kv.first; - llvm::orc::ExecutorSymbolDef symbol = ResolveRuntimeSymbol(*name); - if (symbol.getAddress()) { - new_defs[name] = symbol; + if (auto symbol = ResolveRuntimeSymbol(*name)) { + new_defs[name] = *symbol; } } @@ -92,8 +88,8 @@ llvm::Error RuntimeSymbolGenerator::tryToGenerate( return llvm::Error::success(); } -llvm::orc::ExecutorSymbolDef RuntimeSymbolGenerator::ResolveRuntimeSymbol( - llvm::StringRef name) { +std::optional +RuntimeSymbolGenerator::ResolveRuntimeSymbol(llvm::StringRef name) { void* fn_addr = nullptr; if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) { // On Mac OS X, 'name' may have a leading underscore prefix, even though the @@ -104,20 +100,9 @@ llvm::orc::ExecutorSymbolDef RuntimeSymbolGenerator::ResolveRuntimeSymbol( fn_addr = CustomCallTargetRegistry::Global()->Lookup(name.str(), "Host"); } - if (fn_addr == nullptr) { - // If symbol corresponds to a kernel function, then it must be defined in - // another LLVM module part (another dylib). - if (is_kernel_symbol_ && !is_kernel_symbol_(name.str())) { - LOG(ERROR) - << "Unable to resolve runtime symbol: `" << name.str() - << "'. Hint: if the symbol a custom call target, make sure you've " - "registered it with the JIT using " - "XLA_CPU_REGISTER_CUSTOM_CALL_TARGET."; - } - return {}; - } - return {llvm::orc::ExecutorAddr(reinterpret_cast(fn_addr)), - llvm::JITSymbolFlags::None}; + return llvm::orc::ExecutorSymbolDef{ + llvm::orc::ExecutorAddr(reinterpret_cast(fn_addr)), + llvm::JITSymbolFlags::None}; } //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/cpu/runtime_symbol_generator.h b/third_party/xla/xla/service/cpu/runtime_symbol_generator.h index c7fc14079798a4..4173717a91163e 100644 --- a/third_party/xla/xla/service/cpu/runtime_symbol_generator.h +++ b/third_party/xla/xla/service/cpu/runtime_symbol_generator.h @@ -16,9 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_RUNTIME_SYMBOL_GENERATOR_H_ #define XLA_SERVICE_CPU_RUNTIME_SYMBOL_GENERATOR_H_ -#include +#include -#include "absl/functional/any_invocable.h" #include "llvm/ADT/StringRef.h" #include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h" @@ -31,9 +30,7 @@ namespace xla::cpu { // the compiled XLA kernels. class RuntimeSymbolGenerator : public llvm::orc::DefinitionGenerator { public: - RuntimeSymbolGenerator( - llvm::DataLayout data_layout, - absl::AnyInvocable is_kernel_symbol); + explicit RuntimeSymbolGenerator(llvm::DataLayout data_layout); llvm::Error tryToGenerate(llvm::orc::LookupState&, llvm::orc::LookupKind, llvm::orc::JITDylib& jit_dylib, @@ -41,10 +38,10 @@ class RuntimeSymbolGenerator : public llvm::orc::DefinitionGenerator { const llvm::orc::SymbolLookupSet& names) final; private: - llvm::orc::ExecutorSymbolDef ResolveRuntimeSymbol(llvm::StringRef name); + std::optional ResolveRuntimeSymbol( + llvm::StringRef name); llvm::DataLayout data_layout_; - absl::AnyInvocable is_kernel_symbol_; }; } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/simple_orc_jit.cc b/third_party/xla/xla/service/cpu/simple_orc_jit.cc index bdb8a4fec62832..99db17d5aba3fe 100644 --- a/third_party/xla/xla/service/cpu/simple_orc_jit.cc +++ b/third_party/xla/xla/service/cpu/simple_orc_jit.cc @@ -108,11 +108,8 @@ SimpleOrcJIT::SimpleOrcJIT( for (size_t i = 0; i < num_jit_dylibs; ++i) { jit_dylibs_[i] = &execution_session_->createBareJITDylib( absl::StrCat("")); - jit_dylibs_[i]->addGenerator(std::make_unique( - data_layout_, - /*is_kernel_symbol=*/[&](std::string_view name) { - return kernel_symbols_.contains(name); - })); + jit_dylibs_[i]->addGenerator( + std::make_unique(data_layout_)); } object_layer_.registerJITEventListener(*this); diff --git a/third_party/xla/xla/tests/local_client_execute_test.cc b/third_party/xla/xla/tests/local_client_execute_test.cc index 829bfc31fb3449..22c469aa992863 100644 --- a/third_party/xla/xla/tests/local_client_execute_test.cc +++ b/third_party/xla/xla/tests/local_client_execute_test.cc @@ -760,8 +760,8 @@ XLA_TEST_F(LocalClientExecuteTest, CompilePartitionedExecutable) { EXPECT_EQ(2, executables.size()); } -XLA_TEST_F(LocalClientExecuteTest, - DISABLED_ON_INTERPRETER(SizeOfGeneratedCodeInBytes)) { +XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_CPU(DISABLED_ON_INTERPRETER( + SizeOfGeneratedCodeInBytes))) { if (IsMlirLoweringEnabled()) { // SizeOfGeneratedCodeInBytes is not supported by the MLIR pipeline. GTEST_SKIP();