Skip to content

Commit

Permalink
add- openai-compat
Browse files Browse the repository at this point in the history
Signed-off-by: pandyamarut <[email protected]>
  • Loading branch information
pandyamarut committed Jul 29, 2024
1 parent 8dffcc2 commit 9fc3da1
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 43 deletions.
61 changes: 35 additions & 26 deletions src/engine.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,40 @@
import os
from sglang import Runtime
import subprocess
import time
import requests

MODEL_PATH = os.environ.get("MODEL_PATH", "meta-llama/Llama-2-7b-chat-hf")
class OpenAICompatibleEngine:
def __init__(self, model="meta-llama/Meta-Llama-3-8B-Instruct", host="0.0.0.0", port=30000):
self.model = model
self.host = host
self.port = port
self.base_url = f"http://{host}:{port}/v1"
self.process = None

class SGLangEngine:
def __init__(self):
self.runtime = Runtime(model_path=MODEL_PATH)
self.tokenizer = self.runtime.get_tokenizer()

async def generate(self, prompt, sampling_params):
messages = [
{
"role": "system",
"content": "You will be given question answer tasks.",
},
{"role": "user", "content": prompt},
def start_server(self):
command = [
"python3", "-m", "sglang.launch_server",
"--model", self.model,
"--host", self.host,
"--port", str(self.port)
]
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

result = []
stream = self.runtime.add_request(prompt, sampling_params)
async for output in stream:
result.append(output)

return "".join(result)
self.process = subprocess.Popen(command, stdout=None, stderr=None)
print(f"Server started with PID: {self.process.pid}")

def wait_for_server(self, timeout=300, interval=5):
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(f"{self.base_url}/models")
if response.status_code == 200:
print("Server is ready!")
return True
except requests.RequestException:
pass
time.sleep(interval)
raise TimeoutError("Server failed to start within the timeout period.")

def shutdown(self):
self.runtime.shutdown()
if self.process:
self.process.terminate()
self.process.wait()
print("Server shut down.")
74 changes: 57 additions & 17 deletions src/handler.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,66 @@
import subprocess
import time
import requests
import openai
import runpod
import asyncio
import logging
from engine import SGLangEngine
import json
from engine import OpenAICompatibleEngine
# Initialize the engine
engine = OpenAICompatibleEngine()
engine.start_server()
engine.wait_for_server()

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize OpenAI client
client = openai.Client(base_url=f"{engine.base_url}/v1", api_key="EMPTY")

# Initialize the SGLangEngine globally
engine = SGLangEngine()
logger.info("--- SGLang Engine ready ---")

async def handler(job):
def handler(job):
try:
job_input = job["input"]
prompt = job_input.get("prompt", "")
sampling_params = job_input.get("sampling_params", {"max_new_tokens": 128})

response = await engine.generate(prompt, sampling_params)
openai_route = job_input.get("openai_route")

return {"generated_text": response}
if openai_route:
# Handle OpenAI-compatible routes
openai_input = job_input("openai_input", {})
if openai_route == "/v1/chat/completions":
response = client.chat.completions.create(
model="default",
messages=openai_input.get("messages", []),
max_tokens=openai_input.get("max_tokens", 100),
temperature=openai_input.get("temperature", 0.7),
)
elif openai_route == "/v1/completions":
response = client.completions.create(
model="default",
prompt=openai_input.get("prompt", ""),
max_tokens=openai_input.get("max_tokens", 100),
temperature=openai_input.get("temperature", 0.7),
)
elif openai_route == "/v1/models":
response = client.models.list()
else:
return {"error": f"Unsupported openai_route: {openai_route}"}

return response.model_dump()
else:
# Call /generate endpoint
generate_url = f"{engine.base_url}/generate"
headers = {"Content-Type": "application/json"}
generate_data = {
"text": job_input.get("prompt", ""),
"sampling_params": job_input.get("sampling_params", {})
}
response = requests.post(generate_url, json=generate_data, headers=headers)

if response.status_code == 200:
return response.json()
else:
return {"error": f"Generate request failed with status code {response.status_code}", "details": response.text}

except Exception as e:
logger.error(f"Error in handler: {str(e)}")
return {"error": str(e)}

runpod.serverless.start({"handler": handler})
runpod.serverless.start({"handler": handler})

# Ensure the server is shut down when the serverless function is terminated
import atexit
atexit.register(engine.shutdown)

0 comments on commit 9fc3da1

Please sign in to comment.