Skip to content

Commit

Permalink
Merge branch 'main' into sasha/interactive/better-message
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh authored Dec 9, 2024
2 parents 9516cdb + 2c09db8 commit 4b4d4da
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 27 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dev = [
"nbval<0.12",
"filecheck==1.0.1",
"lit<19.0.0",
"marimo==0.9.31",
"marimo==0.9.32",
"pre-commit==4.0.1",
"ruff==0.8.2",
"asv<0.7",
Expand All @@ -34,8 +34,8 @@ dev = [
"pyright==1.1.390",
]
gui = ["textual==0.89.1", "pyclip==0.7"]
jax = ["jax==0.4.36", "numpy==2.1.3"]
onnx = ["onnx==1.17.0", "numpy==2.1.3"]
jax = ["jax==0.4.36", "numpy==2.2.0"]
onnx = ["onnx==1.17.0", "numpy==2.2.0"]
riscv = ["riscemu==2.2.7"]
wgpu = ["wgpu==0.19.2"]

Expand Down
4 changes: 2 additions & 2 deletions tests/interactive/test_rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_get_all_possible_rewrite():
parser = Parser(ctx, prog)
module = parser.parse_module()

expected_res = (
expected_res = [
(
IndexedIndividualRewrite(
1, IndividualRewrite(operation="test.op", pattern="TestRewrite")
Expand All @@ -57,7 +57,7 @@ def test_get_all_possible_rewrite():
rewrite=IndividualRewrite(operation="test.op", pattern="TestRewrite"),
)
),
)
]

res = get_all_possible_rewrites(module, {"test.op": {"TestRewrite": Rewrite()}})
assert res == expected_res
Expand Down
35 changes: 13 additions & 22 deletions xdsl/interactive/rewrites.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import NamedTuple

from xdsl.dialects.builtin import ModuleOp
Expand Down Expand Up @@ -26,7 +27,7 @@ class IndexedIndividualRewrite(NamedTuple):


def convert_indexed_individual_rewrites_to_available_pass(
rewrites: tuple[IndexedIndividualRewrite, ...], current_module: ModuleOp
rewrites: Sequence[IndexedIndividualRewrite], current_module: ModuleOp
) -> tuple[AvailablePass, ...]:
"""
Function that takes a tuple of rewrites, converts each rewrite into an IndividualRewrite pass and returns the tuple of AvailablePass.
Expand All @@ -51,40 +52,30 @@ def convert_indexed_individual_rewrites_to_available_pass(


def get_all_possible_rewrites(
op: ModuleOp,
module: ModuleOp,
rewrite_by_name: dict[str, dict[str, RewritePattern]],
) -> tuple[IndexedIndividualRewrite, ...]:
) -> Sequence[IndexedIndividualRewrite]:
"""
Function that takes a sequence of IndividualRewrite Patterns and a ModuleOp, and
returns the possible rewrites.
Issue filed: https://github.com/xdslproject/xdsl/issues/2162
"""
old_module = op.clone()
num_ops = len(list(old_module.walk()))

current_module = old_module.clone()
res: list[IndexedIndividualRewrite] = []

res: tuple[IndexedIndividualRewrite, ...] = ()

for op_idx in range(num_ops):
matched_op = list(current_module.walk())[op_idx]
for op_idx, matched_op in enumerate(module.walk()):
if matched_op.name not in rewrite_by_name:
continue
pattern_by_name = rewrite_by_name[matched_op.name]

for pattern_name, pattern in pattern_by_name.items():
rewriter = PatternRewriter(matched_op)
pattern.match_and_rewrite(matched_op, rewriter)
cloned_op = tuple(module.clone().walk())[op_idx]
rewriter = PatternRewriter(cloned_op)
pattern.match_and_rewrite(cloned_op, rewriter)
if rewriter.has_done_action:
res = (
*res,
(
IndexedIndividualRewrite(
op_idx, IndividualRewrite(matched_op.name, pattern_name)
)
),
res.append(
IndexedIndividualRewrite(
op_idx, IndividualRewrite(cloned_op.name, pattern_name)
)
)
current_module = old_module.clone()
matched_op = list(current_module.walk())[op_idx]

return res

0 comments on commit 4b4d4da

Please sign in to comment.