Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-1018: Output streaming #90

Merged
merged 16 commits into from
Apr 9, 2024
Merged
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ services:
redis:
condition: service_healthy
environment:
- KAFKA_BOOTSTRAP_SERVERS=kafka:9093
- REDIS_URL=redis://redis:6379/0
command:
["poetry", "run", "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
Expand All @@ -47,6 +48,7 @@ services:
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
- AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN}
- MODULE_NAME=process_file.app
- KAFKA_BOOTSTRAP_SERVERS=kafka:9093
command:
'sh -c "cd tasks && poetry run celery -A $$MODULE_NAME worker --loglevel=info"'
redis:
Expand Down
23 changes: 18 additions & 5 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@

from log_middleware import LogMiddleware
from tasks.process_file import app as celery_app
from tasks.process_file import process_file, process_file_streaming
from utils import get_input_topic
from tasks.process_file import (
process_file,
process_file_streaming,
process_streaming_output,
)
from utils import get_input_topic, ResultHandler, Settings


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -147,6 +152,7 @@ class SubmitStreamingRequest(BaseModel):
"""

agent: Agent
result_handler: str
task_name: str = "process_file_streaming"


Expand Down Expand Up @@ -204,10 +210,17 @@ async def submit_streaming(request: SubmitStreamingRequest):
serialized_agent = pickle.dumps(request.agent)

logger.info(f"Submitting task {task.name} with agent {serialized_agent}")
result = task.delay(serialized_agent=serialized_agent)
print(f"Task {task.name} submitted with job_id {result.id}")
input_result = task.delay(serialized_agent=serialized_agent)
input_job_id = input_result.id
logger.info(f"Task {task.name} submitted with job_id {input_job_id}")

return Response[JobCreated](data=JobCreated(job_id=result.id))
task = process_streaming_output
logger.info(f"Submitting task {task.name}")
output_result = task.delay(job_id=input_job_id, result_handler=request.result_handler)
output_job_id = output_result.id
logger.info(f"Task {task.name} submitted with job_id {output_job_id}")

return Response[JobCreated](data=JobCreated(job_id=input_job_id))
hakan458 marked this conversation as resolved.
Show resolved Hide resolved


@app.post("/jobs/submit-batch", response_model=Response)
Expand Down
63 changes: 61 additions & 2 deletions server/tasks/process_file.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import asyncio
import json
import pickle
import os
import logging
from celery import Celery
from server.utils import get_input_topic, get_output_topic

from aiokafka import AIOKafkaConsumer
from celery import Celery, states
from celery.exceptions import Ignore
from server.utils import get_input_topic, get_output_topic, ResultHandler, Settings


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -38,3 +42,58 @@ def process_file_streaming(self, serialized_agent: bytes):

# Run the agent
asyncio.run(agent.arun())


async def async_process_streaming_output(input_job_id: str, result_handler: str, batch_size: int):
logger.info(f"Polling for results {input_job_id=}")

try:
result_handler = ResultHandler.__dict__[result_handler]
hakan458 marked this conversation as resolved.
Show resolved Hide resolved
except KeyError as e:
logger.error(f"{result_handler} is not a valid ResultHandler")
raise e

topic = get_output_topic(input_job_id)
settings = Settings()

consumer = AIOKafkaConsumer(
topic,
bootstrap_servers=settings.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
auto_offset_reset="earliest",
)
await consumer.start()
logger.info(f"consumer started {input_job_id=}")

input_job_running = True

while input_job_running:
try:
data = await consumer.getmany(timeout_ms=3000, max_records=batch_size)
for tp, messages in data.items():
if messages:
result_handler(messages)
else:
logger.info(f"No messages in topic {tp.topic}")
finally:
job = process_file_streaming.AsyncResult(input_job_id)
if job.status in ['SUCCESS', 'FAILURE', 'REVOKED']:
input_job_running = False
logger.info(f"Input job done, stopping output job")
else:
logger.info(f"Input job still running, keeping output job running")

await consumer.stop()


@app.task(name="process_streaming_output", track_started=True, bind=True)
def process_streaming_output(self, job_id: str, result_handler: str, batch_size: int = 2):
try:
asyncio.run(async_process_streaming_output(job_id, result_handler, batch_size))
except KeyError:
# Set own status to failure
self.update_state(state=states.FAILURE)

# Ignore the task so no other state is recorded
# TODO check if this updates state to Ignored, or keeps Failed
raise Ignore()
42 changes: 42 additions & 0 deletions server/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import logging

from enum import Enum
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import List, Union

logger = logging.getLogger(__name__)


class Settings(BaseSettings):
"""
Can hardcode settings here, read from env file, or pass as env vars
https://docs.pydantic.dev/latest/concepts/pydantic_settings/#field-value-priority
"""

kafka_bootstrap_servers: Union[str, List[str]]

model_config = SettingsConfigDict(
env_file=".env",
)


def dummy_handler(batch):
"""
Dummy handler to test streaming output flow
Can delete once we have a real handler
"""

logger.info(f"\n\nHandler received batch: {batch}\n\n")


class ResultHandler(Enum):
DUMMY = dummy_handler


def get_input_topic(job_id: str):
return f"adala-input-{job_id}"


def get_output_topic(job_id: str):
return f"adala-output-{job_id}"

Loading