Skip to content

Commit

Permalink
Adds specialize
Browse files Browse the repository at this point in the history
  • Loading branch information
koparasy committed Jan 22, 2025
1 parent f88d62e commit 4be26c0
Showing 1 changed file with 65 additions and 51 deletions.
116 changes: 65 additions & 51 deletions lib/JitEngineDevice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,12 @@ template <typename ImplT> class JitEngineDevice : public JitEngine {
// End Methods implemented in the derived device engine class.
//------------------------------------------------------------------

Function *pruneIR(Module &M, StringRef FnName);
void pruneIR(Module &M, StringRef FnName);

void specializeIR(Module &M, StringRef FnName, StringRef Suffix,
dim3 &BlockDim, dim3 &GridDim,
const SmallVector<int32_t> &RCIndices, RuntimeConstant *RC,
int NumRuntimeConstants);

void replaceGlobalVariablesWithPointers(
Module &M,
Expand Down Expand Up @@ -457,13 +462,65 @@ template <typename ImplT> class JitEngineDevice : public JitEngine {
};

template <typename ImplT>
Function *JitEngineDevice<ImplT>::pruneIR(Module &M, StringRef FnName) {
TIMESCOPE("pruneIR");
PROTEUS_DBG(Logger::logs("proteus") << "=== Parsed Module\n"
<< M << "=== End of Parsed Module\n");
void JitEngineDevice<ImplT>::specializeIR(Module &M, StringRef FnName,
StringRef Suffix, dim3 &BlockDim,
dim3 &GridDim,
const SmallVector<int32_t> &RCIndices,
RuntimeConstant *RC,
int NumRuntimeConstants) {
TIMESCOPE("specializeIR");
Function *F = M.getFunction(FnName);

assert(F && "Expected non-null function!");
// Replace argument uses with runtime constants.
if (Config.ENV_PROTEUS_SPECIALIZE_ARGS)
// TODO: change NumRuntimeConstants to size_t at interface.
TransformArgumentSpecialization::transform(
M, *F, RCIndices,
ArrayRef<RuntimeConstant>{RC,
static_cast<size_t>(NumRuntimeConstants)});

// Replace uses of blockDim.* and gridDim.* with constants.
if (Config.ENV_PROTEUS_SPECIALIZE_DIMS) {
setKernelDims(M, GridDim, BlockDim);
}

// Internalize others besides the kernel function.
internalizeModule(M, [&F](const GlobalValue &GV) {
// Do not internalize the kernel function.
if (&GV == F)
return true;

// Internalize everything else.
return false;
});

PROTEUS_DBG(Logger::logs("proteus") << "=== JIT Module\n"
<< M << "=== End of JIT Module\n");

F->setName(FnName + Suffix);

if (Config.ENV_PROTEUS_SET_LAUNCH_BOUNDS)
setLaunchBoundsForKernel(M, *F, GridDim.x * GridDim.y * GridDim.z,
BlockDim.x * BlockDim.y * BlockDim.z);

runCleanupPassPipeline(M);

#if PROTEUS_ENABLE_DEBUG
Logger::logs("proteus") << "=== Final Module\n"
<< M << "=== End Final Module\n";
if (verifyModule(M, &errs()))
FATAL_ERROR("Broken module found, JIT compilation aborted!");
else
Logger::logs("proteus") << "Module verified!\n";
#endif
}

template <typename ImplT>
void JitEngineDevice<ImplT>::pruneIR(Module &M, StringRef FnName) {
TIMESCOPE("pruneIR");
PROTEUS_DBG(Logger::logs("proteus") << "=== Parsed Module\n"
<< M << "=== End of Parsed Module\n");
// Remove llvm.global.annotations now that we have read them.
if (auto *GlobalAnnotations = M.getGlobalVariable("llvm.global.annotations"))
M.eraseGlobalVariable(GlobalAnnotations);
Expand Down Expand Up @@ -495,8 +552,6 @@ Function *JitEngineDevice<ImplT>::pruneIR(Module &M, StringRef FnName) {
for (auto &GV : M.globals())
if (GV.isExternallyInitialized())
GV.setExternallyInitialized(false);

return F;
}

template <typename ImplT>
Expand Down Expand Up @@ -615,51 +670,10 @@ JitEngineDevice<ImplT>::compileAndRun(
// in memory module, for every annotated kernel. If we have a case of 1000s of
// kernels, this can be an issue

auto *JitFunction = pruneIR(*JitModule, KernelName);

// Internalize others besides the kernel function.
internalizeModule(*JitModule, [&JitFunction](const GlobalValue &GV) {
// Do not internalize the kernel function.
if (&GV == JitFunction)
return true;

// Internalize everything else.
return false;
});
pruneIR(*JitModule, KernelName);

// Replace argument uses with runtime constants.
if (Config.ENV_PROTEUS_SPECIALIZE_ARGS)
// TODO: change NumRuntimeConstants to size_t at interface.
TransformArgumentSpecialization::transform(
*JitModule, *JitFunction, RCIndices,
ArrayRef<RuntimeConstant>{RCsVec.data(),
static_cast<size_t>(NumRuntimeConstants)});

// Replace uses of blockDim.* and gridDim.* with constants.
if (Config.ENV_PROTEUS_SPECIALIZE_DIMS)
setKernelDims(*JitModule, GridDim, BlockDim);

PROTEUS_DBG(Logger::logs("proteus")
<< "=== JIT Module\n"
<< *JitModule << "=== End of JIT Module\n");

JitFunction->setName(KernelName + Suffix);

if (Config.ENV_PROTEUS_SET_LAUNCH_BOUNDS)
setLaunchBoundsForKernel(*JitModule, *JitFunction,
GridDim.x * GridDim.y * GridDim.z,
BlockDim.x * BlockDim.y * BlockDim.z);

runCleanupPassPipeline(*JitModule);

#if PROTEUS_ENABLE_DEBUG
Logger::logs("proteus") << "=== Final Module\n"
<< *JitModule << "=== End Final Module\n";
if (verifyModule(*JitModule, &errs()))
FATAL_ERROR("Broken module found, JIT compilation aborted!");
else
Logger::logs("proteus") << "Module verified!\n";
#endif
specializeIR(*JitModule, KernelName, Suffix, BlockDim, GridDim, RCIndices,
RCsVec.data(), NumRuntimeConstants);

replaceGlobalVariablesWithPointers(*JitModule, VarNameToDevPtr);

Expand Down

0 comments on commit 4be26c0

Please sign in to comment.