Skip to content

Commit

Permalink
#0: Initial review
Browse files Browse the repository at this point in the history
  • Loading branch information
mtairum committed Nov 21, 2024
1 parent e64c3b4 commit 69b787b
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 17 deletions.
35 changes: 22 additions & 13 deletions tech_reports/LLMs/llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,20 @@ Authors:
Almost every LLM generates text in the same manner: Given a prompt from the user, the LLM predicts the next token. Then, the LLM takes that new token and uses it as context to predict the following token. This process repeats until the LLM generates a token that indicates the end of the sequence, or until the user decides to stop the generation. The process is called "autoregressive generation" because each new token is used to predict the next token.

#### Model Inputs and Outputs
Inputs to the mode for generative decoding are generally:
Inputs to the model for generative decoding are generally:
- tokens: produced by the tokenizer
- position ids: the position of the tokens in the sequence
- KV cache: an inference optimization that caches intermediate values

In the model, tokens are embedded from the vocabulary space to the embedding space. Position ids are necessary for updating the KV cache and for positional embeddings like RoPE.
In the model, tokens are embedded from the vocabulary space to the embedding space. Position ids are necessary for updating the KV cache and for positional embeddings like RoPE [TODO: Refer to the RoPE section].

The model outputs:
- logits for the next token
- an updated KV cache

The logits are unnormalized probabilities over the vocabulary. Given these probabilities, the sampler must decide which of these tokens in the vocabulary will be chosen. There are a few sampling methods that are commonly used to pick the next token.
- greedy decoding (argmax of the logits, picks the most likely next token)
- top-p/top-k sampling (restricts the logits according to p and k values, then samples according to the remaining probabilities)
The logits are unnormalized probabilities over the vocabulary. Given these probabilities, the sampler must decide which of these tokens in the vocabulary will be chosen. There are a few sampling methods that are commonly used to pick the next token:
- Greedy decoding (argmax of the logits, picks the most likely next token)
- Top-p/top-k sampling (restricts the logits according to p and k values, then samples according to the remaining probabilities)

#### KV cache
The KV cache is an inference optimization. It allows us to cache some intermediate values during the first inference step which are reused in later steps.
Expand All @@ -89,11 +89,11 @@ LLMs use batching to process multiple sequences in parallel. There are a few rea
However, there are tradeoffs with batching. As the batch size increases, the latency per decode step will also increase. It is typical to use different batch sizes for different use cases, depending on the goal of the system.

#### Performance Metrics
Time to first token (TTFT) measures the latency to generate the first token of the sequence. This is the time to prefill a prompt. It is a measure of interactivity.
**Time to first token (TTFT)** measures the latency to generate the first token of the sequence. This is the time to prefill a prompt and generate the first token. It is a measure of interactivity.

Total throughput (tokens per second) tells us the total number of tokens that the model can generate per second. `total throughput = batch size / decode step latency`. Total throughput is important for cost-sensitive systems or offline processing, where interactivity is less important than throughput. Generally, increasing batch size will increase total throughput.
**Total throughput (tokens per second)** tells us the total number of tokens that the model can generate per second. `total throughput = batch size / decode step latency`. Total throughput is important for cost-sensitive systems or offline processing, where interactivity is less important than throughput. Generally, increasing batch size will increase total throughput.

User throughput (tokens per second per user) is calculate as `user throughput = 1 / decode step latency`. User throughput tells us how interactive the model is, and tells us how fast the generation is for a single user. Generally, decreasing batch size will increase user throughput.
**User throughput (tokens per second per user)** is calculated as `user throughput = 1 / decode step latency`. User throughput tells us how interactive the model is, and tells us how fast the generation is for a single user. Generally, decreasing batch size will increase user throughput.

Note that each of these metrics change with batch size and sequence length. When reporting TTFT, total throughput, and user throughput, the batch size and sequence length must be specified.

Expand All @@ -106,10 +106,14 @@ Note that each of these metrics change with batch size and sequence length. When
- device mesh
- column parallel followed by row parallel
- sharding, CCL ops, reducing CCL overheads, etc.

