Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-977: Get streaming job status #97

Merged
merged 31 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d31318c
feat: DIA-953: Stream results from adala inference server into LSE (A…
matt-bernstein Apr 10, 2024
da620e0
Merge remote-tracking branch 'origin/master' into fb-dia-953/stream-r…
matt-bernstein Apr 12, 2024
dc07ac9
initial solution using pickle/pydantic
matt-bernstein Apr 16, 2024
fd2e40e
polymorphism for resulthandlers like in rest of adala
matt-bernstein Apr 17, 2024
73887a9
enable pickling of agents and handlers
matt-bernstein Apr 17, 2024
1f5ec71
black
matt-bernstein Apr 17, 2024
c286862
add label studio sdk
matt-bernstein Apr 17, 2024
42c19ee
dedup settings object
matt-bernstein Apr 17, 2024
a956784
pass in job id
matt-bernstein Apr 17, 2024
e1c2748
make kafka env base class usable
matt-bernstein Apr 17, 2024
466e299
fix serialization
matt-bernstein Apr 19, 2024
fc6e6d8
fix sdk client headers
matt-bernstein Apr 19, 2024
58c686a
black
matt-bernstein Apr 19, 2024
825ce54
revert timeout
matt-bernstein Apr 19, 2024
4812cb1
update LSE api
matt-bernstein Apr 19, 2024
082af3f
feat: DIA-977: Get status of streaming job
Apr 22, 2024
8dc5bf7
Merge branch 'master' of github.com:HumanSignal/Adala into fb-dia-977
Apr 22, 2024
090d151
Merge branch 'fb-dia-953/stream-results' of github.com:HumanSignal/Ad…
Apr 22, 2024
0a403f5
rename task variables
Apr 22, 2024
b5a7afc
add comments
Apr 22, 2024
963d180
Update docker-compose.yml
pakelley Apr 22, 2024
e619a3d
bugfix for nonexistent job id
matt-bernstein Apr 23, 2024
d37523f
replace original get status implementation
Apr 24, 2024
e2fec3a
Merge branch 'fb-dia-977' of github.com:HumanSignal/Adala into fb-dia…
Apr 24, 2024
344d123
Merge branch 'master' of github.com:HumanSignal/Adala into fb-dia-977
Apr 24, 2024
7a68c42
remove merge conflict bs
Apr 24, 2024
8478b89
cleanup
Apr 24, 2024
7da1ff8
add comment + format
Apr 24, 2024
3fb7f76
use settings.kafka... for agent
Apr 24, 2024
1ea6c29
add retry logic around consumer creation race condition
Apr 24, 2024
012376c
break out of retry loop when successful
Apr 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ services:
- AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN}
- MODULE_NAME=process_file.app
- KAFKA_BOOTSTRAP_SERVERS=kafka:9093
- C_FORCE_ROOT=true # needed when using pickle serializer in celery + running as root - remove when we dont run as root
hakan458 marked this conversation as resolved.
Show resolved Hide resolved
command:
'sh -c "cd tasks && poetry run celery -A $$MODULE_NAME worker --loglevel=info"'
redis:
Expand Down
3 changes: 1 addition & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ fastapi = "^0.104.1"
celery = {version = "^5.3.6", extras = ["redis"]}
uvicorn = "*"
pydantic-settings = "^2.2.1"
label-studio-sdk = "^0.0.32"

