Skip to content

Commit

Permalink
Add line numbers when validating everest config files
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Dec 18, 2024
1 parent 1ed8fc3 commit bc246ad
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 15 deletions.
47 changes: 43 additions & 4 deletions src/everest/config/everest_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,20 @@ def __setattr__(self, name, value):
super().__setattr__(name, value)


class EverestValidationError(ValueError):
def __init__(self):
super().__init__()
# self.errors: dict[tuple[int, int], ErrorDetails] = {}
self.errors: list[tuple[ErrorDetails, tuple[int, int] | None]] = []

@property
def error(self):
return self.errors

def __str__(self):
return f"{self.errors!s}"


class HasName(Protocol):
name: str

Expand Down Expand Up @@ -753,14 +767,39 @@ def lint_config_dict_with_raise(config: dict):
EverestConfig.model_validate(config)

@staticmethod
def load_file(config_path: str) -> "EverestConfig":
config_path = os.path.realpath(config_path)
def load_file(config_file: str) -> "EverestConfig":
config_path = os.path.realpath(config_file)

if not os.path.isfile(config_path):
raise FileNotFoundError(f"File not found: {config_path}")

config_dict = yaml_file_to_substituted_config_dict(config_path)
return EverestConfig.model_validate(config_dict)

try:
return EverestConfig.model_validate(config_dict)
except ValidationError as error:
exp = EverestValidationError()
file_content = []
with open(config_path, encoding="utf-8") as f:
file_content = f.readlines()

for e in error.errors(
include_context=True, include_input=True, include_url=False
):
if e["type"] == "literal_error":
for index, line in enumerate(file_content):
if (pos := line.find(e["input"])) != -1:
exp.errors.append((e, (index + 1, pos + 1)))
break
# elif e["type"] == "missing":
# exp.errors.append((e, None))

# elif e["type"] == "value_error":
# exp.errors.append((e, None))
else:
exp.errors.append((e, None))

raise exp from error

