Skip to content

Commit

Permalink
Refactor queue_options in QueueConfig to be Tuple[str, str]
Browse files Browse the repository at this point in the history
  • Loading branch information
larsevj authored Sep 27, 2023
1 parent 3ecb021 commit dbf128a
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 53 deletions.
68 changes: 30 additions & 38 deletions src/ert/config/queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import shutil
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple, Union, no_type_check
from typing import Any, Dict, List, Tuple, no_type_check

from ert import _clib

Expand All @@ -27,7 +27,7 @@ class QueueConfig:
job_script: str = shutil.which("job_dispatch.py") or "job_dispatch.py"
max_submit: int = 2
queue_system: QueueSystem = QueueSystem.LOCAL # type: ignore
queue_options: Dict[QueueSystem, List[Union[Tuple[str, str], str]]] = field(
queue_options: Dict[QueueSystem, List[Tuple[str, str]]] = field(
default_factory=dict
)

Expand All @@ -37,7 +37,7 @@ def __post_init__(self) -> None:
setting
for settings in self.queue_options.values()
for setting in settings
if isinstance(setting, tuple) and setting[0] == "MAX_RUNNING"
if setting[0] == "MAX_RUNNING" and setting[1]
]:
err_msg = "QUEUE_OPTION MAX_RUNNING is"
try:
Expand Down Expand Up @@ -76,29 +76,25 @@ def from_dict(cls, config_dict: ConfigDict) -> QueueConfig:
)
job_script = job_script or "job_dispatch.py"
max_submit: int = config_dict.get("MAX_SUBMIT", 2)
queue_options: Dict[
QueueSystem, List[Union[Tuple[str, str], str]]
] = defaultdict(list)
queue_options: Dict[QueueSystem, List[Tuple[str, str]]] = defaultdict(list)
for system, option_name, *values in config_dict.get("QUEUE_OPTION", []):
queue_system = QueueSystem.from_string(system)
if option_name not in VALID_QUEUE_OPTIONS[queue_system]:
raise ConfigValidationError(
f"Invalid QUEUE_OPTION for {queue_system.name}: '{option_name}'. "
f"Valid choices are {sorted(VALID_QUEUE_OPTIONS[queue_system])}."
)
if values:
if option_name == "LSF_SERVER" and values[0].startswith("$"):
raise ConfigValidationError(
"Invalid server name specified for QUEUE_OPTION LSF"
f" LSF_SERVER: {values[0]}. Server name is currently an"
" undefined environment variable. The LSF_SERVER keyword is"
" usually provided by the site-configuration file, beware that"
" you are effectively replacing the default value provided."
)
queue_options[queue_system].append((option_name, values[0]))
else:
queue_options[queue_system].append(option_name)