### 3.4 Continuous Batching
Continuous batching is a serving optimization. To describe continuous batching, it is useful to first discuss LLM serving without continuous batching. Without continuous batching, an LLM service waits for `batch_size` requests to come in. The service then prefills each request. Then, the service decodes the batched requests token by token. Once all users in the batch finish generation, the service accepts new requests. This is suboptimal because 1) some requests might end generation early, so 2) some slots in the batch are doing no useful computation, while 3) new requests are waiting.
Continuous batching is a serving optimization. To describe continuous batching, it is useful to first discuss LLM serving without continuous batching.

Without continuous batching, an LLM service waits for `batch_size` requests to come in. The service then prefills each request. Then, the service decodes the batched requests token by token. Once all users in the batch finish generation, the service accepts new requests. This is suboptimal because 1) some requests might end generation early, so 2) some slots in the batch are not doing useful computation, while 3) new requests are waiting.

In contrast, continuous batching allows the service to process new requests as soon as there is a free slot in the batch. The pseudo-code for this algorithm is shown below.

```python
while True:
if not is_full(current_batch) and not prefill_q.empty():
Expand All @@ -121,24 +125,29 @@ while True:
```
Continuous batching improves TTFT by reducing wait times for incoming users. It also increases total throughput by keeping the decode batch full of useful work.

