diff --git a/dev/check_torch_cuda.py b/dev/check_torch_cuda.py new file mode 100644 index 00000000..eaa297a1 --- /dev/null +++ b/dev/check_torch_cuda.py @@ -0,0 +1,18 @@ +import torch +import subprocess + +print(f"PyTorch version: {torch.__version__}") +print(f"CUDA available: {torch.cuda.is_available()}") +if torch.cuda.is_available(): + print(f"CUDA version: {torch.version.cuda}") + +# Check system CUDA version +try: + nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8") + cuda_version = nvcc_output.split("release ")[-1].split(",")[0] + print(f"System CUDA version: {cuda_version}") +except: + print("Unable to check system CUDA version") + +# Print the current device +print(f"Current device: {torch.cuda.get_device_name(0)}") diff --git a/runner/app/main.py b/runner/app/main.py index 147a0c21..b6441a04 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -54,6 +54,9 @@ def load_pipeline(pipeline: str, model_id: str) -> any: from app.pipelines.segment_anything_2 import SegmentAnything2Pipeline return SegmentAnything2Pipeline(model_id) + case "llm": + from app.pipelines.llm import LLMPipeline + return LLMPipeline(model_id) case _: raise EnvironmentError( f"{pipeline} is not a valid pipeline for model {model_id}" @@ -88,6 +91,9 @@ def load_route(pipeline: str) -> any: from app.routes import segment_anything_2 return segment_anything_2.router + case "llm": + from app.routes import llm + return llm.router case _: raise EnvironmentError(f"{pipeline} is not a valid pipeline") diff --git a/runner/app/pipelines/llm.py b/runner/app/pipelines/llm.py new file mode 100644 index 00000000..7d3440d7 --- /dev/null +++ b/runner/app/pipelines/llm.py @@ -0,0 +1,201 @@ +import asyncio +import logging +import os +import psutil +from typing import Dict, Any, List, Optional, AsyncGenerator, Union + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig +from accelerate import init_empty_weights, load_checkpoint_and_dispatch +from app.pipelines.base import Pipeline +from app.pipelines.utils import get_model_dir, get_torch_device +from huggingface_hub import file_download, snapshot_download +from threading import Thread + +logger = logging.getLogger(__name__) + + +def get_max_memory(): + num_gpus = torch.cuda.device_count() + gpu_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)} + cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB" + max_memory = {**gpu_memory, "cpu": cpu_memory} + + logger.info(f"Max memory configuration: {max_memory}") + return max_memory + + +def load_model_8bit(model_id: str, **kwargs): + max_memory = get_max_memory() + + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) + + model = AutoModelForCausalLM.from_pretrained( + model_id, + quantization_config=quantization_config, + device_map="auto", + max_memory=max_memory, + offload_folder="offload", + low_cpu_mem_usage=True, + **kwargs + ) + + return tokenizer, model + + +def load_model_fp16(model_id: str, **kwargs): + device = get_torch_device() + max_memory = get_max_memory() + + # Check for fp16 variant + local_model_path = os.path.join( + get_model_dir(), file_download.repo_folder_name(repo_id=model_id, repo_type="model")) + has_fp16_variant = any(".fp16.safetensors" in fname for _, _, + files in os.walk(local_model_path) for fname in files) + + if device != "cpu" and has_fp16_variant: + logger.info("Loading fp16 variant for %s", model_id) + kwargs["torch_dtype"] = torch.float16 + kwargs["variant"] = "fp16" + elif device != "cpu": + kwargs["torch_dtype"] = torch.bfloat16 + + tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) + + config = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).config + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + checkpoint_dir = snapshot_download( + model_id, cache_dir=get_model_dir(), local_files_only=True) + + model = load_checkpoint_and_dispatch( + model, + checkpoint_dir, + device_map="auto", + max_memory=max_memory, + # Adjust based on your model architecture + no_split_module_classes=["LlamaDecoderLayer"], + dtype=kwargs.get("torch_dtype", torch.float32), + offload_folder="offload", + offload_state_dict=True, + ) + + return tokenizer, model + + +class LLMPipeline(Pipeline): + def __init__(self, model_id: str): + self.model_id = model_id + kwargs = { + "cache_dir": get_model_dir(), + "local_files_only": True, + } + self.device = get_torch_device() + + # Generate the correct folder name + folder_path = file_download.repo_folder_name( + repo_id=model_id, repo_type="model") + self.local_model_path = os.path.join(get_model_dir(), folder_path) + self.checkpoint_dir = snapshot_download( + model_id, cache_dir=get_model_dir(), local_files_only=True) + + logger.info(f"Local model path: {self.local_model_path}") + logger.info(f"Directory contents: {os.listdir(self.local_model_path)}") + + use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true" + + if use_8bit: + logger.info("Using 8-bit quantization") + self.tokenizer, self.model = load_model_8bit(model_id, **kwargs) + else: + logger.info("Using fp16/bf16 precision") + self.tokenizer, self.model = load_model_fp16(model_id, **kwargs) + + logger.info( + f"Model loaded and distributed. Device map: {self.model.hf_device_map}" + ) + + # Set up generation config + self.generation_config = self.model.generation_config + + self.terminators = [ + self.tokenizer.eos_token_id, + self.tokenizer.convert_tokens_to_ids("<|eot_id|>") + ] + + # Optional: Add optimizations + sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" + if sfast_enabled: + logger.info( + "LLMPipeline will be dynamically compiled with stable-fast for %s", + model_id, + ) + from app.pipelines.optim.sfast import compile_model + self.model = compile_model(self.model) + + async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: + conversation = [] + if system_msg: + conversation.append({"role": "system", "content": system_msg}) + if history: + conversation.extend(history) + conversation.append({"role": "user", "content": prompt}) + + input_ids = self.tokenizer.apply_chat_template( + conversation, return_tensors="pt").to(self.model.device) + attention_mask = torch.ones_like(input_ids) + + max_new_tokens = kwargs.get("max_tokens", 256) + temperature = kwargs.get("temperature", 0.7) + + streamer = TextIteratorStreamer( + self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) + + generate_kwargs = self.generation_config.to_dict() + generate_kwargs.update({ + "input_ids": input_ids, + "attention_mask": attention_mask, + "streamer": streamer, + "max_new_tokens": max_new_tokens, + "do_sample": temperature > 0, + "temperature": temperature, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.eos_token_id, + }) + + thread = Thread(target=self.model_generate_wrapper, kwargs=generate_kwargs) + thread.start() + + total_tokens = 0 + try: + for text in streamer: + total_tokens += 1 + yield text + await asyncio.sleep(0) # Allow other tasks to run + except Exception as e: + logger.error(f"Error during streaming: {str(e)}") + raise + + input_length = input_ids.size(1) + yield {"tokens_used": input_length + total_tokens} + + def model_generate_wrapper(self, **kwargs): + try: + logger.debug("Entering model.generate") + with torch.cuda.amp.autocast(): # Use automatic mixed precision + self.model.generate(**kwargs) + logger.debug("Exiting model.generate") + except Exception as e: + logger.error(f"Error in model.generate: {str(e)}", exc_info=True) + raise + + def __str__(self): + return f"LLMPipeline(model_id={self.model_id})" diff --git a/runner/app/routes/llm.py b/runner/app/routes/llm.py new file mode 100644 index 00000000..366a306a --- /dev/null +++ b/runner/app/routes/llm.py @@ -0,0 +1,118 @@ +import logging +import os +from typing import Annotated +from fastapi import APIRouter, Depends, Form, status +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from app.dependencies import get_pipeline +from app.pipelines.base import Pipeline +from app.routes.util import HTTPError, LLMResponse, http_error +import json + +router = APIRouter() + +logger = logging.getLogger(__name__) + +RESPONSES = { + status.HTTP_200_OK: {"model": LLMResponse}, + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + + +@router.post( + "/llm", + response_model=LLMResponse, + responses=RESPONSES, + operation_id="genLLM", + description="Generate text using a language model.", + summary="LLM", + tags=["generate"], + openapi_extra={"x-speakeasy-name-override": "llm"}, +) +@router.post("/llm/", response_model=LLMResponse, responses=RESPONSES, include_in_schema=False) +async def llm( + prompt: Annotated[str, Form()], + model_id: Annotated[str, Form()] = "", + system_msg: Annotated[str, Form()] = "", + temperature: Annotated[float, Form()] = 0.7, + max_tokens: Annotated[int, Form()] = 256, + history: Annotated[str, Form()] = "[]", # We'll parse this as JSON + stream: Annotated[bool, Form()] = False, + pipeline: Pipeline = Depends(get_pipeline), + token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), +): + auth_token = os.environ.get("AUTH_TOKEN") + if auth_token: + if not token or token.credentials != auth_token: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + content=http_error("Invalid bearer token"), + ) + + if model_id != "" and model_id != pipeline.model_id: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error( + f"pipeline configured with {pipeline.model_id} but called with " + f"{model_id}" + ), + ) + + try: + history_list = json.loads(history) + if not isinstance(history_list, list): + raise ValueError("History must be a JSON array") + + generator = pipeline( + prompt=prompt, + history=history_list, + system_msg=system_msg if system_msg else None, + temperature=temperature, + max_tokens=max_tokens + ) + + if stream: + return StreamingResponse(stream_generator(generator), media_type="text/event-stream") + else: + full_response = "" + async for chunk in generator: + if isinstance(chunk, dict): + tokens_used = chunk["tokens_used"] + break + full_response += chunk + + return LLMResponse(response=full_response, tokens_used=tokens_used) + + except json.JSONDecodeError: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": "Invalid JSON format for history"} + ) + except ValueError as ve: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(ve)} + ) + except Exception as e: + logger.error(f"LLM processing error: {str(e)}") + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": "Internal server error during LLM processing."} + ) + + +async def stream_generator(generator): + try: + async for chunk in generator: + if isinstance(chunk, dict): # This is the final result + yield f"data: {json.dumps(chunk)}\n\n" + break + else: + yield f"data: {json.dumps({'chunk': chunk})}\n\n" + yield "data: [DONE]\n\n" + except Exception as e: + logger.error(f"Streaming error: {str(e)}") + yield f"data: {json.dumps({'error': str(e)})}\n\n" diff --git a/runner/app/routes/util.py b/runner/app/routes/util.py index 8a319e84..2371c9e1 100644 --- a/runner/app/routes/util.py +++ b/runner/app/routes/util.py @@ -58,6 +58,11 @@ class TextResponse(BaseModel): chunks: List[chunk] = Field(..., description="The generated text chunks.") +class LLMResponse(BaseModel): + response: str + tokens_used: int + + class APIError(BaseModel): """API error response model.""" diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 71914003..5a241fdf 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -76,6 +76,9 @@ function download_restricted_models() { # Download text-to-image and image-to-image models. huggingface-cli download black-forest-labs/FLUX.1-dev --include "*.safetensors" "*.json" "*.txt" "*.model" --exclude ".onnx" ".onnx_data" --cache-dir models ${TOKEN_FLAG:+"$TOKEN_FLAG"} + # Download LLM models (Warning: large model size) + huggingface-cli download meta-llama/Meta-Llama-3.1-8B-Instruct --include "*.json" "*.bin" "*.safetensors" "*.txt" --cache-dir models + } # Enable HF transfer acceleration. diff --git a/runner/gateway.openapi.yaml b/runner/gateway.openapi.yaml index db86753a..a8bba1b1 100644 --- a/runner/gateway.openapi.yaml +++ b/runner/gateway.openapi.yaml @@ -2,7 +2,7 @@ openapi: 3.1.0 info: title: Livepeer AI Runner description: An application to run AI pipelines - version: v0.5.0 + version: '' servers: - url: https://dream-gateway.livepeer.cloud description: Livepeer Cloud Community Gateway @@ -303,6 +303,53 @@ paths: security: - HTTPBearer: [] x-speakeasy-name-override: segmentAnything2 + /llm: + post: + tags: + - generate + summary: LLM + description: Generate text using a language model. + operationId: genLLM + requestBody: + content: + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/Body_genLLM' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/LLMResponse' + '400': + description: Bad Request + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '500': + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + security: + - HTTPBearer: [] + x-speakeasy-name-override: llm components: schemas: APIError: @@ -467,6 +514,40 @@ components: - image - model_id title: Body_genImageToVideo + Body_genLLM: + properties: + prompt: + type: string + title: Prompt + model_id: + type: string + title: Model Id + default: '' + system_msg: + type: string + title: System Msg + default: '' + temperature: + type: number + title: Temperature + default: 0.7 + max_tokens: + type: integer + title: Max Tokens + default: 256 + history: + type: string + title: History + default: '[]' + stream: + type: boolean + title: Stream + default: false + type: object + required: + - prompt + - model_id + title: Body_genLLM Body_genSegmentAnything2: properties: image: @@ -593,6 +674,19 @@ components: - images title: ImageResponse description: Response model for image generation. + LLMResponse: + properties: + response: + type: string + title: Response + tokens_used: + type: integer + title: Tokens Used + type: object + required: + - response + - tokens_used + title: LLMResponse MasksResponse: properties: masks: diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index f8b9c613..6f557055 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -12,6 +12,7 @@ segment_anything_2, text_to_image, upscale, + llm ) from fastapi.openapi.utils import get_openapi import subprocess @@ -123,6 +124,7 @@ def write_openapi(fname: str, entrypoint: str = "runner", version: str = "0.0.0" app.include_router(upscale.router) app.include_router(audio_to_text.router) app.include_router(segment_anything_2.router) + app.include_router(llm.router) logger.info(f"Generating OpenAPI schema for '{entrypoint}' entrypoint...") openapi = get_openapi( diff --git a/runner/openapi.yaml b/runner/openapi.yaml index f25b3d02..377777a9 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -2,7 +2,7 @@ openapi: 3.1.0 info: title: Livepeer AI Runner description: An application to run AI pipelines - version: v0.5.0 + version: '' servers: - url: https://dream-gateway.livepeer.cloud description: Livepeer Cloud Community Gateway @@ -314,6 +314,53 @@ paths: security: - HTTPBearer: [] x-speakeasy-name-override: segmentAnything2 + /llm: + post: + tags: + - generate + summary: LLM + description: Generate text using a language model. + operationId: genLLM + requestBody: + content: + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/Body_genLLM' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/LLMResponse' + '400': + description: Bad Request + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '500': + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + security: + - HTTPBearer: [] + x-speakeasy-name-override: llm components: schemas: APIError: @@ -475,6 +522,39 @@ components: required: - image title: Body_genImageToVideo + Body_genLLM: + properties: + prompt: + type: string + title: Prompt + model_id: + type: string + title: Model Id + default: '' + system_msg: + type: string + title: System Msg + default: '' + temperature: + type: number + title: Temperature + default: 0.7 + max_tokens: + type: integer + title: Max Tokens + default: 256 + history: + type: string + title: History + default: '[]' + stream: + type: boolean + title: Stream + default: false + type: object + required: + - prompt + title: Body_genLLM Body_genSegmentAnything2: properties: image: @@ -607,6 +687,19 @@ components: - images title: ImageResponse description: Response model for image generation. + LLMResponse: + properties: + response: + type: string + title: Response + tokens_used: + type: integer + title: Tokens Used + type: object + required: + - response + - tokens_used + title: LLMResponse MasksResponse: properties: masks: diff --git a/runner/requirements.txt b/runner/requirements.txt index 24f2442f..87f72e43 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -1,6 +1,6 @@ diffusers==0.30.0 accelerate==0.30.1 -transformers==4.41.1 +transformers==4.43.3 fastapi==0.111.0 pydantic==2.7.2 Pillow==10.3.0 @@ -17,3 +17,5 @@ numpy==1.26.4 av==12.1.0 sentencepiece== 0.2.0 protobuf==5.27.2 +bitsandbytes==0.43.3 +psutil==6.0.0 diff --git a/worker/docker.go b/worker/docker.go index f42c0e49..509a8a01 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -36,7 +36,8 @@ var containerHostPorts = map[string]string{ "image-to-video": "8200", "upscale": "8300", "audio-to-text": "8400", - "segment-anything-2": "8500", + "llm": "8500", + "segment-anything-2": "8600", } // Mapping for per pipeline container images. diff --git a/worker/multipart.go b/worker/multipart.go index 16bc425d..26241972 100644 --- a/worker/multipart.go +++ b/worker/multipart.go @@ -246,6 +246,56 @@ func NewAudioToTextMultipartWriter(w io.Writer, req GenAudioToTextMultipartReque return mw, nil } +func NewLLMMultipartWriter(w io.Writer, req BodyGenLLM) (*multipart.Writer, error) { + mw := multipart.NewWriter(w) + + if err := mw.WriteField("prompt", req.Prompt); err != nil { + return nil, fmt.Errorf("failed to write prompt field: %w", err) + } + + if req.History != nil { + if err := mw.WriteField("history", *req.History); err != nil { + return nil, fmt.Errorf("failed to write history field: %w", err) + } + } + + if req.ModelId != nil { + if err := mw.WriteField("model_id", *req.ModelId); err != nil { + return nil, fmt.Errorf("failed to write model_id field: %w", err) + } + } + + if req.SystemMsg != nil { + if err := mw.WriteField("system_msg", *req.SystemMsg); err != nil { + return nil, fmt.Errorf("failed to write system_msg field: %w", err) + } + } + + if req.Temperature != nil { + if err := mw.WriteField("temperature", fmt.Sprintf("%f", *req.Temperature)); err != nil { + return nil, fmt.Errorf("failed to write temperature field: %w", err) + } + } + + if req.MaxTokens != nil { + if err := mw.WriteField("max_tokens", strconv.Itoa(*req.MaxTokens)); err != nil { + return nil, fmt.Errorf("failed to write max_tokens field: %w", err) + } + } + + if req.Stream != nil { + if err := mw.WriteField("stream", fmt.Sprintf("%v", *req.Stream)); err != nil { + return nil, fmt.Errorf("failed to write stream field: %w", err) + } + } + + if err := mw.Close(); err != nil { + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + return mw, nil +} + func NewSegmentAnything2MultipartWriter(w io.Writer, req GenSegmentAnything2MultipartRequestBody) (*multipart.Writer, error) { mw := multipart.NewWriter(w) writer, err := mw.CreateFormFile("image", req.Image.Filename()) diff --git a/worker/runner.gen.go b/worker/runner.gen.go index a6b5c281..03812d9d 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -113,6 +113,17 @@ type BodyGenImageToVideo struct { Width *int `json:"width,omitempty"` } +// BodyGenLLM defines model for Body_genLLM. +type BodyGenLLM struct { + History *string `json:"history,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + ModelId *string `json:"model_id,omitempty"` + Prompt string `json:"prompt"` + Stream *bool `json:"stream,omitempty"` + SystemMsg *string `json:"system_msg,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` +} + // BodyGenSegmentAnything2 defines model for Body_genSegmentAnything2. type BodyGenSegmentAnything2 struct { // Box A length 4 array given as a box prompt to the model, in XYXY format. @@ -186,6 +197,12 @@ type ImageResponse struct { Images []Media `json:"images"` } +// LLMResponse defines model for LLMResponse. +type LLMResponse struct { + Response string `json:"response"` + TokensUsed int `json:"tokens_used"` +} + // MasksResponse Response model for object segmentation. type MasksResponse struct { // Logits The raw, unnormalized predictions (logits) for the masks. @@ -297,6 +314,9 @@ type GenImageToImageMultipartRequestBody = BodyGenImageToImage // GenImageToVideoMultipartRequestBody defines body for GenImageToVideo for multipart/form-data ContentType. type GenImageToVideoMultipartRequestBody = BodyGenImageToVideo +// GenLLMFormdataRequestBody defines body for GenLLM for application/x-www-form-urlencoded ContentType. +type GenLLMFormdataRequestBody = BodyGenLLM + // GenSegmentAnything2MultipartRequestBody defines body for GenSegmentAnything2 for multipart/form-data ContentType. type GenSegmentAnything2MultipartRequestBody = BodyGenSegmentAnything2 @@ -453,6 +473,11 @@ type ClientInterface interface { // GenImageToVideoWithBody request with any body GenImageToVideoWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + // GenLLMWithBody request with any body + GenLLMWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + GenLLMWithFormdataBody(ctx context.Context, body GenLLMFormdataRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + // GenSegmentAnything2WithBody request with any body GenSegmentAnything2WithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) @@ -513,6 +538,30 @@ func (c *Client) GenImageToVideoWithBody(ctx context.Context, contentType string return c.Client.Do(req) } +func (c *Client) GenLLMWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGenLLMRequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) GenLLMWithFormdataBody(ctx context.Context, body GenLLMFormdataRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGenLLMRequestWithFormdataBody(c.Server, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + func (c *Client) GenSegmentAnything2WithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { req, err := NewGenSegmentAnything2RequestWithBody(c.Server, contentType, body) if err != nil { @@ -675,6 +724,46 @@ func NewGenImageToVideoRequestWithBody(server string, contentType string, body i return req, nil } +// NewGenLLMRequestWithFormdataBody calls the generic GenLLM builder with application/x-www-form-urlencoded body +func NewGenLLMRequestWithFormdataBody(server string, body GenLLMFormdataRequestBody) (*http.Request, error) { + var bodyReader io.Reader + bodyStr, err := runtime.MarshalForm(body, nil) + if err != nil { + return nil, err + } + bodyReader = strings.NewReader(bodyStr.Encode()) + return NewGenLLMRequestWithBody(server, "application/x-www-form-urlencoded", bodyReader) +} + +// NewGenLLMRequestWithBody generates requests for GenLLM with any type of body +func NewGenLLMRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/llm") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + // NewGenSegmentAnything2RequestWithBody generates requests for GenSegmentAnything2 with any type of body func NewGenSegmentAnything2RequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { var err error @@ -828,6 +917,11 @@ type ClientWithResponsesInterface interface { // GenImageToVideoWithBodyWithResponse request with any body GenImageToVideoWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenImageToVideoResponse, error) + // GenLLMWithBodyWithResponse request with any body + GenLLMWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenLLMResponse, error) + + GenLLMWithFormdataBodyWithResponse(ctx context.Context, body GenLLMFormdataRequestBody, reqEditors ...RequestEditorFn) (*GenLLMResponse, error) + // GenSegmentAnything2WithBodyWithResponse request with any body GenSegmentAnything2WithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenSegmentAnything2Response, error) @@ -941,6 +1035,32 @@ func (r GenImageToVideoResponse) StatusCode() int { return 0 } +type GenLLMResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *LLMResponse + JSON400 *HTTPError + JSON401 *HTTPError + JSON422 *HTTPValidationError + JSON500 *HTTPError +} + +// Status returns HTTPResponse.Status +func (r GenLLMResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GenLLMResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + type GenSegmentAnything2Response struct { Body []byte HTTPResponse *http.Response @@ -1055,6 +1175,23 @@ func (c *ClientWithResponses) GenImageToVideoWithBodyWithResponse(ctx context.Co return ParseGenImageToVideoResponse(rsp) } +// GenLLMWithBodyWithResponse request with arbitrary body returning *GenLLMResponse +func (c *ClientWithResponses) GenLLMWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenLLMResponse, error) { + rsp, err := c.GenLLMWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseGenLLMResponse(rsp) +} + +func (c *ClientWithResponses) GenLLMWithFormdataBodyWithResponse(ctx context.Context, body GenLLMFormdataRequestBody, reqEditors ...RequestEditorFn) (*GenLLMResponse, error) { + rsp, err := c.GenLLMWithFormdataBody(ctx, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseGenLLMResponse(rsp) +} + // GenSegmentAnything2WithBodyWithResponse request with arbitrary body returning *GenSegmentAnything2Response func (c *ClientWithResponses) GenSegmentAnything2WithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenSegmentAnything2Response, error) { rsp, err := c.GenSegmentAnything2WithBody(ctx, contentType, body, reqEditors...) @@ -1285,6 +1422,60 @@ func ParseGenImageToVideoResponse(rsp *http.Response) (*GenImageToVideoResponse, return response, nil } +// ParseGenLLMResponse parses an HTTP response from a GenLLMWithResponse call +func ParseGenLLMResponse(rsp *http.Response) (*GenLLMResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GenLLMResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest LLMResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: + var dest HTTPValidationError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON422 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 500: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON500 = &dest + + } + + return response, nil +} + // ParseGenSegmentAnything2Response parses an HTTP response from a GenSegmentAnything2WithResponse call func ParseGenSegmentAnything2Response(rsp *http.Response) (*GenSegmentAnything2Response, error) { bodyBytes, err := io.ReadAll(rsp.Body) @@ -1461,6 +1652,9 @@ type ServerInterface interface { // Image To Video // (POST /image-to-video) GenImageToVideo(w http.ResponseWriter, r *http.Request) + // LLM + // (POST /llm) + GenLLM(w http.ResponseWriter, r *http.Request) // Segment Anything 2 // (POST /segment-anything-2) GenSegmentAnything2(w http.ResponseWriter, r *http.Request) @@ -1500,6 +1694,12 @@ func (_ Unimplemented) GenImageToVideo(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotImplemented) } +// LLM +// (POST /llm) +func (_ Unimplemented) GenLLM(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) +} + // Segment Anything 2 // (POST /segment-anything-2) func (_ Unimplemented) GenSegmentAnything2(w http.ResponseWriter, r *http.Request) { @@ -1593,6 +1793,23 @@ func (siw *ServerInterfaceWrapper) GenImageToVideo(w http.ResponseWriter, r *htt handler.ServeHTTP(w, r.WithContext(ctx)) } +// GenLLM operation middleware +func (siw *ServerInterfaceWrapper) GenLLM(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + ctx = context.WithValue(ctx, HTTPBearerScopes, []string{}) + + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + siw.Handler.GenLLM(w, r) + })) + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler.ServeHTTP(w, r.WithContext(ctx)) +} + // GenSegmentAnything2 operation middleware func (siw *ServerInterfaceWrapper) GenSegmentAnything2(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -1769,6 +1986,9 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/image-to-video", wrapper.GenImageToVideo) }) + r.Group(func(r chi.Router) { + r.Post(options.BaseURL+"/llm", wrapper.GenLLM) + }) r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/segment-anything-2", wrapper.GenSegmentAnything2) }) @@ -1785,60 +2005,64 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xbe28bNxL/KsTeAXUAyZLcujkY6B9OmsbGOalhK02LxBCo3dEu411yy4clNefvfuCQ", - "u9qXHk4T9y7VX5bFx/xmOPOb4UMfg1BkueDAtQpOPgYqTCCj+PH08vyFlELazxGoULJcM8GDE9tCwDYR", - "CSoXXAHJRATpYdALcilykJoBzpGpuD18nIAfnoFSNAY7TjOdQnASvFKx/W+Z23+UlozHwf19L5Dwu2ES", - "ouDkHc56sxpSAi3HiekHCHVw3wueiWg5iYGfmoiJsRjDQltAdZTUNrZxvslTQSOICLaTGUuBaEGmQLSk", - "3PacQmSxz4TMqA5OginjVC4r2qDYtj69AO01YZGTOqMmteODXgPCmYljxmPyEw29jcn5j8QoiMhMyBIH", - "dq9Z0XWNtprSqV4xZpfBNtj1PKMxjAX+aRs2NiyiPISJCmkKNV2fHh43lX3BQ2EkjUF5VbUgMXCQVANh", - "GTaEqVCQLknK+C1EtodOgGhYaJJLkeWaHCQsTkCSO5oaOxNdEgmRCf0U5HdDU6aXT6rmeulxkmvEWerL", - "TTYFafVlhYJrXMTNrYVFzmZLMmc6QWg5yyFlHDb7ibNfh5/gvJMNdhy17fgjxBIQzDxhoYNR2LFAyhTJ", - "jUrQhHMqI4W9GGea0dT1OWziI9vNlArpyGODT5+SC3F1Sg4uxLx/RfktOY1orqltfeIXnvKIMK1IKKRj", - "mMgGwRxYnGh0fKeEV8r6PnmxoFmewgn5SN4HKdXAdT8UXDGlgYfLQRpmfYuur6JF+j44IaPDYY+8DzhI", - "9kENcraAtE+l7hetR/dVA1ygYl8skFv67BTLvYBDTDW7g4lz/i0gxqswOVBPMLwMi4DME6rtf7AIUxMB", - "mUmRdZj4POZCWg+akbpDkvdmOPw2JKMq7NceGrl00LrQm2zi4nqSg+zSYdRU4TW6GhGzghCqHJGD9OrV", - "gJiMnLvOlyBbcBjXEDvvRTx8BhJQNQ153ZdHw+F6PBFwwZRdYxx4SF4JCe4zMcrQ1LIWUOQsT1GeigpV", - "pkYTlYo5SFKisNNEJsXInS6J0hJ4rJOWfkV/co2ou7SrmncXr9jkk+vXVNEZ6OUkTCC8rRlPSwNN612C", - "tJxIKHHDCA5DV1SaZcj7syZ3WVowaWTzsJjNgCvrZEKShMpsZtIqzGs363MEU4KdCpEC5YgWIGpb5Bp8", - "WErKI5ERx29rTGE7d9q7WKuaFYaH/1pD12Lm0rlLEkxwQvM8ZaskJ6FYY7cyB0PbMqolsutCZoubG3k/", - "LxbQJbaOAqCW2bdXAL+wCES7Apg1Quj7XkcxOJM0A4XhqyAUPEJj1LLWnZ2+qulPa7w8wSRRk3n8tFOq", - "60kYJ0j+agehZ27yLrk7FwglW1E3P7LtJ1YHnyf5OBgPTz6ZsL0nUxPegm6iGB09bcJ4Uwi0S8zslxaU", - "NTnNhOHaLoCb0xW3ST394Jo54rRNPijtx8wyrR85Z2lqqYFxbGot4SvX7RmCrilWTQSCKZhQE0/WBPHw", - "qFXVlCrgYEKjaBW6NYVdcUXOamWqL1ElKMimKRZZa8e68oiHEqgq9K4lBARwamKyng62J7uj4//jXLfP", - "QoUl5ixqeO9oePRdFx9izwfR4Vucuy21kWu2pRiXOjakmGuIM+D6lC91wnh81E4zU7HoOKYgKToQ+Y5Q", - "KemSxOwOOKGKUDIVi2LD6OMMebFn9f/1t19/I46Nq9o+E4u1O7S28POC75UD/6kMT9XthPHc6E79xLwv", - "QYnUIKnZzgQ7N5TSy5yFGJVY2lOSS7hjwij7IWIhjmba+1VvtXvFuBgtzhZvycHZD29/ODr+Hl3y+vRV", - "re54ZSWfI8z/uT1SZlIbxep2IowuDbmBD85tJWagt7KgyyoStJE2rdhyzU6oEBfNpiw21pjO9M6tVI+I", - "mQZu/41MaPWagtYg/UidUG4Zh/E4hcoy1LQqkJOfHfIu8uDWqVL2B0xCIWSkHqZeLhjXBEcyTjWoMoGW", - "864KUMpjIO+GvdGNdxEc7eUSWOQQatd9Cq6D3b1DiO7oli9imeVKwVU9Y3lZ5LnToUvRqrB2MLxeHPko", - "FzOvlV+IRizME5BAgIYePmF24cjBr73fnqzYr7bZwW5NZCsHc8BSOoW0A9gFfl9WNDVoBZoRYTxiIdqf", - "2q4QS2F45HvbfD+sdZnS8LbapQ3Xie2C69x4koqY6Qd4ixumiOF9GwEqEamtcNA93VyEcaVt1hczCxE5", - "Dtur6K5cEF046e113jV3tHLChvzxJi/Pzepp46861fs8hGicWtGnnx5tKQGfHv+Njjt2sub+3GNbxfng", - "c4YiODvi92w8vlxzBWWbdryDikBTluI9T5r+PAtO3n0M/ilhFpwE/xisbr8G/uprUF4n3d+0j2rsVBB5", - "yYyXhzWHLc292IrGK3XW6PoLTVmE05Var1OFacjwq02aNOe7X2FxmqyAYOpEHapomxN04Qaa6uR54fZ1", - "vEpTbRr3AT//u3ZehR26bqlWZy4rAR3ykWOvvAu0/eSq5hxr68iOtKC6Ly6bQWlH77QYryBitLoE7ky6", - "awlaCVBV3aiucYdJbDGuHmQSN7bYqqyxSrVgaFpF0nmPGF6pGVcVrSIHbuiTsgjCErh+vVIvB+oboK1L", - "0ZoPTdBJ06GQ65YW7fGNpVk+YxGmF9cdcWPFWBdZo0M38daraw9MFd29VW8a2DeuL3pSx3Ywsw3FYoaC", - "a8rcqRGvHCpPhd0e1s1nx7UXnKvZvC3mbQK6OINzAudUkVlK4xgiu7l+ff3T21rCttPsnoTsStgWV+dU", - "D0xLiTsdfBiZdk/+5urCl90rFULKbV6lYQhKuUv9QsAbmW5dVYN9lIOCZquuJy5XxzraSuRBYYrX3JuI", - "K0wM3x4tOI3rujN7Yfcqez13oprs1Qu0f2KxDUHVxvVXBmuMrF0nr+NNffSmeLHt/iLjkkrqlP1aHyp8", - "zvuP1jOADfcf+5v//c3/13vzf/y3vvgn15BTtDOepeZ4YOnO1vBE5Jv/fGNdQ5k8F9IDLk/c9tvnv+zC", - "psXfO17YeIdppNh6Cu3Is1u3r6kIa3tXypd+P970h48tiDf3VUoOUUxH9eEfnK5qL3xg2uVx7otVV8RM", - "xvbbbZWI1cOJ8j0rltphy4z3XQ8q/Lpu6BvvLPAJxba6q3hwYPvWSr8H7mCbJV/xJsOB2LKj9VCrNqsZ", - "pMNirvrs2PFgAzq+5TIkI0o0y0BpmuVtM60vTnECH0E46/b61LZ7SWvmLJpbExf2rhhvXM61xX662tEC", - "q1jSGaplQaSs0Eiml9d2MZ0xzsbjy2dAJcjy5TfynPuqnCTROg/u7Rx2H9mxCv6lkotJy8LScHJ6Xh6S", - "q2o1xe4gB5C2/cpwjoLuQCo3193w8PhwaE0rcuA0Z8FJ8O3h6HBoV5LqBHEP8NFyX4t+sZy5UF3LWr7S", - "rrzgdtdBfv9hXQNRn0e2uG6+erZWB6WfiWiJuyvBbXGJ79oxD1KpBzYR9SOq6er1/LY46npifV9fZZv1", - "8AsXE6j20XDYQFEx++CDsjrvCqG2ZULZjVRmcCc8MylZdesF331GCKvT0A75z2hErpz1ndzR48h9w6nR", - "iZDsD4hQ8OjbxxHslSUvuLaF4VgIckFl7Kx+dPRZQbSOhdtwVl1IeXR8/FiLf841SE5Tcg3yDmSBoEJi", - "WDNU6evdzf1NL1Amy6hcFr+7IGNBCuamsbLcWaRCy5mLvsqB3gJVyz6nGfTFHUjJImTeWmj2gkGCp9B4", - "VACoe5063CF18AUjtnoMvmvA3ldN4iGiNlgYWgItL0C7GfQ0z9NlcQtae5yKNErtNsDWFJVSs0WpjXek", - "X5hTa9IemVTrB/N7Vl3PqntCeyihuedkY0HKNwUPZDRWD4wqCdyVT7c7SeBl14PlB8V+8cDvcWLfSXvk", - "2K9vYfaxv4/9LxD75UPZT4v9IjB6wcDf9Papfz7VP1of//6llb9XxMdylG8I+o6XWV848FsSHzn46ze2", - "++DfB//nC/4i+grnJkefQACqHSC9YKBhoXfYBLxs3Gxi+q9cZKpOFqicGG8kgD93hlE/k97X+/uw/0rC", - "Hu/m/kS5ryvhh8FuKm+gO8Pcv8MsczuZLosfmeEbIq3I6qcmnSG/esn5hfN9IWgf7/t4/0rivfIK+oGR", - "bqrBoBCAQnGNn6EUFy/PU2Ei8lxkmeFML8lLqmFOl4F/MYfXPepkMIgk0Kwfu9bD1A8/DO1wvKFdM/+1", - "xrPXddOWEynsN6A5G0xB00Gp7/3N/X8DAAD//0kjq9GWSAAA", + "H4sIAAAAAAAC/+xba28bN9b+K8S8L9AEkCzLbZqFgX5w0jQx1k4DW2laJIZAzRyNWHPIKS+W1Kz/+4KH", + "MyPORbds4u6m+hRHvJzn3A/JMx+jWGa5FCCMjk4/RjqeQUbxz7M35y+Uksr9nYCOFcsNkyI6dSME3BBR", + "oHMpNJBMJsCPol6UK5mDMgxwj0yn7eWjGRTLM9CapuDWGWY4RKfRpU7d/5a5+482iok0ur/vRQr+sExB", + "Ep2+x11vVksqoNU6OfkdYhPd96JnMlmOUxBnNmFyJEewMA5QHSV1g22cb3MuaQIJwXEyZRyIkWQCxCgq", + "3MwJJA77VKqMmug0mjBB1TLgBsm2+elFKK8xSzzVKbXcrY96DQivbJoykZKfaFzImJz/SKyGhEylqnDg", + "9JoU/dRkqyg964EwuwS2Qa7nGU1hJPGftmBTyxIqYhjrmHKo8fr06EmT2RcillbRFHTBqpEkBQGKGiAs", + "w4GYSw18STgTt5C4GWYGxMDCkFzJLDfk0YylM1DkjnLrdqJLoiCxcbEF+cNSzszycSiulwVOco04K36F", + "zSagHL+sZHCNifi9jXTI2XRJ5szMEFrOcuBMwGY78fLrsBPcd7xBjsO2HH+EVAGCmc9Y7GGUciyRMk1y", + "q2cowjlVicZZTDDDKPdzjpr4yHYxcal88Nhg02fkQl6dkUcXct6/ouKWnCU0N9SNPi4UT0VCmNEklspH", + "mMQ5wRxYOjNo+J6Jgiln++TFgmY5h1PykXyIODUgTD+WQjNtQMTLAY+zvkPX18mCf4hOyfDouEc+RAIU", + "+10PcrYA3qfK9MvRk/tQABfI2Bdz5BY/O/lyLxKQUsPuYOyNfwuI0cpNHunH6F6WJUDmM2rc/2ARc5sA", + "mSqZdYj4PBVSOQuakrpBkg/2+PjbmAxD2K8LaOSNh9aF3mZj79fjHFQXD8MmC6/R1IiclgEhjBE5qIK9", + "GhCbkXM/+Q2oFhwmDKTeehGPmIICZM1AXrfl4fHxejwJCMm00zEuPCKXUoH/m1htKXdRCyjGrCJEFaGo", + "ZGViDdFczkGRCoXbJrEcPXeyJNooEKmZtfgr55NrRN3FXSjeXaxik02u16mmUzDLcTyD+LYmPKMsNKX3", + "BpSLiYQSv4zgMjRFbViGcX/ajF0uLFieuDwsp1MQ2hmZVGRGVTa1PIR57Xd9jmAqsBMpOVCBaAGStkSu", + "oXBLRUUiM+Lj2xpRuMmd8i51VZPC8dE/1oRrOfXp3CcJJgWhec7ZKskpKHXsNfPo2I0Ma4nsuqTZis2N", + "vJ+XCvSJraMAqGX27RXALywB2a4Apg0X+r7XUQxOFc1Ao/tqiKVIUBi1rHXntg85/WmNlc8wSdRoPnna", + "SdXPJEwQDP56B6Kv/OZddHcuEKpoRf3+GG0/sTr4PMnHw9g/+WTSzR5PbHwLpoliePK0CeNtSdCpmLkf", + "HSgncppJK4xTgN/TF7ezevpBnfnA6YYKp3R/Zi7SFivnjHMXGpjAoZYKL/20Zwi6xliYCCTTMKY2Ha9x", + "4uOTVlVTsYCLCU2SlevWGPbFFXlVK1OLElWBhmzCschau9aXRyJWQHXJdy0hIIAzm5L14WB7sjt58j+c", + "6w5ZqJTEnCUN6x0en3zXFQ9x5l7h8B3u3abayDXbUoxPHRtSzMXFZTuzzJg2Ui3roe/9TRitixldoYsu", + "xkbegmja/PdBpKALMvJzugS7NvbuEjpXtdgOFZVRQLMamSnlGupZn2bdprXUBrJxdQ/TgfMap5DOi5de", + "ZCDLnf6tgkYMfLraYhRM2rHy6DAHp+YNVnANaQbCnImlmTGRnrRNYiIXHZdVhGMYId8RqhRdkpTdgSBU", + "E0omclFeGxTRFrXac17w62+//kZ8Tg5t/plcrD2nt4mfl1lfe/Cfmuepvh0zkVvTyZ+c9xVoyS2mNjeZ", + "4OQGU2aZsxhjMx7wKMkV3DFptfsjYTGuZqaILr3VHQZGx+Hi1eIdefTqh3c/nDz5HgPT9dllrfq8dJTP", + "EeZ/3Uk5s9zFcn07ltZUgtyQFc5dPW6ht5Kgry0UGKtcceGKdrehRlw0m7DUOmF60Xuz0j0ipwaE+29i", + "Y8fXBIwBVaw0Mypc3mEi5RCoocZViZz87JF3+blwRsXZnzCOpVSJ3o+9XDJhCK5kghrQVRlV7bs6hlCR", + "Anl/3BveFCaCqwu6BBY5xMZPn4CfoEC7H91PXn0Jy1zGlELX65aCFnnueehiNCTWdobXi5PCy+W04KpQ", + "RMMX5jNQQIDGBXzCnOLIo197vz1e5cDakRenNZEFIR2BcToB3gHsAn+v6toatBLNkDCRsBjlT91USJW0", + "Iilmu6rvuDZlQuPbcEobrifbBdeb8ZjLlJk9rMUv08SKvvMAPZPc1blonn4vwoQ2rvaTUwcRYxyOh+iu", + "vBNdeOptPe9aQbRywob88Tavbk/raeOvutv9PAHReraST79D3HIQePrkb3TptZM0D7df284de982lc7Z", + "4b+vRqM3ax4i3dCOL5EJGMo4vvZx/vM0On3/Mfp/BdPoNPq/weoNdFA8gA6qR8X7m/aFndsKkoIyE9WV", + "3VGL84JswPGKnTW8/kI5S3C7iut1rDADGf60iZPmfvcrLJ6TFRBMnchDiLa5QRduoNzMnpdmX8erDTW2", + "8Sr08z9rt5Y4oeutcnXztiLQQR9j7FVhAm07uaoZx9o6siMt6O7n66ZTutU7KeMSEkZDFfiXiS4VtBKg", + "Ds2oznGHSC4uLkOB1HlTwcgqJzc3C06BeA4eu6QTLvHHY/JW7xIEVLB/sF3AUwi5gyN3vNB7KdmvLQ9f", + "a/QclkBNPSs67xErgip4VaNr8sgvfVyVdVjU158N6wVO/Ui31bha+6EIOhNPLNU6Y0V5fOMSh5iyBBOm", + "n464sQauk6wFeL/x1paMApgupxdSvWlg36hf9I2OA27mBkplxlIYyvxtqAgeSybSHXjr4nPr2goXejpv", + "k3k3A1PeLXuCc6rJlNM0hYRQTV5f//SuVoK4bXZPq04TbsRXbuFDQEVxpws9q3j35m+vLoqDxIqFmApX", + "KdA4Bq19s0pJ4K3iW7VqcY72UFBsoT5RXR16dLXVXm6K7RubQnE8s2K7t+A2furO8Rinh/H4uSfVjMe9", + "yBStQ9sQhDKud8+sEbLxkwoeb+qrN/mLGy8e6N5QRT2zX2sDzud812u1t2x41zt0tBw6Wr7ejpYnf+uG", + "FnINOUU54+1wjlew/rYQ73i++dc3zjS0zXOpCsDVHeLhQuAve4hsxe8dHyLbT0/tFNqRZ7ceyLmMa6dx", + "KpbFDUPTHj62IN7chyE5RjId1UfxgLeqvbBxuvOohj+spiJmMnK/bqtEHB+eVDEzkNQOlwD4jrtX4dfV", + "edLoH8LWoG11V9lI4+bWSr89z+TNkq/sNfIgtpzRC6ihzGoC6ZCYrz47Tjw4gIbvYhkGI0oMy0AbmuVt", + "Ma0vTnGDwoNw1+31qRsvKK3ZsxxubVzKOxDeqNpri/xMONEBCyTpBdWSIIas2CpmltdOmV4Yr0ajN8+A", + "KlDVFw0Y5/xP1SYzY/Lo3u3hzpEdWig68LxPuiisrCBn59W1vw6rKXYHOYBy41dWCCR0B0r7vZxQZQ6C", + "5iw6jb49Gh4dOx1SM0PEA2zD7xvZLxWZS92l0Oq7g+CbBP+0VZw8ZF640nniyupmH7+TN2jzTCbYOOHO", + "0iCQkM+AVJmBS0H9hBq6+h5kmwd1fTRwX9evy3f4g/cGZPvk+LiBIhD44HfteN4VQu2whLQbScziGXhq", + "OVlN60XffUYIq5vdDvrPaEKuvPQ93eHD0H0rqDUzqdifkCDh4bcPQ7hglrwQxpWEIynJBVWpl/rJyWcF", + "0bribsNZTSHVNfiTh1L+uTCgBOXkGtQdqBJBEL6wWggD1/ub+5tepG2WUbUsvyQiI0nKmE1T7aJmmQRd", + "tFz0dQ70Fqhe9gXNoC/vQCmWYMytuWYvGszwRh0vCQB5r4cOf+EefUGPDa/0d3XY+1AkBUTkBktCF0Cr", + "x9zuCHqW53xZvujW2q0xjFJ3AHDVRFBktkJqozP6C8fUGrUHDqr1R4ZDVF0fVQ8Bbd+A5lvjRpJU/RF7", + "RjRWd4wwCNxVHyN0BoGXXS34e/l+2bL6ML7vqT2w79cPLwffP/j+F/D9qvX703y/dIxeNOA828Hh8TRs", + "8VKTEk5Fah2Q6j6v5e6+JXm9l4ciXvTn83kfvd0qDiKWib9N28/nHckHdvXw3f3g6AdH/3yOXrT07+nd", + "zpfRqYvGjT4t+jv7J+t9vGgFLdoEsJuXig2ZvKN19Atn8xbFB3bzegPGwdEPjv75HL30vtK4yckn+L1u", + "O0gvGricvcPJ/mWjUQFr+qAvQXdGgeABaOdEv//FZP2J6XCIP7j9V+L2+NT+H5zhTeB+6Ow2+Eij082L", + "RvEqt5PJsvwWGlsCjSarb+E6XX7Vav6F831J6ODvB3//Svw9+ExjT0+3oTNoBKCRXOM7ufId9TmXNiHP", + "ZZZZwcySvKQG5nQZFQ2w+HqrTweDRAHN+qkfPeLF8qPYLceGizX7Xxt8UFm3bbWRxnkDmrPBBAwdVPze", + "39z/OwAA//87pybUPU8AAA==", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/worker/worker.go b/worker/worker.go index f1e176de..31bc3a64 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -1,12 +1,16 @@ package worker import ( + "bufio" "bytes" "context" "encoding/json" "errors" + "fmt" "log/slog" + "net/http" "strconv" + "strings" "sync" ) @@ -303,6 +307,41 @@ func (w *Worker) AudioToText(ctx context.Context, req GenAudioToTextMultipartReq return resp.JSON200, nil } +func (w *Worker) LLM(ctx context.Context, req GenLLMFormdataRequestBody) (interface{}, error) { + c, err := w.borrowContainer(ctx, "llm", *req.ModelId) + if err != nil { + return nil, err + } + if c == nil { + return nil, errors.New("borrowed container is nil") + } + if c.Client == nil { + return nil, errors.New("container client is nil") + } + + slog.Info("Container borrowed successfully", "model_id", *req.ModelId) + + var buf bytes.Buffer + mw, err := NewLLMMultipartWriter(&buf, req) + if err != nil { + return nil, err + } + + if req.Stream != nil && *req.Stream { + resp, err := c.Client.GenLLMWithBody(ctx, mw.FormDataContentType(), &buf) + if err != nil { + return nil, err + } + return w.handleStreamingResponse(ctx, c, resp) + } + + resp, err := c.Client.GenLLMWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf) + if err != nil { + return nil, err + } + return w.handleNonStreamingResponse(c, resp) +} + func (w *Worker) SegmentAnything2(ctx context.Context, req GenSegmentAnything2MultipartRequestBody) (*MasksResponse, error) { c, err := w.borrowContainer(ctx, "segment-anything-2", *req.ModelId) if err != nil { @@ -433,3 +472,93 @@ func (w *Worker) returnContainer(rc *RunnerContainer) { // Noop because we allow concurrent in-flight requests for external containers } } + +func (w *Worker) handleNonStreamingResponse(c *RunnerContainer, resp *GenLLMResponse) (*LLMResponse, error) { + defer w.returnContainer(c) + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) + if err != nil { + return nil, err + } + slog.Error("LLM container returned 400", slog.String("err", string(val))) + return nil, errors.New("LLM container returned 400") + } + + if resp.JSON401 != nil { + val, err := json.Marshal(resp.JSON401) + if err != nil { + return nil, err + } + slog.Error("LLM container returned 401", slog.String("err", string(val))) + return nil, errors.New("LLM container returned 401") + } + + if resp.JSON500 != nil { + val, err := json.Marshal(resp.JSON500) + if err != nil { + return nil, err + } + slog.Error("LLM container returned 500", slog.String("err", string(val))) + return nil, errors.New("LLM container returned 500") + } + + return resp.JSON200, nil +} + +type LlmStreamChunk struct { + Chunk string `json:"chunk,omitempty"` + TokensUsed int `json:"tokens_used,omitempty"` + Done bool `json:"done,omitempty"` +} + +func (w *Worker) handleStreamingResponse(ctx context.Context, c *RunnerContainer, resp *http.Response) (<-chan LlmStreamChunk, error) { + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + outputChan := make(chan LlmStreamChunk, 10) + + go func() { + defer close(outputChan) + defer w.returnContainer(c) + + scanner := bufio.NewScanner(resp.Body) + totalTokens := 0 + + for scanner.Scan() { + select { + case <-ctx.Done(): + return + default: + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + outputChan <- LlmStreamChunk{Chunk: "[DONE]", Done: true, TokensUsed: totalTokens} + return + } + + var streamData LlmStreamChunk + if err := json.Unmarshal([]byte(data), &streamData); err != nil { + slog.Error("Error unmarshaling stream data", slog.String("err", err.Error())) + continue + } + + totalTokens += streamData.TokensUsed + + select { + case outputChan <- streamData: + case <-ctx.Done(): + return + } + } + } + } + + if err := scanner.Err(); err != nil { + slog.Error("Error reading stream", slog.String("err", err.Error())) + } + }() + + return outputChan, nil +}