diff --git a/tests/filecheck/dialects/arith/canonicalize.mlir b/tests/filecheck/dialects/arith/canonicalize.mlir index 01a2ea0c86..b87efd39d8 100644 --- a/tests/filecheck/dialects/arith/canonicalize.mlir +++ b/tests/filecheck/dialects/arith/canonicalize.mlir @@ -109,6 +109,22 @@ func.func @test_const_var_const() { %foldable_times = arith.muli %c2, %c2 : i32 "test.op"(%foldable_times) : (i32) -> () +%c0 = arith.constant 0 : i32 + +%zero_plus = arith.addi %c0, %a : i32 +%plus_zero = arith.addi %a, %c0 : i32 + +// CHECK: "test.op"(%a, %a) {"identity addition check"} : (i32, i32) -> () +"test.op"(%zero_plus, %plus_zero) {"identity addition check"} : (i32, i32) -> () + +// CHECK: %plus_const = arith.addi %a, %c2 : i32 +%plus_const = arith.addi %c2, %a : i32 +"test.op"(%plus_const) : (i32) -> () + +// CHECK: %foldable_plus = arith.constant 4 : i32 +%foldable_plus = arith.addi %c2, %c2 : i32 +"test.op"(%foldable_plus) : (i32) -> () + // CHECK: %int = "test.op"() : () -> i32 %int = "test.op"() : () -> i32 // CHECK-NEXT: %{{.*}} = arith.constant true diff --git a/tests/filecheck/dialects/csl/csl-canonicalize.mlir b/tests/filecheck/dialects/csl/csl-canonicalize.mlir index 5075afe5c9..713595bf39 100644 --- a/tests/filecheck/dialects/csl/csl-canonicalize.mlir +++ b/tests/filecheck/dialects/csl/csl-canonicalize.mlir @@ -41,14 +41,13 @@ builtin.module { "test.op"(%19) : (!csl) -> () // CHECK-NEXT: %3 = "test.op"() : () -> !csl -// CHECK-NEXT: %4 = arith.constant 2 : si16 -// CHECK-NEXT: %5 = arith.addi %4, %4 : si16 -// CHECK-NEXT: %6 = "csl.increment_dsd_offset"(%3, %5) <{"elem_type" = f32}> : (!csl, si16) -> !csl -// CHECK-NEXT: %7 = arith.constant 511 : ui16 -// CHECK-NEXT: %8 = "csl.set_dsd_length"(%6, %7) : (!csl, ui16) -> !csl -// CHECK-NEXT: %9 = arith.constant 3 : si8 -// CHECK-NEXT: %10 = "csl.set_dsd_stride"(%8, %9) : (!csl, si8) -> !csl -// CHECK-NEXT: "test.op"(%10) : (!csl) -> () +// CHECK-NEXT: %4 = arith.constant 4 : si16 +// CHECK-NEXT: %5 = "csl.increment_dsd_offset"(%3, %4) <{"elem_type" = f32}> : (!csl, si16) -> !csl +// CHECK-NEXT: %6 = arith.constant 511 : ui16 +// CHECK-NEXT: %7 = "csl.set_dsd_length"(%5, %6) : (!csl, ui16) -> !csl +// CHECK-NEXT: %8 = arith.constant 3 : si8 +// CHECK-NEXT: %9 = "csl.set_dsd_stride"(%7, %8) : (!csl, si8) -> !csl +// CHECK-NEXT: "test.op"(%9) : (!csl) -> () } // CHECK-NEXT: } diff --git a/tests/interactive/test_app.py b/tests/interactive/test_app.py index c8401e03b6..5f2d24290e 100644 --- a/tests/interactive/test_app.py +++ b/tests/interactive/test_app.py @@ -323,8 +323,8 @@ async def test_rewrites(): app.input_text_area.insert( """ func.func @hello(%n : i32) -> i32 { - %two = arith.constant 0 : i32 - %res = arith.addi %two, %n : i32 + %c0 = arith.constant 0 : i32 + %res = arith.addi %n, %c0 : i32 func.return %res : i32 } """ @@ -334,11 +334,11 @@ async def test_rewrites(): await pilot.click("#condense_button") addi_pass = AvailablePass( - display_name="AddiOp(%res = arith.addi %two, %n : i32):arith.addi:AddImmediateZero", + display_name="AddiOp(%res = arith.addi %n, %c0 : i32):arith.addi:AddiIdentityRight", module_pass=individual_rewrite.ApplyIndividualRewritePass, pass_spec=list( parse_pipeline( - 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddImmediateZero"}' + 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}' ) )[0], ) @@ -359,7 +359,7 @@ async def test_rewrites(): individual_rewrite.ApplyIndividualRewritePass, list( parse_pipeline( - 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddImmediateZero"}' + 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}' ) )[0], ), @@ -371,7 +371,7 @@ async def test_rewrites(): app.output_text_area.text == """builtin.module { func.func @hello(%n : i32) -> i32 { - %two = arith.constant 0 : i32 + %c0 = arith.constant 0 : i32 func.return %n : i32 } } diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index dc0fe4cf49..5f609abcc0 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -316,9 +316,12 @@ def print(self, printer: Printer): class AddiOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait): @classmethod def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: - from xdsl.transforms.canonicalization_patterns.arith import AddImmediateZero + from xdsl.transforms.canonicalization_patterns.arith import ( + AddiConstantProp, + AddiIdentityRight, + ) - return (AddImmediateZero(),) + return (AddiIdentityRight(), AddiConstantProp()) @irdl_op_definition diff --git a/xdsl/transforms/canonicalization_patterns/arith.py b/xdsl/transforms/canonicalization_patterns/arith.py index cf12cdefc9..85fb9d7093 100644 --- a/xdsl/transforms/canonicalization_patterns/arith.py +++ b/xdsl/transforms/canonicalization_patterns/arith.py @@ -9,15 +9,30 @@ from xdsl.utils.hints import isa -class AddImmediateZero(RewritePattern): +class AddiIdentityRight(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: arith.AddiOp, rewriter: PatternRewriter) -> None: - if ( - isinstance(op.lhs.owner, arith.ConstantOp) - and isinstance(value := op.lhs.owner.value, IntegerAttr) - and value.value.data == 0 - ): - rewriter.replace_matched_op([], [op.rhs]) + if (rhs := const_evaluate_operand(op.rhs)) is None: + return + if rhs != 0: + return + rewriter.replace_matched_op((), (op.lhs,)) + + +class AddiConstantProp(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: arith.AddiOp, rewriter: PatternRewriter): + if (lhs := const_evaluate_operand(op.lhs)) is None: + return + if (rhs := const_evaluate_operand(op.rhs)) is None: + # Swap inputs if lhs is constant and rhs is not + rewriter.replace_matched_op(arith.AddiOp(op.rhs, op.lhs)) + return + + assert isinstance(op.result.type, IntegerType | IndexType) + rewriter.replace_matched_op( + arith.ConstantOp.from_int_and_width(lhs + rhs, op.result.type) + ) def _fold_const_operation(