Skip to content

Commit

Permalink
Convert ErtConfig to dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Nov 14, 2024
1 parent d144207 commit 9a3c864
Show file tree
Hide file tree
Showing 17 changed files with 168 additions and 67 deletions.
10 changes: 6 additions & 4 deletions src/ert/config/design_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ class DesignMatrix:
xls_filename: Path
design_sheet: str
default_sheet: str
num_realizations: Optional[int] = None
active_realizations: Optional[List[bool]] = None
design_matrix_df: Optional[pd.DataFrame] = None
parameter_configuration: Optional[Dict[str, ParameterConfig]] = None

def __post_init__(self) -> None:
self.num_realizations: Optional[int] = None
self.active_realizations: Optional[List[bool]] = None
self.design_matrix_df: Optional[pd.DataFrame] = None
self.parameter_configuration: Optional[Dict[str, ParameterConfig]] = None

@classmethod
def from_config_list(cls, config_list: List[str]) -> "DesignMatrix":
Expand Down
15 changes: 10 additions & 5 deletions src/ert/config/ensemble_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from ert.field_utils import get_shape

from .field import Field
from .ext_param_config import ExtParamConfig
from .field import Field as FieldConfig
from .gen_data_config import GenDataConfig
from .gen_kw_config import GenKwConfig
from .parameter_config import ParameterConfig
Expand Down Expand Up @@ -49,8 +50,12 @@ def _get_abs_path(file: Optional[str]) -> Optional[str]:
@dataclass
class EnsembleConfig:
grid_file: Optional[str] = None
response_configs: Dict[str, ResponseConfig] = field(default_factory=dict)
parameter_configs: Dict[str, ParameterConfig] = field(default_factory=dict)
response_configs: Dict[str, Union[SummaryConfig, GenDataConfig]] = field(
default_factory=dict
)
parameter_configs: Dict[
str, GenKwConfig | FieldConfig | SurfaceConfig | ExtParamConfig
] = field(default_factory=dict)
refcase: Optional[Refcase] = None

def __post_init__(self) -> None:
Expand Down Expand Up @@ -92,7 +97,7 @@ def from_dict(cls, config_dict: ConfigDict) -> EnsembleConfig:
grid_file_path,
) from err

def make_field(field_list: List[str]) -> Field:
def make_field(field_list: List[str]) -> FieldConfig:
if grid_file_path is None:
raise ConfigValidationError.with_context(
"In order to use the FIELD keyword, a GRID must be supplied.",
Expand All @@ -103,7 +108,7 @@ def make_field(field_list: List[str]) -> Field:
f"Grid file {grid_file_path} did not contain dimensions",
grid_file_path,
)
return Field.from_config_list(grid_file_path, dims, field_list)
return FieldConfig.from_config_list(grid_file_path, dims, field_list)

parameter_configs = (
[GenKwConfig.from_config_list(g) for g in gen_kw_list]
Expand Down
55 changes: 41 additions & 14 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import logging
import os
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import field
from datetime import datetime
from os import path
from pathlib import Path
from typing import (
Any,
ClassVar,
DefaultDict,
Dict,
List,
Optional,
Expand All @@ -24,6 +25,8 @@

import polars
from pydantic import ValidationError as PydanticValidationError
from pydantic import field_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self

from ert.plugins import ErtPluginManager
Expand All @@ -49,6 +52,7 @@
ConfigWarning,
ErrorInfo,
ForwardModelStepKeys,
HistorySource,
HookRuntime,
init_forward_model_schema,
init_site_config_schema,
Expand Down Expand Up @@ -256,7 +260,9 @@ class ErtConfig:
queue_config: QueueConfig = field(default_factory=QueueConfig)
workflow_jobs: Dict[str, WorkflowJob] = field(default_factory=dict)
workflows: Dict[str, Workflow] = field(default_factory=dict)
hooked_workflows: Dict[HookRuntime, List[Workflow]] = field(default_factory=dict)
hooked_workflows: DefaultDict[HookRuntime, List[Workflow]] = field(
default_factory=lambda: defaultdict(list)
)
runpath_file: Path = Path(DEFAULT_RUNPATH_FILE)
ert_templates: List[Tuple[str, str]] = field(default_factory=list)
installed_forward_model_steps: Dict[str, ForwardModelStep] = field(
Expand All @@ -269,6 +275,14 @@ class ErtConfig:
observation_config: List[
Tuple[str, Union[HistoryValues, SummaryValues, GenObsValues]]
] = field(default_factory=list)
enkf_obs: EnkfObs = field(default_factory=EnkfObs)

@field_validator("substitutions", mode="before")
@classmethod
def convert_to_substitutions(cls, v: Dict[str, str]) -> Substitutions:
if isinstance(v, Substitutions):
return v
return Substitutions(v)

def __eq__(self, other: object) -> bool:
if not isinstance(other, ErtConfig):
Expand Down Expand Up @@ -298,8 +312,6 @@ def __post_init__(self) -> None:
if self.user_config_file
else os.getcwd()
)
self.enkf_obs: EnkfObs = self._create_observations(self.observation_config)

self.observations: Dict[str, polars.DataFrame] = self.enkf_obs.datasets

@staticmethod
Expand Down Expand Up @@ -456,7 +468,7 @@ def from_dict(cls, config_dict) -> Self:
errors.append(err)

obs_config_file = config_dict.get(ConfigKeys.OBS_CONFIG)
obs_config_content = None
obs_config_content = []
try:
if obs_config_file:
if path.isfile(obs_config_file) and path.getsize(obs_config_file) == 0:
Expand Down Expand Up @@ -487,6 +499,19 @@ def from_dict(cls, config_dict) -> Self:
[key] for key in summary_obs if key not in summary_keys
]
ensemble_config = EnsembleConfig.from_dict(config_dict=config_dict)
if model_config:
observations = cls._create_observations(
obs_config_content,
ensemble_config,
model_config.time_map,
model_config.history_source,
)
else:
errors.append(
ConfigValidationError(
"Not possible to validate observations without valid model config"
)
)
except ConfigValidationError as err:
errors.append(err)

Expand Down Expand Up @@ -519,6 +544,7 @@ def from_dict(cls, config_dict) -> Self:
model_config=model_config,
user_config_file=config_file_path,
observation_config=obs_config_content,
enkf_obs=observations,
)

