diff --git a/tech_reports/LLMs/images/continuous_batching.png b/tech_reports/LLMs/images/continuous_batching.png new file mode 100644 index 00000000000..e6c3e4681ec Binary files /dev/null and b/tech_reports/LLMs/images/continuous_batching.png differ diff --git a/tech_reports/LLMs/llms.md b/tech_reports/LLMs/llms.md index 3279e8dc714..2ead7130e1f 100644 --- a/tech_reports/LLMs/llms.md +++ b/tech_reports/LLMs/llms.md @@ -1,6 +1,6 @@ # LLMs in TT-NN -Authors: Mark O'Connor, Djordje Ivanovic, Jack (Xun) Cai, Kartik Paigwar, Johanna Rock, Stuti Raizada, Ammar Vora +Authors: Mark O'Connor, Djordje Ivanovic, Jack (Xun) Cai, Kartik Paigwar, Johanna Rock, Stuti Raizada, Ammar Vora, Colman Glagovich ## Contents - [LLMs in TT-NN](#llms-in-tt-nn) @@ -1025,6 +1025,50 @@ def forward( ## 3. Features ### 3.1 Generative Decoding +Almost every LLM generates text in the same manner: Given a prompt from the user, the LLM predicts the next token. Then, the LLM takes that new token and uses it as context to predict the following token. This process repeats until the LLM generates a token that indicates the end of the sequence, or until the user decides to stop the generation. The process is called "autoregressive generation" because each new token is used to predict the next token. + +#### Model Inputs and Outputs +Inputs to the model for generative decoding are generally: +- tokens: produced by the tokenizer +- position ids: the position of the tokens in the sequence +- KV cache: an inference optimization that caches intermediate values + +In the model, tokens are embedded from the vocabulary space to the embedding space. Position ids are necessary for updating the KV cache and for positional embeddings like RoPE. + +The model outputs: +- logits for the next token +- an updated KV cache + +The logits are unnormalized probabilities over the vocabulary. Given these probabilities, the sampler must decide which of these tokens in the vocabulary will be chosen. There are a few sampling methods that are commonly used to pick the next token: +- Greedy decoding (argmax of the logits, picks the most likely next token) +- Top-p/top-k sampling (restricts the logits according to p and k values, then samples according to the remaining probabilities) + +#### KV cache +The KV cache is an inference optimization. It allows us to cache some intermediate values during the first inference step which are reused in later steps. +On the first inference step, the model processes the full prompt and caches the K and V projections for each layer. Subsequent inference steps compute a Q, K, V projection only for the new token, then use the cached K and V projections in attention. Therefore the first step (prefill) creates the KV cache and subsequent steps (decode) use and update the cache. + +The size of the KV cache depends on the batch size and sequence length. Since accelerators have finite memory, it can be necessary to tradeoff batch size and sequence length to allow the KV cache to fit in memory. + +#### Batching +LLMs use batching to process multiple sequences in parallel. There are a few reasons why batching is useful: +- Real-world LLM services need to handle multiple concurrent requests. +- LLM inference is bound by time to read model weights from DRAM. Batching allows model weight reuse across multiple sequences. +- Total throughput of the system increases with batch size. + +However, there are tradeoffs with batching. In decode mode, latency scales sublinearly with batch size up to a point. This is because decode is bound by time to read model weights from DRAM rather than time to compute. If the batch grows very large, decode mode will eventually become compute bound, causing latency to scale linearly with batch size. In prefill mode, latency scales linearly with batch size because prefill is compute bound. + +It is typical to use different batch sizes for different use cases, depending on the goal of the system. + +#### Performance Metrics +**Time to first token (TTFT)** measures the latency to generate the first token of the sequence. This is the time to prefill a prompt and generate the first token. It is a measure of interactivity. + +**Total throughput (tokens per second)** tells us the total number of tokens that the model can generate per second. `total throughput = batch size / decode step latency`. Total throughput is important for cost-sensitive systems or offline processing, where interactivity is less important than throughput. Generally, increasing batch size will increase total throughput. + +**User throughput (tokens per second per user)** is calculated as `user throughput = 1 / decode step latency`. User throughput tells us how interactive the model is, and tells us how fast the generation is for a single user. Generally, decreasing batch size will increase user throughput. + +Note that each of these metrics change with batch size and sequence length. When reporting TTFT, total throughput, and user throughput, the batch size and sequence length must be specified. + + ### 3.2 Prefill and Decode Large language models require two distinct phases for inference due to the fundamental nature of transformer attention and autoregressive generation: prefill and decode. @@ -1333,9 +1377,56 @@ For our [Llama3 family of models](../../models/demos/llama3) we are using the fo ### 3.4 Continuous Batching - - quick intro and how it is implemented in demos. +Continuous batching is a serving optimization. To describe continuous batching, it is useful to first discuss LLM serving without continuous batching. + +Without continuous batching, an LLM service waits for `batch_size` requests to come in. The service then prefills each request. Then, the service decodes the batched requests token by token. Once all users in the batch finish generation, the service accepts new requests. This is suboptimal because 1) some requests might end generation early, so 2) some slots in the batch are not doing useful computation, while 3) new requests are waiting. + +In contrast, continuous batching allows the service to process new requests as soon as there is a free slot in the batch. The pseudo-code for this algorithm is shown below. + +```python +while True: + if not is_full(current_batch) and not prefill_q.empty(): + model_prefill(prefill_q.pop()) + elif not is_empty(current_batch): + model_decode(current_batch) + else: + break +``` + +![alt text](images/continuous_batching.png) +The above image from anyscale (https://www.anyscale.com/blog/continuous-batching-llm-inference) shows how continuous batching inserts prefill sequences into the batch as soon as there is a free slot. + +Continuous batching improves TTFT by reducing wait times for incoming users. It also increases total throughput by keeping the decode batch full of useful work. + +Continuous batching is an LLM serving optimization but it requires some support in the model. The model has to support single user prefill so that when a slot is open, the model can prefill a new request into a specific slot of the batch. The model also has to support batched decode where position ids can be different for each user in the batch, to avoid context contamination. +Implementing continuous batching requires that the serving code track data for each slot of the batch. An example of our continuous batching demo can be found [here](../../models/demos/t3000/llama2_70b/demo/demo_continuous_batching.py). In production deployment, vLLM handles continuous batching for the LLM service. + ### 3.5 vLLM Integration - - Our vLLM repo and what's needed to integrate with it. + +#### Overview +vLLM is an [open-source LLM serving library](https://github.com/vllm-project/vllm). We use vLLM to serve our models in production because of the features it enables. On the serving side, vLLM supports continuous batching and [paged attention](https://arxiv.org/pdf/2309.06180). In addition, vLLM provides an OpenAI-compatible server which is useful for deployment. + +Tenstorrent maintains a [fork of vLLM](https://github.com/tenstorrent/vllm/tree/dev) for serving models on Tenstorrent hardware. The [README](https://github.com/tenstorrent/vllm/tree/dev/tt_metal/README.md) has instructions for setting up the environment. + +#### Implementation Requirements +In order to add vLLM support to a new model, the model must conform to a certain interface. An example of the interface is the [Llama2-70b generation code](../../models/demos/t3000/llama2_70b/tt/llama_generation.py), which implements `prefill_forward`, `decode_forward`, and `initialize_vllm_model`. +Beyond implementing the functionality needed for continuous batching, a model must also implement paged attention. For an example, see [Llama2-70b attention](../../models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py). + +#### vLLM modifications +On the vLLM side there may be additional changes needed to support the new model. + +- Modify [`tt_loader.py`](https://github.com/tenstorrent/vllm/blob/dev/vllm/model_executor/model_loader/tt_loader.py) if the model requires a different initialization. +- Modify [`tt_model_runner.py`](https://github.com/tenstorrent/vllm/blob/dev/vllm/worker/tt_model_runner.py) if it is missing functionality for the new model. + +#### Testing +Finally, test the new model through vLLM. Register the new model as seen in [`offline_inference_tt.py`](https://github.com/tenstorrent/vllm/blob/dev/examples/offline_inference_tt.py). + +```python +from models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration +ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaModelForGeneration) +``` +and run `offline_inference_tt.py` to generate outputs with vLLM. + ## 4. Best Practices and Optimizations ### 4.1 Tracing Reference [Metal Trace guide](https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/AdvancedPerformanceOptimizationsForModels/AdvancedPerformanceOptimizationsForModels.md) for background on tracing. Tracing allows you to record a single pass of your model and store the list of commands and buffers used on-device. You can then execute that trace in a single command with no additional work performed on the host. This eliminates overhead in stages 1-3, you are still responsible for transferring any data needed to and from the device, but host-device transfer of commands is eliminated.