Skip to content

Commit

Permalink
dialects: (riscv) Add rewrite pattern to optimize bitwise xor by zero (
Browse files Browse the repository at this point in the history
…#3197)

Add a rewrite pattern to the risv dialect for to optimise bitwise xor by
zero (x^0 -> x)

---------

Co-authored-by: emmau678 <eu233@Emma-laptop>
  • Loading branch information
emmau678 and emmau678 authored Sep 19, 2024
1 parent c0713ea commit 58b7625
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 3 deletions.
1 change: 0 additions & 1 deletion GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ You're welcome to come up with your own, or do one of the following:
- `x * 2ⁱ -> x << i`
- `x & 0 -> 0`
- `x | 0 -> x`
- `x ^ 0 -> x`

The patterns are defined in
[xdsl/transforms/canonicalization_patterns/riscv.py](xdsl/transforms/canonicalization_patterns/riscv.py).
Expand Down
12 changes: 12 additions & 0 deletions tests/filecheck/backend/riscv/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ builtin.module {
%xor_lhs_rhs = riscv.xor %i1, %i1 : (!riscv.reg<a1>, !riscv.reg<a1>) -> !riscv.reg<a0>
"test.op"(%xor_lhs_rhs) : (!riscv.reg<a0>) -> ()

%xor_bitwise_zero_l0 = riscv.xor %c1, %c0 : (!riscv.reg, !riscv.reg) -> !riscv.reg<a0>
"test.op"(%xor_bitwise_zero_l0) : (!riscv.reg<a0>) -> ()

%xor_bitwise_zero_r0 = riscv.xor %c0, %c1 : (!riscv.reg, !riscv.reg) -> !riscv.reg<a0>
"test.op"(%xor_bitwise_zero_r0) : (!riscv.reg<a0>) -> ()

// scfgw immediates
riscv_snitch.scfgw %i1, %c1 : (!riscv.reg<a1>, !riscv.reg) -> ()
}
Expand Down Expand Up @@ -221,6 +227,12 @@ builtin.module {
// CHECK-NEXT: %xor_lhs_rhs_1 = riscv.mv %xor_lhs_rhs : (!riscv.reg<zero>) -> !riscv.reg<a0>
// CHECK-NEXT: "test.op"(%xor_lhs_rhs_1) : (!riscv.reg<a0>) -> ()

// CHECK-NEXT: %xor_bitwise_zero_l0 = riscv.mv %c1 : (!riscv.reg) -> !riscv.reg<a0>
// CHECK-NEXT: "test.op"(%xor_bitwise_zero_l0) : (!riscv.reg<a0>) -> ()

// CHECK-NEXT: %xor_bitwise_zero_r0 = riscv.mv %c1 : (!riscv.reg) -> !riscv.reg<a0>
// CHECK-NEXT: "test.op"(%xor_bitwise_zero_r0) : (!riscv.reg<a0>) -> ()

// CHECK-NEXT: riscv_snitch.scfgwi %i1, 1 : (!riscv.reg<a1>) -> ()

// CHECK-NEXT: }
Expand Down
7 changes: 5 additions & 2 deletions xdsl/dialects/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,9 +1777,12 @@ class OrOp(RdRsRsOperation[IntRegisterType, IntRegisterType, IntRegisterType]):
class BitwiseXorHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.riscv import XorBySelf
from xdsl.transforms.canonicalization_patterns.riscv import (
BitwiseXorByZero,
XorBySelf,
)

return (XorBySelf(),)
return (XorBySelf(), BitwiseXorByZero())


@irdl_op_definition
Expand Down
15 changes: 15 additions & 0 deletions xdsl/transforms/canonicalization_patterns/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,21 @@ def match_and_rewrite(self, op: riscv.XorOp, rewriter: PatternRewriter):
)


class BitwiseXorByZero(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv.XorOp, rewriter: PatternRewriter):
"""
x ^ 0 = x
"""
if (rs1 := get_constant_value(op.rs1)) is not None and rs1.value.data == 0:
rd = cast(riscv.IntRegisterType, op.rd.type)
rewriter.replace_matched_op(riscv.MVOp(op.rs2, rd=rd))

if (rs2 := get_constant_value(op.rs2)) is not None and rs2.value.data == 0:
rd = cast(riscv.IntRegisterType, op.rd.type)
rewriter.replace_matched_op(riscv.MVOp(op.rs1, rd=rd))


class ScfgwOpUsingImmediate(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
Expand Down

0 comments on commit 58b7625

Please sign in to comment.