diff --git a/06_gpu_and_ml/tgi_mixtral.py b/06_gpu_and_ml/tgi_mixtral.py index 63efb02e6..9fda681dd 100644 --- a/06_gpu_and_ml/tgi_mixtral.py +++ b/06_gpu_and_ml/tgi_mixtral.py @@ -23,7 +23,7 @@ # # Any model supported by TGI can be chosen here. -GPU_CONFIG = gpu.A100(memory=80, count=2) +GPU_CONFIG = gpu.A100(memory=40, count=4) MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1" # Add `["--quantize", "gptq"]` for TheBloke GPTQ models. LAUNCH_FLAGS = [ @@ -190,7 +190,7 @@ async def stats(): return { "backlog": stats.backlog, "num_total_runners": stats.num_total_runners, - "model": MODEL_ID, + "model": MODEL_ID + " (TGI)", } @web_app.get("/completion/{question}") diff --git a/06_gpu_and_ml/vllm_mixtral.py b/06_gpu_and_ml/vllm_mixtral.py index f92a5c366..91fce6abd 100644 --- a/06_gpu_and_ml/vllm_mixtral.py +++ b/06_gpu_and_ml/vllm_mixtral.py @@ -12,10 +12,6 @@ # The larger the batch of prompts, the higher the throughput (up to about 300 tokens/second). # For example, with the 60 prompts below, we can produce 30k tokens in 100 seconds. # -# To run -# [any of the other supported models](https://vllm.readthedocs.io/en/latest/models/supported_models.html), -# simply replace the model name in the download step. You may also need to enable `trust_remote_code` for MPT models (see comment below).. -# # ## Setup # # First we import the components we need from `modal`. @@ -43,6 +39,7 @@ # Tip: avoid using global variables in this function. Changes to code outside this function will not be detected and the download step will not re-run. def download_model_to_folder(): from huggingface_hub import snapshot_download + from transformers.utils import move_cache os.makedirs(MODEL_DIR, exist_ok=True) @@ -51,6 +48,7 @@ def download_model_to_folder(): local_dir=MODEL_DIR, ignore_patterns="*.safetensors", # vLLM doesn't support Mixtral safetensors anyway. ) + move_cache() # ### Image definition @@ -86,12 +84,16 @@ def download_model_to_folder(): # The `vLLM` library allows the code to remain quite clean. There are, however, some # outstanding issues and performance improvements that we patch here, such as multi-GPU setup and # suboptimal Ray CPU pinning. -@stub.cls(gpu=GPU_CONFIG, timeout=60 * 10, container_idle_timeout=60 * 10) +@stub.cls( + gpu=GPU_CONFIG, + timeout=60 * 10, + container_idle_timeout=60 * 10, + allow_concurrent_inputs=10 +) class Model: def __enter__(self): - import subprocess - - from vllm import LLM + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.engine.async_llm_engine import AsyncLLMEngine if GPU_CONFIG.count > 1: # Patch issue from https://github.com/vllm-project/vllm/issues/1116 @@ -100,109 +102,135 @@ def __enter__(self): ray.shutdown() ray.init(num_gpus=GPU_CONFIG.count) - # Load the model. Tip: MPT models may require `trust_remote_code=true`. - self.llm = LLM(MODEL_DIR, tensor_parallel_size=GPU_CONFIG.count) + engine_args = AsyncEngineArgs( + model=MODEL_DIR, + tensor_parallel_size=GPU_CONFIG.count, + gpu_memory_utilization=0.90, + ) + + self.engine = AsyncLLMEngine.from_engine_args(engine_args) self.template = " [INST] {user} [/INST] " + # Performance improvement from https://github.com/vllm-project/vllm/issues/2073#issuecomment-1853422529 if GPU_CONFIG.count > 1: - # Performance improvement from https://github.com/vllm-project/vllm/issues/2073#issuecomment-1853422529 + import subprocess + RAY_CORE_PIN_OVERRIDE = "cpuid=0 ; for pid in $(ps xo '%p %c' | grep ray:: | awk '{print $1;}') ; do taskset -cp $cpuid $pid ; cpuid=$(($cpuid + 1)) ; done" subprocess.call(RAY_CORE_PIN_OVERRIDE, shell=True) @method() - def generate(self, user_questions): + async def completion_stream(self, user_question): from vllm import SamplingParams - - prompts = [self.template.format(user=q) for q in user_questions] + from vllm.utils import random_uuid sampling_params = SamplingParams( temperature=0.75, - top_p=1, - max_tokens=800, - presence_penalty=1.15, + max_tokens=1024, + repetition_penalty=1.1, ) t0 = time.time() - result = self.llm.generate(prompts, sampling_params) - num_tokens = 0 - for output in result: - num_tokens += len(output.outputs[0].token_ids) - print(output.prompt, output.outputs[0].text, "\n\n", sep="") + request_id = random_uuid() + result_generator = self.engine.generate( + self.template.format(user=user_question), + sampling_params, + request_id, + ) + index = 0 + async for output in result_generator: + if "\ufffd" == output.outputs[0].text[-1]: + continue + text_delta = output.outputs[0].text[index:] + index = len(output.outputs[0].text) + + yield text_delta - print(f"Generated {num_tokens} tokens in {time.time() - t0:.2f}s") + print(f"Generated {index} tokens in {time.time() - t0:.2f}s") # ## Run the model # We define a [`local_entrypoint`](/docs/guide/apps#entrypoints-for-ephemeral-apps) to call our remote function -# sequentially for a list of inputs. You can run this locally with `modal run vllm_mixtral.py`. +# sequentially for a list of inputs. You can run this locally with `modal run -q vllm_mixtral.py`. The `q` flag +# enables the text to stream in your local terminal. @stub.local_entrypoint() def main(): model = Model() questions = [ - # Coding questions "Implement a Python function to compute the Fibonacci numbers.", - "Write a Rust function that performs binary exponentiation.", - "How do I allocate memory in C?", - "What are the differences between Javascript and Python?", - "How do I find invalid indices in Postgres?", - "How can you implement a LRU (Least Recently Used) cache in Python?", - "What approach would you use to detect and prevent race conditions in a multithreaded application?", - "Can you explain how a decision tree algorithm works in machine learning?", - "How would you design a simple key-value store database from scratch?", - "How do you handle deadlock situations in concurrent programming?", - "What is the logic behind the A* search algorithm, and where is it used?", - "How can you design an efficient autocomplete system?", - "What approach would you take to design a secure session management system in a web application?", - "How would you handle collision in a hash table?", - "How can you implement a load balancer for a distributed system?", - # Literature "What is the fable involving a fox and grapes?", - "Write a story in the style of James Joyce about a trip to the Australian outback in 2083, to see robots in the beautiful desert.", - "Who does Harry turn into a balloon?", - "Write a tale about a time-traveling historian who's determined to witness the most significant events in human history.", - "Describe a day in the life of a secret agent who's also a full-time parent.", - "Create a story about a detective who can communicate with animals.", - "What is the most unusual thing about living in a city floating in the clouds?", - "In a world where dreams are shared, what happens when a nightmare invades a peaceful dream?", - "Describe the adventure of a lifetime for a group of friends who found a map leading to a parallel universe.", - "Tell a story about a musician who discovers that their music has magical powers.", - "In a world where people age backwards, describe the life of a 5-year-old man.", - "Create a tale about a painter whose artwork comes to life every night.", - "What happens when a poet's verses start to predict future events?", - "Imagine a world where books can talk. How does a librarian handle them?", - "Tell a story about an astronaut who discovered a planet populated by plants.", - "Describe the journey of a letter traveling through the most sophisticated postal service ever.", - "Write a tale about a chef whose food can evoke memories from the eater's past.", - # History "What were the major contributing factors to the fall of the Roman Empire?", - "How did the invention of the printing press revolutionize European society?", - "What are the effects of quantitative easing?", - "How did the Greek philosophers influence economic thought in the ancient world?", - "What were the economic and philosophical factors that led to the fall of the Soviet Union?", - "How did decolonization in the 20th century change the geopolitical map?", - "What was the influence of the Khmer Empire on Southeast Asia's history and culture?", - # Thoughtfulness "Describe the city of the future, considering advances in technology, environmental changes, and societal shifts.", - "In a dystopian future where water is the most valuable commodity, how would society function?", - "If a scientist discovers immortality, how could this impact society, economy, and the environment?", - "What could be the potential implications of contact with an advanced alien civilization?", - # Math "What is the product of 9 and 8?", - "If a train travels 120 kilometers in 2 hours, what is its average speed?", - "Think through this step by step. If the sequence a_n is defined by a_1 = 3, a_2 = 5, and a_n = a_(n-1) + a_(n-2) for n > 2, find a_6.", - "Think through this step by step. Calculate the sum of an arithmetic series with first term 3, last term 35, and total terms 11.", - "Think through this step by step. What is the area of a triangle with vertices at the points (1,2), (3,-4), and (-2,5)?", - "Think through this step by step. Solve the following system of linear equations: 3x + 2y = 14, 5x - y = 15.", - # Facts "Who was Emperor Norton I, and what was his significance in San Francisco's history?", - "What is the Voynich manuscript, and why has it perplexed scholars for centuries?", - "What was Project A119 and what were its objectives?", - "What is the 'Dyatlov Pass incident' and why does it remain a mystery?", - "What is the 'Emu War' that took place in Australia in the 1930s?", - "What is the 'Phantom Time Hypothesis' proposed by Heribert Illig?", - "Who was the 'Green Children of Woolpit' as per 12th-century English legend?", - "What are 'zombie stars' in the context of astronomy?", - "Who were the 'Dog-Headed Saint' and the 'Lion-Faced Saint' in medieval Christian traditions?", - "What is the story of the 'Globsters', unidentified organic masses washed up on the shores?", ] - model.generate.remote(questions) + for question in questions: + print("Sending new request:", question) + for text in model.completion_stream.remote_gen(question): + print(text, end="", flush=True) + +# ## Deploy and invoke the model +# Once we deploy this model with `modal deploy text_generation_inference.py`, +# we can invoke inference from other apps, sharing the same pool +# of GPU containers with all other apps we might need. +# +# ``` +# $ python +# >>> import modal +# >>> f = modal.Function.lookup("example-tgi-Mixtral-8x7B-Instruct-v0.1", "Model.generate") +# >>> f.remote("What is the story about the fox and grapes?") +# 'The story about the fox and grapes ... +# ``` + +# ## Coupling a frontend web application +# +# We can stream inference from a FastAPI backend, also deployed on Modal. +# +# You can try our deployment [here](https://modal-labs--vllm-mixtral.modal.run). + +from modal import Mount, asgi_app +from pathlib import Path + +frontend_path = Path(__file__).parent / "llm-frontend" + + +@stub.function( + mounts=[Mount.from_local_dir(frontend_path, remote_path="/assets")], + keep_warm=1, + allow_concurrent_inputs=20, + timeout=60 * 10, +) +@asgi_app(label="vllm-mixtral") +def app(): + import json + + import fastapi + import fastapi.staticfiles + from fastapi.responses import StreamingResponse + + web_app = fastapi.FastAPI() + + @web_app.get("/stats") + async def stats(): + stats = await Model().generate_stream.get_current_stats.aio() + return { + "backlog": stats.backlog, + "num_total_runners": stats.num_total_runners, + "model": BASE_MODEL + " (vLLM)", + } + + @web_app.get("/completion/{question}") + async def completion(question: str): + from urllib.parse import unquote + + async def generate(): + async for text in Model().completion_stream.remote_gen.aio( + unquote(question) + ): + yield f"data: {json.dumps(dict(text=text), ensure_ascii=False)}\n\n" + + return StreamingResponse(generate(), media_type="text/event-stream") + + web_app.mount( + "/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True) + ) + return web_app