Skip to content

Commit

Permalink
vLLM as inference server
Browse files Browse the repository at this point in the history
  • Loading branch information
gongy committed Dec 13, 2023
1 parent e1172b1 commit abbef63
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 86 deletions.
4 changes: 2 additions & 2 deletions 06_gpu_and_ml/tgi_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#
# Any model supported by TGI can be chosen here.

GPU_CONFIG = gpu.A100(memory=80, count=2)
GPU_CONFIG = gpu.A100(memory=40, count=4)
MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# Add `["--quantize", "gptq"]` for TheBloke GPTQ models.
LAUNCH_FLAGS = [
Expand Down Expand Up @@ -190,7 +190,7 @@ async def stats():
return {
"backlog": stats.backlog,
"num_total_runners": stats.num_total_runners,
"model": MODEL_ID,
"model": MODEL_ID + " (TGI)",
}

@web_app.get("/completion/{question}")
Expand Down
196 changes: 112 additions & 84 deletions 06_gpu_and_ml/vllm_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@
# The larger the batch of prompts, the higher the throughput (up to about 300 tokens/second).
# For example, with the 60 prompts below, we can produce 30k tokens in 100 seconds.
#
# To run
# [any of the other supported models](https://vllm.readthedocs.io/en/latest/models/supported_models.html),
# simply replace the model name in the download step. You may also need to enable `trust_remote_code` for MPT models (see comment below)..
#
# ## Setup
#
# First we import the components we need from `modal`.
Expand Down Expand Up @@ -43,6 +39,7 @@
# Tip: avoid using global variables in this function. Changes to code outside this function will not be detected and the download step will not re-run.
def download_model_to_folder():
from huggingface_hub import snapshot_download
from transformers.utils import move_cache

os.makedirs(MODEL_DIR, exist_ok=True)

Expand All @@ -51,6 +48,7 @@ def download_model_to_folder():
local_dir=MODEL_DIR,
ignore_patterns="*.safetensors", # vLLM doesn't support Mixtral safetensors anyway.
)
move_cache()


# ### Image definition
Expand Down Expand Up @@ -86,12 +84,16 @@ def download_model_to_folder():
# The `vLLM` library allows the code to remain quite clean. There are, however, some
# outstanding issues and performance improvements that we patch here, such as multi-GPU setup and
# suboptimal Ray CPU pinning.
@stub.cls(gpu=GPU_CONFIG, timeout=60 * 10, container_idle_timeout=60 * 10)
@stub.cls(
gpu=GPU_CONFIG,
timeout=60 * 10,
container_idle_timeout=60 * 10,
allow_concurrent_inputs=10
)
class Model:
def __enter__(self):
import subprocess

from vllm import LLM
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine

if GPU_CONFIG.count > 1:
# Patch issue from https://github.com/vllm-project/vllm/issues/1116
Expand All @@ -100,109 +102,135 @@ def __enter__(self):
ray.shutdown()
ray.init(num_gpus=GPU_CONFIG.count)

# Load the model. Tip: MPT models may require `trust_remote_code=true`.
self.llm = LLM(MODEL_DIR, tensor_parallel_size=GPU_CONFIG.count)
engine_args = AsyncEngineArgs(
model=MODEL_DIR,
tensor_parallel_size=GPU_CONFIG.count,
gpu_memory_utilization=0.90,
)

self.engine = AsyncLLMEngine.from_engine_args(engine_args)
self.template = "<s> [INST] {user} [/INST] "

# Performance improvement from https://github.com/vllm-project/vllm/issues/2073#issuecomment-1853422529
if GPU_CONFIG.count > 1:
# Performance improvement from https://github.com/vllm-project/vllm/issues/2073#issuecomment-1853422529
import subprocess

RAY_CORE_PIN_OVERRIDE = "cpuid=0 ; for pid in $(ps xo '%p %c' | grep ray:: | awk '{print $1;}') ; do taskset -cp $cpuid $pid ; cpuid=$(($cpuid + 1)) ; done"
subprocess.call(RAY_CORE_PIN_OVERRIDE, shell=True)