[tool.poetry.dev-dependencies]
pytest = "^7.4.3"
Expand Down
62 changes: 44 additions & 18 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
process_file,
process_file_streaming,
process_streaming_output,
streaming_parent_task,
)
from utils import get_input_topic, Settings
from server.handlers.result_handlers import ResultHandler
Expand Down Expand Up @@ -207,24 +208,13 @@ async def submit_streaming(request: SubmitStreamingRequest):
Response[JobCreated]: The response model for a job created.
"""

# TODO: get task by name, e.g. request.task_name
task = process_file_streaming
agent = request.agent

logger.info(f"Submitting task {task.name} with agent {agent}")
input_result = task.delay(agent=agent)
input_job_id = input_result.id
logger.debug(f"Task {task.name} submitted with job_id {input_job_id}")

task = process_streaming_output
logger.debug(f"Submitting task {task.name}")
output_result = task.delay(
job_id=input_job_id, result_handler=request.result_handler
task = streaming_parent_task
result = task.apply_async(
kwargs={"agent": request.agent, "result_handler": request.result_handler}
)
output_job_id = output_result.id
logger.debug(f"Task {task.name} submitted with job_id {output_job_id}")
logger.info(f"Submitted {task.name} with ID {result.id}")

return Response[JobCreated](data=JobCreated(job_id=input_job_id))
return Response[JobCreated](data=JobCreated(job_id=result.id))


@app.post("/jobs/submit-batch", response_model=Response)
Expand Down Expand Up @@ -256,6 +246,24 @@ async def submit_batch(batch: BatchData):
return Response[BatchSubmitted](data=BatchSubmitted(job_id=batch.job_id))


def aggregate_statuses(input_job_id: str, output_job_id: str):
input_job_status = process_file_streaming.AsyncResult(input_job_id).status
output_job_status = process_streaming_output.AsyncResult(output_job_id).status

statuses = [input_job_status, output_job_status]

if "PENDING" in statuses:
return "PENDING"
if "FAILURE" in statuses:
return "FAILURE"
if "REVOKED" in statuses:
return "REVOKED"
if "STARTED" in statuses or "RETRY" in statuses:
return "STARTED"

return "SUCCESS"


@app.get("/jobs/{job_id}", response_model=Response[JobStatusResponse])
def get_status(job_id):
"""
Expand All @@ -275,9 +283,27 @@ def get_status(job_id):
"REVOKED": Status.CANCELED,
"RETRY": Status.INPROGRESS,
}
job = process_file.AsyncResult(job_id)
job = streaming_parent_task.AsyncResult(job_id)
logger.info(f"\n\nParent task meta : {job.info}\n\n")

# If parent task meta does not contain input/output job IDs - return FAILED
if (
job.info is None
or "input_job_id" not in job.info
or "output_job_id" not in job.info
):
logger.error(
"Parent task does not contain input job ID and/or output_job_id - unable to return proper status"
)
return Response[JobStatusResponse](data=JobStatusResponse(status=Status.FAILED))

input_job_id = job.info["input_job_id"]
output_job_id = job.info["output_job_id"]

try:
status: Status = celery_status_map[job.status]
status: Status = celery_status_map[
aggregate_statuses(input_job_id, output_job_id)
]
except Exception as e:
logger.error(f"Error getting job status: {e}")
status = Status.FAILED
Expand Down
136 changes: 117 additions & 19 deletions server/tasks/process_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import pickle
import os
import logging
import time

from adala.agents import Agent

from aiokafka import AIOKafkaConsumer
from aiokafka.errors import UnknownTopicOrPartitionError
from celery import Celery, states
from celery.exceptions import Ignore
from server.utils import get_input_topic, get_output_topic, Settings
Expand All @@ -23,6 +25,10 @@

@app.task(name="process_file", track_started=True, serializer="pickle")
def process_file(agent: Agent):
# Override kafka_bootstrap_servers with value from settings
settings = Settings()
agent.environment.kafka_bootstrap_servers = settings.kafka_bootstrap_servers

# # Read data from a file and send it to the Kafka input topic
asyncio.run(agent.environment.initialize())

Expand All @@ -33,35 +39,119 @@ def process_file(agent: Agent):
asyncio.run(agent.environment.finalize())


@app.task(
name="streaming_parent_task", track_started=True, bind=True, serializer="pickle"
)
def streaming_parent_task(
hakan458 marked this conversation as resolved.
Show resolved Hide resolved
self, agent: Agent, result_handler: ResultHandler, batch_size: int = 2
):
"""
This task is used to launch the two tasks that are doing the real work, so that
we store those two job IDs as metadata of this parent task, and be able to get
the status of the entire job from one task ID
"""

# Parent job ID is used for input/output topic names
parent_job_id = self.request.id

# Override kafka_bootstrap_servers with value from settings
settings = Settings()
agent.environment.kafka_bootstrap_servers = settings.kafka_bootstrap_servers

inference_task = process_file_streaming
logger.info(f"Submitting task {inference_task.name} with agent {agent}")
input_result = inference_task.delay(agent=agent, parent_job_id=parent_job_id)
input_job_id = input_result.id
logger.info(f"Task {inference_task.name} submitted with job_id {input_job_id}")

