Skip to content

Commit

Permalink
Replace simulation thread with simulation task in BatchContext
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-eq committed Sep 11, 2024
1 parent 22e95fc commit dd263fe
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 78 deletions.
133 changes: 60 additions & 73 deletions src/ert/simulator/batch_simulator_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,21 @@
import asyncio
import contextlib
import logging
import time
from collections import namedtuple
from dataclasses import dataclass
from enum import Enum, auto
from threading import Thread
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import numpy as np

from _ert.threading import ErtThread
from _ert.async_utils import new_event_loop
from ert.config import HookRuntime
from ert.enkf_main import create_run_path
from ert.ensemble_evaluator import Realization
from ert.runpaths import Runpaths
from ert.scheduler import JobState, Scheduler, create_driver
from ert.workflow_runner import WorkflowRunner

from ..run_arg import RunArg, create_run_arguments
from ..run_arg import create_run_arguments
from .forward_model_status import ForwardModelStatus

if TYPE_CHECKING:
Expand Down Expand Up @@ -70,59 +67,31 @@ def _slug(entity: str) -> str:
return "".join([x if x.isalnum() else "_" for x in entity.strip()])


def _run_forward_model(
ert_config: "ErtConfig",
scheduler: Scheduler,
run_args: List[RunArg],
) -> None:
# run simplestep
asyncio.run(_submit_and_run_jobqueue(ert_config, scheduler, run_args))


async def _submit_and_run_jobqueue(
ert_config: "ErtConfig",
scheduler: Scheduler,
run_args: List[RunArg],
) -> None:
max_runtime: Optional[int] = ert_config.analysis_config.max_runtime
if max_runtime == 0:
max_runtime = None
for run_arg in run_args:
if not run_arg.active:
continue
realization = Realization(
iens=run_arg.iens,
fm_steps=[],
active=True,
max_runtime=max_runtime,
run_arg=run_arg,
num_cpu=ert_config.preferred_num_cpu,
job_script=ert_config.queue_config.job_script,
realization_memory=ert_config.queue_config.realization_memory,
)
scheduler.set_realization(realization)

required_realizations = 0
if ert_config.queue_config.stop_long_running:
required_realizations = ert_config.analysis_config.minimum_required_realizations
with contextlib.suppress(asyncio.CancelledError):
await scheduler.execute(required_realizations)


@dataclass
class BatchContext:
result_keys: "Iterable[str]"
ert_config: "ErtConfig"
ensemble: Ensemble
mask: npt.NDArray[np.bool_]
itr: int
case_data: List[Tuple[Any, Any]]

def __post_init__(self) -> None:
"""
Handle which can be used to query status and results for batch simulation.
"""
ert_config = self.ert_config
"""
Handle which can be used to query status and results for batch simulation.
"""