@classmethod
Expand Down Expand Up @@ -970,24 +996,25 @@ def _installed_forward_model_steps_from_dict(
def preferred_num_cpu(self) -> int:
return int(self.substitutions.get(f"<{ConfigKeys.NUM_CPU}>", 1))

@staticmethod
def _create_observations(
self,
obs_config_content: Optional[
Dict[str, Union[HistoryValues, SummaryValues, GenObsValues]]
],
ensemble_config: EnsembleConfig,
time_map: Optional[List[datetime]],
history: HistorySource,
) -> EnkfObs:
if not obs_config_content:
return EnkfObs({}, [])
obs_vectors: Dict[str, ObsVector] = {}
obs_time_list: Sequence[datetime] = []
if self.ensemble_config.refcase is not None:
obs_time_list = self.ensemble_config.refcase.all_dates
elif self.model_config.time_map is not None:
obs_time_list = self.model_config.time_map
if ensemble_config.refcase is not None:
obs_time_list = ensemble_config.refcase.all_dates
elif time_map is not None:
obs_time_list = time_map

history = self.model_config.history_source
time_len = len(obs_time_list)
ensemble_config = self.ensemble_config
config_errors: List[ErrorInfo] = []
for obs_name, values in obs_config_content:
try:
Expand Down Expand Up @@ -1059,7 +1086,7 @@ def _get_files_in_directory(job_path, errors):


def _substitutions_from_dict(config_dict) -> Substitutions:
subst_list = Substitutions()
subst_list = {}

for key, val in config_dict.get("DEFINE", []):
subst_list[key] = val
Expand All @@ -1077,7 +1104,7 @@ def _substitutions_from_dict(config_dict) -> Substitutions:
for key, val in config_dict.get("DATA_KW", []):
subst_list[key] = val

return subst_list
return Substitutions(subst_list)


@no_type_check
Expand Down
2 changes: 1 addition & 1 deletion src/ert/config/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import logging
import os
import time
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Optional, Union, overload

import numpy as np
import xarray as xr
from pydantic.dataclasses import dataclass
from typing_extensions import Self

from ert.field_utils import FieldFileFormat, Shape, read_field, read_mask, save_field
Expand Down
9 changes: 9 additions & 0 deletions src/ert/config/forward_model_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from dataclasses import dataclass, field
from typing import (
ClassVar,
Dict,
Literal,
Optional,
TypedDict,
Union,
)

from pydantic import field_validator
from typing_extensions import NotRequired, Unpack

from ert.config.parsing.config_errors import ConfigWarning
Expand Down Expand Up @@ -172,6 +174,13 @@ class ForwardModelStep:
"_ERT_RUNPATH": "<RUNPATH>",
}

@field_validator("private_args", mode="before")
@classmethod
def convert_to_substitutions(cls, v: Dict[str, str]) -> Substitutions:
if isinstance(v, Substitutions):
return v
return Substitutions(v)

