Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC][CIR] Lower cir.bool to i1 #1158

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,8 @@ class CIRConditionOpLowering
auto *parentOp = op->getParentOp();
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
.Case<mlir::scf::WhileOp>([&](auto) {
auto condition = adaptor.getCondition();
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
op.getLoc(), rewriter.getI1Type(), condition);
rewriter.replaceOpWithNewOp<mlir::scf::ConditionOp>(
op, i1Condition, parentOp->getOperands());
op, adaptor.getCondition(), parentOp->getOperands());
return mlir::success();
})
.Default([](auto) { return mlir::failure(); });
Expand Down
114 changes: 71 additions & 43 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
Expand Down Expand Up @@ -105,15 +106,55 @@ class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> {
}
};

static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter,
mlir::Type type) {
// TODO(cir): Handle other types similarly to clang's codegen
// convertTypeForMemory
if (isa<cir::BoolType>(type)) {
// TODO: Use datalayout to get the size of bool
return mlir::IntegerType::get(type.getContext(), 8);
}

return converter.convertType(type);
}

static mlir::Value emitFromMemory(mlir::ConversionPatternRewriter &rewriter,
cir::LoadOp op, mlir::Value value) {

// TODO(cir): Handle other types similarly to clang's codegen EmitFromMemory
if (isa<cir::BoolType>(op.getResult().getType())) {
// Create trunc of value from i8 to i1
// TODO: Use datalayout to get the size of bool
assert(value.getType().isInteger(8));
return createIntCast(rewriter, value, rewriter.getI1Type());
}

return value;
}

static mlir::Value emitToMemory(mlir::ConversionPatternRewriter &rewriter,
cir::StoreOp op, mlir::Value value) {

// TODO(cir): Handle other types similarly to clang's codegen EmitToMemory
if (isa<cir::BoolType>(op.getValue().getType())) {
// Create zext of value from i1 to i8
// TODO: Use datalayout to get the size of bool
return createIntCast(rewriter, value, rewriter.getI8Type());
}

return value;
}

class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
public:
using OpConversionPattern<cir::AllocaOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::AllocaOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto type = adaptor.getAllocaType();
auto mlirType = getTypeConverter()->convertType(type);

mlir::Type mlirType =
convertTypeForMemory(*getTypeConverter(), adaptor.getAllocaType());

// FIXME: Some types can not be converted yet (e.g. struct)
if (!mlirType)
Expand Down Expand Up @@ -174,12 +215,20 @@ class CIRLoadOpLowering : public mlir::OpConversionPattern<cir::LoadOp> {
mlir::Value base;
SmallVector<mlir::Value> indices;
SmallVector<mlir::Operation *> eraseList;
mlir::memref::LoadOp newLoad;
if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList,
rewriter)) {
rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, base, indices);
newLoad =
rewriter.create<mlir::memref::LoadOp>(op.getLoc(), base, indices);
// rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, base, indices);
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
} else
rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, adaptor.getAddr());
newLoad =
rewriter.create<mlir::memref::LoadOp>(op.getLoc(), adaptor.getAddr());

// Convert adapted result to its original type if needed.
mlir::Value result = emitFromMemory(rewriter, op, newLoad.getResult());
rewriter.replaceOp(op, result);
return mlir::LogicalResult::success();
}
};
Expand All @@ -194,13 +243,16 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
mlir::Value base;
SmallVector<mlir::Value> indices;
SmallVector<mlir::Operation *> eraseList;

// Convert adapted value to its memory type if needed.
mlir::Value value = emitToMemory(rewriter, op, adaptor.getValue());
if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList,
rewriter)) {
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, adaptor.getValue(),
base, indices);
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, value, base,
indices);
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
} else
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, adaptor.getValue(),
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, value,
adaptor.getAddr());
return mlir::LogicalResult::success();
}
Expand Down Expand Up @@ -741,29 +793,20 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
auto type = op.getLhs().getType();

mlir::Value mlirResult;