def __init__(
self,
result_keys: "Iterable[str]",
ert_config: "ErtConfig",
ensemble: Ensemble,
mask: npt.NDArray[np.bool_],
itr: int,
case_data: List[Tuple[Any, Any]],
_loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
self.result_keys = result_keys
self.ert_config = ert_config
self.ensemble = ensemble
self.mask = mask
self.itr = itr
self.case_data = case_data
self._loop = _loop or new_event_loop()

asyncio.set_event_loop(self._loop)
self._sim_task_started = asyncio.Event()
driver = create_driver(ert_config.queue_config)
self._scheduler = Scheduler(
driver, max_running=self.ert_config.queue_config.max_running
Expand Down Expand Up @@ -160,37 +129,55 @@ def __post_init__(self) -> None:
)
for workflow in ert_config.hooked_workflows[HookRuntime.PRE_SIMULATION]:
WorkflowRunner(workflow, None, self.ensemble).run_blocking()
self._sim_thread = self._run_simulations_simple_step()
self._loop.run_until_complete(self.run_forward_model())

# Wait until the queue is active before we finish the creation
# to ensure sane job status while running
while self.running() and not self._scheduler.is_active():
time.sleep(0.1)
async def run_forward_model(self) -> None:
self._sim_task = self._loop.create_task(self._submit_and_run_jobqueue())
await self._sim_task_started.wait()

def __len__(self) -> int:
return len(self.mask)

def get_ensemble(self) -> Ensemble:
return self.ensemble

def _run_simulations_simple_step(self) -> Thread:
sim_thread = ErtThread(
target=lambda: _run_forward_model(
self.ert_config, self._scheduler, self.run_args
async def _submit_and_run_jobqueue(self) -> None:
self._sim_task_started.set()
max_runtime: Optional[int] = self.ert_config.analysis_config.max_runtime
if max_runtime == 0:
max_runtime = None
for run_arg in self.run_args:
if not run_arg.active:
continue
realization = Realization(
iens=run_arg.iens,
fm_steps=[],
active=True,
max_runtime=max_runtime,
run_arg=run_arg,
num_cpu=self.ert_config.preferred_num_cpu,
job_script=self.ert_config.queue_config.job_script,
realization_memory=self.ert_config.queue_config.realization_memory,
)
)
sim_thread.start()
return sim_thread
self._scheduler.set_realization(realization)
required_realizations = 0
if self.ert_config.queue_config.stop_long_running:
required_realizations = (
self.ert_config.analysis_config.minimum_required_realizations
)
with contextlib.suppress(asyncio.CancelledError):
await self._scheduler.execute(required_realizations)

def join(self) -> None:
"""
Will block until the simulation is complete.
"""
while self.running():
time.sleep(1)
self._loop.run_until_complete(self._sim_task)

def running(self) -> bool:
return self._sim_thread.is_alive() or self._scheduler.is_active()
is_running = not self._sim_task.done() or self._scheduler.is_active()
self._loop.run_until_complete(asyncio.sleep(0))
return is_running

@property
def status(self) -> Status:
Expand Down Expand Up @@ -320,7 +307,7 @@ def job_progress(self, iens: int) -> Optional[ForwardModelStatus]:

def stop(self) -> None:
self._scheduler.kill_all_jobs()
self._sim_thread.join()
self._loop.run_until_complete(self._sim_task)

def run_path(self, iens: int) -> str:
return self.run_args[iens].runpath
29 changes: 24 additions & 5 deletions tests/unit_tests/simulator/test_simulation_context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import asyncio

import pytest

from _ert.async_utils import get_running_loop
from ert import JobState, JobStatus
from ert.simulator import BatchContext
from tests.utils import wait_until


@pytest.mark.timeout(15)
@pytest.mark.parametrize(
"success_state, failure_state, status_check_method_name",
[
Expand Down Expand Up @@ -36,17 +40,32 @@ def test_simulation_context(
)

case_data = [(geo_id, {}) for geo_id in range(size)]
even_ctx = BatchContext([], ert_config, even_half, even_mask, 0, case_data)
odd_ctx = BatchContext([], ert_config, odd_half, odd_mask, 0, case_data)
event_loop = get_running_loop() # asyncio.get_running_loop()
even_ctx = BatchContext(
[], ert_config, even_half, even_mask, 0, case_data, _loop=event_loop
)
odd_ctx = BatchContext(
[], ert_config, odd_half, odd_mask, 0, case_data, _loop=event_loop
)

for iens in range(size):
if iens % 2 == 0:
assert getattr(even_ctx, status_check_method_name)(iens) != success_state
assert (
getattr(even_ctx, status_check_method_name)(iens).name
!= success_state.name
)
else:
assert getattr(odd_ctx, status_check_method_name)(iens) != success_state
assert (
getattr(odd_ctx, status_check_method_name)(iens).name
!= success_state.name
)

wait_until(lambda: not even_ctx.running() and not odd_ctx.running(), timeout=90)
async def wait_for_stopped():
while even_ctx.running() or odd_ctx.running(): # noqa: ASYNC110
await asyncio.sleep(0.5)

# event_loop.run_until_complete(wait_for_stopped())
wait_until(lambda: not even_ctx.running() and not odd_ctx.running(), timeout=90)
for iens in range(size):
if iens % 2 == 0:
assert even_ctx.run_args[iens].runpath.endswith(
Expand Down

0 comments on commit dd263fe

Please sign in to comment.