diff --git a/pyk/src/pyk/kcfg/kcfg.py b/pyk/src/pyk/kcfg/kcfg.py index 825dd7c328..57c0e95dc9 100644 --- a/pyk/src/pyk/kcfg/kcfg.py +++ b/pyk/src/pyk/kcfg/kcfg.py @@ -559,6 +559,8 @@ def extend( extend_result: KCFGExtendResult, node: KCFG.Node, logs: dict[int, tuple[LogEntry, ...]], + *, + optimize_kcfg: bool, ) -> None: def log(message: str, *, warning: bool = False) -> None: @@ -584,10 +586,25 @@ def log(message: str, *, warning: bool = False) -> None: log(f'abstraction node: {node.id}') case Step(cterm, depth, next_node_logs, rule_labels, _): + node_id = node.id next_node = self.create_node(cterm) + # Optimization for steps consists of on-the-fly merging of consecutive edges and can + # be performed only if the current node has a single predecessor connected by an Edge + if ( + optimize_kcfg + and (len(predecessors := self.predecessors(target_id=node.id)) == 1) + and isinstance(in_edge := predecessors[0], KCFG.Edge) + ): + # The existing edge is removed and the step parameters are updated accordingly + self.remove_edge(in_edge.source.id, node.id) + node_id = in_edge.source.id + depth += in_edge.depth + rule_labels = list(in_edge.rules) + rule_labels + next_node_logs = logs[node.id] + next_node_logs if node.id in logs else next_node_logs + self.remove_node(node.id) + self.create_edge(node_id, next_node.id, depth, rule_labels) logs[next_node.id] = next_node_logs - self.create_edge(node.id, next_node.id, depth, rules=rule_labels) - log(f'basic block at depth {depth}: {node.id} --> {next_node.id}') + log(f'basic block at depth {depth}: {node_id} --> {next_node.id}') case Branch(branches, _): branch_node_ids = self.split_on_constraints(node.id, branches) diff --git a/pyk/src/pyk/proof/reachability.py b/pyk/src/pyk/proof/reachability.py index d7e706ca93..1ce8cd5658 100644 --- a/pyk/src/pyk/proof/reachability.py +++ b/pyk/src/pyk/proof/reachability.py @@ -40,6 +40,7 @@ class APRProofResult: node_id: int prior_loops_cache_update: tuple[int, ...] + optimize_kcfg: bool @dataclass @@ -220,6 +221,7 @@ def commit(self, result: APRProofResult) -> None: assert result.cached_node_id in self._next_steps self.kcfg.extend( extend_result=self._next_steps.pop(result.cached_node_id), + optimize_kcfg=result.optimize_kcfg, node=self.kcfg.node(result.node_id), logs=self.logs, ) @@ -230,6 +232,7 @@ def commit(self, result: APRProofResult) -> None: self._next_steps[result.node_id] = result.extension_to_cache self.kcfg.extend( extend_result=result.extension_to_apply, + optimize_kcfg=result.optimize_kcfg, node=self.kcfg.node(result.node_id), logs=self.logs, ) @@ -715,6 +718,7 @@ class APRProver(Prover[APRProof, APRProofStep, APRProofResult]): assume_defined: bool kcfg_explore: KCFGExplore extra_module: KFlatModule | None + optimize_kcfg: bool def __init__( self, @@ -727,6 +731,7 @@ def __init__( direct_subproof_rules: bool = False, assume_defined: bool = False, extra_module: KFlatModule | None = None, + optimize_kcfg: bool = False, ) -> None: self.kcfg_explore = kcfg_explore @@ -739,6 +744,7 @@ def __init__( self.direct_subproof_rules = direct_subproof_rules self.assume_defined = assume_defined self.extra_module = extra_module + self.optimize_kcfg = optimize_kcfg def close(self) -> None: self.kcfg_explore.cterm_symbolic._kore_client.close() @@ -808,14 +814,24 @@ def step_proof(self, step: APRProofStep) -> list[APRProofResult]: _LOGGER.info(f'Prior loop heads for node {step.node.id}: {(step.node.id, prior_loops)}') if len(prior_loops) > step.bmc_depth: _LOGGER.warning(f'Bounded node {step.proof_id}: {step.node.id} at bmc depth {step.bmc_depth}') - return [APRProofBoundedResult(node_id=step.node.id, prior_loops_cache_update=prior_loops)] + return [ + APRProofBoundedResult( + node_id=step.node.id, optimize_kcfg=self.optimize_kcfg, prior_loops_cache_update=prior_loops + ) + ] # Check if the current node and target are terminal is_terminal = self.kcfg_explore.kcfg_semantics.is_terminal(step.node.cterm) target_is_terminal = self.kcfg_explore.kcfg_semantics.is_terminal(step.target.cterm) terminal_result: list[APRProofResult] = ( - [APRProofTerminalResult(node_id=step.node.id, prior_loops_cache_update=prior_loops)] if is_terminal else [] + [ + APRProofTerminalResult( + node_id=step.node.id, optimize_kcfg=self.optimize_kcfg, prior_loops_cache_update=prior_loops + ) + ] + if is_terminal + else [] ) # Subsumption is checked if and only if the target node @@ -826,7 +842,12 @@ def step_proof(self, step: APRProofStep) -> list[APRProofResult]: # Information about the subsumed node being terminal must be returned # so that the set of terminal nodes is correctly updated return terminal_result + [ - APRProofSubsumeResult(csubst=csubst, node_id=step.node.id, prior_loops_cache_update=prior_loops) + APRProofSubsumeResult( + csubst=csubst, + optimize_kcfg=self.optimize_kcfg, + node_id=step.node.id, + prior_loops_cache_update=prior_loops, + ) ] if is_terminal: @@ -849,6 +870,7 @@ def step_proof(self, step: APRProofStep) -> list[APRProofResult]: APRProofUseCacheResult( node_id=step.node.id, cached_node_id=step.use_cache, + optimize_kcfg=self.optimize_kcfg, prior_loops_cache_update=prior_loops, ) ] @@ -876,6 +898,7 @@ def step_proof(self, step: APRProofStep) -> list[APRProofResult]: extension_to_apply=extend_results[0], extension_to_cache=extend_results[1], prior_loops_cache_update=prior_loops, + optimize_kcfg=self.optimize_kcfg, ) ] @@ -885,6 +908,7 @@ def step_proof(self, step: APRProofStep) -> list[APRProofResult]: node_id=step.node.id, extension_to_apply=extend_results[0], prior_loops_cache_update=prior_loops, + optimize_kcfg=self.optimize_kcfg, ) ] diff --git a/pyk/src/tests/integration/proof/test_imp.py b/pyk/src/tests/integration/proof/test_imp.py index 3ad17c599c..517051b001 100644 --- a/pyk/src/tests/integration/proof/test_imp.py +++ b/pyk/src/tests/integration/proof/test_imp.py @@ -566,6 +566,35 @@ def same_loop(self, c1: CTerm, c2: CTerm) -> bool: ), ) +APR_PROVE_WITH_KCFG_OPTIMS_TEST_DATA: Iterable[ + tuple[str, Path, str, str, int | None, int | None, Iterable[str], bool, ProofStatus, int] +] = ( + ( + 'imp-simple-sum-100', + K_FILES / 'imp-simple-spec.k', + 'IMP-SIMPLE-SPEC', + 'sum-100', + None, + None, + [], + True, + ProofStatus.PASSED, + 3, + ), + ( + 'imp-simple-long-branches', + K_FILES / 'imp-simple-spec.k', + 'IMP-SIMPLE-SPEC', + 'long-branches', + None, + None, + [], + True, + ProofStatus.PASSED, + 7, + ), +) + PATH_CONSTRAINTS_TEST_DATA: Iterable[ tuple[str, Path, str, str, int | None, int | None, Iterable[str], Iterable[str], str] ] = ( @@ -918,6 +947,55 @@ def test_all_path_reachability_prove( assert proof.status == proof_status assert leaf_number(proof) == expected_leaf_number + @pytest.mark.parametrize( + 'test_id,spec_file,spec_module,claim_id,max_iterations,max_depth,cut_rules,admit_deps,proof_status,expected_nodes', + APR_PROVE_WITH_KCFG_OPTIMS_TEST_DATA, + ids=[test_id for test_id, *_ in APR_PROVE_WITH_KCFG_OPTIMS_TEST_DATA], + ) + def test_all_path_reachability_prove_with_kcfg_optims( + self, + kprove: KProve, + kcfg_explore: KCFGExplore, + test_id: str, + spec_file: str, + spec_module: str, + claim_id: str, + max_iterations: int | None, + max_depth: int | None, + cut_rules: Iterable[str], + admit_deps: bool, + proof_status: ProofStatus, + expected_nodes: int, + tmp_path_factory: TempPathFactory, + ) -> None: + proof_dir = tmp_path_factory.mktemp(f'apr_tmp_proofs-{test_id}') + spec_modules = kprove.parse_modules(Path(spec_file), module_name=spec_module) + spec_label = f'{spec_module}.{claim_id}' + proofs = APRProof.from_spec_modules( + kprove.definition, + spec_modules, + spec_labels=[spec_label], + logs={}, + proof_dir=proof_dir, + ) + proof = single([p for p in proofs if p.id == spec_label]) + if admit_deps: + for subproof in proof.subproofs: + subproof.admit() + subproof.write_proof_data() + + prover = APRProver( + kcfg_explore=kcfg_explore, execute_depth=max_depth, cut_point_rules=cut_rules, optimize_kcfg=True + ) + prover.advance_proof(proof, max_iterations=max_iterations) + + kcfg_show = KCFGShow(kprove, node_printer=APRProofNodePrinter(proof, kprove, full_printer=True)) + cfg_lines = kcfg_show.show(proof.kcfg) + _LOGGER.info('\n'.join(cfg_lines)) + + assert proof.status == proof_status + assert len(proof.kcfg._nodes) == expected_nodes + def test_terminal_node_subsumption( self, kprove: KProve,