From 3f8fd03abeef64b1d7b6f58dfe0eda96e377ab2e Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Fri, 13 Dec 2024 19:51:09 +0100 Subject: [PATCH] feat(runner): add transformers pipeline logic This commit adds the initial transformers pipeline which makes use of the pipeline abstraction of the transformers package to call any huggingface transformers pipeline. --- runner/app/main.py | 10 +- runner/app/routes/transformers.py | 160 ++++++++++++++++++++++++++++++ 2 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 runner/app/routes/transformers.py diff --git a/runner/app/main.py b/runner/app/main.py index 0ba3ee3a..5a34cbd0 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -21,7 +21,9 @@ async def lifespan(app: FastAPI): app.include_router(hardware.router) pipeline = os.environ["PIPELINE"] - model_id = os.environ["MODEL_ID"] + model_id = os.environ.get("MODEL_ID", "") + if pipeline != "transformers" and not model_id: + raise EnvironmentError(f"MODEL_ID must be set when using pipeline {pipeline}") app.pipeline = load_pipeline(pipeline, model_id) app.include_router(load_route(pipeline)) @@ -78,6 +80,8 @@ def load_pipeline(pipeline: str, model_id: str) -> any: from app.pipelines.text_to_speech import TextToSpeechPipeline return TextToSpeechPipeline(model_id) + case "transformers": + return None case _: raise EnvironmentError( f"{pipeline} is not a valid pipeline for model {model_id}" @@ -128,6 +132,10 @@ def load_route(pipeline: str) -> any: from app.routes import text_to_speech return text_to_speech.router + case "transformers": + from app.routes import transformers + + return transformers.router case _: raise EnvironmentError(f"{pipeline} is not a valid pipeline") diff --git a/runner/app/routes/transformers.py b/runner/app/routes/transformers.py new file mode 100644 index 00000000..737708e9 --- /dev/null +++ b/runner/app/routes/transformers.py @@ -0,0 +1,160 @@ +import logging +import torch +import os +from typing import Union, Annotated, Dict, Tuple, Any +from fastapi.responses import JSONResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastapi import APIRouter, status, Depends +from pydantic import BaseModel, Field, HttpUrl +from transformers import pipeline +from app.pipelines.utils import get_torch_device + +from app.routes.utils import http_error, handle_pipeline_exception, HTTPError + +router = APIRouter() + +logger = logging.getLogger(__name__) + +# Pipeline specific error handling configuration. +PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = { + # Error strings. + "Unknown task string": ( + "", + status.HTTP_400_BAD_REQUEST, + ), + "unexpected keyword argument": ( + "Unexpected keyword argument provided.", + status.HTTP_400_BAD_REQUEST, + ), + # Specific error types. + "OutOfMemoryError": ( + "Out of memory error. Try reducing output image resolution.", + status.HTTP_500_INTERNAL_SERVER_ERROR, + ), +} + + +class InferenceRequest(BaseModel): + # TODO: Make params optional once Go codegen tool supports OAPI 3.1 + # https://github.com/deepmap/oapi-codegen/issues/373 + task: Annotated[ + str, + Field( + description=( + "The transformer task to perform. E.g. 'automatic-speech-recognition'." + ), + ), + ] + model_name: Annotated[ + str, + Field( + description=( + "The transformer model to use for the task. E.g. 'openai/whisper-base'." + ), + ), + ] + input: Annotated[ + Union[str, HttpUrl], + Field( + description=( + "The input data to be transformed. Can be string or an url to a file." + ), + ), + ] + pipeline_params: Dict[str, Any] = Field( + default_factory=dict, + description="Additional keyword arguments to pass to the transformer pipeline during inference. E.g. {'return_timestamps': True, 'max_length': 50}.", + ) + + +class InferenceResponse(BaseModel): + """Response model for transformer inference.""" + + output: Any = Field( + ..., description="The output data transformed by the transformer pipeline." + ) + + +RESPONSES = { + status.HTTP_200_OK: { + "content": { + "application/json": { + "schema": { + "x-speakeasy-name-override": "data", + } + } + }, + }, + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + + +@router.post( + "/transformers", + response_model=InferenceResponse, + responses=RESPONSES, + description="Perform inference using a Hugging Face transformer model.", + operation_id="genTransformers", + summary="Transformers", + tags=["generate"], + openapi_extra={"x-speakeasy-name-override": "transformers"}, +) +@router.post("/transformers/", responses=RESPONSES, include_in_schema=False) +async def transformers( + request: InferenceRequest, + token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), +): + auth_token = os.environ.get("AUTH_TOKEN") + if auth_token: + if not token or token.credentials != auth_token: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + content=http_error("Invalid bearer token."), + ) + + if not request.task and not request.model_name: + raise JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error("Either 'task' or 'model_name' must be provided."), + ) + if not request.input: + raise JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error("'input' field is required."), + ) + + torch_device = get_torch_device() + + # Initialize the pipeline with the specified task and model ID. + pipeline_kwargs = {} + if request.task: + pipeline_kwargs["task"] = request.task + if request.model_name: + pipeline_kwargs["model"] = request.model_name + try: + pipe = pipeline(device=torch_device, **pipeline_kwargs) + except Exception as e: + return handle_pipeline_exception( + e, + default_error_message=f"Pipeline initialization error: {e}.", + custom_error_config=PIPELINE_ERROR_CONFIG, + ) + + # Perform inference using the pipeline. + try: + out = pipe(request.input, **request.pipeline_params) + except Exception as e: + if isinstance(e, torch.cuda.OutOfMemoryError): + # TODO: Investigate why not all VRAM memory is cleared. + torch.cuda.empty_cache() + logger.error(f"TransformersPipeline error: {e}") + return handle_pipeline_exception( + e, + default_error_message="transformers pipeline error.", + custom_error_config=PIPELINE_ERROR_CONFIG, + ) + + return {"output": out}