Skip to content

Commit

Permalink
[OMPIRBuilder] Improve 'target if' implementation (#222)
Browse files Browse the repository at this point in the history
This patch cleans up support for the 'if' clause on 'target' directives by
reusing the existing `emitIfClause()` function rather than duplicating code.
One side effect of this change is that constant 'if' values will no longer
result in the generation of branches. Instead, code is generated only for the
applicable case.

It also adds a unit test to ensure it works as expected.
  • Loading branch information
skatrak authored Dec 12, 2024
1 parent 73deda0 commit c17a3dd
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 50 deletions.
83 changes: 33 additions & 50 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7491,7 +7491,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
auto TaskBodyCB =
[&](Value *DeviceID, Value *RTLoc,
IRBuilderBase::InsertPoint TargetTaskAllocaIP) -> Error {
llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = [&]() {
// Assume no error was returned because EmitTargetCallFallbackCB doesn't
// produce any.
llvm::OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
// emitKernelLaunch makes the necessary runtime call to offload the
// kernel. We then outline all that code into a separate function
// ('kernel_launch_function' in the pseudo code above). This function is
Expand All @@ -7506,17 +7508,18 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
// When OutlinedFnID is set to nullptr, then it's not an offloading call.
// In this case, we execute the host implementation directly.
return EmitTargetCallFallbackCB(OMPBuilder.Builder.saveIP());
}();
}());

if (!AfterIP)
return AfterIP.takeError();

OMPBuilder.Builder.restoreIP(*AfterIP);
OMPBuilder.Builder.restoreIP(AfterIP);
return Error::success();
};

auto &&EmitTargetCallElse = [&]() {
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = [&]() {
auto &&EmitTargetCallElse =
[&](OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
// Assume no error was returned because EmitTargetCallFallbackCB doesn't
// produce any.
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
if (RequiresOuterTargetTask) {
// Arguments that are intended to be directly forwarded to an
// emitKernelLaunch call are pased as nullptr, since
Expand All @@ -7525,24 +7528,24 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
/*RTLoc=*/nullptr, AllocaIP,
Dependencies, HasNoWait);
}
return EmitTargetCallFallbackCB(Builder.saveIP());
}();
return EmitTargetCallFallbackCB(CodeGenIP);
}());

// Assume no error was returned because EmitTargetCallFallbackCB doesn't
// produce any. The 'if' check enables accessing the returned value.
if (AfterIP)
Builder.restoreIP(*AfterIP);
Builder.restoreIP(AfterIP);
return Error::success();
};

auto &&EmitTargetCallThen = [&]() {
auto &&EmitTargetCallThen =
[&](OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/false,
/*SeparateBeginEndCalls=*/true);

OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(CodeGenIP);
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
RTArgs, MapInfo,
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, CodeGenIP, Info, RTArgs,
MapInfo,
/*IsNonContiguous=*/true,
/*ForEndCall=*/false);

Expand Down Expand Up @@ -7607,59 +7610,39 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
NumTeamsC, NumThreadsC,
DynCGGroupMem, HasNoWait);

// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = [&]() {
// Assume no error was returned because TaskBodyCB and
// EmitTargetCallFallbackCB don't produce any.
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
if (RequiresOuterTargetTask)
return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
Dependencies, HasNoWait);

return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
EmitTargetCallFallbackCB, KArgs,
DeviceID, RTLoc, AllocaIP);
}();
}());

// Assume no error was returned because TaskBodyCB and
// EmitTargetCallFallbackCB don't produce any. The 'if' check enables
// accessing the returned value.
if (AfterIP)
Builder.restoreIP(*AfterIP);
Builder.restoreIP(AfterIP);
return Error::success();
};

// If we don't have an ID for the target region, it means an offload entry
// wasn't created. In this case we just run the host fallback directly.
if (!OutlinedFnID) {
EmitTargetCallElse();
cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP()));
return;
}

// If there's no IF clause, only generate the kernel launch code path.
if (!IfCond) {
EmitTargetCallThen();
cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP()));
return;
}

// Create if-else to handle IF clause.
llvm::BasicBlock *ThenBlock =
BasicBlock::Create(Builder.getContext(), "omp_if.then");
llvm::BasicBlock *ElseBlock =
BasicBlock::Create(Builder.getContext(), "omp_if.else");
llvm::BasicBlock *ContBlock =
BasicBlock::Create(Builder.getContext(), "omp_if.end");
Builder.CreateCondBr(IfCond, ThenBlock, ElseBlock);

Function *CurFn = Builder.GetInsertBlock()->getParent();

// Emit the 'then' code.
OMPBuilder.emitBlock(ThenBlock, CurFn);
EmitTargetCallThen();
OMPBuilder.emitBranch(ContBlock);
// Emit the 'else' code.
OMPBuilder.emitBlock(ElseBlock, CurFn);
EmitTargetCallElse();
OMPBuilder.emitBranch(ContBlock);
// Emit the continuation block.
OMPBuilder.emitBlock(ContBlock, CurFn, /*IsFinished=*/true);
cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen,
EmitTargetCallElse, AllocaIP));
}

OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
Expand Down
34 changes: 34 additions & 0 deletions mlir/test/Target/LLVMIR/omptarget-if.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: mlir-translate -mlir-to-llvmir %s 2>&1 | FileCheck %s

// Set a dummy target triple to enable target region outlining.
module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
llvm.func @foo(%0 : i1) {
omp.target if(%0) {
omp.terminator
}
llvm.return
}

// CHECK: define void @foo(i1 %[[COND:.*]]) {

// CHECK: br i1 %[[COND]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]]

// CHECK: [[THEN_LABEL]]:
// CHECK: %[[RES:.*]] = call i32 @__tgt_target_kernel({{.*}})
// CHECK-NEXT: %[[OFFLOAD_CHECK:.*]] = icmp ne i32 %[[RES]], 0
// CHECK-NEXT: br i1 %[[OFFLOAD_CHECK]], label %[[OFF_FAIL_LABEL:.*]], label %[[OFF_SUCC_LABEL:.*]]

// CHECK: [[OFF_FAIL_LABEL]]:
// CHECK-NEXT: call void @[[FALLBACK_FN:.*]]()
// CHECK-NEXT: br label %[[OFF_CONT_LABEL:.*]]

// CHECK: [[OFF_CONT_LABEL]]:
// CHECK-NEXT: br label %[[END_LABEL:.*]]

// CHECK: [[ELSE_LABEL]]:
// CHECK-NEXT: call void @[[FALLBACK_FN]]()
// CHECK-NEXT: br label %[[END_LABEL]]

// CHECK: [[END_LABEL]]:
// CHECK-NEXT: ret void
}

0 comments on commit c17a3dd

Please sign in to comment.