-
Notifications
You must be signed in to change notification settings - Fork 108
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge ert-storage code into local dark_storage
- Loading branch information
Showing
12 changed files
with
404 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .misfits import calculate_misfits_from_pandas | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Any, List, Mapping, Optional | ||
Check failure on line 1 in src/ert/dark_storage/compute/misfits.py GitHub Actions / annotate-python-linting
|
||
from uuid import UUID | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
def _calculate_misfit( | ||
obs_value: np.ndarray, response_value: np.ndarray, obs_std: np.ndarray | ||
) -> List[float]: | ||
difference = response_value - obs_value | ||
misfit = (difference / obs_std) ** 2 | ||
return (misfit * np.sign(difference)).tolist() | ||
|
||
|
||
def calculate_misfits_from_pandas( | ||
reponses_dict: Mapping[int, pd.DataFrame], | ||
observation: pd.DataFrame, | ||
summary_misfits: bool = False, | ||
) -> pd.DataFrame: | ||
""" | ||
Compute misfits from reponses_dict (real_id, values in dataframe) | ||
and observation | ||
""" | ||
misfits_dict = {} | ||
for realization_index in reponses_dict: | ||
misfits_dict[realization_index] = _calculate_misfit( | ||
observation["values"], | ||
reponses_dict[realization_index].loc[:, observation.index].values.flatten(), | ||
observation["errors"], | ||
) | ||
|
||
df = pd.DataFrame(data=misfits_dict, index=observation.index) | ||
if summary_misfits: | ||
df = pd.DataFrame([df.abs().sum(axis=0)], columns=df.columns, index=[0]) | ||
return df.T |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from typing import Any | ||
|
||
from fastapi import status | ||
|
||
|
||
class ErtStorageError(RuntimeError): | ||
""" | ||
Base error class for all the rest of errors | ||
""" | ||
|
||
__status_code__ = status.HTTP_200_OK | ||
|
||
def __init__(self, message: str, **kwargs: Any): | ||
super().__init__(message, kwargs) | ||
|
||
|
||
class NotFoundError(ErtStorageError): | ||
__status_code__ = status.HTTP_404_NOT_FOUND | ||
|
||
|
||
class ConflictError(ErtStorageError): | ||
__status_code__ = status.HTTP_409_CONFLICT | ||
|
||
|
||
class ExpectationError(ErtStorageError): | ||
__status_code__ = status.HTTP_417_EXPECTATION_FAILED | ||
|
||
|
||
class UnprocessableError(ErtStorageError): | ||
__status_code__ = status.HTTP_422_UNPROCESSABLE_ENTITY |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from .ensemble import EnsembleIn, EnsembleOut | ||
from .experiment import ExperimentIn, ExperimentOut | ||
from .observation import ( | ||
ObservationIn, | ||
ObservationOut, | ||
ObservationTransformationIn, | ||
ObservationTransformationOut, | ||
) | ||
from .prior import Prior | ||
from .record import RecordOut | ||
from .update import UpdateIn, UpdateOut |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Any, List, Mapping, Optional | ||
from uuid import UUID | ||
|
||
from pydantic import BaseModel, Field, root_validator | ||
|
||
|
||
class _Ensemble(BaseModel): | ||
size: int | ||
parameter_names: List[str] | ||
response_names: List[str] | ||
active_realizations: List[int] = [] | ||
|
||
|
||
class EnsembleIn(_Ensemble): | ||
update_id: Optional[UUID] = None | ||
userdata: Mapping[str, Any] = {} | ||
|
||
@root_validator | ||
def _check_names_no_overlap(cls, values: Mapping[str, Any]) -> Mapping[str, Any]: | ||
""" | ||
Verify that `parameter_names` and `response_names` don't overlap. Ie, no | ||
record can be both a parameter and a response. | ||
""" | ||
if not set(values["parameter_names"]).isdisjoint(set(values["response_names"])): | ||
raise ValueError("parameters and responses cannot have a name in common") | ||
return values | ||
|
||
|
||
class EnsembleOut(_Ensemble): | ||
id: UUID | ||
children: List[UUID] = Field(alias="child_ensemble_ids") | ||
parent: Optional[UUID] = Field(alias="parent_ensemble_id") | ||
experiment_id: Optional[UUID] = None | ||
userdata: Mapping[str, Any] | ||
|
||
class Config: | ||
orm_mode = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from typing import Any, List, Mapping | ||
from uuid import UUID | ||
|
||
from pydantic import BaseModel | ||
|
||
from .prior import Prior | ||
|
||
|
||
class _Experiment(BaseModel): | ||
name: str | ||
|
||
|
||
class ExperimentIn(_Experiment): | ||
priors: Mapping[str, Prior] = {} | ||
|
||
|
||
class ExperimentOut(_Experiment): | ||
id: UUID | ||
ensemble_ids: List[UUID] | ||
priors: Mapping[str, dict] | ||
userdata: Mapping[str, Any] | ||
|
||
class Config: | ||
orm_mode = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import Any, List, Mapping, Optional | ||
from uuid import UUID | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class _ObservationTransformation(BaseModel): | ||
name: str | ||
active: List[bool] | ||
scale: List[float] | ||
observation_id: UUID | ||
|
||
|
||
class ObservationTransformationIn(_ObservationTransformation): | ||
pass | ||
|
||
|
||
class ObservationTransformationOut(_ObservationTransformation): | ||
id: UUID | ||
|
||
class Config: | ||
orm_mode = True | ||
|
||
|
||
class _Observation(BaseModel): | ||
name: str | ||
errors: List[float] | ||
values: List[float] | ||
x_axis: List[Any] | ||
records: Optional[List[UUID]] = None | ||
|
||
|
||
class ObservationIn(_Observation): | ||
pass | ||
|
||
|
||
class ObservationOut(_Observation): | ||
id: UUID | ||
transformation: Optional[ObservationTransformationOut] = None | ||
userdata: Mapping[str, Any] = {} | ||
|
||
class Config: | ||
orm_mode = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import sys | ||
from typing import Union | ||
|
||
from pydantic import BaseModel | ||
|
||
if sys.version_info < (3, 8): | ||
from typing_extensions import Literal | ||
else: | ||
from typing import Literal | ||
|
||
|
||
class PriorConst(BaseModel): | ||
""" | ||
Constant parameter prior | ||
""" | ||
|
||
function: Literal["const"] = "const" | ||
value: float | ||
|
||
|
||
class PriorTrig(BaseModel): | ||
""" | ||
Triangular distribution parameter prior | ||
""" | ||
|
||
function: Literal["trig"] = "trig" | ||
min: float | ||
max: float | ||
mode: float | ||
|
||
|
||
class PriorNormal(BaseModel): | ||
""" | ||
Normal distribution parameter prior | ||
""" | ||
|
||
function: Literal["normal"] = "normal" | ||
mean: float | ||
std: float | ||
|
||
|
||
class PriorLogNormal(BaseModel): | ||
""" | ||
Log-normal distribution parameter prior | ||
""" | ||
|
||
function: Literal["lognormal"] = "lognormal" | ||
mean: float | ||
std: float | ||
|
||
|
||
class PriorErtTruncNormal(BaseModel): | ||
""" | ||
ERT Truncated normal distribution parameter prior | ||
ERT differs from the usual distribution by that it simply clamps on `min` | ||
and `max`, which gives a bias towards the extremes. | ||
""" | ||
|
||
function: Literal["ert_truncnormal"] = "ert_truncnormal" | ||
mean: float | ||
std: float | ||
min: float | ||
max: float | ||
|
||
|
||
class PriorStdNormal(BaseModel): | ||
""" | ||
Standard normal distribution parameter prior | ||
Normal distribution with mean of 0 and standard deviation of 1 | ||
""" | ||
|
||
function: Literal["stdnormal"] = "stdnormal" | ||
|
||
|
||
class PriorUniform(BaseModel): | ||
""" | ||
Uniform distribution parameter prior | ||
""" | ||
|
||
function: Literal["uniform"] = "uniform" | ||
min: float | ||
max: float | ||
|
||
|
||
class PriorErtDUniform(BaseModel): | ||
""" | ||
ERT Discrete uniform distribution parameter prior | ||
This discrete uniform distribution differs from the standard by using the | ||
`bins` parameter. Normally, `a`, and `b` are integers, and the sample space | ||
are the integers between. ERT allows `a` and `b` to be arbitrary floats, | ||
where the sample space is binned. | ||
""" | ||
|
||
function: Literal["ert_duniform"] = "ert_duniform" | ||
bins: int | ||
min: float | ||
max: float | ||
|
||
|
||
class PriorLogUniform(BaseModel): | ||
""" | ||
Logarithmic uniform distribution parameter prior | ||
""" | ||
|
||
function: Literal["loguniform"] = "loguniform" | ||
min: float | ||
max: float | ||
|
||
|
||
class PriorErtErf(BaseModel): | ||
""" | ||
ERT Error function distribution parameter prior | ||
""" | ||
|
||
function: Literal["ert_erf"] = "ert_erf" | ||
min: float | ||
max: float | ||
skewness: float | ||
width: float | ||
|
||
|
||
class PriorErtDErf(BaseModel): | ||
""" | ||
ERT Discrete error function distribution parameter prior | ||
""" | ||
|
||
function: Literal["ert_derf"] = "ert_derf" | ||
bins: int | ||
min: float | ||
max: float | ||
skewness: float | ||
width: float | ||
|
||
|
||
Prior = Union[ | ||
PriorConst, | ||
PriorTrig, | ||
PriorNormal, | ||
PriorLogNormal, | ||
PriorErtTruncNormal, | ||
PriorStdNormal, | ||
PriorUniform, | ||
PriorErtDUniform, | ||
PriorLogUniform, | ||
PriorErtErf, | ||
PriorErtDErf, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from typing import Any, Mapping, Optional | ||
from uuid import UUID | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
class _Record(BaseModel): | ||
pass | ||
|
||
|
||
class RecordOut(_Record): | ||
id: UUID | ||
name: str | ||
userdata: Mapping[str, Any] | ||
has_observations: Optional[bool] | ||
|
||
class Config: | ||
orm_mode = True |
Oops, something went wrong.