Skip to content

Commit

Permalink
[InstCombine] Fold (x < y) ? -1 : zext(x != y) into u/scmp(x,y) (l…
Browse files Browse the repository at this point in the history
…lvm#101049)

This patch adds the aforementioned fold to InstCombine. This pattern is
produced after naive implementations of 3-way comparison in high-level
languages are transformed into LLVM IR and then optimized.

Proofs: https://alive2.llvm.org/ce/z/w4QLq_
  • Loading branch information
Poseydon42 authored Aug 19, 2024
1 parent b8dccb7 commit abf69a1
Show file tree
Hide file tree
Showing 6 changed files with 426 additions and 37 deletions.
1 change: 1 addition & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final

// Helpers of visitSelectInst().
Instruction *foldSelectOfBools(SelectInst &SI);
Instruction *foldSelectToCmp(SelectInst &SI);
Instruction *foldSelectExtConst(SelectInst &Sel);
Instruction *foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI);
Instruction *foldSelectIntoOp(SelectInst &SI, Value *, Value *);
Expand Down
52 changes: 52 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3558,6 +3558,55 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {
Masked);
}

// This function tries to fold the following operations:
// (x < y) ? -1 : zext(x != y)
// (x > y) ? 1 : sext(x != y)
// Into ucmp/scmp(x, y), where signedness is determined by the signedness
// of the comparison in the original sequence.
Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
Value *TV = SI.getTrueValue();
Value *FV = SI.getFalseValue();

ICmpInst::Predicate Pred;
Value *LHS, *RHS;
if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(LHS), m_Value(RHS))))
return nullptr;

if (!LHS->getType()->isIntOrIntVectorTy())
return nullptr;

// Try to swap operands and the predicate. We need to be careful when doing
// so because two of the patterns have opposite predicates, so use the
// constant inside select to determine if swapping operands would be
// beneficial to us.
if ((ICmpInst::isGT(Pred) && match(TV, m_AllOnes())) ||
(ICmpInst::isLT(Pred) && match(TV, m_One()))) {
Pred = ICmpInst::getSwappedPredicate(Pred);
std::swap(LHS, RHS);
}

Intrinsic::ID IID =
ICmpInst::isSigned(Pred) ? Intrinsic::scmp : Intrinsic::ucmp;

bool Replace = false;
// (x < y) ? -1 : zext(x != y)
if (ICmpInst::isLT(Pred) && match(TV, m_AllOnes()) &&
match(FV, m_ZExt(m_c_SpecificICmp(ICmpInst::ICMP_NE, m_Specific(LHS),
m_Specific(RHS)))))
Replace = true;

// (x > y) ? 1 : sext(x != y)
if (ICmpInst::isGT(Pred) && match(TV, m_One()) &&
match(FV, m_SExt(m_c_SpecificICmp(ICmpInst::ICMP_NE, m_Specific(LHS),
m_Specific(RHS)))))
Replace = true;

if (Replace)
return replaceInstUsesWith(
SI, Builder.CreateIntrinsic(SI.getType(), IID, {LHS, RHS}));
return nullptr;
}

bool InstCombinerImpl::fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF,
const Instruction *CtxI) const {
KnownFPClass Known = computeKnownFPClass(MulVal, FMF, fcNegative, CtxI);
Expand Down Expand Up @@ -4061,6 +4110,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Instruction *I = foldBitCeil(SI, Builder))
return I;

if (Instruction *I = foldSelectToCmp(SI))
return I;

// Fold:
// (select A && B, T, F) -> (select A, (select B, T, F), F)
// (select A || B, T, F) -> (select A, T, (select B, T, F))
Expand Down
56 changes: 56 additions & 0 deletions llvm/test/Transforms/InstCombine/scmp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,59 @@ define i8 @scmp_negated_multiuse(i32 %x, i32 %y) {
%2 = sub i8 0, %1
ret i8 %2
}