if (auto ty = mlir::dyn_cast<cir::IntType>(type)) {
auto kind = convertCmpKindToCmpIPredicate(op.getKind(), ty.isSigned());
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
op.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOpWithNewOp<mlir::arith::CmpIOp>(
op, kind, adaptor.getLhs(), adaptor.getRhs());
} else if (auto ty = mlir::dyn_cast<cir::CIRFPTypeInterface>(type)) {
auto kind = convertCmpKindToCmpFPredicate(op.getKind());
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
op.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOpWithNewOp<mlir::arith::CmpFOp>(
op, kind, adaptor.getLhs(), adaptor.getRhs());
} else if (auto ty = mlir::dyn_cast<cir::PointerType>(type)) {
llvm_unreachable("pointer comparison not supported yet");
} else {
return op.emitError() << "unsupported type for CmpOp: " << type;
}

// MLIR comparison ops return i1, but cir::CmpOp returns the same type as
// the LHS value. Since this return value can be used later, we need to
// restore the type with the extension below.
auto mlirResultTy = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(op, mlirResultTy,
mlirResult);

return mlir::LogicalResult::success();
}
};
Expand Down Expand Up @@ -823,12 +866,8 @@ struct CIRBrCondOpLowering : public mlir::OpConversionPattern<cir::BrCondOp> {
mlir::LogicalResult
matchAndRewrite(cir::BrCondOp brOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {

auto condition = adaptor.getCond();
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
brOp.getLoc(), rewriter.getI1Type(), condition);
rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
brOp, i1Condition.getResult(), brOp.getDestTrue(),
brOp, adaptor.getCond(), brOp.getDestTrue(),
adaptor.getDestOperandsTrue(), brOp.getDestFalse(),
adaptor.getDestOperandsFalse());

Expand All @@ -844,16 +883,13 @@ class CIRTernaryOpLowering : public mlir::OpConversionPattern<cir::TernaryOp> {
matchAndRewrite(cir::TernaryOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.setInsertionPoint(op);
auto condition = adaptor.getCond();
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
op.getLoc(), rewriter.getI1Type(), condition);
SmallVector<mlir::Type> resultTypes;
if (mlir::failed(getTypeConverter()->convertTypes(op->getResultTypes(),
resultTypes)))
return mlir::failure();

auto ifOp = rewriter.create<mlir::scf::IfOp>(op.getLoc(), resultTypes,
i1Condition.getResult(), true);
adaptor.getCond(), true);
auto *thenBlock = &ifOp.getThenRegion().front();
auto *elseBlock = &ifOp.getElseRegion().front();
rewriter.inlineBlockBefore(&op.getTrueRegion().front(), thenBlock,
Expand Down Expand Up @@ -890,11 +926,8 @@ class CIRIfOpLowering : public mlir::OpConversionPattern<cir::IfOp> {
mlir::LogicalResult
matchAndRewrite(cir::IfOp ifop, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto condition = adaptor.getCondition();
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
ifop->getLoc(), rewriter.getI1Type(), condition);
auto newIfOp = rewriter.create<mlir::scf::IfOp>(
ifop->getLoc(), ifop->getResultTypes(), i1Condition);
ifop->getLoc(), ifop->getResultTypes(), adaptor.getCondition());
auto *thenBlock = rewriter.createBlock(&newIfOp.getThenRegion());
rewriter.inlineBlockBefore(&ifop.getThenRegion().front(), thenBlock,
thenBlock->end());
Expand All @@ -921,7 +954,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
mlir::OpBuilder b(moduleOp.getContext());

const auto CIRSymType = op.getSymType();
auto convertedType = getTypeConverter()->convertType(CIRSymType);
auto convertedType = convertTypeForMemory(*getTypeConverter(), CIRSymType);
if (!convertedType)
return mlir::failure();
auto memrefType = dyn_cast<mlir::MemRefType>(convertedType);
Expand Down Expand Up @@ -1167,19 +1200,14 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<cir::CastOp> {
return mlir::success();
}
case CIR::float_to_bool: {
auto dstTy = mlir::cast<cir::BoolType>(op.getType());
auto newDstType = convertTy(dstTy);
auto kind = mlir::arith::CmpFPredicate::UNE;

// Check if float is not equal to zero.
auto zeroFloat = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), src.getType(), mlir::FloatAttr::get(src.getType(), 0.0));

// Extend comparison result to either bool (C++) or int (C).
mlir::Value cmpResult = rewriter.create<mlir::arith::CmpFOp>(
op.getLoc(), kind, src, zeroFloat);
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(op, newDstType,
cmpResult);
rewriter.replaceOpWithNewOp<mlir::arith::CmpFOp>(op, kind, src,
zeroFloat);
return mlir::success();
}
case CIR::bool_to_int: {
Expand Down Expand Up @@ -1327,7 +1355,7 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
static mlir::TypeConverter prepareTypeConverter() {
mlir::TypeConverter converter;
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
auto ty = converter.convertType(type.getPointee());
auto ty = convertTypeForMemory(converter, type.getPointee());
// FIXME: The pointee type might not be converted (e.g. struct)
if (!ty)
return nullptr;
Expand All @@ -1347,7 +1375,7 @@ static mlir::TypeConverter prepareTypeConverter() {
mlir::IntegerType::SignednessSemantics::Signless);
});
converter.addConversion([&](cir::BoolType type) -> mlir::Type {
return mlir::IntegerType::get(type.getContext(), 8);
return mlir::IntegerType::get(type.getContext(), 1);
});
converter.addConversion([&](cir::SingleType type) -> mlir::Type {
return mlir::FloatType::getF32(type.getContext());
Expand Down
4 changes: 4 additions & 0 deletions clang/test/CIR/CodeGen/globals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ void use_global() {
int li = a;
}

bool bool_global() {
return e;
}

void use_global_string() {
unsigned char c = s2[0];
}
Expand Down
5 changes: 3 additions & 2 deletions clang/test/CIR/Lowering/ThroughMLIR/bool.cir
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ module {

// MLIR: func @foo() {
// MLIR: [[Value:%[a-z0-9]+]] = memref.alloca() {alignment = 1 : i64} : memref<i8>
// MLIR: = arith.constant 1 : i8
// MLIR: memref.store {{.*}}, [[Value]][] : memref<i8>
// MLIR: %[[CONST:.*]] = arith.constant true
// MLIR: %[[BOOL_TO_MEM:.*]] = arith.extui %[[CONST]] : i1 to i8
// MLIR-NEXT: memref.store %[[BOOL_TO_MEM]], [[Value]][] : memref<i8>
// return

// LLVM: = alloca i8, i64
Expand Down
14 changes: 6 additions & 8 deletions clang/test/CIR/Lowering/ThroughMLIR/branch.cir
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ cir.func @foo(%arg0: !cir.bool) -> !s32i {
}

// MLIR: module {
// MLIR-NEXT: func.func @foo(%arg0: i8) -> i32
// MLIR-NEXT: %0 = arith.trunci %arg0 : i8 to i1
// MLIR-NEXT: cf.cond_br %0, ^bb1, ^bb2
// MLIR-NEXT: func.func @foo(%arg0: i1) -> i32
// MLIR-NEXT: cf.cond_br %arg0, ^bb1, ^bb2
// MLIR-NEXT: ^bb1: // pred: ^bb0
// MLIR-NEXT: %c1_i32 = arith.constant 1 : i32
// MLIR-NEXT: return %c1_i32 : i32
Expand All @@ -25,13 +24,12 @@ cir.func @foo(%arg0: !cir.bool) -> !s32i {
// MLIR-NEXT: }
// MLIR-NEXT: }

// LLVM: define i32 @foo(i8 %0)
// LLVM-NEXT: %2 = trunc i8 %0 to i1
// LLVM-NEXT: br i1 %2, label %3, label %4
// LLVM: define i32 @foo(i1 %0)
// LLVM-NEXT: br i1 %0, label %[[TRUE:.*]], label %[[FALSE:.*]]
// LLVM-EMPTY:
// LLVM-NEXT: 3: ; preds = %1
// LLVM-NEXT: [[TRUE]]:
// LLVM-NEXT: ret i32 1
// LLVM-EMPTY:
// LLVM-NEXT: 4: ; preds = %1
// LLVM-NEXT: [[FALSE]]:
// LLVM-NEXT: ret i32 0
// LLVM-NEXT: }
32 changes: 16 additions & 16 deletions clang/test/CIR/Lowering/ThroughMLIR/cast.cir
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
!u16i = !cir.int<u, 16>
!u8i = !cir.int<u, 8>
module {
// MLIR-LABEL: func.func @cast_int_to_bool(%arg0: i32) -> i8
// LLVM-LABEL: define i8 @cast_int_to_bool(i32 %0)
// MLIR-LABEL: func.func @cast_int_to_bool(%arg0: i32) -> i1
// LLVM-LABEL: define i1 @cast_int_to_bool(i32 %0)
cir.func @cast_int_to_bool(%i : !u32i) -> !cir.bool {
// MLIR-NEXT: %[[ZERO:.*]] = arith.constant 0 : i32
// MLIR-NEXT: arith.cmpi ne, %arg0, %[[ZERO]]
Expand Down Expand Up @@ -71,8 +71,8 @@ module {
%1 = cir.cast(floating, %f : !cir.float), !cir.double
cir.return %1 : !cir.double
}
// MLIR-LABEL: func.func @cast_float_to_bool(%arg0: f32) -> i8
// LLVM-LABEL: define i8 @cast_float_to_bool(float %0)
// MLIR-LABEL: func.func @cast_float_to_bool(%arg0: f32) -> i1
// LLVM-LABEL: define i1 @cast_float_to_bool(float %0)
cir.func @cast_float_to_bool(%f : !cir.float) -> !cir.bool {
// MLIR-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// MLIR-NEXT: arith.cmpf une, %arg0, %[[ZERO]] : f32
Expand All @@ -81,29 +81,29 @@ module {
%1 = cir.cast(float_to_bool, %f : !cir.float), !cir.bool
cir.return %1 : !cir.bool
}
// MLIR-LABEL: func.func @cast_bool_to_int8(%arg0: i8) -> i8
// LLVM-LABEL: define i8 @cast_bool_to_int8(i8 %0)
// MLIR-LABEL: func.func @cast_bool_to_int8(%arg0: i1) -> i8
// LLVM-LABEL: define i8 @cast_bool_to_int8(i1 %0)
cir.func @cast_bool_to_int8(%b : !cir.bool) -> !u8i {
// MLIR-NEXT: arith.bitcast %arg0 : i8 to i8
// LLVM-NEXT: ret i8 %0
// MLIR-NEXT: arith.extui %arg0 : i1 to i8
// LLVM-NEXT: zext i1 %0 to i8

%1 = cir.cast(bool_to_int, %b : !cir.bool), !u8i
cir.return %1 : !u8i
}
// MLIR-LABEL: func.func @cast_bool_to_int(%arg0: i8) -> i32
// LLVM-LABEL: define i32 @cast_bool_to_int(i8 %0)
// MLIR-LABEL: func.func @cast_bool_to_int(%arg0: i1) -> i32
// LLVM-LABEL: define i32 @cast_bool_to_int(i1 %0)
cir.func @cast_bool_to_int(%b : !cir.bool) -> !u32i {
// MLIR-NEXT: arith.extui %arg0 : i8 to i32
// LLVM-NEXT: zext i8 %0 to i32
// MLIR-NEXT: arith.extui %arg0 : i1 to i32
// LLVM-NEXT: zext i1 %0 to i32

%1 = cir.cast(bool_to_int, %b : !cir.bool), !u32i
cir.return %1 : !u32i
}
// MLIR-LABEL: func.func @cast_bool_to_float(%arg0: i8) -> f32
// LLVM-LABEL: define float @cast_bool_to_float(i8 %0)
// MLIR-LABEL: func.func @cast_bool_to_float(%arg0: i1) -> f32
// LLVM-LABEL: define float @cast_bool_to_float(i1 %0)
cir.func @cast_bool_to_float(%b : !cir.bool) -> !cir.float {
// MLIR-NEXT: arith.uitofp %arg0 : i8 to f32
// LLVM-NEXT: uitofp i8 %0 to float
// MLIR-NEXT: arith.uitofp %arg0 : i1 to f32
// LLVM-NEXT: uitofp i1 %0 to float

%1 = cir.cast(bool_to_float, %b : !cir.bool), !cir.float
cir.return %1 : !cir.float
Expand Down
Loading