Skip to content

Commit

Permalink
Use pattern-matching for _on_event
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Nov 20, 2024
1 parent ba1791e commit fbb31a1
Showing 1 changed file with 70 additions and 62 deletions.
132 changes: 70 additions & 62 deletions src/ert/gui/simulation/run_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from queue import SimpleQueue
from typing import Callable, Optional
from uuid import UUID

from qtpy.QtCore import QModelIndex, QSize, Qt, QThread, QTimer, Signal, Slot
from qtpy.QtGui import (
Expand Down Expand Up @@ -32,6 +33,7 @@
)

from _ert.threading import ErtThread
from ert.analysis.event import DataSection
from ert.config import QueueSystem
from ert.ensemble_evaluator import (
EndEvent,
Expand Down Expand Up @@ -426,72 +428,78 @@ def _on_ticker(self) -> None:

@Slot(object)
def _on_event(self, event: object) -> None:
if isinstance(event, EndEvent):
self.simulation_done.emit(event.failed, event.msg)
self._ticker.stop()
elif isinstance(event, FullSnapshotEvent):
if event.snapshot is not None:
if self._restart:
match event:
case EndEvent(failed=failed, msg=msg):
self.simulation_done.emit(failed, msg)
self._ticker.stop()
case FullSnapshotEvent():
if event.snapshot is not None:
if self._restart:
self._snapshot_model._update_snapshot(
event.snapshot, str(event.iteration)
)
else:
self._snapshot_model._add_snapshot(
event.snapshot, str(event.iteration)
)
self.update_total_progress(event.progress, event.iteration_label)
self._progress_widget.update_progress(
event.status_count, event.realization_count
)
self.progress_update_event.emit(
event.status_count, event.realization_count
)
case SnapshotUpdateEvent():
if event.snapshot is not None:
self._snapshot_model._update_snapshot(
event.snapshot, str(event.iteration)
)
else:
self._snapshot_model._add_snapshot(
event.snapshot, str(event.iteration)
)
self.update_total_progress(event.progress, event.iteration_label)
self._progress_widget.update_progress(
event.status_count, event.realization_count
)
self.progress_update_event.emit(event.status_count, event.realization_count)
elif isinstance(event, SnapshotUpdateEvent):
if event.snapshot is not None:
self._snapshot_model._update_snapshot(
event.snapshot, str(event.iteration)
self._progress_widget.update_progress(
event.status_count, event.realization_count
)
self._progress_widget.update_progress(
event.status_count, event.realization_count
)
self.update_total_progress(event.progress, event.iteration_label)
self.progress_update_event.emit(event.status_count, event.realization_count)
elif isinstance(event, RunModelUpdateBeginEvent):
iteration = event.iteration
widget = UpdateWidget(iteration)
tab_index = self._tab_widget.addTab(widget, f"Update {iteration}")

if self._tab_widget.currentIndex() == self._tab_widget.count() - 2:
self._tab_widget.setCurrentIndex(tab_index)

widget.begin(event)

elif isinstance(event, RunModelUpdateEndEvent):
self._progress_widget.stop_waiting_progress_bar()
if (widget := self._get_update_widget(event.iteration)) is not None:
widget.end(event)

elif (isinstance(event, (RunModelStatusEvent, RunModelTimeEvent))) and (
widget := self._get_update_widget(event.iteration)
) is not None:
widget.update_status(event)

elif (isinstance(event, RunModelDataEvent)) and (
widget := self._get_update_widget(event.iteration)
) is not None:
widget.add_table(event)

elif isinstance(event, RunModelErrorEvent):
if (widget := self._get_update_widget(event.iteration)) is not None:
widget.error(event)

if (
isinstance(
event, (RunModelDataEvent, RunModelUpdateEndEvent, RunModelErrorEvent)
)
and self.output_path
):
name = event.name if hasattr(event, "name") else "Report"
if event.data:
event.data.to_csv(name, self.output_path / str(event.run_id))
self.update_total_progress(event.progress, event.iteration_label)
self.progress_update_event.emit(
event.status_count, event.realization_count
)
case RunModelUpdateBeginEvent():
iteration = event.iteration
widget = UpdateWidget(iteration)
tab_index = self._tab_widget.addTab(widget, f"Update {iteration}")

if self._tab_widget.currentIndex() == self._tab_widget.count() - 2:
self._tab_widget.setCurrentIndex(tab_index)

widget.begin(event)

case RunModelUpdateEndEvent():
self._progress_widget.stop_waiting_progress_bar()
widget = self._get_update_widget(event.iteration)
if widget is not None:
widget.end(event)
self._dump_event_data_to_json(event.data, "Report", event.run_id)

case RunModelStatusEvent() | RunModelTimeEvent():
widget = self._get_update_widget(event.iteration)
if widget is not None:
widget.update_status(event)

case RunModelDataEvent():
widget = self._get_update_widget(event.iteration)
if widget is not None:
widget.add_table(event)
self._dump_event_data_to_json(event.data, event.name, event.run_id)

case RunModelErrorEvent():
widget = self._get_update_widget(event.iteration)
if widget is not None:
widget.error(event)
self._dump_event_data_to_json(event.data, "Report", event.run_id)

def _dump_event_data_to_json(

Check failure on line 498 in src/ert/gui/simulation/run_dialog.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
self, data: Optional[DataSection], name: str, run_id: UUID
):
if self.output_path and data:
data.to_csv(name, self.output_path / str(run_id))

def _get_update_widget(self, iteration: int) -> UpdateWidget:
for i in range(0, self._tab_widget.count()):
Expand Down

0 comments on commit fbb31a1

Please sign in to comment.