; Fold ((x s< y) ? -1 : (x != y)) into scmp(x, y)
define i8 @scmp_from_select_lt(i32 %x, i32 %y) {
; CHECK-LABEL: define i8 @scmp_from_select_lt(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
; CHECK-NEXT: ret i8 [[R]]
;
%ne_bool = icmp ne i32 %x, %y
%ne = zext i1 %ne_bool to i8
%lt = icmp slt i32 %x, %y
%r = select i1 %lt, i8 -1, i8 %ne
ret i8 %r
}

; Vector version
define <4 x i8> @scmp_from_select_vec_lt(<4 x i32> %x, <4 x i32> %y) {
; CHECK-LABEL: define <4 x i8> @scmp_from_select_vec_lt(
; CHECK-SAME: <4 x i32> [[X:%.*]], <4 x i32> [[Y:%.*]]) {
; CHECK-NEXT: [[R:%.*]] = call <4 x i8> @llvm.scmp.v4i8.v4i32(<4 x i32> [[X]], <4 x i32> [[Y]])
; CHECK-NEXT: ret <4 x i8> [[R]]
;
%ne_bool = icmp ne <4 x i32> %x, %y
%ne = zext <4 x i1> %ne_bool to <4 x i8>
%lt = icmp slt <4 x i32> %x, %y
%r = select <4 x i1> %lt, <4 x i8> splat(i8 -1), <4 x i8> %ne
ret <4 x i8> %r
}

; Fold (x s<= y) ? sext(x != y) : 1 into scmp(x, y)
define i8 @scmp_from_select_le(i32 %x, i32 %y) {
; CHECK-LABEL: define i8 @scmp_from_select_le(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
; CHECK-NEXT: ret i8 [[R]]
;
%ne_bool = icmp ne i32 %x, %y
%ne = sext i1 %ne_bool to i8
%le = icmp sle i32 %x, %y
%r = select i1 %le, i8 %ne, i8 1
ret i8 %r
}

; Fold (x s>= y) ? zext(x != y) : -1 into scmp(x, y)
define i8 @scmp_from_select_ge(i32 %x, i32 %y) {
; CHECK-LABEL: define i8 @scmp_from_select_ge(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
; CHECK-NEXT: ret i8 [[R]]
;
%ne_bool = icmp ne i32 %x, %y
%ne = zext i1 %ne_bool to i8
%ge = icmp sge i32 %x, %y
%r = select i1 %ge, i8 %ne, i8 -1
ret i8 %r
}
44 changes: 9 additions & 35 deletions llvm/test/Transforms/InstCombine/select-select.ll
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,7 @@ define i8 @strong_order_cmp_ugt_eq(i32 %a, i32 %b) {

define i8 @strong_order_cmp_eq_slt(i32 %a, i32 %b) {
; CHECK-LABEL: @strong_order_cmp_eq_slt(
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SEL_EQ:%.*]] = zext i1 [[CMP_EQ]] to i8
; CHECK-NEXT: [[CMP_LT:%.*]] = icmp slt i32 [[A]], [[B]]
; CHECK-NEXT: [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[SEL_EQ]]
; CHECK-NEXT: [[SEL_LT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
; CHECK-NEXT: ret i8 [[SEL_LT]]
;
%cmp.eq = icmp eq i32 %a, %b
Expand All @@ -297,10 +294,7 @@ define i8 @strong_order_cmp_eq_slt(i32 %a, i32 %b) {

define i8 @strong_order_cmp_eq_sgt(i32 %a, i32 %b) {
; CHECK-LABEL: @strong_order_cmp_eq_sgt(
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SEL_EQ:%.*]] = sext i1 [[CMP_EQ]] to i8
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp sgt i32 [[A]], [[B]]
; CHECK-NEXT: [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEL_EQ]]
; CHECK-NEXT: [[SEL_GT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
; CHECK-NEXT: ret i8 [[SEL_GT]]
;
%cmp.eq = icmp eq i32 %a, %b
Expand All @@ -312,10 +306,7 @@ define i8 @strong_order_cmp_eq_sgt(i32 %a, i32 %b) {

define i8 @strong_order_cmp_eq_ult(i32 %a, i32 %b) {
; CHECK-LABEL: @strong_order_cmp_eq_ult(
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SEL_EQ:%.*]] = zext i1 [[CMP_EQ]] to i8
; CHECK-NEXT: [[CMP_LT:%.*]] = icmp ult i32 [[A]], [[B]]
; CHECK-NEXT: [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[SEL_EQ]]
; CHECK-NEXT: [[SEL_LT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
; CHECK-NEXT: ret i8 [[SEL_LT]]
;
%cmp.eq = icmp eq i32 %a, %b
Expand All @@ -327,10 +318,7 @@ define i8 @strong_order_cmp_eq_ult(i32 %a, i32 %b) {

define i8 @strong_order_cmp_eq_ugt(i32 %a, i32 %b) {
; CHECK-LABEL: @strong_order_cmp_eq_ugt(
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SEL_EQ:%.*]] = sext i1 [[CMP_EQ]] to i8
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt i32 [[A]], [[B]]
; CHECK-NEXT: [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEL_EQ]]
; CHECK-NEXT: [[SEL_GT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
; CHECK-NEXT: ret i8 [[SEL_GT]]
;
%cmp.eq = icmp eq i32 %a, %b
Expand Down Expand Up @@ -404,9 +392,7 @@ define i8 @strong_order_cmp_ne_ugt_ne_not_one_use(i32 %a, i32 %b) {
; CHECK-LABEL: @strong_order_cmp_ne_ugt_ne_not_one_use(
; CHECK-NEXT: [[CMP_NE:%.*]] = icmp ne i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: call void @use1(i1 [[CMP_NE]])
; CHECK-NEXT: [[SEL_EQ:%.*]] = sext i1 [[CMP_NE]] to i8
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt i32 [[A]], [[B]]
; CHECK-NEXT: [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEL_EQ]]
; CHECK-NEXT: [[SEL_GT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
; CHECK-NEXT: ret i8 [[SEL_GT]]
;
%cmp.ne = icmp ne i32 %a, %b
Expand Down Expand Up @@ -535,10 +521,7 @@ define <2 x i8> @strong_order_cmp_ugt_ult_vector_poison(<2 x i32> %a, <2 x i32>

define <2 x i8> @strong_order_cmp_eq_ugt_vector(<2 x i32> %a, <2 x i32> %b) {
; CHECK-LABEL: @strong_order_cmp_eq_ugt_vector(
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne <2 x i32> [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SEL_EQ:%.*]] = sext <2 x i1> [[CMP_EQ]] to <2 x i8>
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A]], [[B]]
; CHECK-NEXT: [[SEL_GT:%.*]] = select <2 x i1> [[CMP_GT]], <2 x i8> <i8 1, i8 1>, <2 x i8> [[SEL_EQ]]
; CHECK-NEXT: [[SEL_GT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
; CHECK-NEXT: ret <2 x i8> [[SEL_GT]]
;
%cmp.eq = icmp eq <2 x i32> %a, %b
Expand All @@ -550,10 +533,7 @@ define <2 x i8> @strong_order_cmp_eq_ugt_vector(<2 x i32> %a, <2 x i32> %b) {

define <2 x i8> @strong_order_cmp_eq_ugt_vector_poison1(<2 x i32> %a, <2 x i32> %b) {
; CHECK-LABEL: @strong_order_cmp_eq_ugt_vector_poison1(
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne <2 x i32> [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SEL_EQ:%.*]] = sext <2 x i1> [[CMP_EQ]] to <2 x i8>
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A]], [[B]]
; CHECK-NEXT: [[SEL_GT:%.*]] = select <2 x i1> [[CMP_GT]], <2 x i8> <i8 1, i8 1>, <2 x i8> [[SEL_EQ]]
; CHECK-NEXT: [[SEL_GT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
; CHECK-NEXT: ret <2 x i8> [[SEL_GT]]
;
%cmp.eq = icmp eq <2 x i32> %a, %b
Expand All @@ -565,10 +545,7 @@ define <2 x i8> @strong_order_cmp_eq_ugt_vector_poison1(<2 x i32> %a, <2 x i32>

define <2 x i8> @strong_order_cmp_eq_ugt_vector_poison2(<2 x i32> %a, <2 x i32> %b) {
; CHECK-LABEL: @strong_order_cmp_eq_ugt_vector_poison2(
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne <2 x i32> [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SEL_EQ:%.*]] = sext <2 x i1> [[CMP_EQ]] to <2 x i8>
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A]], [[B]]
; CHECK-NEXT: [[SEL_GT:%.*]] = select <2 x i1> [[CMP_GT]], <2 x i8> <i8 1, i8 1>, <2 x i8> [[SEL_EQ]]
; CHECK-NEXT: [[SEL_GT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
; CHECK-NEXT: ret <2 x i8> [[SEL_GT]]
;
%cmp.eq = icmp eq <2 x i32> %a, %b
Expand All @@ -580,10 +557,7 @@ define <2 x i8> @strong_order_cmp_eq_ugt_vector_poison2(<2 x i32> %a, <2 x i32>

define <2 x i8> @strong_order_cmp_eq_ugt_vector_poison3(<2 x i32> %a, <2 x i32> %b) {
; CHECK-LABEL: @strong_order_cmp_eq_ugt_vector_poison3(
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne <2 x i32> [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SEL_EQ:%.*]] = sext <2 x i1> [[CMP_EQ]] to <2 x i8>
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A]], [[B]]
; CHECK-NEXT: [[SEL_GT:%.*]] = select <2 x i1> [[CMP_GT]], <2 x i8> <i8 1, i8 poison>, <2 x i8> [[SEL_EQ]]
; CHECK-NEXT: [[SEL_GT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
; CHECK-NEXT: ret <2 x i8> [[SEL_GT]]
;
%cmp.eq = icmp eq <2 x i32> %a, %b
Expand Down
Loading

0 comments on commit abf69a1

Please sign in to comment.