Skip to content

Commit

Permalink
Store observations as parquet
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen committed Sep 13, 2024
1 parent 71d9ba8 commit 3ed8872
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 51 deletions.
99 changes: 55 additions & 44 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import iterative_ensemble_smoother as ies
import numpy as np
import pandas as pd
import polars
import psutil
from iterative_ensemble_smoother.experimental import (
AdaptiveESMDA,
Expand Down Expand Up @@ -153,56 +154,66 @@ def _get_observations_and_responses(
observation_values = []
observation_errors = []
indexes = []
observations = ensemble.experiment.observations
for obs in selected_observations:
observation = observations[obs]
group = observation.attrs["response"]
all_responses = ensemble.load_responses(group, tuple(iens_active_index))
if "time" in observation.coords:
all_responses = all_responses.reindex(
time=observation.time,
method="nearest",
observations_by_type = ensemble.experiment.observations
for (
response_type,
response_cls,
) in ensemble.experiment.response_configuration.items():
observations_for_type = observations_by_type[response_type].filter(
polars.col("observation_key").is_in(selected_observations)
)
responses_for_type = ensemble.load_responses(
response_type, realizations=tuple(iens_active_index)
)
pivoted = responses_for_type.pivot(
on="realization", index=["response_key", *response_cls.primary_key]
)

# Note2reviewer:
# We need to either assume that if there is a time column
# we will approx-join that, or we could specify in response configs
# that there is a column that requires an approx "asof" join.
# Suggest we simplify and assume that there is always only
# one "time" column, which we will reindex towards the response dataset
# with a given resolution
if "time" in pivoted:
joined = observations_for_type.join_asof(
pivoted,
by=["response_key", *response_cls.primary_key],
on="time",
tolerance="1s",
)
try:
observations_and_responses = observation.merge(all_responses, join="left")
except KeyError as e:
raise ErtAnalysisError(
f"Mismatched index for: "
f"Observation: {obs} attached to response: {group}"
) from e

observation_keys.append([obs] * observations_and_responses["observations"].size)

if group == "summary":
indexes.append(
[
np.datetime_as_string(e, unit="s")
for e in observations_and_responses["time"].data
]
)
else:
indexes.append(
[
f"{e[0]}, {e[1]}"
for e in zip(
list(observations_and_responses["report_step"].data)
* len(observations_and_responses["index"].data),
observations_and_responses["index"].data,
)
]
joined = observations_for_type.join(
pivoted,
how="left",
on=["response_key", *response_cls.primary_key],
)

observation_values.append(
observations_and_responses["observations"].data.ravel()
)
observation_errors.append(observations_and_responses["std"].data.ravel())
joined = joined.sort(by="observation_key")

index_1d = joined.with_columns(
polars.concat_str(response_cls.primary_key, separator=", ").alias("index")
)["index"].to_numpy()

obs_keys_1d = joined["observation_key"].to_numpy()
obs_values_1d = joined["observations"].to_numpy()
obs_errors_1d = joined["std"].to_numpy()

# 4 columns are always there:
# [ response_key, observation_key, observations, std ]
# + one column per "primary key" column
num_non_response_value_columns = 4 + len(response_cls.primary_key)
responses = joined.select(
joined.columns[num_non_response_value_columns:]
).to_numpy()

filtered_responses.append(responses)
observation_keys.append(obs_keys_1d)
observation_values.append(obs_values_1d)
observation_errors.append(obs_errors_1d)
indexes.append(index_1d)

filtered_responses.append(
observations_and_responses["values"]
.transpose(..., "realization")
.values.reshape((-1, len(observations_and_responses.realization)))
)
ensemble.load_responses.cache_clear()
return (
np.concatenate(filtered_responses),
Expand Down
20 changes: 17 additions & 3 deletions src/ert/config/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union

import numpy as np
import xarray as xr
import polars

from ert.validation import rangestring_to_list

Expand Down Expand Up @@ -42,8 +42,22 @@ class EnkfObs:
obs_time: List[datetime]

def __post_init__(self) -> None:
self.datasets: Dict[str, xr.Dataset] = {
name: obs.to_dataset([]) for name, obs in sorted(self.obs_vectors.items())
grouped = {}
for vec in self.obs_vectors.values():
if vec.observation_type == EnkfObservationImplementationType.SUMMARY_OBS:
if "summary" not in grouped:
grouped["summary"] = []

grouped["summary"].append(vec.to_dataset([]))

elif vec.observation_type == EnkfObservationImplementationType.GEN_OBS:
if "gen_data" not in grouped:
grouped["gen_data"] = []

grouped["gen_data"].append(vec.to_dataset([]))

self.datasets: Dict[str, polars.DataFrame] = {
name: polars.concat(dfs) for name, dfs in grouped.items()
}

def __len__(self) -> int:
Expand Down
6 changes: 6 additions & 0 deletions src/ert/config/response_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def response_type(self) -> str:
Must not overlap with that of other response configs."""
...

@property
@abstractmethod
def primary_key(self) -> List[str]:
"""Primary key of this response data.
For example 'time' for summary and ['index','report_step'] for gen data"""

@classmethod
@abstractmethod
def from_config_dict(cls, config_dict: ConfigDict) -> Optional[Self]:
Expand Down
4 changes: 4 additions & 0 deletions src/ert/config/summary_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def read_from_file(self, run_path: str, iens: int) -> polars.DataFrame:
def response_type(self) -> str:
return "summary"

@property
def primary_key(self) -> List[str]:
return ["time"]

@classmethod
def from_config_dict(self, config_dict: ConfigDict) -> Optional[SummaryConfig]:
refcase = Refcase.from_config_dict(config_dict)
Expand Down
19 changes: 16 additions & 3 deletions src/ert/storage/local_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from uuid import UUID

import numpy as np
import polars
import xarray as xr
import xtgeo
from pydantic import BaseModel
Expand Down Expand Up @@ -142,7 +143,7 @@ def create(
output_path = path / "observations"
output_path.mkdir()
for obs_name, dataset in observations.items():
dataset.to_netcdf(output_path / f"{obs_name}", engine="scipy")
dataset.write_parquet(output_path / f"{obs_name}")

with open(path / cls._metadata_file, "w", encoding="utf-8") as f:
simulation_data = simulation_arguments if simulation_arguments else {}
Expand Down Expand Up @@ -303,13 +304,25 @@ def update_parameters(self) -> List[str]:
return [p.name for p in self.parameter_configuration.values() if p.update]

@cached_property
def observations(self) -> Dict[str, xr.Dataset]:
def observations(self) -> Dict[str, polars.DataFrame]:
observations = sorted(self.mount_point.glob("observations/*"))
return {
observation.name: xr.open_dataset(observation, engine="scipy")
observation.name: polars.read_parquet(f"{observation}")
for observation in observations
}

@cached_property
def observation_keys(self) -> List[str]:
"""
Gets all \"name\" values for all observations. I.e.,
the summary keyword, the gen_data observation name etc.
"""
keys = []
for df in self.observations.values():
keys.extend(df["observation_key"].unique())

return sorted(keys)

@cached_property
def response_key_to_response_type(self) -> Dict[str, str]:
mapping = {}
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_update_report(
smoother_update(
prior_ens,
posterior_ens,
list(ert_config.observations.keys()),
list(experiment.observation_keys),
ert_config.ensemble_config.parameters,
UpdateSettings(auto_scale_observations=misfit_preprocess),
ESSettings(inversion="subspace"),
Expand Down

0 comments on commit 3ed8872

Please sign in to comment.