Continuous batching is an LLM serving optimization but it requires some support in the model. The model has to support single user prefill so that when a slot is open, the model can prefill a new request into a specific slot of the batch. The model also has to support batched decode where position ids can be different for each user in the batch. Implementing continuous batching requires that the serving code track data for each slot of the batch. An example of our continuous batching demo can be found [here](https://github.com/tenstorrent/tt-metal/blob/main/models/demos/t3000/llama2_70b/demo/demo_continuous_batching.py). In production deployment, vLLM handles continuous batching for the LLM service.
Continuous batching is an LLM serving optimization but it requires some support in the model. The model has to support single user prefill so that when a slot is open, the model can prefill a new request into a specific slot of the batch. The model also has to support batched decode where position ids can be different for each user in the batch, to avoid context contamination.
Implementing continuous batching requires that the serving code track data for each slot of the batch. An example of our continuous batching demo can be found [here](../../models/demos/t3000/llama2_70b/demo/demo_continuous_batching.py). In production deployment, vLLM handles continuous batching for the LLM service.

### 3.5 vLLM Integration

#### Overview
vLLM is an open-source LLM serving library. We use vLLM to serve our models in production because of the features it enables. On the serving side, vLLM support continuous batching and paged attention. In addition, vLLM provides an OpenAI-compatible server which is useful for deployment.

Tenstorrent maintains a [fork of vLLM](https://github.com/tenstorrent/vllm/tree/dev) for serving models on Tenstorrent hardware. The [README](https://github.com/tenstorrent/vllm/tree/dev/tt_metal) has instructions for setting up the environment.
Tenstorrent maintains a [fork of vLLM](https://github.com/tenstorrent/vllm/tree/dev) for serving models on Tenstorrent hardware. The [README](https://github.com/tenstorrent/vllm/tree/dev/tt_metal/README.md) has instructions for setting up the environment.

#### Implementation Requirements
In order to add vLLM support to a new model, the model must conform to a certain interface. An example of the interface is the [Llama2-70b generation code](https://github.com/tenstorrent/tt-metal/blob/main/models/demos/t3000/llama2_70b/tt/llama_generation.py), which implements `prefill_forward`, `decode_forward`, and `initialize_vllm_model`. Beyond implementing the functionality needed for continuous batching, a model must also implement paged attention. For an example, see [Llama2-70b attention](https://github.com/tenstorrent/tt-metal/blob/main/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py).
In order to add vLLM support to a new model, the model must conform to a certain interface. An example of the interface is the [Llama2-70b generation code](../../models/demos/t3000/llama2_70b/tt/llama_generation.py), which implements `prefill_forward`, `decode_forward`, and `initialize_vllm_model`.
Beyond implementing the functionality needed for continuous batching, a model must also implement paged attention. For an example, see [Llama2-70b attention](../../models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py).

#### vLLM modifications
On the vLLM side there may be additional changes needed to support the new model.

Modify the [`tt_loader.py`](https://github.com/tenstorrent/vllm/blob/dev/vllm/model_executor/model_loader/tt_loader.py) if the model requires a different initialization. Modify [`tt_model_runner.py`](https://github.com/tenstorrent/vllm/blob/dev/vllm/worker/tt_model_runner.py) if it is missing functionality for the new model.
- Modify [`tt_loader.py`](https://github.com/tenstorrent/vllm/blob/dev/vllm/model_executor/model_loader/tt_loader.py) if the model requires a different initialization.
- Modify [`tt_model_runner.py`](https://github.com/tenstorrent/vllm/blob/dev/vllm/worker/tt_model_runner.py) if it is missing functionality for the new model.

#### Testing
Finally, test the new model through vLLM. Register the new model as seen in [`offline_inference_tt.py`](https://github.com/tenstorrent/vllm/blob/dev/examples/offline_inference_tt.py).

```python
from models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration
ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaModelForGeneration)
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/third_party/umd
Submodule umd updated 48 files
+1 −1 Doxyfile
+25 −0 blackhole_1chip_cluster.yaml
+2 −2 device/CMakeLists.txt
+1 −1 device/blackhole/blackhole_implementation.cpp
+1 −1 device/cpuset_lib.cpp
+1 −1 device/device_api_metal.h
+1 −1 device/grayskull/grayskull_implementation.cpp
+1 −1 device/hugepage.h
+0 −6 device/libs/create_ethernet_map.h
+2 −2 device/mockup/tt_mockup_device.hpp
+1 −1 device/pcie/pci_device.hpp
+1 −1 device/simulation/deprecated/tt_emulation_device.cpp
+2 −2 device/simulation/deprecated/tt_emulation_device.h
+1 −1 device/simulation/deprecated/tt_emulation_stub.cpp
+2 −2 device/simulation/deprecated/tt_versim_device.cpp
+2 −2 device/simulation/deprecated/tt_versim_device.h
+2 −2 device/simulation/deprecated/tt_versim_stub.cpp
+1 −1 device/simulation/tt_simulation_device.cpp
+2 −2 device/simulation/tt_simulation_device.h
+6 −23 device/tt_cluster_descriptor.cpp
+4 −10 device/tt_cluster_descriptor.h
+32 −0 device/tt_device.cpp
+8 −13 device/tt_device.h
+4 −6 device/tt_io.hpp
+130 −145 device/tt_silicon_driver.cpp
+1 −1 device/tt_silicon_driver_common.cpp
+1 −1 device/wormhole/wormhole_implementation.cpp
+0 −16 device/xy_pair.cpp
+6 −2 device/xy_pair.h
+14 −70 tests/api/test_chip.cpp
+16 −6 tests/api/test_cluster.cpp
+11 −4 tests/api/test_cluster_descriptor.cpp
+1 −1 tests/api/test_mockup_device.cpp
+4 −6 tests/blackhole/test_bh_common.h
+14 −16 tests/blackhole/test_silicon_driver_bh.cpp
+1 −1 tests/emulation/test_emulation_device.cpp
+2 −2 tests/galaxy/test_galaxy_common.cpp
+3 −5 tests/galaxy/test_galaxy_common.h
+4 −4 tests/galaxy/test_umd_concurrent_threads.cpp
+4 −4 tests/galaxy/test_umd_remote_api.cpp
+1 −1 tests/galaxy/test_umd_remote_api_stability.cpp
+11 −14 tests/grayskull/test_silicon_driver.cpp
+3 −3 tests/microbenchmark/device_fixture.hpp
+1 −1 tests/test_utils/device_test_utils.hpp
+7 −7 tests/test_utils/stimulus_generators.hpp
+14 −16 tests/wormhole/test_silicon_driver_wh.cpp
+1 −1 tests/wormhole/test_umd_remote_api_stability.cpp
+4 −4 tests/wormhole/test_wh_common.h

0 comments on commit 69b787b

Please sign in to comment.