diff --git a/src/ert/analysis/_es_update.py b/src/ert/analysis/_es_update.py index 4fd29651dec..f4c798e5bdf 100644 --- a/src/ert/analysis/_es_update.py +++ b/src/ert/analysis/_es_update.py @@ -19,6 +19,7 @@ import iterative_ensemble_smoother as ies import numpy as np import pandas as pd +import polars import psutil from iterative_ensemble_smoother.experimental import ( AdaptiveESMDA, @@ -153,46 +154,55 @@ def _get_observations_and_responses( observation_values = [] observation_errors = [] indexes = [] - observations = ensemble.experiment.observations - for obs in selected_observations: - observation = observations[obs] - group = observation.attrs["response"] - all_responses = ensemble.load_responses(group, tuple(iens_active_index)) - if "time" in observation.coords: - all_responses = all_responses.reindex( - time=observation.time, - method="nearest", - tolerance="1s", - ) - try: - observations_and_responses = observation.merge(all_responses, join="left") - except KeyError as e: - raise ErtAnalysisError( - f"Mismatched index for: " - f"Observation: {obs} attached to response: {group}" - ) from e - - observation_keys.append([obs] * observations_and_responses["observations"].size) - - if group == "summary": - indexes.append( - [ - np.datetime_as_string(e, unit="s") - for e in observations_and_responses["time"].data - ] - ) - else: - indexes.append( - [ - f"{e[0]}, {e[1]}" - for e in zip( - list(observations_and_responses["report_step"].data) - * len(observations_and_responses["index"].data), - observations_and_responses["index"].data, - ) - ] - ) - + observations_by_type = ensemble.experiment.observations + for response_type in ensemble.experiment.response_info: + observations_for_type = observations_by_type[response_type].filter( + polars.col("observation_key").is_in(selected_observations) + ) + responses_for_type = ensemble.load_responses( + response_type, realizations=tuple(iens_active_index) + ) + joined = observations_for_type.join(responses_for_type, how="left") + + # + # observation = None + # # group = observation.attrs["response"] + # all_responses = ensemble.load_responses(group, tuple(iens_active_index)) + # if "time" in observation.coords: + # all_responses = all_responses.reindex( + # time=observation.time, + # method="nearest", + # tolerance="1s", + # ) + # try: + # observations_and_responses = observation.merge(all_responses, join="left") + # except KeyError as e: + # raise ErtAnalysisError( + # f"Mismatched index for: " + # f"Observation: {obs} attached to response: {group}" + # ) from e + # + # observation_keys.append([obs] * observations_and_responses["observations"].size) + # + # if group == "summary": + # indexes.append( + # [ + # np.datetime_as_string(e, unit="s") + # for e in observations_and_responses["time"].data + # ] + # ) + # else: + # indexes.append( + # [ + # f"{e[0]}, {e[1]}" + # for e in zip( + # list(observations_and_responses["report_step"].data) + # * len(observations_and_responses["index"].data), + # observations_and_responses["index"].data, + # ) + # ] + # ) + observations_and_responses = None observation_values.append( observations_and_responses["observations"].data.ravel() ) diff --git a/src/ert/config/observations.py b/src/ert/config/observations.py index d1e8580c982..d1ee31744b8 100644 --- a/src/ert/config/observations.py +++ b/src/ert/config/observations.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union import numpy as np -import xarray as xr +import polars from ert.validation import rangestring_to_list @@ -42,8 +42,22 @@ class EnkfObs: obs_time: List[datetime] def __post_init__(self) -> None: - self.datasets: Dict[str, xr.Dataset] = { - name: obs.to_dataset([]) for name, obs in sorted(self.obs_vectors.items()) + grouped = {} + for vec in self.obs_vectors.values(): + if vec.observation_type == EnkfObservationImplementationType.SUMMARY_OBS: + if "summary" not in grouped: + grouped["summary"] = [] + + grouped["summary"].append(vec.to_dataset([])) + + elif vec.observation_type == EnkfObservationImplementationType.GEN_OBS: + if "gen_data" not in grouped: + grouped["gen_data"] = [] + + grouped["gen_data"].append(vec.to_dataset([])) + + self.datasets: Dict[str, polars.DataFrame] = { + name: polars.concat(dfs) for name, dfs in grouped.items() } def __len__(self) -> int: diff --git a/src/ert/storage/local_experiment.py b/src/ert/storage/local_experiment.py index ff00f939b04..2bc05a95d04 100644 --- a/src/ert/storage/local_experiment.py +++ b/src/ert/storage/local_experiment.py @@ -8,6 +8,7 @@ from uuid import UUID import numpy as np +import polars import xarray as xr import xtgeo from pydantic import BaseModel @@ -142,7 +143,7 @@ def create( output_path = path / "observations" output_path.mkdir() for obs_name, dataset in observations.items(): - dataset.to_netcdf(output_path / f"{obs_name}", engine="scipy") + dataset.write_parquet(output_path / f"{obs_name}") with open(path / cls._metadata_file, "w", encoding="utf-8") as f: simulation_data = simulation_arguments if simulation_arguments else {} @@ -303,13 +304,25 @@ def update_parameters(self) -> List[str]: return [p.name for p in self.parameter_configuration.values() if p.update] @cached_property - def observations(self) -> Dict[str, xr.Dataset]: + def observations(self) -> Dict[str, polars.DataFrame]: observations = sorted(self.mount_point.glob("observations/*")) return { - observation.name: xr.open_dataset(observation, engine="scipy") + observation.name: polars.read_parquet(f"{observation}") for observation in observations } + @cached_property + def observation_keys(self) -> List[str]: + """ + Gets all \"name\" values for all observations. I.e., + the summary keyword, the gen_data observation name etc. + """ + keys = [] + for df in self.observations.values(): + keys.extend(df["observation_key"].unique()) + + return sorted(keys) + @cached_property def response_key_to_response_type(self) -> Dict[str, str]: mapping = {} diff --git a/tests/unit_tests/analysis/test_es_update.py b/tests/unit_tests/analysis/test_es_update.py index 299f156c2b0..39f97248147 100644 --- a/tests/unit_tests/analysis/test_es_update.py +++ b/tests/unit_tests/analysis/test_es_update.py @@ -98,7 +98,7 @@ def test_update_report( smoother_update( prior_ens, posterior_ens, - list(ert_config.observations.keys()), + list(experiment.observation_keys), ert_config.ensemble_config.parameters, UpdateSettings(auto_scale_observations=misfit_preprocess), ESSettings(inversion="subspace"),