Skip to content

Commit

Permalink
[Flang] Extracting internal constants from scalar literals (llvm#73829)
Browse files Browse the repository at this point in the history
Constants actual arguments in function/subroutine calls are currently
lowered as allocas + store. This can sometimes inhibit LTO and the
constant will not be propagated to the called function. Particularly in
cases where the function/subroutine call happens inside a condition.

This patch changes the lowering of these constant actual arguments to a
global constant + fir.address_of_op. This lowering makes it easier for
LTO to propagate the constant.

The optimization must be enabled explicitly to run. Use -mmlir
--enable-constant-argument-globalisation to enable.

---------

Co-authored-by: Dmitriy Smirnov <[email protected]>
  • Loading branch information
Leporacanthicus and d-smirnov authored Jun 25, 2024
1 parent c6973ad commit de528ff
Show file tree
Hide file tree
Showing 9 changed files with 375 additions and 4 deletions.
2 changes: 2 additions & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ namespace fir {
#define GEN_PASS_DECL_OMPFUNCTIONFILTERING
#define GEN_PASS_DECL_VSCALEATTR
#define GEN_PASS_DECL_FUNCTIONATTR
#define GEN_PASS_DECL_CONSTANTARGUMENTGLOBALISATIONOPT

#include "flang/Optimizer/Transforms/Passes.h.inc"

std::unique_ptr<mlir::Pass> createAffineDemotionPass();
Expand Down
9 changes: 9 additions & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,15 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
];
}

// This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
def ConstantArgumentGlobalisationOpt : Pass<"constant-argument-globalisation-opt", "mlir::ModuleOp"> {
let summary = "Convert constant function arguments to global constants.";
let description = [{
Convert scalar literals of function arguments to global constants.
}];
let dependentDialects = [ "fir::FIROpsDialect" ];
}

def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
let summary = "Move local array allocations from heap memory into stack memory";
let description = [{
Expand Down
8 changes: 8 additions & 0 deletions flang/include/flang/Tools/CLOptions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
static llvm::cl::opt<bool> disable##DOName("disable-" DOOption, \
llvm::cl::desc("disable " DODescription " pass"), llvm::cl::init(false), \
llvm::cl::Hidden)
#define EnableOption(EOName, EOOption, EODescription) \
static llvm::cl::opt<bool> enable##EOName("enable-" EOOption, \
llvm::cl::desc("enable " EODescription " pass"), llvm::cl::init(false), \
llvm::cl::Hidden)

