Skip to content

Commit

Permalink
Mixtral TGI example
Browse files Browse the repository at this point in the history
  • Loading branch information
gongy committed Dec 11, 2023
1 parent fb58381 commit 945f2cd
Showing 1 changed file with 225 additions and 0 deletions.
225 changes: 225 additions & 0 deletions 06_gpu_and_ml/tgi_mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# # Hosting Mixtral 8x7B with Text Generation Inference (TGI)
#
# In this example, we show how to run an optimized inference server using [Text Generation Inference (TGI)](https://github.com/huggingface/text-generation-inference)
# with performance advantages over standard text generation pipelines including:
# - continuous batching, so multiple generations can take place at the same time on a single container
# - PagedAttention, an optimization that increases throughput.
#
# This example deployment, [accessible here](https://modal-labs--tgi-mixtral.modal.run), can serve Mixtral 8x7B on two 80GB A100s, with
# 70 second cold starts, up to 200 tokens/s of throughput and per-token latency of 55ms.

# ## Setup
#
# First we import the components we need from `modal`.

from modal import Image, Mount, Stub, asgi_app, gpu, method

from pathlib import Path
import subprocess

# Next, we set which model to serve, taking care to specify the GPU configuration required
# to fit the model into VRAM, and the quantization method (`bitsandbytes` or `gptq`) if desired.
# Note that quantization does degrade token generation performance significantly.
#
# Any model supported by TGI can be chosen here.

GPU_CONFIG = gpu.A100(memory=80, count=2)
MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# Add `["--quantize", "gptq"]` for TheBloke GPTQ models.
LAUNCH_FLAGS = [
"--model-id",
MODEL_ID,
"--port",
"8000",
]

# ## Define a container image
#
# We want to create a Modal image which has the Huggingface model cache pre-populated.
# The benefit of this is that the container no longer has to re-download the model from Huggingface -
# instead, it will take advantage of Modal's internal filesystem for faster cold starts.
# The 95GB model can be loaded in as little as 70 seconds.
#
# ### Download the weights
# We can use the included utilities to download the model weights (and convert to safetensors, if necessary)
# as part of the image build.
#


def download_model():
subprocess.run(
[
"text-generation-server",
"download-weights",
MODEL_ID,
]
)


# ### Image definition
# We’ll start from a Dockerhub image recommended by TGI, and override the default `ENTRYPOINT` for
# Modal to run its own which enables seamless serverless deployments.
#
# Next we run the download step to pre-populate the image with our model weights.
#
# Finally, we install the `text-generation` client to interface with TGI's Rust webserver over `localhost`.

image = (
Image.from_registry("ghcr.io/huggingface/text-generation-inference:1.3.1")
.dockerfile_commands("ENTRYPOINT []")
.run_function(download_model, timeout=60 * 20)
.pip_install("text-generation")
)

stub = Stub("example-tgi-" + MODEL_ID.split("/")[-1], image=image)


# ## The model class
#
# The inference function is best represented with Modal's [class syntax](/docs/guide/lifecycle-functions).
# The class syntax is a special representation for a Modal function which splits logic into two parts:
# 1. the `__enter__` method, which runs once per container when it starts up, and
# 2. the `@method()` function, which runs per inference request.
#
# This means the model is loaded into the GPUs, and the backend for TGI is launched just once when each
# container starts, and this state is cached for each subsequent invocation of the function.
# Note that on start-up, we must wait for the Rust webserver to accept connections before considering the
# container ready.
#
# Here, we also
# - specify how many A100s we need per container
# - specify that each container is allowed to handle up to 10 inputs (i.e. requests) simultaneously
# - keep idle containers for 10 minutes before spinning down
# - lift the timeout of each request.


@stub.cls(
gpu=GPU_CONFIG,
allow_concurrent_inputs=10,
container_idle_timeout=60 * 10,
timeout=60 * 60,
)
class Model:
def __enter__(self):
import socket
import time

from text_generation import AsyncClient

self.launcher = subprocess.Popen(
["text-generation-launcher"] + LAUNCH_FLAGS
)
self.client = AsyncClient("http://127.0.0.1:8000", timeout=60)
self.template = "[INST] {user} [/INST]"

# Poll until webserver at 127.0.0.1:8000 accepts connections before running inputs.
webserver_ready = False
while not webserver_ready:
try:
socket.create_connection(("127.0.0.1", 8000), timeout=1).close()
webserver_ready = True
print("Webserver ready!")
except (socket.timeout, ConnectionRefusedError):
# If launcher process exited, a connection can never be made.
if retcode := self.launcher.poll():
raise RuntimeError(f"launcher exited with code {retcode}")
time.sleep(1.0)


def __exit__(self, _exc_type, _exc_value, _traceback):
self.launcher.terminate()

@method()
async def generate(self, question: str):
prompt = self.template.format(user=question)
result = await self.client.generate(prompt, max_new_tokens=1024)

return result.generated_text

@method()
async def generate_stream(self, question: str):
prompt = self.template.format(user=question)

async for response in self.client.generate_stream(
prompt, max_new_tokens=2048
):
if not response.token.special:
yield response.token.text


# ## Run the model
# We define a [`local_entrypoint`](/docs/guide/apps#entrypoints-for-ephemeral-apps) to invoke
# our remote function. You can run this script locally with `modal run text_generation_inference.py`.
@stub.local_entrypoint()
def main():
print(
Model().generate.remote(
"Implement a Python function to compute the Fibonacci numbers."
)
)


# ## Serve the model
# Once we deploy this model with `modal deploy text_generation_inference.py`, we can serve it
# behind an ASGI app front-end. The front-end code (a single file of Alpine.js) is available
# [here](https://github.com/modal-labs/modal-examples/blob/main/06_gpu_and_ml/llm-frontend/index.html).
#
# You can try our deployment [here](https://modal-labs--tgi-mixtral.modal.run).

frontend_path = Path(__file__).parent / "llm-frontend"


@stub.function(
mounts=[Mount.from_local_dir(frontend_path, remote_path="/assets")],
keep_warm=1,
allow_concurrent_inputs=10,
timeout=60 * 10,
)
@asgi_app(label="tgi-mixtral")
def app():
import json

import fastapi
import fastapi.staticfiles
from fastapi.responses import StreamingResponse

web_app = fastapi.FastAPI()

@web_app.get("/stats")
async def stats():
stats = await Model().generate_stream.get_current_stats.aio()
return {
"backlog": stats.backlog,
"num_total_runners": stats.num_total_runners,
"model": MODEL_ID,
}

@web_app.get("/completion/{question}")
async def completion(question: str):
from urllib.parse import unquote

async def generate():
async for text in Model().generate_stream.remote_gen.aio(
unquote(question)
):
yield f"data: {json.dumps(dict(text=text), ensure_ascii=False)}\n\n"

return StreamingResponse(generate(), media_type="text/event-stream")

web_app.mount(
"/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True)
)
return web_app


# ## Invoke the model from other apps
# Once the model is deployed, we can invoke inference from other apps, sharing the same pool
# of GPU containers with all other apps we might need.
#
# ```
# $ python
# >>> import modal
# >>> f = modal.Function.lookup("example-tgi-Mixtral-8x7B-Instruct-v0.1", "Model.generate")
# >>> f.remote("What is the story about the fox and grapes?")
# 'The story about the fox and grapes ...
# ```

0 comments on commit 945f2cd

Please sign in to comment.