Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove callbacks from job_queue_node #6118

Merged
merged 1 commit into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions src/ert/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@
import logging
import time
from pathlib import Path
from typing import Callable, Iterable, Tuple
from typing import Iterable

from ert.config import EnsembleConfig, ParameterConfig, SummaryConfig
from ert.run_arg import RunArg

from .load_status import LoadResult, LoadStatus
from .realization_state import RealizationState

CallbackArgs = Tuple[RunArg, EnsembleConfig]
Callback = Callable[[RunArg, EnsembleConfig], LoadResult]

logger = logging.getLogger(__name__)


Expand All @@ -28,8 +25,8 @@
try:
start_time = time.perf_counter()
logger.info(f"Starting to load parameter: {config_node.name}")
ds = config_node.read_from_runpath(Path(run_arg.runpath), run_arg.iens)

Check failure on line 28 in src/ert/callbacks.py

View workflow job for this annotation

GitHub Actions / annotate-python-linting

line too long (83 > 79 characters)
run_arg.ensemble_storage.save_parameters(config_node.name, run_arg.iens, ds)

Check failure on line 29 in src/ert/callbacks.py

View workflow job for this annotation

GitHub Actions / annotate-python-linting

line too long (88 > 79 characters)
logger.info(
f"Saved {config_node.name} to storage",
extra={"Time": f"{(time.perf_counter() - start_time):.4f}s"},
Expand All @@ -53,7 +50,7 @@
start_time = time.perf_counter()
logger.info(f"Starting to load response: {config.name}")
ds = config.read_from_file(run_arg.runpath, run_arg.iens)
run_arg.ensemble_storage.save_response(config.name, ds, run_arg.iens)

Check failure on line 53 in src/ert/callbacks.py

View workflow job for this annotation

GitHub Actions / annotate-python-linting

line too long (81 > 79 characters)
logger.info(
f"Saved {config.name} to storage",
extra={"Time": f"{(time.perf_counter() - start_time):.4f}s"},
Expand All @@ -77,14 +74,14 @@
if run_arg.itr == 0:
parameters_result = _read_parameters(
run_arg,
run_arg.ensemble_storage.experiment.parameter_configuration.values(),

Check failure on line 77 in src/ert/callbacks.py

View workflow job for this annotation

GitHub Actions / annotate-python-linting

line too long (85 > 79 characters)
)

if parameters_result.status == LoadStatus.LOAD_SUCCESSFUL:
response_result = _write_responses_to_storage(ens_conf, run_arg)

except Exception as err:
logging.exception(f"Failed to load results for realization {run_arg.iens}")

Check failure on line 84 in src/ert/callbacks.py

View workflow job for this annotation

GitHub Actions / annotate-python-linting

line too long (83 > 79 characters)
parameters_result = LoadResult(
LoadStatus.LOAD_FAILURE,
"Failed to load results for realization "
Expand All @@ -102,8 +99,3 @@
)

return final_result


def forward_model_exit(run_arg: RunArg, _: EnsembleConfig) -> LoadResult:
run_arg.ensemble_storage.state_map[run_arg.iens] = RealizationState.LOAD_FAILURE
return LoadResult(None, "")
12 changes: 4 additions & 8 deletions src/ert/ensemble_evaluator/_builder/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,12 @@
from ert.async_utils import get_event_loop
from ert.ensemble_evaluator import identifiers
from ert.job_queue import Driver, JobQueue
from ert.load_status import LoadResult

from .._wait_for_evaluator import wait_for_evaluator
from ._ensemble import Ensemble

if TYPE_CHECKING:
from ert.callbacks import Callback
from ert.config import AnalysisConfig, EnsembleConfig, QueueConfig
from ert.run_arg import RunArg
from ert.config import AnalysisConfig, QueueConfig

from ..config import EvaluatorServerConfig
from ._realization import Realization
Expand Down Expand Up @@ -62,7 +59,7 @@
if not analysis_config:
raise ValueError(f"{self} needs analysis_config")
self._job_queue = JobQueue(
Driver.create_driver(queue_config), max_submit=queue_config.max_submit

Check failure on line 62 in src/ert/ensemble_evaluator/_builder/_legacy.py

View workflow job for this annotation

GitHub Actions / annotate-python-linting

line too long (82 > 79 characters)
)
self._analysis_config = analysis_config
self._config: Optional[EvaluatorServerConfig] = None
Expand Down Expand Up @@ -98,7 +95,7 @@
status_tab = {
identifiers.EVTYPE_ENSEMBLE_STARTED: "ENSEMBLE_STARTED",
identifiers.EVTYPE_ENSEMBLE_FAILED: "ENSEMBLE_FAILED",
identifiers.EVTYPE_ENSEMBLE_CANCELLED: "ENSEMBLE_CANCELLED",

Check failure on line 98 in src/ert/ensemble_evaluator/_builder/_legacy.py

View workflow job for this annotation

GitHub Actions / annotate-python-linting

line too long (80 > 79 characters)
identifiers.EVTYPE_ENSEMBLE_STOPPED: "ENSEMBLE_STOPPED",
identifiers.EVTYPE_FM_STEP_TIMEOUT: "STEP_TIMEOUT",
}
Expand All @@ -118,7 +115,7 @@

else:

def event_builder(status: str, real_id: Optional[int] = None) -> CloudEvent:

Check failure on line 118 in src/ert/ensemble_evaluator/_builder/_legacy.py

View workflow job for this annotation

GitHub Actions / annotate-python-linting

line too long (88 > 79 characters)
source = f"/ert/ensemble/{self.id_}"
if real_id is not None:
source += f"/real/{real_id}/step/0"
Expand All @@ -137,12 +134,11 @@
timeout_queue: asyncio.Queue[MsgType],
cloudevent_unary_send: Callable[[MsgType], Awaitable[None]],
event_generator: Callable[[str, Optional[int]], MsgType],
) -> Tuple[Callback, asyncio.Task[None]]:
def on_timeout(run_args: RunArg, _: EnsembleConfig) -> LoadResult:
) -> Tuple[Callable[[int], None], asyncio.Task[None]]:
def on_timeout(iens: int) -> None:
timeout_queue.put_nowait(
event_generator(identifiers.EVTYPE_FM_STEP_TIMEOUT, run_args.iens)
event_generator(identifiers.EVTYPE_FM_STEP_TIMEOUT, iens)
)
return LoadResult(None, "timed out")

