Skip to content

Commit

Permalink
#0: added normalization details in the tech report (#15124)
Browse files Browse the repository at this point in the history
### Problem description
Added Distributed and Non-Distributed normalization details in LLM tech
report.
  • Loading branch information
kpaigwar authored Dec 19, 2024
1 parent 26a041c commit a8a812b
Showing 1 changed file with 146 additions and 3 deletions.
149 changes: 146 additions & 3 deletions tech_reports/LLMs/llms.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# LLMs in TT-NN
Authors: Mark O'Connor, Djordje Ivanovic, Jack (Xun) Cai

Authors: Mark O'Connor, Djordje Ivanovic, Jack (Xun) Cai, Kartik Paigwar

## Contents
- [LLMs in TT-NN](#llms-in-tt-nn)
Expand Down Expand Up @@ -56,8 +57,150 @@ Other useful resources:
- Iterative update system
- When to use our fused op
### 2.3 Norm
- Replicated layernorm vs distributed layernorm
- Layernorm/rmsnorm weights in row major / wrapped around tile size trick

Normalization is a critical operation in Large Language Models (LLMs), ensuring stable training and efficient inference. Two widely adopted normalization techniques in modern LLMs, **LayerNorm** and **RMSNorm**, are fully supported in TT-NN.

#### Implementations of Normalization Operations

TT-NN includes two primary implementations of normalization operations to handle diverse activation layouts efficiently:

1. **Non-Distributed Norm**
2. **Distributed Norm**


#### 1. Non-Distributed Norm

**Non-Distributed Norm** refers to the standard implementation of normalization operations applied to activations that are not distributed across multiple devices. This type of normalization is suitable for setups where the entire activation or embedding is available locally on a single device or is replicated identically across multiple devices in a data-parallel setup. This implementation supports both sharded and interleaved inputs.

**Example: RMSNorm on Single Device (Decode Scenario)**

```python
import torch
import ttnn

def torch_rms_norm(x, gamma, eps):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * gamma

batch, seq_len, embedding_dim = 32, 1, 8192
torch_input = torch.randn((batch, seq_len, embedding_dim))
torch_gamma = torch.randn((embedding_dim))
torch_output = torch_rms_norm(torch_input, torch_gamma, eps=1e-5)

# Reshape inputs/weights to 4D tensors
torch_input = torch_input.view(1, 1, batch, embedding_dim) # seq_len = 1 for decode
torch_gamma = torch_gamma.view(1, 1, 1, embedding_dim)

# Convert tensors to TT-NN tensors
ttnn_input = ttnn.as_tensor(
torch_input,
device=device,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG
)
ttnn_gamma = ttnn.as_tensor(
torch_gamma,
device=device,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG
)

# Perform RMSNorm
ttnn_output = ttnn.rms_norm(ttnn_input, epsilon=1e-5, weight=ttnn_gamma)
```

**Optimization for Efficient Weight Reads from DRAM**

In above example, weights were traditionally pushed to device in **TILE layout**. But in this case, padding is required to match the TILE_HEIGHT. This padding increased memory footprint and reduced DRAM access efficiency. To address this, weights are now wrapped into **TILE_WIDTH** sticks and converted to **ROW_MAJOR_LAYOUT** without requiring any padding. This weight transformation doesn't have any overhead during runtime as its only performed once during initialization.

```python
# Optimized Weight Layout for DRAM
torch_gamma = torch_gamma.view(1, 1, embedding_dim // TILE_WIDTH, TILE_WIDTH)
ttnn_gamma_rm = ttnn.as_tensor(
torch_gamma,
device=device,
dtype=ttnn.bfloat16,
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG
)
```




#### 2. Distributed Norm

The distributed implementation is designed for cases where activations are **sharded along the embedding dimension** across multiple devices. It ensures the correct computation of mean and variance across shards by leveraging cross-device communication. We provide support for both interleaved and width-sharded inputs.

#### Steps to Perform Distributed Normalization on TT-Devices

1. **Compute Local Statistics**
Each device computes the required statistics (e.g., \(E[x]\), \(E[x^2]\)) locally on its shard of the input tensor.
- For **RMSNorm**, only \(E[x^2]\) is required.
- For **LayerNorm**, both \(E[x]\) and \(E[x^2]\) are computed.

```python
tt_distributed_stats = ttnn.rms_norm_pre_all_gather(tt_distributed_input_tensor)
```

- **Output**: A `stats` tensor of shape `[1, 1, batch, TILE_WIDTH * num_stats]`.
- **Note**:
- `num_stats=1` for RMSNorm.
- `num_stats=2` for LayerNorm.
- Only the first column of the stats tile contains meaningful data; the rest are padding.

2. **Gather Statistics Across Devices**
The statistics are gathered from all devices along the specified dimension (`dim=3`) and replicated across the device mesh.

```python
tt_gathered_stats = ttnn.all_gather(
tt_distributed_stats,
dim=3,
num_links=1,
cluster_axis=1,
mesh_device=mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
topology=ttnn.Topology.Linear,
)
```

- **Output**: A tensor of shape `[1, 1, batch, TILE_WIDTH * num_stats * num_devices]`.

3. **Global Normalization**
The gathered statistics are used to compute the global mean and variance, and normalization is performed on the sharded input.

```python
tt_distributed_output_tensor = ttnn.rms_norm_post_all_gather(
tt_distributed_input_tensor,
epsilon=eps,
weight=tt_distributed_weights,
program_config=sharded_program_config,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
stats=tt_gathered_stats,
)
```
- **Output**: A tensor of shape `[1, 1, batch, embedding_dim // num_devices]`.


#### Key Notes (Valid for Both Implementations):

- **Interleaved Inputs**:
For interleaved inputs, the kernel parallelizes work across the sequence length (`seq_len`).
This makes it highly **optimal for prefill cases**, where the sequence length is large.

- **Width-Sharded Inputs**:
For width-sharded inputs, the kernel splits the work across the embedding dimension.
This design is more **optimal for decode cases**, where the sequence length is typically `seq_len=1`.


#### References
- Non-Distributed Norm Op Code [[1]](https://github.com/tenstorrent/tt-metal/tree/main/ttnn/cpp/ttnn/operations/normalization/layernorm) [[2]](https://github.com/tenstorrent/tt-metal/tree/main/ttnn/cpp/ttnn/operations/normalization/rmsnorm)
- Distributed Norm Op Code [[3]](https://github.com/tenstorrent/tt-metal/tree/main/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed) [[4]](https://github.com/tenstorrent/tt-metal/tree/main/ttnn/cpp/ttnn/operations/normalization/rmsnorm_distributed)
- Non-Distributed Norms Unit Tests [[5]](https://github.com/tenstorrent/tt-metal/blob/main/tests/tt_eager/python_api_testing/unit_testing/misc/test_layernorm_sharded.py) [[6]](https://github.com/tenstorrent/tt-metal/blob/main/tests/tt_eager/python_api_testing/unit_testing/misc/test_layernorm.py)
- Distributed Norms Unit Tests [[7]](https://github.com/tenstorrent/tt-metal/blob/main/tests/ttnn/unit_tests/operations/test_distributed_layernorm.py) [[8]](https://github.com/tenstorrent/tt-metal/blob/main/tests/ttnn/unit_tests/operations/test_distributed_layernorm_sharded.py)
- Distributed Norm in LLama3 [[9]](https://github.com/tenstorrent/tt-metal/blob/main/models/demos/llama3/tt/distributed_norm.py)

### 2.4 Attention

Attention in TT-NN is implemented in custom TT-NN kernels. In PyTorch, the attention op is usually implemented in the following way with 6 steps:
Expand Down

0 comments on commit a8a812b

Please sign in to comment.