From ca2d82bf8c8178a5dea034f93708f375c8c13c65 Mon Sep 17 00:00:00 2001 From: DanSava Date: Fri, 10 Jan 2025 11:44:00 +0200 Subject: [PATCH] Serialization testing for all StatusEvents form the base runner --- src/ert/ensemble_evaluator/event.py | 30 ---- src/ert/ensemble_evaluator/snapshot.py | 24 ++- src/ert/run_models/base_run_model.py | 8 - src/ert/run_models/event.py | 74 +++++++- tests/ert/__init__.py | 2 + .../test_status_evens_serialization.py | 160 ++++++++++++++++++ tests/everest/test_everserver.py | 49 +----- 7 files changed, 256 insertions(+), 91 deletions(-) create mode 100644 tests/ert/unit_tests/run_models/test_status_evens_serialization.py diff --git a/src/ert/ensemble_evaluator/event.py b/src/ert/ensemble_evaluator/event.py index deabab3c7c7..5639269fd6f 100644 --- a/src/ert/ensemble_evaluator/event.py +++ b/src/ert/ensemble_evaluator/event.py @@ -1,4 +1,3 @@ -import json from dataclasses import dataclass from .snapshot import EnsembleSnapshot @@ -28,32 +27,3 @@ class SnapshotUpdateEvent(_UpdateEvent): class EndEvent: failed: bool msg: str | None = None - - -def snapshot_event_from_json(json_str: str) -> FullSnapshotEvent | SnapshotUpdateEvent: - json_dict = json.loads(json_str) - snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"]) - json_dict["snapshot"] = snapshot - match json_dict.pop("type"): - case "FullSnapshotEvent": - return FullSnapshotEvent(**json_dict) - case "SnapshotUpdateEvent": - return SnapshotUpdateEvent(**json_dict) - case unknown: - raise TypeError(f"Unknown snapshot update event type {unknown}") - - -def snapshot_event_to_json(event: FullSnapshotEvent | SnapshotUpdateEvent) -> str: - assert event.snapshot is not None - return json.dumps( - { - "iteration_label": event.iteration_label, - "total_iterations": event.total_iterations, - "progress": event.progress, - "realization_count": event.realization_count, - "status_count": event.status_count, - "iteration": event.iteration, - "snapshot": event.snapshot.to_dict(), - "type": event.__class__.__name__, - } - ) diff --git a/src/ert/ensemble_evaluator/snapshot.py b/src/ert/ensemble_evaluator/snapshot.py index 76f017172f2..15918b4907b 100644 --- a/src/ert/ensemble_evaluator/snapshot.py +++ b/src/ert/ensemble_evaluator/snapshot.py @@ -161,13 +161,17 @@ def to_dict(self) -> dict[str, Any]: if self._ensemble_state: dict_["status"] = self._ensemble_state if self._realization_snapshots: - dict_["reals"] = self._realization_snapshots + dict_["reals"] = { + k: _filter_nones(v) for k, v in self._realization_snapshots.items() + } for (real_id, fm_id), fm_values_dict in self._fm_step_snapshots.items(): if "reals" not in dict_: dict_["reals"] = {} if real_id not in dict_["reals"]: - dict_["reals"][real_id] = RealizationSnapshot(fm_steps={}) + dict_["reals"][real_id] = _filter_nones( + RealizationSnapshot(fm_steps={}) + ) if "fm_steps" not in dict_["reals"][real_id]: dict_["reals"][real_id]["fm_steps"] = {} @@ -397,15 +401,27 @@ class RealizationSnapshot(TypedDict, total=False): def _realization_dict_to_realization_snapshot( source: dict[str, Any], ) -> RealizationSnapshot: + start_time = source.get("start_time") + if start_time and isinstance(start_time, str): + start_time = datetime.fromisoformat(start_time) + end_time = source.get("end_time") + if end_time and isinstance(end_time, str): + end_time = datetime.fromisoformat(end_time) + realization = RealizationSnapshot( status=source.get("status"), active=source.get("active"), - start_time=source.get("start_time"), - end_time=source.get("end_time"), + start_time=start_time, + end_time=end_time, exec_hosts=source.get("exec_hosts"), message=source.get("message"), fm_steps=source.get("fm_steps", {}), ) + for step in realization["fm_steps"].values(): + if "start_time" in step and isinstance(step["start_time"], str): + step["start_time"] = datetime.fromisoformat(step["start_time"]) + if "end_time" in step and isinstance(step["end_time"], str): + step["end_time"] = datetime.fromisoformat(step["end_time"]) return _filter_nones(realization) diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index fd008a5158a..37c46ba0352 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -20,9 +20,6 @@ from _ert.events import EESnapshot, EESnapshotUpdate, EETerminated, Event from ert.analysis import ( - AnalysisEvent, - AnalysisStatusEvent, - AnalysisTimeEvent, ErtAnalysisError, smoother_update, ) @@ -40,11 +37,6 @@ Monitor, Realization, ) -from ert.ensemble_evaluator.event import ( - EndEvent, - FullSnapshotEvent, - SnapshotUpdateEvent, -) from ert.ensemble_evaluator.identifiers import STATUS from ert.ensemble_evaluator.snapshot import EnsembleSnapshot from ert.ensemble_evaluator.state import ( diff --git a/src/ert/run_models/event.py b/src/ert/run_models/event.py index 795935767ff..153a509b9c1 100644 --- a/src/ert/run_models/event.py +++ b/src/ert/run_models/event.py @@ -1,6 +1,8 @@ from __future__ import annotations -from dataclasses import dataclass +import json +from dataclasses import asdict, dataclass +from datetime import datetime from pathlib import Path from uuid import UUID @@ -15,6 +17,7 @@ FullSnapshotEvent, SnapshotUpdateEvent, ) +from ert.ensemble_evaluator.snapshot import EnsembleSnapshot @dataclass @@ -82,3 +85,72 @@ def write_as_csv(self, output_path: Path | None) -> None: | RunModelDataEvent | RunModelUpdateEndEvent ) + + +EVENT_MAPPING = { + "AnalysisEvent": AnalysisEvent, + "AnalysisStatusEvent": AnalysisStatusEvent, + "AnalysisTimeEvent": AnalysisTimeEvent, + "EndEvent": EndEvent, + "FullSnapshotEvent": FullSnapshotEvent, + "SnapshotUpdateEvent": SnapshotUpdateEvent, + "RunModelErrorEvent": RunModelErrorEvent, + "RunModelStatusEvent": RunModelStatusEvent, + "RunModelTimeEvent": RunModelTimeEvent, + "RunModelUpdateBeginEvent": RunModelUpdateBeginEvent, + "RunModelDataEvent": RunModelDataEvent, + "RunModelUpdateEndEvent": RunModelUpdateEndEvent, +} + + +def status_event_from_json(json_str: str) -> StatusEvents: + json_dict = json.loads(json_str) + event_type = json_dict.pop("event_type", None) + + match event_type: + case FullSnapshotEvent.__name__: + snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"]) + json_dict["snapshot"] = snapshot + return FullSnapshotEvent(**json_dict) + case SnapshotUpdateEvent.__name__: + snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"]) + json_dict["snapshot"] = snapshot + return SnapshotUpdateEvent(**json_dict) + case RunModelDataEvent.__name__ | RunModelUpdateEndEvent.__name__: + if "run_id" in json_dict and isinstance(json_dict["run_id"], str): + json_dict["run_id"] = UUID(json_dict["run_id"]) + if json_dict.get("data"): + json_dict["data"] = DataSection(**json_dict["data"]) + return EVENT_MAPPING[event_type](**json_dict) + case _: + if event_type in EVENT_MAPPING: + if "run_id" in json_dict and isinstance(json_dict["run_id"], str): + json_dict["run_id"] = UUID(json_dict["run_id"]) + return EVENT_MAPPING[event_type](**json_dict) + else: + raise TypeError(f"Unknown status event type {event_type}") + + +def status_event_to_json(event: StatusEvents) -> str: + match event: + case FullSnapshotEvent() | SnapshotUpdateEvent(): + assert event.snapshot is not None + event_dict = asdict(event) + event_dict.update( + { + "snapshot": event.snapshot.to_dict(), + "event_type": event.__class__.__name__, + } + ) + return json.dumps( + event_dict, + default=lambda o: o.strftime("%Y-%m-%dT%H:%M:%S") + if isinstance(o, datetime) + else None, + ) + case StatusEvents: + event_dict = asdict(event) + event_dict["event_type"] = StatusEvents.__class__.__name__ + return json.dumps( + event_dict, default=lambda o: str(o) if isinstance(o, UUID) else None + ) diff --git a/tests/ert/__init__.py b/tests/ert/__init__.py index ee05dbc228c..e3db923e86d 100644 --- a/tests/ert/__init__.py +++ b/tests/ert/__init__.py @@ -38,6 +38,7 @@ def build( exec_hosts: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, + message: str | None = None, ) -> EnsembleSnapshot: snapshot = EnsembleSnapshot() snapshot._ensemble_state = status @@ -53,6 +54,7 @@ def build( end_time=end_time, exec_hosts=exec_hosts, status=status, + message=message, ), ) return snapshot diff --git a/tests/ert/unit_tests/run_models/test_status_evens_serialization.py b/tests/ert/unit_tests/run_models/test_status_evens_serialization.py new file mode 100644 index 00000000000..758b8546965 --- /dev/null +++ b/tests/ert/unit_tests/run_models/test_status_evens_serialization.py @@ -0,0 +1,160 @@ +import uuid +from collections import defaultdict +from datetime import datetime as dt + +import pytest + +from ert.analysis.event import DataSection +from ert.ensemble_evaluator import state +from ert.ensemble_evaluator.snapshot import EnsembleSnapshotMetadata +from ert.run_models.event import ( + AnalysisEvent, + AnalysisStatusEvent, + AnalysisTimeEvent, + EndEvent, + FullSnapshotEvent, + RunModelDataEvent, + RunModelStatusEvent, + RunModelTimeEvent, + RunModelUpdateBeginEvent, + RunModelUpdateEndEvent, + SnapshotUpdateEvent, + status_event_from_json, + status_event_to_json, +) +from tests.ert import SnapshotBuilder + +METADATA = EnsembleSnapshotMetadata( + aggr_fm_step_status_colors=defaultdict(dict), + real_status_colors={}, + sorted_real_ids=[], + sorted_fm_step_ids=defaultdict(list), +) + + +@pytest.mark.parametrize( + "events", + [ + pytest.param( + [ + FullSnapshotEvent( + snapshot=( + SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="0", + index="0", + name="fm_step_0", + status=state.FORWARD_MODEL_STATE_START, + current_memory_usage="500", + max_memory_usage="1000", + stdout="job_fm_step_0.stdout", + stderr="job_fm_step_0.stderr", + start_time=dt(1999, 1, 1), + ) + .add_fm_step( + fm_step_id="1", + index="1", + name="fm_step_1", + status=state.FORWARD_MODEL_STATE_START, + current_memory_usage="500", + max_memory_usage="1000", + stdout="job_fm_step_1.stdout", + stderr="job_fm_step_1.stderr", + start_time=dt(1999, 1, 1), + end_time=None, + ) + .build( + real_ids=["0", "1"], + status=state.REALIZATION_STATE_UNKNOWN, + start_time=dt(1999, 1, 1), + exec_hosts="12121.121", + message="Some message", + ) + ), + iteration_label="Foo", + total_iterations=1, + progress=0.25, + realization_count=4, + status_count={"Finished": 1, "Pending": 2, "Unknown": 1}, + iteration=0, + ), + SnapshotUpdateEvent( + snapshot=SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="0", + index="0", + status=state.FORWARD_MODEL_STATE_FINISHED, + name="fm_step_0", + end_time=dt(2019, 1, 1), + ) + .build( + real_ids=["1"], + status=state.REALIZATION_STATE_RUNNING, + ), + iteration_label="Foo", + total_iterations=1, + progress=0.5, + realization_count=4, + status_count={"Finished": 2, "Running": 1, "Unknown": 1}, + iteration=0, + ), + SnapshotUpdateEvent( + snapshot=SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="1", + index="1", + status=state.FORWARD_MODEL_STATE_FAILURE, + name="fm_step_1", + ) + .build( + real_ids=["0"], + status=state.REALIZATION_STATE_FAILED, + end_time=dt(2019, 1, 1), + ), + iteration_label="Foo", + total_iterations=1, + progress=0.5, + realization_count=4, + status_count={"Finished": 2, "Failed": 1, "Unknown": 1}, + iteration=0, + ), + AnalysisEvent(), + AnalysisStatusEvent(msg="hello"), + AnalysisTimeEvent(remaining_time=22.2, elapsed_time=200.42), + EndEvent(failed=False, msg=""), + RunModelStatusEvent(iteration=1, run_id=uuid.uuid1(), msg="Hello"), + RunModelTimeEvent( + iteration=1, + run_id=uuid.uuid1(), + remaining_time=10.42, + elapsed_time=100.42, + ), + RunModelUpdateBeginEvent(iteration=2, run_id=uuid.uuid1()), + RunModelDataEvent( + iteration=1, + run_id=uuid.uuid1(), + name="Micky", + data=DataSection( + header=["Some", "string", "elements"], + data=[["a", 1.1, "b"], ["c", 3]], + extra={"a": "b", "c": "d"}, + ), + ), + RunModelUpdateEndEvent( + iteration=3, + run_id=uuid.uuid1(), + data=DataSection( + header=["Some", "string", "elements"], + data=[["a", 1.1, "b"], ["c", 3]], + extra={"a": "b", "c": "d"}, + ), + ), + ], + ), + ], +) +def test_status_event_serialization(events): + for event in events: + json_res = status_event_to_json(event) + round_trip_event = status_event_from_json(json_res) + assert event == round_trip_event diff --git a/tests/everest/test_everserver.py b/tests/everest/test_everserver.py index 68f8af189a5..4075e3fbc41 100644 --- a/tests/everest/test_everserver.py +++ b/tests/everest/test_everserver.py @@ -7,13 +7,7 @@ from seba_sqlite.snapshot import SebaSnapshot -from _ert.events import event_from_json, event_to_json -from ert.ensemble_evaluator import FullSnapshotEvent, SnapshotUpdateEvent -from ert.ensemble_evaluator.event import ( - snapshot_event_from_json, - snapshot_event_to_json, -) -from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel +from ert.run_models.everest_run_model import EverestExitCode from everest.config import EverestConfig, OptimizationConfig, ServerConfig from everest.detached import ServerStatus, everserver_status from everest.detached.jobs import everserver @@ -254,44 +248,3 @@ def test_everserver_status_contains_max_runtime_failure( "sleep Failed with: The run is cancelled due to reaching MAX_RUNTIME" in status["message"] ) - - -def test_event_serialization( - copy_math_func_test_data_to_tmp, - evaluator_server_config_generator, -): - config = EverestConfig.load_file("config_minimal.yml") - - def check_status_round_tripping(status): - round_trip_status = json.loads(json.dumps(status)) - assert round_trip_status == status - - run_model = EverestRunModel.create( - config, - simulation_callback=check_status_round_tripping, - ) - send_event = run_model.send_event - send_snapshot_event = run_model.send_snapshot_event - - def check_snapshot_event_serialization_round_trip(*args, **_): - event, _ = args - event_json = event_to_json(event) - round_trip_event = event_from_json(str(event_json)) - assert event == round_trip_event - send_snapshot_event(*args) - - def check_event_serialization_round_trip(event): - if isinstance(event, (FullSnapshotEvent | SnapshotUpdateEvent)): - json_str = snapshot_event_to_json(event) - round_trip = snapshot_event_from_json(json_str) - assert event == round_trip - send_event(event) - - run_model.send_event = check_event_serialization_round_trip - run_model.send_snapshot_event = check_snapshot_event_serialization_round_trip - - evaluator_server_config = evaluator_server_config_generator(run_model) - - run_model.run_experiment(evaluator_server_config) - - assert run_model.result is not None