Skip to content

Commit

Permalink
Refactor from_spec_modules in APRProof (#4447)
Browse files Browse the repository at this point in the history
-  Move `ProveRpc`, `ClaimIndex` into separate modules
- Add parameter `ordered` to `ClaimIndex.labels` to enable topological
sorting of the result
- Implement `APRProof.from_spec_modules` using `ClaimIndex`
  • Loading branch information
tothtamas28 authored Jun 24, 2024
1 parent c14fbb4 commit 67eff5d
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 299 deletions.
3 changes: 2 additions & 1 deletion pyk/src/pyk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@
from .kore.syntax import Pattern, kore_term
from .ktool.kompile import Kompile, KompileBackend
from .ktool.kprint import KPrint
from .ktool.kprove import KProve, ProveRpc
from .ktool.kprove import KProve
from .ktool.krun import KRun
from .ktool.prove_rpc import ProveRpc
from .prelude.k import GENERATED_TOP_CELL
from .prelude.ml import is_top, mlAnd, mlOr
from .proof.reachability import APRFailureInfo, APRProof
Expand Down
188 changes: 188 additions & 0 deletions pyk/src/pyk/ktool/claim_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from __future__ import annotations

from collections.abc import Mapping
from dataclasses import dataclass
from functools import partial
from graphlib import TopologicalSorter
from typing import TYPE_CHECKING

from ..kast import Atts
from ..kast.outer import KClaim
from ..utils import FrozenDict, unique

if TYPE_CHECKING:
from collections.abc import Container, Iterable, Iterator

from ..kast.outer import KFlatModule, KFlatModuleList


@dataclass(frozen=True)
class ClaimIndex(Mapping[str, KClaim]):
claims: FrozenDict[str, KClaim]
main_module_name: str | None

def __init__(
self,
claims: Mapping[str, KClaim],
main_module_name: str | None = None,
):
self._validate(claims)
object.__setattr__(self, 'claims', FrozenDict(claims))
object.__setattr__(self, 'main_module_name', main_module_name)

@staticmethod
def from_module_list(module_list: KFlatModuleList) -> ClaimIndex:
module_list = ClaimIndex._resolve_depends(module_list)
return ClaimIndex(
claims={claim.label: claim for module in module_list.modules for claim in module.claims},
main_module_name=module_list.main_module,
)

@staticmethod
def _validate(claims: Mapping[str, KClaim]) -> None:
for label, claim in claims.items():
if claim.label != label:
raise ValueError(f'Claim label mismatch, expected: {label}, found: {claim.label}')

for depend in claim.dependencies:
if depend not in claims:
raise ValueError(f'Invalid dependency label: {depend}')

@staticmethod
def _resolve_depends(module_list: KFlatModuleList) -> KFlatModuleList:
"""Resolve each depends value relative to the module the claim belongs to.
Example:
```
module THIS-MODULE
claim ... [depends(foo,OTHER-MODULE.bar)]
endmodule
```
becomes
```
module THIS-MODULE
claim ... [depends(THIS-MODULE.foo,OTHER-MODULE.bar)]
endmodule
```
"""
labels = {claim.label for module in module_list.modules for claim in module.claims}

def resolve_claim_depends(module_name: str, claim: KClaim) -> KClaim:
depends = claim.dependencies
if not depends:
return claim

resolve = partial(ClaimIndex._resolve_claim_label, labels, module_name)
resolved = [resolve(label) for label in depends]
return claim.let(att=claim.att.update([Atts.DEPENDS(','.join(resolved))]))

modules: list[KFlatModule] = []
for module in module_list.modules:
resolve_depends = partial(resolve_claim_depends, module.name)
module = module.map_sentences(resolve_depends, of_type=KClaim)
modules.append(module)

return module_list.let(modules=modules)

@staticmethod
def _resolve_claim_label(labels: Container[str], module_name: str | None, label: str) -> str:
"""Resolve `label` to a valid label in `labels`, or raise.
If a `label` is not found and `module_name` is set, the label is tried after qualifying.
"""
if label in labels:
return label

if module_name is not None:
qualified = f'{module_name}.{label}'
if qualified in labels:
return qualified

raise ValueError(f'Claim label not found: {label}')

def __iter__(self) -> Iterator[str]:
return iter(self.claims)

def __len__(self) -> int:
return len(self.claims)

def __getitem__(self, label: str) -> KClaim:
try:
label = self.resolve(label)
except ValueError:
raise KeyError(f'Claim not found: {label}') from None
return self.claims[label]

def resolve(self, label: str) -> str:
return self._resolve_claim_label(self.claims, self.main_module_name, label)

def resolve_all(self, labels: Iterable[str]) -> list[str]:
return [self.resolve(label) for label in unique(labels)]

def labels(
self,
*,
include: Iterable[str] | None = None,
exclude: Iterable[str] | None = None,
with_depends: bool = True,
ordered: bool = False,
) -> list[str]:
"""Return a list of labels from the index.
Args:
include: Labels to include in the result. If `None`, all labels are included.
exclude: Labels to exclude from the result. If `None`, no labels are excluded.
Takes precedence over `include`.
with_depends: If `True`, the result is transitively closed w.r.t. the dependency relation.
Labels in `exclude` are pruned, and their dependencies are not considered on the given path.
ordered: If `True`, the result is topologically sorted w.r.t. the dependency relation.
Returns:
A list of labels from the index.
Raises:
ValueError: If an item in `include` or `exclude` cannot be resolved to a valid label.
"""
include = self.resolve_all(include) if include is not None else self.claims
exclude = self.resolve_all(exclude) if exclude is not None else []

labels: list[str]

if with_depends:
labels = self._close_dependencies(labels=include, prune=exclude)
else:
labels = [label for label in include if label not in set(exclude)]

if ordered:
return self._sort_topologically(labels)

return labels

def _close_dependencies(self, labels: Iterable[str], prune: Iterable[str]) -> list[str]:
res: list[str] = []

pending = list(labels)
done = set(prune)

while pending:
label = pending.pop(0) # BFS

if label in done:
continue

res.append(label)
pending += self.claims[label].dependencies
done.add(label)

return res

def _sort_topologically(self, labels: list[str]) -> list[str]:
label_set = set(labels)
graph = {
label: [dep for dep in claim.dependencies if dep in label_set]
for label, claim in self.claims.items()
if label in labels
}
return list(TopologicalSorter(graph).static_order())
Loading

0 comments on commit 67eff5d

Please sign in to comment.