@staticmethod
def load_file_with_argparser(
Expand All @@ -775,7 +814,7 @@ def load_file_with_argparser(
f"The config file: <{config_path}> contains"
f" invalid YAML syntax: {e!s}"
)
except ValidationError as e:
except EverestValidationError as e:
parser.error(
f"Loading config file <{config_path}> failed with:\n"
f"{format_errors(e)}"
Expand Down
27 changes: 21 additions & 6 deletions src/everest/config/validation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar

from pydantic import BaseModel, ValidationError
from pydantic import BaseModel

from everest.config.install_data_config import InstallDataConfig
from everest.util.forward_models import (
Expand All @@ -16,6 +16,7 @@
parse_forward_model_file,
)

from .everest_config import EverestValidationError
from .install_job_config import InstallJobConfig

if TYPE_CHECKING:
Expand Down Expand Up @@ -253,12 +254,26 @@ def _error_loc(error_dict: "ErrorDetails") -> str:
)


def format_errors(error: ValidationError) -> str:
errors = error.errors()
msg = f"Found {len(errors)} validation error{'s' if len(errors) > 1 else ''}:\n\n"
def format_errors(validation_error: EverestValidationError) -> str:
msg = f"Found {len(validation_error.errors)} validation error{'s' if len(validation_error.errors) > 1 else ''}:\n\n"
error_map = {}
for err in error.errors():
key = _error_loc(err)

for error, pos in validation_error.errors:
print(error, pos)

if pos:
row, col = pos
key = f"line: {row}, column: {col}. {_error_loc(error)}"
else:
key = f"{_error_loc(error)}"
if key not in error_map:
error_map[key] = [key]
error_map[key].append(f" * {error['msg']} (type={error['type']})")
return msg + "\n".join(list(chain.from_iterable(error_map.values())))

for (row, col), err in error.errors.items():
# msg +=f"line: {row}, column: {col}. "
key = f"line: {row}, column: {col}. {_error_loc(err)}"
if key not in error_map:
error_map[key] = [key]
error_map[key].append(f" * {err['msg']} (type={err['type']})")
Expand Down
11 changes: 10 additions & 1 deletion tests/everest/functional/test_main_everest_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
pytestmark = pytest.mark.xdist_group(name="starts_everest")


@pytest.mark.integration_test
def test_everest_entry_docs():
"""Test calling everest with --docs
Expand All @@ -40,6 +41,7 @@ def test_everest_entry_docs():
assert not err.getvalue()


@pytest.mark.integration_test
def test_everest_entry_manual():
"""Test calling everest with --manual"""
with capture_streams() as (out, err), pytest.raises(SystemExit):
Expand All @@ -55,6 +57,7 @@ def test_everest_entry_manual():
assert not err.getvalue()


@pytest.mark.integration_test
def test_everest_entry_version():
"""Test calling everest with --version"""
with capture_streams() as (out, err), pytest.raises(SystemExit):
Expand All @@ -64,6 +67,7 @@ def test_everest_entry_version():
assert any(everest_version in channel for channel in channels)


@pytest.mark.integration_test
def test_everest_main_entry_bad_command():
# Setup command line arguments for the test
with capture_streams() as (_, err), pytest.raises(SystemExit):
Expand All @@ -76,6 +80,7 @@ def test_everest_main_entry_bad_command():

@pytest.mark.flaky(reruns=5)
@pytest.mark.fails_on_macos_github_workflow
@pytest.mark.integration_test
def test_everest_entry_run(copy_math_func_test_data_to_tmp):
# Setup command line arguments
with capture_streams():
Expand Down Expand Up @@ -108,6 +113,7 @@ def test_everest_entry_run(copy_math_func_test_data_to_tmp):
assert status["status"] == ServerStatus.completed


@pytest.mark.integration_test
def test_everest_entry_monitor_no_run(copy_math_func_test_data_to_tmp):
with capture_streams():
start_everest(["everest", "monitor", CONFIG_FILE_MINIMAL])
Expand All @@ -120,13 +126,15 @@ def test_everest_entry_monitor_no_run(copy_math_func_test_data_to_tmp):
assert status["status"] == ServerStatus.never_run


@pytest.mark.integration_test
def test_everest_main_export_entry(copy_math_func_test_data_to_tmp):
# Setup command line arguments
with capture_streams():
start_everest(["everest", "export", CONFIG_FILE_MINIMAL])
assert os.path.exists(os.path.join("everest_output", "config_minimal.csv"))


@pytest.mark.integration_test
def test_everest_main_lint_entry(copy_math_func_test_data_to_tmp):
# Setup command line arguments
with capture_streams() as (out, err):
Expand All @@ -149,7 +157,7 @@ def test_everest_main_lint_entry(copy_math_func_test_data_to_tmp):
type_ = "(type=float_parsing)"
validation_msg = dedent(
f"""Loading config file <config_minimal.yml> failed with:
Found 1 validation error:
Found 1 validation error:
controls -> 0 -> initial_guess
* Input should be a valid number, unable to parse string as a number {type_}
Expand All @@ -161,6 +169,7 @@ def test_everest_main_lint_entry(copy_math_func_test_data_to_tmp):
@pytest.mark.fails_on_macos_github_workflow
@skipif_no_everest_models
@pytest.mark.everest_models_test
@pytest.mark.integration_test
def test_everest_main_configdump_entry(copy_egg_test_data_to_tmp):
# Setup command line arguments
with capture_streams() as (out, _):
Expand Down
6 changes: 3 additions & 3 deletions tests/everest/test_config_file_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from unittest.mock import patch

import pytest
from pydantic_core import ValidationError
from ruamel.yaml import YAML

from everest import ConfigKeys as CK
from everest import config_file_loader as loader
from everest.config import EverestConfig
from everest.config.everest_config import EverestValidationError
from tests.everest.utils import relpath

mocked_root = relpath(os.path.join("test_data", "mocked_test_case"))
Expand Down Expand Up @@ -122,12 +122,12 @@ def test_dependent_definitions_value_error(copy_mocked_test_data_to_tmp):
def test_load_empty_configuration(copy_mocked_test_data_to_tmp):
with open("empty_config.yml", mode="w", encoding="utf-8") as fh:
fh.writelines("")
with pytest.raises(ValidationError, match="missing"):
with pytest.raises(EverestValidationError, match="missing"):
EverestConfig.load_file("empty_config.yml")


def test_load_invalid_configuration(copy_mocked_test_data_to_tmp):
with open("invalid_config.yml", mode="w", encoding="utf-8") as fh:
fh.writelines("asdf")
with pytest.raises(ValidationError, match="missing"):
with pytest.raises(EverestValidationError, match="missing"):
EverestConfig.load_file("invalid_config.yml")
24 changes: 24 additions & 0 deletions tests/everest/test_config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pathlib
import re
import warnings
from argparse import ArgumentParser
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -973,3 +974,26 @@ def test_warning_forward_model_write_objectives(objective, forward_model, warnin
def test_deprecated_keyword():
with pytest.warns(ConfigWarning, match="report_steps .* can be removed"):
ModelConfig(**{"report_steps": []})


def test_load_file_non_existing():
with pytest.raises(FileNotFoundError):
EverestConfig.load_file("non_existing.yml")


def test_load_file_with_errors(copy_math_func_test_data_to_tmp, capsys):
with open("config_minimal.yml", encoding="utf-8") as file:
content = file.read()

with open("config_minimal_error.yml", "w", encoding="utf-8") as file:
file.write(content.replace("generic_control", "yolo_control"))

with pytest.raises(SystemExit):
parser = ArgumentParser(prog="test")
EverestConfig.load_file_with_argparser("config_minimal_error.yml", parser)

captured = capsys.readouterr()

assert "Found 1 validation error" in captured.err
assert "line: 4, column: 11" in captured.err
assert "Input should be 'well_control' or 'generic_control'" in captured.err
3 changes: 2 additions & 1 deletion tests/everest/test_yaml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@


@pytest.mark.parametrize("random_seed", [None, 1234])
def test_random_seed(random_seed):
def test_random_seed(tmp_path, monkeypatch, random_seed):
monkeypatch.chdir(tmp_path)
config = {"model": {"realizations": [0]}}
if random_seed:
config["environment"] = {"random_seed": random_seed}
Expand Down

0 comments on commit bc246ad

Please sign in to comment.