Skip to content

Commit

Permalink
Add event queue and cancel functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
sondreso committed Jan 25, 2024
1 parent a620ea3 commit 7660192
Show file tree
Hide file tree
Showing 12 changed files with 172 additions and 35 deletions.
10 changes: 5 additions & 5 deletions src/ert/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import contextlib
import logging
import os
import queue
import sys
import threading
import time
from typing import Any, TextIO

from ert.cli import (
Expand Down Expand Up @@ -89,12 +89,15 @@ def run_cli(args: Namespace, _: Any = None) -> None:
observations=ert_config.observations,
)

status_queue = queue.SimpleQueue()

try:
model = create_model(
ert_config,
storage,
args,
experiment.id,
status_queue,
)
except ValueError as e:
raise ErtCliError(e) from e
Expand Down Expand Up @@ -132,12 +135,9 @@ def run_cli(args: Namespace, _: Any = None) -> None:
else:
out = sys.stderr
monitor = Monitor(out=out, color_always=args.color_always)
monitor.start()
model.add_send_event_callback(monitor.on_event)
thread.start()
try:
while not monitor.done:
time.sleep(0.5)
monitor.monitor(status_queue)
except (SystemExit, KeyboardInterrupt):
print("\nKilling simulations...")
# tracker.request_termination()
Expand Down
34 changes: 27 additions & 7 deletions src/ert/cli/model_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from queue import SimpleQueue
from typing import TYPE_CHECKING
from uuid import UUID

