From 34091ed8f3cc5a6eedad3de43824b4e283613abc Mon Sep 17 00:00:00 2001 From: "Yngve S. Kristiansen" Date: Thu, 12 Sep 2024 11:15:48 +0200 Subject: [PATCH] Store observations as parquet --- src/ert/config/observation_vector.py | 40 +++++++++++++++------------- src/ert/config/observations.py | 20 +++++++++++--- src/ert/storage/local_experiment.py | 5 ++-- 3 files changed, 42 insertions(+), 23 deletions(-) diff --git a/src/ert/config/observation_vector.py b/src/ert/config/observation_vector.py index f77364b0ade..59a3799a2bf 100644 --- a/src/ert/config/observation_vector.py +++ b/src/ert/config/observation_vector.py @@ -3,8 +3,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Iterable, List, Union -import xarray as xr - from .enkf_observation_implementation_type import EnkfObservationImplementationType from .general_observation import GenObservation from .summary_observation import SummaryObservation @@ -12,6 +10,8 @@ if TYPE_CHECKING: from datetime import datetime +import polars + @dataclass class ObsVector: @@ -27,25 +27,28 @@ def __iter__(self) -> Iterable[Union[SummaryObservation, GenObservation]]: def __len__(self) -> int: return len(self.observations) - def to_dataset(self, active_list: List[int]) -> xr.Dataset: + def to_dataset(self, active_list: List[int]) -> polars.DataFrame: if self.observation_type == EnkfObservationImplementationType.GEN_OBS: - datasets = [] + dataframes = [] for time_step, node in self.observations.items(): if active_list and time_step not in active_list: continue assert isinstance(node, GenObservation) - datasets.append( - xr.Dataset( + dataframes.append( + polars.DataFrame( { - "observations": (["report_step", "index"], [node.values]), - "std": (["report_step", "index"], [node.stds]), - }, - coords={"index": node.indices, "report_step": [time_step]}, + "name": self.data_key, + "index": node.indices, + "report_step": time_step, + "observations": polars.Series( + node.values, dtype=polars.Float32 + ), + "std": polars.Series(node.stds, dtype=polars.Float32), + } ) ) - combined = xr.combine_by_coords(datasets) - combined.attrs["response"] = self.data_key + combined = polars.concat(dataframes) return combined # type: ignore elif self.observation_type == EnkfObservationImplementationType.SUMMARY_OBS: observations = [] @@ -59,13 +62,14 @@ def to_dataset(self, active_list: List[int]) -> xr.Dataset: assert isinstance(n, SummaryObservation) observations.append(n.value) errors.append(n.std) - return xr.Dataset( + + return polars.DataFrame( { - "observations": (["name", "time"], [observations]), - "std": (["name", "time"], [errors]), - }, - coords={"time": dates, "name": [self.observation_key]}, - attrs={"response": "summary"}, + "name": self.observation_key, + "time": dates, + "observations": polars.Series(observations, dtype=polars.Float32), + "std": polars.Series(errors, dtype=polars.Float32), + } ) else: raise ValueError(f"Unknown observation type {self.observation_type}") diff --git a/src/ert/config/observations.py b/src/ert/config/observations.py index d1e8580c982..d1ee31744b8 100644 --- a/src/ert/config/observations.py +++ b/src/ert/config/observations.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union import numpy as np -import xarray as xr +import polars from ert.validation import rangestring_to_list @@ -42,8 +42,22 @@ class EnkfObs: obs_time: List[datetime] def __post_init__(self) -> None: - self.datasets: Dict[str, xr.Dataset] = { - name: obs.to_dataset([]) for name, obs in sorted(self.obs_vectors.items()) + grouped = {} + for vec in self.obs_vectors.values(): + if vec.observation_type == EnkfObservationImplementationType.SUMMARY_OBS: + if "summary" not in grouped: + grouped["summary"] = [] + + grouped["summary"].append(vec.to_dataset([])) + + elif vec.observation_type == EnkfObservationImplementationType.GEN_OBS: + if "gen_data" not in grouped: + grouped["gen_data"] = [] + + grouped["gen_data"].append(vec.to_dataset([])) + + self.datasets: Dict[str, polars.DataFrame] = { + name: polars.concat(dfs) for name, dfs in grouped.items() } def __len__(self) -> int: diff --git a/src/ert/storage/local_experiment.py b/src/ert/storage/local_experiment.py index ff00f939b04..07869063eb4 100644 --- a/src/ert/storage/local_experiment.py +++ b/src/ert/storage/local_experiment.py @@ -8,6 +8,7 @@ from uuid import UUID import numpy as np +import polars import xarray as xr import xtgeo from pydantic import BaseModel @@ -142,7 +143,7 @@ def create( output_path = path / "observations" output_path.mkdir() for obs_name, dataset in observations.items(): - dataset.to_netcdf(output_path / f"{obs_name}", engine="scipy") + dataset.write_parquet(output_path / f"{obs_name}") with open(path / cls._metadata_file, "w", encoding="utf-8") as f: simulation_data = simulation_arguments if simulation_arguments else {} @@ -306,7 +307,7 @@ def update_parameters(self) -> List[str]: def observations(self) -> Dict[str, xr.Dataset]: observations = sorted(self.mount_point.glob("observations/*")) return { - observation.name: xr.open_dataset(observation, engine="scipy") + observation.name: polars.read_parquet(f"{observation}") for observation in observations }