Skip to content

Commit

Permalink
Make sure FullSnapshotEvent and SnapshotUpdateEvent is json serializable
Browse files Browse the repository at this point in the history
  • Loading branch information
DanSava committed Jan 8, 2025
1 parent 002662f commit 7089a3b
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 4 deletions.
30 changes: 30 additions & 0 deletions src/ert/ensemble_evaluator/event.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from dataclasses import dataclass

from .snapshot import EnsembleSnapshot
Expand Down Expand Up @@ -27,3 +28,32 @@ class SnapshotUpdateEvent(_UpdateEvent):
class EndEvent:
failed: bool
msg: str | None = None


def snapshot_event_from_json(json_str: str) -> FullSnapshotEvent | SnapshotUpdateEvent:
json_dict = json.loads(json_str)
snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"])
json_dict["snapshot"] = snapshot
match json_dict.pop("type"):
case "FullSnapshotEvent":
return FullSnapshotEvent(**json_dict)
case "SnapshotUpdateEvent":
return SnapshotUpdateEvent(**json_dict)
case unknown:
raise TypeError(f"Unknown snapshot update event type {unknown}")


def snapshot_event_to_json(event: FullSnapshotEvent | SnapshotUpdateEvent) -> str:
assert event.snapshot is not None
return json.dumps(
{
"iteration_label": event.iteration_label,
"total_iterations": event.total_iterations,
"progress": event.progress,
"realization_count": event.realization_count,
"status_count": event.status_count,
"iteration": event.iteration,
"snapshot": event.snapshot.to_dict(),
"type": event.__class__.__name__,
}
)
5 changes: 5 additions & 0 deletions src/ert/ensemble_evaluator/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def __init__(self) -> None:
sorted_fm_step_ids=defaultdict(list),
)

def __eq__(self, other: object) -> bool:
if not isinstance(other, EnsembleSnapshot):
return NotImplemented
return self.to_dict() == other.to_dict()

@classmethod
def from_nested_dict(cls, source: Mapping[Any, Any]) -> EnsembleSnapshot:
ensemble = EnsembleSnapshot()
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def get_current_status(self) -> dict[str, int]:
status["Finished"] += (
self._get_number_of_finished_realizations_from_reruns()
)
return status
return dict(status)

def _get_number_of_finished_realizations_from_reruns(self) -> int:
return self.active_realizations.count(
Expand Down
21 changes: 18 additions & 3 deletions tests/everest/test_everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from seba_sqlite.snapshot import SebaSnapshot

from _ert.events import event_from_json, event_to_json
from ert.ensemble_evaluator import FullSnapshotEvent, SnapshotUpdateEvent
from ert.ensemble_evaluator.event import (
snapshot_event_from_json,
snapshot_event_to_json,
)
from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel
from everest.config import EverestConfig, OptimizationConfig, ServerConfig
from everest.detached import ServerStatus, everserver_status
Expand Down Expand Up @@ -265,18 +270,28 @@ def check_status_round_tripping(status):
config,
simulation_callback=check_status_round_tripping,
)

send_event = run_model.send_event
send_snapshot_event = run_model.send_snapshot_event

def check_event_serialization_round_trip(*args, **_):
def check_snapshot_event_serialization_round_trip(*args, **_):
event, _ = args
event_json = event_to_json(event)
round_trip_event = event_from_json(str(event_json))
assert event == round_trip_event
send_snapshot_event(*args)

run_model.send_snapshot_event = check_event_serialization_round_trip
def check_event_serialization_round_trip(event):
if isinstance(event, (FullSnapshotEvent | SnapshotUpdateEvent)):
json_str = snapshot_event_to_json(event)
round_trip = snapshot_event_from_json(json_str)
assert event == round_trip
send_event(event)

run_model.send_event = check_event_serialization_round_trip
run_model.send_snapshot_event = check_snapshot_event_serialization_round_trip

evaluator_server_config = evaluator_server_config_generator(run_model)

run_model.run_experiment(evaluator_server_config)

assert run_model.result is not None

0 comments on commit 7089a3b

Please sign in to comment.