-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This commit adds a new LLM pipeline and to the ai-worker.
- Loading branch information
1 parent
2d158a3
commit 6b00498
Showing
14 changed files
with
1,004 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
import subprocess | ||
|
||
print(f"PyTorch version: {torch.__version__}") | ||
print(f"CUDA available: {torch.cuda.is_available()}") | ||
if torch.cuda.is_available(): | ||
print(f"CUDA version: {torch.version.cuda}") | ||
|
||
# Check system CUDA version | ||
try: | ||
nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8") | ||
cuda_version = nvcc_output.split("release ")[-1].split(",")[0] | ||
print(f"System CUDA version: {cuda_version}") | ||
except: | ||
print("Unable to check system CUDA version") | ||
|
||
# Print the current device | ||
print(f"Current device: {torch.cuda.get_device_name(0)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
import asyncio | ||
import logging | ||
import os | ||
import psutil | ||
from typing import Dict, Any, List, Optional, AsyncGenerator, Union | ||
|
||
import torch | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig | ||
from accelerate import init_empty_weights, load_checkpoint_and_dispatch | ||
from app.pipelines.base import Pipeline | ||
from app.pipelines.utils import get_model_dir, get_torch_device | ||
from huggingface_hub import file_download, snapshot_download | ||
from threading import Thread | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def get_max_memory(): | ||
num_gpus = torch.cuda.device_count() | ||
gpu_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)} | ||
cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB" | ||
max_memory = {**gpu_memory, "cpu": cpu_memory} | ||
|
||
logger.info(f"Max memory configuration: {max_memory}") | ||
return max_memory | ||
|
||
|
||
def load_model_8bit(model_id: str, **kwargs): | ||
max_memory = get_max_memory() | ||
|
||
quantization_config = BitsAndBytesConfig( | ||
load_in_8bit=True, | ||
llm_int8_threshold=6.0, | ||
llm_int8_has_fp16_weight=False, | ||
) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
model_id, | ||
quantization_config=quantization_config, | ||
device_map="auto", | ||
max_memory=max_memory, | ||
offload_folder="offload", | ||
low_cpu_mem_usage=True, | ||
**kwargs | ||
) | ||
|
||
return tokenizer, model | ||
|
||
|
||
def load_model_fp16(model_id: str, **kwargs): | ||
device = get_torch_device() | ||
max_memory = get_max_memory() | ||
|
||
# Check for fp16 variant | ||
local_model_path = os.path.join( | ||
get_model_dir(), file_download.repo_folder_name(repo_id=model_id, repo_type="model")) | ||
has_fp16_variant = any(".fp16.safetensors" in fname for _, _, | ||
files in os.walk(local_model_path) for fname in files) | ||
|
||
if device != "cpu" and has_fp16_variant: | ||
logger.info("Loading fp16 variant for %s", model_id) | ||
kwargs["torch_dtype"] = torch.float16 | ||
kwargs["variant"] = "fp16" | ||
elif device != "cpu": | ||
kwargs["torch_dtype"] = torch.bfloat16 | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) | ||
|
||
config = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).config | ||
|
||
with init_empty_weights(): | ||
model = AutoModelForCausalLM.from_config(config) | ||
|
||
checkpoint_dir = snapshot_download( | ||
model_id, cache_dir=get_model_dir(), local_files_only=True) | ||
|
||
model = load_checkpoint_and_dispatch( | ||
model, | ||
checkpoint_dir, | ||
device_map="auto", | ||
max_memory=max_memory, | ||
# Adjust based on your model architecture | ||
no_split_module_classes=["LlamaDecoderLayer"], | ||
dtype=kwargs.get("torch_dtype", torch.float32), | ||
offload_folder="offload", | ||
offload_state_dict=True, | ||
) | ||
|
||
return tokenizer, model | ||
|
||
|
||
class LLMPipeline(Pipeline): | ||
def __init__(self, model_id: str): | ||
self.model_id = model_id | ||
kwargs = { | ||
"cache_dir": get_model_dir(), | ||
"local_files_only": True, | ||
} | ||
self.device = get_torch_device() | ||
|
||
# Generate the correct folder name | ||
folder_path = file_download.repo_folder_name( | ||
repo_id=model_id, repo_type="model") | ||
self.local_model_path = os.path.join(get_model_dir(), folder_path) | ||
self.checkpoint_dir = snapshot_download( | ||
model_id, cache_dir=get_model_dir(), local_files_only=True) | ||
|
||
logger.info(f"Local model path: {self.local_model_path}") | ||
logger.info(f"Directory contents: {os.listdir(self.local_model_path)}") | ||
|
||
use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true" | ||
|
||
if use_8bit: | ||
logger.info("Using 8-bit quantization") | ||
self.tokenizer, self.model = load_model_8bit(model_id, **kwargs) | ||
else: | ||
logger.info("Using fp16/bf16 precision") | ||
self.tokenizer, self.model = load_model_fp16(model_id, **kwargs) | ||
|
||
logger.info( | ||
f"Model loaded and distributed. Device map: {self.model.hf_device_map}" | ||
) | ||
|
||
# Set up generation config | ||
self.generation_config = self.model.generation_config | ||
|
||
self.terminators = [ | ||
self.tokenizer.eos_token_id, | ||
self.tokenizer.convert_tokens_to_ids("<|eot_id|>") | ||
] | ||
|
||
# Optional: Add optimizations | ||
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" | ||
if sfast_enabled: | ||
logger.info( | ||
"LLMPipeline will be dynamically compiled with stable-fast for %s", | ||
model_id, | ||
) | ||
from app.pipelines.optim.sfast import compile_model | ||
self.model = compile_model(self.model) | ||
|
||
async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: | ||
conversation = [] | ||
if system_msg: | ||
conversation.append({"role": "system", "content": system_msg}) | ||
if history: | ||
conversation.extend(history) | ||
conversation.append({"role": "user", "content": prompt}) | ||
|
||
input_ids = self.tokenizer.apply_chat_template( | ||
conversation, return_tensors="pt").to(self.model.device) | ||
attention_mask = torch.ones_like(input_ids) | ||
|
||
max_new_tokens = kwargs.get("max_tokens", 256) | ||
temperature = kwargs.get("temperature", 0.7) | ||
|
||
streamer = TextIteratorStreamer( | ||
self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | ||
|
||
generate_kwargs = self.generation_config.to_dict() | ||
generate_kwargs.update({ | ||
"input_ids": input_ids, | ||
"attention_mask": attention_mask, | ||
"streamer": streamer, | ||
"max_new_tokens": max_new_tokens, | ||
"do_sample": temperature > 0, | ||
"temperature": temperature, | ||
"eos_token_id": self.tokenizer.eos_token_id, | ||
"pad_token_id": self.tokenizer.eos_token_id, | ||
}) | ||
|
||
thread = Thread(target=self.model_generate_wrapper, kwargs=generate_kwargs) | ||
thread.start() | ||
|
||
total_tokens = 0 | ||
try: | ||
for text in streamer: | ||
total_tokens += 1 | ||
yield text | ||
await asyncio.sleep(0) # Allow other tasks to run | ||
except Exception as e: | ||
logger.error(f"Error during streaming: {str(e)}") | ||
raise | ||
|
||
input_length = input_ids.size(1) | ||
yield {"tokens_used": input_length + total_tokens} | ||
|
||
def model_generate_wrapper(self, **kwargs): | ||
try: | ||
logger.debug("Entering model.generate") | ||
with torch.cuda.amp.autocast(): # Use automatic mixed precision | ||
self.model.generate(**kwargs) | ||
logger.debug("Exiting model.generate") | ||
except Exception as e: | ||
logger.error(f"Error in model.generate: {str(e)}", exc_info=True) | ||
raise | ||
|
||
def __str__(self): | ||
return f"LLMPipeline(model_id={self.model_id})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import logging | ||
import os | ||
from typing import Annotated | ||
from fastapi import APIRouter, Depends, Form, status | ||
from fastapi.responses import JSONResponse, StreamingResponse | ||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer | ||
from app.dependencies import get_pipeline | ||
from app.pipelines.base import Pipeline | ||
from app.routes.util import HTTPError, LLMResponse, http_error | ||
import json | ||
|
||
router = APIRouter() | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
RESPONSES = { | ||
status.HTTP_200_OK: {"model": LLMResponse}, | ||
status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, | ||
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, | ||
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, | ||
} | ||
|
||
|
||
@router.post( | ||
"/llm", | ||
response_model=LLMResponse, | ||
responses=RESPONSES, | ||
operation_id="genLLM", | ||
description="Generate text using a language model.", | ||
summary="LLM", | ||
tags=["generate"], | ||
openapi_extra={"x-speakeasy-name-override": "llm"}, | ||
) | ||
@router.post("/llm/", response_model=LLMResponse, responses=RESPONSES, include_in_schema=False) | ||
async def llm( | ||
prompt: Annotated[str, Form()], | ||
model_id: Annotated[str, Form()] = "", | ||
system_msg: Annotated[str, Form()] = "", | ||
temperature: Annotated[float, Form()] = 0.7, | ||
max_tokens: Annotated[int, Form()] = 256, | ||
history: Annotated[str, Form()] = "[]", # We'll parse this as JSON | ||
stream: Annotated[bool, Form()] = False, | ||
pipeline: Pipeline = Depends(get_pipeline), | ||
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), | ||
): | ||
auth_token = os.environ.get("AUTH_TOKEN") | ||
if auth_token: | ||
if not token or token.credentials != auth_token: | ||
return JSONResponse( | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
content=http_error("Invalid bearer token"), | ||
) | ||
|
||
if model_id != "" and model_id != pipeline.model_id: | ||
return JSONResponse( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
content=http_error( | ||
f"pipeline configured with {pipeline.model_id} but called with " | ||
f"{model_id}" | ||
), | ||
) | ||
|
||
try: | ||
history_list = json.loads(history) | ||
if not isinstance(history_list, list): | ||
raise ValueError("History must be a JSON array") | ||
|
||
generator = pipeline( | ||
prompt=prompt, | ||
history=history_list, | ||
system_msg=system_msg if system_msg else None, | ||
temperature=temperature, | ||
max_tokens=max_tokens | ||
) | ||
|
||
if stream: | ||
return StreamingResponse(stream_generator(generator), media_type="text/event-stream") | ||
else: | ||
full_response = "" | ||
async for chunk in generator: | ||
if isinstance(chunk, dict): | ||
tokens_used = chunk["tokens_used"] | ||
break | ||
full_response += chunk | ||
|
||
return LLMResponse(response=full_response, tokens_used=tokens_used) | ||
|
||
except json.JSONDecodeError: | ||
return JSONResponse( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
content={"detail": "Invalid JSON format for history"} | ||
) | ||
except ValueError as ve: | ||
return JSONResponse( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
content={"detail": str(ve)} | ||
) | ||
except Exception as e: | ||
logger.error(f"LLM processing error: {str(e)}") | ||
return JSONResponse( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
content={"detail": "Internal server error during LLM processing."} | ||
) | ||
|
||
|
||
async def stream_generator(generator): | ||
try: | ||
async for chunk in generator: | ||
if isinstance(chunk, dict): # This is the final result | ||
yield f"data: {json.dumps(chunk)}\n\n" | ||
break | ||
else: | ||
yield f"data: {json.dumps({'chunk': chunk})}\n\n" | ||
yield "data: [DONE]\n\n" | ||
except Exception as e: | ||
logger.error(f"Streaming error: {str(e)}") | ||
yield f"data: {json.dumps({'error': str(e)})}\n\n" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.