Skip to content

Commit

Permalink
Store observations as parquet
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen committed Sep 13, 2024
1 parent f16b102 commit 14f9029
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 47 deletions.
90 changes: 50 additions & 40 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import iterative_ensemble_smoother as ies
import numpy as np
import pandas as pd
import polars
import psutil
from iterative_ensemble_smoother.experimental import (
AdaptiveESMDA,
Expand Down Expand Up @@ -153,46 +154,55 @@ def _get_observations_and_responses(
observation_values = []
observation_errors = []
indexes = []
observations = ensemble.experiment.observations
for obs in selected_observations:
observation = observations[obs]
group = observation.attrs["response"]
all_responses = ensemble.load_responses(group, tuple(iens_active_index))
if "time" in observation.coords:
all_responses = all_responses.reindex(
time=observation.time,
method="nearest",
tolerance="1s",
)
try:
observations_and_responses = observation.merge(all_responses, join="left")
except KeyError as e:
raise ErtAnalysisError(
f"Mismatched index for: "
f"Observation: {obs} attached to response: {group}"
) from e

observation_keys.append([obs] * observations_and_responses["observations"].size)

if group == "summary":
indexes.append(
[
np.datetime_as_string(e, unit="s")
for e in observations_and_responses["time"].data
]
)
else:
indexes.append(
[
f"{e[0]}, {e[1]}"
for e in zip(
list(observations_and_responses["report_step"].data)
* len(observations_and_responses["index"].data),
observations_and_responses["index"].data,
)
]
)

observations_by_type = ensemble.experiment.observations
for response_type in ensemble.experiment.response_info:
observations_for_type = observations_by_type[response_type].filter(
polars.col("observation_key").is_in(selected_observations)
)
responses_for_type = ensemble.load_responses(
response_type, realizations=tuple(iens_active_index)
)
joined = observations_for_type.join(responses_for_type, how="left")

#
# observation = None
# # group = observation.attrs["response"]
# all_responses = ensemble.load_responses(group, tuple(iens_active_index))
# if "time" in observation.coords:
# all_responses = all_responses.reindex(
# time=observation.time,
# method="nearest",
# tolerance="1s",
# )
# try:
# observations_and_responses = observation.merge(all_responses, join="left")
# except KeyError as e:
# raise ErtAnalysisError(
# f"Mismatched index for: "
# f"Observation: {obs} attached to response: {group}"
# ) from e
#
# observation_keys.append([obs] * observations_and_responses["observations"].size)
#
# if group == "summary":
# indexes.append(
# [
# np.datetime_as_string(e, unit="s")
# for e in observations_and_responses["time"].data
# ]
# )
# else:
# indexes.append(
# [
# f"{e[0]}, {e[1]}"
# for e in zip(
# list(observations_and_responses["report_step"].data)
# * len(observations_and_responses["index"].data),
# observations_and_responses["index"].data,
# )
# ]
# )
observations_and_responses = None
observation_values.append(
observations_and_responses["observations"].data.ravel()
)
Expand Down
20 changes: 17 additions & 3 deletions src/ert/config/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union

import numpy as np
import xarray as xr
import polars

from ert.validation import rangestring_to_list

Expand Down Expand Up @@ -42,8 +42,22 @@ class EnkfObs:
obs_time: List[datetime]

def __post_init__(self) -> None:
self.datasets: Dict[str, xr.Dataset] = {
name: obs.to_dataset([]) for name, obs in sorted(self.obs_vectors.items())
grouped = {}
for vec in self.obs_vectors.values():
if vec.observation_type == EnkfObservationImplementationType.SUMMARY_OBS:
if "summary" not in grouped:
grouped["summary"] = []

grouped["summary"].append(vec.to_dataset([]))

elif vec.observation_type == EnkfObservationImplementationType.GEN_OBS:
if "gen_data" not in grouped:
grouped["gen_data"] = []

grouped["gen_data"].append(vec.to_dataset([]))

self.datasets: Dict[str, polars.DataFrame] = {
name: polars.concat(dfs) for name, dfs in grouped.items()
}

def __len__(self) -> int:
Expand Down
19 changes: 16 additions & 3 deletions src/ert/storage/local_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from uuid import UUID

import numpy as np
import polars
import xarray as xr
import xtgeo
from pydantic import BaseModel
Expand Down Expand Up @@ -142,7 +143,7 @@ def create(
output_path = path / "observations"
output_path.mkdir()
for obs_name, dataset in observations.items():
dataset.to_netcdf(output_path / f"{obs_name}", engine="scipy")
dataset.write_parquet(output_path / f"{obs_name}")

with open(path / cls._metadata_file, "w", encoding="utf-8") as f:
simulation_data = simulation_arguments if simulation_arguments else {}
Expand Down Expand Up @@ -303,13 +304,25 @@ def update_parameters(self) -> List[str]:
return [p.name for p in self.parameter_configuration.values() if p.update]

@cached_property
def observations(self) -> Dict[str, xr.Dataset]:
def observations(self) -> Dict[str, polars.DataFrame]:
observations = sorted(self.mount_point.glob("observations/*"))
return {
observation.name: xr.open_dataset(observation, engine="scipy")
observation.name: polars.read_parquet(f"{observation}")
for observation in observations
}

@cached_property
def observation_keys(self) -> List[str]:
"""
Gets all \"name\" values for all observations. I.e.,
the summary keyword, the gen_data observation name etc.
"""
keys = []
for df in self.observations.values():
keys.extend(df["observation_key"].unique())

return sorted(keys)

@cached_property
def response_key_to_response_type(self) -> Dict[str, str]:
mapping = {}
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_update_report(
smoother_update(
prior_ens,
posterior_ens,
list(ert_config.observations.keys()),
list(experiment.observation_keys),
ert_config.ensemble_config.parameters,
UpdateSettings(auto_scale_observations=misfit_preprocess),
ESSettings(inversion="subspace"),
Expand Down

0 comments on commit 14f9029

Please sign in to comment.