Skip to content

Commit

Permalink
[LLM Tech report] Sections 2.6, 2.7 and 2.8
Browse files Browse the repository at this point in the history
Co-authored-by: mtairum <[email protected]>
  • Loading branch information
djordje-tt and mtairum authored Dec 4, 2024
1 parent 56c3e45 commit 98be31d
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 2 deletions.
Binary file added tech_reports/LLMs/images/2.6-decoder.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tech_reports/LLMs/images/2.7-lm-head.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tech_reports/LLMs/images/2.8-llama-model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
217 changes: 215 additions & 2 deletions tech_reports/LLMs/llms.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# LLMs in TT-NN
Authors: Mark O'Connor,
Authors: Mark O'Connor, Djordje Ivanovic

## Contents
- [LLMs in TT-NN](#llms-in-tt-nn)
Expand Down Expand Up @@ -64,8 +64,221 @@ Other useful resources:
- limitations
- which dims are parallelized
### 2.5 MLP

### 2.6 Decoder
<div align="center">
<img src="images/2.6-decoder.png" alt="Decoder Diagram" title="Decoder Title" width="350" height="400">
</div>
When the components explained in previous sections (MLP, Attention, RMSNorm) are implemented, bringing up the decoder should be relatively straightforward.
According to the diagram (based on the Llama3.1 example), the components are stacked sequentially during the forward pass.
The only thing to consider is whether addition of MLP and Attention outputs should be stored in L1 or in DRAM.

<br>

The Decode forward pass implementation below follows the diagram above. Keep in mind that, in order to optimize memory usage, it is recommended to deallocate tensors after their usage, which can be crucial under tighter memory constraints.
<br>

To optimize performance in decode mode, we maintain the residual stream in L1 and shard it across cores and devices. However, determining the optimal number of cores for sharding can be challenging, especially for operations like DRAM-sharded matmuls. Here is the [code](https://github.com/tenstorrent/tt-metal/blob/53c32c0c0da926f97bd0eb042e70fd54c2866f44/models/demos/llama3/tt/model_config.py#L931) in Llama model config, that produces the core grid that will divide the N and K dims of a matmul evenly.
When it’s not feasible to keep the streams sharded, we use the ttnn op `interleave_to_sharded`, and conversely, switch back as needed.
In our implementation of Llama3.1 there are some ops that require interleaved tensors and resharding.

<br>

```py
def forward(
self,
x: ttnn.Tensor,
current_pos,
rot_mat=None,
transformation_mats=None,
user_id=0,
mode="decode",
page_table=None,
) -> ttnn.Tensor:
if mode == "prefill":
skip_mem_cfg = ttnn.DRAM_MEMORY_CONFIG
elif mode == 'decode':
skip_mem_cfg = self.model_config["DEC_SKIP_OUTPUT_MEMCFG"]
# Attention RMSNorm
attn_in = self.attention_norm(x)
# Attention
attn_out = self.attention.forward(
attn_in,
current_pos,
rot_mat,
transformation_mats,
user_id,
mode,
page_table,
)
ttnn.deallocate(attn_in)
# Residual add of inputs and attention output
h = ttnn.add(x, attn_out, memory_config=skip_mem_cfg)
ttnn.deallocate(attn_out)
# MLP and RMSNorm
ff_out = self.feed_forward.forward(self.ffn_norm(h), mode)
# Residual add of attention output and mlp output
out = ttnn.add(h, ff_out, memory_config=skip_mem_cfg)

ttnn.deallocate(ff_out)
ttnn.deallocate(h)

return out
```


### 2.7 LM Head

The `LMHead` is unique because LLMs typically have large vocabulary sizes, which are independent of the model size (i.e. model parameters).
As a result, the `LMHead` has a large `last_dim` in its weight matrix. Given the substantial size of `LMHead` weights and the memory limitations of the hardware, these weights must be distributed across multiple devices and processed in iterations, while activations are replicated across devices.

The number of iterations required depends on the size of the weights and the number of devices available, ranging from 1 to several iterations. For example, in Llama 3.1’s decode mode, the LMHead matrix multiplication involves shapes of ```(32, 8K) x (8K, 128K)```.

Below is an illustration of how the LMHead weights are partitioned across two devices, followed by its implementation. For ilustrative purposes it uses 128K for the `vocab_size` instead of the real Llama3.1 value of `128256`.

<div align="center">
<img src="images/2.7-lm-head.png" alt="LM Head Diagram" title="LM_Head" width="650" height="350">
</div>

```py
size_per_device = self.vocab_size // self.num_devices
num_splits = math.ceil(size_per_device / max_columns_per_device)

split_sizes = [min(size_per_device, max_columns_per_device)] * (num_splits - 1)
split_sizes.append(size_per_device - sum(split_sizes)) # remaining columns

# Split the output weights
torch_output_weights = state_dict[f"{state_dict_prefix}output.weight"].permute(1, 0)

self.output_weights = []

for i, split_size in enumerate(split_sizes):
cache_file_name = (
None if args.dummy_weights else weight_cache_path / f"output_lm_head_{num_splits}_split_shard_{i}"
)

# Create a list to store the split tensors for each device
device_splits = []
for device in range(self.num_devices):
start = device * size_per_device + sum(split_sizes[:i])
end = start + split_size
device_splits.append(torch_output_weights[:, start:end])

# Concatenate the splits from all devices
combined_split = torch.cat(device_splits, dim=-1)

memory_config = args.create_dram_sharded_mem_config(
k=args.dim, n=combined_split.shape[-1] // self.num_devices
)
self.output_weights.append(
ttnn.as_tensor(
combined_split,
device=mesh_device,
mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1),
layout=ttnn.TILE_LAYOUT,
dtype=dtype,
memory_config=memory_config,
cache_file_name=cache_file_name,
)
)
```

We use dram-sharded matmul for LMHead with `program_config` and `memory_config` generated by the code below.
For more information check [Section: Op Configs](#44-op-configs).
The primary reason for having multiple `program_configs` is that the weight shapes may result in unequal split sizes. This variability means the same configuration cannot be used for every matrix multiplication.

```py
# Generate dram-sharded memory_config
memory_config = args.create_dram_sharded_mem_config(
k=args.dim, n=combined_split.shape[-1] // self.num_devices
)
# Generate dram-sharded program_config
self.program_configs = [
args.dram_matmul_config(
args.tile_padded_batch_rows,
args.dim,
split_size,
args.lm_head_core_grid.num_cores,
)
for split_size in split_sizes
]
```
Once weights are pushed to the devices and the decoders are executed, the `LMHead` forward pass needs to be executed in iterations.
The code below shows that after each iteration outputs are converted from sharded to interleaved tensors. Once all iterations are completed, the final output is produced by concatenation over the last dim and returned as `output`.

When executing the model, it is essential to ensure that the output of the last decoder is already replicated across tensors. Since this replication is enforced earlier, no additional code is required in the `LMHead` forward pass to handle it.

```py
def forward(self, x: ttnn.Tensor):
outputs = []
for weight, pc in zip(self.output_weights, self.program_configs):
output = ttnn.linear(
x,
weight,
compute_kernel_config=self.compute_kernel_config,
program_config=pc,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=ttnn.bfloat8_b,
)
outputs.append(output)

# Concatenate the outputs
output = ttnn.concat(outputs, dim=-1, memory_config=ttnn.DRAM_MEMORY_CONFIG)

return output
```


### 2.8 Model

<div align="center">
<img src="images/2.8-llama-model.png" alt="Llama model" title="Llama model" width="350" height="350">
</div> <br>

Once the model components (discussed in previous sections) are implemented, there isn’t much left to finalize. In our implementation, embeddings are managed outside the model class, as explained in [Section 2.1 Embedding](#21-embedding).

The model’s constructor initializes N decoders (e.g. 80 for Llama3.1-70b), the `RMSNorm` and the `LMHead`, ensuring that weights for all components are loaded onto the appropriate devices.

During the forward pass, the decoders are executed sequentially, followed by normalization and the `LMHead` computation at the end.
A specific optimization is applied for the prefill mode: since only the last token is relevant, the `LMHead` is executed only on the final tile in this mode.

In prefill mode, the RMSNorm output is interleaved, but the LMHead requires a sharded tensor. To accommodate this, the `interleaved_to_sharded` function is used to prepare the output accordingly.

```py
def forward(
self,
x: ttnn.Tensor,
current_pos,
rot_mat=None,
transformation_mats=None,
user_id=0,
mode="decode",
page_table=None,
get_last_token=-1,
):
for layer in self.layers:
x = layer(x, current_pos, rot_mat, transformation_mats, user_id, mode, page_table)

if mode == "prefill" and get_last_token == -1:
return x

# Slicing the tensor to the nearest ceiling/floor multiples of 32 for the prefill_len, to get the last token
if get_last_token != -1:
x = ttnn.slice(x, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, x.shape[-1]))

# Output norm
x = self.norm(x, mode=mode)

if mode == "prefill":
x = ttnn.interleaved_to_sharded(
x,
self.model_config["LM_HEAD_INPUT_MEMCFG"],
)

return self.lm_head(x)
```


## 3. Features
### 3.1 Generative Decoding
### 3.2 Prefill and Decode
Expand Down Expand Up @@ -221,7 +434,7 @@ This CSV file contains information recorded from all devices during program exec
python models/perf/perf_report.py OPS_CSV_FILE
```

For device performance we recommend looking at a single layer. You can do this by using `--id-range` or by changing your test to run only a single layer of the model. For more information see: [Performance Report Analysis Tool](https://github.com/tenstorrent/tt-metal/tree/main/models/perf). The Performance Report Analysis Tool document describes how to select specific ranges of OPs.
For device performance we recommend looking at a single layer. You can do this by using `--id-range` or by changing your test to run only a single layer of the model. For more information see: [Performance Report Analysis Tool](https://github.com/tenstorrent/tt-metal/tree/main/models/perf). The Performance Report Analysis Tool document describes how to select specific ranges of OPs.

##### What makes a good performance test?

Expand Down

0 comments on commit 98be31d

Please sign in to comment.