Skip to content

Commit

Permalink
Remove superfluous layer in ert_config
Browse files Browse the repository at this point in the history
  • Loading branch information
AugustoMagalhaes authored and berland committed Dec 16, 2024
1 parent d4e7b72 commit ab58e60
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 95 deletions.
27 changes: 0 additions & 27 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,33 +221,6 @@ def handle_default(fm_step: ForwardModelStep, arg: str) -> str:
}


def forward_model_data_to_json(
substitutions: Substitutions,
forward_model_steps: list[ForwardModelStep],
env_vars: dict[str, str],
env_pr_fm_step: dict[str, dict[str, Any]] | None = None,
user_config_file: str | None = "",
run_id: str | None = None,
iens: int = 0,
itr: int = 0,
context_env: dict[str, str] | None = None,
):
if context_env is None:
context_env = {}
if env_pr_fm_step is None:
env_pr_fm_step = {}
return create_forward_model_json(
context=substitutions,
forward_model_steps=forward_model_steps,
user_config_file=user_config_file,
env_vars={**env_vars, **context_env},
env_pr_fm_step=env_pr_fm_step,
run_id=run_id,
iens=iens,
itr=itr,
)


@dataclass
class ErtConfig:
DEFAULT_ENSPATH: ClassVar[str] = "storage"
Expand Down
9 changes: 4 additions & 5 deletions src/ert/enkf_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import xarray as xr
from numpy.random import SeedSequence

from ert.config.ert_config import forward_model_data_to_json
from ert.config.ert_config import create_forward_model_json
from ert.config.forward_model_step import ForwardModelStep
from ert.config.model_config import ModelConfig
from ert.substitutions import Substitutions, substitute_runpath_name
Expand Down Expand Up @@ -272,16 +272,15 @@ def create_run_path(
path = run_path / "jobs.json"
_backup_if_existing(path)

forward_model_output = forward_model_data_to_json(
substitutions=substitutions,
forward_model_output: dict[str, Any] = create_forward_model_json(
context=substitutions,
forward_model_steps=forward_model_steps,
user_config_file=user_config_file,
env_vars=env_vars,
env_vars={**env_vars, **context_env},
env_pr_fm_step=env_pr_fm_step,
run_id=run_arg.run_id,
iens=run_arg.iens,
itr=ensemble.iteration,
context_env=context_env,
)
with open(run_path / "jobs.json", mode="wb") as fptr:
fptr.write(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ert.config import ErtConfig, ForwardModelStep
from ert.config.ert_config import (
_forward_model_step_from_config_file,
forward_model_data_to_json,
create_forward_model_json,
)
from ert.substitutions import Substitutions

Expand Down Expand Up @@ -295,8 +295,8 @@ def test_config_path_and_file(context):
substitutions=context,
user_config_file="path_to_config_file/config.ert",
)
steps_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
steps_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand All @@ -318,8 +318,8 @@ def test_no_steps(context):
user_config_file="path_to_config_file/config.ert",
)

data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand All @@ -339,8 +339,8 @@ def test_one_step(fm_step_list, context):
substitutions=context,
)

data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand All @@ -357,8 +357,8 @@ def run_all(fm_steplist, context):
substitutions=context,
)

data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -400,8 +400,8 @@ def test_that_values_with_brackets_are_ommitted(caplog, fm_step_list, context):
forward_model_steps=forward_model_list, substitutions=context
)

data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -558,8 +558,8 @@ def test_forward_model_job(job, forward_model, expected_args):

forward_model = ert_config.forward_model_steps

data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -589,8 +589,8 @@ def test_that_config_path_is_the_directory_of_the_main_ert_config():
fout.write("FORWARD_MODEL job_name")

ert_config = ErtConfig.from_file("config_file.ert")
data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -661,8 +661,8 @@ def test_simulation_job(job, forward_model, expected_args):
fout.write(forward_model)

ert_config = ErtConfig.from_file("config_file.ert")
data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -696,8 +696,8 @@ def test_that_private_over_global_args_gives_logging_message(caplog):
fout.write("FORWARD_MODEL job_name(<ARG>=B)")

ert_config = ErtConfig.from_file("config_file.ert")
data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -735,8 +735,8 @@ def test_that_private_over_global_args_does_not_give_logging_message_for_argpass
fout.write("FORWARD_MODEL job_name(<ARG>=<ARG>)")

ert_config = ErtConfig.from_file("config_file.ert")
data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -786,8 +786,8 @@ def test_that_environment_variables_are_set_in_forward_model(
fout.write(forward_model)

ert_config = ErtConfig.from_file("config_file.ert")
data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -817,8 +817,8 @@ def test_that_executables_in_path_are_not_made_realpath(tmp_path):
)

ert_config = ErtConfig.from_file(str(config_file))
data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down
50 changes: 30 additions & 20 deletions tests/ert/unit_tests/config/test_ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from ert.config import AnalysisConfig, ConfigValidationError, ErtConfig, HookRuntime
from ert.config.ert_config import (
forward_model_data_to_json,
create_forward_model_json,
site_config_location,
)
from ert.config.parsing import ConfigKeys, ConfigWarning
Expand Down Expand Up @@ -857,11 +857,12 @@ def test_fm_step_config_via_plugin_ends_up_json_data(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)
assert step_json["jobList"][0]["environment"]["FOO"] == "bar"

Expand All @@ -885,12 +886,14 @@ def test_fm_step_config_via_plugin_does_not_leak_to_other_step(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)

assert "FOO" not in step_json["jobList"][0]["environment"]


Expand All @@ -913,12 +916,14 @@ def test_fm_step_config_via_plugin_has_key_names_uppercased(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)

assert step_json["jobList"][0]["environment"]["FOO"] == "bar"


Expand All @@ -941,11 +946,12 @@ def test_fm_step_config_via_plugin_stringifies_python_objects(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)
assert step_json["jobList"][0]["environment"]["FOO"] == "{'a_dict_as_value': 1}"

Expand All @@ -972,11 +978,12 @@ def test_fm_step_config_via_plugin_ignores_conflict_with_setenv(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)
assert step_json["global_environment"]["FOO"] == "bar_from_setenv"
assert step_json["jobList"][0]["environment"]["FOO"] == "bar_from_plugin"
Expand All @@ -1002,11 +1009,12 @@ def test_fm_step_config_via_plugin_does_not_override_default_env(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)
assert (
step_json["jobList"][0]["environment"]["_ERT_RUNPATH"]
Expand Down Expand Up @@ -1034,11 +1042,12 @@ def test_fm_step_config_via_plugin_is_substituted_for_defines(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)
assert step_json["jobList"][0]["environment"]["FOO"] == "define_works"

Expand All @@ -1062,11 +1071,12 @@ def test_fm_step_config_via_plugin_is_dropped_if_not_define_exists(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)
assert "FOO" not in step_json["jobList"][0]["environment"]

Expand Down Expand Up @@ -1533,13 +1543,13 @@ def test_validate_no_logs_when_overwriting_with_same_value(caplog):

with caplog.at_level(logging.INFO):
ert_config = ErtConfig.from_file("config_file.ert")
forward_model_data_to_json(
substitutions=ert_config.substitutions,
create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
run_id="0",
iens="0",
iens=0,
itr=0,
)

Expand Down
Loading

0 comments on commit ab58e60

Please sign in to comment.