Expand Down Expand Up @@ -55,6 +56,7 @@ def create_model(
storage: StorageAccessor,
args: Namespace,
experiment_id: UUID,
status_queue: SimpleQueue,
) -> BaseRunModel:
logger = logging.getLogger(__name__)
logger.info(
Expand All @@ -75,28 +77,36 @@ def create_model(
)

if args.mode == TEST_RUN_MODE:
return _setup_single_test_run(config, storage, args, experiment_id)
return _setup_single_test_run(
config, storage, args, experiment_id, status_queue
)
elif args.mode == ENSEMBLE_EXPERIMENT_MODE:
return _setup_ensemble_experiment(config, storage, args, experiment_id)
return _setup_ensemble_experiment(
config, storage, args, experiment_id, status_queue
)
elif args.mode == ENSEMBLE_SMOOTHER_MODE:
return _setup_ensemble_smoother(
config, storage, args, experiment_id, update_settings
config, storage, args, experiment_id, update_settings, status_queue
)
elif args.mode == ES_MDA_MODE:
return _setup_multiple_data_assimilation(
config, storage, args, experiment_id, update_settings
config, storage, args, experiment_id, update_settings, status_queue
)
elif args.mode == ITERATIVE_ENSEMBLE_SMOOTHER_MODE:
return _setup_iterative_ensemble_smoother(
config, storage, args, experiment_id, update_settings
config, storage, args, experiment_id, update_settings, status_queue
)

else:
raise NotImplementedError(f"Run type not supported {args.mode}")


def _setup_single_test_run(
config: ErtConfig, storage: StorageAccessor, args: Namespace, experiment_id: UUID
config: ErtConfig,
storage: StorageAccessor,
args: Namespace,
experiment_id: UUID,
status_queue: SimpleQueue,
) -> SingleTestRun:
return SingleTestRun(
SingleTestRunArguments(
Expand All @@ -113,7 +123,11 @@ def _setup_single_test_run(


def _setup_ensemble_experiment(
config: ErtConfig, storage: StorageAccessor, args: Namespace, experiment_id: UUID
config: ErtConfig,
storage: StorageAccessor,
args: Namespace,
experiment_id: UUID,
status_queue: SimpleQueue,
) -> EnsembleExperiment:
min_realizations_count = config.analysis_config.minimum_required_realizations
active_realizations = _realizations(args, config.model_config.num_realizations)
Expand All @@ -140,6 +154,7 @@ def _setup_ensemble_experiment(
storage,
config.queue_config,
experiment_id,
status_queue,
)


Expand All @@ -149,6 +164,7 @@ def _setup_ensemble_smoother(
args: Namespace,
experiment_id: UUID,
update_settings: UpdateSettings,
status_queue: SimpleQueue,
) -> EnsembleSmoother:
return EnsembleSmoother(
ESRunArguments(
Expand All @@ -168,6 +184,7 @@ def _setup_ensemble_smoother(
experiment_id,
es_settings=config.analysis_config.es_module,
update_settings=update_settings,
status_queue=status_queue,
)


Expand All @@ -177,6 +194,7 @@ def _setup_multiple_data_assimilation(
args: Namespace,
experiment_id: UUID,
update_settings: UpdateSettings,
status_queue: SimpleQueue,
) -> MultipleDataAssimilation:
# Because the configuration of the CLI is different from the gui, we
# have a different way to get the restart information.
Expand Down Expand Up @@ -207,6 +225,7 @@ def _setup_multiple_data_assimilation(
prior_ensemble,
es_settings=config.analysis_config.es_module,
update_settings=update_settings,
status_queue=status_queue,
)


Expand All @@ -216,6 +235,7 @@ def _setup_iterative_ensemble_smoother(
args: Namespace,
id_: UUID,
update_settings: UpdateSettings,
status_queue: SimpleQueue,
) -> IteratedEnsembleSmoother:
return IteratedEnsembleSmoother(
SIESRunArguments(
Expand Down
20 changes: 19 additions & 1 deletion src/ert/cli/monitor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
import sys
from datetime import datetime, timedelta
from queue import SimpleQueue
from typing import Dict, Optional, TextIO, Tuple

from tqdm import tqdm
Expand Down Expand Up @@ -60,8 +61,25 @@ def __init__(self, out: TextIO = sys.stdout, color_always: bool = False) -> None
self.dot = ""
self.done = False

def start(self) -> None:
def monitor(
self,
event_queue: SimpleQueue,
) -> None:
self._start_time = datetime.now()
while True:
event = event_queue.get()
if isinstance(event, FullSnapshotEvent):
if event.snapshot is not None:
self._snapshots[event.iteration] = event.snapshot
self._progress = event.progress
elif isinstance(event, SnapshotUpdateEvent):
if event.partial_snapshot is not None:
self._snapshots[event.iteration].merge_event(event.partial_snapshot)
self._print_progress(event)
if isinstance(event, EndEvent):
self._print_result(event.failed, event.failed_msg)
self._print_job_errors()
return

def on_event(
self,
Expand Down
56 changes: 56 additions & 0 deletions src/ert/gui/simulation/queue_emitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import logging
from queue import SimpleQueue

from qtpy.QtCore import QObject, Signal, Slot

from ert.ensemble_evaluator import EndEvent, FullSnapshotEvent, SnapshotUpdateEvent
from ert.gui.model.snapshot import SnapshotModel

logger = logging.getLogger(__name__)


class QueueEmitter(QObject):
"""A worker that emits items put on a queue to qt subscribers."""

new_event = Signal(object)
done = Signal()

def __init__(
self,
event_queue: SimpleQueue,
parent=None,
):
super().__init__(parent)
logger.debug("init QueueEmitter")
self._event_queue = event_queue
self._stopped = False

@Slot()
def consume_and_emit(self):
logger.debug("tracking...")
while True:
event = self._event_queue.get()
if self._stopped:
logger.debug("stopped")
break

# pre-rendering in this thread to avoid work in main rendering thread
if isinstance(event, FullSnapshotEvent) and event.snapshot:
SnapshotModel.prerender(event.snapshot)
elif isinstance(event, SnapshotUpdateEvent) and event.partial_snapshot:
SnapshotModel.prerender(event.partial_snapshot)

logger.debug(f"emit {event}")
self.new_event.emit(event)

if isinstance(event, EndEvent):
logger.debug("got end event")
break

self.done.emit()
logger.debug("tracking done.")

@Slot()
def stop(self):
logger.debug("stopping...")
self._stopped = True
27 changes: 19 additions & 8 deletions src/ert/gui/simulation/run_dialog.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from queue import SimpleQueue
from threading import Thread
from typing import Optional

from PyQt5.QtWidgets import QAbstractItemView
from qtpy.QtCore import QModelIndex, QSize, Qt, QTimer, Signal, Slot
from qtpy.QtCore import QModelIndex, QSize, Qt, QThread, QTimer, Signal, Slot
from qtpy.QtGui import QMovie
from qtpy.QtWidgets import (
QDialog,
Expand Down Expand Up @@ -42,6 +43,7 @@
)
from ert.shared.status.utils import format_running_time

from .queue_emitter import QueueEmitter
from .view import LegendView, ProgressView, RealizationWidget, UpdateWidget

_TOTAL_PROGRESS_TEMPLATE = "Total progress {total_progress}% — {phase_name}"
Expand All @@ -55,6 +57,7 @@ def __init__(
self,
config_file: str,
run_model: BaseRunModel,
event_queue: SimpleQueue,
notifier: ErtNotifier,
parent=None,
):
Expand All @@ -66,6 +69,7 @@ def __init__(

self._snapshot_model = SnapshotModel(self)
self._run_model = run_model
self._event_queue = event_queue
self._notifier = notifier

self._isDetailedDialog = False
Expand Down Expand Up @@ -174,7 +178,6 @@ def __init__(
self._setSimpleDialog()
self.finished.connect(self._on_finished)

self._run_model.add_send_event_callback(self.on_run_model_event.emit)
self.on_run_model_event.connect(self._on_event)

def _current_tab_changed(self, index: int) -> None:
Expand Down Expand Up @@ -274,8 +277,20 @@ def startSimulation(self):
args=(evaluator_server_config,),
)

simulation_thread.start()
worker = QueueEmitter(self._event_queue)
worker_thread = QThread()
self._worker = worker
self._worker_thread = worker_thread

worker.done.connect(worker_thread.quit)
worker.new_event.connect(self._on_event)
worker.moveToThread(worker_thread)
self.simulation_done.connect(worker.stop)
worker_thread.started.connect(worker.consume_and_emit)

self._ticker.start(1000)
self._worker_thread.start()
simulation_thread.start()
self._notifier.set_is_simulation_running(True)

def killJobs(self):
Expand All @@ -287,9 +302,7 @@ def killJobs(self):
if kill_job == QMessageBox.Yes:
# Normally this slot would be invoked by the signal/slot system,
# but the worker is busy tracking the evaluation.
# self._tracker.request_termination()
# self._worker_thread.quit()
# self._worker_thread.wait()
self._run_model.cancel()
self._on_finished()
self.finished.emit(-1)
return kill_job
Expand Down Expand Up @@ -329,7 +342,6 @@ def _on_event(self, event: object):
self._show_done_button()
elif isinstance(event, FullSnapshotEvent):
if event.snapshot is not None:
SnapshotModel.prerender(event.snapshot)
self._snapshot_model._add_snapshot(event.snapshot, event.iteration)
self._progress_view.setIndeterminate(event.indeterminate)
progress = int(event.progress * 100)
Expand All @@ -343,7 +355,6 @@ def _on_event(self, event: object):

elif isinstance(event, SnapshotUpdateEvent):
if event.partial_snapshot is not None:
SnapshotModel.prerender(event.partial_snapshot)
self._snapshot_model._add_partial_snapshot(
event.partial_snapshot, event.iteration
)
Expand Down
5 changes: 4 additions & 1 deletion src/ert/gui/simulation/simulation_panel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import OrderedDict
from queue import SimpleQueue
from typing import Any, Dict

from qtpy.QtCore import QSize, Qt
Expand Down Expand Up @@ -181,6 +182,7 @@ def runSimulation(self):
abort = False
QApplication.setOverrideCursor(Qt.CursorShape.WaitCursor)
config = self.facade.config
event_queue = SimpleQueue()
try:
experiment = self._notifier.storage.create_experiment(
parameters=config.ensemble_config.parameter_configuration,
Expand All @@ -194,6 +196,7 @@ def runSimulation(self):
self._notifier.storage,
args,
experiment.id,
event_queue,
)
experiment.write_simulation_arguments(model.simulation_arguments)

Expand Down Expand Up @@ -254,7 +257,7 @@ def runSimulation(self):
QApplication.restoreOverrideCursor()

dialog = RunDialog(
self._config_file, model, self._notifier, self.parent()
self._config_file, model, event_queue, self._notifier, self.parent()
)
self.run_button.setEnabled(False)
self.run_button.setText(EXPERIMENT_IS_RUNNING_BUTTON_MESSAGE)
Expand Down
4 changes: 2 additions & 2 deletions src/ert/gui/tools/run_analysis/run_analysis_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def run(self):
update_settings,
config.analysis_config.es_module,
rng,
self.smoother_event_callback,
self.send_smoother_event,
log_path=config.analysis_config.log_path,
)
except ErtAnalysisError as e:
Expand All @@ -73,7 +73,7 @@ def run(self):

self.finished.emit(error, self._source_fs.name)

def smoother_event_callback(self, event: AnalysisEvent) -> None:
def send_smoother_event(self, event: AnalysisEvent) -> None:
if isinstance(event, AnalysisStatusEvent):
self.progress_update.emit(RunModelStatusEvent(iteration=0, msg=event.msg))
elif isinstance(event, AnalysisTimeEvent):
Expand Down
Loading

0 comments on commit 7660192

Please sign in to comment.