generated from runpod-workers/worker-template
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: pandyamarut <[email protected]>
- Loading branch information
1 parent
8dffcc2
commit 9fc3da1
Showing
2 changed files
with
92 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,40 @@ | ||
import os | ||
from sglang import Runtime | ||
import subprocess | ||
import time | ||
import requests | ||
|
||
MODEL_PATH = os.environ.get("MODEL_PATH", "meta-llama/Llama-2-7b-chat-hf") | ||
class OpenAICompatibleEngine: | ||
def __init__(self, model="meta-llama/Meta-Llama-3-8B-Instruct", host="0.0.0.0", port=30000): | ||
self.model = model | ||
self.host = host | ||
self.port = port | ||
self.base_url = f"http://{host}:{port}/v1" | ||
self.process = None | ||
|
||
class SGLangEngine: | ||
def __init__(self): | ||
self.runtime = Runtime(model_path=MODEL_PATH) | ||
self.tokenizer = self.runtime.get_tokenizer() | ||
|
||
async def generate(self, prompt, sampling_params): | ||
messages = [ | ||
{ | ||
"role": "system", | ||
"content": "You will be given question answer tasks.", | ||
}, | ||
{"role": "user", "content": prompt}, | ||
def start_server(self): | ||
command = [ | ||
"python3", "-m", "sglang.launch_server", | ||
"--model", self.model, | ||
"--host", self.host, | ||
"--port", str(self.port) | ||
] | ||
prompt = self.tokenizer.apply_chat_template( | ||
messages, tokenize=False, add_generation_prompt=True | ||
) | ||
|
||
result = [] | ||
stream = self.runtime.add_request(prompt, sampling_params) | ||
async for output in stream: | ||
result.append(output) | ||
|
||
return "".join(result) | ||
self.process = subprocess.Popen(command, stdout=None, stderr=None) | ||
print(f"Server started with PID: {self.process.pid}") | ||
|
||
def wait_for_server(self, timeout=300, interval=5): | ||
start_time = time.time() | ||
while time.time() - start_time < timeout: | ||
try: | ||
response = requests.get(f"{self.base_url}/models") | ||
if response.status_code == 200: | ||
print("Server is ready!") | ||
return True | ||
except requests.RequestException: | ||
pass | ||
time.sleep(interval) | ||
raise TimeoutError("Server failed to start within the timeout period.") | ||
|
||
def shutdown(self): | ||
self.runtime.shutdown() | ||
if self.process: | ||
self.process.terminate() | ||
self.process.wait() | ||
print("Server shut down.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,66 @@ | ||
import subprocess | ||
import time | ||
import requests | ||
import openai | ||
import runpod | ||
import asyncio | ||
import logging | ||
from engine import SGLangEngine | ||
import json | ||
from engine import OpenAICompatibleEngine | ||
# Initialize the engine | ||
engine = OpenAICompatibleEngine() | ||
engine.start_server() | ||
engine.wait_for_server() | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
# Initialize OpenAI client | ||
client = openai.Client(base_url=f"{engine.base_url}/v1", api_key="EMPTY") | ||
|
||
# Initialize the SGLangEngine globally | ||
engine = SGLangEngine() | ||
logger.info("--- SGLang Engine ready ---") | ||
|
||
async def handler(job): | ||
def handler(job): | ||
try: | ||
job_input = job["input"] | ||
prompt = job_input.get("prompt", "") | ||
sampling_params = job_input.get("sampling_params", {"max_new_tokens": 128}) | ||
|
||
response = await engine.generate(prompt, sampling_params) | ||
openai_route = job_input.get("openai_route") | ||
|
||
return {"generated_text": response} | ||
if openai_route: | ||
# Handle OpenAI-compatible routes | ||
openai_input = job_input("openai_input", {}) | ||
if openai_route == "/v1/chat/completions": | ||
response = client.chat.completions.create( | ||
model="default", | ||
messages=openai_input.get("messages", []), | ||
max_tokens=openai_input.get("max_tokens", 100), | ||
temperature=openai_input.get("temperature", 0.7), | ||
) | ||
elif openai_route == "/v1/completions": | ||
response = client.completions.create( | ||
model="default", | ||
prompt=openai_input.get("prompt", ""), | ||
max_tokens=openai_input.get("max_tokens", 100), | ||
temperature=openai_input.get("temperature", 0.7), | ||
) | ||
elif openai_route == "/v1/models": | ||
response = client.models.list() | ||
else: | ||
return {"error": f"Unsupported openai_route: {openai_route}"} | ||
|
||
return response.model_dump() | ||
else: | ||
# Call /generate endpoint | ||
generate_url = f"{engine.base_url}/generate" | ||
headers = {"Content-Type": "application/json"} | ||
generate_data = { | ||
"text": job_input.get("prompt", ""), | ||
"sampling_params": job_input.get("sampling_params", {}) | ||
} | ||
response = requests.post(generate_url, json=generate_data, headers=headers) | ||
|
||
if response.status_code == 200: | ||
return response.json() | ||
else: | ||
return {"error": f"Generate request failed with status code {response.status_code}", "details": response.text} | ||
|
||
except Exception as e: | ||
logger.error(f"Error in handler: {str(e)}") | ||
return {"error": str(e)} | ||
|
||
runpod.serverless.start({"handler": handler}) | ||
runpod.serverless.start({"handler": handler}) | ||
|
||
# Ensure the server is shut down when the serverless function is terminated | ||
import atexit | ||
atexit.register(engine.shutdown) |