Skip to content

Commit

Permalink
Index ert config dictionary with Ert ConfigKeys instead of hardcoded …
Browse files Browse the repository at this point in the history
…strings
  • Loading branch information
StephanDeHoop committed Oct 23, 2024
1 parent 6aff8b8 commit ca44e00
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 89 deletions.
51 changes: 27 additions & 24 deletions src/everest/simulator/everest_to_ert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import everest
from ert.config import ErtConfig, ExtParamConfig
from ert.config.parsing import ConfigDict
from ert.config.parsing import ConfigKeys as ErtConfigKeys
from everest.config import EverestConfig
from everest.config.control_variable_config import (
ControlVariableConfig,
Expand Down Expand Up @@ -121,7 +122,7 @@ def _extract_summary_keys(ever_config: EverestConfig, ert_config):
+ user_specified_keys
)
all_keys = list(set(all_keys))
ert_config["SUMMARY"] = [all_keys]
ert_config[ErtConfigKeys.SUMMARY] = [all_keys]


def _extract_environment(ever_config: EverestConfig, ert_config):
Expand All @@ -137,9 +138,9 @@ def _extract_environment(ever_config: EverestConfig, ert_config):
default_runpath_file = os.path.join(ever_config.output_dir, ".res_runpath_list")
default_ens_path = os.path.join(ever_config.output_dir, STORAGE_DIR)

ert_config["RUNPATH"] = simulation_path
ert_config["ENSPATH"] = default_ens_path
ert_config["RUNPATH_FILE"] = default_runpath_file
ert_config[ErtConfigKeys.RUNPATH] = simulation_path
ert_config[ErtConfigKeys.ENSPATH] = default_ens_path
ert_config[ErtConfigKeys.RUNPATH_FILE] = default_runpath_file


def _inject_simulation_defaults(ert_config, ever_config: EverestConfig):
Expand Down Expand Up @@ -170,17 +171,17 @@ def _extract_simulator(ever_config: EverestConfig, ert_config):
# Resubmit number (number of submission retries)
resubmit = ever_simulation.resubmit_limit
if resubmit is not None:
ert_config["MAX_SUBMIT"] = resubmit + 1
ert_config[ErtConfigKeys.MAX_SUBMIT] = resubmit + 1

# Maximum number of seconds (MAX_RUNTIME) a forward model is allowed to run
max_runtime = ever_simulation.max_runtime
if max_runtime is not None:
ert_config["MAX_RUNTIME"] = max_runtime or 0
ert_config[ErtConfigKeys.MAX_RUNTIME] = max_runtime or 0

# Number of cores reserved on queue nodes (NUM_CPU)
num_fm_cpu = ever_simulation.cores_per_node
if num_fm_cpu is not None:
ert_config["NUM_CPU"] = num_fm_cpu
ert_config[ErtConfigKeys.NUM_CPU] = num_fm_cpu

_inject_simulation_defaults(ert_config, ever_config)

Expand Down Expand Up @@ -236,21 +237,21 @@ def _extract_jobs(ever_config, ert_config, path):
}
)

res_jobs = ert_config.get("INSTALL_JOB", [])
res_jobs = ert_config.get(ErtConfigKeys.INSTALL_JOB, [])
for job in ever_jobs:
new_job = (
job[ConfigKeys.NAME],
os.path.join(path, job[ConfigKeys.SOURCE]),
)
res_jobs.append(new_job)

ert_config["INSTALL_JOB"] = res_jobs
ert_config[ErtConfigKeys.INSTALL_JOB] = res_jobs


def _extract_workflow_jobs(ever_config, ert_config, path):
workflow_jobs = [_job_to_dict(j) for j in (ever_config.install_workflow_jobs or [])]

res_jobs = ert_config.get("LOAD_WORKFLOW_JOB", [])
res_jobs = ert_config.get(ErtConfigKeys.LOAD_WORKFLOW_JOB, [])
for job in workflow_jobs:
new_job = (
os.path.join(path, job[ConfigKeys.SOURCE]),
Expand All @@ -259,7 +260,7 @@ def _extract_workflow_jobs(ever_config, ert_config, path):
res_jobs.append(new_job)

if res_jobs:
ert_config["LOAD_WORKFLOW_JOB"] = res_jobs
ert_config[ErtConfigKeys.LOAD_WORKFLOW_JOB] = res_jobs


def _extract_workflows(ever_config, ert_config, path):
Expand All @@ -268,8 +269,8 @@ def _extract_workflows(ever_config, ert_config, path):
"post_simulation": "POST_SIMULATION",
}

res_workflows = ert_config.get("LOAD_WORKFLOW", [])
res_hooks = ert_config.get("HOOK_WORKFLOW", [])
res_workflows = ert_config.get(ErtConfigKeys.LOAD_WORKFLOW, [])
res_hooks = ert_config.get(ErtConfigKeys.HOOK_WORKFLOW, [])

for ever_trigger, res_trigger in trigger2res.items():
jobs = getattr(ever_config.workflows, ever_trigger, None)
Expand All @@ -281,8 +282,8 @@ def _extract_workflows(ever_config, ert_config, path):
res_hooks.append((ever_trigger, res_trigger))

if res_workflows:
ert_config["LOAD_WORKFLOW"] = res_workflows
ert_config["HOOK_WORKFLOW"] = res_hooks
ert_config[ErtConfigKeys.LOAD_WORKFLOW] = res_workflows
ert_config[ErtConfigKeys.HOOK_WORKFLOW] = res_hooks


def _internal_data_files(ever_config: EverestConfig):
Expand Down Expand Up @@ -438,30 +439,32 @@ def _extract_forward_model(ever_config: EverestConfig, ert_config):
forward_model += ever_config.forward_model or []
forward_model = _insert_strip_dates_job(ever_config, forward_model)

sim_job = ert_config.get("SIMULATION_JOB", [])
sim_job = ert_config.get(ErtConfigKeys.SIMULATION_JOB, [])
for job in forward_model:
tmp = job.split()
sim_job.append(tuple(tmp))

ert_config["SIMULATION_JOB"] = sim_job
ert_config[ErtConfigKeys.SIMULATION_JOB] = sim_job


def _extract_model(ever_config: EverestConfig, ert_config):
_extract_summary_keys(ever_config, ert_config)

if "NUM_REALIZATIONS" not in ert_config:
if ErtConfigKeys.NUM_REALIZATIONS not in ert_config:
if ever_config.model.realizations is not None:
ert_config["NUM_REALIZATIONS"] = len(ever_config.model.realizations)
ert_config[ErtConfigKeys.NUM_REALIZATIONS] = len(
ever_config.model.realizations
)
else:
ert_config["NUM_REALIZATIONS"] = 1
ert_config[ErtConfigKeys.NUM_REALIZATIONS] = 1


def _extract_seed(ever_config: EverestConfig, ert_config):
assert ever_config.environment is not None
random_seed = ever_config.environment.random_seed

if random_seed:
ert_config["RANDOM_SEED"] = random_seed
ert_config[ErtConfigKeys.RANDOM_SEED] = random_seed


def _extract_results(ever_config: EverestConfig, ert_config):
Expand All @@ -473,10 +476,10 @@ def _extract_results(ever_config: EverestConfig, ert_config):
constraint_names = [
constraint.name for constraint in (ever_config.output_constraints or [])
]
gen_data = ert_config.get("GEN_DATA", [])
gen_data = ert_config.get(ErtConfigKeys.GEN_DATA, [])
for name in objectives_names + constraint_names:
gen_data.append((name, f"RESULT_FILE:{name}"))
ert_config["GEN_DATA"] = gen_data
ert_config[ErtConfigKeys.GEN_DATA] = gen_data


