diff --git a/pyproject.toml b/pyproject.toml index 109ac756c6..74acf0134c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dev = [ "nbval<0.12", "filecheck==1.0.1", "lit<19.0.0", - "marimo==0.9.31", + "marimo==0.9.32", "pre-commit==4.0.1", "ruff==0.8.2", "asv<0.7", @@ -34,8 +34,8 @@ dev = [ "pyright==1.1.390", ] gui = ["textual==0.89.1", "pyclip==0.7"] -jax = ["jax==0.4.36", "numpy==2.1.3"] -onnx = ["onnx==1.17.0", "numpy==2.1.3"] +jax = ["jax==0.4.36", "numpy==2.2.0"] +onnx = ["onnx==1.17.0", "numpy==2.2.0"] riscv = ["riscemu==2.2.7"] wgpu = ["wgpu==0.19.2"] diff --git a/tests/interactive/test_rewrites.py b/tests/interactive/test_rewrites.py index b9010b093f..04a650f80a 100644 --- a/tests/interactive/test_rewrites.py +++ b/tests/interactive/test_rewrites.py @@ -45,7 +45,7 @@ def test_get_all_possible_rewrite(): parser = Parser(ctx, prog) module = parser.parse_module() - expected_res = ( + expected_res = [ ( IndexedIndividualRewrite( 1, IndividualRewrite(operation="test.op", pattern="TestRewrite") @@ -57,7 +57,7 @@ def test_get_all_possible_rewrite(): rewrite=IndividualRewrite(operation="test.op", pattern="TestRewrite"), ) ), - ) + ] res = get_all_possible_rewrites(module, {"test.op": {"TestRewrite": Rewrite()}}) assert res == expected_res diff --git a/xdsl/interactive/rewrites.py b/xdsl/interactive/rewrites.py index 421df1eace..985cdbcb18 100644 --- a/xdsl/interactive/rewrites.py +++ b/xdsl/interactive/rewrites.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import NamedTuple from xdsl.dialects.builtin import ModuleOp @@ -26,7 +27,7 @@ class IndexedIndividualRewrite(NamedTuple): def convert_indexed_individual_rewrites_to_available_pass( - rewrites: tuple[IndexedIndividualRewrite, ...], current_module: ModuleOp + rewrites: Sequence[IndexedIndividualRewrite], current_module: ModuleOp ) -> tuple[AvailablePass, ...]: """ Function that takes a tuple of rewrites, converts each rewrite into an IndividualRewrite pass and returns the tuple of AvailablePass. @@ -51,40 +52,30 @@ def convert_indexed_individual_rewrites_to_available_pass( def get_all_possible_rewrites( - op: ModuleOp, + module: ModuleOp, rewrite_by_name: dict[str, dict[str, RewritePattern]], -) -> tuple[IndexedIndividualRewrite, ...]: +) -> Sequence[IndexedIndividualRewrite]: """ Function that takes a sequence of IndividualRewrite Patterns and a ModuleOp, and returns the possible rewrites. Issue filed: https://github.com/xdslproject/xdsl/issues/2162 """ - old_module = op.clone() - num_ops = len(list(old_module.walk())) - current_module = old_module.clone() + res: list[IndexedIndividualRewrite] = [] - res: tuple[IndexedIndividualRewrite, ...] = () - - for op_idx in range(num_ops): - matched_op = list(current_module.walk())[op_idx] + for op_idx, matched_op in enumerate(module.walk()): if matched_op.name not in rewrite_by_name: continue pattern_by_name = rewrite_by_name[matched_op.name] - for pattern_name, pattern in pattern_by_name.items(): - rewriter = PatternRewriter(matched_op) - pattern.match_and_rewrite(matched_op, rewriter) + cloned_op = tuple(module.clone().walk())[op_idx] + rewriter = PatternRewriter(cloned_op) + pattern.match_and_rewrite(cloned_op, rewriter) if rewriter.has_done_action: - res = ( - *res, - ( - IndexedIndividualRewrite( - op_idx, IndividualRewrite(matched_op.name, pattern_name) - ) - ), + res.append( + IndexedIndividualRewrite( + op_idx, IndividualRewrite(cloned_op.name, pattern_name) + ) ) - current_module = old_module.clone() - matched_op = list(current_module.walk())[op_idx] return res