diff --git a/src/ert/config/gen_data_config.py b/src/ert/config/gen_data_config.py index 98b4d576b53..c9c4f39b277 100644 --- a/src/ert/config/gen_data_config.py +++ b/src/ert/config/gen_data_config.py @@ -5,7 +5,7 @@ from typing import List, Optional, Tuple import numpy as np -import xarray as xr +import polars from typing_extensions import Self from ert.validation import rangestring_to_list @@ -107,8 +107,8 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Optional[Self]: report_steps_list=report_steps, ) - def read_from_file(self, run_path: str, _: int) -> xr.Dataset: - def _read_file(filename: Path, report_step: int) -> xr.Dataset: + def read_from_file(self, run_path: str, _: int) -> polars.DataFrame: + def _read_file(filename: Path, report_step: int) -> polars.DataFrame: if not filename.exists(): raise ValueError(f"Missing output file: {filename}") data = np.loadtxt(_run_path / filename, ndmin=1) @@ -116,12 +116,14 @@ def _read_file(filename: Path, report_step: int) -> xr.Dataset: if active_information_file.exists(): active_list = np.loadtxt(active_information_file) data[active_list == 0] = np.nan - return xr.Dataset( - {"values": (["report_step", "index"], [data])}, - coords={ - "index": np.arange(len(data)), - "report_step": [report_step], - }, + return polars.DataFrame( + { + "report_step": polars.Series( + np.full(len(data), report_step), dtype=polars.UInt16 + ), + "index": polars.Series(np.arange(len(data)), dtype=polars.UInt16), + "values": polars.Series(data, dtype=polars.Float32), + } ) errors = [] @@ -150,16 +152,16 @@ def _read_file(filename: Path, report_step: int) -> xr.Dataset: except ValueError as err: errors.append(str(err)) - ds_all_report_steps = xr.concat( - datasets_per_report_step, dim="report_step" - ).expand_dims(name=[name]) + ds_all_report_steps = polars.concat(datasets_per_report_step) + ds_all_report_steps.insert_column( + 0, polars.Series("response_key", [name] * len(ds_all_report_steps)) + ) datasets_per_name.append(ds_all_report_steps) if errors: raise ValueError(f"Error reading GEN_DATA: {self.name}, errors: {errors}") - combined = xr.concat(datasets_per_name, dim="name") - combined.attrs["response"] = "gen_data" + combined = polars.concat(datasets_per_name) return combined def get_args_for_key(self, key: str) -> Tuple[Optional[str], Optional[List[int]]]: @@ -173,5 +175,9 @@ def get_args_for_key(self, key: str) -> Tuple[Optional[str], Optional[List[int]] def response_type(self) -> str: return "gen_data" + @property + def primary_key(self) -> List[str]: + return ["report_step", "index"] + responses_index.add_response_type(GenDataConfig) diff --git a/src/ert/config/observation_vector.py b/src/ert/config/observation_vector.py index f77364b0ade..eb8bbd4ab22 100644 --- a/src/ert/config/observation_vector.py +++ b/src/ert/config/observation_vector.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Iterable, List, Union -import xarray as xr +import numpy as np from .enkf_observation_implementation_type import EnkfObservationImplementationType from .general_observation import GenObservation @@ -12,6 +12,8 @@ if TYPE_CHECKING: from datetime import datetime +import polars + @dataclass class ObsVector: @@ -27,28 +29,39 @@ def __iter__(self) -> Iterable[Union[SummaryObservation, GenObservation]]: def __len__(self) -> int: return len(self.observations) - def to_dataset(self, active_list: List[int]) -> xr.Dataset: + def to_dataset(self, active_list: List[int]) -> polars.DataFrame: if self.observation_type == EnkfObservationImplementationType.GEN_OBS: - datasets = [] + actual_response_key = self.data_key + actual_observation_key = self.observation_key + dataframes = [] for time_step, node in self.observations.items(): if active_list and time_step not in active_list: continue assert isinstance(node, GenObservation) - datasets.append( - xr.Dataset( + dataframes.append( + polars.DataFrame( { - "observations": (["report_step", "index"], [node.values]), - "std": (["report_step", "index"], [node.stds]), - }, - coords={"index": node.indices, "report_step": [time_step]}, + "response_key": actual_response_key, + "observation_key": actual_observation_key, + "report_step": polars.Series( + np.full(len(node.indices), time_step), + dtype=polars.UInt16, + ), + "index": polars.Series(node.indices, dtype=polars.UInt16), + "observations": polars.Series( + node.values, dtype=polars.Float32 + ), + "std": polars.Series(node.stds, dtype=polars.Float32), + } ) ) - combined = xr.combine_by_coords(datasets) - combined.attrs["response"] = self.data_key + combined = polars.concat(dataframes) return combined # type: ignore elif self.observation_type == EnkfObservationImplementationType.SUMMARY_OBS: observations = [] + actual_response_key = self.observation_key + actual_observation_keys = [] errors = [] dates = list(self.observations.keys()) if active_list: @@ -57,15 +70,20 @@ def to_dataset(self, active_list: List[int]) -> xr.Dataset: for time_step in dates: n = self.observations[time_step] assert isinstance(n, SummaryObservation) + actual_observation_keys.append(n.observation_key) observations.append(n.value) errors.append(n.std) - return xr.Dataset( + + dates_series = polars.Series(dates).dt.cast_time_unit("ms") + + return polars.DataFrame( { - "observations": (["name", "time"], [observations]), - "std": (["name", "time"], [errors]), - }, - coords={"time": dates, "name": [self.observation_key]}, - attrs={"response": "summary"}, + "response_key": actual_response_key, + "observation_key": actual_observation_keys, + "time": dates_series, + "observations": polars.Series(observations, dtype=polars.Float32), + "std": polars.Series(errors, dtype=polars.Float32), + } ) else: raise ValueError(f"Unknown observation type {self.observation_type}") diff --git a/src/ert/config/response_config.py b/src/ert/config/response_config.py index d40f899ed87..55d176a6bd0 100644 --- a/src/ert/config/response_config.py +++ b/src/ert/config/response_config.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -import xarray as xr +import polars from typing_extensions import Self from ert.config.parameter_config import CustomDict @@ -16,7 +16,7 @@ class ResponseConfig(ABC): keys: List[str] = dataclasses.field(default_factory=list) @abstractmethod - def read_from_file(self, run_path: str, iens: int) -> xr.Dataset: ... + def read_from_file(self, run_path: str, iens: int) -> polars.DataFrame: ... def to_dict(self) -> Dict[str, Any]: data = dataclasses.asdict(self, dict_factory=CustomDict) diff --git a/src/ert/config/summary_config.py b/src/ert/config/summary_config.py index 49d1b46a302..3b0c0bfaa3e 100644 --- a/src/ert/config/summary_config.py +++ b/src/ert/config/summary_config.py @@ -5,8 +5,6 @@ from datetime import datetime from typing import TYPE_CHECKING, Optional, Set, Union -import xarray as xr - from ._read_summary import read_summary from .ensemble_config import Refcase from .parsing import ConfigDict, ConfigKeys @@ -18,6 +16,7 @@ from typing import List logger = logging.getLogger(__name__) +import polars @dataclass @@ -37,7 +36,7 @@ def expected_input_files(self) -> List[str]: base = self.input_files[0] return [f"{base}.UNSMRY", f"{base}.SMSPEC"] - def read_from_file(self, run_path: str, iens: int) -> xr.Dataset: + def read_from_file(self, run_path: str, iens: int) -> polars.DataFrame: filename = self.input_files[0].replace("", str(iens)) _, keys, time_map, data = read_summary(f"{run_path}/{filename}", self.keys) if len(data) == 0 or len(keys) == 0: @@ -47,11 +46,19 @@ def read_from_file(self, run_path: str, iens: int) -> xr.Dataset: raise ValueError( f"Did not find any summary values matching {self.keys} in {filename}" ) - ds = xr.Dataset( - {"values": (["name", "time"], data)}, - coords={"time": time_map, "name": keys}, + + # Important: Pick lowest unit resolution to allow for using + # datetimes many years into the future + time_map_series = polars.Series(time_map).dt.cast_time_unit("ms") + df = polars.DataFrame( + { + "response_key": keys, + "time": [time_map_series for _ in data], + "values": [polars.Series(row, dtype=polars.Float32) for row in data], + } ) - return ds.drop_duplicates("time") + df = df.explode("values", "time") + return df @property def response_type(self) -> str: