Skip to content

Commit

Permalink
Test propagation of SETENV to validation code
Browse files Browse the repository at this point in the history
  • Loading branch information
berland committed Dec 6, 2024
1 parent dff8eb5 commit 0fc30a9
Showing 1 changed file with 37 additions and 3 deletions.
40 changes: 37 additions & 3 deletions tests/ert/unit_tests/config/test_forward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import stat
from pathlib import Path
from textwrap import dedent
from typing import Dict
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -669,7 +670,9 @@ def __init__(self):
command=["something", "<arg1>", "-f", "<arg2>", "<arg3>"],
)

def validate_pre_experiment(self, fm_step_json: ForwardModelStepJSON) -> None:
def validate_pre_experiment(
self, fm_step_json: ForwardModelStepJSON, env_vars: Dict[str, str]
) -> None:
if set(self.private_args.keys()) != {"<arg1>", "<arg2>", "<arg3>"}:
raise ForwardModelStepValidationError("Bad")

Expand Down Expand Up @@ -893,6 +896,33 @@ def validate_pre_realization_run(
)


def test_that_plugin_forward_model_validation_sees_setenv(tmp_path):
(tmp_path / "test.ert").write_text(
"""
NUM_REALIZATIONS 1
SETENV FOO bar
FORWARD_MODEL FM1()
"""
)

class ExceptionThatWeWant(ForwardModelStepValidationError):
pass

class FM1(ForwardModelStepPlugin):
def __init__(self):
super().__init__(name="FM1", command=["dummy.sh"])

def validate_pre_experiment(
self, _: ForwardModelStepJSON, env_vars: Dict[str, str]
) -> None:
raise ExceptionThatWeWant(f'Found FOO={env_vars["FOO"]}')

with pytest.raises(ConfigValidationError, match=".*Found FOO=bar.*"):
ErtConfig.with_plugins(forward_model_step_classes=[FM1]).from_file(
tmp_path / "test.ert"
)


def test_that_plugin_forward_model_raises_pre_experiment_validation_error_early(
tmp_path,
):
Expand All @@ -911,7 +941,9 @@ class FM1(ForwardModelStepPlugin):
def __init__(self):
super().__init__(name="FM1", command=["the_executable.sh"])

def validate_pre_experiment(self, fm_step_json: ForwardModelStepJSON) -> None:
def validate_pre_experiment(
self, fm_step_json: ForwardModelStepJSON, _: Dict[str, str]
) -> None:
if self.name != "FM1":
raise ForwardModelStepValidationError("Expected name to be FM1")

Expand All @@ -924,7 +956,9 @@ def __init__(self):
command=["the_executable.sh"],
)

def validate_pre_experiment(self, fm_step_json: ForwardModelStepJSON) -> None:
def validate_pre_experiment(
self, fm_step_json: ForwardModelStepJSON, _: Dict[str, str]
) -> None:
if self.name != "FM2":
raise ForwardModelStepValidationError("Expected name to be FM2")

Expand Down

0 comments on commit 0fc30a9

Please sign in to comment.