From dd263fe21e15fbaed0f5d2ced906f24182ba562f Mon Sep 17 00:00:00 2001 From: Jonathan Karlsen Date: Wed, 11 Sep 2024 13:45:34 +0200 Subject: [PATCH] Replace simulation thread with simulation task in BatchContext --- src/ert/simulator/batch_simulator_context.py | 133 ++++++++---------- .../simulator/test_simulation_context.py | 29 +++- 2 files changed, 84 insertions(+), 78 deletions(-) diff --git a/src/ert/simulator/batch_simulator_context.py b/src/ert/simulator/batch_simulator_context.py index 1fa670acfd3..6b5bcda8e58 100644 --- a/src/ert/simulator/batch_simulator_context.py +++ b/src/ert/simulator/batch_simulator_context.py @@ -3,16 +3,13 @@ import asyncio import contextlib import logging -import time from collections import namedtuple -from dataclasses import dataclass from enum import Enum, auto -from threading import Thread from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import numpy as np -from _ert.threading import ErtThread +from _ert.async_utils import new_event_loop from ert.config import HookRuntime from ert.enkf_main import create_run_path from ert.ensemble_evaluator import Realization @@ -20,7 +17,7 @@ from ert.scheduler import JobState, Scheduler, create_driver from ert.workflow_runner import WorkflowRunner -from ..run_arg import RunArg, create_run_arguments +from ..run_arg import create_run_arguments from .forward_model_status import ForwardModelStatus if TYPE_CHECKING: @@ -70,59 +67,31 @@ def _slug(entity: str) -> str: return "".join([x if x.isalnum() else "_" for x in entity.strip()]) -def _run_forward_model( - ert_config: "ErtConfig", - scheduler: Scheduler, - run_args: List[RunArg], -) -> None: - # run simplestep - asyncio.run(_submit_and_run_jobqueue(ert_config, scheduler, run_args)) - - -async def _submit_and_run_jobqueue( - ert_config: "ErtConfig", - scheduler: Scheduler, - run_args: List[RunArg], -) -> None: - max_runtime: Optional[int] = ert_config.analysis_config.max_runtime - if max_runtime == 0: - max_runtime = None - for run_arg in run_args: - if not run_arg.active: - continue - realization = Realization( - iens=run_arg.iens, - fm_steps=[], - active=True, - max_runtime=max_runtime, - run_arg=run_arg, - num_cpu=ert_config.preferred_num_cpu, - job_script=ert_config.queue_config.job_script, - realization_memory=ert_config.queue_config.realization_memory, - ) - scheduler.set_realization(realization) - - required_realizations = 0 - if ert_config.queue_config.stop_long_running: - required_realizations = ert_config.analysis_config.minimum_required_realizations - with contextlib.suppress(asyncio.CancelledError): - await scheduler.execute(required_realizations) - - -@dataclass class BatchContext: - result_keys: "Iterable[str]" - ert_config: "ErtConfig" - ensemble: Ensemble - mask: npt.NDArray[np.bool_] - itr: int - case_data: List[Tuple[Any, Any]] - - def __post_init__(self) -> None: - """ - Handle which can be used to query status and results for batch simulation. - """ - ert_config = self.ert_config + """ + Handle which can be used to query status and results for batch simulation. + """ + + def __init__( + self, + result_keys: "Iterable[str]", + ert_config: "ErtConfig", + ensemble: Ensemble, + mask: npt.NDArray[np.bool_], + itr: int, + case_data: List[Tuple[Any, Any]], + _loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + self.result_keys = result_keys + self.ert_config = ert_config + self.ensemble = ensemble + self.mask = mask + self.itr = itr + self.case_data = case_data + self._loop = _loop or new_event_loop() + + asyncio.set_event_loop(self._loop) + self._sim_task_started = asyncio.Event() driver = create_driver(ert_config.queue_config) self._scheduler = Scheduler( driver, max_running=self.ert_config.queue_config.max_running @@ -160,12 +129,11 @@ def __post_init__(self) -> None: ) for workflow in ert_config.hooked_workflows[HookRuntime.PRE_SIMULATION]: WorkflowRunner(workflow, None, self.ensemble).run_blocking() - self._sim_thread = self._run_simulations_simple_step() + self._loop.run_until_complete(self.run_forward_model()) - # Wait until the queue is active before we finish the creation - # to ensure sane job status while running - while self.running() and not self._scheduler.is_active(): - time.sleep(0.1) + async def run_forward_model(self) -> None: + self._sim_task = self._loop.create_task(self._submit_and_run_jobqueue()) + await self._sim_task_started.wait() def __len__(self) -> int: return len(self.mask) @@ -173,24 +141,43 @@ def __len__(self) -> int: def get_ensemble(self) -> Ensemble: return self.ensemble - def _run_simulations_simple_step(self) -> Thread: - sim_thread = ErtThread( - target=lambda: _run_forward_model( - self.ert_config, self._scheduler, self.run_args + async def _submit_and_run_jobqueue(self) -> None: + self._sim_task_started.set() + max_runtime: Optional[int] = self.ert_config.analysis_config.max_runtime + if max_runtime == 0: + max_runtime = None + for run_arg in self.run_args: + if not run_arg.active: + continue + realization = Realization( + iens=run_arg.iens, + fm_steps=[], + active=True, + max_runtime=max_runtime, + run_arg=run_arg, + num_cpu=self.ert_config.preferred_num_cpu, + job_script=self.ert_config.queue_config.job_script, + realization_memory=self.ert_config.queue_config.realization_memory, ) - ) - sim_thread.start() - return sim_thread + self._scheduler.set_realization(realization) + required_realizations = 0 + if self.ert_config.queue_config.stop_long_running: + required_realizations = ( + self.ert_config.analysis_config.minimum_required_realizations + ) + with contextlib.suppress(asyncio.CancelledError): + await self._scheduler.execute(required_realizations) def join(self) -> None: """ Will block until the simulation is complete. """ - while self.running(): - time.sleep(1) + self._loop.run_until_complete(self._sim_task) def running(self) -> bool: - return self._sim_thread.is_alive() or self._scheduler.is_active() + is_running = not self._sim_task.done() or self._scheduler.is_active() + self._loop.run_until_complete(asyncio.sleep(0)) + return is_running @property def status(self) -> Status: @@ -320,7 +307,7 @@ def job_progress(self, iens: int) -> Optional[ForwardModelStatus]: def stop(self) -> None: self._scheduler.kill_all_jobs() - self._sim_thread.join() + self._loop.run_until_complete(self._sim_task) def run_path(self, iens: int) -> str: return self.run_args[iens].runpath diff --git a/tests/unit_tests/simulator/test_simulation_context.py b/tests/unit_tests/simulator/test_simulation_context.py index 6a870c163c0..a292bc90c5e 100644 --- a/tests/unit_tests/simulator/test_simulation_context.py +++ b/tests/unit_tests/simulator/test_simulation_context.py @@ -1,10 +1,14 @@ +import asyncio + import pytest +from _ert.async_utils import get_running_loop from ert import JobState, JobStatus from ert.simulator import BatchContext from tests.utils import wait_until +@pytest.mark.timeout(15) @pytest.mark.parametrize( "success_state, failure_state, status_check_method_name", [ @@ -36,17 +40,32 @@ def test_simulation_context( ) case_data = [(geo_id, {}) for geo_id in range(size)] - even_ctx = BatchContext([], ert_config, even_half, even_mask, 0, case_data) - odd_ctx = BatchContext([], ert_config, odd_half, odd_mask, 0, case_data) + event_loop = get_running_loop() # asyncio.get_running_loop() + even_ctx = BatchContext( + [], ert_config, even_half, even_mask, 0, case_data, _loop=event_loop + ) + odd_ctx = BatchContext( + [], ert_config, odd_half, odd_mask, 0, case_data, _loop=event_loop + ) for iens in range(size): if iens % 2 == 0: - assert getattr(even_ctx, status_check_method_name)(iens) != success_state + assert ( + getattr(even_ctx, status_check_method_name)(iens).name + != success_state.name + ) else: - assert getattr(odd_ctx, status_check_method_name)(iens) != success_state + assert ( + getattr(odd_ctx, status_check_method_name)(iens).name + != success_state.name + ) - wait_until(lambda: not even_ctx.running() and not odd_ctx.running(), timeout=90) + async def wait_for_stopped(): + while even_ctx.running() or odd_ctx.running(): # noqa: ASYNC110 + await asyncio.sleep(0.5) + # event_loop.run_until_complete(wait_for_stopped()) + wait_until(lambda: not even_ctx.running() and not odd_ctx.running(), timeout=90) for iens in range(size): if iens % 2 == 0: assert even_ctx.run_args[iens].runpath.endswith(