async def send_timeout_message() -> None:
while True:
Expand All @@ -152,7 +148,7 @@
assert self._config # mypy
await cloudevent_unary_send(timeout_cloudevent)

send_timeout_future = get_event_loop().create_task(send_timeout_message())

Check failure on line 151 in src/ert/ensemble_evaluator/_builder/_legacy.py

View workflow job for this annotation

GitHub Actions / annotate-python-linting

line too long (82 > 79 characters)

return on_timeout, send_timeout_future

Expand Down Expand Up @@ -183,7 +179,7 @@
raise ValueError("no config")

# The cloudevent_unary_send only accepts a cloud event, but in order to
# send cloud events over the network, we need token, URI and cert. These are

Check failure on line 182 in src/ert/ensemble_evaluator/_builder/_legacy.py

View workflow job for this annotation

GitHub Actions / annotate-python-linting

line too long (84 > 79 characters)
# not known until evaluate() is called and _config is set. So in a hacky
# fashion, we create the partialmethod (bound partial) here, after evaluate().
# Note that this is the "sync" version of evaluate(), and that the "async"
Expand Down
8 changes: 3 additions & 5 deletions src/ert/ensemble_evaluator/_builder/_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

SOURCE_TEMPLATE_STEP = "/step/{step_id}"
if TYPE_CHECKING:
from ert.callbacks import Callback, CallbackArgs
from ert.config.ensemble_config import EnsembleConfig
from ert.run_arg import RunArg


