Skip to content

Commit

Permalink
feat: DIA-1018: Output streaming (#90)
Browse files Browse the repository at this point in the history
Co-authored-by: Matt Bernstein <[email protected]>
  • Loading branch information
hakan458 and matt-bernstein authored Apr 9, 2024
1 parent 9b7ff60 commit 41aacaf
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ services:
redis:
condition: service_healthy
environment:
- REDIS_URL=redis://redis:6379/0
- KAFKA_BOOTSTRAP_SERVERS=kafka:9093
- REDIS_URL=redis://redis:6379/0
command:
["poetry", "run", "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
worker:
Expand Down
25 changes: 20 additions & 5 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@

from log_middleware import LogMiddleware
from tasks.process_file import app as celery_app
from tasks.process_file import process_file, process_file_streaming
from utils import get_input_topic
from tasks.process_file import (
process_file,
process_file_streaming,
process_streaming_output,
)
from utils import get_input_topic, ResultHandler, Settings


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -147,6 +152,7 @@ class SubmitStreamingRequest(BaseModel):
"""

agent: Agent
result_handler: str
task_name: str = "process_file_streaming"


Expand Down Expand Up @@ -204,10 +210,19 @@ async def submit_streaming(request: SubmitStreamingRequest):
serialized_agent = pickle.dumps(request.agent)

logger.info(f"Submitting task {task.name} with agent {serialized_agent}")
result = task.delay(serialized_agent=serialized_agent)
print(f"Task {task.name} submitted with job_id {result.id}")
input_result = task.delay(serialized_agent=serialized_agent)
input_job_id = input_result.id
logger.info(f"Task {task.name} submitted with job_id {input_job_id}")

task = process_streaming_output
logger.info(f"Submitting task {task.name}")
output_result = task.delay(
job_id=input_job_id, result_handler=request.result_handler
)
output_job_id = output_result.id
logger.info(f"Task {task.name} submitted with job_id {output_job_id}")

return Response[JobCreated](data=JobCreated(job_id=result.id))
return Response[JobCreated](data=JobCreated(job_id=input_job_id))


@app.post("/jobs/submit-batch", response_model=Response)
Expand Down
67 changes: 65 additions & 2 deletions server/tasks/process_file.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import asyncio
import json
import pickle
import os
import logging
from celery import Celery
from server.utils import get_input_topic, get_output_topic

from aiokafka import AIOKafkaConsumer
from celery import Celery, states
from celery.exceptions import Ignore
from server.utils import get_input_topic, get_output_topic, ResultHandler, Settings


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -38,3 +42,62 @@ def process_file_streaming(self, serialized_agent: bytes):

# Run the agent
asyncio.run(agent.arun())


async def async_process_streaming_output(
input_job_id: str, result_handler: str, batch_size: int
):
logger.info(f"Polling for results {input_job_id=}")

try:
result_handler = ResultHandler.__dict__[result_handler]
except KeyError as e:
logger.error(f"{result_handler} is not a valid ResultHandler")
raise e

topic = get_output_topic(input_job_id)
settings = Settings()

consumer = AIOKafkaConsumer(
topic,
bootstrap_servers=settings.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
auto_offset_reset="earliest",
)
await consumer.start()
logger.info(f"consumer started {input_job_id=}")

input_job_running = True

while input_job_running:
try:
data = await consumer.getmany(timeout_ms=3000, max_records=batch_size)
for tp, messages in data.items():
if messages:
result_handler(messages)
else:
logger.info(f"No messages in topic {tp.topic}")
finally:
job = process_file_streaming.AsyncResult(input_job_id)
if job.status in ["SUCCESS", "FAILURE", "REVOKED"]:
input_job_running = False
logger.info(f"Input job done, stopping output job")
else:
logger.info(f"Input job still running, keeping output job running")

await consumer.stop()


@app.task(name="process_streaming_output", track_started=True, bind=True)
def process_streaming_output(
self, job_id: str, result_handler: str, batch_size: int = 2
):
try:
asyncio.run(async_process_streaming_output(job_id, result_handler, batch_size))
except KeyError:
# Set own status to failure
self.update_state(state=states.FAILURE)

# Ignore the task so no other state is recorded
# TODO check if this updates state to Ignored, or keeps Failed
raise Ignore()

0 comments on commit 41aacaf

Please sign in to comment.