result_handler_task = process_streaming_output
logger.info(f"Submitting task {result_handler_task.name}")
output_result = result_handler_task.delay(
input_job_id=input_job_id,
parent_job_id=parent_job_id,
result_handler=result_handler,
batch_size=batch_size,
)
output_job_id = output_result.id
logger.info(
f"Task {result_handler_task.name} submitted with job_id {output_job_id}"
)

# Store input and output job IDs in parent task metadata
# Need to pass state as well otherwise its overwritten to None
self.update_state(
state=states.STARTED,
meta={"input_job_id": input_job_id, "output_job_id": output_job_id},
)

input_job = process_file_streaming.AsyncResult(input_job_id)
output_job = process_streaming_output.AsyncResult(output_job_id)

terminal_statuses = ["SUCCESS", "FAILURE", "REVOKED"]

while (
input_job.status not in terminal_statuses
or output_job.status not in terminal_statuses
):
time.sleep(1)
hakan458 marked this conversation as resolved.
Show resolved Hide resolved

logger.info("Both input and output jobs complete")

# Update parent task status to SUCCESS and pass metadata again
# otherwise its overwritten to None
self.update_state(
state=states.SUCCESS,
meta={"input_job_id": input_job_id, "output_job_id": output_job_id},
)

# This makes it so Celery doesnt update the tasks state again, which would wipe out the custom metadata we added
# It will retain that state we set above
raise Ignore()
hakan458 marked this conversation as resolved.
Show resolved Hide resolved


@app.task(
name="process_file_streaming", track_started=True, bind=True, serializer="pickle"
)
def process_file_streaming(self, agent: Agent):
# Get own job ID to set Consumer topic accordingly
job_id = self.request.id
agent.environment.kafka_input_topic = get_input_topic(job_id)
agent.environment.kafka_output_topic = get_output_topic(job_id)
def process_file_streaming(self, agent: Agent, parent_job_id: str):
# Set input and output topics using parent job ID
agent.environment.kafka_input_topic = get_input_topic(parent_job_id)
agent.environment.kafka_output_topic = get_output_topic(parent_job_id)

# Run the agent
asyncio.run(agent.arun())


async def async_process_streaming_output(
input_job_id: str, result_handler: ResultHandler, batch_size: int
input_job_id: str,
parent_job_id: str,
result_handler: ResultHandler,
batch_size: int,
hakan458 marked this conversation as resolved.
Show resolved Hide resolved
):
logger.info(f"Polling for results {input_job_id=}")
logger.info(f"Polling for results {parent_job_id=}")

topic = get_output_topic(input_job_id)
topic = get_output_topic(parent_job_id)
settings = Settings()

consumer = AIOKafkaConsumer(
topic,
bootstrap_servers=settings.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
auto_offset_reset="earliest",
)
await consumer.start()
logger.info(f"consumer started {input_job_id=}")
# Retry to workaround race condition of topic creation
retries = 5
while retries > 0:
try:
consumer = AIOKafkaConsumer(
topic,
bootstrap_servers=settings.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
auto_offset_reset="earliest",
)
await consumer.start()
logger.info(f"consumer started {parent_job_id=}")
break
except UnknownTopicOrPartitionError as e:
logger.error(msg=e)
logger.info(f"Retrying to create consumer with topic {topic}")

await consumer.stop()
retries -= 1
time.sleep(1)

input_job_running = True

Expand All @@ -78,6 +168,7 @@ async def async_process_streaming_output(
)
else:
logger.debug(f"No messages in topic {tp.topic}")

if not data:
logger.info(f"No messages in any topic")
finally:
Expand All @@ -96,16 +187,23 @@ async def async_process_streaming_output(
name="process_streaming_output", track_started=True, bind=True, serializer="pickle"
)
def process_streaming_output(
self, job_id: str, result_handler: ResultHandler, batch_size: int = 2
self,
input_job_id: str,
parent_job_id: str,
result_handler: ResultHandler,
batch_size: int,
):
try:
asyncio.run(async_process_streaming_output(job_id, result_handler, batch_size))
asyncio.run(
async_process_streaming_output(
input_job_id, parent_job_id, result_handler, batch_size
)
)
except Exception as e:
# Set own status to failure
self.update_state(state=states.FAILURE)

logger.error(msg=e)

# Ignore the task so no other state is recorded
# TODO check if this updates state to Ignored, or keeps Failed
raise Ignore()
Loading