Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plugin hook for configuring an activate script on the compute side #9273

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
ForwardModelStepKeys,
HistorySource,
HookRuntime,
QueueSystemWithGeneric,
init_forward_model_schema,
init_site_config_schema,
init_user_config_schema,
Expand Down Expand Up @@ -260,6 +261,7 @@ class ErtConfig:
DEFAULT_RUNPATH_FILE: ClassVar[str] = ".ert_runpath_list"
PREINSTALLED_FORWARD_MODEL_STEPS: ClassVar[Dict[str, ForwardModelStep]] = {}
ENV_PR_FM_STEP: ClassVar[Dict[str, Dict[str, Any]]] = {}
ACTIVATE_SCRIPT: Optional[str] = None

substitutions: Substitutions = field(default_factory=Substitutions)
ensemble_config: EnsembleConfig = field(default_factory=EnsembleConfig)
Expand Down Expand Up @@ -347,6 +349,7 @@ class ErtConfigWithPlugins(ErtConfig):
Dict[str, ForwardModelStepPlugin]
] = preinstalled_fm_steps
ENV_PR_FM_STEP: ClassVar[Dict[str, Dict[str, Any]]] = env_pr_fm_step
ACTIVATE_SCRIPT = ErtPluginManager().activate_script()

assert issubclass(ErtConfigWithPlugins, ErtConfig)
return ErtConfigWithPlugins
Expand Down Expand Up @@ -675,6 +678,12 @@ def _merge_user_and_site_config(
user_config_dict[keyword] = value + original_entries
elif keyword not in user_config_dict:
user_config_dict[keyword] = value
if cls.ACTIVATE_SCRIPT:
if "QUEUE_OPTION" not in user_config_dict:
user_config_dict["QUEUE_OPTION"] = []
user_config_dict["QUEUE_OPTION"].append(
[QueueSystemWithGeneric.GENERIC, "ACTIVATE_SCRIPT", cls.ACTIVATE_SCRIPT]
)
return user_config_dict

@classmethod
Expand Down
10 changes: 9 additions & 1 deletion src/ert/config/queue_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import os
import re
import shutil
from abc import abstractmethod
Expand All @@ -26,12 +27,20 @@
NonEmptyString = Annotated[str, pydantic.StringConstraints(min_length=1)]


def activate_script() -> str:
venv = os.environ.get("VIRTUAL_ENV")
if not venv:
return ""
return f"source {venv}/bin/activate"


@pydantic.dataclasses.dataclass(config={"extra": "forbid", "validate_assignment": True})
class QueueOptions:
name: str
max_running: pydantic.NonNegativeInt = 0
submit_sleep: pydantic.NonNegativeFloat = 0.0
project_code: Optional[str] = None
activate_script: str = field(default_factory=activate_script)

@staticmethod
def create_queue_options(
Expand Down Expand Up @@ -292,7 +301,6 @@ def from_dict(cls, config_dict: ConfigDict) -> QueueConfig:
_grouped_queue_options = _group_queue_options_by_queue_system(
_raw_queue_options
)

_log_duplicated_queue_options(_raw_queue_options)
_raise_for_defaulted_invalid_options(_raw_queue_options)

Expand Down
1 change: 1 addition & 0 deletions src/ert/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> Any:
"flow_config_path",
"help_links",
"site_config_lines",
"activate_script",
]
and res is not None
):
Expand Down
2 changes: 2 additions & 0 deletions src/ert/plugins/hook_specifications/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .activate_script import activate_script
from .ecl_config import (
ecl100_config_path,
ecl300_config_path,
Expand All @@ -21,6 +22,7 @@
from .site_config import site_config_lines

__all__ = [
"activate_script",
"add_log_handle_to_root",
"add_span_processor",
"ecl100_config_path",
Expand Down
19 changes: 19 additions & 0 deletions src/ert/plugins/hook_specifications/activate_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from ert.plugins.plugin_manager import hook_specification


@hook_specification
def activate_script() -> str: # type: ignore
"""
Allows the plugin to provide a script that will be run when
the driver submits to the cluster. The script will run in
bash.

Example:
import ert

@ert.plugin(name="my_plugin")
def activate_script():
return "source /private/venv/my_env/bin/activate

:return: Activate script
oyvindeide marked this conversation as resolved.
Show resolved Hide resolved
"""
11 changes: 11 additions & 0 deletions src/ert/plugins/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,17 @@ def _site_config_lines(self) -> List[str]:
]
return list(chain.from_iterable(reversed(plugin_site_config_lines)))

def activate_script(self) -> str:
plugin_responses = self.hook.activate_script()
if not plugin_responses:
return ""
if len(plugin_responses) > 1:
raise ValueError(
f"Only one activate script is allowed, got {[plugin.plugin_metadata.plugin_name for plugin in plugin_responses]}"
)
else:
return plugin_responses[0].data

def get_installable_workflow_jobs(self) -> Dict[str, str]:
config_workflow_jobs = self._get_config_workflow_jobs()
hooked_workflow_jobs = self.get_ertscript_workflows().get_workflows()
Expand Down
14 changes: 13 additions & 1 deletion src/ert/scheduler/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,28 @@
"""Bash and other shells add an offset of 128 to the signal value when a process exited due to a signal"""


def create_submit_script(
runpath: Path, executable: str, args: tuple[str, ...], activate_script: str
) -> str:
return (
"#!/usr/bin/env bash\n"
f"cd {shlex.quote(str(runpath))}\n"
f"{activate_script}\n"
f"exec -a {shlex.quote(executable)} {executable} {shlex.join(args)}\n"
)


class FailedSubmit(RuntimeError):
pass


class Driver(ABC):
"""Adapter for the HPC cluster."""

def __init__(self, **kwargs: Dict[str, str]) -> None:
def __init__(self, activate_script: str = "") -> None:
self._event_queue: Optional[asyncio.Queue[Event]] = None
self._job_error_message_by_iens: Dict[int, str] = {}
self.activate_script = activate_script

@property
def event_queue(self) -> asyncio.Queue[Event]:
Expand Down
1 change: 1 addition & 0 deletions src/ert/scheduler/local_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ async def submit(
runpath: Optional[Path] = None,
num_cpu: Optional[int] = 1,
realization_memory: Optional[int] = 0,
activate_script: str = "",
) -> None:
self._tasks[iens] = asyncio.create_task(self._run(iens, executable, *args))
with suppress(KeyError):
Expand Down
12 changes: 4 additions & 8 deletions src/ert/scheduler/lsf_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
get_args,
)

