diff --git a/src/ert/analysis/event.py b/src/ert/analysis/event.py index 44f5772402b..dffeecf281f 100644 --- a/src/ert/analysis/event.py +++ b/src/ert/analysis/event.py @@ -4,28 +4,30 @@ from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path +from typing import Literal import pandas as pd +from pydantic import BaseModel, ConfigDict -@dataclass -class AnalysisEvent: +class AnalysisEvent(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") pass -@dataclass class AnalysisStatusEvent(AnalysisEvent): + event_type: Literal["AnalysisStatusEvent"] = "AnalysisStatusEvent" msg: str -@dataclass class AnalysisTimeEvent(AnalysisEvent): + event_type: Literal["AnalysisTimeEvent"] = "AnalysisTimeEvent" remaining_time: float elapsed_time: float -@dataclass class AnalysisReportEvent(AnalysisEvent): + event_type: Literal["AnalysisReportEvent"] = "AnalysisReportEvent" report: str @@ -56,18 +58,18 @@ def to_csv(self, name: str, output_path: Path) -> None: df.to_csv(f_path.with_suffix(".csv")) -@dataclass class AnalysisDataEvent(AnalysisEvent): + event_type: Literal["AnalysisDataEvent"] = "AnalysisDataEvent" name: str data: DataSection -@dataclass class AnalysisErrorEvent(AnalysisEvent): + event_type: Literal["AnalysisErrorEvent"] = "AnalysisErrorEvent" error_msg: str data: DataSection | None = None -@dataclass class AnalysisCompleteEvent(AnalysisEvent): + event_type: Literal["AnalysisCompleteEvent"] = "AnalysisCompleteEvent" data: DataSection diff --git a/src/ert/cli/main.py b/src/ert/cli/main.py index 547ec323cd0..7552851de7a 100644 --- a/src/ert/cli/main.py +++ b/src/ert/cli/main.py @@ -24,7 +24,7 @@ ) from ert.namespace import Namespace from ert.plugins import ErtPluginManager -from ert.run_models.base_run_model import StatusEvents +from ert.run_models.event import StatusEvents from ert.run_models.model_factory import create_model from ert.storage import open_storage from ert.storage.local_storage import local_storage_set_ert_config diff --git a/src/ert/cli/monitor.py b/src/ert/cli/monitor.py index 921b584f222..429b997ca53 100644 --- a/src/ert/cli/monitor.py +++ b/src/ert/cli/monitor.py @@ -22,11 +22,11 @@ FORWARD_MODEL_STATE_FAILURE, REAL_STATE_TO_COLOR, ) -from ert.run_models.base_run_model import StatusEvents from ert.run_models.event import ( RunModelDataEvent, RunModelErrorEvent, RunModelUpdateEndEvent, + StatusEvents, ) from ert.shared.status.utils import format_running_time diff --git a/src/ert/ensemble_evaluator/event.py b/src/ert/ensemble_evaluator/event.py index 5639269fd6f..df7e5b7ea94 100644 --- a/src/ert/ensemble_evaluator/event.py +++ b/src/ert/ensemble_evaluator/event.py @@ -1,29 +1,49 @@ -from dataclasses import dataclass +from collections.abc import Mapping +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, field_serializer, field_validator from .snapshot import EnsembleSnapshot -@dataclass -class _UpdateEvent: +class _UpdateEvent(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") iteration_label: str total_iterations: int progress: float realization_count: int status_count: dict[str, int] iteration: int + snapshot: EnsembleSnapshot | None = None + + @field_serializer("snapshot") + def serialize_snapshot( + self, value: EnsembleSnapshot | None + ) -> dict[str, Any] | None: + if value is None: + return None + return value.to_dict() + + @field_validator("snapshot", mode="before") + @classmethod + def validate_snapshot( + cls, value: EnsembleSnapshot | Mapping[Any, Any] + ) -> EnsembleSnapshot: + if isinstance(value, EnsembleSnapshot): + return value + return EnsembleSnapshot.from_nested_dict(value) -@dataclass class FullSnapshotEvent(_UpdateEvent): - snapshot: EnsembleSnapshot | None = None + event_type: Literal["FullSnapshotEvent"] = "FullSnapshotEvent" -@dataclass class SnapshotUpdateEvent(_UpdateEvent): - snapshot: EnsembleSnapshot | None = None + event_type: Literal["SnapshotUpdateEvent"] = "SnapshotUpdateEvent" -@dataclass -class EndEvent: +class EndEvent(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + event_type: Literal["EndEvent"] = "EndEvent" failed: bool - msg: str | None = None + msg: str diff --git a/src/ert/ensemble_evaluator/snapshot.py b/src/ert/ensemble_evaluator/snapshot.py index 1b4c5ebb6de..77170810e92 100644 --- a/src/ert/ensemble_evaluator/snapshot.py +++ b/src/ert/ensemble_evaluator/snapshot.py @@ -65,8 +65,10 @@ class UnsupportedOperationException(ValueError): def convert_iso8601_to_datetime( - timestamp: datetime | str, -) -> datetime: + timestamp: datetime | str | None, +) -> datetime | None: + if timestamp is None: + return None if isinstance(timestamp, datetime): return timestamp @@ -113,6 +115,11 @@ def __init__(self) -> None: sorted_fm_step_ids=defaultdict(list), ) + def __eq__(self, other: object) -> bool: + if not isinstance(other, EnsembleSnapshot): + return NotImplemented + return self.to_dict() == other.to_dict() + @classmethod def from_nested_dict(cls, source: Mapping[Any, Any]) -> EnsembleSnapshot: ensemble = EnsembleSnapshot() @@ -156,13 +163,17 @@ def to_dict(self) -> dict[str, Any]: if self._ensemble_state: dict_["status"] = self._ensemble_state if self._realization_snapshots: - dict_["reals"] = self._realization_snapshots + dict_["reals"] = { + k: _filter_nones(v) for k, v in self._realization_snapshots.items() + } for (real_id, fm_id), fm_values_dict in self._fm_step_snapshots.items(): if "reals" not in dict_: dict_["reals"] = {} if real_id not in dict_["reals"]: - dict_["reals"][real_id] = RealizationSnapshot(fm_steps={}) + dict_["reals"][real_id] = _filter_nones( + RealizationSnapshot(fm_steps={}) + ) if "fm_steps" not in dict_["reals"][real_id]: dict_["reals"][real_id]["fm_steps"] = {} @@ -395,12 +406,17 @@ def _realization_dict_to_realization_snapshot( realization = RealizationSnapshot( status=source.get("status"), active=source.get("active"), - start_time=source.get("start_time"), - end_time=source.get("end_time"), + start_time=convert_iso8601_to_datetime(source.get("start_time")), + end_time=convert_iso8601_to_datetime(source.get("end_time")), exec_hosts=source.get("exec_hosts"), message=source.get("message"), fm_steps=source.get("fm_steps", {}), ) + for step in realization["fm_steps"].values(): + if step.get("start_time"): + step["start_time"] = convert_iso8601_to_datetime(step["start_time"]) + if step.get("end_time"): + step["end_time"] = convert_iso8601_to_datetime(step["end_time"]) return _filter_nones(realization) diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index d8c795e10ac..fb187096ed1 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -20,9 +20,6 @@ from _ert.events import EESnapshot, EESnapshotUpdate, EETerminated, Event from ert.analysis import ( - AnalysisEvent, - AnalysisStatusEvent, - AnalysisTimeEvent, ErtAnalysisError, smoother_update, ) @@ -30,6 +27,7 @@ AnalysisCompleteEvent, AnalysisDataEvent, AnalysisErrorEvent, + AnalysisEvent, ) from ert.config import HookRuntime, QueueSystem from ert.config.analysis_module import BaseSettings @@ -44,11 +42,6 @@ Monitor, Realization, ) -from ert.ensemble_evaluator.event import ( - EndEvent, - FullSnapshotEvent, - SnapshotUpdateEvent, -) from ert.ensemble_evaluator.identifiers import STATUS from ert.ensemble_evaluator.snapshot import EnsembleSnapshot from ert.ensemble_evaluator.state import ( @@ -68,12 +61,18 @@ from ..config.analysis_config import UpdateSettings from ..run_arg import RunArg from .event import ( + AnalysisStatusEvent, + AnalysisTimeEvent, + EndEvent, + FullSnapshotEvent, RunModelDataEvent, RunModelErrorEvent, RunModelStatusEvent, RunModelTimeEvent, RunModelUpdateBeginEvent, RunModelUpdateEndEvent, + SnapshotUpdateEvent, + StatusEvents, ) logger = logging.getLogger(__name__) @@ -81,21 +80,6 @@ if TYPE_CHECKING: from ert.config import QueueConfig -StatusEvents = ( - FullSnapshotEvent - | SnapshotUpdateEvent - | EndEvent - | AnalysisEvent - | AnalysisStatusEvent - | AnalysisTimeEvent - | RunModelErrorEvent - | RunModelStatusEvent - | RunModelTimeEvent - | RunModelUpdateBeginEvent - | RunModelDataEvent - | RunModelUpdateEndEvent -) - class OutOfOrderSnapshotUpdateException(ValueError): pass @@ -419,7 +403,7 @@ def get_current_status(self) -> dict[str, int]: status["Finished"] += ( self._get_number_of_finished_realizations_from_reruns() ) - return status + return dict(status) def _get_number_of_finished_realizations_from_reruns(self) -> int: return self.active_realizations.count( diff --git a/src/ert/run_models/event.py b/src/ert/run_models/event.py index 424c8ec9235..32021ec19b4 100644 --- a/src/ert/run_models/event.py +++ b/src/ert/run_models/event.py @@ -1,36 +1,47 @@ from __future__ import annotations -from dataclasses import dataclass from pathlib import Path +from typing import Annotated, Literal from uuid import UUID +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter + +from ert.analysis import ( + AnalysisStatusEvent, + AnalysisTimeEvent, +) from ert.analysis.event import DataSection +from ert.ensemble_evaluator.event import ( + EndEvent, + FullSnapshotEvent, + SnapshotUpdateEvent, +) -@dataclass -class RunModelEvent: +class RunModelEvent(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") iteration: int run_id: UUID -@dataclass class RunModelStatusEvent(RunModelEvent): + event_type: Literal["RunModelStatusEvent"] = "RunModelStatusEvent" msg: str -@dataclass class RunModelTimeEvent(RunModelEvent): + event_type: Literal["RunModelTimeEvent"] = "RunModelTimeEvent" remaining_time: float elapsed_time: float -@dataclass class RunModelUpdateBeginEvent(RunModelEvent): + event_type: Literal["RunModelUpdateBeginEvent"] = "RunModelUpdateBeginEvent" pass -@dataclass class RunModelDataEvent(RunModelEvent): + event_type: Literal["RunModelDataEvent"] = "RunModelDataEvent" name: str data: DataSection @@ -39,8 +50,8 @@ def write_as_csv(self, output_path: Path | None) -> None: self.data.to_csv(self.name, output_path / str(self.run_id)) -@dataclass class RunModelUpdateEndEvent(RunModelEvent): + event_type: Literal["RunModelUpdateEndEvent"] = "RunModelUpdateEndEvent" data: DataSection def write_as_csv(self, output_path: Path | None) -> None: @@ -48,11 +59,39 @@ def write_as_csv(self, output_path: Path | None) -> None: self.data.to_csv("Report", output_path / str(self.run_id)) -@dataclass class RunModelErrorEvent(RunModelEvent): + event_type: Literal["RunModelErrorEvent"] = "RunModelErrorEvent" error_msg: str - data: DataSection | None = None + data: DataSection def write_as_csv(self, output_path: Path | None) -> None: if output_path and self.data: self.data.to_csv("Report", output_path / str(self.run_id)) + + +StatusEvents = ( + AnalysisStatusEvent + | AnalysisTimeEvent + | EndEvent + | FullSnapshotEvent + | SnapshotUpdateEvent + | RunModelErrorEvent + | RunModelStatusEvent + | RunModelTimeEvent + | RunModelUpdateBeginEvent + | RunModelDataEvent + | RunModelUpdateEndEvent +) + + +STATUS_EVENTS_ANNOTATION = Annotated[StatusEvents, Field(discriminator="event_type")] + +StatusEventAdapter: TypeAdapter[StatusEvents] = TypeAdapter(STATUS_EVENTS_ANNOTATION) + + +def status_event_from_json(raw_msg: str | bytes) -> StatusEvents: + return StatusEventAdapter.validate_json(raw_msg) + + +def status_event_to_json(event: StatusEvents) -> str: + return event.model_dump_json() diff --git a/src/ert/run_models/model_factory.py b/src/ert/run_models/model_factory.py index e1df99a9d5f..0cede97e4d2 100644 --- a/src/ert/run_models/model_factory.py +++ b/src/ert/run_models/model_factory.py @@ -32,7 +32,7 @@ import numpy.typing as npt from ert.namespace import Namespace - from ert.run_models.base_run_model import StatusEvents + from ert.run_models.event import StatusEvents from ert.storage import Storage diff --git a/tests/ert/__init__.py b/tests/ert/__init__.py index ee05dbc228c..e3db923e86d 100644 --- a/tests/ert/__init__.py +++ b/tests/ert/__init__.py @@ -38,6 +38,7 @@ def build( exec_hosts: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, + message: str | None = None, ) -> EnsembleSnapshot: snapshot = EnsembleSnapshot() snapshot._ensemble_state = status @@ -53,6 +54,7 @@ def build( end_time=end_time, exec_hosts=exec_hosts, status=status, + message=message, ), ) return snapshot diff --git a/tests/ert/unit_tests/cli/test_model_hook_order.py b/tests/ert/unit_tests/cli/test_model_hook_order.py index cec455baff6..d17b67e7232 100644 --- a/tests/ert/unit_tests/cli/test_model_hook_order.py +++ b/tests/ert/unit_tests/cli/test_model_hook_order.py @@ -1,3 +1,4 @@ +import uuid from unittest.mock import ANY, MagicMock, PropertyMock, call, patch import pytest @@ -48,6 +49,7 @@ def test_hook_call_order_ensemble_smoother(monkeypatch): ens_mock = MagicMock() ens_mock.iteration = 0 + ens_mock.id = uuid.uuid1() storage_mock = MagicMock() storage_mock.create_ensemble.return_value = ens_mock @@ -86,6 +88,7 @@ def test_hook_call_order_es_mda(monkeypatch): ens_mock = MagicMock() ens_mock.iteration = 0 + ens_mock.id = uuid.uuid1() storage_mock = MagicMock() storage_mock.create_ensemble.return_value = ens_mock test_class = MultipleDataAssimilation( @@ -115,9 +118,15 @@ def test_hook_call_order_iterative_ensemble_smoother(monkeypatch): monkeypatch.setattr(base_run_model, "_seed_sequence", MagicMock(return_value=0)) monkeypatch.setattr(base_run_model.BaseRunModel, "run_workflows", run_wfs_mock) + ens_mock = MagicMock() + ens_mock.iteration = 2 + ens_mock.id = uuid.uuid1() + storage_mock = MagicMock() + storage_mock.create_ensemble.return_value = ens_mock + test_class = IteratedEnsembleSmoother(*[MagicMock()] * 13) test_class.run_ensemble_evaluator = MagicMock(return_value=[0]) - + test_class._storage = storage_mock # Mock the return values of iterative_smoother_update # Mock the iteration property of IteratedEnsembleSmoother with ( diff --git a/tests/ert/unit_tests/run_models/test_status_events_serialization.py b/tests/ert/unit_tests/run_models/test_status_events_serialization.py new file mode 100644 index 00000000000..da53ea5803c --- /dev/null +++ b/tests/ert/unit_tests/run_models/test_status_events_serialization.py @@ -0,0 +1,180 @@ +import uuid +from collections import defaultdict +from datetime import datetime as dt + +import pytest + +from ert.analysis.event import DataSection +from ert.ensemble_evaluator import state +from ert.ensemble_evaluator.snapshot import EnsembleSnapshotMetadata +from ert.run_models.event import ( + AnalysisStatusEvent, + AnalysisTimeEvent, + EndEvent, + FullSnapshotEvent, + RunModelDataEvent, + RunModelStatusEvent, + RunModelTimeEvent, + RunModelUpdateBeginEvent, + RunModelUpdateEndEvent, + SnapshotUpdateEvent, + status_event_from_json, + status_event_to_json, +) +from tests.ert import SnapshotBuilder + +METADATA = EnsembleSnapshotMetadata( + aggr_fm_step_status_colors=defaultdict(dict), + real_status_colors={}, + sorted_real_ids=[], + sorted_fm_step_ids=defaultdict(list), +) + + +@pytest.mark.parametrize( + "event", + [ + pytest.param( + FullSnapshotEvent( + snapshot=( + SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="0", + index="0", + name="fm_step_0", + status=state.FORWARD_MODEL_STATE_START, + current_memory_usage="500", + max_memory_usage="1000", + stdout="job_fm_step_0.stdout", + stderr="job_fm_step_0.stderr", + start_time=dt(1999, 1, 1), + ) + .add_fm_step( + fm_step_id="1", + index="1", + name="fm_step_1", + status=state.FORWARD_MODEL_STATE_START, + current_memory_usage="500", + max_memory_usage="1000", + stdout="job_fm_step_1.stdout", + stderr="job_fm_step_1.stderr", + start_time=dt(1999, 1, 1), + end_time=None, + ) + .build( + real_ids=["0", "1"], + status=state.REALIZATION_STATE_UNKNOWN, + start_time=dt(1999, 1, 1), + exec_hosts="12121.121", + message="Some message", + ) + ), + iteration_label="Foo", + total_iterations=1, + progress=0.25, + realization_count=4, + status_count={"Finished": 1, "Pending": 2, "Unknown": 1}, + iteration=0, + ), + id="FullSnapshotEvent", + ), + pytest.param( + SnapshotUpdateEvent( + snapshot=SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="0", + index="0", + status=state.FORWARD_MODEL_STATE_FINISHED, + name="fm_step_0", + end_time=dt(2019, 1, 1), + ) + .build( + real_ids=["1"], + status=state.REALIZATION_STATE_RUNNING, + ), + iteration_label="Foo", + total_iterations=1, + progress=0.5, + realization_count=4, + status_count={"Finished": 2, "Running": 1, "Unknown": 1}, + iteration=0, + ), + id="SnapshotUpdateEvent1", + ), + pytest.param( + SnapshotUpdateEvent( + snapshot=SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="1", + index="1", + status=state.FORWARD_MODEL_STATE_FAILURE, + name="fm_step_1", + ) + .build( + real_ids=["0"], + status=state.REALIZATION_STATE_FAILED, + end_time=dt(2019, 1, 1), + ), + iteration_label="Foo", + total_iterations=1, + progress=0.5, + realization_count=4, + status_count={"Finished": 2, "Failed": 1, "Unknown": 1}, + iteration=0, + ), + id="SnapshotUpdateEvent2", + ), + pytest.param(AnalysisStatusEvent(msg="hello"), id="AnalysisStatusEvent"), + pytest.param( + AnalysisTimeEvent(remaining_time=22.2, elapsed_time=200.42), + id="AnalysisTimeEvent", + ), + pytest.param(EndEvent(failed=False, msg=""), id="EndEvent"), + pytest.param( + RunModelStatusEvent(iteration=1, run_id=uuid.uuid1(), msg="Hello"), + id="RunModelStatusEvent", + ), + pytest.param( + RunModelTimeEvent( + iteration=1, + run_id=uuid.uuid1(), + remaining_time=10.42, + elapsed_time=100.42, + ), + id="RunModelTimeEvent", + ), + pytest.param( + RunModelUpdateBeginEvent(iteration=2, run_id=uuid.uuid1()), + id="RunModelUpdateBeginEvent", + ), + pytest.param( + RunModelDataEvent( + iteration=1, + run_id=uuid.uuid1(), + name="Micky", + data=DataSection( + header=["Some", "string", "elements"], + data=[["a", 1.1, "b"], ["c", 3]], + extra={"a": "b", "c": "d"}, + ), + ), + id="RunModelDataEvent", + ), + pytest.param( + RunModelUpdateEndEvent( + iteration=3, + run_id=uuid.uuid1(), + data=DataSection( + header=["Some", "string", "elements"], + data=[["a", 1.1, "b"], ["c", 3]], + extra={"a": "b", "c": "d"}, + ), + ), + id="RunModelUpdateEndEvent", + ), + ], +) +def test_status_event_serialization(event): + json_res = status_event_to_json(event) + round_trip_event = status_event_from_json(json_res) + assert event == round_trip_event