From 6eb7862d3a8cf797cec223858f3a014b0ce6b619 Mon Sep 17 00:00:00 2001 From: Sondre Sortland Date: Sun, 15 Sep 2024 15:01:49 +0200 Subject: [PATCH] format experiment server --- src/ert/experiment_server/main.py | 81 ++++++++++++++++++++++--------- 1 file changed, 57 insertions(+), 24 deletions(-) diff --git a/src/ert/experiment_server/main.py b/src/ert/experiment_server/main.py index 8171162512a..c7f0ebbb773 100644 --- a/src/ert/experiment_server/main.py +++ b/src/ert/experiment_server/main.py @@ -1,53 +1,81 @@ import asyncio import multiprocessing as mp import os -from concurrent.futures import ProcessPoolExecutor +import queue import uuid +from concurrent.futures import ProcessPoolExecutor from multiprocessing.queues import Queue -import queue +from typing import Dict, Tuple, Union from fastapi import BackgroundTasks, FastAPI, HTTPException, WebSocket +from fastapi.encoders import jsonable_encoder from pydantic import BaseModel, Field +from ert.config import ErtConfig, QueueSystem from ert.ensemble_evaluator import EvaluatorServerConfig -from ert.run_models.model_factory import create_model -from ert.run_models.base_run_model import BaseRunModel, StatusEvents -from ert.gui.simulation.ensemble_experiment_panel import Arguments as EnsembleExperimentArguments -from ert.gui.simulation.ensemble_smoother_panel import Arguments as EnsembleSmootherArguments -from ert.gui.simulation.evaluate_ensemble_panel import Arguments as EvaluateEnsembleArguments -from ert.gui.simulation.iterated_ensemble_smoother_panel import Arguments as IteratedEnsembleSmootherArguments +from ert.ensemble_evaluator.event import EndEvent, _UpdateEvent +from ert.gui.simulation.ensemble_experiment_panel import ( + Arguments as EnsembleExperimentArguments, +) +from ert.gui.simulation.ensemble_smoother_panel import ( + Arguments as EnsembleSmootherArguments, +) +from ert.gui.simulation.evaluate_ensemble_panel import ( + Arguments as EvaluateEnsembleArguments, +) +from ert.gui.simulation.iterated_ensemble_smoother_panel import ( + Arguments as IteratedEnsembleSmootherArguments, +) from ert.gui.simulation.manual_update_panel import Arguments as ManualUpdateArguments -from ert.gui.simulation.multiple_data_assimilation_panel import Arguments as MultipleDataAssimilationArguments +from ert.gui.simulation.multiple_data_assimilation_panel import ( + Arguments as MultipleDataAssimilationArguments, +) from ert.gui.simulation.single_test_run_panel import Arguments as SingleTestRunArguments +from ert.run_models.base_run_model import BaseRunModel, StatusEvents +from ert.run_models.model_factory import create_model from ert.storage import open_storage -from ert.ensemble_evaluator.event import _UpdateEvent, EndEvent -from typing import Dict, Union, Tuple - -from ert.config import ErtConfig, QueueSystem -from fastapi.encoders import jsonable_encoder class Experiment(BaseModel): - args: Union[EnsembleExperimentArguments, EnsembleSmootherArguments, EvaluateEnsembleArguments, IteratedEnsembleSmootherArguments, ManualUpdateArguments, MultipleDataAssimilationArguments, SingleTestRunArguments] = Field(..., discriminator='mode') + args: Union[ + EnsembleExperimentArguments, + EnsembleSmootherArguments, + EvaluateEnsembleArguments, + IteratedEnsembleSmootherArguments, + ManualUpdateArguments, + MultipleDataAssimilationArguments, + SingleTestRunArguments, + ] = Field(..., discriminator="mode") ert_config: ErtConfig -mp_ctx = mp.get_context('fork') -process_pool = ProcessPoolExecutor(max_workers=max((os.cpu_count() or 1) - 2, 1), mp_context=mp_ctx) + +mp_ctx = mp.get_context("fork") +process_pool = ProcessPoolExecutor( + max_workers=max((os.cpu_count() or 1) - 2, 1), mp_context=mp_ctx +) app = FastAPI() +experiments: Dict[str, Tuple[BaseRunModel, "Queue[StatusEvents]"]] = {} @app.get("/") async def root(): return {"message": "ping"} -experiments : Dict[str, Tuple[BaseRunModel, "Queue[StatusEvents]"]]= {} -async def run_experiment(experiment_id:str, evaluator_server_config: EvaluatorServerConfig): +async def run_experiment( + experiment_id: str, evaluator_server_config: EvaluatorServerConfig +): loop = asyncio.get_running_loop() print(f"Starting experiment {experiment_id}") - await loop.run_in_executor(None, lambda: experiments[experiment_id][0].start_simulations_thread(evaluator_server_config)) + await loop.run_in_executor( + None, + lambda: experiments[experiment_id][0].start_simulations_thread( + evaluator_server_config + ), + ) print(f"Experiment {experiment_id} done") + @app.post("/experiments/") async def submit_experiment(experiment: Experiment, background_tasks: BackgroundTasks): storage = open_storage(experiment.ert_config.ens_path, "w") @@ -60,7 +88,10 @@ async def submit_experiment(experiment: Experiment, background_tasks: Background status_queue, ) except ValueError as e: - return HTTPException(status_code=420, detail=f"{experiment.args.mode} was not valid, failed with: {e}") + return HTTPException( + status_code=420, + detail=f"{experiment.args.mode} was not valid, failed with: {e}", + ) port_range = None if model.queue_system == QueueSystem.LOCAL: @@ -70,9 +101,12 @@ async def submit_experiment(experiment: Experiment, background_tasks: Background experiment_id = str(uuid.uuid4()) experiments[experiment_id] = (model, status_queue) - background_tasks.add_task(run_experiment, experiment_id, evaluator_server_config=evaluator_server_config) + background_tasks.add_task( + run_experiment, experiment_id, evaluator_server_config=evaluator_server_config + ) return {"message": "Experiment Started", "experiment_id": experiment_id} + @app.put("/experiments/{experiment_id}/cancel") async def cancel_experiment(experiment_id: str): if experiment_id in experiments: @@ -92,7 +126,7 @@ async def websocket_endpoint(websocket: WebSocket, experiment_id: str): try: item: StatusEvents = q.get(block=False) except queue.Empty: - asyncio.sleep(0.01) + await asyncio.sleep(0.01) continue if isinstance(item, _UpdateEvent): @@ -104,4 +138,3 @@ async def websocket_endpoint(websocket: WebSocket, experiment_id: str): await asyncio.sleep(0.1) if isinstance(item, EndEvent): break -