diff --git a/pyk/.gitignore b/pyk/.gitignore index e8fd5033e0..67864bdfbf 100644 --- a/pyk/.gitignore +++ b/pyk/.gitignore @@ -12,4 +12,5 @@ __pycache__/ *.debug-log .idea/ +.vscode/ .DS_Store \ No newline at end of file diff --git a/pyk/Makefile b/pyk/Makefile index f41d7c33b8..684ce7801f 100644 --- a/pyk/Makefile +++ b/pyk/Makefile @@ -105,6 +105,9 @@ pyupgrade: poetry-install $(POETRY_RUN) pyupgrade --py310-plus $(SRC_FILES) +pr: format check pyupgrade + git rebase -i develop + # Documentation DOCS_API_DIR := docs/api diff --git a/pyk/src/pyk/cterm/cterm.py b/pyk/src/pyk/cterm/cterm.py index 0c04e0ff20..9ff3bc9737 100644 --- a/pyk/src/pyk/cterm/cterm.py +++ b/pyk/src/pyk/cterm/cterm.py @@ -24,7 +24,7 @@ from ..prelude.k import GENERATED_TOP_CELL, K from ..prelude.kbool import andBool, orBool from ..prelude.ml import is_bottom, is_top, mlAnd, mlBottom, mlEquals, mlEqualsTrue, mlImplies, mlTop -from ..utils import unique +from ..utils import not_none, unique if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -203,6 +203,8 @@ def anti_unify( - ``csubst2``: Constrained substitution to apply to `cterm` to obtain `other`. """ new_config, self_subst, other_subst = anti_unify(self.config, other.config, kdef=kdef) + # todo: It's not able to distinguish between constraints in different cterms, + # because variable names may be used inconsistently in different cterms. common_constraints = [constraint for constraint in self.constraints if constraint in other.constraints] self_unique_constraints = [ ml_pred_to_bool(constraint) for constraint in self.constraints if constraint not in other.constraints @@ -434,3 +436,30 @@ def cterm_build_rule( keep_vars, defunc_with=defunc_with, ) + + +def cterms_anti_unify( + cterms: Iterable[CTerm], keep_values: bool = False, kdef: KDefinition | None = None +) -> tuple[CTerm, list[CSubst]]: + """Given many `CTerm` instances, find a more general `CTerm` which can instantiate to all. + + Args: + cterms: `CTerm`s to consider for finding a more general `CTerm` with this one. + keep_values: do not discard information about abstracted variables in returned result. + kdef (optional): `KDefinition` to make analysis more precise. + + Returns: + A tuple ``(cterm, csubsts)`` where + + - ``cterm``: More general `CTerm` than any of the input `CTerm`s. + - ``csubsts``: List of `CSubst` which, when applied to `cterm`, yield the input `CTerm`s. + """ + # TODO: optimize this function, reduce useless auto-generated variables. + cterms = list(cterms) + if not cterms: + raise ValueError('Anti-unification failed, no CTerms provided') + merged_cterm = cterms[0] + for cterm in cterms[1:]: + merged_cterm = merged_cterm.anti_unify(cterm, keep_values, kdef)[0] + csubsts = [not_none(cterm_match(merged_cterm, cterm)) for cterm in cterms] + return merged_cterm, csubsts diff --git a/pyk/src/pyk/kcfg/kcfg.py b/pyk/src/pyk/kcfg/kcfg.py index 5937f02cee..5289bf210a 100644 --- a/pyk/src/pyk/kcfg/kcfg.py +++ b/pyk/src/pyk/kcfg/kcfg.py @@ -942,12 +942,20 @@ def contains_merged_edge(self, edge: MergedEdge) -> bool: return edge == other return False - def create_merged_edge(self, source_id: NodeIdLike, target_id: NodeIdLike, edges: Iterable[Edge]) -> MergedEdge: + def create_merged_edge( + self, source_id: NodeIdLike, target_id: NodeIdLike, edges: Iterable[Edge | MergedEdge] + ) -> MergedEdge: if len(list(edges)) == 0: raise ValueError(f'Cannot build KCFG MergedEdge with no edges: {edges}') source = self.node(source_id) target = self.node(target_id) - merged_edge = KCFG.MergedEdge(source, target, tuple(edges)) + flatten_edges: list[KCFG.Edge] = [] + for edge in edges: + if isinstance(edge, KCFG.MergedEdge): + flatten_edges.extend(edge.edges) + else: + flatten_edges.append(edge) + merged_edge = KCFG.MergedEdge(source, target, tuple(flatten_edges)) self.add_successor(merged_edge) return merged_edge @@ -959,6 +967,13 @@ def remove_merged_edge(self, source_id: NodeIdLike, target_id: NodeIdLike) -> No raise ValueError(f'MergedEdge does not exist: {source_id} -> {target_id}') self._merged_edges.pop(source_id) + def general_edges( + self, *, source_id: NodeIdLike | None = None, target_id: NodeIdLike | None = None + ) -> list[Edge | MergedEdge]: + return self.edges(source_id=source_id, target_id=target_id) + self.merged_edges( + source_id=source_id, target_id=target_id + ) + def cover(self, source_id: NodeIdLike, target_id: NodeIdLike) -> Cover | None: source_id = self._resolve(source_id) target_id = self._resolve(target_id) diff --git a/pyk/src/pyk/kcfg/minimize.py b/pyk/src/pyk/kcfg/minimize.py index 32945a7269..5bd91a802a 100644 --- a/pyk/src/pyk/kcfg/minimize.py +++ b/pyk/src/pyk/kcfg/minimize.py @@ -3,20 +3,32 @@ from functools import reduce from typing import TYPE_CHECKING -from ..cterm import CTerm -from ..utils import not_none, single +from pyk.cterm import CTerm +from pyk.cterm.cterm import cterms_anti_unify +from pyk.utils import not_none, partition, single + +from .semantics import DefaultSemantics if TYPE_CHECKING: from collections.abc import Callable + from pyk.kast.outer import KDefinition + from .kcfg import KCFG, NodeIdLike + from .semantics import KCFGSemantics class KCFGMinimizer: kcfg: KCFG + semantics: KCFGSemantics + kdef: KDefinition | None - def __init__(self, kcfg: KCFG) -> None: + def __init__(self, kcfg: KCFG, heuristics: KCFGSemantics | None = None, kdef: KDefinition | None = None) -> None: + if heuristics is None: + heuristics = DefaultSemantics() self.kcfg = kcfg + self.semantics = heuristics + self.kdef = kdef def lift_edge(self, b_id: NodeIdLike) -> None: """Lift an edge up another edge directly preceding it. @@ -188,6 +200,94 @@ def _fold_lift(result: bool, finder_lifter: tuple[Callable, Callable]) -> bool: _fold_lift, [(self.kcfg.edges, self.lift_split_edge), (self.kcfg.splits, self.lift_split_split)], False ) + def merge_nodes(self) -> bool: + """Merge targets of Split for cutting down the number of branches, using heuristics KCFGSemantics.is_mergeable. + + Side Effect: The KCFG is rewritten by the following rewrite pattern, + - Match: A -|Split|-> A_i -|Edge|-> B_i + - Rewrite: + - if `B_x, B_y, ..., B_z are not mergeable` then unchanged + - if `B_x, B_y, ..., B_z are mergeable`, then + - A -|Split|-> A_x or A_y or ... or A_z + - A_x or A_y or ... or A_z -|Edge|-> B_x or B_y or ... or B_z + - B_x or B_y or ... or B_z -|Split|-> B_x, B_y, ..., B_z + + Specifically, when `B_merge = B_x or B_y or ... or B_z` + - `or`: fresh variables in places where the configurations differ + - `Edge` in A_merged -|Edge|-> B_merge: list of merged edges is from A_i -|Edge|-> B_i + - `Split` in B_merge -|Split|-> B_x, B_y, ..., B_z: subst for it is from A -|Split|-> A_1, A_2, ..., A_n + :param semantics: provides the is_mergeable heuristic + :return: whether any merge was performed + """ + + def _is_mergeable(x: KCFG.Edge | KCFG.MergedEdge, y: KCFG.Edge | KCFG.MergedEdge) -> bool: + return self.semantics.is_mergeable(x.target.cterm, y.target.cterm) + + # ---- Match ---- + + # A -|Split|> Ai -|Edge/MergedEdge|> Mergeable Bi + sub_graphs: list[tuple[KCFG.Split, list[list[KCFG.Edge | KCFG.MergedEdge]]]] = [] + + for split in self.kcfg.splits(): + _edges = [ + single(self.kcfg.general_edges(source_id=ai)) + for ai in split.target_ids + if self.kcfg.general_edges(source_id=ai) + ] + _partitions = partition(_edges, _is_mergeable) + if len(_partitions) < len(_edges): + sub_graphs.append((split, _partitions)) + + if not sub_graphs: + return False + + # ---- Rewrite ---- + + for split, edge_partitions in sub_graphs: + + # Remove the original sub-graphs + for p in edge_partitions: + if len(p) == 1: + continue + for e in p: + # TODO: remove the split and edges, then safely remove the nodes. + self.kcfg.remove_node(e.source.id) + + # Create A -|MergedEdge|-> Merged_Bi -|Split|-> Bi, if one edge partition covers all the splits + if len(edge_partitions) == 1: + merged_bi_cterm, merged_bi_subst = cterms_anti_unify( + [edge.target.cterm for edge in edge_partitions[0]], keep_values=True, kdef=self.kdef + ) + merged_bi = self.kcfg.create_node(merged_bi_cterm) + self.kcfg.create_merged_edge(split.source.id, merged_bi.id, edge_partitions[0]) + self.kcfg.create_split( + merged_bi.id, zip([e.target.id for e in edge_partitions[0]], merged_bi_subst, strict=True) + ) + continue + + # Create A -|Split|-> Others & Merged_Ai -|MergedEdge|-> Merged_Bi -|Split|-> Bi + _split_nodes: list[NodeIdLike] = [] + for edge_partition in edge_partitions: + if len(edge_partition) == 1: + _split_nodes.append(edge_partition[0].source.id) + continue + merged_ai_cterm, _ = cterms_anti_unify( + [ai2bi.source.cterm for ai2bi in edge_partition], keep_values=True, kdef=self.kdef + ) + merged_bi_cterm, merged_bi_subst = cterms_anti_unify( + [ai2bi.target.cterm for ai2bi in edge_partition], keep_values=True, kdef=self.kdef + ) + merged_ai = self.kcfg.create_node(merged_ai_cterm) + _split_nodes.append(merged_ai.id) + merged_bi = self.kcfg.create_node(merged_bi_cterm) + self.kcfg.create_merged_edge(merged_ai.id, merged_bi.id, edge_partition) + self.kcfg.create_split( + merged_bi.id, zip([ai2bi.target.id for ai2bi in edge_partition], merged_bi_subst, strict=True) + ) + self.kcfg.create_split_by_nodes(split.source.id, _split_nodes) + + return True + def minimize(self) -> None: """Minimize KCFG by repeatedly performing the lifting transformations. @@ -198,3 +298,7 @@ def minimize(self) -> None: while repeat: repeat = self.lift_edges() repeat = self.lift_splits() or repeat + + repeat = True + while repeat: + repeat = self.merge_nodes() diff --git a/pyk/src/pyk/kcfg/semantics.py b/pyk/src/pyk/kcfg/semantics.py index fc98409e60..629f259b79 100644 --- a/pyk/src/pyk/kcfg/semantics.py +++ b/pyk/src/pyk/kcfg/semantics.py @@ -34,6 +34,11 @@ def custom_step(self, c: CTerm) -> KCFGExtendResult | None: ... """Implement a custom semantic step.""" + @abstractmethod + def is_mergeable(self, c1: CTerm, c2: CTerm) -> bool: ... + + """Check whether or not the two given ``CTerm``s are mergeable. Must be transitive, commutative, and reflexive.""" + class DefaultSemantics(KCFGSemantics): def is_terminal(self, c: CTerm) -> bool: @@ -50,3 +55,6 @@ def can_make_custom_step(self, c: CTerm) -> bool: def custom_step(self, c: CTerm) -> KCFGExtendResult | None: return None + + def is_mergeable(self, c1: CTerm, c2: CTerm) -> bool: + return False diff --git a/pyk/src/pyk/utils.py b/pyk/src/pyk/utils.py index 243acda628..9221c31ce3 100644 --- a/pyk/src/pyk/utils.py +++ b/pyk/src/pyk/utils.py @@ -319,6 +319,36 @@ def repeat_last(iterable: Iterable[T]) -> Iterator[T]: yield last +def partition(iterable: Iterable[T], pred: Callable[[T, T], bool]) -> list[list[T]]: + """Partition the iterable into sublists based on the given predicate. + + predicate pred(_, _) should satisfy: + - pred(x, x) + - if pred(x, y) and pred(y, z) then pred(x, z); + - if pred(x, y) then pred(y, x); + """ + groups: list[list[T]] = [] + for item in iterable: + found = False + for group in groups: + group_matches = [] + for group_item in group: + group_match = pred(group_item, item) + if group_match != pred(item, group_item): + raise ValueError(f'Partitioning failed, predicate commutativity failed on: {(item, group_item)}') + group_matches.append(group_match) + if found and any(group_matches): + raise ValueError(f'Partitioning failed, item matched multiple groups: {item}') + if all(group_matches): + found = True + group.append(item) + elif any(group_matches): + raise ValueError(f'Partitioning failed, item matched only some elements of group: {(item, group)}') + if not found: + groups.append([item]) + return groups + + def nonempty_str(x: Any) -> str: if x is None: raise ValueError('Expected nonempty string, found: null.') diff --git a/pyk/src/tests/integration/ktool/test_imp.py b/pyk/src/tests/integration/ktool/test_imp.py index bf61b80867..0593db69eb 100644 --- a/pyk/src/tests/integration/ktool/test_imp.py +++ b/pyk/src/tests/integration/ktool/test_imp.py @@ -9,7 +9,7 @@ from pyk.cli.pyk import ProveOptions from pyk.kast.inner import KApply, KSequence, KVariable -from pyk.kcfg.semantics import KCFGSemantics +from pyk.kcfg.semantics import DefaultSemantics from pyk.ktool.prove_rpc import ProveRpc from pyk.proof import ProofStatus from pyk.testing import KCFGExploreTest, KProveTest @@ -25,12 +25,13 @@ from pyk.kast.outer import KDefinition from pyk.kcfg import KCFGExplore from pyk.kcfg.kcfg import KCFGExtendResult + from pyk.kcfg.semantics import KCFGSemantics from pyk.ktool.kprove import KProve _LOGGER: Final = logging.getLogger(__name__) -class ImpSemantics(KCFGSemantics): +class ImpSemantics(DefaultSemantics): definition: KDefinition | None def __init__(self, definition: KDefinition | None = None): diff --git a/pyk/src/tests/integration/proof/test_custom_step.py b/pyk/src/tests/integration/proof/test_custom_step.py index ed73b4b7b8..c8d1f1ad90 100644 --- a/pyk/src/tests/integration/proof/test_custom_step.py +++ b/pyk/src/tests/integration/proof/test_custom_step.py @@ -10,7 +10,7 @@ from pyk.kast.manip import set_cell from pyk.kcfg import KCFGExplore from pyk.kcfg.kcfg import Step -from pyk.kcfg.semantics import KCFGSemantics +from pyk.kcfg.semantics import DefaultSemantics from pyk.kcfg.show import KCFGShow from pyk.proof import APRProof, APRProver, ProofStatus from pyk.proof.show import APRProofNodePrinter @@ -26,6 +26,7 @@ from pyk.cterm import CTermSymbolic from pyk.kast.outer import KClaim from pyk.kcfg.kcfg import KCFGExtendResult + from pyk.kcfg.semantics import KCFGSemantics from pyk.ktool.kprove import KProve from pyk.utils import BugReport @@ -46,7 +47,7 @@ ) -class CustomStepSemanticsWithoutStep(KCFGSemantics): +class CustomStepSemanticsWithoutStep(DefaultSemantics): def is_terminal(self, c: CTerm) -> bool: k_cell = c.cell('K_CELL') if ( diff --git a/pyk/src/tests/integration/proof/test_goto.py b/pyk/src/tests/integration/proof/test_goto.py index 3a8b74aaff..25e7ce55c4 100644 --- a/pyk/src/tests/integration/proof/test_goto.py +++ b/pyk/src/tests/integration/proof/test_goto.py @@ -7,7 +7,7 @@ import pytest from pyk.kast.inner import KApply, KSequence -from pyk.kcfg.semantics import KCFGSemantics +from pyk.kcfg.semantics import DefaultSemantics from pyk.kcfg.show import KCFGShow from pyk.proof import APRProof, APRProver, ProofStatus from pyk.proof.show import APRProofNodePrinter @@ -24,12 +24,13 @@ from pyk.kast.outer import KDefinition from pyk.kcfg import KCFGExplore from pyk.kcfg.kcfg import KCFGExtendResult + from pyk.kcfg.semantics import KCFGSemantics from pyk.ktool.kprove import KProve _LOGGER: Final = logging.getLogger(__name__) -class GotoSemantics(KCFGSemantics): +class GotoSemantics(DefaultSemantics): def is_terminal(self, c: CTerm) -> bool: return False diff --git a/pyk/src/tests/integration/proof/test_imp.py b/pyk/src/tests/integration/proof/test_imp.py index da4a808200..2b63e3b069 100644 --- a/pyk/src/tests/integration/proof/test_imp.py +++ b/pyk/src/tests/integration/proof/test_imp.py @@ -10,7 +10,7 @@ from pyk.cterm import CSubst, CTerm from pyk.kast.inner import KApply, KSequence, KSort, KToken, KVariable, Subst from pyk.kast.manip import minimize_term, sort_ac_collections -from pyk.kcfg.semantics import KCFGSemantics +from pyk.kcfg.semantics import DefaultSemantics from pyk.kcfg.show import KCFGShow from pyk.prelude.kbool import BOOL, FALSE, andBool, orBool from pyk.prelude.kint import intToken @@ -34,6 +34,7 @@ from pyk.kast.outer import KDefinition from pyk.kcfg import KCFGExplore from pyk.kcfg.kcfg import KCFGExtendResult + from pyk.kcfg.semantics import KCFGSemantics from pyk.ktool.kprint import KPrint, SymbolTable from pyk.ktool.kprove import KProve from pyk.proof import Prover @@ -46,7 +47,7 @@ def proof_dir(tmp_path_factory: TempPathFactory) -> Path: return tmp_path_factory.mktemp('proofs') -class ImpSemantics(KCFGSemantics): +class ImpSemantics(DefaultSemantics): definition: KDefinition | None def __init__(self, definition: KDefinition | None = None): diff --git a/pyk/src/tests/integration/proof/test_refute_node.py b/pyk/src/tests/integration/proof/test_refute_node.py index 5203f12a79..c5d10f5634 100644 --- a/pyk/src/tests/integration/proof/test_refute_node.py +++ b/pyk/src/tests/integration/proof/test_refute_node.py @@ -12,7 +12,7 @@ from pyk.kast.outer import KClaim from pyk.kcfg import KCFG from pyk.kcfg.minimize import KCFGMinimizer -from pyk.kcfg.semantics import KCFGSemantics +from pyk.kcfg.semantics import DefaultSemantics from pyk.prelude.kint import gtInt, intToken, leInt from pyk.prelude.ml import is_top, mlEqualsTrue from pyk.proof import APRProof, APRProver, ImpliesProver, ProofStatus, RefutationProof @@ -31,12 +31,13 @@ from pyk.kast.outer import KDefinition from pyk.kcfg import KCFGExplore from pyk.kcfg.kcfg import KCFGExtendResult + from pyk.kcfg.semantics import KCFGSemantics from pyk.ktool.kprove import KProve STATE = Union[tuple[str, str], tuple[str, str, str]] -class RefuteSemantics(KCFGSemantics): +class RefuteSemantics(DefaultSemantics): def is_terminal(self, c: CTerm) -> bool: k_cell = c.cell('K_CELL') if type(k_cell) is KSequence: diff --git a/pyk/src/tests/integration/proof/test_simple.py b/pyk/src/tests/integration/proof/test_simple.py index 4ce4d8e7ef..96a72e7153 100644 --- a/pyk/src/tests/integration/proof/test_simple.py +++ b/pyk/src/tests/integration/proof/test_simple.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING from pyk.kast.inner import KApply, KSequence -from pyk.kcfg.semantics import KCFGSemantics +from pyk.kcfg.semantics import DefaultSemantics from pyk.proof import APRProof, APRProver from pyk.testing import KCFGExploreTest, KProveTest from pyk.utils import single @@ -18,12 +18,13 @@ from pyk.kast.outer import KDefinition from pyk.kcfg import KCFGExplore from pyk.kcfg.kcfg import KCFGExtendResult + from pyk.kcfg.semantics import KCFGSemantics from pyk.ktool.kprove import KProve _LOGGER: Final = logging.getLogger(__name__) -class SimpleSemantics(KCFGSemantics): +class SimpleSemantics(DefaultSemantics): def is_terminal(self, c: CTerm) -> bool: k_cell = c.cell('K_CELL') if type(k_cell) is KSequence and type(k_cell[0]) is KApply and k_cell[0].label.name == 'f_SIMPLE-PROOFS_Step': diff --git a/pyk/src/tests/unit/kcfg/merge_node_data.py b/pyk/src/tests/unit/kcfg/merge_node_data.py new file mode 100644 index 0000000000..c2c3854895 --- /dev/null +++ b/pyk/src/tests/unit/kcfg/merge_node_data.py @@ -0,0 +1,472 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Final + +from pyk.cterm import CTerm +from pyk.kast.inner import KApply, KLabel, KSort, KToken, KVariable +from pyk.kast.manip import ml_pred_to_bool +from pyk.kcfg import KCFG +from pyk.kcfg.semantics import DefaultSemantics +from pyk.prelude.kbool import andBool +from pyk.prelude.kint import intToken +from pyk.utils import single + +from ..utils import ge_ml, k, lt_ml + +if TYPE_CHECKING: + from collections.abc import Iterable + + from pyk.kast.inner import KInner + from pyk.kcfg.minimize import KCFGMinimizer + + +def merge_node_test_kcfg() -> KCFG: + """Define a KCFG with all possible scenarios for merging nodes. + + Here are some specifications for the KCFG: + 1. Unable to continue other pattern-rewriting, e.g., lift_edge_edge, lift_split_split, lift_edge_split, ... + 2. Able to test the merged CTerms and the merged CSubsts. + 3. Able to propagate all possible result structures through different heuristics, including + merged-into-one, merged-into-two, partially-merged-into-one, partially-merged-into-two, and not-merged. + 4. Contains Split, Edge, and MergedEdge, because the merging process is targeted at these types of edges. + """ + cfg = KCFG() + # Split Source: A + # 1 -10 <= X < 100 + cfg.create_node( + CTerm(k(KVariable('X')), [ge_ml('X', -10), lt_ml('X', 100)]), + ) + + # Split Targets & Edge Sources: Ai + # 2 -10 <= X < 0 + cfg.create_node(CTerm(k(KVariable('X')), [ge_ml('X', -10), lt_ml('X', 0)])) + # 3 0 <= X < 2 + cfg.create_node(CTerm(k(KVariable('X')), [ge_ml('X', 0), lt_ml('X', 2)])) + # 4 2 <= A < 6 + cfg.create_node(CTerm(k(KVariable('A')), [ge_ml('A', 2), lt_ml('A', 6)])) + # 5 6 <= B < 10 + cfg.create_node(CTerm(k(KVariable('B')), [ge_ml('B', 6), lt_ml('B', 10)])) + # 6 <10> + cfg.create_node(CTerm(k(intToken(10)))) + # 7 <11> + cfg.create_node(CTerm(k(intToken(11)))) + # 8 12 <= Z < 100 + cfg.create_node(CTerm(k(KVariable('Z')), [ge_ml('Z', 12), lt_ml('Z', 100)])) + + # Edge Targets: Bi + # 9 <1> + cfg.create_node(CTerm(k(intToken(1)))) + # 10 <2> + cfg.create_node(CTerm(k(intToken(2)))) + # 11 <3> + cfg.create_node(CTerm(k(intToken(3)))) + # 12 <4> + cfg.create_node(CTerm(k(intToken(4)))) + # 13 <5> + cfg.create_node(CTerm(k(intToken(5)))) + # 14 <6> + cfg.create_node(CTerm(k(intToken(6)))) + # 15 <7> + cfg.create_node(CTerm(k(intToken(7)))) + + # MergedEdge Sources + # 16 2 <= X < 4 + cfg.create_node(CTerm(k(KVariable('X')), [ge_ml('X', 2), lt_ml('X', 4)])) + # 17 4 <= Y < 6 + cfg.create_node(CTerm(k(KVariable('Y')), [ge_ml('Y', 4), lt_ml('Y', 6)])) + # 18 6 <= X < 8 + cfg.create_node(CTerm(k(KVariable('X')), [ge_ml('X', 6), lt_ml('X', 8)])) + # 19 8 <= Y < 10 + cfg.create_node(CTerm(k(KVariable('Y')), [ge_ml('Y', 8), lt_ml('Y', 10)])) + + # MergedEdge Targets + # 20 <8> + cfg.create_node(CTerm(k(intToken(8)))) + # 21 <9> + cfg.create_node(CTerm(k(intToken(9)))) + # 22 <10> + cfg.create_node(CTerm(k(intToken(10)))) + # 23 <11> + cfg.create_node(CTerm(k(intToken(11)))) + + # MergedEdge + e1 = cfg.create_edge(16, 20, 5, ['r1']) + e2 = cfg.create_edge(17, 21, 6, ['r2', 'r3']) + e3 = cfg.create_edge(18, 22, 7, ['r4', 'r5']) + e4 = cfg.create_edge(19, 23, 8, ['r6', 'r7', 'r8']) + cfg.remove_node(16) + cfg.remove_node(17) + cfg.remove_node(18) + cfg.remove_node(19) + cfg.remove_node(20) + cfg.remove_node(21) + cfg.remove_node(22) + cfg.remove_node(23) + + # Split + cfg.create_split_by_nodes(1, [2, 3, 4, 5, 6, 7, 8]) + + # Edge + cfg.create_edge(2, 9, 10, ['r9']) + cfg.create_edge(3, 10, 11, ['r10', 'r11']) + cfg.create_merged_edge(4, 11, [e1, e2]) + cfg.create_merged_edge(5, 12, [e3, e4]) + cfg.create_edge(6, 13, 14, ['r12', 'r13', 'r14']) + cfg.create_edge(7, 14, 15, ['r15']) + cfg.create_edge(8, 15, 16, ['r16']) + + return cfg + + +class MergedNo(DefaultSemantics): + def is_mergeable(self, c1: CTerm, c2: CTerm) -> bool: + return False + + +class MergedOne(DefaultSemantics): + def is_mergeable(self, c1: CTerm, c2: CTerm) -> bool: + return True + + +def util_get_token(c: CTerm) -> int: + assert isinstance(c.config, KApply) + x = c.config.args[0] + assert isinstance(x, KToken) + return int(x.token) + + +class MergedPartialOne0(DefaultSemantics): + def is_mergeable(self, c1: CTerm, c2: CTerm) -> bool: + x = util_get_token(c1) + y = util_get_token(c2) + if x < 3 and y < 3: + return True + return False + + +class MergedPartialOne1(DefaultSemantics): + def is_mergeable(self, c1: CTerm, c2: CTerm) -> bool: + x = util_get_token(c1) + y = util_get_token(c2) + if x < 4 and y < 4: + return True + return False + + +class MergedPartialOne2(DefaultSemantics): + def is_mergeable(self, c1: CTerm, c2: CTerm) -> bool: + x = util_get_token(c1) + y = util_get_token(c2) + if x < 5 and y < 5: + return True + return False + + +class MergedTwo0(DefaultSemantics): + def is_mergeable(self, c1: CTerm, c2: CTerm) -> bool: + x = util_get_token(c1) + y = util_get_token(c2) + if x < 3 and y < 3: + return True + if x >= 3 and y >= 3: + return True + return False + + +class MergedTwo1(DefaultSemantics): + def is_mergeable(self, c1: CTerm, c2: CTerm) -> bool: + x = util_get_token(c1) + y = util_get_token(c2) + if x < 4 and y < 4: + return True + if x >= 4 and y >= 4: + return True + return False + + +class MergedPartialTwo(DefaultSemantics): + def is_mergeable(self, c1: CTerm, c2: CTerm) -> bool: + x = util_get_token(c1) + y = util_get_token(c2) + if x < 4 and y < 4: + return True + if 4 <= x < 6 and 4 <= y < 6: + return True + return False + + +class MergedFail(DefaultSemantics): + def is_mergeable(self, c1: CTerm, c2: CTerm) -> bool: + x = util_get_token(c1) + y = util_get_token(c2) + if x < 5 and y < 5: + return True + if x >= 4 and y >= 4: + return True + return False + + +def util_check_constraint_element(constraint: KInner, merged_var: KInner, under_check: Iterable[int]) -> None: + # orBool + assert isinstance(constraint, KApply) and constraint.label == KLabel('_orBool_') + assert isinstance(merged_var, KVariable) + idx = 0 + count = 0 + eq_idx = 0 + for arg in constraint.args: + # _==K_ + if isinstance(arg, KApply) and arg.label == KLabel('_==K_'): + var = arg.args[0] + token = arg.args[1] + assert isinstance(var, KVariable) and isinstance(token, KToken) + assert var == merged_var + assert int(token.token) in under_check + under_check = [x for x in under_check if x != int(token.token)] + count += 1 + eq_idx = idx + idx += 1 + if count == 2: + assert not under_check + return + + idx = 0 if eq_idx == 1 else 0 + # andBool + ab = constraint.args[idx] + assert isinstance(ab, KApply) and ab.label == KLabel('_andBool_') + idx = 0 + for arg in ab.args: + # x ==K y + if isinstance(arg, KApply) and arg.label == KLabel('_==K_'): + assert isinstance(arg.args[0], KVariable) and isinstance(arg.args[1], KVariable) + m_idx = arg.args.index(merged_var) + other_idx = 0 if m_idx == 1 else 1 + left_idx = 0 if idx == 1 else 1 + util_check_constraint_element(ab.args[left_idx], arg.args[other_idx], under_check) + idx += 1 + + +def util_check_constraint(constraint: KInner, merged_var: KVariable, under_check: Iterable[int]) -> None: + # mlEqualsTrue + assert isinstance(constraint, KApply) and constraint.label == KLabel( + '#Equals', [KSort('Bool'), KSort('GeneratedTopCell')] + ) + assert constraint.args[0] == KToken('true', KSort('Bool')) + util_check_constraint_element(constraint.args[1], merged_var, under_check) + + +def check_merge_no(minimizer: KCFGMinimizer) -> None: + minimizer.minimize() + assert minimizer.kcfg.to_dict() == merge_node_test_kcfg().to_dict() + + +def check_merged_one(minimizer: KCFGMinimizer) -> None: + minimizer.minimize() + # 1 --> merged bi: Merged Edge + merged_edge = single(minimizer.kcfg.merged_edges(source_id=1)) + edges = {2: 9, 3: 10, 16: 20, 17: 21, 18: 22, 19: 23, 6: 13, 7: 14, 8: 15} + for edge in merged_edge.edges: + assert edge.source.id in edges + assert edge.target.id == edges[edge.source.id] + merged_bi = merged_edge.target + assert isinstance(merged_bi.cterm.config, KApply) + merged_var = merged_bi.cterm.config.args[0] + assert isinstance(merged_var, KVariable) + merged_constraint = single(merged_bi.cterm.constraints) + util_check_constraint(merged_constraint, merged_var, [1, 2, 3, 4, 5, 6, 7]) + # merged bi --> 9 - 15: Split + split = single(minimizer.kcfg.splits(source_id=merged_edge.target.id)) + splits = split.splits + expected_splits = {9: 1, 10: 2, 11: 3, 12: 4, 13: 5, 14: 6, 15: 7} + for s in expected_splits: + assert s in splits + x = list(splits[s].subst.values())[0] + assert isinstance(x, KToken) and int(x.token) == expected_splits[s] + + +def check_merged_partial_one0(minimizer: KCFGMinimizer) -> None: + minimizer.merge_nodes() + # merged 9 - 10, else unchanged + # 1 --> merged ai (2,3) & 4 - 8: Split + split = single(minimizer.kcfg.splits(source_id=1)) + splits = split.splits + expected_splits = [4, 5, 6, 7, 8, 24] + assert all(s in splits for s in expected_splits) + merged_ai = minimizer.kcfg.node(24) + merged_ai_c = single(merged_ai.cterm.constraints) + expected_ai_c = KLabel('#Equals', [KSort('Bool'), KSort('GeneratedTopCell')])( + KToken('true', KSort('Bool')), + KLabel('_orBool_')( + andBool([ml_pred_to_bool(c) for c in merge_node_test_kcfg().node(2).cterm.constraints]), + andBool([ml_pred_to_bool(c) for c in merge_node_test_kcfg().node(3).cterm.constraints]), + ), + ) + assert merged_ai_c == expected_ai_c + # merged ai (2,3) --> merged bi (9,10): MergedEdge + merged_edge = single(minimizer.kcfg.merged_edges(source_id=24)) + edges = {2: 9, 3: 10} + assert all(e.source.id in edges and e.target.id == edges[e.source.id] for e in merged_edge.edges) + # merged bi (9,10) --> 9 - 10: Split + split = single(minimizer.kcfg.splits(source_id=merged_edge.target.id)) + splits = split.splits + expected_splits = [9, 10] + assert all(s in splits for s in expected_splits) + # 4 - 8 --> 11 - 15: Edge (unchanged) + for i in range(4, 6): + medge = single(minimizer.kcfg.merged_edges(source_id=i)) + assert medge.target.id == i + 7 + for i in range(6, 9): + edge = single(minimizer.kcfg.edges(source_id=i)) + assert edge.target.id == i + 7 + + +def check_merged_partial_one1(minimizer: KCFGMinimizer) -> None: + minimizer.merge_nodes() + # merged 9 - 11, else unchanged + # 1 --> merged ai (2,3,4) & 5 - 8: Split + split = single(minimizer.kcfg.splits(source_id=1)) + splits = split.splits + expected_splits = [5, 6, 7, 8, 24] + assert all(s in splits for s in expected_splits) + # merged ai (2,3,4) --> merged bi (9,10,11): MergedEdge + merged_edge = single(minimizer.kcfg.merged_edges(source_id=24)) + # edges = {2: 9, 3: 10, 4: 11} + edges = {2: 9, 3: 10, 16: 20, 17: 21} + assert all(e.source.id in edges and e.target.id == edges[e.source.id] for e in merged_edge.edges) + # merged bi (9,10,11) --> 9 - 11: Split + split = single(minimizer.kcfg.splits(source_id=merged_edge.target.id)) + splits = split.splits + expected_splits = [9, 10, 11] + assert all(s in splits for s in expected_splits) + # 5 - 8 --> 12 - 15: Edge (unchanged) + medge = single(minimizer.kcfg.merged_edges(source_id=5)) + assert medge.target.id == 12 + for i in range(6, 9): + edge = single(minimizer.kcfg.edges(source_id=i)) + assert edge.target.id == i + 7 + + +def check_merged_partial_one2(minimizer: KCFGMinimizer) -> None: + minimizer.merge_nodes() + # merged 9 - 12, else unchanged + # 1 --> merged ai (2,3,4,5) & 6 - 8: Split + split = single(minimizer.kcfg.splits(source_id=1)) + splits = split.splits + expected_splits = [6, 7, 8, 24] + assert all(s in splits for s in expected_splits) + # merged ai (2,3,4,5) --> merged bi (9,10,11,12): MergedEdge + merged_edge = single(minimizer.kcfg.merged_edges(source_id=24)) + edges = {2: 9, 3: 10, 16: 20, 17: 21, 18: 22, 19: 23} + assert all(e.source.id in edges and e.target.id == edges[e.source.id] for e in merged_edge.edges) + # merged bi (9,10,11,12) --> 9 - 12: Split + split = single(minimizer.kcfg.splits(source_id=merged_edge.target.id)) + splits = split.splits + expected_splits = [9, 10, 11, 12] + assert all(s in splits for s in expected_splits) + # 6 - 8 --> 13 - 15: Edge (unchanged) + for i in range(6, 9): + edge = single(minimizer.kcfg.edges(source_id=i)) + assert edge.target.id == i + 7 + + +def check_merged_two0(minimizer: KCFGMinimizer) -> None: + minimizer.merge_nodes() + # merged 9 - 10, 11 - 15, else unchanged + # 1 --> merged ai (2,3) & (4-8): Split + split = single(minimizer.kcfg.splits(source_id=1)) + splits = split.splits + expected_splits = [24, 26] + assert all(s in splits for s in expected_splits) + # merged ai (4-8) --> merged bi (11-15): MergedEdge + merged_edge = single(minimizer.kcfg.merged_edges(source_id=24)) + edges = {2: 9, 3: 10} + assert all(e.source.id in edges and e.target.id == edges[e.source.id] for e in merged_edge.edges) + # merged bi (11-15) --> 11 - 15: Split + split = single(minimizer.kcfg.splits(source_id=merged_edge.target.id)) + splits = split.splits + expected_splits = [9, 10] + assert all(s in splits for s in expected_splits) + # merged ai (2,3) --> merged bi (9,10): MergedEdge + merged_edge = single(minimizer.kcfg.merged_edges(source_id=26)) + edges = {16: 20, 17: 21, 18: 22, 19: 23, 6: 13, 7: 14, 8: 15} + assert all(e.source.id in edges and e.target.id == edges[e.source.id] for e in merged_edge.edges) + # merged bi (9,10) --> 9 - 10: Split + split = single(minimizer.kcfg.splits(source_id=merged_edge.target.id)) + splits = split.splits + expected_splits = [11, 12, 13, 14, 15] + assert all(s in splits for s in expected_splits) + + +def check_merged_two1(minimizer: KCFGMinimizer) -> None: + minimizer.merge_nodes() + # merged 9 - 11, 12 - 15, else unchanged + # 1 --> merged ai (2,3,4) & (5-8): Split + split = single(minimizer.kcfg.splits(source_id=1)) + splits = split.splits + expected_splits = [24, 26] + assert all(s in splits for s in expected_splits) + # merged ai (5-8) --> merged bi (12-15): MergedEdge + merged_edge = single(minimizer.kcfg.merged_edges(source_id=24)) + edges = {2: 9, 3: 10, 16: 20, 17: 21} + assert all(e.source.id in edges and e.target.id == edges[e.source.id] for e in merged_edge.edges) + # merged bi (12-15) --> 12 - 15: Split + split = single(minimizer.kcfg.splits(source_id=merged_edge.target.id)) + splits = split.splits + expected_splits = [9, 10, 11] + assert all(s in splits for s in expected_splits) + # merged ai (2,3,4) --> merged bi (9,10,11): MergedEdge + merged_edge = single(minimizer.kcfg.merged_edges(source_id=26)) + edges = {18: 22, 19: 23, 6: 13, 7: 14, 8: 15} + assert all(e.source.id in edges and e.target.id == edges[e.source.id] for e in merged_edge.edges) + # merged bi (9,10,11) --> 9 - 11: Split + split = single(minimizer.kcfg.splits(source_id=merged_edge.target.id)) + splits = split.splits + expected_splits = [12, 13, 14, 15] + assert all(s in splits for s in expected_splits) + + +def check_merged_partial_two(minimizer: KCFGMinimizer) -> None: + minimizer.merge_nodes() + # 1 --> merged ai (2,3,4) & (5,6) & 7 - 8: Split + split = single(minimizer.kcfg.splits(source_id=1)) + splits = split.splits + expected_splits = [7, 8, 24, 26] + assert all(s in splits for s in expected_splits) + # merged ai (5,6) --> merged bi (13,14): MergedEdge + merged_edge = single(minimizer.kcfg.merged_edges(source_id=24)) + edges = {2: 9, 3: 10, 16: 20, 17: 21} + assert all(e.source.id in edges and e.target.id == edges[e.source.id] for e in merged_edge.edges) + # merged bi (13,14) --> 13 - 14: Split + split = single(minimizer.kcfg.splits(source_id=merged_edge.target.id)) + splits = split.splits + expected_splits = [9, 10, 11] + assert all(s in splits for s in expected_splits) + # merged ai (2,3,4) --> merged bi (9,10,11): MergedEdge + merged_edge = single(minimizer.kcfg.merged_edges(source_id=26)) + edges = {18: 22, 19: 23, 6: 13} + assert all(e.source.id in edges and e.target.id == edges[e.source.id] for e in merged_edge.edges) + # merged bi (9,10,11) --> 9 - 11: Split + split = single(minimizer.kcfg.splits(source_id=merged_edge.target.id)) + splits = split.splits + expected_splits = [12, 13] + assert all(s in splits for s in expected_splits) + + +def check_merged_fail(minimizer: KCFGMinimizer) -> None: + try: + minimizer.merge_nodes() + raise AssertionError + except ValueError: + pass + + +KCFG_MERGE_NODE_TEST_DATA: Final = ( + (MergedNo(), merge_node_test_kcfg(), check_merge_no), + (MergedOne(), merge_node_test_kcfg(), check_merged_one), + (MergedPartialOne0(), merge_node_test_kcfg(), check_merged_partial_one0), + (MergedPartialOne1(), merge_node_test_kcfg(), check_merged_partial_one1), + (MergedPartialOne2(), merge_node_test_kcfg(), check_merged_partial_one2), + (MergedTwo0(), merge_node_test_kcfg(), check_merged_two0), + (MergedTwo1(), merge_node_test_kcfg(), check_merged_two1), + (MergedPartialTwo(), merge_node_test_kcfg(), check_merged_partial_two), + (MergedFail(), merge_node_test_kcfg(), check_merged_fail), +) diff --git a/pyk/src/tests/unit/kcfg/test_minimize.py b/pyk/src/tests/unit/kcfg/test_minimize.py index 4221f3eb65..b4f0d91a9b 100644 --- a/pyk/src/tests/unit/kcfg/test_minimize.py +++ b/pyk/src/tests/unit/kcfg/test_minimize.py @@ -3,8 +3,6 @@ from typing import TYPE_CHECKING import pytest -from unit.mock_kprint import MockKPrint -from unit.test_kcfg import edge_dicts, node_dicts, split_dicts, to_csubst_node, x_config, x_node, x_subst from pyk.kast.inner import KVariable from pyk.kcfg import KCFG, KCFGShow @@ -14,9 +12,16 @@ from pyk.prelude.ml import mlEqualsTrue, mlTop from pyk.utils import single +from ..kcfg.merge_node_data import KCFG_MERGE_NODE_TEST_DATA +from ..mock_kprint import MockKPrint +from ..test_kcfg import edge_dicts, node_dicts, split_dicts, to_csubst_node, x_config, x_node, x_subst + if TYPE_CHECKING: + from collections.abc import Callable + from pyk.kast.inner import KApply from pyk.kcfg.kcfg import NodeIdLike + from pyk.kcfg.semantics import KCFGSemantics def contains_edge(cfg: KCFG, source: NodeIdLike, target: NodeIdLike, depth: int, rules: tuple[str, ...]) -> bool: @@ -407,3 +412,12 @@ def x_lt(n: int) -> KApply: ) assert actual == expected + + +@pytest.mark.parametrize('heuristics,kcfg,check', KCFG_MERGE_NODE_TEST_DATA) +def test_merge_nodes(heuristics: KCFGSemantics, kcfg: KCFG, check: Callable[[KCFGMinimizer], None]) -> None: + # When + minimizer = KCFGMinimizer(kcfg, heuristics) + + # Then + check(minimizer) diff --git a/pyk/src/tests/unit/test_utils.py b/pyk/src/tests/unit/test_utils.py index 55f5d14470..43bf5e0335 100644 --- a/pyk/src/tests/unit/test_utils.py +++ b/pyk/src/tests/unit/test_utils.py @@ -5,10 +5,10 @@ import pytest -from pyk.utils import POSet, deconstruct_short_hash +from pyk.utils import POSet, deconstruct_short_hash, partition if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Callable, Iterable FULL_HASH: Final = '0001000200030004000500060007000800010002000300040005000600070008' @@ -58,3 +58,29 @@ def test_poset(relation: Iterable[tuple[int, int]], expected: dict[int, set[int] # Then assert actual == expected + + +@pytest.mark.parametrize( + 'iterable,pred,expected', + ( + ([1, 2, 3, 4], lambda x, y: x % 2 == y % 2, [[1, 3], [2, 4]]), + ([1, 2, 3, 4], lambda x, y: x % 2 == 0 and y % 2 == 0, [[1], [2, 4], [3]]), + ([1, 2, 3, 4], lambda x, y: x % 2 == 1 and y % 2 == 1, [[1, 3], [2], [4]]), + ([1, 2, 3, 4], lambda x, y: x % 2 == 0, None), + ([1, 2, 3, 4], lambda x, y: x % 2 == 0 and y % 2 == 1, None), + ([1, 2, 3, 4], lambda x, y: x % 2 == 1 and y % 2 == 0, None), + ), +) +def test_partition(iterable: Iterable[int], pred: Callable[[int, int], bool], expected: list[list[int]] | None) -> None: + # When + try: + actual = partition(iterable, pred) + + # Then + except ValueError as e: + if not expected: + assert str(e).startswith('Partitioning failed') + return + raise + + assert actual == expected