Authors: Mark O'Connor, Djordje Ivanovic, Jack (Xun) Cai
- LLMs in TT-NN
This document provides guidance on how to bring up high-performance multi-chip models on Tenstorrent hardware using the TT-Metal stack. It targets users with previous experience on TT-Metal and shares our current best practices, tips, caveats, and workarounds on model bringup.
Basic Requirements:
- Access to TT hardware - This document is specifically for bringing models up on Wormhole (WH), but much of this document applies to Grayskull.
- Good grasp of PyTorch and transformers - This document skims some basics, for example, this document assumes you understand what a kv-cache is and understand the difference between prefill (reading tokens and generating the kv-cache entries) and decode (auto-regressively generating new tokens one at a time). Beginner tutorials will follow, this document helps experts get up to speed deploying LLMs on Metal.
- Familiarity with Metal and ttnn - How to install, build, run examples, etc.
Other useful resources:
- Reference ViT guide if this document seems unclear or intimidating.
- Reference Building llama from scratch for further information about LLMs in general.
- Iterative update system
- When to use our fused op
- Replicated layernorm vs distributed layernorm
- Layernorm/rmsnorm weights in row major / wrapped around tile size trick
Attention in TT-NN is implemented in custom TT-NN kernels. In PyTorch, the attention op is usually implemented in the following way with 6 steps:
- QKV projections matmuls
- Reshape Q, K, V to match the expected input shape for the attention op
- Apply RoPE to Q and K
- Cache K and V
- Scaled Dot Product Attention
- Output reshape and output matmul
For example, the Llama model is implemented as follows:
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
"""
Forward pass of the attention module.
Args:
x (torch.Tensor): Input tensor.
start_pos (int): Starting position for caching.
freqs_cis (torch.Tensor): Precomputed frequency tensor.
mask (torch.Tensor, optional): Attention mask tensor.
Returns:
torch.Tensor: Output tensor after attention.
"""
# (1) QKV projections matmuls
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
# (2) Reshape Q, K, V to match the expected input shape for the attention op
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# (3) Apply RoPE to Q and K
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
# (4) Cache K and V
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
# (5) Scaled Dot Product Attention
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
output = torch.scaled_dot_product_attention(xq, keys, values, attn_mask=mask)
# (6) Output reshape and output matmul
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
The generic torch
implementation is agnostic to prefill and decode modes, however, our implementation differientiates them. To learn more about the differences between the two modes and how we handle them in TT-NN, please see 3.2 Prefill and Decode. In general, our high performance attention module uses specialized implementations for each mode as they have different memory and compute patterns and bottlenecks, requiring different optimizations.
The rest of this section will organized as follows. We split the attention module into two parts -- prefill and decode -- and describe the 6 steps implementations for each. Then, we discuss some limitations of the current implementation and useful facts that will help with debugging and performance optimization.
Some common terminology used in this section:
Term | Description |
---|---|
bsz | batch size |
batch_id | batch index (used for prefill) |
cur_pos/cur_pos_tensor | list/tensor of current positions in the sequence for each batch |
cache_len | length of the KV cache |
seqlen | sequence length |
dim | hidden dimension of input x |
head_dim | hidden dimension of Q, K, V |
n_q_heads | number of heads in Q |
n_kv_heads | number of heads in K, V |
The attention module in prefill mode expects input shape (1, bsz=1, seqlen, hidden_dim)
and outputs a tensor of the same shape. Note that bsz=1
is required. For multiple batches, we simply run prefill iteratively and populate the KV cache at batch_id
.
An end-to-end example of the prefill attention module is in the models/demos/llama3/tt/llama_attention.py
file, under the forward_prefill
method. In short, we break down the attention module in prefill mode into the following steps:
-
QKV projections matmuls.
- We combine the QKV projection weights into a single tensor, and perform standard
ttnn.linear
. Example:xqkv_fused = ttnn.linear(x, wqkv, dtype=ttnn.bfloat16)
- Input/Output shapes:
(1, 1, seqlen, dim) -> (1, 1, seqlen, (n_q_heads+2*n_kv_heads)*head_dim)
- We combine the QKV projection weights into a single tensor, and perform standard
-
Reshape Q, K, V to match the expected input shape for scaled dot product attention.
- We split the fused QKV tensor into individual Q, K, V tensors using a custom optimized TM op,
ttnn.experimental.nlp_create_qkv_heads
. Example:Q, K, V = ttnn.experimental.nlp_create_qkv_heads(xqkv_fused, num_heads=n_q_heads, num_kv_heads=n_kv_heads, transpose_k_heads=False)
- Input/Output shapes:
(1, 1, seqlen, (n_q_heads+2*n_kv_heads)*head_dim) -> (1, n_q_heads, seqlen, head_dim), (1, n_kv_heads, seqlen, head_dim), (1, n_kv_heads, seqlen, head_dim)
- We split the fused QKV tensor into individual Q, K, V tensors using a custom optimized TM op,
-
Apply RoPE to Q and K
- We apply the RoPE transformation to Q and K using the rotary embedding op outlined in 2.2 RoPE. The input/output shapes remain the same as in step 2.
-
Cache K and V
- We populate the KV cache at
batch_id
with the current K and V tensors using thettnn.fill_cache
op. Example:ttnn.fill_cache(K_cache, K, batch_id) ttnn.fill_cache(V_cache, V, batch_id)
- If page table is used, we use the
ttnn.experimental.paged_fill_cache
op. Example:ttnn.experimental.paged_fill_cache(K_cache, K, page_table, batch_idx=batch_id) ttnn.experimental.paged_fill_cache(V_cache, V, page_table, batch_idx=batch_id)
- We populate the KV cache at
-
Scaled Dot Product Attention
-
We perform scaled dot product attention using our custom flash attention kernel,
ttnn.transformer.scaled_dot_product_attention
. It takes in the following arguments:q
: Query tensor of shape(1, n_q_heads, seqlen, head_dim)
.k
: Key tensor of shape(1, n_kv_heads, cache_len, head_dim)
.v
: Value tensor of shape(1, n_kv_heads, cache_len, head_dim)
.attn_mask
: Defaults toNone
. [b x 1 x cache_len x seqlen]. Head broadcasting is implied.is_causal
: bool, defaults totrue
. Whether to apply causal masking.scale
: float, defaults toNone
.program_config
: Defaults toNone
.compute_kernel_config
: Defaults toNone
.
-
For general prefilling phase use cases with causal attention, it is recommended to set
is_causal=True
. This removes the need forattn_mask
and attention scores are computed in the lower triangular half of the attention matrix. For example:attn_output = ttnn.transformer.scaled_dot_product_attention(Q,K,V,is_causal=True)
-
For non-causal attention,
attn_mask
must be provided. An example is in the cross attention case in visual language models. For example:attn_output = ttnn.transformer.scaled_dot_product_attention(Q,K,V,attn_mask=mask, is_causal=False)
-
-
Output reshape and output matmul
- At last, we use
ttnn.experimental.nlp_concat_heads
to reshape the output of the attention op, followed by a standardttnn.linear
to do the output projection. Example:attn_output = ttnn.experimental.nlp_concat_heads(attn_output) output = ttnn.linear(attn_output, wo)
- Input/Output shapes:
(1, n_q_heads, seqlen, head_dim) -> (1, 1, seqlen, hidden_dim) -> (1, 1, seqlen, hidden_dim)
- At last, we use
The attention module in decode mode expects input shape (1, seqlen=1, bsz, hidden_dim)
and outputs a tensor of the same shape. Decode mode expects sequence length of 1 and parallelizes over batch size due to the auto-regressive nature of decoding.
An end-to-end example of the decode attention module is in the models/demos/llama3/tt/llama_attention.py
file, under the forward_decode
method. The decode mode is broken down into the following steps:
-
QKV projections matmuls.
- This works the same as in prefill mode, using
ttnn.linear
. Note that the input shape is(1, 1, bsz, dim)
instead of(1, 1, seqlen, dim)
. - Input/Output shapes:
(1, 1, bsz, dim) -> (1, 1, bsz, (n_q_heads+2*n_kv_heads)*head_dim)
- This works the same as in prefill mode, using
-
Reshape Q, K, V to match the expected input shape for scaled dot product attention.
- We split the fused QKV tensor into individual Q, K, V tensors using
ttnn.experimental.nlp_create_qkv_heads_decode
. Note that this is a different op thanttnn.experimental.nlp_create_qkv_heads
used in prefill mode. Example:Q, K, V = ttnn.experimental.nlp_create_qkv_heads_decode( xqkv_fused, num_heads=n_q_heads, num_kv_heads=n_kv_heads, memory_config=ttnn.MemoryConfig( ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1 ) )
- Input/Output shapes: The output is height sharded across the batch dimension on
bsz
number of cores.(1, 1, bsz, (n_q_heads+2*n_kv_heads)*head_dim) -> (1, bsz, n_q_heads, head_dim), (1, bsz, n_kv_heads, head_dim), (1, bsz, n_kv_heads, head_dim)
- We split the fused QKV tensor into individual Q, K, V tensors using
-
Apply RoPE to Q and K
- Again, we apply the RoPE transformation to Q and K using the rotary embedding op outlined in 2.2 RoPE. The input/output shapes remain the same as in step 2.
-
Cache K and V
- We populate the KV cache at
cur_pos
for all batches with the current K and V tensors using thettnn.experimental.paged_update_cache
op. This op takes in an optionalpage_table
argument to support paged KV cache updates. Example:ttnn.experimental.paged_update_cache(keys, K, update_idxs=cur_pos, page_table=page_table) ttnn.experimental.paged_update_cache(values, V, update_idxs=cur_pos, page_table=page_table)
- If current position is
cur_pos_tensor
, attnn.Tensor
rather than a list, we use theupdate_idxs_tensor
argument instead:ttnn.experimental.paged_update_cache(keys, K, update_idxs_tensor=cur_pos_tensor, page_table=page_table)
- We populate the KV cache at
-
Scaled Dot Product Attention Decode
- We perform scaled dot product attention using our custom flash attention kernel optimized for decode mode,
ttnn.transformer.scaled_dot_product_attention_decode
andttnn.transformer.paged_scaled_dot_product_attention_decode
for paged KV cache. ttnn.transformer.scaled_dot_product_attention_decode
takes in the following arguments:q
: Query tensor of shape(1, bsz, n_q_heads, head_dim)
.k
: Key tensor of shape(1, bsz, cache_len, head_dim)
.v
: Value tensor of shape(1, bsz, cache_len, head_dim)
.is_causal
: bool, defaults totrue
. Whether to apply causal masking.attn_mask
: Optional attention mask tensor. Defaults toNone
and only used ifis_causal=False
.cur_pos
: (Required for is_causal=True) List of current positions in the sequence for each batch. Defaults toNone
. Must be provided ifcur_pos_tensor
is not provided.cur_pos_tensor
: (Required for is_causal=True) Optional current position tensor. Defaults toNone
. Must be provided ifcur_pos
is not provided.scale
: Optional scale factor. Defaults toNone
.program_config
: Optional program configuration. Defaults toNone
.compute_kernel_config
: Optional compute kernel configuration. Defaults toNone
.memory_config
: Optional memory configuration for output tensor. Defaults toNone
.
ttnn.transformer.paged_scaled_dot_product_attention_decode
takes in the same arguments asttnn.transformer.scaled_dot_product_attention_decode
, but also takes in an additionalpage_table_tensor
argument.- For general decode use cases, it is recommended to set
is_causal=True
. This removes the need forattn_mask
which greatly reduces memory bandwidth usage. For example:attn_output = ttnn.transformer.paged_scaled_dot_product_attention_decode(Q, K, V, cur_pos_tensor=cur_pos, page_table=page_table)
- For non-causal attention,
attn_mask
must be provided. An example is in the cross attention case in visual language models. For example:attn_output = ttnn.transformer.paged_scaled_dot_product_attention_decode(Q, K, V, attn_mask=mask, is_causal=False)
- We perform scaled dot product attention using our custom flash attention kernel optimized for decode mode,
-
Output reshape and output matmul
- Lastly, we use
ttnn.experimental.nlp_concat_heads_decode
to reshape the output of the attention op, followed by a standardttnn.linear
to do the output projection. Example:attn_output = ttnn.experimental.nlp_concat_heads_decode(attn_output, num_heads=n_q_heads) output = ttnn.linear(attn_output, wo)
- Input/Output shapes:
(1, bsz, n_q_heads, head_dim) -> (1, 1, bsz, hidden_dim) -> (1, 1, bsz, hidden_dim)
- Lastly, we use
Flash attention and flash decode are the major ops for attention. They are optimized over for latency and throughput, and perform much better than vanilla implementations. If you are interested in how they work, please refer to our Flash Attention Tech Report.
TLDR -- here are some useful things about the attention ops to keep in mind that will help you write efficient and bug-free code:
-
Program Configs in flash attention (and flash decode) ops: The Program config has the following parameters:
compute_with_storage_grid_size
: The size of the grid size.q_chunk_size
: The size of a chunk to process at a time for Q.k_chunk_size
: The size of a chunk to process at a time for K and V.exp_approx_mode
: Whether to use the exponential approximation mode for softmax.max_cores_per_head_batch
: The maximum number of cores to use for each head batch in flash decode.
Flash attention processes Q, K, V in chunks of size
q_chunk_size
andk_chunk_size
. The chunk size must be a power of 2 and a multiple of 32. By default, the chunk size is set to 512, but you should experiment with different values to find the best performance. Flash attention is parallelized on the cores specified incompute_with_storage_grid_size
. For example, if you are running on a grid size of 8x8, then flash attention is parallelized over 64 cores. The parallelization is divided by batch, then by head, then by the number of Q chunks.Flash decode processes the entire Q (since query in decode mode is small) and K/V in chunks of size
k_chunk_size
. As a result, theq_chunk_size
field is not used for flash decode. It is parallelized over the cores specified incompute_with_storage_grid_size
. The parallelization is divided by batch, then by kv_head. In many cases, there will be more cores thanheads*batch
, so this is why flash decode is needed because it allows for multiple cores to process a single head. In extreme cases where there are too many cores to process a single head, the noc bandwidth between cores will become the bottleneck. We experimentally found out that more than 16 cores per head batch no longer provides any benefits and starts degrading performance. Themax_cores_per_head_batch
field is used to limit the number of cores used for each head batch for flash decode, and is set to 16 by default.Lastly, the
exp_approx_mode
field is to set the exponential approximation mode for softmax in flash attention and flash decode. We recommend setting this totrue
for smallseqlen/chunk_size
values. For largeseqlen/chunk_size
values, the error introduced by the exponential approximation can accumulate through chunk accumulation, causing major degradation in pcc. For example in Llama3 models, we useq_chunk_size
andk_chunk_size
of 512, andexp_approx_mode
set tofalse
for long sequence lengths greater than 16K. -
Current Position Tensor for flash decode and kv cache ops:
In decode mode, you can either provide a list of current positions, or a tensor. The tensor version can be more efficient because it supports tracing. To learn more about what is tracing and how to use it, please refer to 4.1 Tracing. In short, tracing requires the traced variables to be statically known at the compile time, so if you provide a list of current positions, you cannot modify it for the next token generation. However, if you provide a tensor, the position values are stored in device memory and can be updated using binary addition op, e.g.
ttnn.add
.
The Decode forward pass implementation below follows the diagram above. Keep in mind that, in order to optimize memory usage, it is recommended to deallocate tensors after their usage, which can be crucial under tighter memory constraints.
To optimize performance in decode mode, we maintain the residual stream in L1 and shard it across cores and devices. However, determining the optimal number of cores for sharding can be challenging, especially for operations like DRAM-sharded matmuls. Here is the code in Llama model config, that produces the core grid that will divide the N and K dims of a matmul evenly.
When it’s not feasible to keep the streams sharded, we use the ttnn op interleave_to_sharded
, and conversely, switch back as needed.
In our implementation of Llama3.1 there are some ops that require interleaved tensors and resharding.
def forward(
self,
x: ttnn.Tensor,
current_pos,
rot_mat=None,
transformation_mats=None,
user_id=0,
mode="decode",
page_table=None,
) -> ttnn.Tensor:
if mode == "prefill":
skip_mem_cfg = ttnn.DRAM_MEMORY_CONFIG
elif mode == 'decode':
skip_mem_cfg = self.model_config["DEC_SKIP_OUTPUT_MEMCFG"]
# Attention RMSNorm
attn_in = self.attention_norm(x)
# Attention
attn_out = self.attention.forward(
attn_in,
current_pos,
rot_mat,
transformation_mats,
user_id,
mode,
page_table,
)
ttnn.deallocate(attn_in)
# Residual add of inputs and attention output
h = ttnn.add(x, attn_out, memory_config=skip_mem_cfg)
ttnn.deallocate(attn_out)
# MLP and RMSNorm
ff_out = self.feed_forward.forward(self.ffn_norm(h), mode)
# Residual add of attention output and mlp output
out = ttnn.add(h, ff_out, memory_config=skip_mem_cfg)
ttnn.deallocate(ff_out)
ttnn.deallocate(h)
return out
The LMHead
is unique because LLMs typically have large vocabulary sizes, which are independent of the model size (i.e. model parameters).
As a result, the LMHead
has a large last_dim
in its weight matrix. Given the substantial size of LMHead
weights and the memory limitations of the hardware, these weights must be distributed across multiple devices and processed in iterations, while activations are replicated across devices.
The number of iterations required depends on the size of the weights and the number of devices available, ranging from 1 to several iterations. For example, in Llama 3.1’s decode mode, the LMHead matrix multiplication involves shapes of (32, 8K) x (8K, 128K)
.
Below is an illustration of how the LMHead weights are partitioned across two devices, followed by its implementation. For ilustrative purposes it uses 128K for the vocab_size
instead of the real Llama3.1 value of 128256
.
size_per_device = self.vocab_size // self.num_devices
num_splits = math.ceil(size_per_device / max_columns_per_device)
split_sizes = [min(size_per_device, max_columns_per_device)] * (num_splits - 1)
split_sizes.append(size_per_device - sum(split_sizes)) # remaining columns
# Split the output weights
torch_output_weights = state_dict[f"{state_dict_prefix}output.weight"].permute(1, 0)
self.output_weights = []
for i, split_size in enumerate(split_sizes):
cache_file_name = (
None if args.dummy_weights else weight_cache_path / f"output_lm_head_{num_splits}_split_shard_{i}"
)
# Create a list to store the split tensors for each device
device_splits = []
for device in range(self.num_devices):
start = device * size_per_device + sum(split_sizes[:i])
end = start + split_size
device_splits.append(torch_output_weights[:, start:end])
# Concatenate the splits from all devices
combined_split = torch.cat(device_splits, dim=-1)
memory_config = args.create_dram_sharded_mem_config(
k=args.dim, n=combined_split.shape[-1] // self.num_devices
)
self.output_weights.append(
ttnn.as_tensor(
combined_split,
device=mesh_device,
mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1),
layout=ttnn.TILE_LAYOUT,
dtype=dtype,
memory_config=memory_config,
cache_file_name=cache_file_name,
)
)
We use dram-sharded matmul for LMHead with program_config
and memory_config
generated by the code below.
For more information check Section: Op Configs.
The primary reason for having multiple program_configs
is that the weight shapes may result in unequal split sizes. This variability means the same configuration cannot be used for every matrix multiplication.
# Generate dram-sharded memory_config
memory_config = args.create_dram_sharded_mem_config(
k=args.dim, n=combined_split.shape[-1] // self.num_devices
)
# Generate dram-sharded program_config
self.program_configs = [
args.dram_matmul_config(
args.tile_padded_batch_rows,
args.dim,
split_size,
args.lm_head_core_grid.num_cores,
)
for split_size in split_sizes
]
Once weights are pushed to the devices and the decoders are executed, the LMHead
forward pass needs to be executed in iterations.
The code below shows that after each iteration outputs are converted from sharded to interleaved tensors. Once all iterations are completed, the final output is produced by concatenation over the last dim and returned as output
.
When executing the model, it is essential to ensure that the output of the last decoder is already replicated across tensors. Since this replication is enforced earlier, no additional code is required in the LMHead
forward pass to handle it.
def forward(self, x: ttnn.Tensor):
outputs = []
for weight, pc in zip(self.output_weights, self.program_configs):
output = ttnn.linear(
x,
weight,
compute_kernel_config=self.compute_kernel_config,
program_config=pc,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=ttnn.bfloat8_b,
)
outputs.append(output)
# Concatenate the outputs
output = ttnn.concat(outputs, dim=-1, memory_config=ttnn.DRAM_MEMORY_CONFIG)
return output
Once the model components (discussed in previous sections) are implemented, there isn’t much left to finalize. In our implementation, embeddings are managed outside the model class, as explained in Section 2.1 Embedding.
The model’s constructor initializes N decoders (e.g. 80 for Llama3.1-70b), the RMSNorm
and the LMHead
, ensuring that weights for all components are loaded onto the appropriate devices.
During the forward pass, the decoders are executed sequentially, followed by normalization and the LMHead
computation at the end.
A specific optimization is applied for the prefill mode: since only the last token is relevant, the LMHead
is executed only on the final tile in this mode.
In prefill mode, the RMSNorm output is interleaved, but the LMHead requires a sharded tensor. To accommodate this, the interleaved_to_sharded
function is used to prepare the output accordingly.
def forward(
self,
x: ttnn.Tensor,
current_pos,
rot_mat=None,
transformation_mats=None,
user_id=0,
mode="decode",
page_table=None,
get_last_token=-1,
):
for layer in self.layers:
x = layer(x, current_pos, rot_mat, transformation_mats, user_id, mode, page_table)
if mode == "prefill" and get_last_token == -1:
return x
# Slicing the tensor to the nearest ceiling/floor multiples of 32 for the prefill_len, to get the last token
if get_last_token != -1:
x = ttnn.slice(x, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, x.shape[-1]))
# Output norm
x = self.norm(x, mode=mode)
if mode == "prefill":
x = ttnn.interleaved_to_sharded(
x,
self.model_config["LM_HEAD_INPUT_MEMCFG"],
)
return self.lm_head(x)
- submodules, tests
- how to combine prefill and decode,
- slicing prefill to fit in L1
- device mesh
- column parallel followed by row parallel
- sharding, CCL ops, reducing CCL overheads, etc.
- quick intro and how it is implemented in demos.
- Our vLLM repo and what's needed to integrate with it.
Reference Metal Trace guide for background on tracing. Tracing allows you to record a single pass of your model and store the list of commands and buffers used on-device. You can then execute that trace in a single command with no additional work performed on the host. This eliminates overhead in stages 1-3, you are still responsible for transferring any data needed to and from the device, but host-device transfer of commands is eliminated.
We typically use tracing for the decode pass of LLMs but not the prefill pass. The main reasons for this are linked to tracing’s key limitation:
- You cannot allocate or deallocate tensors during a trace. When executing a trace every buffer will be the same size every time.
Tracing doesn’t work with prefill, sequence length and matmul row counts will likely change. Tracing works with decode, reference sections on handling kv-cache and paging with tracing. Conveniently, in prefill we have large operations in the millisecond plus range which the host can dispatch quickly. Decode, with a comparatively small batch size, we iterate through the entire model in 10ms with microsecond-length op times where we can't wait for a CPU or linux process scheduling, the speed at which electrons coruscate from DRAM and the NoC through our cores.
Async mode allows the host to continuously send commands to the device without blocking until data is read back from device, improving performance. Enable it with:
mesh_device.enable_async(True)
Without async mode each python call to ttnn will block until the device has finished and results are available. This is good for debugging, any crash or error will show you the offending line of code. With async mode enabled your python thread keeps on running while the host and device handle background calls, only blocking when data needs to be read back from device.
Async mode is faster, in case of asserts or crashes your python stack will be several lines further on than the call that caused the problem. For performance work async mode should always be enabled. For debugging it can be useful to disable it.
- how to feed back output to input and read output asyncronously
- Writing correct program configs and shard specs
- Deciding how many cores to run an op on
- Why did we use 16 cores for MLP
- Which matmul to use when @Colman Glagovich
- 1d, 2d, dram-sharded, ...
- Implicitly padding weights in program config for matmuls
- How we measure it (PCC, perplexity, top-1/top-5, end-user tests, benchmarking)
- How much PCC is enough? Rules of thumb.
- Accuracy tests
- Debugging PCC issues
ttnn performance has five components:
- Main Python Thread - Main python thread is your code that executes ttnn calls and other logical OPs. The speed of the main python thread determines the speed at which python calls are dispatched to the API. You are in control of any overheads. When counting in microseconds python is slower than you think.
- Host API - Most ttnn calls are immediately dispatched onto multiple C++ threads for further processing before any hardware changes. You are generally not in control of any overheads in this part of the stack.
- Host-device Communications - Data is heavy, avoid moving it. PCIe bandwidth and latency isn't negligible at the speeds needed to run models. In addition, Tenstorrent converts data into tiles of 32x32 elements for faster processing. Tilizing and untilizing data must be specified, takes time, and is performed on-device where possible.
- Device Dispatch - We can measure time between one OP finishing and the next starting. The lower limit of device dispatches are single-digit microseconds. Work is underway to reduce the lower limit to zero. However, for various reasons you might see much higher dispatch times, most notably if there are a lot of runtime arguments to a function or if OPs are running between calls.
- Device OP Performance - Device OP performance measures how long it takes the hardware to run a given operation. We want performance limited by either DRAM bandwidth or math throughput. For larger OPs both of these are achievable. Device OP performance is about how data is placed (DRAM vs L1, sharded vs interleaved) and how the compute kernels are configured (process more than one tile at once and use smaller data formats).
Further detail will be provided. It is important to confirm that Tracing has been enabled. For more inforation see 4.1 Tracing for more details, tracing should be used for decode mode but not prefill mode.
This means that for decode mode you won’t have to worry about 1-3 but for prefill mode you will.
Implement the main python thread if you are not tracing. The main python thread is not important if you are using tracing. The Metal Profiler/Tracy can also show python performance but for pure python analysis Viztracer is a recommended tool. viztracer:
pip install viztracer
Find the line of code to profile, it is usually the part that calls your model’s forward function and wrap it, e.g.:
# ...
# setup code above
from viztracer import Viztracer
with Viztracer(output_file='trace.json') as tracer:
tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
You can view this file with vizviewer trace.json
- it’s self-sufficient so if you’re working on a remote machine you can copy it back to your laptop and run it there (remember to pip install viztracer
locally as well). Use WASD to navigate the UI and use the mouse to expand processes to see the call stacks. Look for any non-ttnn code that takes a significant amount of time between the ttnn calls in functions and find a way to remove or optimize it.
What to look for:
- The model forward pass running quickly and then waiting in a ttnn.to_torch or similar call reading data back from device.
- Time from the start to end of the forward pass of your model. If this is shorter than target latency of your device, it is Fast Enough™ and you are done with this section.
Top tips:
- Torch modules add overhead to every function call and member access. We don’t subclass
torch.nn.Module
for anything that might have to run quickly. - Generate shard spec and compute kernel config objects once (e.g. in a constructor) instead of recreating them every time you run the forward pass. Keep the forward pass clean.
- Make sure Metal is compiled in Release mode (default) and you are using ttnn’s async mode (see above).
Any overhead here is outside your control and in our experience is minimal. Use a C++ profiler or Metal Profiler/Tracy with host stack traces enabled to see this time.
As little communication as possible between the host and the device is preferred. For LLMs this means:
- Perform embeddings on-device (tokens ids are smaller than embeddings).
- Return only the last token from prefill, not all the tokens.
- Perform sampling (argmax etc) on-device if you can (at time of writing only argmax is implemented).
- Avoid pushing attention masks, rotation matrices if they can be generated on-device or re-used between iterations.
Note where data is tilized and untilized. Do not tilize or untilize data on the host. The API to_torch
will by default do this on the host. You can untilize on-device like this:
tt_out_tiled = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
tt_out_row_major = ttnn.untilize(tt_out_tiled, use_multicore=True)
tt_tok = ttnn.argmax(tt_out_row_major, dim=3, use_multicore=True)
torch_tok = ttnn.to_torch(tt_tok)
Looking at host-device communications in a python profiler like viztracer
is possible but be careful - when async-mode is on then any time spent in a communication call like to_torch
can be comprised of up to three measures:
- Time spent waiting for the device
- Time spent transferring data
- Time spent untilizing data
If you want to measure calls this way, turn async mode off. The time your main python thread spends in to_torch
will not include any time spent waiting for the device and will be a closer approximation the measures above.
This is the fun bit, but we need to do a little prep to get started. First, metal must be compiled with -p
to enable device profiling:
./build_metal -p
Then we can record an OP performance csv file with tracy. For the pytests, run it like this:
python -m tracy -r -p -v -m pytest path/to/test.py
This produces a file with naming convention similar to ops_perf_results_2024_11_01_15_33_18.csv
, this file is needed from the profiler. For more information see: Metal Profiler tech report.
Warning: Only use a single trace execution step when profiling. Profiler support with tracing is still a work-in-progress and more iterations will result in a
AssertionError: Device data mismatch error
.
Note: If you see errors while running tracy, try this device-only profiling process instead: run with
TT_METAL_DEVICE_PROFILER=1 pytest path/to/test.py
. After the run completes runtt_metal/tools/profiler/process_ops_logs.py --date
to generate the CSV file.
This CSV file contains information recorded from all devices during program execution. To summarize, we run the perf_report.py
tool:
python models/perf/perf_report.py OPS_CSV_FILE
For device performance we recommend looking at a single layer. You can do this by using --id-range
or by changing your test to run only a single layer of the model. For more information see: Performance Report Analysis Tool. The Performance Report Analysis Tool document describes how to select specific ranges of OPs.
Ideally you should run your model in as close to end-user form as possible, simplifying it as much as possible. In practice this means:
- Use tracing (if you are using tracing in production).
- Skip the first compilation iteration - this adds a lot of one-time host overhead between OPs.
- Run a single layer of the model - but be aware of which OPs are run for every layer and which ones are only run at the start and end (e.g. embedding, final norm and LM head).
- Add a tracy signpost e.g.
tracy.signpost("Performance pass")
before the part you want to record - this will be focused on by default byperf_report.py
, saving you some work.
Here is an example without tracing enabled. You can instantly see that more time (756us) is spent in between OPs (op-to-op gap) than running OPs on device (362us)!
Reducing op-to-op gap
There are two main contributors to op-to-op gap: host time and dispatch time.
- Host time is optimized in steps 1-3. If you are already using tracing or are using async mode and have ensured that your python thread is dispatching faster than the device is generating outputs, then this has already been minimized.
- Dispatch time is out of your hands, but as an example, it is influenced by the number of runtime args a kernel uses.
- You can examine the source code for any kernel with high op-to-op latency and see if you can convert some runtime args into compile-time args for your use case.
- You can fuse multiple OPs into a single kernel. Examples where this was worthwhile in the past include
LayerNorm
andScaledDotProductAttentionDecode
.
Typically tracing reduces the op-to-op gap below 6us and as of November 2024 there are roadmap plans to reduce this to zero, so as long as your OPs are below this level, your opportunities for optimization here are limited.
See the next section for tips on how to optimize OP performance.
There are many individual tips, let’s start with overall advice:
- Use as many cores as possible.
- Move data as little as possible.
The perfect OP runs on the entire core grid using sharded inputs from L1. Let’s look more at data movement first, then specific tips.
OPs can read data from:
- DRAM Interleaved - Each tile (32x32 datums) is read from a different DRAM bank. This is the ttnn default and is the slowest way to read data. A matmul can expect to read around 190 GB/s on a Wormhole like this.
- DRAM Sharded - Specifically used for DRAM-bound matmuls and nothing else, this splits the data across DRAM banks and uses the closest core to each bank on the chip to read from that bank. This achieves around 240 GB/s on a Wormhole.
- L1 Interleaved - Tiles are interleaved across the L1 of all the cores and are read across the NoC (network-on-chip).
- L1 Sharded - Tiles are sharded across a particular grid of cores.
Note that the term sharding is used in two ways in the metal stack. Here we are talking about sharding across cores within a single chip. It is also used to refer to sharding a dimension across multiple devices - an analogous operation but confusing in this context.
L1 sharded is particularly fast when the data an OP requires is already placed in L1 of the correct core, avoiding the NoC entirely and reading at maximum speed.
Activations are placed in L1 and weights placed in DRAM.
See the op config section for more details on writing shard specs in your code.
Situation: OPs are reading from the fastest memory they can, sharded if possible. What might still make things slow?
- Unnecessary
ShardedToInterleaved
andInterleavedToSharded
calls. The fastest work is work that you don’t have to do. These calls are pure data movement and it is often better to have some OPs using fewer cores if it means they can use the same sharding of their input data as the previous and subsequent OPs. Always avoid data movement! - Always use
ScaledDotProductAttention
(SDPA) OPs if possible. These implement FlashAttention / FlashDecode and are much faster than writing attention using individual operations. - Cross-device communication OPs.
AllGather
,ReduceScatter
etc. Avoid these where possible, try usingbfp8
inputs instead ofbf16
if you can. There is anAllGatherMatmul
OP that overlapsAllGather
with aMatmul
that you can investigate further too - seettnn.experimental.all_gather_matmul
with an example of its use looking like this:
_, dense_out_sharded, _ = ttnn.experimental.all_gather_matmul(
input_tensor,
weights,
dim=3,
all_gather_core_grid_offset=(0, 4),
num_links=1,
memory_config_ag=all_gather_memcfg,
memory_config_mm=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
program_config=all_gather_matmul_progcfg,
compute_kernel_config=compute_kernel_config_hifi2,
)
Matmuls are usually the most significant workload. They should be memory-bound, compute-bound or too small to matter. perf_report.py
gives good advice for your matmuls and you should follow it, which usually involves specifying a program config:
- Output subblock size should be at least 2x1 or 1x2.
- DRAM-sharded matmuls should be used for any DRAM-bound cases, e.g. most decode matmuls.
- The inner dim number of tiles (
in0_block_w
) should be at least 2 if possible. - Use the lowest precision you can for weights and inputs - we find BFP8 weights always work and BFP4 weights work for some matmuls particularly in the MLP.
- Use an appropriate math fidelity in the compute kernel config. This controls the number of bits multiplied together and is especially important for compute-bound matmuls as the Tensix core’s math throughput is 2x higher with HiFi2 and 3.6x faster with LoFi.
- Use HiFi4 for BF16 weights or if accuracy is very important (you often see this in attention ops)
- Use HiFi2 for BFP8 weights - this drops the least-significant bit of a BF16 @ BFP8 matmul but this is usually not an issue. You may find that LoFi works as well.
- Use LoFi for BFP4 weights.
You can specify a compute kernel like this:
self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=False,
fp32_dest_acc_en=False,
packer_l1_acc=True,
)
As always, do not recreate these every single forward pass if you want your python thread to be fast (which you do).
- Running out of L1
- Shard spec and program config mismatches
- For some TTNN ops (e.g. ttnn.all_gather) it's not supported to pass -1 in the dim argument.
- You'll see an error related to op invocation where the arguments don't match
- link to any other description, and mention it is needed for N300 and T3K
- Host communications cause tracing to hang
- Running without async mode enabled causes tracing to hang
- Careful with print in tracing
- Large matmuls hanging? Link to appropriate ticket with workaround
- Issue is being investigated with a workaround of setting the output subblock to 1,1 and grid size to 8x7