Skip to content

Commit

Permalink
[MLIR][OpenMP] Support lowering of host_eval to LLVM IR (#179)
Browse files Browse the repository at this point in the history
This patch updates the MLIR to LLVM IR lowering of `omp.target` to support
passing `num_teams`, `num_threads`, `thread_limit` and SPMD loop bounds through
the `host_eval` argument of `omp.target`.

This replaces the previous implementation where this information was directly
attached to the `omp.target` operation rather than captured to be used by the
corresponding nested operation.
  • Loading branch information
skatrak authored Nov 7, 2024
1 parent 2f3acbd commit 5654efd
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 102 deletions.
38 changes: 31 additions & 7 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -728,13 +728,12 @@ class OpenMPIRBuilder {
LoopBodyGenCallbackTy BodyGenCB, Value *TripCount,
const Twine &Name = "loop");

/// Generator for the control flow structure of an OpenMP canonical loop.
/// Calculate the trip count of a canonical loop.
///
/// Instead of a logical iteration space, this allows specifying user-defined
/// loop counter values using increment, upper- and lower bounds. To
/// disambiguate the terminology when counting downwards, instead of lower
/// bounds we use \p Start for the loop counter value in the first body
/// iteration.
/// This allows specifying user-defined loop counter values using increment,
/// upper- and lower bounds. To disambiguate the terminology when counting
/// downwards, instead of lower bounds we use \p Start for the loop counter
/// value in the first body iteration.
///
/// Consider the following limitations:
///
Expand All @@ -758,7 +757,32 @@ class OpenMPIRBuilder {
///
/// for (int i = 0; i < 42; i -= 1u)
///
//
/// \param Loc The insert and source location description.
/// \param Start Value of the loop counter for the first iterations.
/// \param Stop Loop counter values past this will stop the loop.
/// \param Step Loop counter increment after each iteration; negative
/// means counting down.
/// \param IsSigned Whether Start, Stop and Step are signed integers.
/// \param InclusiveStop Whether \p Stop itself is a valid value for the loop
/// counter.
/// \param Name Base name used to derive instruction names.
///
/// \returns The value holding the calculated trip count.
Value *calculateCanonicalLoopTripCount(const LocationDescription &Loc,
Value *Start, Value *Stop, Value *Step,
bool IsSigned, bool InclusiveStop,
const Twine &Name = "loop");

/// Generator for the control flow structure of an OpenMP canonical loop.
///
/// Instead of a logical iteration space, this allows specifying user-defined
/// loop counter values using increment, upper- and lower bounds. To
/// disambiguate the terminology when counting downwards, instead of lower
/// bounds we use \p Start for the loop counter value in the first body
///
/// It calls \see calculateCanonicalLoopTripCount for trip count calculations,
/// so limitations of that method apply here as well.
///
/// \param Loc The insert and source location description.
/// \param BodyGenCB Callback that will generate the loop body code.
/// \param Start Value of the loop counter for the first iterations.
Expand Down
28 changes: 18 additions & 10 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4032,11 +4032,9 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
return CL;
}

Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
InsertPointTy ComputeIP, const Twine &Name) {

Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
bool IsSigned, bool InclusiveStop, const Twine &Name) {
// Consider the following difficulties (assuming 8-bit signed integers):
// * Adding \p Step to the loop counter which passes \p Stop may overflow:
// DO I = 1, 100, 50
Expand All @@ -4048,9 +4046,7 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
assert(IndVarTy == Stop->getType() && "Stop type mismatch");
assert(IndVarTy == Step->getType() && "Step type mismatch");

LocationDescription ComputeLoc =
ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
updateToLocation(ComputeLoc);
updateToLocation(Loc);

ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
ConstantInt *One = ConstantInt::get(IndVarTy, 1);
Expand Down Expand Up @@ -4090,8 +4086,20 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr);
CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
}
Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
"omp_" + Name + ".tripcount");

return Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
"omp_" + Name + ".tripcount");
}

Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
InsertPointTy ComputeIP, const Twine &Name) {
LocationDescription ComputeLoc =
ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;

Value *TripCount = calculateCanonicalLoopTripCount(
ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);

auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
Builder.restoreIP(CodeGenIP);
Expand Down
16 changes: 3 additions & 13 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1427,8 +1427,7 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) {
EXPECT_EQ(&Loop->getAfter()->front(), RetInst);
}

TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
TEST_F(OpenMPIRBuilderTest, CanonicalLoopTripCount) {
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
IRBuilder<> Builder(BB);
Expand All @@ -1444,17 +1443,8 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
Value *StartVal = ConstantInt::get(LCTy, Start);
Value *StopVal = ConstantInt::get(LCTy, Stop);
Value *StepVal = ConstantInt::get(LCTy, Step);
auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) {
return Error::success();
};
Expected<CanonicalLoopInfo *> LoopResult =
OMPBuilder.createCanonicalLoop(Loc, LoopBodyGenCB, StartVal, StopVal,
StepVal, IsSigned, InclusiveStop);
assert(LoopResult && "unexpected error");
CanonicalLoopInfo *Loop = *LoopResult;
Loop->assertOK();
Builder.restoreIP(Loop->getAfterIP());
Value *TripCount = Loop->getTripCount();
Value *TripCount = OMPBuilder.calculateCanonicalLoopTripCount(
Loc, StartVal, StopVal, StepVal, IsSigned, InclusiveStop);
return cast<ConstantInt>(TripCount)->getValue().getZExtValue();
};

Expand Down
67 changes: 23 additions & 44 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1772,55 +1772,34 @@ LogicalResult TargetOp::verify() {
Operation *TargetOp::getInnermostCapturedOmpOp() {
Dialect *ompDialect = (*this)->getDialect();
Operation *capturedOp = nullptr;
Region *capturedParentRegion = nullptr;

walk<WalkOrder::PostOrder>([&](Operation *op) {
// Process in pre-order to check operations from outermost to innermost,
// ensuring we only enter the region of an operation if it meets the criteria
// for being captured. We stop the exploration of nested operations as soon as
// we process a region with no operation to be captured.
walk<WalkOrder::PreOrder>([&](Operation *op) {
if (op == *this)
return;

// Reset captured op if crossing through an omp.loop_nest, so that the top
// level one will be the one captured.
if (llvm::isa<LoopNestOp>(op)) {
capturedOp = nullptr;
capturedParentRegion = nullptr;
}
return WalkResult::advance();

// Ignore operations of other dialects or omp operations with no regions,
// because these will only be checked if they are siblings of an omp
// operation that can potentially be captured.
bool isOmpDialect = op->getDialect() == ompDialect;
bool hasRegions = op->getNumRegions() > 0;

if (capturedOp) {
bool isImmediateParent = false;
for (Region &region : op->getRegions()) {
if (&region == capturedParentRegion) {
isImmediateParent = true;
capturedParentRegion = op->getParentRegion();
break;
}
}

// Make sure the captured op is part of a (possibly multi-level) nest of
// OpenMP-only operations containing no unsupported siblings at any level.
if ((hasRegions && isOmpDialect != isImmediateParent) ||
(!isImmediateParent && !siblingAllowedInCapture(op))) {
capturedOp = nullptr;
capturedParentRegion = nullptr;
}
} else {
// The first OpenMP dialect op containing a region found while visiting
// in post-order should be the innermost captured OpenMP operation.
if (isOmpDialect && hasRegions) {
capturedOp = op;
capturedParentRegion = op->getParentRegion();

// Don't capture this op if it has a not-allowed sibling.
for (Operation &sibling : op->getParentRegion()->getOps()) {
if (&sibling != op && !siblingAllowedInCapture(&sibling)) {
capturedOp = nullptr;
capturedParentRegion = nullptr;
}
}
}
}
if (!isOmpDialect || !hasRegions)
return WalkResult::skip();

// Don't capture this op if it has a not-allowed sibling, and stop recursing
// into nested operations.
for (Operation &sibling : op->getParentRegion()->getOps())
if (&sibling != op && !siblingAllowedInCapture(&sibling))
return WalkResult::interrupt();

// Don't continue capturing nested operations if we reach an omp.loop_nest.
// Otherwise, process the contents of this operation.
capturedOp = op;
return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
: WalkResult::advance();
});

return capturedOp;
Expand Down
Loading

0 comments on commit 5654efd

Please sign in to comment.