def validate_pre_experiment(self, fm_step_json: ForwardModelStepJSON) -> None:
"""
Raise errors pertaining to the environment not being
Expand Down
10 changes: 5 additions & 5 deletions src/ert/config/general_observation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List

import numpy as np
import numpy.typing as npt


@dataclass(eq=False)
class GenObservation:
values: npt.NDArray[np.double]
stds: npt.NDArray[np.double]
indices: npt.NDArray[np.int32]
std_scaling: npt.NDArray[np.double]
values: List[float]
stds: List[float]
indices: List[int]
std_scaling: List[float]

def __post_init__(self) -> None:
for val in self.stds:
Expand Down
10 changes: 6 additions & 4 deletions src/ert/config/observations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union
Expand Down Expand Up @@ -39,8 +39,8 @@ def history_key(key: str) -> str:

@dataclass
class EnkfObs:
obs_vectors: Dict[str, ObsVector]
obs_time: List[datetime]
obs_vectors: Dict[str, ObsVector] = field(default_factory=dict)
obs_time: List[datetime] = field(default_factory=list)

def __post_init__(self) -> None:
grouped: Dict[str, List[polars.DataFrame]] = {}
Expand Down Expand Up @@ -394,7 +394,9 @@ def _create_gen_obs(
f"index list ({indices}) must be of equal length",
obs_file if obs_file is not None else "",
)
return GenObservation(values, stds, indices, std_scaling)
return GenObservation(
values.tolist(), stds.tolist(), indices.tolist(), std_scaling.tolist()
)

@classmethod
def _handle_general_observation(
Expand Down
3 changes: 1 addition & 2 deletions src/ert/config/queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Any, Dict, List, Literal, Mapping, Optional, Union, no_type_check

import pydantic
from pydantic import Field
from pydantic.dataclasses import dataclass
from typing_extensions import Annotated

Expand Down Expand Up @@ -270,7 +269,7 @@ class QueueConfig:
queue_system: QueueSystem = QueueSystem.LOCAL
queue_options: Union[
LsfQueueOptions, TorqueQueueOptions, SlurmQueueOptions, LocalQueueOptions
] = Field(default_factory=LocalQueueOptions, discriminator="name")
] = pydantic.Field(default_factory=LocalQueueOptions, discriminator="name")
queue_options_test_run: LocalQueueOptions = field(default_factory=LocalQueueOptions)
stop_long_running: bool = False

Expand Down
8 changes: 4 additions & 4 deletions src/ert/config/refcase.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from dataclasses import dataclass
from datetime import datetime
from typing import (
Any,
List,
Optional,
Sequence,
)

import numpy as np
import numpy.typing as npt

from ._read_summary import read_summary
from .parsing.config_dict import ConfigDict
Expand All @@ -21,7 +19,7 @@ class Refcase:
start_date: datetime
keys: List[str]
dates: Sequence[datetime]
values: npt.NDArray[Any]
values: List[List[float]]

def __eq__(self, other: object) -> bool:
if not isinstance(other, Refcase):
Expand Down Expand Up @@ -50,5 +48,7 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Optional["Refcase"]:
raise ConfigValidationError(f"Could not read refcase: {err}") from err

return (
cls(start_date, refcase_keys, time_map, data) if data is not None else None
cls(start_date, refcase_keys, time_map, data.tolist())
if data is not None
else None
)
11 changes: 4 additions & 7 deletions src/ert/config/workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple

from .parsing import ConfigValidationError, ErrorInfo, init_workflow_schema, parse
Expand All @@ -11,14 +12,10 @@
from .workflow_job import WorkflowJob


@dataclass
class Workflow:
def __init__(
self,
src_file: str,
cmd_list: List[Tuple[WorkflowJob, Any]],
):
self.src_file = src_file
self.cmd_list = cmd_list
src_file: str
cmd_list: List[Tuple[WorkflowJob, Any]]

def __len__(self) -> int:
return len(self.cmd_list)
Expand Down
2 changes: 1 addition & 1 deletion src/ert/config/workflow_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def from_file(cls, config_file: str, name: Optional[str] = None) -> "WorkflowJob
arg_types_list = cls._make_arg_types_list(content_dict)
return cls(
name=name,
internal=content_dict.get("INTERNAL"), # type: ignore
internal=bool(content_dict.get("INTERNAL", False)), # type: ignore
min_args=content_dict.get("MIN_ARG"), # type: ignore
max_args=content_dict.get("MAX_ARG"), # type: ignore
arg_types=arg_types_list,
Expand Down
Loading

0 comments on commit 9a3c864

Please sign in to comment.