From 19dc942da5a287581dc5cf12a8c7f83445f6843b Mon Sep 17 00:00:00 2001 From: DanSava Date: Tue, 17 Dec 2024 15:19:45 +0200 Subject: [PATCH 1/4] Add event serialization testing --- tests/everest/test_everserver.py | 34 +++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/everest/test_everserver.py b/tests/everest/test_everserver.py index 4075e3fbc41..3793eab37f2 100644 --- a/tests/everest/test_everserver.py +++ b/tests/everest/test_everserver.py @@ -7,7 +7,8 @@ from seba_sqlite.snapshot import SebaSnapshot -from ert.run_models.everest_run_model import EverestExitCode +from _ert.events import event_from_json, event_to_json +from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel from everest.config import EverestConfig, OptimizationConfig, ServerConfig from everest.detached import ServerStatus, everserver_status from everest.detached.jobs import everserver @@ -248,3 +249,34 @@ def test_everserver_status_contains_max_runtime_failure( "sleep Failed with: The run is cancelled due to reaching MAX_RUNTIME" in status["message"] ) + + +def test_event_serialization( + copy_math_func_test_data_to_tmp, + evaluator_server_config_generator, +): + config = EverestConfig.load_file("config_minimal.yml") + + def check_status_round_tripping(status): + round_trip_status = json.loads(json.dumps(status)) + assert round_trip_status == status + + run_model = EverestRunModel.create( + config, + simulation_callback=check_status_round_tripping, + ) + + send_snapshot_event = run_model.send_snapshot_event + + def check_event_serialization_round_trip(*args, **_): + event, _ = args + event_json = event_to_json(event) + round_trip_event = event_from_json(str(event_json)) + assert event == round_trip_event + send_snapshot_event(*args) + + run_model.send_snapshot_event = check_event_serialization_round_trip + + evaluator_server_config = evaluator_server_config_generator(run_model) + + run_model.run_experiment(evaluator_server_config) From c5c4996c6753b5ba12cd698a7e22097a47c2bebb Mon Sep 17 00:00:00 2001 From: DanSava Date: Mon, 6 Jan 2025 16:51:43 +0200 Subject: [PATCH 2/4] Make sure FullSnapshotEvent and SnapshotUpdateEvent is json serializable --- src/ert/ensemble_evaluator/event.py | 30 ++++++++++++++++++++++++++ src/ert/ensemble_evaluator/snapshot.py | 5 +++++ src/ert/run_models/base_run_model.py | 2 +- tests/everest/test_everserver.py | 21 +++++++++++++++--- 4 files changed, 54 insertions(+), 4 deletions(-) diff --git a/src/ert/ensemble_evaluator/event.py b/src/ert/ensemble_evaluator/event.py index 5639269fd6f..deabab3c7c7 100644 --- a/src/ert/ensemble_evaluator/event.py +++ b/src/ert/ensemble_evaluator/event.py @@ -1,3 +1,4 @@ +import json from dataclasses import dataclass from .snapshot import EnsembleSnapshot @@ -27,3 +28,32 @@ class SnapshotUpdateEvent(_UpdateEvent): class EndEvent: failed: bool msg: str | None = None + + +def snapshot_event_from_json(json_str: str) -> FullSnapshotEvent | SnapshotUpdateEvent: + json_dict = json.loads(json_str) + snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"]) + json_dict["snapshot"] = snapshot + match json_dict.pop("type"): + case "FullSnapshotEvent": + return FullSnapshotEvent(**json_dict) + case "SnapshotUpdateEvent": + return SnapshotUpdateEvent(**json_dict) + case unknown: + raise TypeError(f"Unknown snapshot update event type {unknown}") + + +def snapshot_event_to_json(event: FullSnapshotEvent | SnapshotUpdateEvent) -> str: + assert event.snapshot is not None + return json.dumps( + { + "iteration_label": event.iteration_label, + "total_iterations": event.total_iterations, + "progress": event.progress, + "realization_count": event.realization_count, + "status_count": event.status_count, + "iteration": event.iteration, + "snapshot": event.snapshot.to_dict(), + "type": event.__class__.__name__, + } + ) diff --git a/src/ert/ensemble_evaluator/snapshot.py b/src/ert/ensemble_evaluator/snapshot.py index 1b4c5ebb6de..76f017172f2 100644 --- a/src/ert/ensemble_evaluator/snapshot.py +++ b/src/ert/ensemble_evaluator/snapshot.py @@ -113,6 +113,11 @@ def __init__(self) -> None: sorted_fm_step_ids=defaultdict(list), ) + def __eq__(self, other: object) -> bool: + if not isinstance(other, EnsembleSnapshot): + return NotImplemented + return self.to_dict() == other.to_dict() + @classmethod def from_nested_dict(cls, source: Mapping[Any, Any]) -> EnsembleSnapshot: ensemble = EnsembleSnapshot() diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index d8c795e10ac..83d10ac06a2 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -419,7 +419,7 @@ def get_current_status(self) -> dict[str, int]: status["Finished"] += ( self._get_number_of_finished_realizations_from_reruns() ) - return status + return dict(status) def _get_number_of_finished_realizations_from_reruns(self) -> int: return self.active_realizations.count( diff --git a/tests/everest/test_everserver.py b/tests/everest/test_everserver.py index 3793eab37f2..68f8af189a5 100644 --- a/tests/everest/test_everserver.py +++ b/tests/everest/test_everserver.py @@ -8,6 +8,11 @@ from seba_sqlite.snapshot import SebaSnapshot from _ert.events import event_from_json, event_to_json +from ert.ensemble_evaluator import FullSnapshotEvent, SnapshotUpdateEvent +from ert.ensemble_evaluator.event import ( + snapshot_event_from_json, + snapshot_event_to_json, +) from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel from everest.config import EverestConfig, OptimizationConfig, ServerConfig from everest.detached import ServerStatus, everserver_status @@ -265,18 +270,28 @@ def check_status_round_tripping(status): config, simulation_callback=check_status_round_tripping, ) - + send_event = run_model.send_event send_snapshot_event = run_model.send_snapshot_event - def check_event_serialization_round_trip(*args, **_): + def check_snapshot_event_serialization_round_trip(*args, **_): event, _ = args event_json = event_to_json(event) round_trip_event = event_from_json(str(event_json)) assert event == round_trip_event send_snapshot_event(*args) - run_model.send_snapshot_event = check_event_serialization_round_trip + def check_event_serialization_round_trip(event): + if isinstance(event, (FullSnapshotEvent | SnapshotUpdateEvent)): + json_str = snapshot_event_to_json(event) + round_trip = snapshot_event_from_json(json_str) + assert event == round_trip + send_event(event) + + run_model.send_event = check_event_serialization_round_trip + run_model.send_snapshot_event = check_snapshot_event_serialization_round_trip evaluator_server_config = evaluator_server_config_generator(run_model) run_model.run_experiment(evaluator_server_config) + + assert run_model.result is not None From 8e9996f979c824b92de321d95e4a530b0c497637 Mon Sep 17 00:00:00 2001 From: DanSava Date: Thu, 9 Jan 2025 14:51:39 +0200 Subject: [PATCH 3/4] Move StatusEvents to run_models/events --- src/ert/cli/main.py | 2 +- src/ert/cli/monitor.py | 2 +- src/ert/run_models/base_run_model.py | 22 +++++++--------------- src/ert/run_models/event.py | 26 ++++++++++++++++++++++++++ src/ert/run_models/model_factory.py | 2 +- 5 files changed, 36 insertions(+), 18 deletions(-) diff --git a/src/ert/cli/main.py b/src/ert/cli/main.py index 547ec323cd0..7552851de7a 100644 --- a/src/ert/cli/main.py +++ b/src/ert/cli/main.py @@ -24,7 +24,7 @@ ) from ert.namespace import Namespace from ert.plugins import ErtPluginManager -from ert.run_models.base_run_model import StatusEvents +from ert.run_models.event import StatusEvents from ert.run_models.model_factory import create_model from ert.storage import open_storage from ert.storage.local_storage import local_storage_set_ert_config diff --git a/src/ert/cli/monitor.py b/src/ert/cli/monitor.py index 921b584f222..429b997ca53 100644 --- a/src/ert/cli/monitor.py +++ b/src/ert/cli/monitor.py @@ -22,11 +22,11 @@ FORWARD_MODEL_STATE_FAILURE, REAL_STATE_TO_COLOR, ) -from ert.run_models.base_run_model import StatusEvents from ert.run_models.event import ( RunModelDataEvent, RunModelErrorEvent, RunModelUpdateEndEvent, + StatusEvents, ) from ert.shared.status.utils import format_running_time diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index 83d10ac06a2..4fa3b84feee 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -68,12 +68,19 @@ from ..config.analysis_config import UpdateSettings from ..run_arg import RunArg from .event import ( + AnalysisEvent, + AnalysisStatusEvent, + AnalysisTimeEvent, + EndEvent, + FullSnapshotEvent, RunModelDataEvent, RunModelErrorEvent, RunModelStatusEvent, RunModelTimeEvent, RunModelUpdateBeginEvent, RunModelUpdateEndEvent, + SnapshotUpdateEvent, + StatusEvents, ) logger = logging.getLogger(__name__) @@ -81,21 +88,6 @@ if TYPE_CHECKING: from ert.config import QueueConfig -StatusEvents = ( - FullSnapshotEvent - | SnapshotUpdateEvent - | EndEvent - | AnalysisEvent - | AnalysisStatusEvent - | AnalysisTimeEvent - | RunModelErrorEvent - | RunModelStatusEvent - | RunModelTimeEvent - | RunModelUpdateBeginEvent - | RunModelDataEvent - | RunModelUpdateEndEvent -) - class OutOfOrderSnapshotUpdateException(ValueError): pass diff --git a/src/ert/run_models/event.py b/src/ert/run_models/event.py index 424c8ec9235..795935767ff 100644 --- a/src/ert/run_models/event.py +++ b/src/ert/run_models/event.py @@ -4,7 +4,17 @@ from pathlib import Path from uuid import UUID +from ert.analysis import ( + AnalysisEvent, + AnalysisStatusEvent, + AnalysisTimeEvent, +) from ert.analysis.event import DataSection +from ert.ensemble_evaluator.event import ( + EndEvent, + FullSnapshotEvent, + SnapshotUpdateEvent, +) @dataclass @@ -56,3 +66,19 @@ class RunModelErrorEvent(RunModelEvent): def write_as_csv(self, output_path: Path | None) -> None: if output_path and self.data: self.data.to_csv("Report", output_path / str(self.run_id)) + + +StatusEvents = ( + AnalysisEvent + | AnalysisStatusEvent + | AnalysisTimeEvent + | EndEvent + | FullSnapshotEvent + | SnapshotUpdateEvent + | RunModelErrorEvent + | RunModelStatusEvent + | RunModelTimeEvent + | RunModelUpdateBeginEvent + | RunModelDataEvent + | RunModelUpdateEndEvent +) diff --git a/src/ert/run_models/model_factory.py b/src/ert/run_models/model_factory.py index e1df99a9d5f..0cede97e4d2 100644 --- a/src/ert/run_models/model_factory.py +++ b/src/ert/run_models/model_factory.py @@ -32,7 +32,7 @@ import numpy.typing as npt from ert.namespace import Namespace - from ert.run_models.base_run_model import StatusEvents + from ert.run_models.event import StatusEvents from ert.storage import Storage From 9abf27ff11c9dd9d33029f18d1c36962912bbc0e Mon Sep 17 00:00:00 2001 From: DanSava Date: Fri, 10 Jan 2025 11:44:00 +0200 Subject: [PATCH 4/4] Serialization testing for all StatusEvents form the base runner --- src/ert/analysis/event.py | 18 +- src/ert/ensemble_evaluator/event.py | 70 +++---- src/ert/ensemble_evaluator/snapshot.py | 23 ++- src/ert/run_models/base_run_model.py | 10 +- src/ert/run_models/event.py | 39 ++-- tests/ert/__init__.py | 2 + .../unit_tests/cli/test_model_hook_order.py | 11 +- .../test_status_events_serialization.py | 180 ++++++++++++++++++ tests/everest/test_everserver.py | 49 +---- 9 files changed, 277 insertions(+), 125 deletions(-) create mode 100644 tests/ert/unit_tests/run_models/test_status_events_serialization.py diff --git a/src/ert/analysis/event.py b/src/ert/analysis/event.py index 44f5772402b..dffeecf281f 100644 --- a/src/ert/analysis/event.py +++ b/src/ert/analysis/event.py @@ -4,28 +4,30 @@ from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path +from typing import Literal import pandas as pd +from pydantic import BaseModel, ConfigDict -@dataclass -class AnalysisEvent: +class AnalysisEvent(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") pass -@dataclass class AnalysisStatusEvent(AnalysisEvent): + event_type: Literal["AnalysisStatusEvent"] = "AnalysisStatusEvent" msg: str -@dataclass class AnalysisTimeEvent(AnalysisEvent): + event_type: Literal["AnalysisTimeEvent"] = "AnalysisTimeEvent" remaining_time: float elapsed_time: float -@dataclass class AnalysisReportEvent(AnalysisEvent): + event_type: Literal["AnalysisReportEvent"] = "AnalysisReportEvent" report: str @@ -56,18 +58,18 @@ def to_csv(self, name: str, output_path: Path) -> None: df.to_csv(f_path.with_suffix(".csv")) -@dataclass class AnalysisDataEvent(AnalysisEvent): + event_type: Literal["AnalysisDataEvent"] = "AnalysisDataEvent" name: str data: DataSection -@dataclass class AnalysisErrorEvent(AnalysisEvent): + event_type: Literal["AnalysisErrorEvent"] = "AnalysisErrorEvent" error_msg: str data: DataSection | None = None -@dataclass class AnalysisCompleteEvent(AnalysisEvent): + event_type: Literal["AnalysisCompleteEvent"] = "AnalysisCompleteEvent" data: DataSection diff --git a/src/ert/ensemble_evaluator/event.py b/src/ert/ensemble_evaluator/event.py index deabab3c7c7..df7e5b7ea94 100644 --- a/src/ert/ensemble_evaluator/event.py +++ b/src/ert/ensemble_evaluator/event.py @@ -1,59 +1,49 @@ -import json -from dataclasses import dataclass +from collections.abc import Mapping +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, field_serializer, field_validator from .snapshot import EnsembleSnapshot -@dataclass -class _UpdateEvent: +class _UpdateEvent(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") iteration_label: str total_iterations: int progress: float realization_count: int status_count: dict[str, int] iteration: int + snapshot: EnsembleSnapshot | None = None + + @field_serializer("snapshot") + def serialize_snapshot( + self, value: EnsembleSnapshot | None + ) -> dict[str, Any] | None: + if value is None: + return None + return value.to_dict() + + @field_validator("snapshot", mode="before") + @classmethod + def validate_snapshot( + cls, value: EnsembleSnapshot | Mapping[Any, Any] + ) -> EnsembleSnapshot: + if isinstance(value, EnsembleSnapshot): + return value + return EnsembleSnapshot.from_nested_dict(value) -@dataclass class FullSnapshotEvent(_UpdateEvent): - snapshot: EnsembleSnapshot | None = None + event_type: Literal["FullSnapshotEvent"] = "FullSnapshotEvent" -@dataclass class SnapshotUpdateEvent(_UpdateEvent): - snapshot: EnsembleSnapshot | None = None + event_type: Literal["SnapshotUpdateEvent"] = "SnapshotUpdateEvent" -@dataclass -class EndEvent: +class EndEvent(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + event_type: Literal["EndEvent"] = "EndEvent" failed: bool - msg: str | None = None - - -def snapshot_event_from_json(json_str: str) -> FullSnapshotEvent | SnapshotUpdateEvent: - json_dict = json.loads(json_str) - snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"]) - json_dict["snapshot"] = snapshot - match json_dict.pop("type"): - case "FullSnapshotEvent": - return FullSnapshotEvent(**json_dict) - case "SnapshotUpdateEvent": - return SnapshotUpdateEvent(**json_dict) - case unknown: - raise TypeError(f"Unknown snapshot update event type {unknown}") - - -def snapshot_event_to_json(event: FullSnapshotEvent | SnapshotUpdateEvent) -> str: - assert event.snapshot is not None - return json.dumps( - { - "iteration_label": event.iteration_label, - "total_iterations": event.total_iterations, - "progress": event.progress, - "realization_count": event.realization_count, - "status_count": event.status_count, - "iteration": event.iteration, - "snapshot": event.snapshot.to_dict(), - "type": event.__class__.__name__, - } - ) + msg: str diff --git a/src/ert/ensemble_evaluator/snapshot.py b/src/ert/ensemble_evaluator/snapshot.py index 76f017172f2..77170810e92 100644 --- a/src/ert/ensemble_evaluator/snapshot.py +++ b/src/ert/ensemble_evaluator/snapshot.py @@ -65,8 +65,10 @@ class UnsupportedOperationException(ValueError): def convert_iso8601_to_datetime( - timestamp: datetime | str, -) -> datetime: + timestamp: datetime | str | None, +) -> datetime | None: + if timestamp is None: + return None if isinstance(timestamp, datetime): return timestamp @@ -161,13 +163,17 @@ def to_dict(self) -> dict[str, Any]: if self._ensemble_state: dict_["status"] = self._ensemble_state if self._realization_snapshots: - dict_["reals"] = self._realization_snapshots + dict_["reals"] = { + k: _filter_nones(v) for k, v in self._realization_snapshots.items() + } for (real_id, fm_id), fm_values_dict in self._fm_step_snapshots.items(): if "reals" not in dict_: dict_["reals"] = {} if real_id not in dict_["reals"]: - dict_["reals"][real_id] = RealizationSnapshot(fm_steps={}) + dict_["reals"][real_id] = _filter_nones( + RealizationSnapshot(fm_steps={}) + ) if "fm_steps" not in dict_["reals"][real_id]: dict_["reals"][real_id]["fm_steps"] = {} @@ -400,12 +406,17 @@ def _realization_dict_to_realization_snapshot( realization = RealizationSnapshot( status=source.get("status"), active=source.get("active"), - start_time=source.get("start_time"), - end_time=source.get("end_time"), + start_time=convert_iso8601_to_datetime(source.get("start_time")), + end_time=convert_iso8601_to_datetime(source.get("end_time")), exec_hosts=source.get("exec_hosts"), message=source.get("message"), fm_steps=source.get("fm_steps", {}), ) + for step in realization["fm_steps"].values(): + if step.get("start_time"): + step["start_time"] = convert_iso8601_to_datetime(step["start_time"]) + if step.get("end_time"): + step["end_time"] = convert_iso8601_to_datetime(step["end_time"]) return _filter_nones(realization) diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index 4fa3b84feee..fb187096ed1 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -20,9 +20,6 @@ from _ert.events import EESnapshot, EESnapshotUpdate, EETerminated, Event from ert.analysis import ( - AnalysisEvent, - AnalysisStatusEvent, - AnalysisTimeEvent, ErtAnalysisError, smoother_update, ) @@ -30,6 +27,7 @@ AnalysisCompleteEvent, AnalysisDataEvent, AnalysisErrorEvent, + AnalysisEvent, ) from ert.config import HookRuntime, QueueSystem from ert.config.analysis_module import BaseSettings @@ -44,11 +42,6 @@ Monitor, Realization, ) -from ert.ensemble_evaluator.event import ( - EndEvent, - FullSnapshotEvent, - SnapshotUpdateEvent, -) from ert.ensemble_evaluator.identifiers import STATUS from ert.ensemble_evaluator.snapshot import EnsembleSnapshot from ert.ensemble_evaluator.state import ( @@ -68,7 +61,6 @@ from ..config.analysis_config import UpdateSettings from ..run_arg import RunArg from .event import ( - AnalysisEvent, AnalysisStatusEvent, AnalysisTimeEvent, EndEvent, diff --git a/src/ert/run_models/event.py b/src/ert/run_models/event.py index 795935767ff..32021ec19b4 100644 --- a/src/ert/run_models/event.py +++ b/src/ert/run_models/event.py @@ -1,11 +1,12 @@ from __future__ import annotations -from dataclasses import dataclass from pathlib import Path +from typing import Annotated, Literal from uuid import UUID +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter + from ert.analysis import ( - AnalysisEvent, AnalysisStatusEvent, AnalysisTimeEvent, ) @@ -17,30 +18,30 @@ ) -@dataclass -class RunModelEvent: +class RunModelEvent(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") iteration: int run_id: UUID -@dataclass class RunModelStatusEvent(RunModelEvent): + event_type: Literal["RunModelStatusEvent"] = "RunModelStatusEvent" msg: str -@dataclass class RunModelTimeEvent(RunModelEvent): + event_type: Literal["RunModelTimeEvent"] = "RunModelTimeEvent" remaining_time: float elapsed_time: float -@dataclass class RunModelUpdateBeginEvent(RunModelEvent): + event_type: Literal["RunModelUpdateBeginEvent"] = "RunModelUpdateBeginEvent" pass -@dataclass class RunModelDataEvent(RunModelEvent): + event_type: Literal["RunModelDataEvent"] = "RunModelDataEvent" name: str data: DataSection @@ -49,8 +50,8 @@ def write_as_csv(self, output_path: Path | None) -> None: self.data.to_csv(self.name, output_path / str(self.run_id)) -@dataclass class RunModelUpdateEndEvent(RunModelEvent): + event_type: Literal["RunModelUpdateEndEvent"] = "RunModelUpdateEndEvent" data: DataSection def write_as_csv(self, output_path: Path | None) -> None: @@ -58,10 +59,10 @@ def write_as_csv(self, output_path: Path | None) -> None: self.data.to_csv("Report", output_path / str(self.run_id)) -@dataclass class RunModelErrorEvent(RunModelEvent): + event_type: Literal["RunModelErrorEvent"] = "RunModelErrorEvent" error_msg: str - data: DataSection | None = None + data: DataSection def write_as_csv(self, output_path: Path | None) -> None: if output_path and self.data: @@ -69,8 +70,7 @@ def write_as_csv(self, output_path: Path | None) -> None: StatusEvents = ( - AnalysisEvent - | AnalysisStatusEvent + AnalysisStatusEvent | AnalysisTimeEvent | EndEvent | FullSnapshotEvent @@ -82,3 +82,16 @@ def write_as_csv(self, output_path: Path | None) -> None: | RunModelDataEvent | RunModelUpdateEndEvent ) + + +STATUS_EVENTS_ANNOTATION = Annotated[StatusEvents, Field(discriminator="event_type")] + +StatusEventAdapter: TypeAdapter[StatusEvents] = TypeAdapter(STATUS_EVENTS_ANNOTATION) + + +def status_event_from_json(raw_msg: str | bytes) -> StatusEvents: + return StatusEventAdapter.validate_json(raw_msg) + + +def status_event_to_json(event: StatusEvents) -> str: + return event.model_dump_json() diff --git a/tests/ert/__init__.py b/tests/ert/__init__.py index ee05dbc228c..e3db923e86d 100644 --- a/tests/ert/__init__.py +++ b/tests/ert/__init__.py @@ -38,6 +38,7 @@ def build( exec_hosts: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, + message: str | None = None, ) -> EnsembleSnapshot: snapshot = EnsembleSnapshot() snapshot._ensemble_state = status @@ -53,6 +54,7 @@ def build( end_time=end_time, exec_hosts=exec_hosts, status=status, + message=message, ), ) return snapshot diff --git a/tests/ert/unit_tests/cli/test_model_hook_order.py b/tests/ert/unit_tests/cli/test_model_hook_order.py index cec455baff6..d17b67e7232 100644 --- a/tests/ert/unit_tests/cli/test_model_hook_order.py +++ b/tests/ert/unit_tests/cli/test_model_hook_order.py @@ -1,3 +1,4 @@ +import uuid from unittest.mock import ANY, MagicMock, PropertyMock, call, patch import pytest @@ -48,6 +49,7 @@ def test_hook_call_order_ensemble_smoother(monkeypatch): ens_mock = MagicMock() ens_mock.iteration = 0 + ens_mock.id = uuid.uuid1() storage_mock = MagicMock() storage_mock.create_ensemble.return_value = ens_mock @@ -86,6 +88,7 @@ def test_hook_call_order_es_mda(monkeypatch): ens_mock = MagicMock() ens_mock.iteration = 0 + ens_mock.id = uuid.uuid1() storage_mock = MagicMock() storage_mock.create_ensemble.return_value = ens_mock test_class = MultipleDataAssimilation( @@ -115,9 +118,15 @@ def test_hook_call_order_iterative_ensemble_smoother(monkeypatch): monkeypatch.setattr(base_run_model, "_seed_sequence", MagicMock(return_value=0)) monkeypatch.setattr(base_run_model.BaseRunModel, "run_workflows", run_wfs_mock) + ens_mock = MagicMock() + ens_mock.iteration = 2 + ens_mock.id = uuid.uuid1() + storage_mock = MagicMock() + storage_mock.create_ensemble.return_value = ens_mock + test_class = IteratedEnsembleSmoother(*[MagicMock()] * 13) test_class.run_ensemble_evaluator = MagicMock(return_value=[0]) - + test_class._storage = storage_mock # Mock the return values of iterative_smoother_update # Mock the iteration property of IteratedEnsembleSmoother with ( diff --git a/tests/ert/unit_tests/run_models/test_status_events_serialization.py b/tests/ert/unit_tests/run_models/test_status_events_serialization.py new file mode 100644 index 00000000000..da53ea5803c --- /dev/null +++ b/tests/ert/unit_tests/run_models/test_status_events_serialization.py @@ -0,0 +1,180 @@ +import uuid +from collections import defaultdict +from datetime import datetime as dt + +import pytest + +from ert.analysis.event import DataSection +from ert.ensemble_evaluator import state +from ert.ensemble_evaluator.snapshot import EnsembleSnapshotMetadata +from ert.run_models.event import ( + AnalysisStatusEvent, + AnalysisTimeEvent, + EndEvent, + FullSnapshotEvent, + RunModelDataEvent, + RunModelStatusEvent, + RunModelTimeEvent, + RunModelUpdateBeginEvent, + RunModelUpdateEndEvent, + SnapshotUpdateEvent, + status_event_from_json, + status_event_to_json, +) +from tests.ert import SnapshotBuilder + +METADATA = EnsembleSnapshotMetadata( + aggr_fm_step_status_colors=defaultdict(dict), + real_status_colors={}, + sorted_real_ids=[], + sorted_fm_step_ids=defaultdict(list), +) + + +@pytest.mark.parametrize( + "event", + [ + pytest.param( + FullSnapshotEvent( + snapshot=( + SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="0", + index="0", + name="fm_step_0", + status=state.FORWARD_MODEL_STATE_START, + current_memory_usage="500", + max_memory_usage="1000", + stdout="job_fm_step_0.stdout", + stderr="job_fm_step_0.stderr", + start_time=dt(1999, 1, 1), + ) + .add_fm_step( + fm_step_id="1", + index="1", + name="fm_step_1", + status=state.FORWARD_MODEL_STATE_START, + current_memory_usage="500", + max_memory_usage="1000", + stdout="job_fm_step_1.stdout", + stderr="job_fm_step_1.stderr", + start_time=dt(1999, 1, 1), + end_time=None, + ) + .build( + real_ids=["0", "1"], + status=state.REALIZATION_STATE_UNKNOWN, + start_time=dt(1999, 1, 1), + exec_hosts="12121.121", + message="Some message", + ) + ), + iteration_label="Foo", + total_iterations=1, + progress=0.25, + realization_count=4, + status_count={"Finished": 1, "Pending": 2, "Unknown": 1}, + iteration=0, + ), + id="FullSnapshotEvent", + ), + pytest.param( + SnapshotUpdateEvent( + snapshot=SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="0", + index="0", + status=state.FORWARD_MODEL_STATE_FINISHED, + name="fm_step_0", + end_time=dt(2019, 1, 1), + ) + .build( + real_ids=["1"], + status=state.REALIZATION_STATE_RUNNING, + ), + iteration_label="Foo", + total_iterations=1, + progress=0.5, + realization_count=4, + status_count={"Finished": 2, "Running": 1, "Unknown": 1}, + iteration=0, + ), + id="SnapshotUpdateEvent1", + ), + pytest.param( + SnapshotUpdateEvent( + snapshot=SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="1", + index="1", + status=state.FORWARD_MODEL_STATE_FAILURE, + name="fm_step_1", + ) + .build( + real_ids=["0"], + status=state.REALIZATION_STATE_FAILED, + end_time=dt(2019, 1, 1), + ), + iteration_label="Foo", + total_iterations=1, + progress=0.5, + realization_count=4, + status_count={"Finished": 2, "Failed": 1, "Unknown": 1}, + iteration=0, + ), + id="SnapshotUpdateEvent2", + ), + pytest.param(AnalysisStatusEvent(msg="hello"), id="AnalysisStatusEvent"), + pytest.param( + AnalysisTimeEvent(remaining_time=22.2, elapsed_time=200.42), + id="AnalysisTimeEvent", + ), + pytest.param(EndEvent(failed=False, msg=""), id="EndEvent"), + pytest.param( + RunModelStatusEvent(iteration=1, run_id=uuid.uuid1(), msg="Hello"), + id="RunModelStatusEvent", + ), + pytest.param( + RunModelTimeEvent( + iteration=1, + run_id=uuid.uuid1(), + remaining_time=10.42, + elapsed_time=100.42, + ), + id="RunModelTimeEvent", + ), + pytest.param( + RunModelUpdateBeginEvent(iteration=2, run_id=uuid.uuid1()), + id="RunModelUpdateBeginEvent", + ), + pytest.param( + RunModelDataEvent( + iteration=1, + run_id=uuid.uuid1(), + name="Micky", + data=DataSection( + header=["Some", "string", "elements"], + data=[["a", 1.1, "b"], ["c", 3]], + extra={"a": "b", "c": "d"}, + ), + ), + id="RunModelDataEvent", + ), + pytest.param( + RunModelUpdateEndEvent( + iteration=3, + run_id=uuid.uuid1(), + data=DataSection( + header=["Some", "string", "elements"], + data=[["a", 1.1, "b"], ["c", 3]], + extra={"a": "b", "c": "d"}, + ), + ), + id="RunModelUpdateEndEvent", + ), + ], +) +def test_status_event_serialization(event): + json_res = status_event_to_json(event) + round_trip_event = status_event_from_json(json_res) + assert event == round_trip_event diff --git a/tests/everest/test_everserver.py b/tests/everest/test_everserver.py index 68f8af189a5..4075e3fbc41 100644 --- a/tests/everest/test_everserver.py +++ b/tests/everest/test_everserver.py @@ -7,13 +7,7 @@ from seba_sqlite.snapshot import SebaSnapshot -from _ert.events import event_from_json, event_to_json -from ert.ensemble_evaluator import FullSnapshotEvent, SnapshotUpdateEvent -from ert.ensemble_evaluator.event import ( - snapshot_event_from_json, - snapshot_event_to_json, -) -from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel +from ert.run_models.everest_run_model import EverestExitCode from everest.config import EverestConfig, OptimizationConfig, ServerConfig from everest.detached import ServerStatus, everserver_status from everest.detached.jobs import everserver @@ -254,44 +248,3 @@ def test_everserver_status_contains_max_runtime_failure( "sleep Failed with: The run is cancelled due to reaching MAX_RUNTIME" in status["message"] ) - - -def test_event_serialization( - copy_math_func_test_data_to_tmp, - evaluator_server_config_generator, -): - config = EverestConfig.load_file("config_minimal.yml") - - def check_status_round_tripping(status): - round_trip_status = json.loads(json.dumps(status)) - assert round_trip_status == status - - run_model = EverestRunModel.create( - config, - simulation_callback=check_status_round_tripping, - ) - send_event = run_model.send_event - send_snapshot_event = run_model.send_snapshot_event - - def check_snapshot_event_serialization_round_trip(*args, **_): - event, _ = args - event_json = event_to_json(event) - round_trip_event = event_from_json(str(event_json)) - assert event == round_trip_event - send_snapshot_event(*args) - - def check_event_serialization_round_trip(event): - if isinstance(event, (FullSnapshotEvent | SnapshotUpdateEvent)): - json_str = snapshot_event_to_json(event) - round_trip = snapshot_event_from_json(json_str) - assert event == round_trip - send_event(event) - - run_model.send_event = check_event_serialization_round_trip - run_model.send_snapshot_event = check_snapshot_event_serialization_round_trip - - evaluator_server_config = evaluator_server_config_generator(run_model) - - run_model.run_experiment(evaluator_server_config) - - assert run_model.result is not None