Skip to content

Commit

Permalink
dialects: (arith) add SignlessIntegerBinaryOperation canonicalization (
Browse files Browse the repository at this point in the history
…#3583)

Generically implements various arith canonicalizations on SignlessIntegerBinaryOperation
  • Loading branch information
alexarice authored Dec 17, 2024
1 parent 3725f40 commit b8611e0
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 81 deletions.
10 changes: 10 additions & 0 deletions tests/filecheck/dialects/arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,13 @@ func.func @test_const_var_const() {
%9 = arith.cmpi uge, %int, %int : i32

"test.op"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %int) : (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i32) -> ()

// Subtraction is not commutative so should not have the constant swapped to the right
// CHECK: arith.subi %c2, %a : i32
%10 = arith.subi %c2, %a : i32
"test.op"(%10) : (i32) -> ()

// CHECK: %{{.*}} = arith.constant false
%11 = arith.constant true
%12 = arith.addi %11, %11 : i1
"test.op"(%12) : (i1) -> ()
10 changes: 5 additions & 5 deletions tests/interactive/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,11 @@ async def test_rewrites():
await pilot.click("#condense_button")

addi_pass = AvailablePass(
display_name="AddiOp(%res = arith.addi %n, %c0 : i32):arith.addi:AddiIdentityRight",
display_name="AddiOp(%res = arith.addi %n, %c0 : i32):arith.addi:SignlessIntegerBinaryOperationZeroOrUnitRight",
module_pass=individual_rewrite.ApplyIndividualRewritePass,
pass_spec=list(
parse_pipeline(
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}'
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}'
)
)[0],
)
Expand All @@ -354,7 +354,7 @@ async def test_rewrites():
individual_rewrite.ApplyIndividualRewritePass,
list(
parse_pipeline(
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}'
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}'
)
)[0],
),
Expand Down Expand Up @@ -563,7 +563,7 @@ async def test_apply_individual_rewrite():
n.data is not None
and n.data[1] is not None
and str(n.data[1])
== 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiConstantProp"}'
== 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationConstantProp"}'
):
node = n

Expand Down Expand Up @@ -593,7 +593,7 @@ async def test_apply_individual_rewrite():
n.data is not None
and n.data[1] is not None
and str(n.data[1])
== 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}'
== 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}'
):
node = n

