Skip to content

Commit

Permalink
dialects: (arith) add Cmpi canonicalization (#3564)
Browse files Browse the repository at this point in the history
Adds cmpi canonicalization patterns for when the two operands are the
same ssa value.

---------

Co-authored-by: Sasha Lopoukhine <[email protected]>
  • Loading branch information
alexarice and superlopuh authored Dec 4, 2024
1 parent 44d321a commit ad28789
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 1 deletion.
25 changes: 25 additions & 0 deletions tests/filecheck/dialects/arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,28 @@ func.func @test_const_var_const() {
// CHECK: %foldable_times = arith.constant 4 : i32
%foldable_times = arith.muli %c2, %c2 : i32
"test.op"(%foldable_times) : (i32) -> ()

// CHECK: %int = "test.op"() : () -> i32
%int = "test.op"() : () -> i32
// CHECK-NEXT: %{{.*}} = arith.constant true
%0 = arith.cmpi eq, %int, %int : i32
// CHECK-NEXT: %{{.*}} = arith.constant false
%1 = arith.cmpi ne, %int, %int : i32
// CHECK-NEXT: %{{.*}} = arith.constant false
%2 = arith.cmpi slt, %int, %int : i32
// CHECK-NEXT: %{{.*}} = arith.constant true
%3 = arith.cmpi sle, %int, %int : i32
// CHECK-NEXT: %{{.*}} = arith.constant false
%4 = arith.cmpi sgt, %int, %int : i32
// CHECK-NEXT: %{{.*}} = arith.constant true
%5 = arith.cmpi sge, %int, %int : i32
// CHECK-NEXT: %{{.*}} = arith.constant false
%6 = arith.cmpi ult, %int, %int : i32
// CHECK-NEXT: %{{.*}} = arith.constant true
%7 = arith.cmpi ule, %int, %int : i32
// CHECK-NEXT: %{{.*}} = arith.constant false
%8 = arith.cmpi ugt, %int, %int : i32
// CHECK-NEXT: %{{.*}} = arith.constant true
%9 = arith.cmpi uge, %int, %int : i32

"test.op"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %int) : (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i32) -> ()
10 changes: 10 additions & 0 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,14 @@ def _validate_operand_types(operand1: SSAValue, operand2: SSAValue):
traits = traits_def(Pure())


class CmpiHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns import arith

return (arith.ApplyCmpiPredicateToEqualOperands(),)


@irdl_op_definition
class CmpiOp(ComparisonOperation):
"""
Expand Down Expand Up @@ -689,6 +697,8 @@ class CmpiOp(ComparisonOperation):
rhs = operand_def(signlessIntegerLike)
result = result_def(IntegerType(1))

traits = traits_def(CmpiHasCanonicalizationPatterns(), Pure())

def __init__(
self,
operand1: Operation | SSAValue,
Expand Down
11 changes: 10 additions & 1 deletion xdsl/transforms/canonicalization_patterns/arith.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from xdsl.dialects import arith, builtin
from xdsl.dialects.builtin import IndexType, IntegerAttr, IntegerType
from xdsl.dialects.builtin import BoolAttr, IndexType, IntegerAttr, IntegerType
from xdsl.pattern_rewriter import (
PatternRewriter,
RewritePattern,
Expand Down Expand Up @@ -191,3 +191,12 @@ def match_and_rewrite(self, op: arith.MuliOp, rewriter: PatternRewriter):
rewriter.replace_matched_op(
arith.ConstantOp.from_int_and_width(lhs * rhs, op.result.type)
)


class ApplyCmpiPredicateToEqualOperands(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.CmpiOp, rewriter: PatternRewriter):
if op.lhs != op.rhs:
return
val = op.predicate.value.data in (0, 3, 5, 7, 9)
rewriter.replace_matched_op(arith.ConstantOp(BoolAttr.from_bool(val)))

0 comments on commit ad28789

Please sign in to comment.