Skip to content

Commit

Permalink
Add event type and ID to server sent events
Browse files Browse the repository at this point in the history
To let client know when it's done, and also reconnect
when the connection breaks.
  • Loading branch information
antoineleclair committed Apr 8, 2024
1 parent dee4052 commit 77a6585
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 15 deletions.
31 changes: 24 additions & 7 deletions disco/endpoints/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from datetime import datetime
from typing import Annotated

from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm.session import Session as DBSession
from sse_starlette import ServerSentEvent
from sse_starlette.sse import EventSourceResponse

from disco.auth import get_api_key, get_api_key_wo_tx
Expand Down Expand Up @@ -106,7 +107,9 @@ def deployments_post(
dependencies=[Depends(get_api_key_wo_tx)],
)
async def deployment_output_get(
project_name: str, deployment_number: int, after: datetime | None = None
project_name: str,
deployment_number: int,
last_event_id: Annotated[str | None, Header()] = None,
):
with Session() as dbsession:
with dbsession.begin():
Expand All @@ -122,6 +125,11 @@ async def deployment_output_get(
if deployment is None:
raise HTTPException(status_code=404)
source = f"DEPLOYMENT_{deployment.id}"
after = None
if last_event_id is not None:
output = commandoutputs.get_by_id(dbsession, last_event_id)
if output is not None:
after = output.created

async def get_build_output(source: str, after: datetime | None):
while True:
Expand All @@ -130,13 +138,22 @@ async def get_build_output(source: str, after: datetime | None):
output = commandoutputs.get_next(dbsession, source, after=after)
if output is not None:
if output.text is None:
yield ServerSentEvent(
id=output.id,
event="end",
data="",
)
return
after = output.created
yield json.dumps(
{
"timestamp": output.created.isoformat(),
"text": output.text,
}
yield ServerSentEvent(
id=output.id,
event="output",
data=json.dumps(
{
"timestamp": output.created.isoformat(),
"text": output.text,
}
),
)
if output is None:
await asyncio.sleep(0.1)
Expand Down
6 changes: 5 additions & 1 deletion disco/endpoints/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random

from fastapi import APIRouter, Depends, HTTPException
from sse_starlette import ServerSentEvent
from sse_starlette.sse import EventSourceResponse

from disco.auth import get_api_key_wo_tx
Expand Down Expand Up @@ -62,7 +63,10 @@ async def read_logs(project_name: str | None, service_name: str | None):
try:
while True:
log_obj = await log_queue.get()
yield json.dumps(log_obj)
yield ServerSentEvent(
event="output",
data=json.dumps(log_obj),
)
finally:
try:
await start_logspout_process.wait()
Expand Down
31 changes: 24 additions & 7 deletions disco/endpoints/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from datetime import datetime
from typing import Annotated

from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel, Field, ValidationError
from pydantic_core import InitErrorDetails, PydanticCustomError
from sqlalchemy.orm.session import Session as DBSession
from sse_starlette import ServerSentEvent
from sse_starlette.sse import EventSourceResponse

from disco.auth import get_api_key, get_api_key_wo_tx
Expand Down Expand Up @@ -130,7 +131,9 @@ def run_post(
dependencies=[Depends(get_api_key_wo_tx)],
)
async def run_output_get(
project_name: str, run_number: int, after: datetime | None = None
project_name: str,
run_number: int,
last_event_id: Annotated[str | None, Header()] = None,
):
with Session() as dbsession:
with dbsession.begin():
Expand All @@ -141,6 +144,11 @@ async def run_output_get(
if run is None:
raise HTTPException(status_code=404)
source = f"RUN_{run.id}"
after = None
if last_event_id is not None:
output = commandoutputs.get_by_id(dbsession, last_event_id)
if output is not None:
after = output.created

# TODO refactor, this is copy-pasted from deployment output
async def get_run_output(source: str, after: datetime | None):
Expand All @@ -150,13 +158,22 @@ async def get_run_output(source: str, after: datetime | None):
output = commandoutputs.get_next(dbsession, source, after=after)
if output is not None:
if output.text is None:
yield ServerSentEvent(
id=output.id,
event="end",
data="",
)
return
after = output.created
yield json.dumps(
{
"timestamp": output.created.isoformat(),
"text": output.text,
}
yield ServerSentEvent(
id=output.id,
event="output",
data=json.dumps(
{
"timestamp": output.created.isoformat(),
"text": output.text,
}
),
)
if output is None:
await asyncio.sleep(0.1)
Expand Down
4 changes: 4 additions & 0 deletions disco/utils/commandoutputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ def get_next(

def delete_output_for_source(dbsession: DBSession, source: str) -> None:
dbsession.query(CommandOutput).filter(CommandOutput.source == source).delete()


def get_by_id(dbsession, output_id) -> CommandOutput | None:
return dbsession.query(CommandOutput).get(output_id)

0 comments on commit 77a6585

Please sign in to comment.