Skip to content

Commit

Permalink
feat: DIA-977: Get streaming job status (#97)
Browse files Browse the repository at this point in the history
Co-authored-by: Matt Bernstein <[email protected]>
Co-authored-by: pakelley <[email protected]>
  • Loading branch information
3 people authored Apr 24, 2024
1 parent 589f868 commit c3ddb83
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 39 deletions.
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ services:
- AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN}
- MODULE_NAME=process_file.app
- KAFKA_BOOTSTRAP_SERVERS=kafka:9093
- C_FORCE_ROOT=true # needed when using pickle serializer in celery + running as root - remove when we dont run as root
command:
'sh -c "cd tasks && poetry run celery -A $$MODULE_NAME worker --loglevel=info"'
redis:
Expand Down
3 changes: 1 addition & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ fastapi = "^0.104.1"
celery = {version = "^5.3.6", extras = ["redis"]}
uvicorn = "*"
pydantic-settings = "^2.2.1"
label-studio-sdk = "^0.0.32"

[tool.poetry.dev-dependencies]
pytest = "^7.4.3"
Expand Down
62 changes: 44 additions & 18 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
process_file,
process_file_streaming,
process_streaming_output,
streaming_parent_task,
)
from utils import get_input_topic, Settings
from server.handlers.result_handlers import ResultHandler
Expand Down Expand Up @@ -207,24 +208,13 @@ async def submit_streaming(request: SubmitStreamingRequest):
Response[JobCreated]: The response model for a job created.
"""

# TODO: get task by name, e.g. request.task_name
task = process_file_streaming
agent = request.agent

logger.info(f"Submitting task {task.name} with agent {agent}")
input_result = task.delay(agent=agent)
input_job_id = input_result.id
logger.debug(f"Task {task.name} submitted with job_id {input_job_id}")

task = process_streaming_output
logger.debug(f"Submitting task {task.name}")
output_result = task.delay(
job_id=input_job_id, result_handler=request.result_handler
task = streaming_parent_task
result = task.apply_async(
kwargs={"agent": request.agent, "result_handler": request.result_handler}
)
output_job_id = output_result.id
logger.debug(f"Task {task.name} submitted with job_id {output_job_id}")
logger.info(f"Submitted {task.name} with ID {result.id}")

return Response[JobCreated](data=JobCreated(job_id=input_job_id))
return Response[JobCreated](data=JobCreated(job_id=result.id))


@app.post("/jobs/submit-batch", response_model=Response)
Expand Down Expand Up @@ -256,6 +246,24 @@ async def submit_batch(batch: BatchData):
return Response[BatchSubmitted](data=BatchSubmitted(job_id=batch.job_id))


def aggregate_statuses(input_job_id: str, output_job_id: str):
input_job_status = process_file_streaming.AsyncResult(input_job_id).status
output_job_status = process_streaming_output.AsyncResult(output_job_id).status

statuses = [input_job_status, output_job_status]

if "PENDING" in statuses:
return "PENDING"
if "FAILURE" in statuses:
return "FAILURE"
if "REVOKED" in statuses:
return "REVOKED"
if "STARTED" in statuses or "RETRY" in statuses:
return "STARTED"

return "SUCCESS"


@app.get("/jobs/{job_id}", response_model=Response[JobStatusResponse])
def get_status(job_id):
"""
Expand All @@ -275,9 +283,27 @@ def get_status(job_id):
"REVOKED": Status.CANCELED,
"RETRY": Status.INPROGRESS,
}
job = process_file.AsyncResult(job_id)
job = streaming_parent_task.AsyncResult(job_id)
logger.info(f"\n\nParent task meta : {job.info}\n\n")

# If parent task meta does not contain input/output job IDs - return FAILED
if (
job.info is None
or "input_job_id" not in job.info
or "output_job_id" not in job.info
):
logger.error(
"Parent task does not contain input job ID and/or output_job_id - unable to return proper status"
)
return Response[JobStatusResponse](data=JobStatusResponse(status=Status.FAILED))

input_job_id = job.info["input_job_id"]
output_job_id = job.info["output_job_id"]

try:
status: Status = celery_status_map[job.status]
status: Status = celery_status_map[
aggregate_statuses(input_job_id, output_job_id)
]
except Exception as e:
logger.error(f"Error getting job status: {e}")
status = Status.FAILED
Expand Down
136 changes: 117 additions & 19 deletions server/tasks/process_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import pickle
import os
import logging
import time

from adala.agents import Agent

from aiokafka import AIOKafkaConsumer
from aiokafka.errors import UnknownTopicOrPartitionError
from celery import Celery, states
from celery.exceptions import Ignore
from server.utils import get_input_topic, get_output_topic, Settings
Expand All @@ -23,6 +25,10 @@

@app.task(name="process_file", track_started=True, serializer="pickle")
def process_file(agent: Agent):
# Override kafka_bootstrap_servers with value from settings
settings = Settings()
agent.environment.kafka_bootstrap_servers = settings.kafka_bootstrap_servers

# # Read data from a file and send it to the Kafka input topic
asyncio.run(agent.environment.initialize())

Expand All @@ -33,35 +39,119 @@ def process_file(agent: Agent):
asyncio.run(agent.environment.finalize())


