diff --git a/pyk/src/pyk/kast/manip.py b/pyk/src/pyk/kast/manip.py index dab6de52f8c..0d3ec94e165 100644 --- a/pyk/src/pyk/kast/manip.py +++ b/pyk/src/pyk/kast/manip.py @@ -25,6 +25,7 @@ var_occurrences, ) from .outer import KClaim, KDefinition, KFlatModule, KRule, KRuleLike +from .rewrite import indexed_rewrite if TYPE_CHECKING: from collections.abc import Callable, Collection, Iterable @@ -606,45 +607,6 @@ def apply_existential_substitutions(state: KInner, constraints: Iterable[KInner] return (Subst(subst)(state), [Subst(subst)(c) for c in new_constraints]) -def indexed_rewrite(kast: KInner, rewrites: Iterable[KRewrite]) -> KInner: - token_rewrites: list[KRewrite] = [] - apply_rewrites: dict[str, list[KRewrite]] = {} - other_rewrites: list[KRewrite] = [] - for r in rewrites: - if type(r.lhs) is KToken: - token_rewrites.append(r) - elif type(r.lhs) is KApply: - if r.lhs.label.name in apply_rewrites: - apply_rewrites[r.lhs.label.name].append(r) - else: - apply_rewrites[r.lhs.label.name] = [r] - else: - other_rewrites.append(r) - - def _apply_rewrites(_kast: KInner) -> KInner: - if type(_kast) is KToken: - for tr in token_rewrites: - _kast = tr.apply_top(_kast) - elif type(_kast) is KApply: - if _kast.label.name in apply_rewrites: - for ar in apply_rewrites[_kast.label.name]: - _kast = ar.apply_top(_kast) - else: - for _or in other_rewrites: - _kast = _or.apply_top(_kast) - return _kast - - orig_kast: KInner = kast - new_kast: KInner | None = None - while orig_kast != new_kast: - if new_kast is None: - new_kast = orig_kast - else: - orig_kast = new_kast - new_kast = bottom_up(_apply_rewrites, new_kast) - return new_kast - - def undo_aliases(definition: KDefinition, kast: KInner) -> KInner: aliases = [] for rule in definition.alias_rules: diff --git a/pyk/src/pyk/kast/outer.py b/pyk/src/pyk/kast/outer.py index 60f700049a7..435ee5436c9 100644 --- a/pyk/src/pyk/kast/outer.py +++ b/pyk/src/pyk/kast/outer.py @@ -32,6 +32,7 @@ top_down, ) from .kast import kast_term +from .rewrite import indexed_rewrite if TYPE_CHECKING: from collections.abc import Callable, Iterator, Mapping @@ -1586,14 +1587,10 @@ def _remove_config_var_lookups(_kast: KInner) -> KInner: else: raise ValueError(f'Cannot handle initializer for label: {prod_klabel}') - init_rewrites = [rule.body for rule in self.rules if Atts.INITIALIZER in rule.att] - old_init_config: KInner | None = None - while init_config != old_init_config: - old_init_config = init_config - for rew in init_rewrites: - assert type(rew) is KRewrite - init_config = rew(init_config) - + init_rewrites = [ + rule.body for rule in self.rules if Atts.INITIALIZER in rule.att and type(rule.body) is KRewrite + ] + init_config = indexed_rewrite(init_config, init_rewrites) init_config = top_down(_remove_config_var_lookups, init_config) return init_config diff --git a/pyk/src/pyk/kast/rewrite.py b/pyk/src/pyk/kast/rewrite.py new file mode 100644 index 00000000000..a6f33771758 --- /dev/null +++ b/pyk/src/pyk/kast/rewrite.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from .att import WithKAtt +from .inner import KApply, KToken, bottom_up + +if TYPE_CHECKING: + from collections.abc import Iterable + from typing import Final, TypeVar + + from .inner import KInner, KRewrite + + KI = TypeVar('KI', bound=KInner) + W = TypeVar('W', bound=WithKAtt) + +_LOGGER: Final = logging.getLogger(__name__) + + +def indexed_rewrite(kast: KInner, rewrites: Iterable[KRewrite]) -> KInner: + token_rewrites: list[KRewrite] = [] + apply_rewrites: dict[str, list[KRewrite]] = {} + other_rewrites: list[KRewrite] = [] + for r in rewrites: + if type(r.lhs) is KToken: + token_rewrites.append(r) + elif type(r.lhs) is KApply: + if r.lhs.label.name in apply_rewrites: + apply_rewrites[r.lhs.label.name].append(r) + else: + apply_rewrites[r.lhs.label.name] = [r] + else: + other_rewrites.append(r) + + def _apply_rewrites(_kast: KInner) -> KInner: + if type(_kast) is KToken: + for tr in token_rewrites: + _kast = tr.apply_top(_kast) + elif type(_kast) is KApply: + if _kast.label.name in apply_rewrites: + for ar in apply_rewrites[_kast.label.name]: + _kast = ar.apply_top(_kast) + else: + for _or in other_rewrites: + _kast = _or.apply_top(_kast) + return _kast + + orig_kast: KInner = kast + new_kast: KInner | None = None + while orig_kast != new_kast: + if new_kast is None: + new_kast = orig_kast + else: + orig_kast = new_kast + new_kast = bottom_up(_apply_rewrites, new_kast) + return new_kast diff --git a/pyk/src/pyk/ktool/kprove.py b/pyk/src/pyk/ktool/kprove.py index db350daa009..dfa9c6b9350 100644 --- a/pyk/src/pyk/ktool/kprove.py +++ b/pyk/src/pyk/ktool/kprove.py @@ -337,7 +337,8 @@ def _get_claim_module(_label: str) -> str | None: _dependency_label = f'{_module_name}.{_dependency_label}' _updated_dependencies.append(_dependency_label) if len(_updated_dependencies) > 0: - claim_labels.extend(_updated_dependencies) + if include_dependencies: + claim_labels.extend(_updated_dependencies) _claim = _claim.let(att=_claim.att.update([Atts.DEPENDS(','.join(_updated_dependencies))])) final_claims[claim_label] = _claim