Skip to content

Commit

Permalink
checks
Browse files Browse the repository at this point in the history
  • Loading branch information
rachelspark committed Oct 16, 2023
1 parent b3be84f commit 2c0fb13
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions 06_gpu_and_ml/vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# `vLLM` also supports a use case as a FastAPI server which we will explore in a future guide. This example
# walks through setting up an environment that works with `vLLM ` for basic inference.
#
# We are running the [Mistral 7B Instruct](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) model here, which is an instruct fine-tuned version of Mistral's 7B model best fit for conversation.
# You can expect 20 second cold starts and well over 100 tokens/second. The larger the batch of prompts, the higher the throughput.
# We are running the [Mistral 7B Instruct](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) model here, which is an instruct fine-tuned version of Mistral's 7B model best fit for conversation.
# You can expect 20 second cold starts and well over 100 tokens/second. The larger the batch of prompts, the higher the throughput.
# For example, with the 60 prompts below, we can produce 19k tokens in 15 seconds, which is around 1.25k tokens/second.
#
# To run
Expand Down Expand Up @@ -58,14 +58,13 @@ def download_model_to_folder():
#
image = (
Image.from_dockerhub("nvcr.io/nvidia/pytorch:22.12-py3")
.pip_install(
"torch==2.0.1", index_url="https://download.pytorch.org/whl/cu118"
)
.pip_install("torch==2.0.1", index_url="https://download.pytorch.org/whl/cu118")
.apt_install("git")
# Download latest version of vLLM
.run_commands(
"git clone https://github.com/vllm-project/vllm.git",
"cd vllm && pip install -e .",)
"cd vllm && pip install -e .",
)
# Use the barebones hf-transfer package for maximum download speeds. No progress bar, but expect 700MB/s.
.pip_install("hf-transfer~=0.1")
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
Expand Down Expand Up @@ -93,15 +92,18 @@ def __enter__(self):

# Load the model. Tip: MPT models may require `trust_remote_code=true`.
self.llm = LLM(MODEL_DIR)
self.template = """SYSTEM: You are a helpful assistant.
USER: {}
ASSISTANT: """
self.template = """"<s>[INST] <<SYS>>
{system}
<</SYS>>
{user} [/INST] """

@method()
def generate(self, user_questions):
from vllm import SamplingParams

prompts = [self.template.format(q) for q in user_questions]
prompts = [self.template.format(system="", user=q) for q in user_questions]

sampling_params = SamplingParams(
temperature=0.75,
top_p=1,
Expand Down Expand Up @@ -189,4 +191,4 @@ def main():
"Who were the 'Dog-Headed Saint' and the 'Lion-Faced Saint' in medieval Christian traditions?",
"What is the story of the 'Globsters', unidentified organic masses washed up on the shores?",
]
model.generate.call(questions)
model.generate.remote(questions)

0 comments on commit 2c0fb13

Please sign in to comment.