-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixes issues related to kernels calling other kernels (#1246)
Avoids entry points being the target of function calls by introducing a second function for any called kernel.
- Loading branch information
Showing
8 changed files
with
233 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
#include "llvm/IR/Constants.h" | ||
#include "llvm/IR/IRBuilder.h" | ||
#include "llvm/IR/InstIterator.h" | ||
#include "llvm/IR/InstVisitor.h" | ||
#include "llvm/IR/Instructions.h" | ||
#include "llvm/IR/Intrinsics.h" | ||
#include "llvm/Pass.h" | ||
#include "llvm/Transforms/Utils/Cloning.h" | ||
|
||
#include "WrapKernelPass.h" | ||
|
||
using namespace llvm; | ||
|
||
void clspv::WrapKernelPass::runOnFunction(Module &M, llvm::Function *F) { | ||
SmallVector<Type *, 8> NewParamTypes; | ||
for (auto &Arg : F->args()) { | ||
NewParamTypes.push_back(Arg.getType()); | ||
} | ||
|
||
auto *NewFuncTy = FunctionType::get(F->getReturnType(), NewParamTypes, false); | ||
|
||
auto NewFunc = Function::Create(NewFuncTy, F->getLinkage()); | ||
NewFunc->setName(F->getName().str()); | ||
F->setName(F->getName().str() + ".inner"); | ||
NewFunc->setCallingConv(F->getCallingConv()); | ||
NewFunc->copyAttributesFrom(F); | ||
|
||
for (auto &arg : F->args()) { | ||
NewFunc->getArg(arg.getArgNo())->setName(arg.getName()); | ||
} | ||
|
||
F->setCallingConv(CallingConv::SPIR_FUNC); | ||
for (auto &U : F->uses()) { | ||
if (auto CI = dyn_cast<CallInst>(U.getUser())) { | ||
CI->setCallingConv(CallingConv::SPIR_FUNC); | ||
} | ||
} | ||
NewFunc->copyMetadata(F, 0); | ||
F->clearMetadata(); | ||
|
||
IRBuilder<> Builder(BasicBlock::Create(M.getContext(), "entry", NewFunc)); | ||
|
||
// Copy args from src func to new func | ||
// Get the arguments of the source function. | ||
SmallVector<Value *, 8> WrappedArgs; | ||
for (unsigned ArgNum = 0; ArgNum < F->arg_size(); ArgNum++) { | ||
auto *NewArg = NewFunc->getArg(ArgNum); | ||
WrappedArgs.push_back(NewArg); | ||
} | ||
|
||
auto *CallInst = Builder.CreateCall(F, WrappedArgs); | ||
CallInst->setCallingConv(CallingConv::SPIR_FUNC); | ||
Builder.CreateRetVoid(); | ||
|
||
// Insert the function after the original, to preserve ordering | ||
// in the module as much as possible. | ||
auto &FunctionList = M.getFunctionList(); | ||
for (auto Iter = FunctionList.begin(), IterEnd = FunctionList.end(); | ||
Iter != IterEnd; ++Iter) { | ||
if (&*Iter == F) { | ||
FunctionList.insertAfter(Iter, NewFunc); | ||
break; | ||
} | ||
} | ||
} | ||
|
||
bool isCalled(llvm::Function &F) { | ||
for (auto &U : F.uses()) { | ||
if (dyn_cast<CallInst>(U.getUser())) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
PreservedAnalyses clspv::WrapKernelPass::run(llvm::Module &M, | ||
ModuleAnalysisManager &) { | ||
PreservedAnalyses PA; | ||
SmallVector<Function *, 8> FuncsToWrap; | ||
|
||
for (auto &F : M.functions()) { | ||
if (F.getCallingConv() == CallingConv::SPIR_KERNEL && isCalled(F)) { | ||
FuncsToWrap.emplace_back(&F); | ||
} | ||
} | ||
for (unsigned i = 0; i < FuncsToWrap.size(); ++i) { | ||
runOnFunction(M, FuncsToWrap[i]); | ||
} | ||
return PA; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// Copyright 2023 The Clspv Authors. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "llvm/IR/Module.h" | ||
#include "llvm/IR/PassManager.h" | ||
|
||
#ifndef _CLSPV_LIB_WRAP_KERNEL_PASS_H | ||
#define _CLSPV_LIB_WRAP_KERNEL_PASS_H | ||
|
||
namespace clspv { | ||
struct WrapKernelPass : llvm::PassInfoMixin<WrapKernelPass> { | ||
llvm::PreservedAnalyses run(llvm::Module &M, llvm::ModuleAnalysisManager &); | ||
|
||
private: | ||
void runOnFunction(llvm::Module &M, llvm::Function *F); | ||
}; | ||
} // namespace clspv | ||
|
||
#endif // _CLSPV_LIB_WRAP_KERNEL_PASS_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// ; RUN: clspv %s -o %t.spv | ||
// ; RUN: spirv-dis -o %t2.spvasm %t.spv | ||
// ; RUN: FileCheck %s < %t2.spvasm | ||
// ; RUN: spirv-val %t.spv | ||
|
||
__attribute__((work_group_size_hint(1,1,1))) | ||
__attribute__((reqd_work_group_size(1,1,1))) | ||
__kernel void add(global int *A, global int *B) { | ||
// Add 1 to each element of the input buffer. | ||
for (unsigned int i = 0; i < get_global_size(0); i++) { | ||
B[i] = A[i] + 1; | ||
} | ||
} | ||
|
||
// kernel1.clspv | ||
kernel void main_kernel(global int* C, global int* D) { | ||
// Call the function. | ||
add(C, D); | ||
} | ||
|
||
// CHECK: [[extinst:%[a-zA-A0-9_]+]] = OpExtInstImport "NonSemantic.ClspvReflection.5" | ||
// CHECK-DAG: [[kernel_add_name:%[a-zA-Z0-9_]+]] = OpString "add" | ||
// CHECK: [[attributes:%[^ ]+]] = OpString " __attribute__((work_group_size_hint(1, 1, 1))) __attribute__((reqd_work_group_size(1, 1, 1))) __kernel" | ||
// CHECK-DAG: [[k0arg0:%[a-zA-Z0-9_]+]] = OpString "A" | ||
// CHECK-DAG: [[k0arg1:%[a-zA-Z0-9_]+]] = OpString "B" | ||
// CHECK-DAG: [[kernel_main_name:%[a-zA-Z0-9_]+]] = OpString "main_kernel" | ||
// CHECK-DAG: [[k1arg0:%[a-zA-Z0-9_]+]] = OpString "C" | ||
// CHECK-DAG: [[k1arg1:%[a-zA-Z0-9_]+]] = OpString "D" | ||
|
||
// CHECK-DAG: [[kernel_add_def:%[a-zA-A0-9_]+]] = OpExtInst %void [[extinst]] Kernel {{.*}} [[kernel_add_name]] {{.*}} [[attributes]] | ||
// CHECK-NEXT: OpExtInst %void [[extinst]] PropertyRequiredWorkgroupSize [[kernel_add_def]] {{.*}} {{.*}} {{.*}} | ||
// CHECK: OpExtInst %void [[extinst]] ArgumentInfo [[k0arg0]] | ||
// CHECK: OpExtInst %void [[extinst]] ArgumentInfo [[k0arg1]] | ||
// CHECK-DAG: OpExtInst %void [[extinst]] Kernel {{.*}} [[kernel_main_name]] {{.*}} {{.*}} {{.*}} | ||
// CHECK: OpExtInst %void [[extinst]] ArgumentInfo [[k1arg0]] | ||
// CHECK: OpExtInst %void [[extinst]] ArgumentInfo [[k1arg1]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
; RUN: clspv-opt %s -o %t.ll --passes=wrap-kernel | ||
; RUN: FileCheck %s < %t.ll | ||
|
||
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" | ||
target triple = "spir-unknown-unknown" | ||
|
||
define dso_local spir_kernel void @add(ptr addrspace(1) align 4 %A, ptr addrspace(1) align 4 %B) #2 !kernel_arg_addr_space !15 !kernel_arg_access_qual !16 !kernel_arg_type !17 !kernel_arg_base_type !17 !kernel_arg_type_qual !18 !work_group_size_hint !19 !reqd_work_group_size !19 { | ||
entry: | ||
%A.addr = alloca ptr addrspace(1), align 4 | ||
%B.addr = alloca ptr addrspace(1), align 4 | ||
%i = alloca i32, align 4 | ||
store ptr addrspace(1) %A, ptr %A.addr, align 4 | ||
store ptr addrspace(1) %B, ptr %B.addr, align 4 | ||
store i32 0, ptr %i, align 4 | ||
ret void | ||
} | ||
|
||
define dso_local spir_kernel void @main_kernel(ptr addrspace(1) align 4 %A, ptr addrspace(1) align 4 %B) #2 !kernel_arg_addr_space !15 !kernel_arg_access_qual !16 !kernel_arg_type !17 !kernel_arg_base_type !17 !kernel_arg_type_qual !18 { | ||
entry: | ||
%A.addr = alloca ptr addrspace(1), align 4 | ||
%B.addr = alloca ptr addrspace(1), align 4 | ||
store ptr addrspace(1) %A, ptr %A.addr, align 4 | ||
store ptr addrspace(1) %B, ptr %B.addr, align 4 | ||
%0 = load ptr addrspace(1), ptr %A.addr, align 4 | ||
%1 = load ptr addrspace(1), ptr %B.addr, align 4 | ||
call spir_kernel void @add(ptr addrspace(1) align 4 %0, ptr addrspace(1) align 4 %1) #5 | ||
ret void | ||
} | ||
|
||
|
||
!0 = !{i32 1, !"wchar_size", i32 4} | ||
!1 = !{i32 7, !"direct-access-external-data", i32 0} | ||
!2 = !{i32 7, !"frame-pointer", i32 2} | ||
!3 = !{i32 1, i32 2} | ||
!4 = !{!"clang version 18.0.0 (https://github.com/llvm/llvm-project c7d65e4466eafe518937c59ef9a242234ed7a08a)"} | ||
!5 = !{!"clang version 17.0.0 (https://github.com/llvm/llvm-project 1e6fc9626c0f49ce952a67aef47e86253d13f74a)"} | ||
!6 = !{!"clang version 17.0.0 (https://github.com/llvm/llvm-project ab674234c440ed27302f58eeccc612c83b32c43f)"} | ||
!7 = !{!"_Z4sqrtf", !" __attribute__((overloadable)) __attribute__((const))"} | ||
!8 = !{!"_Z4sqrtDv2_f", !" __attribute__((overloadable)) __attribute__((const))"} | ||
!9 = !{!"_Z4sqrtDv3_f", !" __attr((overloadable)) __attribute__((const))"} | ||
!10 = !{!"_Z4sqrtDv4_f", !" __attribute__((overloadable)) __aibute__ttribute__((const))"} | ||
!11 = !{!"_Z4sqrtDv8_f", !" __attribute__((overloadable)) __attribute__((const))"} | ||
!12 = !{!"_Z4sqrtDv16_f", !" __attribute__((overloadable)) __attribute__((const))"} | ||
!13 = !{!"add", !" __attribute__((work_group_size_hint(1, 1, 1))) __attribute__((reqd_work_group_size(1, 1, 1))) __kernel"} | ||
!14 = !{!"main_kernel", !" kernel"} | ||
!15 = !{i32 1, i32 1} | ||
!16 = !{!"none", !"none"} | ||
!17 = !{!"int*", !"int*"} | ||
!18 = !{!"", !""} | ||
!19 = !{i32 1, i32 1, i32 1} | ||
|
||
; CHECK:define dso_local spir_func void @add.inner(ptr addrspace(1) align [[alignment:[0-9]*]] %A, ptr addrspace(1) align [[alignment]] %B) | ||
; CHECK-NEXT: entry: | ||
; CHECK: ret void | ||
|
||
; CHECK: define dso_local spir_kernel void @add(ptr addrspace(1) align [[alignment]] %A, ptr addrspace(1) align [[alignment]] %B) !kernel_arg_addr_space [[args_type:![0-9]*]] !kernel_arg_access_qual [[none:![0-9]*]] !kernel_arg_type [[ptr:![0-9]*]] !kernel_arg_base_type [[ptr]] !kernel_arg_type_qual [[empty:![0-9]*]] !work_group_size_hint [[work_group_hint:![0-9]*]] !reqd_work_group_size [[work_group_hint]] | ||
; CHECK-NEXT: entry: | ||
; CHECK-NEXT: call spir_func void @add.inner(ptr addrspace(1) %A, ptr addrspace(1) %B) | ||
; CHECK: ret void | ||
|
||
; CHECK: define dso_local spir_kernel void @main_kernel(ptr addrspace(1) align [[alignment]] %A, ptr addrspace(1) align [[alignment]] %B) !kernel_arg_addr_space [[args_type]] !kernel_arg_access_qual [[none]] !kernel_arg_type [[ptr]] !kernel_arg_base_type [[ptr]] !kernel_arg_type_qual [[empty]] | ||
; CHECK-NEXT: entry: | ||
; CHECK-DAG: [[paramA:%[a-zA-A0-9_]+]] = load ptr addrspace(1), ptr %A.addr, align [[alignment]] | ||
; CHECK-DAG: [[paramB:%[a-zA-A0-9_]+]] = load ptr addrspace(1), ptr %B.addr, align [[alignment]] | ||
; CHECK: call spir_func void @add.inner(ptr addrspace(1) align [[alignment]] [[paramA]], ptr addrspace(1) align [[alignment]] [[paramB]]) | ||
; CHECK ret void | ||
|
||
; CHECK: [[args_type]] = !{i32 1, i32 1} | ||
; CHECK: [[none]] = !{!"none", !"none"} | ||
; CHECK: [[ptr]] = !{!"int*", !"int*"} | ||
; CHECK: [[empty]] = !{!"", !""} | ||
; CHECK: [[work_group_hint]] = !{i32 1, i32 1, i32 1} |