From 77a6585d5a85e5ec4c32569563a92b7102740cd0 Mon Sep 17 00:00:00 2001 From: Antoine Leclair Date: Sun, 7 Apr 2024 09:34:46 -0400 Subject: [PATCH] Add event type and ID to server sent events To let client know when it's done, and also reconnect when the connection breaks. --- disco/endpoints/deployments.py | 31 ++++++++++++++++++++++++------- disco/endpoints/logs.py | 6 +++++- disco/endpoints/run.py | 31 ++++++++++++++++++++++++------- disco/utils/commandoutputs.py | 4 ++++ 4 files changed, 57 insertions(+), 15 deletions(-) diff --git a/disco/endpoints/deployments.py b/disco/endpoints/deployments.py index e772797..bfaca05 100644 --- a/disco/endpoints/deployments.py +++ b/disco/endpoints/deployments.py @@ -4,9 +4,10 @@ from datetime import datetime from typing import Annotated -from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException +from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException from pydantic import BaseModel, Field, model_validator from sqlalchemy.orm.session import Session as DBSession +from sse_starlette import ServerSentEvent from sse_starlette.sse import EventSourceResponse from disco.auth import get_api_key, get_api_key_wo_tx @@ -106,7 +107,9 @@ def deployments_post( dependencies=[Depends(get_api_key_wo_tx)], ) async def deployment_output_get( - project_name: str, deployment_number: int, after: datetime | None = None + project_name: str, + deployment_number: int, + last_event_id: Annotated[str | None, Header()] = None, ): with Session() as dbsession: with dbsession.begin(): @@ -122,6 +125,11 @@ async def deployment_output_get( if deployment is None: raise HTTPException(status_code=404) source = f"DEPLOYMENT_{deployment.id}" + after = None + if last_event_id is not None: + output = commandoutputs.get_by_id(dbsession, last_event_id) + if output is not None: + after = output.created async def get_build_output(source: str, after: datetime | None): while True: @@ -130,13 +138,22 @@ async def get_build_output(source: str, after: datetime | None): output = commandoutputs.get_next(dbsession, source, after=after) if output is not None: if output.text is None: + yield ServerSentEvent( + id=output.id, + event="end", + data="", + ) return after = output.created - yield json.dumps( - { - "timestamp": output.created.isoformat(), - "text": output.text, - } + yield ServerSentEvent( + id=output.id, + event="output", + data=json.dumps( + { + "timestamp": output.created.isoformat(), + "text": output.text, + } + ), ) if output is None: await asyncio.sleep(0.1) diff --git a/disco/endpoints/logs.py b/disco/endpoints/logs.py index 14f90f6..bf7c9cd 100644 --- a/disco/endpoints/logs.py +++ b/disco/endpoints/logs.py @@ -4,6 +4,7 @@ import random from fastapi import APIRouter, Depends, HTTPException +from sse_starlette import ServerSentEvent from sse_starlette.sse import EventSourceResponse from disco.auth import get_api_key_wo_tx @@ -62,7 +63,10 @@ async def read_logs(project_name: str | None, service_name: str | None): try: while True: log_obj = await log_queue.get() - yield json.dumps(log_obj) + yield ServerSentEvent( + event="output", + data=json.dumps(log_obj), + ) finally: try: await start_logspout_process.wait() diff --git a/disco/endpoints/run.py b/disco/endpoints/run.py index 77b81d7..22b83b1 100644 --- a/disco/endpoints/run.py +++ b/disco/endpoints/run.py @@ -4,11 +4,12 @@ from datetime import datetime from typing import Annotated -from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException +from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException from fastapi.exceptions import RequestValidationError from pydantic import BaseModel, Field, ValidationError from pydantic_core import InitErrorDetails, PydanticCustomError from sqlalchemy.orm.session import Session as DBSession +from sse_starlette import ServerSentEvent from sse_starlette.sse import EventSourceResponse from disco.auth import get_api_key, get_api_key_wo_tx @@ -130,7 +131,9 @@ def run_post( dependencies=[Depends(get_api_key_wo_tx)], ) async def run_output_get( - project_name: str, run_number: int, after: datetime | None = None + project_name: str, + run_number: int, + last_event_id: Annotated[str | None, Header()] = None, ): with Session() as dbsession: with dbsession.begin(): @@ -141,6 +144,11 @@ async def run_output_get( if run is None: raise HTTPException(status_code=404) source = f"RUN_{run.id}" + after = None + if last_event_id is not None: + output = commandoutputs.get_by_id(dbsession, last_event_id) + if output is not None: + after = output.created # TODO refactor, this is copy-pasted from deployment output async def get_run_output(source: str, after: datetime | None): @@ -150,13 +158,22 @@ async def get_run_output(source: str, after: datetime | None): output = commandoutputs.get_next(dbsession, source, after=after) if output is not None: if output.text is None: + yield ServerSentEvent( + id=output.id, + event="end", + data="", + ) return after = output.created - yield json.dumps( - { - "timestamp": output.created.isoformat(), - "text": output.text, - } + yield ServerSentEvent( + id=output.id, + event="output", + data=json.dumps( + { + "timestamp": output.created.isoformat(), + "text": output.text, + } + ), ) if output is None: await asyncio.sleep(0.1) diff --git a/disco/utils/commandoutputs.py b/disco/utils/commandoutputs.py index 2c05c83..18220cf 100644 --- a/disco/utils/commandoutputs.py +++ b/disco/utils/commandoutputs.py @@ -24,3 +24,7 @@ def get_next( def delete_output_for_source(dbsession: DBSession, source: str) -> None: dbsession.query(CommandOutput).filter(CommandOutput.source == source).delete() + + +def get_by_id(dbsession, output_id) -> CommandOutput | None: + return dbsession.query(CommandOutput).get(output_id)