Expand Down
220 changes: 187 additions & 33 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from xdsl.pattern_rewriter import RewritePattern
from xdsl.printer import Printer
from xdsl.traits import (
Commutative,
ConditionallySpeculatable,
ConstantLike,
HasCanonicalizationPatternsTrait,
Expand Down Expand Up @@ -195,6 +196,36 @@ class SignlessIntegerBinaryOperation(IRDLOperation, abc.ABC):

assembly_format = "$lhs `,` $rhs attr-dict `:` type($result)"

@staticmethod
def py_operation(lhs: int, rhs: int) -> int | None:
"""
Performs a python function corresponding to this operation.
If `i := py_operation(lhs, rhs)` is an int, then this operation can be
canonicalized to a constant with value `i` when the inputs are constants
with values `lhs` and `rhs`.
"""
return None

@staticmethod
def is_right_zero(attr: AnyIntegerAttr) -> bool:
"""
Returns True only when 'attr' is a right zero for the operation
https://en.wikipedia.org/wiki/Absorbing_element
Note that this depends on the operation and does *not* imply that
attr.value.data == 0
"""
return False

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
"""
Return True only when 'attr' is a right unit/identity for the operation
https://en.wikipedia.org/wiki/Identity_element
"""
return False

def __init__(
self,
operand1: Operation | SSAValue,
Expand All @@ -209,6 +240,22 @@ def __hash__(self) -> int:
return id(self)


class SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(
HasCanonicalizationPatternsTrait
):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.arith import (
SignlessIntegerBinaryOperationConstantProp,
SignlessIntegerBinaryOperationZeroOrUnitRight,
)

return (
SignlessIntegerBinaryOperationConstantProp(),
SignlessIntegerBinaryOperationZeroOrUnitRight(),
)


class SignlessIntegerBinaryOperationWithOverflow(
SignlessIntegerBinaryOperation, abc.ABC
):
Expand Down Expand Up @@ -318,22 +365,23 @@ def print(self, printer: Printer):
printer.print_attribute(self.result.type)


class AddiOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.arith import (
AddiConstantProp,
AddiIdentityRight,
)

return (AddiIdentityRight(), AddiConstantProp())


@irdl_op_definition
class AddiOp(SignlessIntegerBinaryOperationWithOverflow):
name = "arith.addi"

traits = traits_def(Pure(), AddiOpHasCanonicalizationPatternsTrait())
traits = traits_def(
Pure(),
Commutative(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def py_operation(lhs: int, rhs: int) -> int | None:
return lhs + rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


@irdl_op_definition
Expand Down Expand Up @@ -400,19 +448,27 @@ def infer_overflow_type(input_type: Attribute) -> Attribute:
)


class MuliHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns import arith

return (arith.MuliIdentityRight(), arith.MuliConstantProp())


@irdl_op_definition
class MuliOp(SignlessIntegerBinaryOperationWithOverflow):
name = "arith.muli"

traits = traits_def(Pure(), MuliHasCanonicalizationPatterns())
traits = traits_def(
Pure(),
Commutative(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def py_operation(lhs: int, rhs: int) -> int | None:
return lhs * rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)

@staticmethod
def is_right_zero(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


class MulExtendedBase(IRDLOperation):
Expand Down Expand Up @@ -460,7 +516,17 @@ class MulSIExtendedOp(MulExtendedBase):
class SubiOp(SignlessIntegerBinaryOperationWithOverflow):
name = "arith.subi"

traits = traits_def(Pure())
traits = traits_def(
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait()
)

@staticmethod
def py_operation(lhs: int, rhs: int) -> int | None:
return lhs - rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


class DivUISpeculatable(ConditionallySpeculatable):
Expand All @@ -483,7 +549,15 @@ class DivUIOp(SignlessIntegerBinaryOperation):

name = "arith.divui"

traits = traits_def(NoMemoryEffect(), DivUISpeculatable())
traits = traits_def(
NoMemoryEffect(),
DivUISpeculatable(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


@irdl_op_definition
Expand All @@ -495,7 +569,14 @@ class DivSIOp(SignlessIntegerBinaryOperation):

name = "arith.divsi"

traits = traits_def(NoMemoryEffect())
traits = traits_def(
NoMemoryEffect(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


@irdl_op_definition
Expand All @@ -506,21 +587,40 @@ class FloorDivSIOp(SignlessIntegerBinaryOperation):

name = "arith.floordivsi"

traits = traits_def(Pure())
traits = traits_def(
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait()
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


@irdl_op_definition
class CeilDivSIOp(SignlessIntegerBinaryOperation):
name = "arith.ceildivsi"

traits = traits_def(Pure())
traits = traits_def(
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait()
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


@irdl_op_definition
class CeilDivUIOp(SignlessIntegerBinaryOperation):
name = "arith.ceildivui"

traits = traits_def(NoMemoryEffect())
traits = traits_def(
NoMemoryEffect(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


@irdl_op_definition
Expand Down Expand Up @@ -567,21 +667,57 @@ class MaxSIOp(SignlessIntegerBinaryOperation):
class AndIOp(SignlessIntegerBinaryOperation):
name = "arith.andi"

traits = traits_def(Pure())
traits = traits_def(
Pure(),
Commutative(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def py_operation(lhs: int, rhs: int) -> int | None:
return lhs & rhs

@staticmethod
def is_right_zero(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


@irdl_op_definition
class OrIOp(SignlessIntegerBinaryOperation):
name = "arith.ori"

traits = traits_def(Pure())
traits = traits_def(
Pure(),
Commutative(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def py_operation(lhs: int, rhs: int) -> int | None:
return lhs | rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


@irdl_op_definition
class XOrIOp(SignlessIntegerBinaryOperation):
name = "arith.xori"

traits = traits_def(Pure())
traits = traits_def(
Pure(),
Commutative(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def py_operation(lhs: int, rhs: int) -> int | None:
return lhs ^ rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


@irdl_op_definition
Expand All @@ -593,7 +729,13 @@ class ShLIOp(SignlessIntegerBinaryOperationWithOverflow):

name = "arith.shli"

traits = traits_def(Pure())
traits = traits_def(
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait()
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


@irdl_op_definition
Expand All @@ -606,7 +748,13 @@ class ShRUIOp(SignlessIntegerBinaryOperation):

name = "arith.shrui"

traits = traits_def(Pure())
traits = traits_def(
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait()
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


@irdl_op_definition
Expand All @@ -620,7 +768,13 @@ class ShRSIOp(SignlessIntegerBinaryOperation):

name = "arith.shrsi"

traits = traits_def(Pure())
traits = traits_def(
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait()
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


class ComparisonOperation(IRDLOperation):
Expand Down
Loading

0 comments on commit b8611e0

Please sign in to comment.