Skip to content

Commit

Permalink
Serialization testing for all StatusEvents form the base runner
Browse files Browse the repository at this point in the history
  • Loading branch information
DanSava committed Jan 10, 2025
1 parent 46e184d commit ca2d82b
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 91 deletions.
30 changes: 0 additions & 30 deletions src/ert/ensemble_evaluator/event.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from dataclasses import dataclass

from .snapshot import EnsembleSnapshot
Expand Down Expand Up @@ -28,32 +27,3 @@ class SnapshotUpdateEvent(_UpdateEvent):
class EndEvent:
failed: bool
msg: str | None = None


def snapshot_event_from_json(json_str: str) -> FullSnapshotEvent | SnapshotUpdateEvent:
json_dict = json.loads(json_str)
snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"])
json_dict["snapshot"] = snapshot
match json_dict.pop("type"):
case "FullSnapshotEvent":
return FullSnapshotEvent(**json_dict)
case "SnapshotUpdateEvent":
return SnapshotUpdateEvent(**json_dict)
case unknown:
raise TypeError(f"Unknown snapshot update event type {unknown}")


def snapshot_event_to_json(event: FullSnapshotEvent | SnapshotUpdateEvent) -> str:
assert event.snapshot is not None
return json.dumps(
{
"iteration_label": event.iteration_label,
"total_iterations": event.total_iterations,
"progress": event.progress,
"realization_count": event.realization_count,
"status_count": event.status_count,
"iteration": event.iteration,
"snapshot": event.snapshot.to_dict(),
"type": event.__class__.__name__,
}
)
24 changes: 20 additions & 4 deletions src/ert/ensemble_evaluator/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,17 @@ def to_dict(self) -> dict[str, Any]:
if self._ensemble_state:
dict_["status"] = self._ensemble_state
if self._realization_snapshots:
dict_["reals"] = self._realization_snapshots
dict_["reals"] = {
k: _filter_nones(v) for k, v in self._realization_snapshots.items()
}

for (real_id, fm_id), fm_values_dict in self._fm_step_snapshots.items():
if "reals" not in dict_:
dict_["reals"] = {}
if real_id not in dict_["reals"]:
dict_["reals"][real_id] = RealizationSnapshot(fm_steps={})
dict_["reals"][real_id] = _filter_nones(
RealizationSnapshot(fm_steps={})
)
if "fm_steps" not in dict_["reals"][real_id]:
dict_["reals"][real_id]["fm_steps"] = {}

Expand Down Expand Up @@ -397,15 +401,27 @@ class RealizationSnapshot(TypedDict, total=False):
def _realization_dict_to_realization_snapshot(
source: dict[str, Any],
) -> RealizationSnapshot:
start_time = source.get("start_time")
if start_time and isinstance(start_time, str):
start_time = datetime.fromisoformat(start_time)
end_time = source.get("end_time")
if end_time and isinstance(end_time, str):
end_time = datetime.fromisoformat(end_time)

realization = RealizationSnapshot(
status=source.get("status"),
active=source.get("active"),
start_time=source.get("start_time"),
end_time=source.get("end_time"),
start_time=start_time,
end_time=end_time,
exec_hosts=source.get("exec_hosts"),
message=source.get("message"),
fm_steps=source.get("fm_steps", {}),
)
for step in realization["fm_steps"].values():
if "start_time" in step and isinstance(step["start_time"], str):
step["start_time"] = datetime.fromisoformat(step["start_time"])
if "end_time" in step and isinstance(step["end_time"], str):
step["end_time"] = datetime.fromisoformat(step["end_time"])
return _filter_nones(realization)


Expand Down
8 changes: 0 additions & 8 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@

from _ert.events import EESnapshot, EESnapshotUpdate, EETerminated, Event
from ert.analysis import (
AnalysisEvent,
AnalysisStatusEvent,
AnalysisTimeEvent,
ErtAnalysisError,
smoother_update,
)
Expand All @@ -40,11 +37,6 @@
Monitor,
Realization,
)
from ert.ensemble_evaluator.event import (
EndEvent,
FullSnapshotEvent,
SnapshotUpdateEvent,
)
from ert.ensemble_evaluator.identifiers import STATUS
from ert.ensemble_evaluator.snapshot import EnsembleSnapshot
from ert.ensemble_evaluator.state import (
Expand Down
74 changes: 73 additions & 1 deletion src/ert/run_models/event.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from dataclasses import dataclass
import json
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from uuid import UUID

Expand All @@ -15,6 +17,7 @@
FullSnapshotEvent,
SnapshotUpdateEvent,
)
from ert.ensemble_evaluator.snapshot import EnsembleSnapshot