from .driver import SIGNAL_OFFSET, Driver, FailedSubmit
from .driver import SIGNAL_OFFSET, Driver, FailedSubmit, create_submit_script
from .event import Event, FinishedEvent, StartedEvent

_POLL_PERIOD = 2.0 # seconds
Expand Down Expand Up @@ -264,8 +264,9 @@ def __init__(
bjobs_cmd: Optional[str] = None,
bkill_cmd: Optional[str] = None,
bhist_cmd: Optional[str] = None,
activate_script: str = "",
) -> None:
super().__init__()
super().__init__(activate_script)
self._queue_name = queue_name
self._project_code = project_code
self._resource_requirement = resource_requirement
Expand Down Expand Up @@ -309,12 +310,7 @@ async def submit(

arg_queue_name = ["-q", self._queue_name] if self._queue_name else []
arg_project_code = ["-P", self._project_code] if self._project_code else []

script = (
"#!/usr/bin/env bash\n"
f"cd {shlex.quote(str(runpath))}\n"
f"exec -a {shlex.quote(executable)} {executable} {shlex.join(args)}\n"
)
script = create_submit_script(runpath, executable, args, self.activate_script)
script_path: Optional[Path] = None
try:
with NamedTemporaryFile(
Expand Down
11 changes: 4 additions & 7 deletions src/ert/scheduler/openpbs_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
get_type_hints,
)

from .driver import Driver, FailedSubmit
from .driver import Driver, FailedSubmit, create_submit_script
from .event import Event, FinishedEvent, StartedEvent

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -145,8 +145,9 @@ def __init__(
qsub_cmd: Optional[str] = None,
qstat_cmd: Optional[str] = None,
qdel_cmd: Optional[str] = None,
activate_script: str = "",
) -> None:
super().__init__()
super().__init__(activate_script)

self._queue_name = queue_name
self._project_code = project_code
Expand Down Expand Up @@ -241,11 +242,7 @@ async def submit(
[] if self._keep_qsub_output else ["-o", "/dev/null", "-e", "/dev/null"]
)

script = (
"#!/usr/bin/env bash\n"
f"cd {shlex.quote(str(runpath))}\n"
f"exec -a {shlex.quote(executable)} {executable} {shlex.join(args)}\n"
)
script = create_submit_script(runpath, executable, args, self.activate_script)
name_prefix = self._job_prefix or ""
qsub_with_args: List[str] = [
str(self._qsub_cmd),
Expand Down
11 changes: 4 additions & 7 deletions src/ert/scheduler/slurm_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Tuple,
)

from .driver import SIGNAL_OFFSET, Driver, FailedSubmit
from .driver import SIGNAL_OFFSET, Driver, FailedSubmit, create_submit_script
from .event import Event, FinishedEvent, StartedEvent

SLURM_FAILED_EXIT_CODE_FETCH = SIGNAL_OFFSET + 66
Expand Down Expand Up @@ -77,6 +77,7 @@ def __init__(
max_runtime: Optional[float] = None,
squeue_timeout: float = 2,
project_code: Optional[str] = None,
activate_script: str = "",
) -> None:
"""
The arguments "memory" and "realization_memory" are currently both
Expand All @@ -90,7 +91,7 @@ def __init__(
zero "realization memory" is the default and means no intended
memory allocation.
"""
super().__init__()
super().__init__(activate_script)
self._submit_locks: dict[int, asyncio.Lock] = {}
self._iens2jobid: dict[int, str] = {}
self._jobs: dict[str, JobData] = {}
Expand Down Expand Up @@ -181,11 +182,7 @@ async def submit(
if runpath is None:
runpath = Path.cwd()

script = (
"#!/usr/bin/env bash\n"
f"cd {shlex.quote(str(runpath))}\n"
f"exec -a {shlex.quote(executable)} {executable} {shlex.join(args)}\n"
)
script = create_submit_script(runpath, executable, args, self.activate_script)
script_path: Optional[Path] = None
try:
with NamedTemporaryFile(
Expand Down
11 changes: 9 additions & 2 deletions src/everest/detached/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import importlib
import json
import logging
Expand All @@ -21,6 +22,7 @@
SlurmQueueOptions,
TorqueQueueOptions,
)
from ert.plugins import ErtPluginManager
from ert.scheduler import create_driver
from ert.scheduler.driver import Driver, FailedSubmit
from ert.scheduler.event import StartedEvent
Expand Down Expand Up @@ -57,12 +59,15 @@ async def start_server(config: EverestConfig, debug: bool = False) -> Driver:
args = ["--config-file", str(config.config_path)]
if debug:
args.append("--debug")
poll_task = asyncio.create_task(driver.poll(), name="poll_task")
await driver.submit(0, "everserver", *args)
except FailedSubmit as err:
raise ValueError(f"Failed to submit Everserver with error: {err}") from err
status = await driver.event_queue.get()
if not isinstance(status, StartedEvent):
poll_task.cancel()
raise ValueError(f"Everserver not started as expected, got status: {status}")
poll_task.cancel()
return driver


Expand Down Expand Up @@ -279,22 +284,24 @@ def get_server_queue_options(
simulator: Optional[SimulatorConfig],
server: Optional[ServerConfig],
) -> QueueOptions:
activate_script = ErtPluginManager().activate_script()
queue_system = _find_res_queue_system(simulator, server)
ever_queue_config = server if server is not None else simulator

if queue_system == QueueSystem.LSF:
queue = LsfQueueOptions(
activate_script=activate_script,
lsf_queue=ever_queue_config.name,
lsf_resource=ever_queue_config.options,
)
elif queue_system == QueueSystem.SLURM:
queue = SlurmQueueOptions(
activate_script=activate_script,
exclude_host=ever_queue_config.exclude_host,
include_host=ever_queue_config.include_host,
partition=ever_queue_config.name,
)
elif queue_system == QueueSystem.TORQUE:
queue = TorqueQueueOptions()
queue = TorqueQueueOptions(activate_script=activate_script)
elif queue_system == QueueSystem.LOCAL:
queue = LocalQueueOptions()
else:
Expand Down
13 changes: 13 additions & 0 deletions tests/ert/unit_tests/config/test_queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ert.config.queue_config import (
LocalQueueOptions,
LsfQueueOptions,
QueueOptions,
SlurmQueueOptions,
TorqueQueueOptions,
)
Expand Down Expand Up @@ -509,3 +510,15 @@ def test_driver_initialization_from_defaults(queue_system):
LocalDriver(**LocalQueueOptions().driver_options)
if queue_system == QueueSystem.SLURM:
SlurmDriver(**SlurmQueueOptions().driver_options)


@pytest.mark.parametrize(
"venv, expected", [("my_env", "source my_env/bin/activate"), (None, "")]
)
def test_default_activate_script_generation(expected, monkeypatch, venv):
if venv:
monkeypatch.setenv("VIRTUAL_ENV", venv)
else:
monkeypatch.delenv("VIRTUAL_ENV", raising=False)
options = QueueOptions(name="local")
assert options.activate_script == expected
Loading