-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
transformations: (eqsat) add pass to convert non-eclass functions to …
…eclass (#3189) This PR addresses #3170: - [x] Added initial front end pass `eqsat-create-eclasses` for the minimal example - [x] Added an initial test case for the pass
- Loading branch information
1 parent
f099cce
commit 002fc52
Showing
3 changed files
with
103 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
// RUN: xdsl-opt -p eqsat-create-eclasses %s | filecheck %s | ||
|
||
func.func @test(%x : index) -> (index) { | ||
%c2 = arith.constant 2 : index | ||
func.return %c2 : index | ||
} | ||
|
||
// CHECK: func.func @test(%x : index) -> index { | ||
// CHECK-NEXT: %x_1 = eqsat.eclass %x : index | ||
// CHECK-NEXT: %c2 = arith.constant 2 : index | ||
// CHECK-NEXT: %c2_1 = eqsat.eclass %c2 : index | ||
// CHECK-NEXT: func.return %c2_1 : index | ||
// CHECK-NEXT: } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from xdsl.context import MLContext | ||
from xdsl.dialects import builtin, eqsat, func | ||
from xdsl.ir import Block | ||
from xdsl.passes import ModulePass | ||
from xdsl.pattern_rewriter import ( | ||
GreedyRewritePatternApplier, | ||
PatternRewriter, | ||
PatternRewriteWalker, | ||
RewritePattern, | ||
op_type_rewrite_pattern, | ||
) | ||
from xdsl.rewriter import InsertPoint, Rewriter | ||
from xdsl.utils.exceptions import DiagnosticException | ||
|
||
|
||
def insert_eclass_ops(block: Block): | ||
# Insert eqsat.eclass for each operation | ||
for op in block.ops: | ||
results = op.results | ||
|
||
# Skip special ops such as return ops | ||
if isinstance(op, func.Return): | ||
continue | ||
|
||
if len(results) != 1: | ||
raise DiagnosticException("Ops with non-single results not handled") | ||
|
||
eclass_op = eqsat.EClassOp(results[0]) | ||
insertion_point = InsertPoint.after(op) | ||
Rewriter.insert_op(eclass_op, insertion_point) | ||
results[0].replace_by_if( | ||
eclass_op.results[0], lambda u: not isinstance(u.operation, eqsat.EClassOp) | ||
) | ||
|
||
# Insert eqsat.eclass for each arg | ||
for arg in block.args: | ||
eclass_op = eqsat.EClassOp(arg) | ||
insertion_point = InsertPoint.at_start(block) | ||
Rewriter.insert_op(eclass_op, insertion_point) | ||
arg.replace_by_if( | ||
eclass_op.results[0], lambda u: not isinstance(u.operation, eqsat.EClassOp) | ||
) | ||
|
||
|
||
class InsertEclassOps(RewritePattern): | ||
""" | ||
Inserts a `eqsat.eclass` after each operation except module op and function op. | ||
""" | ||
|
||
@op_type_rewrite_pattern | ||
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): | ||
insert_eclass_ops(op.body.block) | ||
|
||
|
||
class EqsatCreateEclasses(ModulePass): | ||
""" | ||
Create initial eclasses from an MLIR program. | ||
Input example: | ||
```mlir | ||
func.func @test(%a : index, %b : index) -> (index) { | ||
%c = arith.addi %a, %b : index | ||
func.return %c : index | ||
} | ||
``` | ||
Output example: | ||
```mlir | ||
func.func @test(%a : index, %b : index) -> (index) { | ||
%a_eq = eqsat.eclass %a : index | ||
%b_eq = eqsat.eclass %b : index | ||
%c = arith.addi %a_eq, %b_eq : index | ||
%c_eq = eqsat.eclass %c : index | ||
func.return %c_eq : index | ||
} | ||
``` | ||
""" | ||
|
||
name = "eqsat-create-eclasses" | ||
|
||
def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: | ||
PatternRewriteWalker( | ||
GreedyRewritePatternApplier([InsertEclassOps()]), | ||
apply_recursively=False, | ||
).rewrite_module(op) |