generated from runpod-workers/worker-template
-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
af9ab86
commit 878c544
Showing
9 changed files
with
276 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
runpod==1.6.2 | ||
infinity-emb[all] | ||
infinity-emb[all] | ||
hf_transfer |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import os | ||
from dotenv import load_dotenv | ||
from constants import DEFAULT_BATCH_SIZE, DEFAULT_BACKEND | ||
|
||
class EmbeddingServiceConfig: | ||
def __init__(self): | ||
load_dotenv() | ||
self.backend = os.environ.get("BACKEND", DEFAULT_BACKEND) | ||
self.model_names = self._get_model_names() | ||
self.batch_sizes = self._get_batch_sizes() | ||
self.default_model_name = self._get_default_model_name() | ||
|
||
def _get_model_names(self): | ||
model_names = os.environ.get("MODEL_NAMES") | ||
if not model_names: | ||
raise ValueError("MODEL_NAMES environment variable is required") | ||
model_names = model_names.split(";") | ||
model_names.append("BAAI/bge-large-en-v1.5") | ||
return model_names | ||
|
||
def _get_batch_sizes(self): | ||
batch_sizes = os.getenv("BATCH_SIZES", f"{DEFAULT_BATCH_SIZE};" * len(self.model_names)).split(";") | ||
batch_sizes = [batch_size for batch_size in batch_sizes if batch_size] | ||
if len(batch_sizes) != len(self.model_names): | ||
raise ValueError("BATCH_SIZES must have the same number of elements as MODEL_NAMES") | ||
batch_sizes = [int(batch_size) for batch_size in batch_sizes] | ||
return batch_sizes | ||
|
||
def _get_default_model_name(self): | ||
env_default_model_name = os.environ.get("DEFAULT_MODEL_NAME") | ||
if env_default_model_name in self.model_names: | ||
return env_default_model_name | ||
elif len(self.model_names) == 1: | ||
return self.model_names[0] | ||
else: | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
DEFAULT_BATCH_SIZE = 128 | ||
DEFAULT_BACKEND = "torch" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from config import EmbeddingServiceConfig | ||
from infinity_emb import EngineArgs, AsyncEmbeddingEngine | ||
from utils import create_error_response, OpenAIEmbeddingResult, OpenAIModelInfo, ModelInfo, list_embeddings_to_response, to_rerank_response | ||
class EmbeddingService: | ||
def __init__(self): | ||
self.config = EmbeddingServiceConfig() | ||
self.models = self._initialize_models() | ||
|
||
def _initialize_models(self): | ||
models = {} | ||
for model_name, batch_size in zip(self.config.model_names, self.config.batch_sizes): | ||
models[model_name] = EmbeddingModel(model_name, batch_size, self.config.backend) | ||
return models | ||
|
||
async def _embed(self, input, engine): | ||
if not isinstance(input, list): | ||
input = [input] | ||
async with engine: | ||
embeddings, usage = await engine.embed(input) | ||
return embeddings, usage | ||
|
||
def openai_get_models(self): | ||
return OpenAIModelInfo(data=[ | ||
ModelInfo(id=model.engine_args.model_name_or_path, | ||
stats=dict(batch_size=model.engine_args.batch_size, | ||
backend=model.engine_args.engine.name)) | ||
for model in self.models.values()]).model_dump() | ||
|
||
async def openai_get_embeddings(self, embedding_input, model): | ||
embeddings, usage = await self._embed(embedding_input, model.engine) | ||
result = list_embeddings_to_response(embeddings, model.engine_args.model_name_or_path, usage) | ||
return OpenAIEmbeddingResult(**result).model_dump() | ||
|
||
async def infinity_embed(self, embedding_input, model): | ||
embeddings, usage = await self._embed(embedding_input, model.engine) | ||
return list_embeddings_to_response(embeddings, model.engine_args.model_name_or_path, usage) | ||
|
||
async def infinity_rerank(self, input, model): | ||
query, docs, return_docs = input["query"], input["docs"], input["return_docs"] | ||
async with model.engine: | ||
scores, usage = await model.engine.rerank(query=query, docs=docs, raw_scores=False) | ||
if not return_docs: | ||
docs = None | ||
return to_rerank_response(scores=scores, documents=docs, model=model.engine_args.model_name_or_path, usage=usage) | ||
|
||
|
||
|
||
class EmbeddingModel: | ||
def __init__(self, model_name, batch_size, backend): | ||
print(f"Initializing model {model_name} with batch size {batch_size} and backend {backend}") | ||
self.model_name = model_name | ||
self.batch_size = batch_size | ||
self.engine_args = EngineArgs(model_name_or_path=model_name, batch_size=batch_size, engine=backend) | ||
self.engine = AsyncEmbeddingEngine.from_args(self.engine_args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,83 +1,46 @@ | ||
""" Example handler file. """ | ||
import os | ||
|
||
import runpod | ||
from utils import create_error_response, OpenAIEmbeddingInput, OpenAIEmbeddingResult, OpenAIModelInfo, ModelInfo, list_embeddings_to_response | ||
from typing import Optional, Union | ||
from infinity_emb import AsyncEmbeddingEngine, EngineArgs | ||
|
||
model_names = os.environ.get("MODEL_NAMES") | ||
if not model_names: | ||
raise ValueError("MODEL_NAMES environment variable is required") | ||
model_names = model_names.split(";") | ||
default_model_name = None | ||
if len(model_names) == 1: | ||
default_model_name = model_names[0] | ||
|
||
batch_size = int(os.environ.get("BATCH_SIZE", 64)) | ||
from utils import create_error_response | ||
from embedding_service import EmbeddingService | ||
|
||
backend = os.environ.get("BACKEND", "torch") | ||
|
||
engine_args_list = {model_name: EngineArgs(model_name_or_path=model_name, engine=backend, batch_size=batch_size) for model_name in model_names} | ||
engines = {model_name: AsyncEmbeddingEngine(engine_args) for model_name, engine_args in engine_args_list.items()} | ||
embedding_service = EmbeddingService() | ||
|
||
async def async_generator_handler(job): | ||
job_input = job['input'] | ||
|
||
openai_route = job_input.get("openai_route") | ||
openai_route, openai_input = job_input.get("openai_route"), job_input.get("openai_input") | ||
if openai_route: | ||
openai_input = job_input.get("openai_input") | ||
model_name = openai_input.get("model") | ||
engine = engines.get(model_name) | ||
engine_args = engine_args_list.get(model_name) | ||
if not engine: | ||
return create_error_response(f"Model '{model_name}' not found").model_dump() | ||
|
||
if openai_route == "/v1/embeddings": | ||
embedding_input = openai_input.get("input") | ||
if isinstance(embedding_input, str): | ||
embedding_input = [embedding_input] | ||
try: | ||
embeddings, usage = await engine.embed(embedding_input) | ||
result = list_embeddings_to_response(embeddings, model_name, usage) | ||
return OpenAIEmbeddingResult(**result).model_dump() | ||
except Exception as e: | ||
return create_error_response(str(e)).model_dump() | ||
elif openai_route == "/v1/models": | ||
return OpenAIModelInfo(data=ModelInfo(id=engine_args.model_name_or_path, stats=dict(batch_size=engine_args.batch_size), backend=engine_args.engine.name)).model_dump() | ||
if openai_route == "/v1/models": | ||
return embedding_service.openai_get_models() | ||
elif openai_route == "/v1/embeddings": | ||
if not openai_input: | ||
return create_error_response("Missing input").model_dump() | ||
requested_model_name = openai_input.get("model") | ||
input_to_process = openai_input.get("input") | ||
processor_func = embedding_service.openai_get_embeddings | ||
else: | ||
return create_error_response(f"Invalid route: {openai_route}").model_dump() | ||
|
||
return create_error_response(f"Invalid OpenAI Route: {openai_route}").model_dump() | ||
else: | ||
model_name = job_input.get("model_name") | ||
request_type = job_input.get("request_type", "embed") | ||
|
||
if not model_name: | ||
if not default_model_name: | ||
return {"error": "model_name input is required when there is more than one model"} | ||
else: | ||
model_name = default_model_name | ||
engine = engines.get(model_name) | ||
requested_model_name = job_input.get("model") or embedding_service.config.default_model_name | ||
input = job_input.get("input") | ||
query, docs, return_docs = job_input.get("query"), job_input.get("docs"), job_input.get("return_docs") | ||
if query: | ||
input_to_process = {"query": query, "docs": docs, "return_docs": return_docs} | ||
processor_func = embedding_service.infinity_rerank | ||
elif input: | ||
### Non-OpenAI Embed not available until fixed | ||
# input_to_process = job_input.get("sentences", []) | ||
# processor_func = embedding_service.infinity_embed | ||
input_to_process = input | ||
processor_func = embedding_service.openai_get_embeddings | ||
else: | ||
return create_error_response(f"Invalid input: {job_input}").model_dump() | ||
|
||
if request_type == "embed": | ||
sentences = job_input.get("sentences", []) | ||
if not sentences: | ||
return {"error": "'sentences' input is required for embedding"} | ||
if not all(isinstance(sentence, str) for sentence in sentences): | ||
return {"error": "sentences must be a list of strings"} | ||
async with engine: | ||
embeddings, usage = await engine.embed(sentences=sentences) | ||
return {"embeddings": embeddings, "usage": usage} | ||
elif request_type == "rerank": | ||
query = job_input.get("query") | ||
if not query: | ||
return {"error": "'query' input is required"} | ||
docs = job_input.get("docs", []) | ||
if not docs: | ||
return {"error": "'docs' input is required for reranking"} | ||
async with engine: | ||
ranking, usage = await engine.rerank(query=query, docs=docs) | ||
return {"ranking": ranking, "usage": usage} | ||
|
||
model = embedding_service.models.get(requested_model_name) | ||
if not model: | ||
return create_error_response(f"Model '{requested_model_name}' not found").model_dump() | ||
|
||
try: | ||
return await processor_func(input_to_process, model) | ||
except Exception as e: | ||
return create_error_response(str(e)).model_dump() | ||
|
||
runpod.serverless.start({"handler": async_generator_handler}) | ||
runpod.serverless.start({"handler": async_generator_handler, "concurrency_modifier": lambda x: 3000}) |
Oops, something went wrong.