def _everest_to_ert_config_dict(
Expand All @@ -489,7 +492,7 @@ def _everest_to_ert_config_dict(
ert_config = site_config if site_config is not None else {}

config_dir = ever_config.config_directory
ert_config["DEFINE"] = [("<CONFIG_PATH>", config_dir)]
ert_config[ErtConfigKeys.DEFINE] = [("<CONFIG_PATH>", config_dir)]

# Extract simulator and simulation related configs
_extract_simulator(ever_config, ert_config)
Expand Down
39 changes: 20 additions & 19 deletions tests/everest/test_egg_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import everest
from ert.config import ErtConfig, QueueSystem
from ert.config.parsing import ConfigKeys as ErtConfigKeys
from everest.config import EverestConfig
from everest.config.export_config import ExportConfig
from everest.config_keys import ConfigKeys
Expand Down Expand Up @@ -462,25 +463,25 @@


def sort_res_summary(ert_config):
ert_config["SUMMARY"][0] = sorted(ert_config["SUMMARY"][0])
ert_config[ErtConfigKeys.SUMMARY][0] = sorted(ert_config[ErtConfigKeys.SUMMARY][0])


def _generate_exp_ert_config(config_path, output_dir):
return {
"DEFINE": [("<CONFIG_PATH>", config_path)],
"INSTALL_JOB": everest_default_jobs(output_dir),
"QUEUE_OPTION": [(QueueSystem.LOCAL, "MAX_RUNNING", 3)],
"QUEUE_SYSTEM": QueueSystem.LOCAL,
"NUM_REALIZATIONS": NUM_REALIZATIONS,
"RUNPATH": os.path.join(
ErtConfigKeys.DEFINE: [("<CONFIG_PATH>", config_path)],
ErtConfigKeys.INSTALL_JOB: everest_default_jobs(output_dir),
ErtConfigKeys.QUEUE_OPTION: [(QueueSystem.LOCAL, "MAX_RUNNING", 3)],
ErtConfigKeys.QUEUE_SYSTEM: QueueSystem.LOCAL,
ErtConfigKeys.NUM_REALIZATIONS: NUM_REALIZATIONS,
ErtConfigKeys.RUNPATH: os.path.join(
output_dir,
"egg_simulations/<CASE_NAME>/geo_realization_<GEO_ID>/simulation_<IENS>",
),
"RUNPATH_FILE": os.path.join(
ErtConfigKeys.RUNPATH_FILE: os.path.join(
os.path.realpath("everest/model"),
"everest_output/.res_runpath_list",
),
"SIMULATION_JOB": [
ErtConfigKeys.SIMULATION_JOB: [
(
"copy_directory",
f"{config_path}/../../eclipse/include/"
Expand Down Expand Up @@ -551,14 +552,14 @@ def _generate_exp_ert_config(config_path, output_dir):
),
("rf", "-s", "eclipse/model/EGG", "-o", "rf"),
],
"ENSPATH": os.path.join(
ErtConfigKeys.ENSPATH: os.path.join(
os.path.realpath("everest/model"),
"everest_output/simulation_results",
),
"ECLBASE": "eclipse/model/EGG",
"RANDOM_SEED": 123456,
"SUMMARY": SUM_KEYS,
"GEN_DATA": [("rf", "RESULT_FILE:rf")],
ErtConfigKeys.ECLBASE: "eclipse/model/EGG",
ErtConfigKeys.RANDOM_SEED: 123456,
ErtConfigKeys.SUMMARY: SUM_KEYS,
ErtConfigKeys.GEN_DATA: [("rf", "RESULT_FILE:rf")],
}


Expand Down Expand Up @@ -589,7 +590,7 @@ def test_egg_model_convert_no_opm(copy_egg_test_data_to_tmp):
output_dir = config.output_dir
config_path = os.path.dirname(os.path.abspath(CONFIG_FILE))
exp_ert_config = _generate_exp_ert_config(config_path, output_dir)
exp_ert_config["SUMMARY"][0] = SUM_KEYS_NO_OPM
exp_ert_config[ErtConfigKeys.SUMMARY][0] = SUM_KEYS_NO_OPM
sort_res_summary(exp_ert_config)
sort_res_summary(ert_config)
assert exp_ert_config == ert_config
Expand All @@ -612,8 +613,8 @@ def test_opm_fail_default_summary_keys(copy_egg_test_data_to_tmp):
output_dir = config.output_dir
config_path = os.path.dirname(os.path.abspath(CONFIG_FILE))
exp_ert_config = _generate_exp_ert_config(config_path, output_dir)
exp_ert_config["SUMMARY"][0] = filter(
lambda key: not key.startswith("G"), exp_ert_config["SUMMARY"][0]
exp_ert_config[ErtConfigKeys.SUMMARY][0] = filter(
lambda key: not key.startswith("G"), exp_ert_config[ErtConfigKeys.SUMMARY][0]
)
sort_res_summary(exp_ert_config)
sort_res_summary(ert_config)
Expand Down Expand Up @@ -651,11 +652,11 @@ def test_opm_fail_explicit_summary_keys(copy_egg_test_data_to_tmp):
output_dir = config.output_dir
config_path = os.path.dirname(os.path.abspath(CONFIG_FILE))
exp_ert_config = _generate_exp_ert_config(config_path, output_dir)
exp_ert_config["SUMMARY"] = [
exp_ert_config[ErtConfigKeys.SUMMARY] = [
list(
filter(
lambda key: not key.startswith("G"),
exp_ert_config["SUMMARY"][0],
exp_ert_config[ErtConfigKeys.SUMMARY][0],
)
)
+ extra_sum_keys
Expand Down
Loading

0 comments on commit ca44e00

Please sign in to comment.