/// Shared option in tools to control whether dynamically sized array
/// allocations should always be on the heap.
Expand Down Expand Up @@ -86,6 +90,8 @@ DisableOption(BoxedProcedureRewrite, "boxed-procedure-rewrite",

DisableOption(ExternalNameConversion, "external-name-interop",
"convert names with external convention");
EnableOption(ConstantArgumentGlobalisation, "constant-argument-globalisation",
"the local constant argument to global constant conversion");

using PassConstructor = std::unique_ptr<mlir::Pass>();

Expand Down Expand Up @@ -270,6 +276,8 @@ inline void createDefaultFIROptimizerPassPipeline(
// These passes may increase code size.
pm.addPass(fir::createSimplifyIntrinsics());
pm.addPass(fir::createAlgebraicSimplificationPass(config));
if (enableConstantArgumentGlobalisation)
pm.addPass(fir::createConstantArgumentGlobalisationOpt());
}

if (pc.LoopVersioning)
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_flang_library(FIRTransforms
AnnotateConstant.cpp
AssumedRankOpConversion.cpp
CharacterConversion.cpp
ConstantArgumentGlobalisation.cpp
ControlFlowConverter.cpp
ArrayValueCopy.cpp
ExternalNameConversion.cpp
Expand Down
185 changes: 185 additions & 0 deletions flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
//===- ConstantArgumentGlobalisation.cpp ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace fir {
#define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

#define DEBUG_TYPE "flang-constant-argument-globalisation-opt"

namespace {
unsigned uniqueLitId = 1;

class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
protected:
const mlir::DominanceInfo &di;

public:
using OpRewritePattern::OpRewritePattern;

CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
: OpRewritePattern(ctx), di(_di) {}

mlir::LogicalResult
matchAndRewrite(fir::CallOp callOp,
mlir::PatternRewriter &rewriter) const override {
LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
auto module = callOp->getParentOfType<mlir::ModuleOp>();
bool needUpdate = false;
fir::FirOpBuilder builder(rewriter, module);
llvm::SmallVector<mlir::Value> newOperands;
llvm::SmallVector<std::pair<mlir::Operation *, mlir::Operation *>> allocas;
for (const mlir::Value &a : callOp.getArgs()) {
auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
// We can convert arguments that are alloca, and that has
// the value by reference attribute. All else is just added
// to the argument list.
if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
newOperands.push_back(a);
continue;
}

mlir::Type varTy = alloca.getInType();
assert(!fir::hasDynamicSize(varTy) &&
"only expect statically sized scalars to be by value");

// Find immediate store with const argument
mlir::Operation *store = nullptr;
for (mlir::Operation *s : alloca->getUsers()) {
if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) {
// We can only deal with ONE store - if already found one,
// set to nullptr and exit the loop.
if (store) {
store = nullptr;
break;
}
store = s;
}
}

// If we didn't find any store, or multiple stores, add argument as is
// and move on.
if (!store) {
newOperands.push_back(a);
continue;
}

LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");

mlir::Operation *definingOp = store->getOperand(0).getDefiningOp();
// If not a constant, add to operands and move on.
if (!mlir::isa<mlir::arith::ConstantOp>(definingOp)) {
// Unable to remove alloca arg
newOperands.push_back(a);
continue;
}

LLVM_DEBUG(llvm::dbgs() << " found define " << *definingOp << "\n");

std::string globalName =
"_global_const_." + std::to_string(uniqueLitId++);
assert(!builder.getNamedGlobal(globalName) &&
"We should have a unique name here");

if (std::find_if(allocas.begin(), allocas.end(), [alloca](auto x) {
return x.first == alloca;
}) == allocas.end()) {
allocas.push_back(std::make_pair(alloca, store));
}

auto loc = callOp.getLoc();
fir::GlobalOp global = builder.createGlobalConstant(
loc, varTy, globalName,
[&](fir::FirOpBuilder &builder) {
mlir::Operation *cln = definingOp->clone();
builder.insert(cln);
mlir::Value val =
builder.createConvert(loc, varTy, cln->getResult(0));
builder.create<fir::HasValueOp>(loc, val);
},
builder.createInternalLinkage());
mlir::Value addr = builder.create<fir::AddrOfOp>(loc, global.resultType(),
global.getSymbol());
newOperands.push_back(addr);
needUpdate = true;
}

if (needUpdate) {
auto loc = callOp.getLoc();
llvm::SmallVector<mlir::Type> newResultTypes;
newResultTypes.append(callOp.getResultTypes().begin(),
callOp.getResultTypes().end());
fir::CallOp newOp = builder.create<fir::CallOp>(
loc, newResultTypes,
callOp.getCallee().has_value() ? callOp.getCallee().value()
: mlir::SymbolRefAttr{},
newOperands);
// Copy all the attributes from the old to new op.
newOp->setAttrs(callOp->getAttrs());
rewriter.replaceOp(callOp, newOp);

for (auto a : allocas) {
if (a.first->hasOneUse()) {
// If the alloca is only used for a store and the call operand, the
// store is no longer required.
rewriter.eraseOp(a.second);
rewriter.eraseOp(a.first);
}
}
LLVM_DEBUG(llvm::dbgs() << "global constant for " << callOp << " as "
<< newOp << '\n');
return mlir::success();
}

// Failure here just means "we couldn't do the conversion", which is
// perfectly acceptable to the upper layers of this function.
return mlir::failure();
}
};

// this pass attempts to convert immediate scalar literals in function calls
// to global constants to allow transformations such as Dead Argument
// Elimination
class ConstantArgumentGlobalisationOpt
: public fir::impl::ConstantArgumentGlobalisationOptBase<
ConstantArgumentGlobalisationOpt> {
public:
ConstantArgumentGlobalisationOpt() = default;

void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
auto *context = &getContext();
mlir::RewritePatternSet patterns(context);
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;

patterns.insert<CallOpRewriter>(context, *di);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
mod, std::move(patterns), config))) {
mlir::emitError(mod.getLoc(),
"error in constant globalisation optimization\n");
signalPassFailure();
}
}
};
} // namespace
4 changes: 1 addition & 3 deletions flang/test/Fir/boxproc.fir
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@

// CHECK-LABEL: define void @_QPtest_proc_dummy_other(ptr
// CHECK-SAME: %[[VAL_0:.*]])
// CHECK: %[[VAL_1:.*]] = alloca i32, i64 1, align 4
// CHECK: store i32 4, ptr %[[VAL_1]], align 4
// CHECK: call void %[[VAL_0]](ptr %[[VAL_1]])
// CHECK: call void %[[VAL_0]](ptr %{{.*}})

