Skip to content

Commit

Permalink
Make EnsembleConfig agnostic of response impls
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen committed Aug 27, 2024
1 parent 7935e45 commit 3edf135
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 82 deletions.
84 changes: 12 additions & 72 deletions src/ert/config/ensemble_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,24 @@
import os
from collections import Counter
from dataclasses import dataclass, field
from datetime import datetime
from typing import (
Any,
Dict,
List,
Optional,
Sequence,
Union,
no_type_check,
overload,
)

import numpy as np
import numpy.typing as npt

from ert.field_utils import get_shape

from ._read_summary import read_summary
from .field import Field
from .gen_data_config import GenDataConfig
from .gen_kw_config import GenKwConfig
from .parameter_config import ParameterConfig
from .parsing import ConfigDict, ConfigKeys, ConfigValidationError
from .refcase import Refcase
from .response_config import ResponseConfig
from .summary_config import SummaryConfig
from .responses_index import responses_index
from .surface_config import SurfaceConfig

logger = logging.getLogger(__name__)
Expand All @@ -50,28 +43,6 @@ def _get_abs_path(file: Optional[str]) -> Optional[str]:
return file


@dataclass(eq=False)
class Refcase:
start_date: datetime
keys: List[str]
dates: Sequence[datetime]
values: npt.NDArray[Any]

def __eq__(self, other: object) -> bool:
if not isinstance(other, Refcase):
return False
return bool(
self.start_date == other.start_date
and self.keys == other.keys
and self.dates == other.dates
and np.all(self.values == other.values)
)

@property
def all_dates(self) -> List[datetime]:
return [self.start_date] + list(self.dates)


@dataclass
class EnsembleConfig:
grid_file: Optional[str] = None
Expand All @@ -90,7 +61,7 @@ def __post_init__(self) -> None:

@staticmethod
def _check_for_duplicate_names(
parameter_list: List[ParameterConfig], gen_data_list: List[ResponseConfig]
parameter_list: List[str], gen_data_list: List[str]
) -> None:
names_counter = Counter(g for g in parameter_list + gen_data_list)
duplicate_names = [n for n, c in names_counter.items() if c > 1]
Expand All @@ -106,7 +77,6 @@ def _check_for_duplicate_names(
@classmethod
def from_dict(cls, config_dict: ConfigDict) -> EnsembleConfig:
grid_file_path = config_dict.get(ConfigKeys.GRID)
refcase_file_path = config_dict.get(ConfigKeys.REFCASE)
gen_kw_list = config_dict.get(ConfigKeys.GEN_KW, [])
surface_list = config_dict.get(ConfigKeys.SURFACE, [])
field_list = config_dict.get(ConfigKeys.FIELD, [])
Expand All @@ -133,19 +103,6 @@ def make_field(field_list: List[str]) -> Field:
)
return Field.from_config_list(grid_file_path, dims, field_list)

