Skip to content

Commit

Permalink
work on multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
sondreso committed Sep 7, 2024
1 parent 5513112 commit ea0f79d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 12 deletions.
33 changes: 22 additions & 11 deletions src/ert/experiment_server/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import multiprocessing as mp
import uuid
from multiprocessing.queues import Queue
import queue

from fastapi import BackgroundTasks, FastAPI, HTTPException, WebSocket
from pydantic import BaseModel, Field
import json
import dataclasses
import asyncio

from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.run_models.model_factory import create_model
Expand All @@ -17,9 +18,9 @@
from ert.gui.simulation.multiple_data_assimilation_panel import Arguments as MultipleDataAssimilationArguments
from ert.gui.simulation.single_test_run_panel import Arguments as SingleTestRunArguments
from ert.storage import open_storage
from ert.ensemble_evaluator.event import _UpdateEvent, EndEvent

from typing import Dict, Union, Tuple
import uuid

from ert.config import ErtConfig, QueueSystem
from fastapi.encoders import jsonable_encoder
Expand All @@ -28,22 +29,26 @@ class Experiment(BaseModel):
args: Union[EnsembleExperimentArguments, EnsembleSmootherArguments, EvaluateEnsembleArguments, IteratedEnsembleSmootherArguments, ManualUpdateArguments, MultipleDataAssimilationArguments, SingleTestRunArguments] = Field(..., discriminator='mode')
ert_config: ErtConfig


mp_ctx = mp.get_context('fork')
app = FastAPI()


@app.get("/")
async def root():
return {"message": "ping"}

experiments : Dict[str, Tuple[BaseRunModel, queue.SimpleQueue]]= {}
experiments : Dict[str, Tuple[BaseRunModel, "Queue[StatusEvents]"]]= {}

def run_experiment(experiment_id:str, evaluator_server_config: EvaluatorServerConfig):
experiments[experiment_id][0].start_simulations_thread(evaluator_server_config=evaluator_server_config)
p = mp_ctx.Process(target=experiments[experiment_id][0].start_simulations_thread, args=(evaluator_server_config,))
p.start()
p.join()

@app.post("/experiments/")
async def submit_experiment(experiment: Experiment, background_tasks: BackgroundTasks):
storage = open_storage(experiment.ert_config.ens_path, "w")
status_queue: queue.SimpleQueue[StatusEvents] = queue.SimpleQueue()
status_queue: "Queue[StatusEvents]" = mp_ctx.Queue()
try:
model = create_model(
experiment.ert_config,
Expand All @@ -52,7 +57,7 @@ async def submit_experiment(experiment: Experiment, background_tasks: Background
status_queue,
)
except ValueError as e:
return HTTPException(status_code=404, detail=f"{experiment.args.mode} was not valid, failed with: {e}")
return HTTPException(status_code=420, detail=f"{experiment.args.mode} was not valid, failed with: {e}")

port_range = None
if model.queue_system == QueueSystem.LOCAL:
Expand All @@ -74,13 +79,19 @@ async def cancel_experiment(experiment_id: str):

@app.websocket("/experiments/{experiment_id}/events")
async def websocket_endpoint(websocket: WebSocket, experiment_id: str):
if experiment_id not in experiments:
return
await websocket.accept()
print(experiment_id)
print(experiments)
q = experiments[experiment_id][1]
while True:
item: StatusEvents = q.get()
from ert.ensemble_evaluator.event import _UpdateEvent, EndEvent
try:
item: StatusEvents = q.get()
except queue.Empty:
asyncio.sleep(0.01)
continue

if isinstance(item, _UpdateEvent):
item.snapshot = item.snapshot.to_dict()
print(item)
Expand All @@ -90,4 +101,4 @@ async def websocket_endpoint(websocket: WebSocket, experiment_id: str):
await asyncio.sleep(0.1)
if isinstance(item, EndEvent):
break

2 changes: 1 addition & 1 deletion test-data/poly_example/poly_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ def _evaluate(coeffs, x):
output = [_evaluate(coeffs, x) for x in range(10)]
import time
import random
time.sleep(random.randint(5,65))
# time.sleep(random.randint(5,65))
with open("poly.out", "w", encoding="utf-8") as f:
f.write("\n".join(map(str, output)))

0 comments on commit ea0f79d

Please sign in to comment.