Expand All @@ -26,11 +26,9 @@ class LegacyStep:
jobs: Sequence[LegacyJob]
name: str
max_runtime: Optional[int]
callback_arguments: CallbackArgs
done_callback: Callback
exit_callback: Callback
run_arg: "RunArg"
ensemble_config: "EnsembleConfig"
num_cpu: int
run_path: Path
job_script: str
job_name: str
run_arg: RunArg
38 changes: 20 additions & 18 deletions src/ert/job_queue/job_queue_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,24 @@
import random
import time
from threading import Lock, Semaphore, Thread
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Callable, Optional

from cwrap import BaseCClass
from ecl.util.util import StringList

from ert._clib.queue import _refresh_status # pylint: disable=import-error
from ert.callbacks import forward_model_ok
from ert.load_status import LoadStatus

from ..realization_state import RealizationState
from . import ResPrototype
from .job_status import JobStatus
from .submit_status import SubmitStatus
from .thread_status import ThreadStatus

if TYPE_CHECKING:
from ert.callbacks import Callback, CallbackArgs

from ..config import EnsembleConfig
from ..run_arg import RunArg
from .driver import Driver

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -94,16 +96,14 @@ def __init__(
num_cpu: int,
status_file: str,
exit_file: str,
done_callback_function: Callback,
exit_callback_function: Callback,
callback_arguments: CallbackArgs,
run_arg: "RunArg",
ensemble_config: "EnsembleConfig",
max_runtime: Optional[int] = None,
callback_timeout: Optional[Callback] = None,
callback_timeout: Optional[Callable[[int], None]] = None,
):
self.done_callback_function = done_callback_function
self.exit_callback_function = exit_callback_function
self.callback_timeout = callback_timeout
self.callback_arguments = callback_arguments
self.run_arg = run_arg
self.ensemble_config = ensemble_config
argc = 1
argv = StringList()
argv.append(run_path)
Expand Down Expand Up @@ -177,8 +177,8 @@ def submit(self, driver: "Driver") -> SubmitStatus:
return self._submit(driver)