eclbase = config_dict.get("ECLBASE")
if eclbase is not None:
eclbase = eclbase.replace("%d", "<IENS>")
refcase_keys = []
time_map = []
data = None
if refcase_file_path is not None:
try:
start_date, refcase_keys, time_map, data = read_summary(
refcase_file_path, ["*"]
)
except Exception as err:
raise ConfigValidationError(f"Could not read refcase: {err}") from err
parameter_configs = (
[GenKwConfig.from_config_list(g) for g in gen_kw_list]
+ [SurfaceConfig.from_config_list(s) for s in surface_list]
Expand All @@ -154,32 +111,16 @@ def make_field(field_list: List[str]) -> Field:

response_configs: List[ResponseConfig] = []

if "GEN_DATA" in config_dict:
gen_data_config = GenDataConfig.from_config_dict(config_dict)
if len(gen_data_config.keys) > 0:
response_configs.append(gen_data_config)
for config_cls in responses_index.values():
instance = config_cls.from_config_dict(config_dict)

refcase = (
Refcase(start_date, refcase_keys, time_map, data)
if data is not None
else None
)
summary_keys = config_dict.get(ConfigKeys.SUMMARY, [])
if summary_keys:
if eclbase is None:
raise ConfigValidationError(
"In order to use summary responses, ECLBASE has to be set."
)
time_map = set(refcase.dates) if refcase is not None else None

response_configs.append(
SummaryConfig(
name="summary",
input_files=[eclbase],
keys=[key for keys in summary_keys for key in keys],
refcase=time_map,
)
)
if instance is not None and instance.keys:
response_configs.append(instance)

refcase = Refcase.from_config_dict(config_dict)
eclbase = config_dict.get("ECLBASE")
if eclbase is not None:
eclbase = eclbase.replace("%d", "<IENS>")

return cls(
grid_file=grid_file_path,
Expand Down Expand Up @@ -211,7 +152,6 @@ def hasNodeGenData(self, key: str) -> bool:
return False

config = self.response_configs["gen_data"]
assert isinstance(config, GenDataConfig)
return key in config.keys

def addNode(self, config_node: Union[ParameterConfig, ResponseConfig]) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/ert/config/gen_data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ._option_dict import option_dict
from .parsing import ConfigDict, ConfigValidationError, ErrorInfo
from .response_config import ResponseConfig
from .responses_index import responses_index


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/ert/config/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import xarray as xr

from ert.validation import rangestring_to_list
from . import GenDataConfig

from .enkf_observation_implementation_type import EnkfObservationImplementationType
from .gen_data_config import GenDataConfig
from .general_observation import GenObservation
from .observation_vector import ObsVector
from .parsing import ConfigWarning, HistorySource
Expand Down
54 changes: 54 additions & 0 deletions src/ert/config/refcase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from dataclasses import dataclass
from datetime import datetime
from typing import (
Any,
List,
Optional,
Sequence,
)

import numpy as np
import numpy.typing as npt

from ._read_summary import read_summary
from .parsing.config_dict import ConfigDict
from .parsing.config_errors import ConfigValidationError
from .parsing.config_keywords import ConfigKeys


@dataclass(eq=False)
class Refcase:
start_date: datetime
keys: List[str]
dates: Sequence[datetime]
values: npt.NDArray[Any]

def __eq__(self, other: object) -> bool:
if not isinstance(other, Refcase):
return False
return bool(
self.start_date == other.start_date
and self.keys == other.keys
and self.dates == other.dates
and np.all(self.values == other.values)
)

@property
def all_dates(self) -> List[datetime]:
return [self.start_date] + list(self.dates)

@classmethod
def from_config_dict(cls, config_dict: ConfigDict) -> Optional["Refcase"]:
data = None
refcase_file_path = config_dict.get(ConfigKeys.REFCASE) # type: ignore
if refcase_file_path is not None:
try:
start_date, refcase_keys, time_map, data = read_summary(
refcase_file_path, ["*"]
)
except Exception as err:
raise ConfigValidationError(f"Could not read refcase: {err}") from err

return (
cls(start_date, refcase_keys, time_map, data) if data is not None else None
)
6 changes: 4 additions & 2 deletions src/ert/config/response_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import dataclasses
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import xarray as xr
from typing_extensions import Self

from ert.config.parameter_config import CustomDict
from ert.config.parsing import ConfigDict
Expand All @@ -29,8 +30,9 @@ def response_type(self) -> str:
Must not overlap with that of other response configs."""
...

@classmethod
@abstractmethod
def from_config_dict(self, config_dict: ConfigDict) -> "ResponseConfig":
def from_config_dict(cls, config_dict: ConfigDict) -> Optional[Self]:
"""Creates a config, given an ert config dict.
A response config may depend on several config kws, such as REFCASE
for summary."""
35 changes: 31 additions & 4 deletions src/ert/config/summary_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Set, Union
from typing import TYPE_CHECKING, Optional, Set, Union

import xarray as xr

from ._read_summary import read_summary
from .parsing import ConfigDict
from .ensemble_config import Refcase
from .parsing import ConfigDict, ConfigKeys
from .parsing.config_errors import ConfigValidationError
from .response_config import ResponseConfig
from .responses_index import responses_index

if TYPE_CHECKING:
from typing import List
Expand Down Expand Up @@ -49,5 +52,29 @@ def read_from_file(self, run_path: str, iens: int) -> xr.Dataset:
def response_type(self) -> str:
return "summary"

def from_config_dict(self, config_dict: ConfigDict) -> "ResponseConfig":
pass
@classmethod
def from_config_dict(self, config_dict: ConfigDict) -> Optional[SummaryConfig]:
refcase = Refcase.from_config_dict(config_dict)
eclbase = config_dict.get("ECLBASE") # type: ignore
if eclbase is not None:
eclbase = eclbase.replace("%d", "<IENS>")

summary_keys = config_dict.get(ConfigKeys.SUMMARY, []) # type: ignore
if summary_keys:
if eclbase is None:
raise ConfigValidationError(
"In order to use summary responses, ECLBASE has to be set."
)
time_map = set(refcase.dates) if refcase is not None else None

return SummaryConfig(
name="summary",
input_files=[eclbase],
keys=[key for keys in summary_keys for key in keys],
refcase=time_map,
)

return None


responses_index.add_response_type(SummaryConfig)
9 changes: 6 additions & 3 deletions tests/unit_tests/config/test_ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,11 +1442,14 @@ def test_validate_no_logs_when_overwriting_with_same_value(caplog):
ert_conf.forward_model_data_to_json("0", "0", 0)
assert (
"Private arg '<VAR3>':'5' chosen over global '55' in forward model step "
"step_name" in caplog.text
"step_name"
in caplog.text
and "Private arg '<VAR1>':'10' chosen over global '10' in forward model "
"step step_name" not in caplog.text
"step step_name"
not in caplog.text
and "Private arg '<VAR2>':'20' chosen over global '20' in forward model "
"step step_name" not in caplog.text
"step step_name"
not in caplog.text
)


Expand Down
51 changes: 51 additions & 0 deletions tests/unit_tests/config/test_responses_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest

from ert.config.gen_data_config import GenDataConfig
from ert.config.responses_index import responses_index
from ert.config.summary_config import SummaryConfig


def test_adding_gendata_and_summary():
ri = responses_index

# Manually reset it
ri._items = {}

assert [*ri.keys()] == []
assert [*ri.values()] == []
assert [*ri.items()] == []

ri.add_response_type(GenDataConfig)
assert [*ri.keys()] == ["GenDataConfig"]
assert [*ri.values()] == [GenDataConfig]
assert [*ri.items()] == [("GenDataConfig", GenDataConfig)]

with pytest.raises(
KeyError, match="Response type with name GenDataConfig is already registered"
):
ri.add_response_type(GenDataConfig)

ri.add_response_type(SummaryConfig)
assert [*ri.keys()] == ["GenDataConfig", "SummaryConfig"]
assert [*ri.values()] == [GenDataConfig, SummaryConfig]
assert [*ri.items()] == [
("GenDataConfig", GenDataConfig),
("SummaryConfig", SummaryConfig),
]

with pytest.raises(
KeyError, match="Response type with name SummaryConfig is already registered"
):
ri.add_response_type(SummaryConfig)


def test_adding_non_response_config():
ri = responses_index

class NotAResponseConfig:
pass

with pytest.raises(
ValueError, match="Response type must be subclass of ResponseConfig"
):
ri.add_response_type(NotAResponseConfig)

0 comments on commit 3edf135

Please sign in to comment.