diff --git a/pyk/src/pyk/kast/outer.py b/pyk/src/pyk/kast/outer.py index 29a91ab50f9..1cd84118ab1 100644 --- a/pyk/src/pyk/kast/outer.py +++ b/pyk/src/pyk/kast/outer.py @@ -185,18 +185,13 @@ def label(self) -> str: """Return a (hopefully) unique label associated with the given `KSentence`. :return: Unique label for the given sentence, either (in order): - - User supplied `label` attribute (or supplied in rule label), - - Unique identifier computed and inserted by the frontend, or - - Source location for the sentence. + - User supplied `label` attribute (or supplied in rule label),or + - Unique identifier computed and inserted by the frontend. """ - if Atts.LABEL in self.att: - return self.att[Atts.LABEL] - elif self.unique_id is not None: - return self.unique_id - elif self.source is not None: - _LOGGER.warning(f'Found a sentence without label or UNIQUE_ID: {self}') - return self.source - raise ValueError(f'Found sentence without label, UNIQUE_ID, or SOURCE:LOCATION: {self}') + label = self.att.get(Atts.LABEL, self.unique_id) + if label is None: + raise ValueError(f'Found sentence without label or UNIQUE_ID: {self}') + return label @final diff --git a/pyk/src/pyk/ktool/kprove.py b/pyk/src/pyk/ktool/kprove.py index dfa9c6b9350..b9ff48dd328 100644 --- a/pyk/src/pyk/ktool/kprove.py +++ b/pyk/src/pyk/ktool/kprove.py @@ -4,8 +4,12 @@ import logging import os import re +from collections.abc import Mapping from contextlib import contextmanager +from dataclasses import dataclass from enum import Enum +from functools import cached_property, partial +from graphlib import TopologicalSorter from itertools import chain from pathlib import Path from subprocess import CalledProcessError @@ -16,21 +20,21 @@ from ..kast import Atts, kast_term from ..kast.inner import KInner from ..kast.manip import extract_lhs, flatten_label -from ..kast.outer import KApply, KDefinition, KFlatModule, KFlatModuleList, KImport, KRequire +from ..kast.outer import KApply, KClaim, KDefinition, KFlatModule, KFlatModuleList, KImport, KRequire from ..kore.rpc import KoreExecLogFormat from ..prelude.ml import is_top from ..proof import APRProof, APRProver, EqualityProof, ImpliesProver -from ..utils import gen_file_timestamp, run_process +from ..utils import FrozenDict, gen_file_timestamp, run_process, unique from . import TypeInferenceMode from .kprint import KPrint if TYPE_CHECKING: - from collections.abc import Callable, Iterable, Iterator, Mapping + from collections.abc import Callable, Container, Iterable, Iterator from subprocess import CompletedProcess from typing import ContextManager, Final from ..cli.pyk import ProveOptions - from ..kast.outer import KClaim, KRule, KRuleLike + from ..kast.outer import KRule, KRuleLike from ..kast.pretty import SymbolTable from ..kcfg import KCFGExplore from ..proof import Proof, Prover @@ -283,7 +287,26 @@ def get_claim_modules( type_inference_mode=type_inference_mode, args=['--emit-json-spec', ntf.name], ) - return KFlatModuleList.from_dict(kast_term(json.loads(Path(ntf.name).read_text()))) + json_data = json.loads(Path(ntf.name).read_text()) + + return KFlatModuleList.from_dict(kast_term(json_data)) + + def get_claim_index( + self, + spec_file: Path, + spec_module_name: str | None = None, + include_dirs: Iterable[Path] = (), + md_selector: str | None = None, + type_inference_mode: TypeInferenceMode | None = None, + ) -> ClaimIndex: + module_list = self.get_claim_modules( + spec_file=spec_file, + spec_module_name=spec_module_name, + include_dirs=include_dirs, + md_selector=md_selector, + type_inference_mode=type_inference_mode, + ) + return ClaimIndex.from_module_list(module_list) def get_claims( self, @@ -296,7 +319,7 @@ def get_claims( include_dependencies: bool = True, type_inference_mode: TypeInferenceMode | None = None, ) -> list[KClaim]: - flat_module_list = self.get_claim_modules( + claim_index = self.get_claim_index( spec_file=spec_file, spec_module_name=spec_module_name, include_dirs=include_dirs, @@ -304,49 +327,13 @@ def get_claims( type_inference_mode=type_inference_mode, ) - _module_names = [module.name for module in flat_module_list.modules] - - def _get_claim_module(_label: str) -> str | None: - if _label.find('.') > 0 and _label.split('.')[0] in _module_names: - return _label.split('.')[0] - return None - - all_claims = { - claim.label: (claim, module.name) for module in flat_module_list.modules for claim in module.claims - } - - claim_labels = list(all_claims.keys()) if claim_labels is None else list(claim_labels) - exclude_claim_labels = [] if exclude_claim_labels is None else list(exclude_claim_labels) - - final_claims: dict[str, KClaim] = {} - unfound_labels: list[str] = [] - while len(claim_labels) > 0: - claim_label = claim_labels.pop(0) - if claim_label in final_claims or claim_label in exclude_claim_labels: - continue - if claim_label not in all_claims: - claim_label = f'{flat_module_list.main_module}.{claim_label}' - if claim_label not in all_claims: - unfound_labels.append(claim_label) - continue - - _claim, _module_name = all_claims[claim_label] - _updated_dependencies: list[str] = [] - for _dependency_label in _claim.dependencies: - if _get_claim_module(_dependency_label) is None: - _dependency_label = f'{_module_name}.{_dependency_label}' - _updated_dependencies.append(_dependency_label) - if len(_updated_dependencies) > 0: - if include_dependencies: - claim_labels.extend(_updated_dependencies) - _claim = _claim.let(att=_claim.att.update([Atts.DEPENDS(','.join(_updated_dependencies))])) - - final_claims[claim_label] = _claim - - if len(unfound_labels) > 0: - raise ValueError(f'Claim labels not found: {unfound_labels}') + labels = claim_index.labels( + include=claim_labels, + exclude=exclude_claim_labels, + with_depends=include_dependencies, + ) - return list(final_claims.values()) + return [claim_index[label] for label in labels] @contextmanager def _tmp_claim_definition( @@ -493,3 +480,138 @@ def _prove_claim_rpc( else: _LOGGER.info(f'Proof pending: {proof.id}') return proof + + +@dataclass(frozen=True) +class ClaimIndex(Mapping[str, KClaim]): + claims: FrozenDict[str, KClaim] + main_module_name: str | None + + def __init__( + self, + claims: Mapping[str, KClaim], + main_module_name: str | None = None, + ): + self._validate(claims) + object.__setattr__(self, 'claims', FrozenDict(claims)) + object.__setattr__(self, 'main_module_name', main_module_name) + + @staticmethod + def from_module_list(module_list: KFlatModuleList) -> ClaimIndex: + module_list = ClaimIndex._resolve_depends(module_list) + return ClaimIndex( + claims={claim.label: claim for module in module_list.modules for claim in module.claims}, + main_module_name=module_list.main_module, + ) + + @staticmethod + def _validate(claims: Mapping[str, KClaim]) -> None: + for label, claim in claims.items(): + if claim.label != label: + raise ValueError(f'Claim label mismatch, expected: {label}, found: {claim.label}') + + for depend in claim.dependencies: + if depend not in claims: + raise ValueError(f'Invalid dependency label: {depend}') + + @staticmethod + def _resolve_depends(module_list: KFlatModuleList) -> KFlatModuleList: + """Resolve each depends value relative to the module the claim belongs to. + + Example: + + module THIS-MODULE + claim ... [depends(foo,OTHER-MODULE.bar)] + endmodule + + becomes + + module THIS-MODULE + claim ... [depends(THIS-MODULE.foo,OTHER-MODULE.bar)] + endmodule + """ + + labels = {claim.label for module in module_list.modules for claim in module.claims} + + def resolve_claim_depends(module_name: str, claim: KClaim) -> KClaim: + depends = claim.dependencies + if not depends: + return claim + + resolve = partial(ClaimIndex._resolve_claim_label, labels, module_name) + resolved = [resolve(label) for label in depends] + return claim.let(att=claim.att.update([Atts.DEPENDS(','.join(resolved))])) + + modules: list[KFlatModule] = [] + for module in module_list.modules: + resolve_depends = partial(resolve_claim_depends, module.name) + module = module.map_sentences(resolve_depends, of_type=KClaim) + modules.append(module) + + return module_list.let(modules=modules) + + @staticmethod + def _resolve_claim_label(labels: Container[str], module_name: str | None, label: str) -> str: + """Resolve `label` to a valid label in `labels`, or raise. + + If a `label` is not found and `module_name` is set, the label is tried after qualifying. + """ + if label in labels: + return label + + if module_name is not None: + qualified = f'{module_name}.{label}' + if qualified in labels: + return qualified + + raise ValueError(f'Claim label not found: {label}') + + def __iter__(self) -> Iterator[str]: + return iter(self.claims) + + def __len__(self) -> int: + return len(self.claims) + + def __getitem__(self, label: str) -> KClaim: + try: + label = self.resolve(label) + except ValueError: + raise KeyError(f'Claim not found: {label}') from None + return self.claims[label] + + @cached_property + def topological(self) -> tuple[str, ...]: + graph = {label: claim.dependencies for label, claim in self.claims.items()} + return tuple(TopologicalSorter(graph).static_order()) + + def resolve(self, label: str) -> str: + return self._resolve_claim_label(self.claims, self.main_module_name, label) + + def resolve_all(self, labels: Iterable[str]) -> list[str]: + return [self.resolve(label) for label in unique(labels)] + + def labels( + self, + *, + include: Iterable[str] | None = None, + exclude: Iterable[str] | None = None, + with_depends: bool = True, + ) -> list[str]: + res: list[str] = [] + + pending = self.resolve_all(include) if include is not None else list(self.claims) + done = set(self.resolve_all(exclude)) if exclude is not None else set() + + while pending: + label = pending.pop(0) # BFS + + if label in done: + continue + + res.append(label) + done.add(label) + + if with_depends: + pending += self.claims[label].dependencies + + return res