Skip to content

Commit

Permalink
Cleanup and add return model
Browse files Browse the repository at this point in the history
  • Loading branch information
sondreso committed Sep 15, 2024
1 parent 869c246 commit 5e3619b
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 72 deletions.
2 changes: 1 addition & 1 deletion Hackathon.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ ert status <experiment_id>
This can be combined with the commands from the previous section to directly submit and open a status window by doing the following:

```bash
ert status $(<command from previous section> | jq -r .experiment_id)
ert status $(<command from previous section> | jq -r .id)
```

15 changes: 10 additions & 5 deletions src/ert/experiment_server/experiment_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import logging
import queue
from multiprocessing.queues import Queue
from typing import Dict, List
Expand All @@ -11,6 +12,9 @@
from ert.run_models.base_run_model import BaseRunModel, StatusEvents


logger = logging.getLogger(__name__)


class EndTaskEvent:
pass

Expand All @@ -34,12 +38,16 @@ def __init__(self, _id: str, model: BaseRunModel, status_queue: "Queue[StatusEve
self._subscribers: Dict[str, Subscriber] = {}
self._events: List[StatusEvents] = []

@property
def model_type(self) -> str:
return str(self._model.name())

def cancel(self) -> None:
self._model.cancel()

async def run(self):
loop = asyncio.get_running_loop()
print(f"Starting experiment {self._id}")
logger.info(f"Starting experiment {self._id}")

port_range = None
if self._model.queue_system == QueueSystem.LOCAL:
Expand All @@ -62,9 +70,6 @@ async def run(self):

if isinstance(item, _UpdateEvent):
item.snapshot = item.snapshot.to_dict()
# print(item)
# print()
# print()
event = jsonable_encoder(item)
self._events.append(event)
for sub in self._subscribers.values():
Expand All @@ -78,7 +83,7 @@ async def run(self):
break

await simulation_future
print(f"Experiment {self._id} done")
logger.info(f"Experiment {self._id} done")

async def get_event(self, subscriber_id: str) -> StatusEvents:
if subscriber_id not in self._subscribers:
Expand Down
64 changes: 15 additions & 49 deletions src/ert/experiment_server/main.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,36 @@
import asyncio
import multiprocessing as mp
import uuid
from contextlib import asynccontextmanager
from multiprocessing.queues import Queue
from typing import Dict, Union
from typing import Dict, List

from fastapi import BackgroundTasks, FastAPI, HTTPException, WebSocket
from pydantic import BaseModel, Field

from ert.config import ErtConfig
from ert.gui.simulation.ensemble_experiment_panel import (
Arguments as EnsembleExperimentArguments,
)
from ert.gui.simulation.ensemble_smoother_panel import (
Arguments as EnsembleSmootherArguments,
)
from ert.gui.simulation.evaluate_ensemble_panel import (
Arguments as EvaluateEnsembleArguments,
)
from ert.gui.simulation.iterated_ensemble_smoother_panel import (
Arguments as IteratedEnsembleSmootherArguments,
)
from ert.gui.simulation.manual_update_panel import Arguments as ManualUpdateArguments
from ert.gui.simulation.multiple_data_assimilation_panel import (
Arguments as MultipleDataAssimilationArguments,
)
from ert.gui.simulation.single_test_run_panel import Arguments as SingleTestRunArguments
from ert.run_models.base_run_model import StatusEvents
from ert.run_models.model_factory import create_model
from ert.storage import open_storage

from .experiment_task import EndTaskEvent, ExperimentTask
from .models import Experiment, ExperimentOut


class Experiment(BaseModel):
args: Union[
EnsembleExperimentArguments,
EnsembleSmootherArguments,
EvaluateEnsembleArguments,
IteratedEnsembleSmootherArguments,
ManualUpdateArguments,
MultipleDataAssimilationArguments,
SingleTestRunArguments,
] = Field(..., discriminator="mode")
ert_config: ErtConfig


mp_ctx = mp.get_context("fork")
experiments: Dict[str, ExperimentTask] = {}
app = FastAPI()

@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup actions
yield
# Shutdown actions

app = FastAPI(lifespan=lifespan)

@app.get("/")
async def root():
return {"message": "ping"}


@app.post("/experiments/")
@app.get("/experiments/", response_model=List[ExperimentOut])
async def get_experiments():
return [ExperimentOut(id=k, type=v.model_type) for k, v in experiments.items()]


@app.post("/experiments/", response_model=ExperimentOut)
async def submit_experiment(experiment: Experiment, background_tasks: BackgroundTasks):
storage = open_storage(experiment.ert_config.ens_path, "w")
status_queue: "Queue[StatusEvents]" = mp_ctx.Queue()
status_queue: "Queue[StatusEvents]" = mp.Queue()
try:
model = create_model(
experiment.ert_config,
Expand All @@ -83,18 +48,19 @@ async def submit_experiment(experiment: Experiment, background_tasks: Background
task = ExperimentTask(_id=experiment_id, model=model, status_queue=status_queue)
experiments[experiment_id] = task
background_tasks.add_task(task.run)
return {"message": "Experiment Started", "experiment_id": experiment_id}
return ExperimentOut(id=experiment_id, type=task.model_type)


@app.put("/experiments/{experiment_id}/cancel")
@app.put("/experiments/{experiment_id}/cancel", response_model=ExperimentOut)
async def cancel_experiment(experiment_id: str):
if experiment_id not in experiments:
return HTTPException(
status_code=404,
detail=f"Experiment with id {experiment_id} does not exist.",
)
experiments[experiment_id].cancel()
return {"message": "Experiment canceled", "experiment_id": experiment_id}
task = experiments[experiment_id]
task.cancel()
return ExperimentOut(id=experiment_id, type=task.model_type)


@app.websocket("/experiments/{experiment_id}/events")
Expand Down
39 changes: 39 additions & 0 deletions src/ert/experiment_server/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from uuid import UUID
from typing import Union, List

from pydantic import BaseModel, Field

from ert.config import ErtConfig
from ert.gui.simulation.ensemble_experiment_panel import (
Arguments as EnsembleExperimentArguments,
)
from ert.gui.simulation.ensemble_smoother_panel import (
Arguments as EnsembleSmootherArguments,
)
from ert.gui.simulation.evaluate_ensemble_panel import (
Arguments as EvaluateEnsembleArguments,
)
from ert.gui.simulation.iterated_ensemble_smoother_panel import (
Arguments as IteratedEnsembleSmootherArguments,
)
from ert.gui.simulation.manual_update_panel import Arguments as ManualUpdateArguments
from ert.gui.simulation.multiple_data_assimilation_panel import (
Arguments as MultipleDataAssimilationArguments,
)
from ert.gui.simulation.single_test_run_panel import Arguments as SingleTestRunArguments

class Experiment(BaseModel):
args: Union[
EnsembleExperimentArguments,
EnsembleSmootherArguments,
EvaluateEnsembleArguments,
IteratedEnsembleSmootherArguments,
ManualUpdateArguments,
MultipleDataAssimilationArguments,
SingleTestRunArguments,
] = Field(..., discriminator="mode")
ert_config: ErtConfig

class ExperimentOut(BaseModel):
id: UUID
type: str
24 changes: 8 additions & 16 deletions src/ert/gui/simulation/event_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,40 +49,32 @@ def __init__(
def consume_and_emit(self) -> None:
logger.debug("tracking...")
with connect(f"ws://127.0.0.1:8000/experiments/{self._experiment_id}/events") as websocket:
print("Connected")
logger.info("Connected")
while True:
try:
message = websocket.recv(timeout=1.0)
except TimeoutError:
message = None
if self._stopped:
logger.debug("stopped")
logger.info("Stopped")
break

if message is None:
logger.info("Sleeping")
sleep(0.1)
print("Sleep")
continue

print(message)
print()
print()
logger.info("Got message %s".format(message))
event_dict = json.loads(message)
print(event_dict)
print()
print()
if "snapshot" in event_dict:
event_dict["snapshot"] = Snapshot.from_nested_dict(event_dict["snapshot"])
print(event_dict)
print()
print()
try:
event_wrapper = EventWrapper(event=event_dict)
except ValidationError as e:
print(e)
print(event_wrapper)
print()
print()
logger.error("Error when processing event %s".format(str(event_dict)),exc_info=e)

event = event_wrapper.event

# pre-rendering in this thread to avoid work in main rendering thread
if (
isinstance(event, (FullSnapshotEvent, SnapshotUpdateEvent))
Expand Down
2 changes: 1 addition & 1 deletion src/ert/gui/simulation/experiment_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def run_experiment(self) -> None:


res = requests.post("http://127.0.0.1:8000/experiments/", json=jsonable_encoder(data))
experiment_id = res.json()["experiment_id"]
experiment_id = res.json()["id"]
dialog = RunDialog(
experiment_id,
self._notifier,
Expand Down

0 comments on commit 5e3619b

Please sign in to comment.