def run_done_callback(self) -> Optional[LoadStatus]:
callback_status, status_msg = self.done_callback_function(
*self.callback_arguments
callback_status, status_msg = forward_model_ok(
self.run_arg, self.ensemble_config
)
if callback_status == LoadStatus.LOAD_SUCCESSFUL:
self._set_queue_status(JobStatus.SUCCESS)
Expand All @@ -194,10 +194,12 @@ def run_done_callback(self) -> Optional[LoadStatus]:

def run_timeout_callback(self) -> None:
if self.callback_timeout:
self.callback_timeout(*self.callback_arguments)
self.callback_timeout(self.run_arg.iens)

def run_exit_callback(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we avoid the string "callback" altogether?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just a first pass. we will do more cleanup in later PRs

self.exit_callback_function(*self.callback_arguments)
self.run_arg.ensemble_storage.state_map[
self.run_arg.iens
] = RealizationState.LOAD_FAILURE

def is_running(self, given_status: Optional[JobStatus] = None) -> bool:
status = given_status or self.status
Expand Down Expand Up @@ -297,7 +299,7 @@ def _handle_end_status(
if end_status == JobStatus.DONE:
with pool_sema:
logger.info(
f"Realization: {self.callback_arguments[0].iens} complete, "
f"Realization: {self.run_arg.iens} complete, "
"starting to load results"
)
self.run_done_callback()
Expand All @@ -310,19 +312,19 @@ def _handle_end_status(
elif current_status in self.RESUBMIT_STATES:
if self.submit_attempt < max_submit:
logger.warning(
f"Realization: {self.callback_arguments[0].iens} "
f"Realization: {self.run_arg.iens} "
f"failed with: {self._status_msg}, resubmitting"
)
self._transition_status(ThreadStatus.READY, current_status)
else:
self._transition_to_failure(
message=f"Realization: {self.callback_arguments[0].iens} "
message=f"Realization: {self.run_arg.iens} "
"failed after reaching max submit"
f" ({max_submit}):\n\t{self._status_msg}"
)
elif current_status in self.FAILURE_STATES:
self._transition_to_failure(
message=f"Realization: {self.callback_arguments[0].iens} "
message=f"Realization: {self.run_arg.iens} "
f"failed with: {self._status_msg}"
)
else:
Expand Down
15 changes: 5 additions & 10 deletions src/ert/job_queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from . import ResPrototype

if TYPE_CHECKING:
from ert.callbacks import Callback
from ert.config import ErtConfig
from ert.ensemble_evaluator import LegacyStep
from ert.run_arg import RunArg
Expand Down Expand Up @@ -484,8 +483,6 @@ def add_job_from_run_arg(
run_arg: "RunArg",
ert_config: "ErtConfig",
max_runtime: Optional[int],
ok_cb: Callback,
exit_cb: Callback,
num_cpu: int,
) -> None:
job_name = run_arg.job_name
Expand All @@ -499,9 +496,8 @@ def add_job_from_run_arg(
num_cpu=num_cpu,
status_file=self.status_file,
exit_file=self.exit_file,
done_callback_function=ok_cb,
exit_callback_function=exit_cb,
callback_arguments=(run_arg, ert_config.ensemble_config),
run_arg=run_arg,
ensemble_config=ert_config.ensemble_config,
max_runtime=max_runtime,
)

Expand All @@ -512,7 +508,7 @@ def add_job_from_run_arg(
def add_ee_stage(
self,
stage: "LegacyStep",
callback_timeout: Optional[Callback] = None,
callback_timeout: Optional[Callable[[int], None]] = None,
) -> None:
job = JobQueueNode(
job_script=stage.job_script,
Expand All @@ -521,9 +517,8 @@ def add_ee_stage(
num_cpu=stage.num_cpu,
status_file=self.status_file,
exit_file=self.exit_file,
done_callback_function=stage.done_callback,
exit_callback_function=stage.exit_callback,
callback_arguments=stage.callback_arguments,
run_arg=stage.run_arg,
ensemble_config=stage.ensemble_config,
max_runtime=stage.max_runtime,
callback_timeout=callback_timeout,
)
Expand Down
12 changes: 3 additions & 9 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from cloudevents.http import CloudEvent

import _ert_com_protocol
from ert.callbacks import forward_model_exit, forward_model_ok
from ert.cli import MODULE_MODE
from ert.config import HookRuntime
from ert.enkf_main import EnKFMain
Expand Down Expand Up @@ -389,23 +388,18 @@ def _build_ensemble(
self.ert().resConfig().forward_model_list
)
]
real.active(True).add_step(
real.add_step(
LegacyStep(
id_="0",
jobs=jobs,
name="legacy step",
max_runtime=self.ert().analysisConfig().max_runtime,
callback_arguments=(
run_arg,
self.ert().resConfig().ensemble_config,
),
done_callback=forward_model_ok,
exit_callback=forward_model_exit,
run_arg=run_arg,
ensemble_config=self.ert().resConfig().ensemble_config,
num_cpu=self.ert().get_num_cpu(),
run_path=Path(run_arg.runpath),
job_script=self.ert().resConfig().queue_config.job_script,
job_name=run_arg.job_name,
run_arg=run_arg,
)
)
builder.add_realization(real)
Expand Down
3 changes: 0 additions & 3 deletions src/ert/simulator/simulation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from time import sleep
from typing import TYPE_CHECKING, Any, List, Optional, Tuple

from ert.callbacks import forward_model_exit, forward_model_ok
from ert.config import HookRuntime
from ert.job_queue import Driver, JobQueue, JobQueueManager, RunStatus
from ert.realization_state import RealizationState
Expand Down Expand Up @@ -49,8 +48,6 @@ def _run_forward_model(
run_arg,
ert.resConfig(),
max_runtime,
forward_model_ok,
forward_model_exit,
ert.get_num_cpu(),
)

Expand Down
16 changes: 12 additions & 4 deletions tests/unit_tests/ensemble_evaluator/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ert.ensemble_evaluator.config import EvaluatorServerConfig
from ert.ensemble_evaluator.evaluator import EnsembleEvaluator
from ert.ensemble_evaluator.snapshot import SnapshotBuilder
from ert.job_queue import JobQueueNode
from ert.load_status import LoadStatus

from .ensemble_evaluator_utils import TestEnsemble
Expand Down Expand Up @@ -66,7 +67,16 @@ def queue_config_fixture():

@pytest.fixture
def make_ensemble_builder(queue_config):
def _make_ensemble_builder(tmpdir, num_reals, num_jobs, job_sleep=0):
def _make_ensemble_builder(monkeypatch, tmpdir, num_reals, num_jobs, job_sleep=0):
monkeypatch.setattr(
ert.job_queue.job_queue_node,
"forward_model_ok",
lambda _, _b: (LoadStatus.LOAD_SUCCESSFUL, ""),
)
monkeypatch.setattr(
JobQueueNode, "run_exit_callback", lambda _: (LoadStatus.LOAD_FAILURE, "")
)

builder = ert.ensemble_evaluator.EnsembleBuilder()
with tmpdir.as_cwd():
ext_job_list = []
Expand Down Expand Up @@ -121,11 +131,9 @@ class RunArg:
job_script="job_dispatch.py",
max_runtime=10,
run_arg=Mock(iens=iens),
done_callback=lambda _, _b: (LoadStatus.LOAD_SUCCESSFUL, ""),
exit_callback=lambda _, _b: (LoadStatus.LOAD_FAILURE, ""),
# the first callback_argument is expected to be a run_arg
# from the run_arg, the queue wants to access the iens prop
callback_arguments=(RunArg(iens), None),
ensemble_config=None,
run_path=run_path,
num_cpu=1,
name="dummy step",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def __init__(self, _iter, reals, steps, jobs, id_):
],
name=f"step-{step_no}",
max_runtime=0,
callback_arguments=(),
done_callback=None,
exit_callback=None,
ensemble_config=None,
num_cpu=0,
run_path=None,
run_arg=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ async def _handler(websocket, path):
@pytest.mark.asyncio
@pytest.mark.timeout(60)
async def test_happy_path(
tmpdir, unused_tcp_port, event_loop, make_ensemble_builder, queue_config, caplog
tmpdir,
unused_tcp_port,
event_loop,
make_ensemble_builder,
queue_config,
caplog,
monkeypatch,
):
asyncio.set_event_loop(event_loop)
host = "localhost"
Expand All @@ -44,7 +50,7 @@ async def test_happy_path(
mock_ws_task = get_event_loop().create_task(mock_ws(host, unused_tcp_port, done))
await wait_for_evaluator(base_url=url, timeout=5)

ensemble = make_ensemble_builder(tmpdir, 1, 1).build()
ensemble = make_ensemble_builder(monkeypatch, tmpdir, 1, 1).build()
queue = JobQueue(
Driver.create_driver(queue_config), max_submit=queue_config.max_submit
)
Expand Down
4 changes: 1 addition & 3 deletions tests/unit_tests/ensemble_evaluator/test_ensemble_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def test_build_ensemble(active_real):
job_script="job_script",
job_name="job_name",
num_cpu=1,
callback_arguments=MagicMock(),
done_callback=MagicMock(),
exit_callback=MagicMock(),
ensemble_config=MagicMock(),
jobs=[
LegacyJob(
ext_job=MagicMock(),
Expand Down
Loading
Loading