Skip to content

Commit

Permalink
[CIR][ThroughMLIR] Support lowering SwitchOp without fallthrough to scf
Browse files Browse the repository at this point in the history
  • Loading branch information
Mochthon committed Oct 17, 2024
1 parent 8311717 commit 198ce40
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 16 deletions.
107 changes: 91 additions & 16 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,24 @@ class CIRYieldOpLowering
}
};

class CIRBreakOpLowering
: public mlir::OpConversionPattern<mlir::cir::BreakOp> {
public:
using OpConversionPattern<mlir::cir::BreakOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(mlir::cir::BreakOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto *parentOp = op->getParentOp();
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
.Case<mlir::scf::IndexSwitchOp>([&](auto) {
rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(
op, adaptor.getOperands());
return mlir::success();
})
.Default([](auto) { return mlir::failure(); });
}
};

class CIRIfOpLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
public:
using mlir::OpConversionPattern<mlir::cir::IfOp>::OpConversionPattern;
Expand All @@ -909,6 +927,62 @@ class CIRIfOpLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
}
};

class CIRSwitchOpLowering
: public mlir::OpConversionPattern<mlir::cir::SwitchOp> {
public:
using mlir::OpConversionPattern<mlir::cir::SwitchOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::SwitchOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<mlir::Type> resultTypes;
if (mlir::failed(getTypeConverter()->convertTypes(op->getResultTypes(),
resultTypes)))
return mlir::failure();

auto caseValue = rewriter.create<mlir::arith::IndexCastOp>(
adaptor.getCondition().getLoc(), rewriter.getIndexType(),
adaptor.getCondition());

llvm::SmallVector<int64_t, 3> cases;
auto caseAttrList = op.getCasesAttr();
for (auto &caseAttr : caseAttrList) {
mlir::Attribute caseAttrValue;
caseAttr.walkImmediateSubElements(
[&caseAttrValue](mlir::Attribute subAttr) {
if (!caseAttrValue)
caseAttrValue = subAttr;
},
[](mlir::Type type) {});

mlir::cir::IntAttr cirIntAttr;
caseAttrValue.walkImmediateSubElements(
[&cirIntAttr](mlir::Attribute subAttr) {
if (!cirIntAttr)
cirIntAttr = mlir::dyn_cast_or_null<mlir::cir::IntAttr>(subAttr);
},
[](mlir::Type type) {});

if (cirIntAttr != nullptr)
cases.push_back(cirIntAttr.getSInt());
}

auto casesRegionCount = cases.size();

auto indexSwitchOp = rewriter.create<mlir::scf::IndexSwitchOp>(
op.getLoc(), mlir::TypeRange(resultTypes), caseValue, cases,
casesRegionCount);

for (unsigned int i = 0; i < op.getNumRegions(); i++) {
rewriter.inlineRegionBefore(op->getRegion(i), indexSwitchOp.getRegion(i),
indexSwitchOp.getRegion(i).end());
}

rewriter.replaceOp(op, indexSwitchOp);
return mlir::success();
}
};

class CIRGlobalOpLowering
: public mlir::OpConversionPattern<mlir::cir::GlobalOp> {
public:
Expand Down Expand Up @@ -1316,22 +1390,23 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());

patterns.add<
CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering,
CIRLogOpLowering, CIRRoundOpLowering, CIRPtrStrideOpLowering,
CIRSinOpLowering, CIRShiftOpLowering, CIRBitClzOpLowering,
CIRBitCtzOpLowering, CIRBitPopcountOpLowering, CIRBitClrsbOpLowering,
CIRBitFfsOpLowering, CIRBitParityOpLowering, CIRIfOpLowering,
CIRVectorCreateLowering, CIRVectorInsertLowering,
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
patterns.getContext());
patterns.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
CIRYieldOpLowering, CIRBreakOpLowering, CIRCosOpLowering,
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
CIRPtrStrideOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering,
CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering,
CIRLogOpLowering, CIRRoundOpLowering, CIRPtrStrideOpLowering,
CIRSinOpLowering, CIRShiftOpLowering, CIRBitClzOpLowering,
CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
CIRBitParityOpLowering, CIRIfOpLowering, CIRSwitchOpLowering,
CIRVectorCreateLowering, CIRVectorInsertLowering,
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(
converter, patterns.getContext());
}

static mlir::TypeConverter prepareTypeConverter() {
Expand Down
56 changes: 56 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/switch.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir %s -o %t.mlir
// RUN: FileCheck --input-file=%t.mlir %s

int switch_test(int cond) {

// CHECK: %alloca = memref.alloca() {alignment = 4 : i64} : memref<i32>
// CHECK: %alloca_0 = memref.alloca() {alignment = 4 : i64} : memref<i32>
// CHECK: %alloca_1 = memref.alloca() {alignment = 4 : i64} : memref<i32>

// CHECK: memref.store %arg0, %alloca[] : memref<i32>

int ret;

// CHECK: memref.alloca_scope {

// CHECK: %2 = memref.load %alloca[] : memref<i32>
// CHECK: %3 = arith.index_cast %2 : i32 to index

switch (cond) {

// CHECK: scf.index_switch %3

case 0: ret = 10; break;

// CHECK: case 0 {
// CHECK: %c100_i32 = arith.constant 100 : i32
// CHECK: memref.store %c100_i32, %alloca_1[] : memref<i32>
// CHECK: scf.yield
// CHECK: }

case 1: ret = 100; break;

// CHECK: case 1 {
// CHECK: %c1000_i32 = arith.constant 1000 : i32
// CHECK: memref.store %c1000_i32, %alloca_1[] : memref<i32>
// CHECK: scf.yield
// CHECK: }

default: ret = 1000; break;

// CHECK: default {
// CHECK: %c10_i32 = arith.constant 10 : i32
// CHECK: memref.store %c10_i32, %alloca_1[] : memref<i32>
// CHECK: }

}

return ret;

// CHECK: }

// CHECK: %0 = memref.load %alloca_1[] : memref<i32>
// CHECK: memref.store %0, %alloca_0[] : memref<i32>
// CHECK: %1 = memref.load %alloca_0[] : memref<i32>
// CHECK: return %1 : i32
}

0 comments on commit 198ce40

Please sign in to comment.