Skip to content

Commit

Permalink
TGI Llama3 (#705)
Browse files Browse the repository at this point in the history
  • Loading branch information
gongy authored Apr 18, 2024
1 parent f22657d commit 46d5fd6
Showing 1 changed file with 29 additions and 29 deletions.
58 changes: 29 additions & 29 deletions 06_gpu_and_ml/llm-serving/text_generation_inference.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# # Hosting any LLaMA 2 model with Text Generation Inference (TGI)
# # Hosting any LLaMA 3 model with Text Generation Inference (TGI)
#
# In this example, we show how to run an optimized inference server using [Text Generation Inference (TGI)](https://github.com/huggingface/text-generation-inference)
# with performance advantages over standard text generation pipelines including:
# - continuous batching, so multiple generations can take place at the same time on a single container
# - PagedAttention, which applies memory paging to the attention mechanism's key-value cache, increasing throughput
#
# This example deployment, [accessible here](https://modal-labs--tgi-app.modal.run), can serve LLaMA 2 70B with
# This example deployment, [accessible here](https://modal-labs--tgi-app.modal.run), can serve LLaMA 3 70B with
# 70 second cold starts, up to 200 tokens/s of throughput, and a per-token latency of 55ms.

# ## Setup
Expand All @@ -24,16 +24,16 @@
#
# Any model supported by TGI can be chosen here.

MODEL_ID = "meta-llama/Llama-2-70b-chat-hf"
REVISION = "e1ce257bd76895e0864f3b4d6c7ed3c4cdec93e2"
MODEL_ID = "meta-llama/Meta-Llama-3-70B-Instruct"
MODEL_REVISION = "81ca4500337d94476bda61d84f0c93af67e4495f"
# Add `["--quantize", "gptq"]` for TheBloke GPTQ models.
LAUNCH_FLAGS = [
"--model-id",
MODEL_ID,
"--port",
"8000",
"--revision",
REVISION,
MODEL_REVISION,
]

# ## Define a container image
Expand All @@ -56,13 +56,8 @@ def download_model():
"download-weights",
MODEL_ID,
"--revision",
REVISION,
MODEL_REVISION,
],
env={
**os.environ,
"HUGGING_FACE_HUB_TOKEN": os.environ["HF_TOKEN"],
},
check=True,
)


Expand All @@ -73,23 +68,23 @@ def download_model():
# Next we run the download step to pre-populate the image with our model weights.
#
# For this step to work on a [gated model](https://github.com/huggingface/text-generation-inference#using-a-private-or-gated-model)
# such as LLaMA 2, the `HF_TOKEN` environment variable must be set.
# such as LLaMA 3, the `HF_TOKEN` environment variable must be set.
#
# After [creating a HuggingFace access token](https://huggingface.co/settings/tokens)
# and accepting the [LLaMA 2 license](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf),
# and accepting the [LLaMA 3 license](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct),
# head to the [secrets page](https://modal.com/secrets) to share it with Modal
#
# Finally, we install the `text-generation` client to interface with TGI's Rust webserver over `localhost`.

stub = Stub("example-tgi-" + MODEL_ID.split("/")[-1])

tgi_image = (
Image.from_registry(
"ghcr.io/huggingface/text-generation-inference:1.4", add_python="3.10"
)
Image.from_registry("ghcr.io/huggingface/text-generation-inference:1.4")
.dockerfile_commands("ENTRYPOINT []")
.run_function(
download_model, secrets=[Secret.from_name("huggingface-secret")]
download_model,
secrets=[Secret.from_name("huggingface-secret")],
timeout=3600,
)
.pip_install("text-generation")
)
Expand All @@ -115,13 +110,13 @@ def download_model():
# - increase the timeout limit


GPU_CONFIG = gpu.A100(memory=80, count=2) # 2 A100s for LLaMA 2 70B
GPU_CONFIG = gpu.H100(count=2) # 2 H100s


@stub.cls(
secrets=[Secret.from_name("huggingface-secret")],
gpu=GPU_CONFIG,
allow_concurrent_inputs=10,
allow_concurrent_inputs=15,
container_idle_timeout=60 * 10,
timeout=60 * 60,
image=tgi_image,
Expand All @@ -142,11 +137,11 @@ def start_server(self):
},
)
self.client = AsyncClient("http://127.0.0.1:8000", timeout=60)
self.template = """<s>[INST] <<SYS>>
{system}
<</SYS>>
self.template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{user} [/INST] """
"""

# Poll until webserver at 127.0.0.1:8000 accepts connections before running inputs.
def webserver_ready():
Expand Down Expand Up @@ -174,19 +169,24 @@ def terminate_server(self):

@method()
async def generate(self, question: str):
prompt = self.template.format(system="", user=question)
result = await self.client.generate(prompt, max_new_tokens=1024)
prompt = self.template.format(user=question)
result = await self.client.generate(
prompt, max_new_tokens=1024, stop_sequences=["<|eot_id|>"]
)

return result.generated_text

@method()
async def generate_stream(self, question: str):
prompt = self.template.format(system="", user=question)
prompt = self.template.format(user=question)

async for response in self.client.generate_stream(
prompt, max_new_tokens=1024
prompt, max_new_tokens=1024, stop_sequences=["<|eot_id|>"]
):
if not response.token.special:
if (
not response.token.special
and response.token.text != "<|eot_id|>"
):
yield response.token.text


Expand Down Expand Up @@ -262,7 +262,7 @@ async def generate():
# ```
# $ python
# >>> import modal
# >>> f = modal.Function.lookup("example-tgi-Llama-2-70b-chat-hf", "Model.generate")
# >>> f = modal.Function.lookup("example-tgi-Meta-Llama-3-70B-Instruct", "Model.generate")
# >>> f.remote("What is the story about the fox and grapes?")
# 'The story about the fox and grapes ...
# ```

0 comments on commit 46d5fd6

Please sign in to comment.