Skip to content

Commit

Permalink
dialects: (arith) add addi constant propagation (#3563)
Browse files Browse the repository at this point in the history
Moves constants to the right for an addi operation, similar to what
#3557 did for muli. Also replaces `AddImmediateZero` by
`AddiIdentityRight`, which removes constant zeros on the right of an
addition instead of the left of an addition like the previous pattern
did.

Updates interactive tests (and driveby fixes the constant zero being
given the SSA value name %two in this test).
  • Loading branch information
alexarice authored Dec 4, 2024
1 parent 560b52f commit 494a12f
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 23 deletions.
16 changes: 16 additions & 0 deletions tests/filecheck/dialects/arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,22 @@ func.func @test_const_var_const() {
%foldable_times = arith.muli %c2, %c2 : i32
"test.op"(%foldable_times) : (i32) -> ()

%c0 = arith.constant 0 : i32

%zero_plus = arith.addi %c0, %a : i32
%plus_zero = arith.addi %a, %c0 : i32

// CHECK: "test.op"(%a, %a) {"identity addition check"} : (i32, i32) -> ()
"test.op"(%zero_plus, %plus_zero) {"identity addition check"} : (i32, i32) -> ()

// CHECK: %plus_const = arith.addi %a, %c2 : i32
%plus_const = arith.addi %c2, %a : i32
"test.op"(%plus_const) : (i32) -> ()

// CHECK: %foldable_plus = arith.constant 4 : i32
%foldable_plus = arith.addi %c2, %c2 : i32
"test.op"(%foldable_plus) : (i32) -> ()

// CHECK: %int = "test.op"() : () -> i32
%int = "test.op"() : () -> i32
// CHECK-NEXT: %{{.*}} = arith.constant true
Expand Down
15 changes: 7 additions & 8 deletions tests/filecheck/dialects/csl/csl-canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,13 @@ builtin.module {
"test.op"(%19) : (!csl<dsd mem1d_dsd>) -> ()

// CHECK-NEXT: %3 = "test.op"() : () -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %4 = arith.constant 2 : si16
// CHECK-NEXT: %5 = arith.addi %4, %4 : si16
// CHECK-NEXT: %6 = "csl.increment_dsd_offset"(%3, %5) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %7 = arith.constant 511 : ui16
// CHECK-NEXT: %8 = "csl.set_dsd_length"(%6, %7) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %9 = arith.constant 3 : si8
// CHECK-NEXT: %10 = "csl.set_dsd_stride"(%8, %9) : (!csl<dsd mem1d_dsd>, si8) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: "test.op"(%10) : (!csl<dsd mem1d_dsd>) -> ()
// CHECK-NEXT: %4 = arith.constant 4 : si16
// CHECK-NEXT: %5 = "csl.increment_dsd_offset"(%3, %4) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %6 = arith.constant 511 : ui16
// CHECK-NEXT: %7 = "csl.set_dsd_length"(%5, %6) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %8 = arith.constant 3 : si8
// CHECK-NEXT: %9 = "csl.set_dsd_stride"(%7, %8) : (!csl<dsd mem1d_dsd>, si8) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: "test.op"(%9) : (!csl<dsd mem1d_dsd>) -> ()

}
// CHECK-NEXT: }
12 changes: 6 additions & 6 deletions tests/interactive/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ async def test_rewrites():
app.input_text_area.insert(
"""
func.func @hello(%n : i32) -> i32 {
%two = arith.constant 0 : i32
%res = arith.addi %two, %n : i32
%c0 = arith.constant 0 : i32
%res = arith.addi %n, %c0 : i32
func.return %res : i32
}
"""
Expand All @@ -334,11 +334,11 @@ async def test_rewrites():
await pilot.click("#condense_button")

addi_pass = AvailablePass(
display_name="AddiOp(%res = arith.addi %two, %n : i32):arith.addi:AddImmediateZero",
display_name="AddiOp(%res = arith.addi %n, %c0 : i32):arith.addi:AddiIdentityRight",
module_pass=individual_rewrite.ApplyIndividualRewritePass,
pass_spec=list(
parse_pipeline(
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddImmediateZero"}'
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}'
)
)[0],
)
Expand All @@ -359,7 +359,7 @@ async def test_rewrites():
individual_rewrite.ApplyIndividualRewritePass,
list(
parse_pipeline(
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddImmediateZero"}'
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}'
)
)[0],
),
Expand All @@ -371,7 +371,7 @@ async def test_rewrites():
app.output_text_area.text
== """builtin.module {
func.func @hello(%n : i32) -> i32 {
%two = arith.constant 0 : i32
%c0 = arith.constant 0 : i32
func.return %n : i32
}
}
Expand Down
7 changes: 5 additions & 2 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,12 @@ def print(self, printer: Printer):
class AddiOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.arith import AddImmediateZero
from xdsl.transforms.canonicalization_patterns.arith import (
AddiConstantProp,
AddiIdentityRight,
)

return (AddImmediateZero(),)
return (AddiIdentityRight(), AddiConstantProp())


@irdl_op_definition
Expand Down
29 changes: 22 additions & 7 deletions xdsl/transforms/canonicalization_patterns/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,30 @@
from xdsl.utils.hints import isa


class AddImmediateZero(RewritePattern):
class AddiIdentityRight(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.AddiOp, rewriter: PatternRewriter) -> None:
if (
isinstance(op.lhs.owner, arith.ConstantOp)
and isinstance(value := op.lhs.owner.value, IntegerAttr)
and value.value.data == 0
):
rewriter.replace_matched_op([], [op.rhs])
if (rhs := const_evaluate_operand(op.rhs)) is None:
return
if rhs != 0:
return
rewriter.replace_matched_op((), (op.lhs,))


class AddiConstantProp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.AddiOp, rewriter: PatternRewriter):
if (lhs := const_evaluate_operand(op.lhs)) is None:
return
if (rhs := const_evaluate_operand(op.rhs)) is None:
# Swap inputs if lhs is constant and rhs is not
rewriter.replace_matched_op(arith.AddiOp(op.rhs, op.lhs))
return

assert isinstance(op.result.type, IntegerType | IndexType)
rewriter.replace_matched_op(
arith.ConstantOp.from_int_and_width(lhs + rhs, op.result.type)
)


def _fold_const_operation(
Expand Down

0 comments on commit 494a12f

Please sign in to comment.