diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 1bc1c30e283fe53..3e11867d0ec85a4 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -33,7 +33,6 @@ #include "flang/Parser/parse-tree.h" #include "flang/Semantics/openmp-directive-sets.h" #include "flang/Semantics/tools.h" -#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Transforms/RegionUtils.h" @@ -45,26 +44,6 @@ using namespace Fortran::lower::omp; //===----------------------------------------------------------------------===// // Code generation helper functions //===----------------------------------------------------------------------===// - -/// Add to the given target operation a host_eval argument, which must be -/// defined outside. -/// -/// \return the entry block argument to represent \c hostVar inside of the -/// target region. -static mlir::Value addHostEvalVar(mlir::omp::TargetOp targetOp, - mlir::Value hostVar) { - assert(!targetOp.getRegion().isAncestor(hostVar.getParentRegion()) && - "variable must be defined outside of the target region"); - - auto argIface = llvm::cast(*targetOp); - unsigned insertIndex = - argIface.getHostEvalBlockArgsStart() + argIface.numHostEvalBlockArgs(); - - targetOp.getHostEvalVarsMutable().append(hostVar); - return targetOp.getRegion().insertArgument(insertIndex, hostVar.getType(), - hostVar.getLoc()); -} - namespace { /// Structure holding the information needed to create and bind entry block /// arguments associated to a single clause. @@ -83,7 +62,7 @@ struct EntryBlockArgsEntry { /// Structure holding the information needed to create and bind entry block /// arguments associated to all clauses that can define them. struct EntryBlockArgs { - EntryBlockArgsEntry hostEval; + llvm::ArrayRef hostEval; EntryBlockArgsEntry inReduction; EntryBlockArgsEntry map; EntryBlockArgsEntry priv; @@ -93,8 +72,8 @@ struct EntryBlockArgs { EntryBlockArgsEntry useDevicePtr; bool isValid() const { - return hostEval.isValid() && inReduction.isValid() && map.isValid() && - priv.isValid() && reduction.isValid() && taskReduction.isValid() && + return inReduction.isValid() && map.isValid() && priv.isValid() && + reduction.isValid() && taskReduction.isValid() && useDeviceAddr.isValid() && useDevicePtr.isValid(); } @@ -110,8 +89,81 @@ struct EntryBlockArgs { taskReduction.vars, useDeviceAddr.vars, useDevicePtr.vars); } }; + +/// Structure holding information that is needed to pass host-evaluated +/// information to later lowering stages. +struct HostEvalInfo { + mlir::omp::HostEvaluatedOperands ops; + llvm::SmallVector iv; + + /// Fill \c vars with values stored in \c ops. + /// + /// The order in which values are stored matches the one expected by \see + /// bindOperands(). + void collectValues(llvm::SmallVectorImpl &vars) const { + vars.append(ops.loopLowerBounds); + vars.append(ops.loopUpperBounds); + vars.append(ops.loopSteps); + + if (ops.numTeamsLower) + vars.push_back(ops.numTeamsLower); + + if (ops.numTeamsUpper) + vars.push_back(ops.numTeamsUpper); + + if (ops.numThreads) + vars.push_back(ops.numThreads); + + if (ops.threadLimit) + vars.push_back(ops.threadLimit); + } + + /// Update \c ops, replacing all values with the corresponding block argument + /// in \c args. + /// + /// The order in which values are stored in \c args is the same as the one + /// used by \see collectValues(). + void bindOperands(llvm::ArrayRef args) { + assert(args.size() == + ops.loopLowerBounds.size() + ops.loopUpperBounds.size() + + ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) + + (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) + + (ops.threadLimit ? 1 : 0) && + "invalid block argument list"); + int index = 0; + for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i) + ops.loopLowerBounds[i] = args[index++]; + + for (size_t i = 0; i < ops.loopUpperBounds.size(); ++i) + ops.loopUpperBounds[i] = args[index++]; + + for (size_t i = 0; i < ops.loopSteps.size(); ++i) + ops.loopSteps[i] = args[index++]; + + if (ops.numTeamsLower) + ops.numTeamsLower = args[index++]; + + if (ops.numTeamsUpper) + ops.numTeamsUpper = args[index++]; + + if (ops.numThreads) + ops.numThreads = args[index++]; + + if (ops.threadLimit) + ops.threadLimit = args[index++]; + } +}; } // namespace +/// Stack of \see HostEvalInfo to represent the current nest of \c omp.target +/// operations being created. +/// +/// The current implementation prevents nested 'target' regions from breaking +/// the handling of the outer region by keeping a stack of information +/// structures, but it will probably still require some further work to support +/// reverse offloading. +static llvm::SmallVector hostEvalInfo; + /// Get the directive enumeration value corresponding to the given OpenMP /// construct PFT node. llvm::omp::Directive @@ -175,124 +227,179 @@ extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) { ompConstruct.u); } -/// Check whether the parent of the given evaluation contains other evaluations. -static bool evalHasSiblings(const lower::pft::Evaluation &eval) { - auto checkSiblings = [&eval](const lower::pft::EvaluationList &siblings) { - for (auto &sibling : siblings) - if (&sibling != &eval && !sibling.isEndStmt()) - return true; +/// Populate the global \see hostEvalInfo after processing clauses for the given +/// \p eval OpenMP target construct, or nested constructs, if these must be +/// evaluated outside of the target region per the spec. +/// +/// In particular, this will ensure that in 'target teams' and equivalent nested +/// constructs, the \c thread_limit and \c num_teams clauses will be evaluated +/// in the host. Additionally, loop bounds, steps and the \c num_threads clause +/// will also be evaluated in the host if a target SPMD construct is detected +/// (i.e. 'target teams distribute parallel do [simd]' or equivalent nesting). +/// +/// The result, stored as a global, is intended to be used to populate the \c +/// host_eval operands of the associated \c omp.target operation, and also to be +/// checked and used by later lowering steps to populate the corresponding +/// operands of the \c omp.teams, \c omp.parallel or \c omp.loop_nest +/// operations. +static void processHostEvalClauses(lower::AbstractConverter &converter, + semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, + lower::pft::Evaluation &eval, + mlir::Location loc) { + // Obtain the list of clauses of the given OpenMP block or loop construct + // evaluation. Other evaluations passed to this lambda keep `clauses` + // unchanged. + auto extractClauses = [&semaCtx](lower::pft::Evaluation &eval, + List &clauses) { + const auto *ompEval = eval.getIf(); + if (!ompEval) + return; - return false; + const parser::OmpClauseList *beginClauseList = nullptr; + const parser::OmpClauseList *endClauseList = nullptr; + common::visit( + common::visitors{ + [&](const parser::OpenMPBlockConstruct &ompConstruct) { + const auto &beginDirective = + std::get(ompConstruct.t); + beginClauseList = + &std::get(beginDirective.t); + endClauseList = &std::get( + std::get(ompConstruct.t).t); + }, + [&](const parser::OpenMPLoopConstruct &ompConstruct) { + const auto &beginDirective = + std::get(ompConstruct.t); + beginClauseList = + &std::get(beginDirective.t); + + if (auto &endDirective = + std::get>( + ompConstruct.t)) + endClauseList = + &std::get(endDirective->t); + }, + [&](const auto &) {}}, + ompEval->u); + + assert(beginClauseList && "expected begin directive"); + clauses.append(makeClauses(*beginClauseList, semaCtx)); + + if (endClauseList) + clauses.append(makeClauses(*endClauseList, semaCtx)); }; - return eval.parent.visit(common::visitors{ - [&](const lower::pft::Program &parent) { - return parent.getUnits().size() + parent.getCommonBlocks().size() > 1; - }, - [&](const lower::pft::Evaluation &parent) { - return checkSiblings(*parent.evaluationList); - }, - [&](const auto &parent) { - return checkSiblings(parent.evaluationList); - }}); -} + // Return the directive that is immediately nested inside of the given + // `parent` evaluation, if it is its only non-end-statement nested evaluation + // and it represents an OpenMP construct. + auto extractOnlyOmpNestedDir = [](lower::pft::Evaluation &parent) + -> std::optional { + if (!parent.hasNestedEvaluations()) + return std::nullopt; + + llvm::omp::Directive dir; + auto &nested = parent.getFirstNestedEvaluation(); + if (const auto *ompEval = nested.getIf()) + dir = extractOmpDirective(*ompEval); + else + return std::nullopt; -/// Check whether the given omp.target operation exists and we're compiling for -/// the host device. -static bool isHostTarget(mlir::omp::TargetOp targetOp) { - if (!targetOp) - return false; + for (auto &sibling : parent.getNestedEvaluations()) + if (&sibling != &nested && !sibling.isEndStmt()) + return std::nullopt; - auto offloadModOp = llvm::cast( - *targetOp->getParentOfType()); + return dir; + }; - return !offloadModOp.getIsTargetDevice(); -} + // Process the given evaluation assuming it's part of a 'target' construct or + // captured by one, and store results in the global `hostEvalInfo`. + std::function &)> + processEval; + processEval = [&](lower::pft::Evaluation &eval, const List &clauses) { + using namespace llvm::omp; + ClauseProcessor cp(converter, semaCtx, clauses); -/// Check whether a given evaluation points to an OpenMP loop construct that -/// represents a target SPMD kernel. For this to be true, it must be a `target -/// teams distribute parallel do [simd]` or equivalent construct. -/// -/// Currently, this is limited to cases where all relevant OpenMP constructs are -/// either combined or directly nested within the same function. Also, the -/// composite `distribute parallel do` is not identified if split into two -/// explicit nested loops (i.e. a `distribute` loop and a `parallel do` loop). -static bool isTargetSPMDLoop(const lower::pft::Evaluation &eval) { - using namespace llvm::omp; + // Call `processEval` recursively with the immediately nested evaluation and + // its corresponding clauses if there is a single nested evaluation + // representing an OpenMP directive that passes the given test. + auto processSingleNestedIf = [&](llvm::function_ref test) { + std::optional nestedDir = extractOnlyOmpNestedDir(eval); + if (!nestedDir || !test(*nestedDir)) + return; - const auto *ompEval = eval.getIf(); - if (!ompEval) - return false; + lower::pft::Evaluation &nestedEval = eval.getFirstNestedEvaluation(); + List nestedClauses; + extractClauses(nestedEval, nestedClauses); + processEval(nestedEval, nestedClauses); + }; - switch (extractOmpDirective(*ompEval)) { - case OMPD_distribute_parallel_do: - case OMPD_distribute_parallel_do_simd: { - // It will return true only if one of these are true: - // - It has a 'target teams' parent and no siblings. - // - It has a 'teams' parent and no siblings, and the 'teams' has a - // 'target' parent and no siblings. - if (evalHasSiblings(eval)) - return false; - - const auto *parentEval = eval.parent.getIf(); - if (!parentEval) - return false; - - const auto *parentOmpEval = parentEval->getIf(); - if (!parentOmpEval) - return false; - - auto parentDir = extractOmpDirective(*parentOmpEval); - if (parentDir == OMPD_target_teams) - return true; - - if (parentDir != OMPD_teams) - return false; - - if (evalHasSiblings(*parentEval)) - return false; - - const auto *parentOfParentEval = - parentEval->parent.getIf(); - if (!parentOfParentEval) - return false; - - const auto *parentOfParentOmpEval = - parentOfParentEval->getIf(); - return parentOfParentOmpEval && - extractOmpDirective(*parentOfParentOmpEval) == OMPD_target; - } - case OMPD_teams_distribute_parallel_do: - case OMPD_teams_distribute_parallel_do_simd: { - // Check there's a 'target' parent and no siblings. - if (evalHasSiblings(eval)) - return false; - - const auto *parentEval = eval.parent.getIf(); - if (!parentEval) - return false; - - const auto *parentOmpEval = parentEval->getIf(); - return parentOmpEval && extractOmpDirective(*parentOmpEval) == OMPD_target; - } - case OMPD_target_teams_distribute_parallel_do: - case OMPD_target_teams_distribute_parallel_do_simd: - return true; - default: - return false; - } -} + const auto *ompEval = eval.getIf(); + if (!ompEval) + return; -static mlir::omp::TargetOp findParentTargetOp(mlir::OpBuilder &builder) { - mlir::Operation *parentOp = builder.getBlock()->getParentOp(); - if (!parentOp) - return nullptr; + HostEvalInfo &hostInfo = hostEvalInfo.back(); + + switch (extractOmpDirective(*ompEval)) { + // Cases where 'teams' and 'target' SPMD clauses might be present. + case OMPD_teams_distribute_parallel_do: + case OMPD_teams_distribute_parallel_do_simd: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams_distribute_parallel_do: + case OMPD_target_teams_distribute_parallel_do_simd: + cp.processNumTeams(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_distribute_parallel_do: + case OMPD_distribute_parallel_do_simd: + cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv); + cp.processNumThreads(stmtCtx, hostInfo.ops); + break; + + // Cases where 'teams' clauses might be present, and 'target' SPMD is + // possible by looking at nested evaluations. + case OMPD_teams: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams: + cp.processNumTeams(stmtCtx, hostInfo.ops); + processSingleNestedIf([](Directive nestedDir) { + return nestedDir == OMPD_distribute_parallel_do || + nestedDir == OMPD_distribute_parallel_do_simd; + }); + break; + + // Cases where only 'teams' host-evaluated clauses might be present. + case OMPD_teams_distribute: + case OMPD_teams_distribute_simd: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams_distribute: + case OMPD_target_teams_distribute_simd: + cp.processNumTeams(stmtCtx, hostInfo.ops); + break; + + // Standalone 'target' case. + case OMPD_target: { + processSingleNestedIf( + [](Directive nestedDir) { return topTeamsSet.test(nestedDir); }); + break; + } + default: + break; + } + }; - auto targetOp = llvm::dyn_cast(parentOp); - if (!targetOp) - targetOp = parentOp->getParentOfType(); + const auto *ompEval = eval.getIf(); + assert(ompEval && + llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) && + "expected TARGET construct evaluation"); - return targetOp; + // Use the whole list of clauses passed to the construct here, rather than the + // ones applicable to omp.target. + List clauses; + extractClauses(eval, clauses); + processEval(eval, clauses); } static void genOMPDispatch(lower::AbstractConverter &converter, @@ -423,8 +530,8 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter, }; // Process in clause name alphabetical order to match block arguments order. - bindPrivateLike(args.hostEval.syms, args.hostEval.vars, - op.getHostEvalBlockArgs()); + // Do not bind host_eval variables because they cannot be used inside of the + // corresponding region except for very specific cases. bindPrivateLike(args.inReduction.syms, args.inReduction.vars, op.getInReductionBlockArgs()); bindMapLike(args.map.syms, op.getMapBlockArgs()); @@ -489,213 +596,6 @@ static void genNestedEvaluations(lower::AbstractConverter &converter, converter.genEval(e); } -static bool mustEvalTeamsOutsideTarget(const lower::pft::Evaluation &eval, - mlir::omp::TargetOp targetOp) { - if (!isHostTarget(targetOp)) - return false; - - llvm::omp::Directive dir = - extractOmpDirective(eval.get()); - return llvm::omp::allTeamsSet.test(dir) && - (llvm::omp::allTargetSet.test(dir) || !evalHasSiblings(eval)); -} - -static bool mustEvalTargetSPMDOutsideTarget(const lower::pft::Evaluation &eval, - mlir::omp::TargetOp targetOp) { - if (!isHostTarget(targetOp)) - return false; - - return isTargetSPMDLoop(eval); -} - -//===----------------------------------------------------------------------===// -// HostClausesInsertionGuard -//===----------------------------------------------------------------------===// - -/// If the insertion point of the builder is located inside of an omp.target -/// region, this RAII guard moves the insertion point to just before that -/// omp.target operation and then restores the original insertion point when -/// destroyed. If not currently inserting inside an omp.target, it remains -/// unchanged. -class HostClausesInsertionGuard { -public: - HostClausesInsertionGuard(mlir::OpBuilder &builder) : builder(builder) { - targetOp = findParentTargetOp(builder); - if (targetOp) { - ip = builder.saveInsertionPoint(); - builder.setInsertionPoint(targetOp); - } - } - - ~HostClausesInsertionGuard() { - if (ip.isSet()) { - fixupExtractedHostOps(); - builder.restoreInsertionPoint(ip); - } - } - - mlir::omp::TargetOp getTargetOp() const { return targetOp; } - -private: - mlir::OpBuilder &builder; - mlir::OpBuilder::InsertPoint ip; - mlir::omp::TargetOp targetOp; - - // Finds the list of op operands that escape the target op's region; that is: - // the operands that are used outside the target op but defined inside it. - void - findEscapingOpOperands(llvm::DenseSet &escapingOperands) { - if (!targetOp) - return; - - mlir::Region *targetParentRegion = targetOp->getParentRegion(); - assert(targetParentRegion != nullptr && - "Expected omp.target op to be nested in a parent region."); - - llvm::DenseSet visitedOps; - - // Walk the parent region in pre-order to make sure we visit `targetOp` - // before its nested ops. - targetParentRegion->walk( - [&](mlir::Operation *op) { - // Once we come across `targetOp`, we interrupt the walk since we - // already visited all the ops that come before it in the region. - if (op == targetOp) - return mlir::WalkResult::interrupt(); - - for (mlir::OpOperand &operand : op->getOpOperands()) { - mlir::Operation *operandDefiningOp = operand.get().getDefiningOp(); - - if (operandDefiningOp == nullptr) - continue; - - if (visitedOps.contains(operandDefiningOp)) - continue; - - visitedOps.insert(operandDefiningOp); - auto parentTargetOp = - operandDefiningOp->getParentOfType(); - - if (parentTargetOp != targetOp) - continue; - - escapingOperands.insert(&operand); - } - - return mlir::WalkResult::advance(); - }); - } - - // For an escaping operand, clone its use-def chain (i.e. its backward slice) - // outside the target region. - // - // \return the last op in the chain (this is the op that defines the escaping - // operand). - mlir::Operation * - cloneOperandSliceOutsideTargetOp(mlir::OpOperand *escapingOperand) { - mlir::Operation *operandDefiningOp = escapingOperand->get().getDefiningOp(); - llvm::SetVector backwardSlice; - mlir::BackwardSliceOptions sliceOptions; - sliceOptions.inclusive = true; - mlir::getBackwardSlice(operandDefiningOp, &backwardSlice, sliceOptions); - - auto ip = builder.saveInsertionPoint(); - - mlir::IRMapping mapper; - builder.setInsertionPoint(escapingOperand->getOwner()); - - mlir::Operation *lastSliceOp = nullptr; - llvm::SetVector opsToClone; - - for (auto *op : backwardSlice) { - // DeclareOps need special handling by searching for the corresponding ops - // in the host. Therefore, do not clone them since this special handling - // is done later in the fix-up process. - // - // TODO this might need a more elaborate handling in the future but for - // now this seems sufficient for our purposes. - if (llvm::isa(op)) { - opsToClone.clear(); - break; - } - - opsToClone.insert(op); - } - - for (mlir::Operation *op : opsToClone) - lastSliceOp = builder.clone(*op, mapper); - - builder.restoreInsertionPoint(ip); - return lastSliceOp; - } - - /// Fixup any uses of target region block arguments that we have just created - /// outside of the target region, and replace them by their host values. - void fixupExtractedHostOps() { - llvm::DenseSet escapingOperands; - findEscapingOpOperands(escapingOperands); - - for (mlir::OpOperand *operand : escapingOperands) { - mlir::Operation *operandDefiningOp = operand->get().getDefiningOp(); - assert(operandDefiningOp != nullptr && - "Expected escaping operand to have a defining op (i.e. not to be " - "a block argument)"); - mlir::Operation *lastSliceOp = cloneOperandSliceOutsideTargetOp(operand); - - if (lastSliceOp == nullptr) - continue; - - // Find the index of the operand in the list of results produced by its - // defining op. - unsigned operandResultIdx = 0; - for (auto [idx, res] : llvm::enumerate(operandDefiningOp->getResults())) { - if (res == operand->get()) { - operandResultIdx = idx; - break; - } - } - - // Replace the escaping operand with the corresponding value from the - // op that we cloned outside the target op. - operand->getOwner()->setOperand(operand->getOperandNumber(), - lastSliceOp->getResult(operandResultIdx)); - } - - auto useOutsideTargetRegion = [](mlir::OpOperand &operand) { - if (mlir::Operation *owner = operand.getOwner()) - return !owner->getParentOfType(); - return false; - }; - - auto argIface = llvm::cast(*targetOp); - for (auto [map, arg] : - llvm::zip_equal(targetOp.getMapVars(), argIface.getMapBlockArgs())) { - mlir::Value hostVal = - map.getDefiningOp().getVarPtr(); - - // Replace instances of omp.target block arguments used outside with their - // corresponding host value. - arg.replaceUsesWithIf(hostVal, [&](mlir::OpOperand &operand) -> bool { - // If the use is an hlfir.declare, we need to search for the matching - // one within host code. - if (auto declareOp = llvm::dyn_cast_if_present( - operand.getOwner())) { - if (auto hostDeclareOp = hostVal.getDefiningOp()) { - declareOp->replaceUsesWithIf(hostDeclareOp.getResults(), - useOutsideTargetRegion); - } else if (auto hostBoxOp = hostVal.getDefiningOp()) { - declareOp->replaceUsesWithIf(hostBoxOp.getVal() - .getDefiningOp() - .getResults(), - useOutsideTargetRegion); - } - } - return useOutsideTargetRegion(operand); - }); - } - } -}; - static fir::GlobalOp globalInitialization(lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder, const semantics::Symbol &sym, @@ -1052,7 +952,7 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter, llvm::SmallVector types; llvm::SmallVector locs; unsigned numVars = - args.hostEval.vars.size() + args.inReduction.vars.size() + + args.hostEval.size() + args.inReduction.vars.size() + args.map.vars.size() + args.priv.vars.size() + args.reduction.vars.size() + args.taskReduction.vars.size() + args.useDeviceAddr.vars.size() + args.useDevicePtr.vars.size(); @@ -1068,7 +968,7 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter, // Populate block arguments in clause name alphabetical order to match // expected order by the BlockArgOpenMPOpInterface. - extractTypeLoc(args.hostEval.vars); + extractTypeLoc(args.hostEval); extractTypeLoc(args.inReduction.vars); extractTypeLoc(args.map.vars); extractTypeLoc(args.priv.vars); @@ -1417,6 +1317,8 @@ static void genBodyOfTargetOp( dsp.processStep2(); bindEntryBlockArgs(converter, targetOp, args); + if (!hostEvalInfo.empty()) + hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs()); // Check if cloning the bounds introduced any dependency on the outer region. // If so, then either clone them as well if they are MemoryEffectFree, or else @@ -1591,27 +1493,23 @@ static void genLoopNestClauses(lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const List &clauses, - mlir::Location loc, bool evalOutsideTarget, - mlir::omp::LoopNestOperands &clauseOps, + mlir::Location loc, mlir::omp::LoopNestOperands &clauseOps, llvm::SmallVectorImpl &iv) { ClauseProcessor cp(converter, semaCtx, clauses); - // Evaluate loop bounds on the host device, if the operation is defining part - // of a target SPMD kernel. - if (evalOutsideTarget) { - HostClausesInsertionGuard guard(converter.getFirOpBuilder()); + if (hostEvalInfo.empty()) { cp.processCollapse(loc, eval, clauseOps, iv); - - for (unsigned i = 0; i < clauseOps.loopLowerBounds.size(); ++i) { - clauseOps.loopLowerBounds[i] = - addHostEvalVar(guard.getTargetOp(), clauseOps.loopLowerBounds[i]); - clauseOps.loopUpperBounds[i] = - addHostEvalVar(guard.getTargetOp(), clauseOps.loopUpperBounds[i]); - clauseOps.loopSteps[i] = - addHostEvalVar(guard.getTargetOp(), clauseOps.loopSteps[i]); - } } else { - cp.processCollapse(loc, eval, clauseOps, iv); + HostEvalInfo &hostInfo = hostEvalInfo.back(); + if (!hostInfo.iv.empty()) { + HostEvalInfo &hostInfo = hostEvalInfo.back(); + clauseOps.loopLowerBounds = hostInfo.ops.loopLowerBounds; + clauseOps.loopUpperBounds = hostInfo.ops.loopUpperBounds; + clauseOps.loopSteps = hostInfo.ops.loopSteps; + iv.append(hostInfo.iv); + } else { + cp.processCollapse(loc, eval, clauseOps, iv); + } } clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr(); @@ -1638,23 +1536,16 @@ genOrderedRegionClauses(lower::AbstractConverter &converter, static void genParallelClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, const List &clauses, - mlir::Location loc, bool evalOutsideTarget, - mlir::omp::ParallelOperands &clauseOps, + mlir::Location loc, mlir::omp::ParallelOperands &clauseOps, llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps); - // Evaluate NUM_THREADS on the host device, if the operation is defining part - // of a target SPMD kernel. - if (evalOutsideTarget) { - HostClausesInsertionGuard guard(converter.getFirOpBuilder()); - if (cp.processNumThreads(stmtCtx, clauseOps)) - clauseOps.numThreads = - addHostEvalVar(guard.getTargetOp(), clauseOps.numThreads); - } else { + if (!hostEvalInfo.empty() && hostEvalInfo.back().ops.numThreads) + clauseOps.numThreads = hostEvalInfo.back().ops.numThreads; + else cp.processNumThreads(stmtCtx, clauseOps); - } cp.processProcBind(clauseOps); cp.processReduction(loc, clauseOps, reductionSyms); @@ -1701,9 +1592,9 @@ static void genSingleClauses(lower::AbstractConverter &converter, static void genTargetClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, - lower::StatementContext &stmtCtx, const List &clauses, - mlir::Location loc, bool processHostOnlyClauses, - mlir::omp::TargetOperands &clauseOps, + lower::StatementContext &stmtCtx, lower::pft::Evaluation &eval, + const List &clauses, mlir::Location loc, + bool processHostOnlyClauses, mlir::omp::TargetOperands &clauseOps, llvm::SmallVectorImpl &hasDeviceAddrSyms, llvm::SmallVectorImpl &isDevicePtrSyms, llvm::SmallVectorImpl &mapSyms) { @@ -1711,6 +1602,15 @@ static void genTargetClauses( cp.processDepend(clauseOps); cp.processDevice(stmtCtx, clauseOps); cp.processHasDeviceAddr(clauseOps, hasDeviceAddrSyms); + + // Introduce a new host_eval information structure and populate it if lowering + // for the host device. + if (processHostOnlyClauses) { + HostEvalInfo &hostInfo = hostEvalInfo.emplace_back(); + processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc); + hostInfo.collectValues(clauseOps.hostEvalVars); + } + cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); cp.processIsDevicePtr(clauseOps, isDevicePtrSyms); cp.processMap(loc, stmtCtx, clauseOps, &mapSyms); @@ -1816,28 +1716,30 @@ static void genTaskwaitClauses(lower::AbstractConverter &converter, static void genTeamsClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, const List &clauses, - mlir::Location loc, bool evalOutsideTarget, - mlir::omp::TeamsOperands &clauseOps, + mlir::Location loc, mlir::omp::TeamsOperands &clauseOps, llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); - // Evaluate NUM_TEAMS and THREAD_LIMIT on the host device, if currently inside - // of an omp.target operation. - if (evalOutsideTarget) { - HostClausesInsertionGuard guard(converter.getFirOpBuilder()); - if (cp.processNumTeams(stmtCtx, clauseOps)) - clauseOps.numTeamsUpper = - addHostEvalVar(guard.getTargetOp(), clauseOps.numTeamsUpper); - - if (cp.processThreadLimit(stmtCtx, clauseOps)) - clauseOps.threadLimit = - addHostEvalVar(guard.getTargetOp(), clauseOps.threadLimit); - } else { + if (hostEvalInfo.empty()) { cp.processNumTeams(stmtCtx, clauseOps); cp.processThreadLimit(stmtCtx, clauseOps); + } else { + HostEvalInfo &hostInfo = hostEvalInfo.back(); + if (hostInfo.ops.numTeamsLower || hostInfo.ops.numTeamsUpper) { + clauseOps.numTeamsLower = hostInfo.ops.numTeamsLower; + clauseOps.numTeamsUpper = hostInfo.ops.numTeamsUpper; + } else { + cp.processNumTeams(stmtCtx, clauseOps); + } + + if (hostInfo.ops.threadLimit) + clauseOps.threadLimit = hostInfo.ops.threadLimit; + else + cp.processThreadLimit(stmtCtx, clauseOps); } + cp.processReduction(loc, clauseOps, reductionSyms); } @@ -1993,13 +1895,14 @@ genOrderedRegionOp(lower::AbstractConverter &converter, lower::SymMap &symTable, queue, item, clauseOps); } -static mlir::omp::ParallelOp genParallelOp( - lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - mlir::Location loc, const ConstructQueue &queue, - ConstructQueue::const_iterator item, mlir::omp::ParallelOperands &clauseOps, - const EntryBlockArgs &args, DataSharingProcessor *dsp, - bool isComposite = false, mlir::omp::TargetOp parentTarget = nullptr) { +static mlir::omp::ParallelOp +genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, + lower::pft::Evaluation &eval, mlir::Location loc, + const ConstructQueue &queue, ConstructQueue::const_iterator item, + mlir::omp::ParallelOperands &clauseOps, + const EntryBlockArgs &args, DataSharingProcessor *dsp, + bool isComposite = false) { auto genRegionEntryCB = [&](mlir::Operation *op) { genEntryBlock(converter, args, op->getRegion(0)); bindEntryBlockArgs( @@ -2188,7 +2091,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, mlir::omp::TargetOperands clauseOps; llvm::SmallVector mapSyms, isDevicePtrSyms, hasDeviceAddrSyms; - genTargetClauses(converter, semaCtx, stmtCtx, item->clauses, loc, + genTargetClauses(converter, semaCtx, stmtCtx, eval, item->clauses, loc, processHostOnlyClauses, clauseOps, hasDeviceAddrSyms, isDevicePtrSyms, mapSyms); @@ -2308,7 +2211,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, extractMappedBaseValues(clauseOps.mapVars, mapBaseValues); EntryBlockArgs args; - // TODO: Fill hostEval in advance rather than adding to it later on. + args.hostEval = clauseOps.hostEvalVars; // TODO: Add in_reduction syms and vars. args.map.syms = mapSyms; args.map.vars = mapBaseValues; @@ -2317,6 +2220,12 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, genBodyOfTargetOp(converter, symTable, semaCtx, eval, targetOp, args, loc, queue, item, dsp); + + // Remove the host_eval info created for this target region to avoid impacting + // the lowering of unrelated operations. + if (processHostOnlyClauses) + hostEvalInfo.pop_back(); + return targetOp; } @@ -2440,14 +2349,10 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, ConstructQueue::const_iterator item) { lower::StatementContext stmtCtx; - mlir::omp::TargetOp targetOp = - findParentTargetOp(converter.getFirOpBuilder()); - bool evalOutsideTarget = mustEvalTeamsOutsideTarget(eval, targetOp); - mlir::omp::TeamsOperands clauseOps; llvm::SmallVector reductionSyms; - genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - evalOutsideTarget, clauseOps, reductionSyms); + genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps, + reductionSyms); EntryBlockArgs args; // TODO: Add private syms and vars. @@ -2499,7 +2404,7 @@ static void genStandaloneDistribute(lower::AbstractConverter &converter, mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc, - /*evalOutsideTarget=*/false, loopNestClauseOps, iv); + loopNestClauseOps, iv); EntryBlockArgs distributeArgs; distributeArgs.priv.syms = dsp.getDelayedPrivSymbols(); @@ -2534,7 +2439,7 @@ static void genStandaloneDo(lower::AbstractConverter &converter, mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc, - /*evalOutsideTarget=*/false, loopNestClauseOps, iv); + loopNestClauseOps, iv); EntryBlockArgs wsloopArgs; // TODO: Add private syms and vars. @@ -2560,8 +2465,7 @@ static void genStandaloneParallel(lower::AbstractConverter &converter, mlir::omp::ParallelOperands parallelClauseOps; llvm::SmallVector parallelReductionSyms; genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - /*evalOutsideTarget=*/false, parallelClauseOps, - parallelReductionSyms); + parallelClauseOps, parallelReductionSyms); std::optional dsp; if (enableDelayedPrivatization) { @@ -2605,7 +2509,7 @@ static void genStandaloneSimd(lower::AbstractConverter &converter, mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc, - /*evalOutsideTarget=*/false, loopNestClauseOps, iv); + loopNestClauseOps, iv); EntryBlockArgs simdArgs; // TODO: Add private syms and vars. @@ -2645,16 +2549,11 @@ static void genCompositeDistributeParallelDo( ConstructQueue::const_iterator parallelItem = std::next(distributeItem); ConstructQueue::const_iterator doItem = std::next(parallelItem); - mlir::omp::TargetOp targetOp = - findParentTargetOp(converter.getFirOpBuilder()); - bool evalOutsideTarget = mustEvalTargetSPMDOutsideTarget(eval, targetOp); - // Create parent omp.parallel first. mlir::omp::ParallelOperands parallelClauseOps; llvm::SmallVector parallelReductionSyms; genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc, - evalOutsideTarget, parallelClauseOps, - parallelReductionSyms); + parallelClauseOps, parallelReductionSyms); DataSharingProcessor dsp(converter, semaCtx, doItem->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/true, @@ -2668,8 +2567,7 @@ static void genCompositeDistributeParallelDo( parallelArgs.reduction.syms = parallelReductionSyms; parallelArgs.reduction.vars = parallelClauseOps.reductionVars; genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem, - parallelClauseOps, parallelArgs, &dsp, - /*isComposite=*/true, evalOutsideTarget ? targetOp : nullptr); + parallelClauseOps, parallelArgs, &dsp, /*isComposite=*/true); // Clause processing. mlir::omp::DistributeOperands distributeClauseOps; @@ -2684,7 +2582,7 @@ static void genCompositeDistributeParallelDo( mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, doItem->clauses, loc, - evalOutsideTarget, loopNestClauseOps, iv); + loopNestClauseOps, iv); // Operation creation. EntryBlockArgs distributeArgs; @@ -2720,16 +2618,11 @@ static void genCompositeDistributeParallelDoSimd( ConstructQueue::const_iterator doItem = std::next(parallelItem); ConstructQueue::const_iterator simdItem = std::next(doItem); - mlir::omp::TargetOp targetOp = - findParentTargetOp(converter.getFirOpBuilder()); - bool evalOutsideTarget = mustEvalTargetSPMDOutsideTarget(eval, targetOp); - // Create parent omp.parallel first. mlir::omp::ParallelOperands parallelClauseOps; llvm::SmallVector parallelReductionSyms; genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc, - evalOutsideTarget, parallelClauseOps, - parallelReductionSyms); + parallelClauseOps, parallelReductionSyms); DataSharingProcessor dsp(converter, semaCtx, simdItem->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/true, @@ -2743,8 +2636,7 @@ static void genCompositeDistributeParallelDoSimd( parallelArgs.reduction.syms = parallelReductionSyms; parallelArgs.reduction.vars = parallelClauseOps.reductionVars; genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem, - parallelClauseOps, parallelArgs, &dsp, - /*isComposite=*/true, evalOutsideTarget ? targetOp : nullptr); + parallelClauseOps, parallelArgs, &dsp, /*isComposite=*/true); // Clause processing. mlir::omp::DistributeOperands distributeClauseOps; @@ -2764,7 +2656,7 @@ static void genCompositeDistributeParallelDoSimd( mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc, - evalOutsideTarget, loopNestClauseOps, iv); + loopNestClauseOps, iv); // Operation creation. EntryBlockArgs distributeArgs; @@ -2832,7 +2724,7 @@ static void genCompositeDistributeSimd(lower::AbstractConverter &converter, mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc, - /*evalOutsideTarget=*/false, loopNestClauseOps, iv); + loopNestClauseOps, iv); // Operation creation. EntryBlockArgs distributeArgs; @@ -2890,7 +2782,7 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter, mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc, - /*evalOutsideTarget=*/false, loopNestClauseOps, iv); + loopNestClauseOps, iv); // Operation creation. EntryBlockArgs wsloopArgs; diff --git a/flang/test/Lower/OpenMP/eval-outside-target.f90 b/flang/test/Lower/OpenMP/eval-outside-target.f90 index d0925971e4b2bca..32c52462b86a760 100644 --- a/flang/test/Lower/OpenMP/eval-outside-target.f90 +++ b/flang/test/Lower/OpenMP/eval-outside-target.f90 @@ -33,7 +33,7 @@ end subroutine teams subroutine distribute_parallel_do() ! BOTH: omp.target - ! HOST-SAME: host_eval(%{{.*}} -> %[[NUM_THREADS:.*]], %{{.*}} -> %[[LB:.*]], %{{.*}} -> %[[UB:.*]], %{{.*}} -> %[[STEP:.*]] : i32, i32, i32, i32) + ! HOST-SAME: host_eval(%{{.*}} -> %[[LB:.*]], %{{.*}} -> %[[UB:.*]], %{{.*}} -> %[[STEP:.*]], %{{.*}} -> %[[NUM_THREADS:.*]] : i32, i32, i32, i32) ! DEVICE-NOT: host_eval({{.*}}) ! DEVICE-SAME: { @@ -95,7 +95,7 @@ end subroutine distribute_parallel_do subroutine distribute_parallel_do_simd() ! BOTH: omp.target - ! HOST-SAME: host_eval(%{{.*}} -> %[[NUM_THREADS:.*]], %{{.*}} -> %[[LB:.*]], %{{.*}} -> %[[UB:.*]], %{{.*}} -> %[[STEP:.*]] : i32, i32, i32, i32) + ! HOST-SAME: host_eval(%{{.*}} -> %[[LB:.*]], %{{.*}} -> %[[UB:.*]], %{{.*}} -> %[[STEP:.*]], %{{.*}} -> %[[NUM_THREADS:.*]] : i32, i32, i32, i32) ! DEVICE-NOT: host_eval({{.*}}) ! DEVICE-SAME: { diff --git a/flang/test/Lower/OpenMP/target-do-loop-control-exprs.f90 b/flang/test/Lower/OpenMP/target-do-loop-control-exprs.f90 deleted file mode 100644 index b4d5cffffac1d69..000000000000000 --- a/flang/test/Lower/OpenMP/target-do-loop-control-exprs.f90 +++ /dev/null @@ -1,95 +0,0 @@ -! Verifies that if expressions are used to compute a target parallel loop, that -! no values escape the target region when flang emits the ops corresponding to -! these expressions (for example the compute the trip count for the target region). - -! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s - -subroutine foo(upper_bound) - implicit none - integer :: upper_bound - integer :: nodes(1 : upper_bound) - integer :: i - - !$omp target teams distribute parallel do - do i = 1, ubound(nodes,1) - nodes(i) = i - end do - !$omp end target teams distribute parallel do -end subroutine - -! CHECK: func.func @_QPfoo(%[[FUNC_ARG:.*]]: !fir.ref {fir.bindc_name = "upper_bound"}) { -! CHECK: %[[UB_ALLOC:.*]] = fir.alloca i32 -! CHECK: fir.dummy_scope : !fir.dscope -! CHECK: %[[UB_DECL:.*]]:2 = hlfir.declare %[[FUNC_ARG]] {{.*}} {uniq_name = "_QFfooEupper_bound"} - -! CHECK: omp.map.info -! CHECK: omp.map.info -! CHECK: omp.map.info - -! Verify that we load from the original/host allocation of the `upper_bound` -! variable rather than the corresponding target region arg. - -! CHECK: fir.load %[[UB_ALLOC]] : !fir.ref -! CHECK: omp.target - -! CHECK: } - -subroutine foo_with_dummy_arg(nodes) - implicit none - integer, intent(inout) :: nodes( : ) - integer :: i - - !$omp target teams distribute parallel do - do i = 1, ubound(nodes, 1) - nodes(i) = i - end do - !$omp end target teams distribute parallel do -end subroutine - -! CHECK: func.func @_QPfoo_with_dummy_arg(%[[FUNC_ARG:.*]]: !fir.box> {fir.bindc_name = "nodes"}) { - -! CHECK: %[[ARR_DECL:.*]]:2 = hlfir.declare %[[FUNC_ARG]] dummy_scope - -! CHECK: omp.map.info -! CHECK: omp.map.info -! CHECK: omp.map.info - -! Verify that we get the box dims of the host array declaration not the target -! one. - -! CHECK: fir.box_dims %[[ARR_DECL]] - -! CHECK: omp.target - -! CHECK: } - - -subroutine bounds_expr_in_loop_control(array) - real, intent(out) :: array(:,:) - integer :: bounds(2), i, j - bounds = shape(array) - - !$omp target teams distribute parallel do simd collapse(2) - do j = 1,bounds(2) - do i = 1,bounds(1) - array(i,j) = 0. - enddo - enddo -end subroutine bounds_expr_in_loop_control - - -! CHECK: func.func @_QPbounds_expr_in_loop_control(%[[FUNC_ARG:.*]]: {{.*}}) { - -! CHECK: %[[BOUNDS_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "{{.*}}Ebounds"} : (!fir.ref>, !fir.shape<1>) -> ({{.*}}) - -! Verify that the host declaration of `bounds` (i.e. not the target/mapped one) -! is used for the trip count calculation. Trip count is calculation ops are emitted -! directly before the `omp.target` op and after all `omp.map.info` op; hence the -! `CHECK-NOT: ...` line. - -! CHECK: hlfir.designate %[[BOUNDS_DECL:.*]]#0 (%c2{{.*}}) -! CHECK: hlfir.designate %[[BOUNDS_DECL:.*]]#0 (%c1{{.*}}) -! CHECK-NOT: omp.map.info -! CHECK: omp.target - -! CHECK: } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h index 1247a871f93c6dc..881e6f16299c024 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h @@ -41,6 +41,12 @@ struct DeviceTypeClauseOps { // Extra operation operand structures. //===----------------------------------------------------------------------===// +/// Clauses that correspond to operations other than omp.target, but might have +/// to be evaluated outside of a parent target region in some cases. +using HostEvaluatedOperands = + detail::Clauses; + // TODO: Add `indirect` clause. using DeclareTargetOperands = detail::Clauses;