Skip to content

Commit

Permalink
vllm mixtral (#519)
Browse files Browse the repository at this point in the history
  • Loading branch information
gongy authored Dec 13, 2023
1 parent e2bcf5d commit 0fce7f2
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 2 deletions.
4 changes: 2 additions & 2 deletions 06_gpu_and_ml/tgi_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#
# Any model supported by TGI can be chosen here.

GPU_CONFIG = gpu.A100(memory=80, count=2)
GPU_CONFIG = gpu.A100(memory=40, count=4)
MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# Add `["--quantize", "gptq"]` for TheBloke GPTQ models.
LAUNCH_FLAGS = [
Expand Down Expand Up @@ -190,7 +190,7 @@ async def stats():
return {
"backlog": stats.backlog,
"num_total_runners": stats.num_total_runners,
"model": MODEL_ID,
"model": MODEL_ID + " (TGI)",
}

@web_app.get("/completion/{question}")
Expand Down
238 changes: 238 additions & 0 deletions 06_gpu_and_ml/vllm_mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# # Fast inference with vLLM (Mixtral 8x7B)
#
# In this example, we show how to run basic inference, using [`vLLM`](https://github.com/vllm-project/vllm)
# to take advantage of PagedAttention, which speeds up sequential inferences with optimized key-value caching.
#
# `vLLM` also supports a use case as a FastAPI server which we will explore in a future guide. This example
# walks through setting up an environment that works with `vLLM ` for basic inference.
#
# We are running the [Mixtral 8x7B Instruct](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) model here, which is a mixture-of-experts model finetuned for conversation.
# You can expect 3 minute second cold starts
# For a single request, the throughput is about 11 tokens/second, but there are upcoming `vLLM` optimizations to improve this.
# The larger the batch of prompts, the higher the throughput (up to about 300 tokens/second).
# For example, with the 60 prompts below, we can produce 30k tokens in 100 seconds.
#
# ## Setup
#
# First we import the components we need from `modal`.

import os
import time

from modal import Image, Stub, gpu, method

MODEL_DIR = "/model"
BASE_MODEL = "mistralai/Mixtral-8x7B-Instruct-v0.1"
GPU_CONFIG = gpu.A100(memory=80, count=2)


# ## Define a container image
#
# We want to create a Modal image which has the model weights pre-saved to a directory. The benefit of this
# is that the container no longer has to re-download the model from Huggingface - instead, it will take
# advantage of Modal's internal filesystem for faster cold starts.
#
# ### Download the weights
#
# We can download the model to a particular directory using the HuggingFace utility function `snapshot_download`.
#
# Tip: avoid using global variables in this function. Changes to code outside this function will not be detected and the download step will not re-run.
def download_model_to_folder():
from huggingface_hub import snapshot_download
from transformers.utils import move_cache

os.makedirs(MODEL_DIR, exist_ok=True)

snapshot_download(
BASE_MODEL,
local_dir=MODEL_DIR,
ignore_patterns="*.safetensors", # vLLM doesn't support Mixtral safetensors anyway.
)
move_cache()


# ### Image definition
# We’ll start from a Dockerhub image recommended by `vLLM`, and use
# run_function to run the function defined above to ensure the weights of
# the model are saved within the container image.

VLLM_HASH = "89523c8293bc02a4dfaaa80079a5347dc3952464a33a501d5de329921eea7ec7"

image = (
Image.from_registry(
f"vllm/vllm-openai@sha256:{VLLM_HASH}",
setup_dockerfile_commands=[
"RUN apt-get install python-is-python3",
"RUN mv /workspace/* /root",
],
)
.dockerfile_commands("ENTRYPOINT []")
.pip_install("huggingface_hub==0.19.4", "hf-transfer==0.1.4")
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.run_function(download_model_to_folder, timeout=60 * 20)
)

stub = Stub("example-vllm-inference", image=image)


# ## The model class
#
# The inference function is best represented with Modal's [class syntax](/docs/guide/lifecycle-functions) and the `__enter__` method.
# This enables us to load the model into memory just once every time a container starts up, and keep it cached
# on the GPU for each subsequent invocation of the function.
#
# The `vLLM` library allows the code to remain quite clean. There are, however, some
# outstanding issues and performance improvements that we patch here, such as multi-GPU setup and
# suboptimal Ray CPU pinning.
@stub.cls(
gpu=GPU_CONFIG,
timeout=60 * 10,
container_idle_timeout=60 * 10,
allow_concurrent_inputs=10,
)
class Model:
def __enter__(self):
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine

if GPU_CONFIG.count > 1:
# Patch issue from https://github.com/vllm-project/vllm/issues/1116
import ray

ray.shutdown()
ray.init(num_gpus=GPU_CONFIG.count)

engine_args = AsyncEngineArgs(
model=MODEL_DIR,
tensor_parallel_size=GPU_CONFIG.count,
gpu_memory_utilization=0.90,
)

self.engine = AsyncLLMEngine.from_engine_args(engine_args)
self.template = "<s> [INST] {user} [/INST] "

# Performance improvement from https://github.com/vllm-project/vllm/issues/2073#issuecomment-1853422529
if GPU_CONFIG.count > 1:
import subprocess

RAY_CORE_PIN_OVERRIDE = "cpuid=0 ; for pid in $(ps xo '%p %c' | grep ray:: | awk '{print $1;}') ; do taskset -cp $cpuid $pid ; cpuid=$(($cpuid + 1)) ; done"
subprocess.call(RAY_CORE_PIN_OVERRIDE, shell=True)

@method()
async def completion_stream(self, user_question):
from vllm import SamplingParams
from vllm.utils import random_uuid

sampling_params = SamplingParams(
temperature=0.75,
max_tokens=1024,
repetition_penalty=1.1,
)

t0 = time.time()
request_id = random_uuid()
result_generator = self.engine.generate(
self.template.format(user=user_question),
sampling_params,
request_id,
)
index = 0
async for output in result_generator:
if "\ufffd" == output.outputs[0].text[-1]:
continue
text_delta = output.outputs[0].text[index:]
index = len(output.outputs[0].text)

yield text_delta

print(f"Generated {index} tokens in {time.time() - t0:.2f}s")


# ## Run the model
# We define a [`local_entrypoint`](/docs/guide/apps#entrypoints-for-ephemeral-apps) to call our remote function
# sequentially for a list of inputs. You can run this locally with `modal run -q vllm_mixtral.py`. The `q` flag
# enables the text to stream in your local terminal.
@stub.local_entrypoint()
def main():
model = Model()
questions = [
"Implement a Python function to compute the Fibonacci numbers.",
"What is the fable involving a fox and grapes?",
"What were the major contributing factors to the fall of the Roman Empire?",
"Describe the city of the future, considering advances in technology, environmental changes, and societal shifts.",
"What is the product of 9 and 8?",
"Who was Emperor Norton I, and what was his significance in San Francisco's history?",
]
for question in questions:
print("Sending new request:", question)
for text in model.completion_stream.remote_gen(question):
print(text, end="", flush=True)


# ## Deploy and invoke the model
# Once we deploy this model with `modal deploy text_generation_inference.py`,
# we can invoke inference from other apps, sharing the same pool
# of GPU containers with all other apps we might need.
#
# ```
# $ python
# >>> import modal
# >>> f = modal.Function.lookup("example-tgi-Mixtral-8x7B-Instruct-v0.1", "Model.generate")
# >>> f.remote("What is the story about the fox and grapes?")
# 'The story about the fox and grapes ...
# ```

# ## Coupling a frontend web application
#
# We can stream inference from a FastAPI backend, also deployed on Modal.
#
# You can try our deployment [here](https://modal-labs--vllm-mixtral.modal.run).

from pathlib import Path

from modal import Mount, asgi_app

frontend_path = Path(__file__).parent / "llm-frontend"


@stub.function(
mounts=[Mount.from_local_dir(frontend_path, remote_path="/assets")],
keep_warm=1,
allow_concurrent_inputs=20,
timeout=60 * 10,
)
@asgi_app(label="vllm-mixtral")
def app():
import json

import fastapi
import fastapi.staticfiles
from fastapi.responses import StreamingResponse

web_app = fastapi.FastAPI()

@web_app.get("/stats")
async def stats():
stats = await Model().completion_stream.get_current_stats.aio()
return {
"backlog": stats.backlog,
"num_total_runners": stats.num_total_runners,
"model": BASE_MODEL + " (vLLM)",
}

@web_app.get("/completion/{question}")
async def completion(question: str):
from urllib.parse import unquote

async def generate():
async for text in Model().completion_stream.remote_gen.aio(
unquote(question)
):
yield f"data: {json.dumps(dict(text=text), ensure_ascii=False)}\n\n"

return StreamingResponse(generate(), media_type="text/event-stream")

web_app.mount(
"/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True)
)
return web_app

0 comments on commit 0fce7f2

Please sign in to comment.