Skip to content

Commit

Permalink
Fix site-config options overwriting user config
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-eq committed Sep 10, 2024
1 parent de9d63a commit 5df708f
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 12 deletions.
39 changes: 32 additions & 7 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def from_file(cls, user_config_file: str) -> Self:
Warnings will be issued with :python:`warnings.warn(category=ConfigWarning)`
when the user should be notified with non-fatal configuration problems.
"""
user_config_dict = cls.read_user_config(user_config_file)
user_config_dict = cls.read_user_config_and_apply_site_config(user_config_file)
config_dir = path.abspath(path.dirname(user_config_file))
cls._log_config_file(user_config_file)
cls._log_config_dict(user_config_dict)
Expand Down Expand Up @@ -372,12 +372,37 @@ def read_site_config(cls) -> ConfigDict:

@classmethod
def read_user_config(cls, user_config_file: str) -> ConfigDict:
site_config = cls.read_site_config()
return lark_parse(
file=user_config_file,
schema=init_user_config_schema(),
site_config=site_config,
)
return lark_parse(user_config_file, schema=init_user_config_schema())

@classmethod
def read_user_config_and_apply_site_config(
cls, user_config_file: str
) -> ConfigDict:
site_config_dict = cls.read_site_config()
user_config_dict = cls.read_user_config(user_config_file)

for keyword, value in site_config_dict.items():
if keyword == "QUEUE_OPTION":
filtered_queue_options = []
for queue_option in value:
if (
"MAX_RUNNING" in queue_option
and "MAX_RUNNING" in user_config_dict
) or (
"SUBMIT_SLEEP" in queue_option
and "SUBMIT_SLEEP" in user_config_dict
):
continue
filtered_queue_options.append(queue_option)
user_config_dict["QUEUE_OPTION"] = (
filtered_queue_options + user_config_dict.get("QUEUE_OPTION", [])
)
elif isinstance(value, list):
original_entries: list = user_config_dict.get(keyword, [])
user_config_dict[keyword] = value + original_entries
elif keyword not in user_config_dict:
user_config_dict[keyword] = value
return user_config_dict

@staticmethod
def check_non_utf_chars(file_path: str) -> None:
Expand Down
5 changes: 1 addition & 4 deletions src/ert/config/parsing/lark_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,8 @@ def _tree_to_dict(
pre_defines: Defines,
tree: Tree[Instruction],
schema: SchemaItemDict,
site_config: Optional[ConfigDict] = None,
) -> ConfigDict:
config_dict = site_config if site_config else {}
config_dict = {}
defines = pre_defines.copy()
config_dict["DEFINE"] = defines # type: ignore

Expand Down Expand Up @@ -483,7 +482,6 @@ def _parse_file(file: str) -> Tree[Instruction]:
def parse(
file: str,
schema: SchemaItemDict,
site_config: Optional[ConfigDict] = None,
pre_defines: Optional[List[Tuple[str, str]]] = None,
) -> ConfigDict:
filepath = os.path.normpath(os.path.abspath(file))
Expand All @@ -509,7 +507,6 @@ def parse(
config_file=file,
pre_defines=pre_defines,
tree=tree,
site_config=site_config,
schema=schema,
)

Expand Down
72 changes: 72 additions & 0 deletions tests/unit_tests/config/test_ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ContextString,
)
from ert.config.parsing.observations_parser import ObservationConfigError
from ert.config.parsing.queue_system import QueueSystem

from .config_dict_generator import config_generators

Expand Down Expand Up @@ -1644,3 +1645,74 @@ def test_that_empty_params_file_gives_reasonable_error(tmpdir, param_config):

with pytest.raises(ConfigValidationError, match="No parameters specified in"):
ErtConfig.from_file("config.ert")


@pytest.mark.usefixtures("use_tmpdir")
@pytest.mark.parametrize(
"max_running_queue_config_entry",
[
pytest.param(
"""
MAX_RUNNING 6
QUEUE_OPTION LSF MAX_RUNNING 2
QUEUE_OPTION SLURM MAX_RUNNING 2
""",
id="general_keyword_max_running",
),
pytest.param(
"""
QUEUE_OPTION TORQUE MAX_RUNNING 6
MAX_RUNNING 2
""",
id="queue_option_max_running",
),
],
)
def test_queue_config_general_max_running_takes_precedence_over_queue_option(
max_running_queue_config_entry,
):
test_config_file = Path("test.ert")
test_config_file.write_text(
dedent(
f"""
NUM_REALIZATIONS 100
DEFINE <STORAGE> storage/<CONFIG_FILE_BASE>-<DATE>
RUNPATH <STORAGE>/runpath/realization-<IENS>/iter-<ITER>
ENSPATH <STORAGE>/ensemble
QUEUE_SYSTEM TORQUE
{max_running_queue_config_entry}
"""
)
)

config = ErtConfig.from_file(test_config_file)
assert config.queue_config.max_running == 6


def test_general_option_in_local_config_takes_precedence_over_site_config(
tmp_path, monkeypatch
):
test_site_config = tmp_path / "test_site_config.ert"
test_site_config.write_text(
"QUEUE_OPTION TORQUE MAX_RUNNING 6\nQUEUE_SYSTEM LOCAL\nQUEUE_OPTION TORQUE SUBMIT_SLEEP 7"
)
monkeypatch.setenv("ERT_SITE_CONFIG", str(test_site_config))

test_config_file = tmp_path / "test.ert"
test_config_file.write_text(
dedent(
"""
NUM_REALIZATIONS 100
DEFINE <STORAGE> storage/<CONFIG_FILE_BASE>-<DATE>
RUNPATH <STORAGE>/runpath/realization-<IENS>/iter-<ITER>
ENSPATH <STORAGE>/ensemble
QUEUE_SYSTEM TORQUE
MAX_RUNNING 13
SUBMIT_SLEEP 14
"""
)
)
config = ErtConfig.from_file(test_config_file)
assert config.queue_config.max_running == 13
assert config.queue_config.submit_sleep == 14
assert config.queue_config.queue_system == QueueSystem.TORQUE
2 changes: 1 addition & 1 deletion tests/unit_tests/config/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def make_field(field_line):
"NUM_REALIZATIONS 1\nGRID " + egrid_file + "\n" + field_line + "\n",
encoding="utf-8",
)
parsed = lark_parse(str(config_file), init_user_config_schema(), None, None)
parsed = lark_parse(str(config_file), init_user_config_schema(), None)

return Field.from_config_list(parsed["GRID"], grid_shape, parsed["FIELD"][0])

Expand Down

0 comments on commit 5df708f

Please sign in to comment.