Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-977: Get streaming job status #97

Merged
merged 31 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d31318c
feat: DIA-953: Stream results from adala inference server into LSE (A…
matt-bernstein Apr 10, 2024
da620e0
Merge remote-tracking branch 'origin/master' into fb-dia-953/stream-r…
matt-bernstein Apr 12, 2024
dc07ac9
initial solution using pickle/pydantic
matt-bernstein Apr 16, 2024
fd2e40e
polymorphism for resulthandlers like in rest of adala
matt-bernstein Apr 17, 2024
73887a9
enable pickling of agents and handlers
matt-bernstein Apr 17, 2024
1f5ec71
black
matt-bernstein Apr 17, 2024
c286862
add label studio sdk
matt-bernstein Apr 17, 2024
42c19ee
dedup settings object
matt-bernstein Apr 17, 2024
a956784
pass in job id
matt-bernstein Apr 17, 2024
e1c2748
make kafka env base class usable
matt-bernstein Apr 17, 2024
466e299
fix serialization
matt-bernstein Apr 19, 2024
fc6e6d8
fix sdk client headers
matt-bernstein Apr 19, 2024
58c686a
black
matt-bernstein Apr 19, 2024
825ce54
revert timeout
matt-bernstein Apr 19, 2024
4812cb1
update LSE api
matt-bernstein Apr 19, 2024
082af3f
feat: DIA-977: Get status of streaming job
Apr 22, 2024
8dc5bf7
Merge branch 'master' of github.com:HumanSignal/Adala into fb-dia-977
Apr 22, 2024
090d151
Merge branch 'fb-dia-953/stream-results' of github.com:HumanSignal/Ad…
Apr 22, 2024
0a403f5
rename task variables
Apr 22, 2024
b5a7afc
add comments
Apr 22, 2024
963d180
Update docker-compose.yml
pakelley Apr 22, 2024
e619a3d
bugfix for nonexistent job id
matt-bernstein Apr 23, 2024
d37523f
replace original get status implementation
Apr 24, 2024
e2fec3a
Merge branch 'fb-dia-977' of github.com:HumanSignal/Adala into fb-dia…
Apr 24, 2024
344d123
Merge branch 'master' of github.com:HumanSignal/Adala into fb-dia-977
Apr 24, 2024
7a68c42
remove merge conflict bs
Apr 24, 2024
8478b89
cleanup
Apr 24, 2024
7da1ff8
add comment + format
Apr 24, 2024
3fb7f76
use settings.kafka... for agent
Apr 24, 2024
1ea6c29
add retry logic around consumer creation race condition
Apr 24, 2024
012376c
break out of retry loop when successful
Apr 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions adala/environments/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,36 @@ class AsyncKafkaEnvironment(AsyncEnvironment):
kafka_input_topic: str
kafka_output_topic: str

async def initialize(self):
# claim kafka topic from shared pool here?
pass

async def finalize(self):
# release kafka topic to shared pool here?
pass

async def get_feedback(
self,
skills: SkillSet,
predictions: InternalDataFrame,
num_feedbacks: Optional[int] = None,
) -> EnvironmentFeedback:
raise NotImplementedError("Feedback is not supported in Kafka environment")

async def restore(self):
raise NotImplementedError("Restore is not supported in Kafka environment")

async def save(self):
raise NotImplementedError("Save is not supported in Kafka environment")

async def message_receiver(self, consumer: AIOKafkaConsumer, timeout: int = 3):
await consumer.start()
try:
while True:
try:
# Wait for the next message with a timeout
msg = await asyncio.wait_for(consumer.getone(), timeout=timeout)
print_text(f"Received message: {msg.value}")
yield msg.value
except asyncio.TimeoutError:
print_text(
Expand All @@ -55,8 +78,10 @@ async def message_sender(
try:
for record in data:
await producer.send_and_wait(topic, value=record)
print_text(f"Sent message: {record} to {topic=}")
finally:
await producer.stop()
print_text(f"No more messages for {topic=}")

async def get_next_batch(self, data_iterator, batch_size: int) -> List[Dict]:
batch = []
Expand Down
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ services:
- AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN}
- MODULE_NAME=process_file.app
- KAFKA_BOOTSTRAP_SERVERS=kafka:9093
- C_FORCE_ROOT=true # needed when using pickle serializer in celery + running as root - remove when we dont run as root
hakan458 marked this conversation as resolved.
Show resolved Hide resolved
command:
'sh -c "cd tasks && poetry run celery -A $$MODULE_NAME worker --loglevel=info"'
redis:
Expand Down
216 changes: 215 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ fastapi = "^0.104.1"
celery = {version = "^5.3.6", extras = ["redis"]}
uvicorn = "*"
pydantic-settings = "^2.2.1"
label-studio-sdk = "^0.0.32"

[tool.poetry.dev-dependencies]
pytest = "^7.4.3"
Expand Down
123 changes: 88 additions & 35 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aiokafka import AIOKafkaProducer
from fastapi import HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from pydantic import BaseModel, SerializeAsAny, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic.functional_validators import AfterValidator
from typing_extensions import Annotated
Expand All @@ -23,26 +23,15 @@
process_file,
process_file_streaming,
process_streaming_output,
streaming_parent_task,
)
from utils import get_input_topic, ResultHandler, Settings
from utils import get_input_topic, Settings
from server.handlers.result_handlers import ResultHandler


logger = logging.getLogger(__name__)


class Settings(BaseSettings):
"""
Can hardcode settings here, read from env file, or pass as env vars
https://docs.pydantic.dev/latest/concepts/pydantic_settings/#field-value-priority
"""

kafka_bootstrap_servers: Union[str, List[str]]

model_config = SettingsConfigDict(
env_file=".env",
)


settings = Settings()

app = fastapi.FastAPI()
Expand Down Expand Up @@ -148,13 +137,25 @@ class SubmitRequest(BaseModel):
class SubmitStreamingRequest(BaseModel):
"""
Request model for submitting a streaming job.
Only difference from SubmitRequest is the task_name
"""

agent: Agent
result_handler: str
# SerializeAsAny is for allowing subclasses of ResultHandler
result_handler: SerializeAsAny[ResultHandler]
task_name: str = "process_file_streaming"

@field_validator("result_handler", mode="before")
def validate_result_handler(cls, value: Dict) -> ResultHandler:
"""
Allows polymorphism for ResultHandlers created from a dict; same implementation as the Skills, Environment, and Runtime within an Agent
"""
if "type" not in value:
raise HTTPException(
status_code=400, detail="Missing type in result_handler"
)
result_handler = ResultHandler.create_from_registry(value.pop("type"), **value)
return result_handler


class BatchData(BaseModel):
"""
Expand Down Expand Up @@ -184,10 +185,10 @@ async def submit(request: SubmitRequest):

# TODO: get task by name, e.g. request.task_name
task = process_file
serialized_agent = pickle.dumps(request.agent)
agent = request.agent

logger.debug(f"Submitting task {task.name} with agent {serialized_agent}")
result = task.delay(serialized_agent=serialized_agent)
logger.debug(f"Submitting task {task.name} with agent {agent}")
result = task.delay(agent=agent)
logger.debug(f"Task {task.name} submitted with job_id {result.id}")

return Response[JobCreated](data=JobCreated(job_id=result.id))
Expand All @@ -206,23 +207,13 @@ async def submit_streaming(request: SubmitStreamingRequest):
"""

# TODO: get task by name, e.g. request.task_name
task = process_file_streaming
serialized_agent = pickle.dumps(request.agent)

logger.info(f"Submitting task {task.name} with agent {serialized_agent}")
input_result = task.delay(serialized_agent=serialized_agent)
input_job_id = input_result.id
logger.info(f"Task {task.name} submitted with job_id {input_job_id}")

task = process_streaming_output
logger.info(f"Submitting task {task.name}")
output_result = task.delay(
job_id=input_job_id, result_handler=request.result_handler
task = streaming_parent_task
result = task.apply_async(
kwargs={"agent": request.agent, "result_handler": request.result_handler}
)
output_job_id = output_result.id
logger.info(f"Task {task.name} submitted with job_id {output_job_id}")
logger.info(f"Submitted {task.name} with ID {result.id}")

return Response[JobCreated](data=JobCreated(job_id=input_job_id))
return Response[JobCreated](data=JobCreated(job_id=result.id))


@app.post("/jobs/submit-batch", response_model=Response)
Expand Down Expand Up @@ -284,6 +275,68 @@ def get_status(job_id):
return Response[JobStatusResponse](data=JobStatusResponse(status=status))


def aggregate_statuses(input_job_id: str, output_job_id: str):
input_job_status = process_file_streaming.AsyncResult(input_job_id).status
output_job_status = process_streaming_output.AsyncResult(output_job_id).status

statuses = [input_job_status, output_job_status]

if "PENDING" in statuses:
return "PENDING"
if "FAILURE" in statuses:
return "FAILURE"
if "REVOKED" in statuses:
return "REVOKED"
if "STARTED" in statuses or "RETRY" in statuses:
return "STARTED"

return "SUCCESS"


@app.get("/streaming-jobs/{job_id}", response_model=Response[JobStatusResponse])
hakan458 marked this conversation as resolved.
Show resolved Hide resolved
def get_status_streaming(job_id):
"""
Get the status of a job.

Args:
job_id (str)

Returns:
JobStatusResponse: The response model for getting the status of a job.
"""
celery_status_map = {
"PENDING": Status.PENDING,
"STARTED": Status.INPROGRESS,
"SUCCESS": Status.COMPLETED,
"FAILURE": Status.FAILED,
"REVOKED": Status.CANCELED,
"RETRY": Status.INPROGRESS,
}
job = streaming_parent_task.AsyncResult(job_id)
logger.info(f"\n\nParent task meta : {job.info}\n\n")

# If parent task meta does not contain input/output job IDs - return FAILED
if "input_job_id" not in job.info or "output_job_id" not in job.info:
logger.error(
"Parent task does not contain input job ID and/or output_job_id - unable to return proper status"
)
return Response[JobStatusResponse](data=JobStatusResponse(status=Status.FAILED))

input_job_id = job.info["input_job_id"]
output_job_id = job.info["output_job_id"]

try:
status: Status = celery_status_map[
aggregate_statuses(input_job_id, output_job_id)
]
except Exception as e:
logger.error(f"Error getting job status: {e}")
status = Status.FAILED
else:
logger.info(f"Job {job_id} status: {status}")
return Response[JobStatusResponse](data=JobStatusResponse(status=status))


@app.delete("/jobs/{job_id}", response_model=Response[JobStatusResponse])
def cancel_job(job_id):
"""
Expand Down
Empty file added server/handlers/__init__.py
Empty file.
77 changes: 77 additions & 0 deletions server/handlers/result_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import Optional
import logging
import json
from abc import abstractmethod
from pydantic import computed_field, ConfigDict, model_validator

from adala.utils.registry import BaseModelInRegistry
from label_studio_sdk import Client


logger = logging.getLogger(__name__)


class ResultHandler(BaseModelInRegistry):
@abstractmethod
def __call__(self, batch):
"""
Callable to do something with a batch of results.
"""


class DummyHandler(ResultHandler):
"""
Dummy handler to test streaming output flow
Can delete once we have a real handler
"""

def __call__(self, batch):
logger.info(f"\n\nHandler received batch: {batch}\n\n")


class LSEHandler(ResultHandler):
"""
Handler to use the Label Studio SDK to load a batch of results back into a Label Studio project
"""

model_config = ConfigDict(arbitrary_types_allowed=True) # for @computed_field

api_key: str
url: str
modelrun_id: int

@computed_field
def client(self) -> Client:
_client = Client(
api_key=self.api_key,
url=self.url,
)
# Need this to make POST requests using the SDK client
# TODO headers can only be set in this function, since client is a computed field. Need to rethink approach if we make non-POST requests, should probably just make a PR in label_studio_sdk to allow setting this in make_request()
_client.headers.update(
{
"accept": "application/json",
"Content-Type": "application/json",
}
)
return _client

@model_validator(mode="after")
def ready(self):
conn = self.client.check_connection()
assert conn["status"] == "UP", "Label Studio is not available"

return self

def __call__(self, batch):
logger.info(f"\n\nHandler received batch: {batch}\n\n")
self.client.make_request(
"POST",
"/api/model-run/batch-predictions",
data=json.dumps(
{
"modelrun_id": self.modelrun_id,
"results": batch,
}
),
)
Loading
Loading