func.func @_QPtest_proc_dummy() {
%c0_i32 = arith.constant 0 : i32
Expand Down
5 changes: 4 additions & 1 deletion flang/test/Lower/character-local-variables.f90
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
! RUN: bbc -hlfir=false %s -o - | FileCheck %s
! RUN: bbc -hlfir=false --enable-constant-argument-globalisation %s -o - \
! RUN: | FileCheck %s --check-prefix=CHECK-CONST

! Test lowering of local character variables

Expand Down Expand Up @@ -118,7 +120,8 @@ subroutine assumed_length_param(n)
integer :: n
! CHECK: %[[c4:.*]] = arith.constant 4 : i64
! CHECK: fir.store %[[c4]] to %[[tmp:.*]] : !fir.ref<i64>
! CHECK: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref<i64>) -> ()
! CHECK-CONST: %[[tmp:.*]] = fir.address_of(@_global_const_.{{.*}}) : !fir.ref<i64>
! CHECK-CONST: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref<i64>) -> ()
call take_int(len(c(n), kind=8))
end

Expand Down
98 changes: 98 additions & 0 deletions flang/test/Transforms/constant-argument-globalisation-2.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// RUN: fir-opt --split-input-file --constant-argument-globalisation-opt < %s | FileCheck %s

module {
// Test for "two conditional writes to the same alloca doesn't get replaced."
func.func @func(%arg0: i32, %arg1: i1) {
%c2_i32 = arith.constant 2 : i32
%addr = fir.alloca i32 {adapt.valuebyref}
fir.if %arg1 {
fir.store %c2_i32 to %addr : !fir.ref<i32>
} else {
fir.store %arg0 to %addr : !fir.ref<i32>
}
fir.call @sub2(%addr) : (!fir.ref<i32>) -> ()
return
}
func.func private @sub2(!fir.ref<i32>)

// CHECK-LABEL: func.func @func
// CHECK-SAME: [[ARG0:%.*]]: i32
// CHECK-SAME: [[ARG1:%.*]]: i1)
// CHECK: [[CONST:%.*]] = arith.constant
// CHECK: [[ADDR:%.*]] = fir.alloca i32
// CHECK: fir.if [[ARG1]]
// CHECK: fir.store [[CONST]] to [[ADDR]]
// CHECK: } else {
// CHECK: fir.store [[ARG0]] to [[ADDR]]
// CHECK: fir.call @sub2([[ADDR]])
// CHECK: return

}

// -----

module {
// Test for "two writes to the same alloca doesn't get replaced."
func.func @func() {
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%addr = fir.alloca i32 {adapt.valuebyref}
fir.store %c1_i32 to %addr : !fir.ref<i32>
fir.store %c2_i32 to %addr : !fir.ref<i32>
fir.call @sub2(%addr) : (!fir.ref<i32>) -> ()
return
}
func.func private @sub2(!fir.ref<i32>)

// CHECK-LABEL: func.func @func
// CHECK: [[CONST1:%.*]] = arith.constant
// CHECK: [[CONST2:%.*]] = arith.constant
// CHECK: [[ADDR:%.*]] = fir.alloca i32
// CHECK: fir.store [[CONST1]] to [[ADDR]]
// CHECK: fir.store [[CONST2]] to [[ADDR]]
// CHECK: fir.call @sub2([[ADDR]])
// CHECK: return

}

// -----

module {
// Test for "one write to the the alloca gets replaced."
func.func @func() {
%c1_i32 = arith.constant 1 : i32
%addr = fir.alloca i32 {adapt.valuebyref}
fir.store %c1_i32 to %addr : !fir.ref<i32>
fir.call @sub2(%addr) : (!fir.ref<i32>) -> ()
return
}
func.func private @sub2(!fir.ref<i32>)

// CHECK-LABEL: func.func @func
// CHECK: [[ADDR:%.*]] = fir.address_of([[EXTR:@.*]]) : !fir.ref<i32>
// CHECK: fir.call @sub2([[ADDR]])
// CHECK: return
// CHECK: fir.global internal [[EXTR]] constant : i32 {
// CHECK: %{{.*}} = arith.constant 1 : i32
// CHECK: fir.has_value %{{.*}} : i32
// CHECK: }

}

// -----
// Check that same argument used twice is converted.
module {
func.func @func(%arg0: !fir.ref<i32>, %arg1: i1) {
%c2_i32 = arith.constant 2 : i32
%addr1 = fir.alloca i32 {adapt.valuebyref}
fir.store %c2_i32 to %addr1 : !fir.ref<i32>
fir.call @sub1(%addr1, %addr1) : (!fir.ref<i32>, !fir.ref<i32>) -> ()
return
}
}

// CHECK-LABEL: func.func @func
// CHECK-NEXT: %[[ARG1:.*]] = fir.address_of([[CONST1:@.*]]) : !fir.ref<i32>
// CHECK-NEXT: %[[ARG2:.*]] = fir.address_of([[CONST2:@.*]]) : !fir.ref<i32>
// CHECK-NEXT: fir.call @sub1(%[[ARG1]], %[[ARG2]])
// CHECK-NEXT: return
Loading

0 comments on commit de528ff

Please sign in to comment.