Skip to content

Commit

Permalink
Extract class ClaimIndex (#4444)
Browse files Browse the repository at this point in the history
Preparatory refactoring for:
* #4406

Extracts a class `ClaimIndex` that enables decoupling the two main steps
of `KProve.get_claims`:
- `KProve.get_claim_index`: create a data structure from a K spec file,
expensive. We eventually want to save this part by caching the data
structure.
- `ClaimIndex.labels`: filter the data structure, cheap.

`ClaimIndex` itself is a mapping from labels to `KClaim`-s. The
invariant it enforces (`ClaimIndex._validate`):
- For each item `(label, claim)` in the mapping, `claim.label == label`.
- For each claim, its `depend` attribute values are in the mapping.

In addition, when attribute `main_module_name` is set, it allows looking
up claims from the main module without qualifying the label (i.e. to
look up `MAIN_MODULE.foo` with key `foo`).

As a next step, `APRProof.from_spec_modules` can be simplified.
  • Loading branch information
tothtamas28 authored Jun 15, 2024
1 parent 2af5d76 commit ec63f14
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 59 deletions.
17 changes: 6 additions & 11 deletions pyk/src/pyk/kast/outer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,18 +185,13 @@ def label(self) -> str:
"""Return a (hopefully) unique label associated with the given `KSentence`.
:return: Unique label for the given sentence, either (in order):
- User supplied `label` attribute (or supplied in rule label),
- Unique identifier computed and inserted by the frontend, or
- Source location for the sentence.
- User supplied `label` attribute (or supplied in rule label),or
- Unique identifier computed and inserted by the frontend.
"""
if Atts.LABEL in self.att:
return self.att[Atts.LABEL]
elif self.unique_id is not None:
return self.unique_id
elif self.source is not None:
_LOGGER.warning(f'Found a sentence without label or UNIQUE_ID: {self}')
return self.source
raise ValueError(f'Found sentence without label, UNIQUE_ID, or SOURCE:LOCATION: {self}')
label = self.att.get(Atts.LABEL, self.unique_id)
if label is None:
raise ValueError(f'Found sentence without label or UNIQUE_ID: {self}')
return label


@final
Expand Down
218 changes: 170 additions & 48 deletions pyk/src/pyk/ktool/kprove.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import logging
import os
import re
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from functools import cached_property, partial
from graphlib import TopologicalSorter
from itertools import chain
from pathlib import Path
from subprocess import CalledProcessError
Expand All @@ -16,21 +20,21 @@
from ..kast import Atts, kast_term
from ..kast.inner import KInner
from ..kast.manip import extract_lhs, flatten_label
from ..kast.outer import KApply, KDefinition, KFlatModule, KFlatModuleList, KImport, KRequire
from ..kast.outer import KApply, KClaim, KDefinition, KFlatModule, KFlatModuleList, KImport, KRequire
from ..kore.rpc import KoreExecLogFormat
from ..prelude.ml import is_top
from ..proof import APRProof, APRProver, EqualityProof, ImpliesProver
from ..utils import gen_file_timestamp, run_process
from ..utils import FrozenDict, gen_file_timestamp, run_process, unique
from . import TypeInferenceMode
from .kprint import KPrint

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator, Mapping
from collections.abc import Callable, Container, Iterable, Iterator
from subprocess import CompletedProcess
from typing import ContextManager, Final

from ..cli.pyk import ProveOptions
from ..kast.outer import KClaim, KRule, KRuleLike
from ..kast.outer import KRule, KRuleLike
from ..kast.pretty import SymbolTable
from ..kcfg import KCFGExplore
from ..proof import Proof, Prover
Expand Down Expand Up @@ -283,7 +287,26 @@ def get_claim_modules(
type_inference_mode=type_inference_mode,
args=['--emit-json-spec', ntf.name],
)
return KFlatModuleList.from_dict(kast_term(json.loads(Path(ntf.name).read_text())))
json_data = json.loads(Path(ntf.name).read_text())

return KFlatModuleList.from_dict(kast_term(json_data))

def get_claim_index(
self,
spec_file: Path,
spec_module_name: str | None = None,
include_dirs: Iterable[Path] = (),
md_selector: str | None = None,
type_inference_mode: TypeInferenceMode | None = None,
) -> ClaimIndex:
module_list = self.get_claim_modules(
spec_file=spec_file,
spec_module_name=spec_module_name,
include_dirs=include_dirs,
md_selector=md_selector,
type_inference_mode=type_inference_mode,
)
return ClaimIndex.from_module_list(module_list)

def get_claims(
self,
Expand All @@ -296,57 +319,21 @@ def get_claims(
include_dependencies: bool = True,
type_inference_mode: TypeInferenceMode | None = None,
) -> list[KClaim]:
flat_module_list = self.get_claim_modules(
claim_index = self.get_claim_index(
spec_file=spec_file,
spec_module_name=spec_module_name,
include_dirs=include_dirs,
md_selector=md_selector,
type_inference_mode=type_inference_mode,
)

_module_names = [module.name for module in flat_module_list.modules]

def _get_claim_module(_label: str) -> str | None:
if _label.find('.') > 0 and _label.split('.')[0] in _module_names:
return _label.split('.')[0]
return None

all_claims = {
claim.label: (claim, module.name) for module in flat_module_list.modules for claim in module.claims
}

claim_labels = list(all_claims.keys()) if claim_labels is None else list(claim_labels)
exclude_claim_labels = [] if exclude_claim_labels is None else list(exclude_claim_labels)

final_claims: dict[str, KClaim] = {}
unfound_labels: list[str] = []
while len(claim_labels) > 0:
claim_label = claim_labels.pop(0)
if claim_label in final_claims or claim_label in exclude_claim_labels:
continue
if claim_label not in all_claims:
claim_label = f'{flat_module_list.main_module}.{claim_label}'
if claim_label not in all_claims:
unfound_labels.append(claim_label)
continue

_claim, _module_name = all_claims[claim_label]
_updated_dependencies: list[str] = []
for _dependency_label in _claim.dependencies:
if _get_claim_module(_dependency_label) is None:
_dependency_label = f'{_module_name}.{_dependency_label}'
_updated_dependencies.append(_dependency_label)
if len(_updated_dependencies) > 0:
if include_dependencies:
claim_labels.extend(_updated_dependencies)
_claim = _claim.let(att=_claim.att.update([Atts.DEPENDS(','.join(_updated_dependencies))]))

final_claims[claim_label] = _claim

if len(unfound_labels) > 0:
raise ValueError(f'Claim labels not found: {unfound_labels}')
labels = claim_index.labels(
include=claim_labels,
exclude=exclude_claim_labels,
with_depends=include_dependencies,
)

return list(final_claims.values())
return [claim_index[label] for label in labels]

@contextmanager
def _tmp_claim_definition(
Expand Down Expand Up @@ -493,3 +480,138 @@ def _prove_claim_rpc(
else:
_LOGGER.info(f'Proof pending: {proof.id}')
return proof


@dataclass(frozen=True)
class ClaimIndex(Mapping[str, KClaim]):
claims: FrozenDict[str, KClaim]
main_module_name: str | None

def __init__(
self,
claims: Mapping[str, KClaim],
main_module_name: str | None = None,
):
self._validate(claims)
object.__setattr__(self, 'claims', FrozenDict(claims))
object.__setattr__(self, 'main_module_name', main_module_name)

@staticmethod
def from_module_list(module_list: KFlatModuleList) -> ClaimIndex:
module_list = ClaimIndex._resolve_depends(module_list)
return ClaimIndex(
claims={claim.label: claim for module in module_list.modules for claim in module.claims},
main_module_name=module_list.main_module,
)

@staticmethod
def _validate(claims: Mapping[str, KClaim]) -> None:
for label, claim in claims.items():
if claim.label != label:
raise ValueError(f'Claim label mismatch, expected: {label}, found: {claim.label}')

for depend in claim.dependencies:
if depend not in claims:
raise ValueError(f'Invalid dependency label: {depend}')

@staticmethod
def _resolve_depends(module_list: KFlatModuleList) -> KFlatModuleList:
"""Resolve each depends value relative to the module the claim belongs to.
Example:
module THIS-MODULE
claim ... [depends(foo,OTHER-MODULE.bar)]
endmodule
becomes
module THIS-MODULE
claim ... [depends(THIS-MODULE.foo,OTHER-MODULE.bar)]
endmodule
"""

labels = {claim.label for module in module_list.modules for claim in module.claims}

def resolve_claim_depends(module_name: str, claim: KClaim) -> KClaim:
depends = claim.dependencies
if not depends:
return claim

resolve = partial(ClaimIndex._resolve_claim_label, labels, module_name)
resolved = [resolve(label) for label in depends]
return claim.let(att=claim.att.update([Atts.DEPENDS(','.join(resolved))]))

modules: list[KFlatModule] = []
for module in module_list.modules:
resolve_depends = partial(resolve_claim_depends, module.name)
module = module.map_sentences(resolve_depends, of_type=KClaim)
modules.append(module)

return module_list.let(modules=modules)

@staticmethod
def _resolve_claim_label(labels: Container[str], module_name: str | None, label: str) -> str:
"""Resolve `label` to a valid label in `labels`, or raise.
If a `label` is not found and `module_name` is set, the label is tried after qualifying.
"""
if label in labels:
return label

if module_name is not None:
qualified = f'{module_name}.{label}'
if qualified in labels:
return qualified

raise ValueError(f'Claim label not found: {label}')

def __iter__(self) -> Iterator[str]:
return iter(self.claims)

def __len__(self) -> int:
return len(self.claims)

def __getitem__(self, label: str) -> KClaim:
try:
label = self.resolve(label)
except ValueError:
raise KeyError(f'Claim not found: {label}') from None
return self.claims[label]

@cached_property
def topological(self) -> tuple[str, ...]:
graph = {label: claim.dependencies for label, claim in self.claims.items()}
return tuple(TopologicalSorter(graph).static_order())

def resolve(self, label: str) -> str:
return self._resolve_claim_label(self.claims, self.main_module_name, label)

def resolve_all(self, labels: Iterable[str]) -> list[str]:
return [self.resolve(label) for label in unique(labels)]

def labels(
self,
*,
include: Iterable[str] | None = None,
exclude: Iterable[str] | None = None,
with_depends: bool = True,
) -> list[str]:
res: list[str] = []

pending = self.resolve_all(include) if include is not None else list(self.claims)
done = set(self.resolve_all(exclude)) if exclude is not None else set()

while pending:
label = pending.pop(0) # BFS

if label in done:
continue

res.append(label)
done.add(label)

if with_depends:
pending += self.claims[label].dependencies

return res

0 comments on commit ec63f14

Please sign in to comment.