From e6a0d1dea286c7b9e95303183061620c70414b19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20T=C3=B3th?= Date: Tue, 3 Dec 2024 09:29:11 +0100 Subject: [PATCH] Refactor rule parsing (#4699) * Remove unnecessary cases * Make pattern matching logic more concise * Process three cases of simplification rules: `AppRule`, `CeilRule`, `EqualsRule` --- pyk/src/pyk/kore/rule.py | 270 +++++++++++--------- pyk/src/tests/integration/kore/test_rule.py | 6 +- 2 files changed, 160 insertions(+), 116 deletions(-) diff --git a/pyk/src/pyk/kore/rule.py b/pyk/src/pyk/kore/rule.py index 680d91082f..8bc7fb965e 100644 --- a/pyk/src/pyk/kore/rule.py +++ b/pyk/src/pyk/kore/rule.py @@ -1,25 +1,44 @@ -"""Parse KORE axioms into rewrite rules. - -Based on the [LLVM Backend's implementation](https://github.com/runtimeverification/llvm-backend/blob/d5eab4b0f0e610bc60843ebb482f79c043b92702/lib/ast/pattern_matching.cpp). -""" - from __future__ import annotations +import logging from abc import ABC from dataclasses import dataclass -from typing import TYPE_CHECKING, final +from typing import TYPE_CHECKING, Generic, TypeVar, final from .prelude import inj -from .syntax import And, App, Axiom, Ceil, Equals, EVar, Implies, In, Not, Rewrites, SortVar, String, Top +from .syntax import ( + DV, + And, + App, + Axiom, + Ceil, + Equals, + EVar, + Implies, + In, + Not, + Pattern, + Rewrites, + SortApp, + SortVar, + String, + Top, +) if TYPE_CHECKING: from typing import Final - from .syntax import Definition, Pattern + from .syntax import Definition Attrs = dict[str, tuple[Pattern, ...]] +P = TypeVar('P', bound=Pattern) + + +_LOGGER: Final = logging.getLogger(__name__) + + # There's a simplification rule with irregular form in the prelude module INJ. # This rule is skipped in Rule.extract_all. _S1, _S2, _S3, _R = (SortVar(name) for name in ['S1', 'S2', 'S3', 'R']) @@ -56,33 +75,35 @@ def from_axiom(axiom: Axiom) -> Rule: if isinstance(axiom.pattern, Rewrites): return RewriteRule.from_axiom(axiom) - if 'simplification' in axiom.attrs_by_key: - return SimpliRule.from_axiom(axiom) + if 'simplification' not in axiom.attrs_by_key: + return FunctionRule.from_axiom(axiom) - return FunctionRule.from_axiom(axiom) + match axiom.pattern: + case Implies(right=Equals(left=App())): + return AppRule.from_axiom(axiom) + case Implies(right=Equals(left=Ceil())): + return CeilRule.from_axiom(axiom) + case Implies(right=Equals(left=Equals())): + return EqualsRule.from_axiom(axiom) + case _: + raise ValueError(f'Cannot parse simplification rule: {axiom.text}') @staticmethod def extract_all(defn: Definition) -> list[Rule]: - return [Rule.from_axiom(axiom) for axiom in defn.axioms if Rule._is_rule(axiom)] - - @staticmethod - def _is_rule(axiom: Axiom) -> bool: - if axiom == _INJ_AXIOM: - return False - - if any(attr in axiom.attrs_by_key for attr in _SKIPPED_ATTRS): - return False + def is_rule(axiom: Axiom) -> bool: + if axiom == _INJ_AXIOM: + return False - match axiom.pattern: - case Implies(right=Equals(left=Ceil())): - # Ceil rule + if any(attr in axiom.attrs_by_key for attr in _SKIPPED_ATTRS): return False - return True + return True + + return [Rule.from_axiom(axiom) for axiom in defn.axioms if is_rule(axiom)] @final -@dataclass +@dataclass(frozen=True) class RewriteRule(Rule): lhs: App rhs: App @@ -95,8 +116,7 @@ class RewriteRule(Rule): @staticmethod def from_axiom(axiom: Axiom) -> RewriteRule: - lhs, req, ctx = RewriteRule._extract_lhs(axiom) - rhs, ens = RewriteRule._extract_rhs(axiom) + lhs, rhs, req, ens, ctx = RewriteRule._extract(axiom) priority = _extract_priority(axiom) uid = _extract_uid(axiom) label = _extract_label(axiom) @@ -112,51 +132,35 @@ def from_axiom(axiom: Axiom) -> RewriteRule: ) @staticmethod - def _extract_lhs(axiom: Axiom) -> tuple[App, Pattern | None, EVar | None]: - req: Pattern | None = None - # Cases 0-5 of get_left_hand_side - # Cases 5-10 of get_requires + def _extract(axiom: Axiom) -> tuple[App, App, Pattern | None, Pattern | None, EVar | None]: match axiom.pattern: - case Rewrites(left=And(ops=(Top(), lhs))): - pass - case Rewrites(left=And(ops=(Equals(left=req), lhs))): - pass - case Rewrites(left=And(ops=(lhs, Top()))): - pass - case Rewrites(left=And(ops=(lhs, Equals(left=req)))): - pass - case Rewrites(left=And(ops=(Not(), And(ops=(Top(), lhs))))): - pass - case Rewrites(left=And(ops=(Not(), And(ops=(Equals(left=req), lhs))))): + case Rewrites(left=And(ops=(_lhs, _req)), right=_rhs): pass case _: - raise ValueError(f'Cannot extract LHS from axiom: {axiom.text}') + raise ValueError(f'Cannot extract rewrite rule from axiom: {axiom.text}') ctx: EVar | None = None - match lhs: - case App("Lbl'-LT-'generatedTop'-GT-'") as app: + match _lhs: + case App("Lbl'-LT-'generatedTop'-GT-'") as lhs: pass - case And(_, (App("Lbl'-LT-'generatedTop'-GT-'") as app, EVar("Var'Hash'Configuration") as ctx)): + case And(_, (App("Lbl'-LT-'generatedTop'-GT-'") as lhs, EVar("Var'Hash'Configuration") as ctx)): pass + case _: + raise ValueError(f'Cannot extract LHS configuration from axiom: {axiom.text}') - return app, req, ctx - - @staticmethod - def _extract_rhs(axiom: Axiom) -> tuple[App, Pattern | None]: - # Case 2 of get_right_hand_side: - # 2: rhs(\rewrites(_, \and(X, Y))) = get_builtin(\and(X, Y)) - # Below is a special case without get_builtin - match axiom.pattern: - case Rewrites(right=And(ops=(App("Lbl'-LT-'generatedTop'-GT-'") as rhs, Top() | Equals() as _ens))): + req = _extract_condition(_req) + rhs, ens = _extract_rhs(_rhs) + match rhs: + case App("Lbl'-LT-'generatedTop'-GT-'"): pass case _: - raise ValueError(f'Cannot extract RHS from axiom: {axiom.text}') - ens = _extract_ensures(_ens) - return rhs, ens + raise ValueError(f'Cannot extract RHS configuration from axiom: {axiom.text}') + + return lhs, rhs, req, ens, ctx @final -@dataclass +@dataclass(frozen=True) class FunctionRule(Rule): lhs: App rhs: Pattern @@ -166,9 +170,7 @@ class FunctionRule(Rule): @staticmethod def from_axiom(axiom: Axiom) -> FunctionRule: - args, req = FunctionRule._extract_args(axiom) - app, rhs, ens = FunctionRule._extract_rhs(axiom) - lhs = app.let(args=args) + lhs, rhs, req, ens = FunctionRule._extract(axiom) priority = _extract_priority(axiom) return FunctionRule( lhs=lhs, @@ -179,60 +181,85 @@ def from_axiom(axiom: Axiom) -> FunctionRule: ) @staticmethod - def _extract_args(axiom: Axiom) -> tuple[tuple[Pattern, ...], Pattern | None]: - req: Pattern | None = None - # Cases 7-10 of get_left_hand_side - # Cases 0-3 of get_requires + def _extract(axiom: Axiom) -> tuple[App, Pattern, Pattern | None, Pattern | None]: match axiom.pattern: - case Implies(left=And(ops=(Top(), pat))): - return FunctionRule._get_patterns(pat), req - case Implies(left=And(ops=(Equals(left=req), pat))): - return FunctionRule._get_patterns(pat), req - case Implies(left=And(ops=(Not(), And(ops=(Top(), pat))))): - return FunctionRule._get_patterns(pat), req - case Implies(left=And(ops=(Not(), And(ops=(Equals(left=req), pat))))): - return FunctionRule._get_patterns(pat), req + case Implies( + left=And( + ops=(Not(), And(ops=(_req, _args))) | (_req, _args), + ), + right=Equals(left=App() as app, right=_rhs), + ): + args = FunctionRule._extract_args(_args) + lhs = app.let(args=args) + req = _extract_condition(_req) + rhs, ens = _extract_rhs(_rhs) + return lhs, rhs, req, ens case _: - raise ValueError(f'Cannot extract LHS from axiom: {axiom.text}') + raise ValueError(f'Cannot extract function rule from axiom: {axiom.text}') @staticmethod - def _get_patterns(pattern: Pattern) -> tuple[Pattern, ...]: - # get_patterns(\top()) = [] - # get_patterns(\and(\in(_, X), Y)) = X : get_patterns(Y) + def _extract_args(pattern: Pattern) -> tuple[Pattern, ...]: match pattern: case Top(): return () - case And(ops=(In(right=x), y)): - return (x,) + FunctionRule._get_patterns(y) + case And(ops=(In(left=EVar(), right=arg), rest)): + return (arg,) + FunctionRule._extract_args(rest) case _: - raise AssertionError() + raise ValueError(f'Cannot extract argument list from pattern: {pattern.text}') + + +class SimpliRule(Rule, Generic[P], ABC): + lhs: P @staticmethod - def _extract_rhs(axiom: Axiom) -> tuple[App, Pattern, Pattern | None]: - # Case 0 of get_right_hand_side + def _extract(axiom: Axiom, lhs_type: type[P]) -> tuple[P, Pattern, Pattern | None, Pattern | None]: match axiom.pattern: - case Implies(right=Equals(left=App() as app, right=And(ops=(rhs, Top() | Equals() as _ens)))): - pass + case Implies(left=_req, right=Equals(left=lhs, right=_rhs)): + req = _extract_condition(_req) + rhs, ens = _extract_rhs(_rhs) + if not isinstance(lhs, lhs_type): + raise ValueError(f'Invalid LHS type from simplification axiom: {axiom.text}') + return lhs, rhs, req, ens case _: - raise ValueError(f'Cannot extract RHS from axiom: {axiom.text}') - ens = _extract_ensures(_ens) - return app, rhs, ens + raise ValueError(f'Cannot extract simplification rule from axiom: {axiom.text}') @final -@dataclass -class SimpliRule(Rule): - lhs: Pattern +@dataclass(frozen=True) +class AppRule(SimpliRule[App]): + lhs: App + rhs: Pattern + req: Pattern | None + ens: Pattern | None + priority: int + + @staticmethod + def from_axiom(axiom: Axiom) -> AppRule: + lhs, rhs, req, ens = SimpliRule._extract(axiom, App) + priority = _extract_priority(axiom) + return AppRule( + lhs=lhs, + rhs=rhs, + req=req, + ens=ens, + priority=priority, + ) + + +@final +@dataclass(frozen=True) +class CeilRule(SimpliRule): + lhs: Ceil rhs: Pattern req: Pattern | None ens: Pattern | None priority: int @staticmethod - def from_axiom(axiom: Axiom) -> SimpliRule: - lhs, rhs, req, ens = SimpliRule._extract(axiom) + def from_axiom(axiom: Axiom) -> CeilRule: + lhs, rhs, req, ens = SimpliRule._extract(axiom, Ceil) priority = _extract_priority(axiom) - return SimpliRule( + return CeilRule( lhs=lhs, rhs=rhs, req=req, @@ -240,32 +267,47 @@ def from_axiom(axiom: Axiom) -> SimpliRule: priority=priority, ) + +@final +@dataclass(frozen=True) +class EqualsRule(SimpliRule): + lhs: Equals + rhs: Pattern + req: Pattern | None + ens: Pattern | None + priority: int + @staticmethod - def _extract(axiom: Axiom) -> tuple[Pattern, Pattern, Pattern | None, Pattern | None]: - req: Pattern | None = None - # Cases 11-12 of get_left_hand_side - # Case 0 of get_right_hand_side - match axiom.pattern: - case Implies(left=Top(), right=Equals(left=lhs, right=And(ops=(rhs, Top() | Equals() as _ens)))): - pass - case Implies(left=Equals(left=req), right=Equals(left=lhs, right=And(ops=(rhs, Top() | Equals() as _ens)))): - pass - case Implies(right=Equals(left=Ceil())): - raise ValueError(f'Axiom is a ceil rule: {axiom.text}') - case _: - raise ValueError(f'Cannot extract simplification rule from axiom: {axiom.text}') - ens = _extract_ensures(_ens) - return lhs, rhs, req, ens + def from_axiom(axiom: Axiom) -> EqualsRule: + lhs, rhs, req, ens = SimpliRule._extract(axiom, Equals) + if not isinstance(lhs, Equals): + raise ValueError(f'Cannot extract LHS as Equals from axiom: {axiom.text}') + priority = _extract_priority(axiom) + return EqualsRule( + lhs=lhs, + rhs=rhs, + req=req, + ens=ens, + priority=priority, + ) + + +def _extract_rhs(pattern: Pattern) -> tuple[Pattern, Pattern | None]: + match pattern: + case And(ops=(rhs, _ens)): + return rhs, _extract_condition(_ens) + case _: + raise ValueError(f'Cannot extract RHS from pattern: {pattern.text}') -def _extract_ensures(ens: Top | Equals | None) -> Pattern | None: - match ens: +def _extract_condition(pattern: Pattern) -> Pattern | None: + match pattern: case Top(): return None - case Equals(left=res): - return res + case Equals(left=cond, right=DV(SortApp('SortBool'), String('true'))): + return cond case _: - raise AssertionError() + raise ValueError(f'Cannot extract condition from pattern: {pattern.text}') def _extract_uid(axiom: Axiom) -> str: diff --git a/pyk/src/tests/integration/kore/test_rule.py b/pyk/src/tests/integration/kore/test_rule.py index 5f723146d8..fadb530548 100644 --- a/pyk/src/tests/integration/kore/test_rule.py +++ b/pyk/src/tests/integration/kore/test_rule.py @@ -18,7 +18,7 @@ @pytest.fixture(scope='module') def definition(kompile: Kompiler) -> Definition: main_file = K_FILES / 'imp.k' - definition_dir = kompile(main_file=main_file) + definition_dir = kompile(main_file=main_file, backend='haskell') kore_file = definition_dir / 'definition.kore' kore_text = kore_file.read_text() definition = KoreParser(kore_text).definition() @@ -33,4 +33,6 @@ def test_extract_all(definition: Definition) -> None: cnt = Counter(type(rule).__name__ for rule in rules) assert cnt['RewriteRule'] assert cnt['FunctionRule'] - assert cnt['SimpliRule'] + assert cnt['AppRule'] + assert cnt['CeilRule'] + assert cnt['EqualsRule']