@method()
def generate(self, user_questions):
async def completion_stream(self, user_question):
from vllm import SamplingParams

prompts = [self.template.format(user=q) for q in user_questions]
from vllm.utils import random_uuid

sampling_params = SamplingParams(
temperature=0.75,
top_p=1,
max_tokens=800,
presence_penalty=1.15,
max_tokens=1024,
repetition_penalty=1.1,
)

t0 = time.time()
result = self.llm.generate(prompts, sampling_params)
num_tokens = 0
for output in result:
num_tokens += len(output.outputs[0].token_ids)
print(output.prompt, output.outputs[0].text, "\n\n", sep="")
request_id = random_uuid()
result_generator = self.engine.generate(
self.template.format(user=user_question),
sampling_params,
request_id,
)
index = 0
async for output in result_generator:
if "\ufffd" == output.outputs[0].text[-1]:
continue
text_delta = output.outputs[0].text[index:]
index = len(output.outputs[0].text)

yield text_delta

print(f"Generated {num_tokens} tokens in {time.time() - t0:.2f}s")
print(f"Generated {index} tokens in {time.time() - t0:.2f}s")


# ## Run the model
# We define a [`local_entrypoint`](/docs/guide/apps#entrypoints-for-ephemeral-apps) to call our remote function
# sequentially for a list of inputs. You can run this locally with `modal run vllm_mixtral.py`.
# sequentially for a list of inputs. You can run this locally with `modal run -q vllm_mixtral.py`. The `q` flag
# enables the text to stream in your local terminal.
@stub.local_entrypoint()
def main():
model = Model()
questions = [
# Coding questions
"Implement a Python function to compute the Fibonacci numbers.",
"Write a Rust function that performs binary exponentiation.",
"How do I allocate memory in C?",
"What are the differences between Javascript and Python?",
"How do I find invalid indices in Postgres?",
"How can you implement a LRU (Least Recently Used) cache in Python?",
"What approach would you use to detect and prevent race conditions in a multithreaded application?",
"Can you explain how a decision tree algorithm works in machine learning?",
"How would you design a simple key-value store database from scratch?",
"How do you handle deadlock situations in concurrent programming?",
"What is the logic behind the A* search algorithm, and where is it used?",
"How can you design an efficient autocomplete system?",
"What approach would you take to design a secure session management system in a web application?",
"How would you handle collision in a hash table?",
"How can you implement a load balancer for a distributed system?",
# Literature
"What is the fable involving a fox and grapes?",
"Write a story in the style of James Joyce about a trip to the Australian outback in 2083, to see robots in the beautiful desert.",
"Who does Harry turn into a balloon?",
"Write a tale about a time-traveling historian who's determined to witness the most significant events in human history.",
"Describe a day in the life of a secret agent who's also a full-time parent.",
"Create a story about a detective who can communicate with animals.",
"What is the most unusual thing about living in a city floating in the clouds?",
"In a world where dreams are shared, what happens when a nightmare invades a peaceful dream?",
"Describe the adventure of a lifetime for a group of friends who found a map leading to a parallel universe.",
"Tell a story about a musician who discovers that their music has magical powers.",
"In a world where people age backwards, describe the life of a 5-year-old man.",
"Create a tale about a painter whose artwork comes to life every night.",
"What happens when a poet's verses start to predict future events?",
"Imagine a world where books can talk. How does a librarian handle them?",
"Tell a story about an astronaut who discovered a planet populated by plants.",
"Describe the journey of a letter traveling through the most sophisticated postal service ever.",
"Write a tale about a chef whose food can evoke memories from the eater's past.",
# History
"What were the major contributing factors to the fall of the Roman Empire?",
"How did the invention of the printing press revolutionize European society?",
"What are the effects of quantitative easing?",
"How did the Greek philosophers influence economic thought in the ancient world?",
"What were the economic and philosophical factors that led to the fall of the Soviet Union?",
"How did decolonization in the 20th century change the geopolitical map?",
"What was the influence of the Khmer Empire on Southeast Asia's history and culture?",
# Thoughtfulness
"Describe the city of the future, considering advances in technology, environmental changes, and societal shifts.",
"In a dystopian future where water is the most valuable commodity, how would society function?",
"If a scientist discovers immortality, how could this impact society, economy, and the environment?",
"What could be the potential implications of contact with an advanced alien civilization?",
# Math
"What is the product of 9 and 8?",
"If a train travels 120 kilometers in 2 hours, what is its average speed?",
"Think through this step by step. If the sequence a_n is defined by a_1 = 3, a_2 = 5, and a_n = a_(n-1) + a_(n-2) for n > 2, find a_6.",
"Think through this step by step. Calculate the sum of an arithmetic series with first term 3, last term 35, and total terms 11.",
"Think through this step by step. What is the area of a triangle with vertices at the points (1,2), (3,-4), and (-2,5)?",
"Think through this step by step. Solve the following system of linear equations: 3x + 2y = 14, 5x - y = 15.",
# Facts
"Who was Emperor Norton I, and what was his significance in San Francisco's history?",
"What is the Voynich manuscript, and why has it perplexed scholars for centuries?",
"What was Project A119 and what were its objectives?",
"What is the 'Dyatlov Pass incident' and why does it remain a mystery?",
"What is the 'Emu War' that took place in Australia in the 1930s?",
"What is the 'Phantom Time Hypothesis' proposed by Heribert Illig?",
"Who was the 'Green Children of Woolpit' as per 12th-century English legend?",
"What are 'zombie stars' in the context of astronomy?",
"Who were the 'Dog-Headed Saint' and the 'Lion-Faced Saint' in medieval Christian traditions?",
"What is the story of the 'Globsters', unidentified organic masses washed up on the shores?",
]
model.generate.remote(questions)
for question in questions:
print("Sending new request:", question)
for text in model.completion_stream.remote_gen(question):
print(text, end="", flush=True)

