Skip to content

Commit

Permalink
interactive: small refactor (#3601)
Browse files Browse the repository at this point in the history
Slight refactor to `get_all_possible_rewrites` to not be quite so eager
in walking the module lots of times.
  • Loading branch information
alexarice authored Dec 9, 2024
1 parent 00affad commit 2c09db8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 24 deletions.
4 changes: 2 additions & 2 deletions tests/interactive/test_rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_get_all_possible_rewrite():
parser = Parser(ctx, prog)
module = parser.parse_module()

expected_res = (
expected_res = [
(
IndexedIndividualRewrite(
1, IndividualRewrite(operation="test.op", pattern="TestRewrite")
Expand All @@ -57,7 +57,7 @@ def test_get_all_possible_rewrite():
rewrite=IndividualRewrite(operation="test.op", pattern="TestRewrite"),
)
),
)
]

res = get_all_possible_rewrites(module, {"test.op": {"TestRewrite": Rewrite()}})
assert res == expected_res
Expand Down
35 changes: 13 additions & 22 deletions xdsl/interactive/rewrites.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import NamedTuple

from xdsl.dialects.builtin import ModuleOp
Expand Down Expand Up @@ -26,7 +27,7 @@ class IndexedIndividualRewrite(NamedTuple):


def convert_indexed_individual_rewrites_to_available_pass(
rewrites: tuple[IndexedIndividualRewrite, ...], current_module: ModuleOp
rewrites: Sequence[IndexedIndividualRewrite], current_module: ModuleOp
) -> tuple[AvailablePass, ...]:
"""
Function that takes a tuple of rewrites, converts each rewrite into an IndividualRewrite pass and returns the tuple of AvailablePass.
Expand All @@ -51,40 +52,30 @@ def convert_indexed_individual_rewrites_to_available_pass(


def get_all_possible_rewrites(
op: ModuleOp,
module: ModuleOp,
rewrite_by_name: dict[str, dict[str, RewritePattern]],
) -> tuple[IndexedIndividualRewrite, ...]:
) -> Sequence[IndexedIndividualRewrite]:
"""
Function that takes a sequence of IndividualRewrite Patterns and a ModuleOp, and
returns the possible rewrites.
Issue filed: https://github.com/xdslproject/xdsl/issues/2162
"""
old_module = op.clone()
num_ops = len(list(old_module.walk()))

current_module = old_module.clone()
res: list[IndexedIndividualRewrite] = []

res: tuple[IndexedIndividualRewrite, ...] = ()

for op_idx in range(num_ops):
matched_op = list(current_module.walk())[op_idx]
for op_idx, matched_op in enumerate(module.walk()):
if matched_op.name not in rewrite_by_name:
continue
pattern_by_name = rewrite_by_name[matched_op.name]

for pattern_name, pattern in pattern_by_name.items():
rewriter = PatternRewriter(matched_op)
pattern.match_and_rewrite(matched_op, rewriter)
cloned_op = tuple(module.clone().walk())[op_idx]
rewriter = PatternRewriter(cloned_op)
pattern.match_and_rewrite(cloned_op, rewriter)
if rewriter.has_done_action:
res = (
*res,
(
IndexedIndividualRewrite(
op_idx, IndividualRewrite(matched_op.name, pattern_name)
)
),
res.append(
IndexedIndividualRewrite(
op_idx, IndividualRewrite(cloned_op.name, pattern_name)
)
)
current_module = old_module.clone()
matched_op = list(current_module.walk())[op_idx]

return res

0 comments on commit 2c09db8

Please sign in to comment.