Skip to content

Commit

Permalink
Fixes issues related to kernels calling other kernels (#1246)
Browse files Browse the repository at this point in the history
Avoids entry points being the target of function calls by introducing a second function for any called kernel.
  • Loading branch information
Rekt3421 authored Oct 5, 2023
1 parent cd48230 commit 7412943
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 1 deletion.
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ add_library(clspv_passes OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/LongVectorLoweringPass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/SetImageChannelMetadataPass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ThreeElementVectorLoweringPass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/WrapKernelPass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/LowerAddrSpaceCastPass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/LowerPrivatePointerPHIPass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/MultiVersionUBOFunctionsPass.cpp
Expand Down
3 changes: 2 additions & 1 deletion lib/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,11 @@ int RunPassPipeline(llvm::Module &M, llvm::raw_svector_ostream *binaryStream) {
// Run the following optimizations prior to the standard LLVM pass pipeline.
pb.registerPipelineStartEPCallback([](llvm::ModulePassManager &pm,
llvm::OptimizationLevel level) {
pm.addPass(clspv::AnnotationToMetadataPass());
pm.addPass(clspv::WrapKernelPass());
pm.addPass(clspv::NativeMathPass());
pm.addPass(clspv::ZeroInitializeAllocasPass());
pm.addPass(clspv::KernelArgNamesToMetadataPass());
pm.addPass(clspv::AnnotationToMetadataPass());
pm.addPass(clspv::AddFunctionAttributesPass());
pm.addPass(clspv::AutoPodArgsPass());
pm.addPass(clspv::DeclarePushConstantsPass());
Expand Down
1 change: 1 addition & 0 deletions lib/PassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ MODULE_PASS("set-image-channel-metadata", clspv::SetImageChannelMetadataPass)
MODULE_PASS("lower-addrspacecast", clspv::LowerAddrSpaceCastPass)
MODULE_PASS("lower-private-pointer-phi", clspv::LowerPrivatePointerPHIPass)
MODULE_PASS("multi-version-ubo-functions", clspv::MultiVersionUBOFunctionsPass)
MODULE_PASS("wrap-kernel", clspv::WrapKernelPass)
MODULE_PASS("native-math", clspv::NativeMathPass)
MODULE_PASS("opencl-inliner", clspv::OpenCLInlinerPass)
MODULE_PASS("physical-pointer-args", clspv::PhysicalPointerArgsPass)
Expand Down
1 change: 1 addition & 0 deletions lib/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#include "UndoSRetPass.h"
#include "UndoTranslateSamplerFoldPass.h"
#include "UndoTruncateToOddIntegerPass.h"
#include "WrapKernelPass.h"
#include "ZeroInitializeAllocasPass.h"

#endif
90 changes: 90 additions & 0 deletions lib/WrapKernelPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#include "llvm/IR/Constants.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstVisitor.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/Utils/Cloning.h"

#include "WrapKernelPass.h"

using namespace llvm;

void clspv::WrapKernelPass::runOnFunction(Module &M, llvm::Function *F) {
SmallVector<Type *, 8> NewParamTypes;
for (auto &Arg : F->args()) {
NewParamTypes.push_back(Arg.getType());
}

auto *NewFuncTy = FunctionType::get(F->getReturnType(), NewParamTypes, false);

auto NewFunc = Function::Create(NewFuncTy, F->getLinkage());
NewFunc->setName(F->getName().str());
F->setName(F->getName().str() + ".inner");
NewFunc->setCallingConv(F->getCallingConv());
NewFunc->copyAttributesFrom(F);

for (auto &arg : F->args()) {
NewFunc->getArg(arg.getArgNo())->setName(arg.getName());
}

F->setCallingConv(CallingConv::SPIR_FUNC);
for (auto &U : F->uses()) {
if (auto CI = dyn_cast<CallInst>(U.getUser())) {
CI->setCallingConv(CallingConv::SPIR_FUNC);
}
}
NewFunc->copyMetadata(F, 0);
F->clearMetadata();

IRBuilder<> Builder(BasicBlock::Create(M.getContext(), "entry", NewFunc));

// Copy args from src func to new func
// Get the arguments of the source function.
SmallVector<Value *, 8> WrappedArgs;
for (unsigned ArgNum = 0; ArgNum < F->arg_size(); ArgNum++) {
auto *NewArg = NewFunc->getArg(ArgNum);
WrappedArgs.push_back(NewArg);
}

auto *CallInst = Builder.CreateCall(F, WrappedArgs);
CallInst->setCallingConv(CallingConv::SPIR_FUNC);
Builder.CreateRetVoid();

// Insert the function after the original, to preserve ordering
// in the module as much as possible.
auto &FunctionList = M.getFunctionList();
for (auto Iter = FunctionList.begin(), IterEnd = FunctionList.end();
Iter != IterEnd; ++Iter) {
if (&*Iter == F) {
FunctionList.insertAfter(Iter, NewFunc);
break;
}
}
}

bool isCalled(llvm::Function &F) {
for (auto &U : F.uses()) {
if (dyn_cast<CallInst>(U.getUser())) {
return true;
}
}
return false;
}

PreservedAnalyses clspv::WrapKernelPass::run(llvm::Module &M,
ModuleAnalysisManager &) {
PreservedAnalyses PA;
SmallVector<Function *, 8> FuncsToWrap;

for (auto &F : M.functions()) {
if (F.getCallingConv() == CallingConv::SPIR_KERNEL && isCalled(F)) {
FuncsToWrap.emplace_back(&F);
}
}
for (unsigned i = 0; i < FuncsToWrap.size(); ++i) {
runOnFunction(M, FuncsToWrap[i]);
}
return PA;
}
30 changes: 30 additions & 0 deletions lib/WrapKernelPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright 2023 The Clspv Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"

#ifndef _CLSPV_LIB_WRAP_KERNEL_PASS_H
#define _CLSPV_LIB_WRAP_KERNEL_PASS_H

namespace clspv {
struct WrapKernelPass : llvm::PassInfoMixin<WrapKernelPass> {
llvm::PreservedAnalyses run(llvm::Module &M, llvm::ModuleAnalysisManager &);

private:
void runOnFunction(llvm::Module &M, llvm::Function *F);
};
} // namespace clspv

#endif // _CLSPV_LIB_WRAP_KERNEL_PASS_H
36 changes: 36 additions & 0 deletions test/WrapKernel/wrap_kernel.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// ; RUN: clspv %s -o %t.spv
// ; RUN: spirv-dis -o %t2.spvasm %t.spv
// ; RUN: FileCheck %s < %t2.spvasm
// ; RUN: spirv-val %t.spv

__attribute__((work_group_size_hint(1,1,1)))
__attribute__((reqd_work_group_size(1,1,1)))
__kernel void add(global int *A, global int *B) {
// Add 1 to each element of the input buffer.
for (unsigned int i = 0; i < get_global_size(0); i++) {
B[i] = A[i] + 1;
}
}

// kernel1.clspv
kernel void main_kernel(global int* C, global int* D) {
// Call the function.
add(C, D);
}

// CHECK: [[extinst:%[a-zA-A0-9_]+]] = OpExtInstImport "NonSemantic.ClspvReflection.5"
// CHECK-DAG: [[kernel_add_name:%[a-zA-Z0-9_]+]] = OpString "add"
// CHECK: [[attributes:%[^ ]+]] = OpString " __attribute__((work_group_size_hint(1, 1, 1))) __attribute__((reqd_work_group_size(1, 1, 1))) __kernel"
// CHECK-DAG: [[k0arg0:%[a-zA-Z0-9_]+]] = OpString "A"
// CHECK-DAG: [[k0arg1:%[a-zA-Z0-9_]+]] = OpString "B"
// CHECK-DAG: [[kernel_main_name:%[a-zA-Z0-9_]+]] = OpString "main_kernel"
// CHECK-DAG: [[k1arg0:%[a-zA-Z0-9_]+]] = OpString "C"
// CHECK-DAG: [[k1arg1:%[a-zA-Z0-9_]+]] = OpString "D"

// CHECK-DAG: [[kernel_add_def:%[a-zA-A0-9_]+]] = OpExtInst %void [[extinst]] Kernel {{.*}} [[kernel_add_name]] {{.*}} [[attributes]]
// CHECK-NEXT: OpExtInst %void [[extinst]] PropertyRequiredWorkgroupSize [[kernel_add_def]] {{.*}} {{.*}} {{.*}}
// CHECK: OpExtInst %void [[extinst]] ArgumentInfo [[k0arg0]]
// CHECK: OpExtInst %void [[extinst]] ArgumentInfo [[k0arg1]]
// CHECK-DAG: OpExtInst %void [[extinst]] Kernel {{.*}} [[kernel_main_name]] {{.*}} {{.*}} {{.*}}
// CHECK: OpExtInst %void [[extinst]] ArgumentInfo [[k1arg0]]
// CHECK: OpExtInst %void [[extinst]] ArgumentInfo [[k1arg1]]
72 changes: 72 additions & 0 deletions test/WrapKernel/wrap_kernel.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
; RUN: clspv-opt %s -o %t.ll --passes=wrap-kernel
; RUN: FileCheck %s < %t.ll

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir-unknown-unknown"

define dso_local spir_kernel void @add(ptr addrspace(1) align 4 %A, ptr addrspace(1) align 4 %B) #2 !kernel_arg_addr_space !15 !kernel_arg_access_qual !16 !kernel_arg_type !17 !kernel_arg_base_type !17 !kernel_arg_type_qual !18 !work_group_size_hint !19 !reqd_work_group_size !19 {
entry:
%A.addr = alloca ptr addrspace(1), align 4
%B.addr = alloca ptr addrspace(1), align 4
%i = alloca i32, align 4
store ptr addrspace(1) %A, ptr %A.addr, align 4
store ptr addrspace(1) %B, ptr %B.addr, align 4
store i32 0, ptr %i, align 4
ret void
}

define dso_local spir_kernel void @main_kernel(ptr addrspace(1) align 4 %A, ptr addrspace(1) align 4 %B) #2 !kernel_arg_addr_space !15 !kernel_arg_access_qual !16 !kernel_arg_type !17 !kernel_arg_base_type !17 !kernel_arg_type_qual !18 {
entry:
%A.addr = alloca ptr addrspace(1), align 4
%B.addr = alloca ptr addrspace(1), align 4
store ptr addrspace(1) %A, ptr %A.addr, align 4
store ptr addrspace(1) %B, ptr %B.addr, align 4
%0 = load ptr addrspace(1), ptr %A.addr, align 4
%1 = load ptr addrspace(1), ptr %B.addr, align 4
call spir_kernel void @add(ptr addrspace(1) align 4 %0, ptr addrspace(1) align 4 %1) #5
ret void
}


!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 7, !"direct-access-external-data", i32 0}
!2 = !{i32 7, !"frame-pointer", i32 2}
!3 = !{i32 1, i32 2}
!4 = !{!"clang version 18.0.0 (https://github.com/llvm/llvm-project c7d65e4466eafe518937c59ef9a242234ed7a08a)"}
!5 = !{!"clang version 17.0.0 (https://github.com/llvm/llvm-project 1e6fc9626c0f49ce952a67aef47e86253d13f74a)"}
!6 = !{!"clang version 17.0.0 (https://github.com/llvm/llvm-project ab674234c440ed27302f58eeccc612c83b32c43f)"}
!7 = !{!"_Z4sqrtf", !" __attribute__((overloadable)) __attribute__((const))"}
!8 = !{!"_Z4sqrtDv2_f", !" __attribute__((overloadable)) __attribute__((const))"}
!9 = !{!"_Z4sqrtDv3_f", !" __attr((overloadable)) __attribute__((const))"}
!10 = !{!"_Z4sqrtDv4_f", !" __attribute__((overloadable)) __aibute__ttribute__((const))"}
!11 = !{!"_Z4sqrtDv8_f", !" __attribute__((overloadable)) __attribute__((const))"}
!12 = !{!"_Z4sqrtDv16_f", !" __attribute__((overloadable)) __attribute__((const))"}
!13 = !{!"add", !" __attribute__((work_group_size_hint(1, 1, 1))) __attribute__((reqd_work_group_size(1, 1, 1))) __kernel"}
!14 = !{!"main_kernel", !" kernel"}
!15 = !{i32 1, i32 1}
!16 = !{!"none", !"none"}
!17 = !{!"int*", !"int*"}
!18 = !{!"", !""}
!19 = !{i32 1, i32 1, i32 1}

; CHECK:define dso_local spir_func void @add.inner(ptr addrspace(1) align [[alignment:[0-9]*]] %A, ptr addrspace(1) align [[alignment]] %B)
; CHECK-NEXT: entry:
; CHECK: ret void

; CHECK: define dso_local spir_kernel void @add(ptr addrspace(1) align [[alignment]] %A, ptr addrspace(1) align [[alignment]] %B) !kernel_arg_addr_space [[args_type:![0-9]*]] !kernel_arg_access_qual [[none:![0-9]*]] !kernel_arg_type [[ptr:![0-9]*]] !kernel_arg_base_type [[ptr]] !kernel_arg_type_qual [[empty:![0-9]*]] !work_group_size_hint [[work_group_hint:![0-9]*]] !reqd_work_group_size [[work_group_hint]]
; CHECK-NEXT: entry:
; CHECK-NEXT: call spir_func void @add.inner(ptr addrspace(1) %A, ptr addrspace(1) %B)
; CHECK: ret void

; CHECK: define dso_local spir_kernel void @main_kernel(ptr addrspace(1) align [[alignment]] %A, ptr addrspace(1) align [[alignment]] %B) !kernel_arg_addr_space [[args_type]] !kernel_arg_access_qual [[none]] !kernel_arg_type [[ptr]] !kernel_arg_base_type [[ptr]] !kernel_arg_type_qual [[empty]]
; CHECK-NEXT: entry:
; CHECK-DAG: [[paramA:%[a-zA-A0-9_]+]] = load ptr addrspace(1), ptr %A.addr, align [[alignment]]
; CHECK-DAG: [[paramB:%[a-zA-A0-9_]+]] = load ptr addrspace(1), ptr %B.addr, align [[alignment]]
; CHECK: call spir_func void @add.inner(ptr addrspace(1) align [[alignment]] [[paramA]], ptr addrspace(1) align [[alignment]] [[paramB]])
; CHECK ret void

; CHECK: [[args_type]] = !{i32 1, i32 1}
; CHECK: [[none]] = !{!"none", !"none"}
; CHECK: [[ptr]] = !{!"int*", !"int*"}
; CHECK: [[empty]] = !{!"", !""}
; CHECK: [[work_group_hint]] = !{i32 1, i32 1, i32 1}

0 comments on commit 7412943

Please sign in to comment.