# ## Deploy and invoke the model
# Once we deploy this model with `modal deploy text_generation_inference.py`,
# we can invoke inference from other apps, sharing the same pool
# of GPU containers with all other apps we might need.
#
# ```
# $ python
# >>> import modal
# >>> f = modal.Function.lookup("example-tgi-Mixtral-8x7B-Instruct-v0.1", "Model.generate")
# >>> f.remote("What is the story about the fox and grapes?")
# 'The story about the fox and grapes ...
# ```

# ## Coupling a frontend web application
#
# We can stream inference from a FastAPI backend, also deployed on Modal.
#
# You can try our deployment [here](https://modal-labs--vllm-mixtral.modal.run).

from modal import Mount, asgi_app
from pathlib import Path

frontend_path = Path(__file__).parent / "llm-frontend"


@stub.function(
mounts=[Mount.from_local_dir(frontend_path, remote_path="/assets")],
keep_warm=1,
allow_concurrent_inputs=20,
timeout=60 * 10,
)
@asgi_app(label="vllm-mixtral")
def app():
import json

import fastapi
import fastapi.staticfiles
from fastapi.responses import StreamingResponse

web_app = fastapi.FastAPI()

@web_app.get("/stats")
async def stats():
stats = await Model().generate_stream.get_current_stats.aio()
return {
"backlog": stats.backlog,
"num_total_runners": stats.num_total_runners,
"model": BASE_MODEL + " (vLLM)",
}

@web_app.get("/completion/{question}")
async def completion(question: str):
from urllib.parse import unquote

async def generate():
async for text in Model().completion_stream.remote_gen.aio(
unquote(question)
):
yield f"data: {json.dumps(dict(text=text), ensure_ascii=False)}\n\n"

return StreamingResponse(generate(), media_type="text/event-stream")

web_app.mount(
"/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True)
)
return web_app

0 comments on commit abbef63

Please sign in to comment.