Skip to content

Commit

Permalink
Read from runpath to polars
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen committed Sep 13, 2024
1 parent 9474217 commit 71d9ba8
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 40 deletions.
34 changes: 20 additions & 14 deletions src/ert/config/gen_data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Optional, Tuple

import numpy as np
import xarray as xr
import polars
from typing_extensions import Self

from ert.validation import rangestring_to_list
Expand Down Expand Up @@ -107,21 +107,23 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Optional[Self]:
report_steps_list=report_steps,
)

def read_from_file(self, run_path: str, _: int) -> xr.Dataset:
def _read_file(filename: Path, report_step: int) -> xr.Dataset:
def read_from_file(self, run_path: str, _: int) -> polars.DataFrame:
def _read_file(filename: Path, report_step: int) -> polars.DataFrame:
if not filename.exists():
raise ValueError(f"Missing output file: {filename}")
data = np.loadtxt(_run_path / filename, ndmin=1)
active_information_file = _run_path / (str(filename) + "_active")
if active_information_file.exists():
active_list = np.loadtxt(active_information_file)
data[active_list == 0] = np.nan
return xr.Dataset(
{"values": (["report_step", "index"], [data])},
coords={
"index": np.arange(len(data)),
"report_step": [report_step],
},
return polars.DataFrame(
{
"report_step": polars.Series(
np.full(len(data), report_step), dtype=polars.UInt16
),
"index": polars.Series(np.arange(len(data)), dtype=polars.UInt16),
"values": polars.Series(data, dtype=polars.Float32),
}
)

errors = []
Expand Down Expand Up @@ -150,16 +152,16 @@ def _read_file(filename: Path, report_step: int) -> xr.Dataset:
except ValueError as err:
errors.append(str(err))

ds_all_report_steps = xr.concat(
datasets_per_report_step, dim="report_step"
).expand_dims(name=[name])
ds_all_report_steps = polars.concat(datasets_per_report_step)
ds_all_report_steps.insert_column(
0, polars.Series("response_key", [name] * len(ds_all_report_steps))
)
datasets_per_name.append(ds_all_report_steps)

if errors:
raise ValueError(f"Error reading GEN_DATA: {self.name}, errors: {errors}")

combined = xr.concat(datasets_per_name, dim="name")
combined.attrs["response"] = "gen_data"
combined = polars.concat(datasets_per_name)
return combined

def get_args_for_key(self, key: str) -> Tuple[Optional[str], Optional[List[int]]]:
Expand All @@ -173,5 +175,9 @@ def get_args_for_key(self, key: str) -> Tuple[Optional[str], Optional[List[int]]
def response_type(self) -> str:
return "gen_data"

@property
def primary_key(self) -> List[str]:
return ["report_step", "index"]


responses_index.add_response_type(GenDataConfig)
52 changes: 35 additions & 17 deletions src/ert/config/observation_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Iterable, List, Union

import xarray as xr
import numpy as np

from .enkf_observation_implementation_type import EnkfObservationImplementationType
from .general_observation import GenObservation
Expand All @@ -12,6 +12,8 @@
if TYPE_CHECKING:
from datetime import datetime

import polars


@dataclass
class ObsVector:
Expand All @@ -27,28 +29,39 @@ def __iter__(self) -> Iterable[Union[SummaryObservation, GenObservation]]:
def __len__(self) -> int:
return len(self.observations)

def to_dataset(self, active_list: List[int]) -> xr.Dataset:
def to_dataset(self, active_list: List[int]) -> polars.DataFrame:
if self.observation_type == EnkfObservationImplementationType.GEN_OBS:
datasets = []
actual_response_key = self.data_key
actual_observation_key = self.observation_key
dataframes = []
for time_step, node in self.observations.items():
if active_list and time_step not in active_list:
continue

assert isinstance(node, GenObservation)
datasets.append(
xr.Dataset(
dataframes.append(
polars.DataFrame(
{
"observations": (["report_step", "index"], [node.values]),
"std": (["report_step", "index"], [node.stds]),
},
coords={"index": node.indices, "report_step": [time_step]},
"response_key": actual_response_key,
"observation_key": actual_observation_key,
"report_step": polars.Series(
np.full(len(node.indices), time_step),
dtype=polars.UInt16,
),
"index": polars.Series(node.indices, dtype=polars.UInt16),
"observations": polars.Series(
node.values, dtype=polars.Float32
),
"std": polars.Series(node.stds, dtype=polars.Float32),
}
)
)
combined = xr.combine_by_coords(datasets)
combined.attrs["response"] = self.data_key
combined = polars.concat(dataframes)
return combined # type: ignore
elif self.observation_type == EnkfObservationImplementationType.SUMMARY_OBS:
observations = []
actual_response_key = self.observation_key
actual_observation_keys = []
errors = []
dates = list(self.observations.keys())
if active_list:
Expand All @@ -57,15 +70,20 @@ def to_dataset(self, active_list: List[int]) -> xr.Dataset:
for time_step in dates:
n = self.observations[time_step]
assert isinstance(n, SummaryObservation)
actual_observation_keys.append(n.observation_key)
observations.append(n.value)
errors.append(n.std)
return xr.Dataset(

dates_series = polars.Series(dates).dt.cast_time_unit("ms")

return polars.DataFrame(
{
"observations": (["name", "time"], [observations]),
"std": (["name", "time"], [errors]),
},
coords={"time": dates, "name": [self.observation_key]},
attrs={"response": "summary"},
"response_key": actual_response_key,
"observation_key": actual_observation_keys,
"time": dates_series,
"observations": polars.Series(observations, dtype=polars.Float32),
"std": polars.Series(errors, dtype=polars.Float32),
}
)
else:
raise ValueError(f"Unknown observation type {self.observation_type}")
4 changes: 2 additions & 2 deletions src/ert/config/response_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional

import xarray as xr
import polars
from typing_extensions import Self

from ert.config.parameter_config import CustomDict
Expand All @@ -16,7 +16,7 @@ class ResponseConfig(ABC):
keys: List[str] = dataclasses.field(default_factory=list)

@abstractmethod
def read_from_file(self, run_path: str, iens: int) -> xr.Dataset: ...
def read_from_file(self, run_path: str, iens: int) -> polars.DataFrame: ...

def to_dict(self) -> Dict[str, Any]:
data = dataclasses.asdict(self, dict_factory=CustomDict)
Expand Down
21 changes: 14 additions & 7 deletions src/ert/config/summary_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from datetime import datetime
from typing import TYPE_CHECKING, Optional, Set, Union

import xarray as xr

from ._read_summary import read_summary
from .ensemble_config import Refcase
from .parsing import ConfigDict, ConfigKeys
Expand All @@ -18,6 +16,7 @@
from typing import List

logger = logging.getLogger(__name__)
import polars


@dataclass
Expand All @@ -37,7 +36,7 @@ def expected_input_files(self) -> List[str]:
base = self.input_files[0]
return [f"{base}.UNSMRY", f"{base}.SMSPEC"]

def read_from_file(self, run_path: str, iens: int) -> xr.Dataset:
def read_from_file(self, run_path: str, iens: int) -> polars.DataFrame:
filename = self.input_files[0].replace("<IENS>", str(iens))
_, keys, time_map, data = read_summary(f"{run_path}/{filename}", self.keys)
if len(data) == 0 or len(keys) == 0:
Expand All @@ -47,11 +46,19 @@ def read_from_file(self, run_path: str, iens: int) -> xr.Dataset:
raise ValueError(
f"Did not find any summary values matching {self.keys} in {filename}"
)
ds = xr.Dataset(
{"values": (["name", "time"], data)},
coords={"time": time_map, "name": keys},

# Important: Pick lowest unit resolution to allow for using
# datetimes many years into the future
time_map_series = polars.Series(time_map).dt.cast_time_unit("ms")
df = polars.DataFrame(
{
"response_key": keys,
"time": [time_map_series for _ in data],
"values": [polars.Series(row, dtype=polars.Float32) for row in data],
}
)
return ds.drop_duplicates("time")
df = df.explode("values", "time")
return df

@property
def response_type(self) -> str:
Expand Down

0 comments on commit 71d9ba8

Please sign in to comment.