queue_options[queue_system].append(
(option_name, values[0] if values else "")
)
if values and option_name == "LSF_SERVER" and values[0].startswith("$"):
raise ConfigValidationError(
"Invalid server name specified for QUEUE_OPTION LSF"
f" LSF_SERVER: {values[0]}. Server name is currently an"
" undefined environment variable. The LSF_SERVER keyword is"
" usually provided by the site-configuration file, beware that"
" you are effectively replacing the default value provided."
)
if (
selected_queue_system == QueueSystem.TORQUE
and queue_options[QueueSystem.TORQUE]
Expand Down Expand Up @@ -127,15 +123,12 @@ def create_local_copy(self) -> QueueConfig:

def _check_for_overwritten_queue_system_options(
selected_queue_system: QueueSystem,
queue_system_options: List[Union[Tuple[str, str], str]],
queue_system_options: List[Tuple[str, str]],
) -> None:
def generate_dict(
option_list: List[Union[Tuple[str, str], str]]
) -> Dict[str, List[str]]:
def generate_dict(option_list: List[Tuple[str, str]]) -> Dict[str, List[str]]:
temp_dict: Dict[str, List[str]] = defaultdict(list)
for option_string in option_list:
if isinstance(option_string, tuple) and (option_string[0] != "MAX_RUNNING"):
temp_dict.setdefault(option_string[0], []).append(option_string[1])
temp_dict.setdefault(option_string[0], []).append(option_string[1])
return temp_dict

for option_name, option_values in generate_dict(queue_system_options).items():
Expand All @@ -148,16 +141,15 @@ def generate_dict(

def _validate_torque_options(torque_options: List[Tuple[str, str]]) -> None:
for option_strings in torque_options:
if isinstance(option_strings, tuple):
option_name = option_strings[0]
option_value = option_strings[1]
if (
option_value != "" # This is equivalent to the option not being set
and option_name == "MEMORY_PER_JOB"
and re.match("[0-9]+[mg]b", option_value) is None
):
raise ConfigValidationError(
f"The value '{option_value}' is not valid for the Torque option "
"MEMORY_PER_JOB, it must be of "
"the format '<integer>mb' or '<integer>gb'."
)
option_name = option_strings[0]
option_value = option_strings[1]
if (
option_value != "" # This is equivalent to the option not being set
and option_name == "MEMORY_PER_JOB"
and re.match("[0-9]+[mg]b", option_value) is None
):
raise ConfigValidationError(
f"The value '{option_value}' is not valid for the Torque option "
"MEMORY_PER_JOB, it must be of "
"the format '<integer>mb' or '<integer>gb'."
)
13 changes: 2 additions & 11 deletions src/ert/job_queue/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,11 @@ def __init__(

def set_option(self, option: str, value: str) -> bool:
if option == "MAX_RUNNING":
self.set_max_running(int(value))
self.set_max_running(int(value) if value else 0)
return True
else:
return self._set_option(option, str(value))

def unset_option(self, option: str) -> None:
if option == "MAX_RUNNING":
self.set_max_running(0)
else:
self._set_option(option, "")

def get_option(self, option_key: str) -> str:
if option_key == "MAX_RUNNING":
return str(self.get_max_running())
Expand All @@ -58,10 +52,7 @@ def create_driver(cls, queue_config: QueueConfig) -> "Driver":
driver = Driver(queue_config.queue_system)
if queue_config.queue_system in queue_config.queue_options:
for setting in queue_config.queue_options[queue_config.queue_system]:
if isinstance(setting, tuple):
driver.set_option(*setting)
else:
driver.unset_option(setting)
driver.set_option(*setting)
return driver

@property
Expand Down
21 changes: 20 additions & 1 deletion tests/unit_tests/config/test_queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def test_overwriting_QUEUE_OPTIONS_warning(
f.write("NUM_REALIZATIONS 1\n")
f.write(f"QUEUE_SYSTEM {queue_system}\n")
f.write(f"QUEUE_OPTION {queue_system} {queue_system_option} test_1\n")
f.write(f"QUEUE_OPTION {queue_system} {queue_system_option} \n")
test_site_config = tmp_path / "test_site_config.ert"
test_site_config.write_text(
"JOB_SCRIPT job_dispatch.py\n"
Expand Down Expand Up @@ -185,3 +184,23 @@ def test_undefined_LSF_SERVER_environment_variable():
),
):
ErtConfig.from_file(filename)


@pytest.mark.usefixtures("use_tmpdir")
@pytest.mark.parametrize(
"queue_system, queue_system_option",
[("LSF", "LSF_SERVER"), ("SLURM", "SQUEUE"), ("TORQUE", "QUEUE")],
)
def test_initializing_empty_config_values(queue_system, queue_system_option):
filename = "config.ert"
with open(filename, "w", encoding="utf-8") as f:
f.write("NUM_REALIZATIONS 1\n")
f.write(f"QUEUE_SYSTEM {queue_system}\n")
f.write(f"QUEUE_OPTION {queue_system} {queue_system_option}\n")
f.write(f"QUEUE_OPTION {queue_system} MAX_RUNNING\n")
config_object = ErtConfig.from_file(filename)
driver = Driver.create_driver(config_object.queue_config)
assert driver.get_option(queue_system_option) == ""
assert driver.get_option("MAX_RUNNING") == "0"
for options in config_object.queue_config.queue_options[queue_system]:
assert isinstance(options, tuple)
10 changes: 7 additions & 3 deletions tests/unit_tests/job_queue/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@ def test_set_and_unset_option():
queue_options={
QueueSystem.LOCAL: [
("MAX_RUNNING", "50"),
"MAX_RUNNING",
("MAX_RUNNING", ""),
]
},
)
driver = Driver.create_driver(queue_config)
assert driver.get_option("MAX_RUNNING") == "0"
assert driver.set_option("MAX_RUNNING", "42")
assert driver.get_option("MAX_RUNNING") == "42"
driver.unset_option("MAX_RUNNING")
driver.set_option("MAX_RUNNING", "")
assert driver.get_option("MAX_RUNNING") == "0"
driver.set_option("MAX_RUNNING", "100")
assert driver.get_option("MAX_RUNNING") == "100"
driver.set_option("MAX_RUNNING", "0")
assert driver.get_option("MAX_RUNNING") == "0"


Expand Down Expand Up @@ -54,6 +58,6 @@ def test_get_slurm_queue_config():

assert driver.get_option("SBATCH") == "/path/to/sbatch"
assert driver.get_option("SCONTROL") == "scontrol"
driver.unset_option("SCONTROL")
driver.set_option("SCONTROL", "")
assert driver.get_option("SCONTROL") == ""
assert driver.name == "SLURM"

0 comments on commit dbf128a

Please sign in to comment.