@dataclass
Expand Down Expand Up @@ -82,3 +85,72 @@ def write_as_csv(self, output_path: Path | None) -> None:
| RunModelDataEvent
| RunModelUpdateEndEvent
)


EVENT_MAPPING = {
"AnalysisEvent": AnalysisEvent,
"AnalysisStatusEvent": AnalysisStatusEvent,
"AnalysisTimeEvent": AnalysisTimeEvent,
"EndEvent": EndEvent,
"FullSnapshotEvent": FullSnapshotEvent,
"SnapshotUpdateEvent": SnapshotUpdateEvent,
"RunModelErrorEvent": RunModelErrorEvent,
"RunModelStatusEvent": RunModelStatusEvent,
"RunModelTimeEvent": RunModelTimeEvent,
"RunModelUpdateBeginEvent": RunModelUpdateBeginEvent,
"RunModelDataEvent": RunModelDataEvent,
"RunModelUpdateEndEvent": RunModelUpdateEndEvent,
}


def status_event_from_json(json_str: str) -> StatusEvents:
json_dict = json.loads(json_str)
event_type = json_dict.pop("event_type", None)

match event_type:
case FullSnapshotEvent.__name__:
snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"])
json_dict["snapshot"] = snapshot
return FullSnapshotEvent(**json_dict)
case SnapshotUpdateEvent.__name__:
snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"])
json_dict["snapshot"] = snapshot
return SnapshotUpdateEvent(**json_dict)
case RunModelDataEvent.__name__ | RunModelUpdateEndEvent.__name__:
if "run_id" in json_dict and isinstance(json_dict["run_id"], str):
json_dict["run_id"] = UUID(json_dict["run_id"])
if json_dict.get("data"):
json_dict["data"] = DataSection(**json_dict["data"])
return EVENT_MAPPING[event_type](**json_dict)
case _:
if event_type in EVENT_MAPPING:
if "run_id" in json_dict and isinstance(json_dict["run_id"], str):
json_dict["run_id"] = UUID(json_dict["run_id"])
return EVENT_MAPPING[event_type](**json_dict)
else:
raise TypeError(f"Unknown status event type {event_type}")


def status_event_to_json(event: StatusEvents) -> str:
match event:
case FullSnapshotEvent() | SnapshotUpdateEvent():
assert event.snapshot is not None
event_dict = asdict(event)
event_dict.update(
{
"snapshot": event.snapshot.to_dict(),
"event_type": event.__class__.__name__,
}
)
return json.dumps(
event_dict,
default=lambda o: o.strftime("%Y-%m-%dT%H:%M:%S")
if isinstance(o, datetime)
else None,
)
case StatusEvents:
event_dict = asdict(event)
event_dict["event_type"] = StatusEvents.__class__.__name__
return json.dumps(
event_dict, default=lambda o: str(o) if isinstance(o, UUID) else None
)
2 changes: 2 additions & 0 deletions tests/ert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def build(
exec_hosts: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
message: str | None = None,
) -> EnsembleSnapshot:
snapshot = EnsembleSnapshot()
snapshot._ensemble_state = status
Expand All @@ -53,6 +54,7 @@ def build(
end_time=end_time,
exec_hosts=exec_hosts,
status=status,
message=message,
),
)
return snapshot
Expand Down
160 changes: 160 additions & 0 deletions tests/ert/unit_tests/run_models/test_status_evens_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import uuid
from collections import defaultdict
from datetime import datetime as dt

import pytest

from ert.analysis.event import DataSection
from ert.ensemble_evaluator import state
from ert.ensemble_evaluator.snapshot import EnsembleSnapshotMetadata
from ert.run_models.event import (
AnalysisEvent,
AnalysisStatusEvent,
AnalysisTimeEvent,
EndEvent,
FullSnapshotEvent,
RunModelDataEvent,
RunModelStatusEvent,
RunModelTimeEvent,
RunModelUpdateBeginEvent,
RunModelUpdateEndEvent,
SnapshotUpdateEvent,
status_event_from_json,
status_event_to_json,
)
from tests.ert import SnapshotBuilder

