From 2c09db86606789eea2814860abcf55d04663d773 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Mon, 9 Dec 2024 11:45:42 +0000 Subject: [PATCH] interactive: small refactor (#3601) Slight refactor to `get_all_possible_rewrites` to not be quite so eager in walking the module lots of times. --- tests/interactive/test_rewrites.py | 4 ++-- xdsl/interactive/rewrites.py | 35 +++++++++++------------------- 2 files changed, 15 insertions(+), 24 deletions(-) diff --git a/tests/interactive/test_rewrites.py b/tests/interactive/test_rewrites.py index b9010b093f..04a650f80a 100644 --- a/tests/interactive/test_rewrites.py +++ b/tests/interactive/test_rewrites.py @@ -45,7 +45,7 @@ def test_get_all_possible_rewrite(): parser = Parser(ctx, prog) module = parser.parse_module() - expected_res = ( + expected_res = [ ( IndexedIndividualRewrite( 1, IndividualRewrite(operation="test.op", pattern="TestRewrite") @@ -57,7 +57,7 @@ def test_get_all_possible_rewrite(): rewrite=IndividualRewrite(operation="test.op", pattern="TestRewrite"), ) ), - ) + ] res = get_all_possible_rewrites(module, {"test.op": {"TestRewrite": Rewrite()}}) assert res == expected_res diff --git a/xdsl/interactive/rewrites.py b/xdsl/interactive/rewrites.py index 421df1eace..985cdbcb18 100644 --- a/xdsl/interactive/rewrites.py +++ b/xdsl/interactive/rewrites.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import NamedTuple from xdsl.dialects.builtin import ModuleOp @@ -26,7 +27,7 @@ class IndexedIndividualRewrite(NamedTuple): def convert_indexed_individual_rewrites_to_available_pass( - rewrites: tuple[IndexedIndividualRewrite, ...], current_module: ModuleOp + rewrites: Sequence[IndexedIndividualRewrite], current_module: ModuleOp ) -> tuple[AvailablePass, ...]: """ Function that takes a tuple of rewrites, converts each rewrite into an IndividualRewrite pass and returns the tuple of AvailablePass. @@ -51,40 +52,30 @@ def convert_indexed_individual_rewrites_to_available_pass( def get_all_possible_rewrites( - op: ModuleOp, + module: ModuleOp, rewrite_by_name: dict[str, dict[str, RewritePattern]], -) -> tuple[IndexedIndividualRewrite, ...]: +) -> Sequence[IndexedIndividualRewrite]: """ Function that takes a sequence of IndividualRewrite Patterns and a ModuleOp, and returns the possible rewrites. Issue filed: https://github.com/xdslproject/xdsl/issues/2162 """ - old_module = op.clone() - num_ops = len(list(old_module.walk())) - current_module = old_module.clone() + res: list[IndexedIndividualRewrite] = [] - res: tuple[IndexedIndividualRewrite, ...] = () - - for op_idx in range(num_ops): - matched_op = list(current_module.walk())[op_idx] + for op_idx, matched_op in enumerate(module.walk()): if matched_op.name not in rewrite_by_name: continue pattern_by_name = rewrite_by_name[matched_op.name] - for pattern_name, pattern in pattern_by_name.items(): - rewriter = PatternRewriter(matched_op) - pattern.match_and_rewrite(matched_op, rewriter) + cloned_op = tuple(module.clone().walk())[op_idx] + rewriter = PatternRewriter(cloned_op) + pattern.match_and_rewrite(cloned_op, rewriter) if rewriter.has_done_action: - res = ( - *res, - ( - IndexedIndividualRewrite( - op_idx, IndividualRewrite(matched_op.name, pattern_name) - ) - ), + res.append( + IndexedIndividualRewrite( + op_idx, IndividualRewrite(cloned_op.name, pattern_name) + ) ) - current_module = old_module.clone() - matched_op = list(current_module.walk())[op_idx] return res