Skip to content

Commit

Permalink
Serialization testing for all StatusEvents form the base runner
Browse files Browse the repository at this point in the history
  • Loading branch information
DanSava committed Jan 22, 2025
1 parent 8e9996f commit 741e418
Show file tree
Hide file tree
Showing 9 changed files with 283 additions and 125 deletions.
18 changes: 10 additions & 8 deletions src/ert/analysis/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,30 @@
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import pandas as pd
from pydantic import BaseModel, ConfigDict


@dataclass
class AnalysisEvent:
class AnalysisEvent(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
pass


@dataclass
class AnalysisStatusEvent(AnalysisEvent):
event_type: Literal["AnalysisStatusEvent"] = "AnalysisStatusEvent"
msg: str


@dataclass
class AnalysisTimeEvent(AnalysisEvent):
event_type: Literal["AnalysisTimeEvent"] = "AnalysisTimeEvent"
remaining_time: float
elapsed_time: float


@dataclass
class AnalysisReportEvent(AnalysisEvent):
event_type: Literal["AnalysisReportEvent"] = "AnalysisReportEvent"
report: str


Expand Down Expand Up @@ -56,18 +58,18 @@ def to_csv(self, name: str, output_path: Path) -> None:
df.to_csv(f_path.with_suffix(".csv"))


@dataclass
class AnalysisDataEvent(AnalysisEvent):
event_type: Literal["AnalysisDataEvent"] = "AnalysisDataEvent"
name: str
data: DataSection


@dataclass
class AnalysisErrorEvent(AnalysisEvent):
event_type: Literal["AnalysisErrorEvent"] = "AnalysisErrorEvent"
error_msg: str
data: DataSection | None = None


@dataclass
class AnalysisCompleteEvent(AnalysisEvent):
event_type: Literal["AnalysisCompleteEvent"] = "AnalysisCompleteEvent"
data: DataSection
70 changes: 30 additions & 40 deletions src/ert/ensemble_evaluator/event.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,49 @@
import json
from dataclasses import dataclass
from collections.abc import Mapping
from typing import Any, Literal

from pydantic import BaseModel, ConfigDict, field_serializer, field_validator

from .snapshot import EnsembleSnapshot


@dataclass
class _UpdateEvent:
class _UpdateEvent(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
iteration_label: str
total_iterations: int
progress: float
realization_count: int
status_count: dict[str, int]
iteration: int
snapshot: EnsembleSnapshot | None = None

@field_serializer("snapshot")
def serialize_snapshot(
self, value: EnsembleSnapshot | None
) -> dict[str, Any] | None:
if value is None:
return None
return value.to_dict()

@field_validator("snapshot", mode="before")
@classmethod
def validate_snapshot(
cls, value: EnsembleSnapshot | Mapping[Any, Any]
) -> EnsembleSnapshot:
if isinstance(value, EnsembleSnapshot):
return value
return EnsembleSnapshot.from_nested_dict(value)


@dataclass
class FullSnapshotEvent(_UpdateEvent):
snapshot: EnsembleSnapshot | None = None
event_type: Literal["FullSnapshotEvent"] = "FullSnapshotEvent"


@dataclass
class SnapshotUpdateEvent(_UpdateEvent):
snapshot: EnsembleSnapshot | None = None
event_type: Literal["SnapshotUpdateEvent"] = "SnapshotUpdateEvent"


@dataclass
class EndEvent:
class EndEvent(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
event_type: Literal["EndEvent"] = "EndEvent"
failed: bool
msg: str | None = None


def snapshot_event_from_json(json_str: str) -> FullSnapshotEvent | SnapshotUpdateEvent:
json_dict = json.loads(json_str)
snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"])
json_dict["snapshot"] = snapshot
match json_dict.pop("type"):
case "FullSnapshotEvent":
return FullSnapshotEvent(**json_dict)
case "SnapshotUpdateEvent":
return SnapshotUpdateEvent(**json_dict)
case unknown:
raise TypeError(f"Unknown snapshot update event type {unknown}")


def snapshot_event_to_json(event: FullSnapshotEvent | SnapshotUpdateEvent) -> str:
assert event.snapshot is not None
return json.dumps(
{
"iteration_label": event.iteration_label,
"total_iterations": event.total_iterations,
"progress": event.progress,
"realization_count": event.realization_count,
"status_count": event.status_count,
"iteration": event.iteration,
"snapshot": event.snapshot.to_dict(),
"type": event.__class__.__name__,
}
)
msg: str
23 changes: 17 additions & 6 deletions src/ert/ensemble_evaluator/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ class UnsupportedOperationException(ValueError):


def convert_iso8601_to_datetime(
timestamp: datetime | str,
) -> datetime:
timestamp: datetime | str | None,
) -> datetime | None:
if timestamp is None:
return None
if isinstance(timestamp, datetime):
return timestamp

Expand Down Expand Up @@ -161,13 +163,17 @@ def to_dict(self) -> dict[str, Any]:
if self._ensemble_state:
dict_["status"] = self._ensemble_state
if self._realization_snapshots:
dict_["reals"] = self._realization_snapshots
dict_["reals"] = {
k: _filter_nones(v) for k, v in self._realization_snapshots.items()
}

for (real_id, fm_id), fm_values_dict in self._fm_step_snapshots.items():
if "reals" not in dict_:
dict_["reals"] = {}
if real_id not in dict_["reals"]:
dict_["reals"][real_id] = RealizationSnapshot(fm_steps={})
dict_["reals"][real_id] = _filter_nones(
RealizationSnapshot(fm_steps={})
)
if "fm_steps" not in dict_["reals"][real_id]:
dict_["reals"][real_id]["fm_steps"] = {}

Expand Down Expand Up @@ -400,12 +406,17 @@ def _realization_dict_to_realization_snapshot(
realization = RealizationSnapshot(
status=source.get("status"),
active=source.get("active"),
start_time=source.get("start_time"),
end_time=source.get("end_time"),
start_time=convert_iso8601_to_datetime(source.get("start_time")),
end_time=convert_iso8601_to_datetime(source.get("end_time")),
exec_hosts=source.get("exec_hosts"),
message=source.get("message"),
fm_steps=source.get("fm_steps", {}),
)
for step in realization["fm_steps"].values():
if step.get("start_time"):
step["start_time"] = convert_iso8601_to_datetime(step["start_time"])
if step.get("end_time"):
step["end_time"] = convert_iso8601_to_datetime(step["end_time"])
return _filter_nones(realization)


Expand Down
10 changes: 1 addition & 9 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,14 @@

from _ert.events import EESnapshot, EESnapshotUpdate, EETerminated, Event
from ert.analysis import (
AnalysisEvent,
AnalysisStatusEvent,
AnalysisTimeEvent,
ErtAnalysisError,
smoother_update,
)
from ert.analysis.event import (
AnalysisCompleteEvent,
AnalysisDataEvent,
AnalysisErrorEvent,
AnalysisEvent,
)
from ert.config import HookRuntime, QueueSystem
from ert.config.analysis_module import BaseSettings
Expand All @@ -44,11 +42,6 @@
Monitor,
Realization,
)
from ert.ensemble_evaluator.event import (
EndEvent,
FullSnapshotEvent,
SnapshotUpdateEvent,
)
from ert.ensemble_evaluator.identifiers import STATUS
from ert.ensemble_evaluator.snapshot import EnsembleSnapshot
from ert.ensemble_evaluator.state import (
Expand All @@ -68,7 +61,6 @@
from ..config.analysis_config import UpdateSettings
from ..run_arg import RunArg
from .event import (
AnalysisEvent,
AnalysisStatusEvent,
AnalysisTimeEvent,
EndEvent,
Expand Down
39 changes: 26 additions & 13 deletions src/ert/run_models/event.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Annotated, Literal
from uuid import UUID

from pydantic import BaseModel, ConfigDict, Field, TypeAdapter

from ert.analysis import (
AnalysisEvent,
AnalysisStatusEvent,
AnalysisTimeEvent,
)
Expand All @@ -17,30 +18,30 @@
)


@dataclass
class RunModelEvent:
class RunModelEvent(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
iteration: int
run_id: UUID


@dataclass
class RunModelStatusEvent(RunModelEvent):
event_type: Literal["RunModelStatusEvent"] = "RunModelStatusEvent"
msg: str


@dataclass
class RunModelTimeEvent(RunModelEvent):
event_type: Literal["RunModelTimeEvent"] = "RunModelTimeEvent"
remaining_time: float
elapsed_time: float


@dataclass
class RunModelUpdateBeginEvent(RunModelEvent):
event_type: Literal["RunModelUpdateBeginEvent"] = "RunModelUpdateBeginEvent"
pass


@dataclass
class RunModelDataEvent(RunModelEvent):
event_type: Literal["RunModelDataEvent"] = "RunModelDataEvent"
name: str
data: DataSection

Expand All @@ -49,28 +50,27 @@ def write_as_csv(self, output_path: Path | None) -> None:
self.data.to_csv(self.name, output_path / str(self.run_id))


@dataclass
class RunModelUpdateEndEvent(RunModelEvent):
event_type: Literal["RunModelUpdateEndEvent"] = "RunModelUpdateEndEvent"
data: DataSection

def write_as_csv(self, output_path: Path | None) -> None:
if output_path and self.data:
self.data.to_csv("Report", output_path / str(self.run_id))


@dataclass
class RunModelErrorEvent(RunModelEvent):
event_type: Literal["RunModelErrorEvent"] = "RunModelErrorEvent"
error_msg: str
data: DataSection | None = None
data: DataSection

def write_as_csv(self, output_path: Path | None) -> None:
if output_path and self.data:
self.data.to_csv("Report", output_path / str(self.run_id))


StatusEvents = (
AnalysisEvent
| AnalysisStatusEvent
AnalysisStatusEvent
| AnalysisTimeEvent
| EndEvent
| FullSnapshotEvent
Expand All @@ -82,3 +82,16 @@ def write_as_csv(self, output_path: Path | None) -> None:
| RunModelDataEvent
| RunModelUpdateEndEvent
)


STATUS_EVENTS_ANNOTATION = Annotated[StatusEvents, Field(discriminator="event_type")]

StatusEventAdapter: TypeAdapter[StatusEvents] = TypeAdapter(STATUS_EVENTS_ANNOTATION)


def status_event_from_json(raw_msg: str | bytes) -> StatusEvents:
return StatusEventAdapter.validate_json(raw_msg)


def status_event_to_json(event: StatusEvents) -> str:
return event.model_dump_json()
2 changes: 2 additions & 0 deletions tests/ert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def build(
exec_hosts: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
message: str | None = None,
) -> EnsembleSnapshot:
snapshot = EnsembleSnapshot()
snapshot._ensemble_state = status
Expand All @@ -53,6 +54,7 @@ def build(
end_time=end_time,
exec_hosts=exec_hosts,
status=status,
message=message,
),
)
return snapshot
Expand Down
11 changes: 10 additions & 1 deletion tests/ert/unit_tests/cli/test_model_hook_order.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from unittest.mock import ANY, MagicMock, PropertyMock, call, patch

import pytest
Expand Down Expand Up @@ -48,6 +49,7 @@ def test_hook_call_order_ensemble_smoother(monkeypatch):

ens_mock = MagicMock()
ens_mock.iteration = 0
ens_mock.id = uuid.uuid1()
storage_mock = MagicMock()
storage_mock.create_ensemble.return_value = ens_mock

Expand Down Expand Up @@ -86,6 +88,7 @@ def test_hook_call_order_es_mda(monkeypatch):

ens_mock = MagicMock()
ens_mock.iteration = 0
ens_mock.id = uuid.uuid1()
storage_mock = MagicMock()
storage_mock.create_ensemble.return_value = ens_mock
test_class = MultipleDataAssimilation(
Expand Down Expand Up @@ -115,9 +118,15 @@ def test_hook_call_order_iterative_ensemble_smoother(monkeypatch):
monkeypatch.setattr(base_run_model, "_seed_sequence", MagicMock(return_value=0))
monkeypatch.setattr(base_run_model.BaseRunModel, "run_workflows", run_wfs_mock)

ens_mock = MagicMock()
ens_mock.iteration = 2
ens_mock.id = uuid.uuid1()
storage_mock = MagicMock()
storage_mock.create_ensemble.return_value = ens_mock

test_class = IteratedEnsembleSmoother(*[MagicMock()] * 13)
test_class.run_ensemble_evaluator = MagicMock(return_value=[0])

test_class._storage = storage_mock
# Mock the return values of iterative_smoother_update
# Mock the iteration property of IteratedEnsembleSmoother
with (
Expand Down
Loading

0 comments on commit 741e418

Please sign in to comment.