Skip to content

Commit

Permalink
initial solution using pickle/pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-bernstein committed Apr 16, 2024
1 parent da620e0 commit dc07ac9
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 14 deletions.
5 changes: 3 additions & 2 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class SubmitStreamingRequest(BaseModel):
"""

agent: Agent
result_handler: str
result_handler: ResultHandler
task_name: str = "process_file_streaming"


Expand Down Expand Up @@ -216,8 +216,9 @@ async def submit_streaming(request: SubmitStreamingRequest):

task = process_streaming_output
logger.info(f"Submitting task {task.name}")
serialized_result_handler = pickle.dumps(request.result_handler)
output_result = task.delay(
job_id=input_job_id, result_handler=request.result_handler
job_id=input_job_id, serialized_result_handler=serialized_result_handler
)
output_job_id = output_result.id
logger.info(f"Task {task.name} submitted with job_id {output_job_id}")
Expand Down
17 changes: 10 additions & 7 deletions server/tasks/process_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,17 @@ def process_file_streaming(self, serialized_agent: bytes):


async def async_process_streaming_output(
input_job_id: str, result_handler: str, batch_size: int
input_job_id: str, serialized_result_handler: bytes, batch_size: int
):
logger.info(f"Polling for results {input_job_id=}")

try:
result_handler = ResultHandler.__dict__[result_handler]
except KeyError as e:
logger.error(f"{result_handler} is not a valid ResultHandler")
# result_handler = ResultHandler.__dict__[result_handler]
result_handler = pickle.loads(serialized_result_handler)
assert isinstance(result_handler, ResultHandler)
except Exception as e:
# logger.error(f"{result_handler} is not a valid ResultHandler")
logger.error(f"not a valid ResultHandler")
raise e

topic = get_output_topic(input_job_id)
Expand Down Expand Up @@ -90,11 +93,11 @@ async def async_process_streaming_output(

@app.task(name="process_streaming_output", track_started=True, bind=True)
def process_streaming_output(
self, job_id: str, result_handler: str, batch_size: int = 2
self, job_id: str, serialized_result_handler: bytes, batch_size: int = 2
):
try:
asyncio.run(async_process_streaming_output(job_id, result_handler, batch_size))
except KeyError:
asyncio.run(async_process_streaming_output(job_id, serialized_result_handler, batch_size))
except Exception as e:
# Set own status to failure
self.update_state(state=states.FAILURE)

Expand Down
71 changes: 66 additions & 5 deletions server/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import logging
import json

from enum import Enum
# from enum import Enum
from abc import abstractmethod
from pydantic import BaseModel, computed_field, ConfigDict, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import List, Union

from label_studio_sdk import Client


logger = logging.getLogger(__name__)


Expand All @@ -20,17 +26,72 @@ class Settings(BaseSettings):
)


def dummy_handler(batch):
class ResultHandler(BaseModel):

@abstractmethod
def __call__(self, batch):
'''
Callable to do something with a batch of results.
'''

class DummyHandler(ResultHandler):
"""
Dummy handler to test streaming output flow
Can delete once we have a real handler
"""

logger.info(f"\n\nHandler received batch: {batch}\n\n")
def __call__(self, batch):
logger.info(f"\n\nHandler received batch: {batch}\n\n")


class LSEHandler(ResultHandler):
"""
Handler to use the Label Studio SDK to load a batch of results back into a Label Studio project
"""
model_config = ConfigDict(arbitrary_types_allowed=True) # for @computed_field

api_key: str
url: str
job_id: str

@computed_field
def client(self) -> Client:
return Client(
api_key=self.api_key,
url=self.url,
)

@model_validator(mode="after")
def ready(self):
# Need this to make POST requests using the SDK client
self.client.headers.update(
{
'accept': 'application/json',
'Content-Type': 'application/json',
}
)

self.client.check_connection()

return self

def __call__(self, batch):
logger.info(f"\n\nHandler received batch: {batch}\n\n")
self.client.make_request(
"POST",
"/api/model-run/batch-predictions",
data=json.dumps(
{
'job_id': self.job_id,
'results': batch,
}
),
)


class ResultHandler(Enum):
DUMMY = dummy_handler
# class ResultHandler(Enum):
# DUMMY = dummy_handler
# LSE = LSEHandler


def get_input_topic(job_id: str):
Expand Down

0 comments on commit dc07ac9

Please sign in to comment.