METADATA = EnsembleSnapshotMetadata(
aggr_fm_step_status_colors=defaultdict(dict),
real_status_colors={},
sorted_real_ids=[],
sorted_fm_step_ids=defaultdict(list),
)


@pytest.mark.parametrize(
"events",
[
pytest.param(
[
FullSnapshotEvent(
snapshot=(
SnapshotBuilder(metadata=METADATA)
.add_fm_step(
fm_step_id="0",
index="0",
name="fm_step_0",
status=state.FORWARD_MODEL_STATE_START,
current_memory_usage="500",
max_memory_usage="1000",
stdout="job_fm_step_0.stdout",
stderr="job_fm_step_0.stderr",
start_time=dt(1999, 1, 1),
)
.add_fm_step(
fm_step_id="1",
index="1",
name="fm_step_1",
status=state.FORWARD_MODEL_STATE_START,
current_memory_usage="500",
max_memory_usage="1000",
stdout="job_fm_step_1.stdout",
stderr="job_fm_step_1.stderr",
start_time=dt(1999, 1, 1),
end_time=None,
)
.build(
real_ids=["0", "1"],
status=state.REALIZATION_STATE_UNKNOWN,
start_time=dt(1999, 1, 1),
exec_hosts="12121.121",
message="Some message",
)
),
iteration_label="Foo",
total_iterations=1,
progress=0.25,
realization_count=4,
status_count={"Finished": 1, "Pending": 2, "Unknown": 1},
iteration=0,
),
SnapshotUpdateEvent(
snapshot=SnapshotBuilder(metadata=METADATA)
.add_fm_step(
fm_step_id="0",
index="0",
status=state.FORWARD_MODEL_STATE_FINISHED,
name="fm_step_0",
end_time=dt(2019, 1, 1),
)
.build(
real_ids=["1"],
status=state.REALIZATION_STATE_RUNNING,
),
iteration_label="Foo",
total_iterations=1,
progress=0.5,
realization_count=4,
status_count={"Finished": 2, "Running": 1, "Unknown": 1},
iteration=0,
),
SnapshotUpdateEvent(
snapshot=SnapshotBuilder(metadata=METADATA)
.add_fm_step(
fm_step_id="1",
index="1",
status=state.FORWARD_MODEL_STATE_FAILURE,
name="fm_step_1",
)
.build(
real_ids=["0"],
status=state.REALIZATION_STATE_FAILED,
end_time=dt(2019, 1, 1),
),
iteration_label="Foo",
total_iterations=1,
progress=0.5,
realization_count=4,
status_count={"Finished": 2, "Failed": 1, "Unknown": 1},
iteration=0,
),
AnalysisEvent(),
AnalysisStatusEvent(msg="hello"),
AnalysisTimeEvent(remaining_time=22.2, elapsed_time=200.42),
EndEvent(failed=False, msg=""),
RunModelStatusEvent(iteration=1, run_id=uuid.uuid1(), msg="Hello"),
RunModelTimeEvent(
iteration=1,
run_id=uuid.uuid1(),
remaining_time=10.42,
elapsed_time=100.42,
),
RunModelUpdateBeginEvent(iteration=2, run_id=uuid.uuid1()),
RunModelDataEvent(
iteration=1,
run_id=uuid.uuid1(),
name="Micky",
data=DataSection(
header=["Some", "string", "elements"],
data=[["a", 1.1, "b"], ["c", 3]],
extra={"a": "b", "c": "d"},
),
),
RunModelUpdateEndEvent(
iteration=3,
run_id=uuid.uuid1(),
data=DataSection(
header=["Some", "string", "elements"],
data=[["a", 1.1, "b"], ["c", 3]],
extra={"a": "b", "c": "d"},
),
),
],
),
],
)
def test_status_event_serialization(events):
for event in events:
json_res = status_event_to_json(event)
round_trip_event = status_event_from_json(json_res)
assert event == round_trip_event
Loading

0 comments on commit ca2d82b

Please sign in to comment.