@app.task(
name="streaming_parent_task", track_started=True, bind=True, serializer="pickle"
)
def streaming_parent_task(
self, agent: Agent, result_handler: ResultHandler, batch_size: int = 2
):
"""
This task is used to launch the two tasks that are doing the real work, so that
we store those two job IDs as metadata of this parent task, and be able to get
the status of the entire job from one task ID
"""

# Parent job ID is used for input/output topic names
parent_job_id = self.request.id

# Override kafka_bootstrap_servers with value from settings
settings = Settings()
agent.environment.kafka_bootstrap_servers = settings.kafka_bootstrap_servers

inference_task = process_file_streaming
logger.info(f"Submitting task {inference_task.name} with agent {agent}")
input_result = inference_task.delay(agent=agent, parent_job_id=parent_job_id)
input_job_id = input_result.id
logger.info(f"Task {inference_task.name} submitted with job_id {input_job_id}")

result_handler_task = process_streaming_output
logger.info(f"Submitting task {result_handler_task.name}")
output_result = result_handler_task.delay(
input_job_id=input_job_id,
parent_job_id=parent_job_id,
result_handler=result_handler,
batch_size=batch_size,
)
output_job_id = output_result.id
logger.info(
f"Task {result_handler_task.name} submitted with job_id {output_job_id}"
)

# Store input and output job IDs in parent task metadata
# Need to pass state as well otherwise its overwritten to None
self.update_state(
state=states.STARTED,
meta={"input_job_id": input_job_id, "output_job_id": output_job_id},
)

input_job = process_file_streaming.AsyncResult(input_job_id)
output_job = process_streaming_output.AsyncResult(output_job_id)

terminal_statuses = ["SUCCESS", "FAILURE", "REVOKED"]

while (
input_job.status not in terminal_statuses
or output_job.status not in terminal_statuses
):
time.sleep(1)

logger.info("Both input and output jobs complete")

# Update parent task status to SUCCESS and pass metadata again
# otherwise its overwritten to None
self.update_state(
state=states.SUCCESS,
meta={"input_job_id": input_job_id, "output_job_id": output_job_id},
)

# This makes it so Celery doesnt update the tasks state again, which would wipe out the custom metadata we added
# It will retain that state we set above
raise Ignore()


@app.task(
name="process_file_streaming", track_started=True, bind=True, serializer="pickle"
)
def process_file_streaming(self, agent: Agent):
# Get own job ID to set Consumer topic accordingly
job_id = self.request.id
agent.environment.kafka_input_topic = get_input_topic(job_id)
agent.environment.kafka_output_topic = get_output_topic(job_id)
def process_file_streaming(self, agent: Agent, parent_job_id: str):
# Set input and output topics using parent job ID
agent.environment.kafka_input_topic = get_input_topic(parent_job_id)
agent.environment.kafka_output_topic = get_output_topic(parent_job_id)

# Run the agent
asyncio.run(agent.arun())


async def async_process_streaming_output(
input_job_id: str, result_handler: ResultHandler, batch_size: int
input_job_id: str,
parent_job_id: str,
result_handler: ResultHandler,
batch_size: int,
):
logger.info(f"Polling for results {input_job_id=}")
logger.info(f"Polling for results {parent_job_id=}")

topic = get_output_topic(input_job_id)
topic = get_output_topic(parent_job_id)
settings = Settings()

consumer = AIOKafkaConsumer(
topic,
bootstrap_servers=settings.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
auto_offset_reset="earliest",
)
await consumer.start()
logger.info(f"consumer started {input_job_id=}")
# Retry to workaround race condition of topic creation
retries = 5
while retries > 0:
try:
consumer = AIOKafkaConsumer(
topic,
bootstrap_servers=settings.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
auto_offset_reset="earliest",
)
await consumer.start()
logger.info(f"consumer started {parent_job_id=}")
break
except UnknownTopicOrPartitionError as e:
logger.error(msg=e)
logger.info(f"Retrying to create consumer with topic {topic}")

await consumer.stop()
retries -= 1
time.sleep(1)

input_job_running = True

Expand All @@ -78,6 +168,7 @@ async def async_process_streaming_output(
)
else:
logger.debug(f"No messages in topic {tp.topic}")

if not data:
logger.info(f"No messages in any topic")
finally:
Expand All @@ -96,16 +187,23 @@ async def async_process_streaming_output(
name="process_streaming_output", track_started=True, bind=True, serializer="pickle"
)
def process_streaming_output(
self, job_id: str, result_handler: ResultHandler, batch_size: int = 2
self,
input_job_id: str,
parent_job_id: str,
result_handler: ResultHandler,
batch_size: int,
):
try:
asyncio.run(async_process_streaming_output(job_id, result_handler, batch_size))
asyncio.run(
async_process_streaming_output(
input_job_id, parent_job_id, result_handler, batch_size
)
)
except Exception as e:
# Set own status to failure
self.update_state(state=states.FAILURE)

logger.error(msg=e)

# Ignore the task so no other state is recorded
# TODO check if this updates state to Ignored, or keeps Failed
raise Ignore()

0 comments on commit c3ddb83

Please sign in to comment.