diff --git a/pyk/src/pyk/proof/proof.py b/pyk/src/pyk/proof/proof.py index d7903fd99d3..f8600ed09ec 100644 --- a/pyk/src/pyk/proof/proof.py +++ b/pyk/src/pyk/proof/proof.py @@ -344,43 +344,42 @@ def parallel_advance_proof( explored: set[PS] = set() iterations = 0 - main_prover = create_prover() + with create_prover() as main_prover: + main_prover.init_proof(proof) - main_prover.init_proof(proof) + with _ProverPool[P, PS, SR](create_prover=create_prover, max_workers=max_workers) as pool: - with _ProverPool[P, PS, SR](create_prover=create_prover, max_workers=max_workers) as pool: + def submit_steps(_steps: Iterable[PS]) -> None: + for step in _steps: + if step in explored: + continue + explored.add(step) + future: Future[Any] = pool.submit(step) # <-- schedule steps for execution + pending.add(future) - def submit_steps(_steps: Iterable[PS]) -> None: - for step in _steps: - if step in explored: - continue - explored.add(step) - future: Future[Any] = pool.submit(step) # <-- schedule steps for execution - pending.add(future) + submit_steps(proof.get_steps()) - submit_steps(proof.get_steps()) + while True: + if not pending: + break + done, _ = wait(pending, return_when='FIRST_COMPLETED') + future = done.pop() + proof_results = future.result() + for result in proof_results: + proof.commit(result) + proof.write_proof_data() + iterations += 1 + if max_iterations is not None and max_iterations <= iterations: + break + if fail_fast and proof.failed: + _LOGGER.warning(f'Terminating proof early because fail_fast is set: {proof.id}') + break + submit_steps(proof.get_steps()) + pending.remove(future) - while True: - if not pending: - break - done, _ = wait(pending, return_when='FIRST_COMPLETED') - future = done.pop() - proof_results = future.result() - for result in proof_results: - proof.commit(result) + if proof.failed: + proof.failure_info = main_prover.failure_info(proof) proof.write_proof_data() - iterations += 1 - if max_iterations is not None and max_iterations <= iterations: - break - if fail_fast and proof.failed: - _LOGGER.warning(f'Terminating proof early because fail_fast is set: {proof.id}') - break - submit_steps(proof.get_steps()) - pending.remove(future) - - if proof.failed: - proof.failure_info = main_prover.failure_info(proof) - proof.write_proof_data() class _ProverPool(ContextManager['_ProverPool'], Generic[P, PS, SR]): diff --git a/pyk/src/tests/integration/proof/test_imp.py b/pyk/src/tests/integration/proof/test_imp.py index d2f40148626..77f08067c3c 100644 --- a/pyk/src/tests/integration/proof/test_imp.py +++ b/pyk/src/tests/integration/proof/test_imp.py @@ -1425,3 +1425,52 @@ def test_all_path_reachability_prove_parallel( assert proof.status == proof_status assert leaf_number(proof) == expected_leaf_number + + def test_all_path_reachability_prove_parallel_resources( + self, + kprove: KProve, + tmp_path_factory: TempPathFactory, + create_prover: Callable[[int, Iterable[str]], Prover], + ) -> None: + + test_id = 'imp-simple-addition-1' + spec_file = K_FILES / 'imp-simple-spec.k' + spec_module = 'IMP-SIMPLE-SPEC' + claim_id = 'addition-1' + + with tmp_path_factory.mktemp(f'apr_tmp_proofs-{test_id}') as proof_dir: + spec_modules = kprove.get_claim_modules(Path(spec_file), spec_module_name=spec_module) + spec_label = f'{spec_module}.{claim_id}' + proofs = APRProof.from_spec_modules( + kprove.definition, + spec_modules, + spec_labels=[spec_label], + logs={}, + proof_dir=proof_dir, + ) + proof = single([p for p in proofs if p.id == spec_label]) + + _create_prover = partial(create_prover, 1, []) + + provers_created = 0 + + class MyAPRProver(APRProver): + provers_closed: int = 0 + + def close(self) -> None: + MyAPRProver.provers_closed += 1 + super().close() + + def create_prover_res_counter() -> APRProver: + nonlocal provers_created + provers_created += 1 + prover = _create_prover() + prover.__class__ = MyAPRProver + assert type(prover) is MyAPRProver + return prover + + parallel_advance_proof( + proof=proof, max_iterations=2, create_prover=create_prover_res_counter, max_workers=2 + ) + + assert provers_created == MyAPRProver.provers_closed