diff --git a/tech_reports/LLMs/images/2.6-decoder.png b/tech_reports/LLMs/images/2.6-decoder.png
new file mode 100644
index 00000000000..e3040fa1bbf
Binary files /dev/null and b/tech_reports/LLMs/images/2.6-decoder.png differ
diff --git a/tech_reports/LLMs/images/2.7-lm-head.png b/tech_reports/LLMs/images/2.7-lm-head.png
new file mode 100644
index 00000000000..7f3a5787fe4
Binary files /dev/null and b/tech_reports/LLMs/images/2.7-lm-head.png differ
diff --git a/tech_reports/LLMs/images/2.8-llama-model.png b/tech_reports/LLMs/images/2.8-llama-model.png
new file mode 100644
index 00000000000..16af55088d6
Binary files /dev/null and b/tech_reports/LLMs/images/2.8-llama-model.png differ
diff --git a/tech_reports/LLMs/llms.md b/tech_reports/LLMs/llms.md
index 6415f783819..e06dd5b5479 100644
--- a/tech_reports/LLMs/llms.md
+++ b/tech_reports/LLMs/llms.md
@@ -1,5 +1,5 @@
# LLMs in TT-NN
-Authors: Mark O'Connor,
+Authors: Mark O'Connor, Djordje Ivanovic
## Contents
- [LLMs in TT-NN](#llms-in-tt-nn)
@@ -64,8 +64,221 @@ Other useful resources:
- limitations
- which dims are parallelized
### 2.5 MLP
+
### 2.6 Decoder
+
+
+
+When the components explained in previous sections (MLP, Attention, RMSNorm) are implemented, bringing up the decoder should be relatively straightforward.
+According to the diagram (based on the Llama3.1 example), the components are stacked sequentially during the forward pass.
+The only thing to consider is whether addition of MLP and Attention outputs should be stored in L1 or in DRAM.
+
+
+
+The Decode forward pass implementation below follows the diagram above. Keep in mind that, in order to optimize memory usage, it is recommended to deallocate tensors after their usage, which can be crucial under tighter memory constraints.
+
+
+To optimize performance in decode mode, we maintain the residual stream in L1 and shard it across cores and devices. However, determining the optimal number of cores for sharding can be challenging, especially for operations like DRAM-sharded matmuls. Here is the [code](https://github.com/tenstorrent/tt-metal/blob/53c32c0c0da926f97bd0eb042e70fd54c2866f44/models/demos/llama3/tt/model_config.py#L931) in Llama model config, that produces the core grid that will divide the N and K dims of a matmul evenly.
+When it’s not feasible to keep the streams sharded, we use the ttnn op `interleave_to_sharded`, and conversely, switch back as needed.
+In our implementation of Llama3.1 there are some ops that require interleaved tensors and resharding.
+
+
+
+```py
+def forward(
+ self,
+ x: ttnn.Tensor,
+ current_pos,
+ rot_mat=None,
+ transformation_mats=None,
+ user_id=0,
+ mode="decode",
+ page_table=None,
+ ) -> ttnn.Tensor:
+ if mode == "prefill":
+ skip_mem_cfg = ttnn.DRAM_MEMORY_CONFIG
+ elif mode == 'decode':
+ skip_mem_cfg = self.model_config["DEC_SKIP_OUTPUT_MEMCFG"]
+ # Attention RMSNorm
+ attn_in = self.attention_norm(x)
+ # Attention
+ attn_out = self.attention.forward(
+ attn_in,
+ current_pos,
+ rot_mat,
+ transformation_mats,
+ user_id,
+ mode,
+ page_table,
+ )
+ ttnn.deallocate(attn_in)
+ # Residual add of inputs and attention output
+ h = ttnn.add(x, attn_out, memory_config=skip_mem_cfg)
+ ttnn.deallocate(attn_out)
+ # MLP and RMSNorm
+ ff_out = self.feed_forward.forward(self.ffn_norm(h), mode)
+ # Residual add of attention output and mlp output
+ out = ttnn.add(h, ff_out, memory_config=skip_mem_cfg)
+
+ ttnn.deallocate(ff_out)
+ ttnn.deallocate(h)
+
+ return out
+```
+
+
### 2.7 LM Head
+
+The `LMHead` is unique because LLMs typically have large vocabulary sizes, which are independent of the model size (i.e. model parameters).
+As a result, the `LMHead` has a large `last_dim` in its weight matrix. Given the substantial size of `LMHead` weights and the memory limitations of the hardware, these weights must be distributed across multiple devices and processed in iterations, while activations are replicated across devices.
+
+The number of iterations required depends on the size of the weights and the number of devices available, ranging from 1 to several iterations. For example, in Llama 3.1’s decode mode, the LMHead matrix multiplication involves shapes of ```(32, 8K) x (8K, 128K)```.
+
+Below is an illustration of how the LMHead weights are partitioned across two devices, followed by its implementation. For ilustrative purposes it uses 128K for the `vocab_size` instead of the real Llama3.1 value of `128256`.
+
+
+
+
+
+```py
+size_per_device = self.vocab_size // self.num_devices
+num_splits = math.ceil(size_per_device / max_columns_per_device)
+
+split_sizes = [min(size_per_device, max_columns_per_device)] * (num_splits - 1)
+split_sizes.append(size_per_device - sum(split_sizes)) # remaining columns
+
+# Split the output weights
+torch_output_weights = state_dict[f"{state_dict_prefix}output.weight"].permute(1, 0)
+
+self.output_weights = []
+
+for i, split_size in enumerate(split_sizes):
+ cache_file_name = (
+ None if args.dummy_weights else weight_cache_path / f"output_lm_head_{num_splits}_split_shard_{i}"
+ )
+
+ # Create a list to store the split tensors for each device
+ device_splits = []
+ for device in range(self.num_devices):
+ start = device * size_per_device + sum(split_sizes[:i])
+ end = start + split_size
+ device_splits.append(torch_output_weights[:, start:end])
+
+ # Concatenate the splits from all devices
+ combined_split = torch.cat(device_splits, dim=-1)
+
+ memory_config = args.create_dram_sharded_mem_config(
+ k=args.dim, n=combined_split.shape[-1] // self.num_devices
+ )
+ self.output_weights.append(
+ ttnn.as_tensor(
+ combined_split,
+ device=mesh_device,
+ mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1),
+ layout=ttnn.TILE_LAYOUT,
+ dtype=dtype,
+ memory_config=memory_config,
+ cache_file_name=cache_file_name,
+ )
+ )
+```
+
+We use dram-sharded matmul for LMHead with `program_config` and `memory_config` generated by the code below.
+For more information check [Section: Op Configs](#44-op-configs).
+The primary reason for having multiple `program_configs` is that the weight shapes may result in unequal split sizes. This variability means the same configuration cannot be used for every matrix multiplication.
+
+```py
+# Generate dram-sharded memory_config
+memory_config = args.create_dram_sharded_mem_config(
+ k=args.dim, n=combined_split.shape[-1] // self.num_devices
+)
+# Generate dram-sharded program_config
+self.program_configs = [
+ args.dram_matmul_config(
+ args.tile_padded_batch_rows,
+ args.dim,
+ split_size,
+ args.lm_head_core_grid.num_cores,
+ )
+ for split_size in split_sizes
+]
+```
+Once weights are pushed to the devices and the decoders are executed, the `LMHead` forward pass needs to be executed in iterations.
+The code below shows that after each iteration outputs are converted from sharded to interleaved tensors. Once all iterations are completed, the final output is produced by concatenation over the last dim and returned as `output`.
+
+When executing the model, it is essential to ensure that the output of the last decoder is already replicated across tensors. Since this replication is enforced earlier, no additional code is required in the `LMHead` forward pass to handle it.
+
+```py
+def forward(self, x: ttnn.Tensor):
+ outputs = []
+ for weight, pc in zip(self.output_weights, self.program_configs):
+ output = ttnn.linear(
+ x,
+ weight,
+ compute_kernel_config=self.compute_kernel_config,
+ program_config=pc,
+ memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
+ dtype=ttnn.bfloat8_b,
+ )
+ outputs.append(output)
+
+ # Concatenate the outputs
+ output = ttnn.concat(outputs, dim=-1, memory_config=ttnn.DRAM_MEMORY_CONFIG)
+
+ return output
+```
+
+
+### 2.8 Model
+
+
+
+
+
+Once the model components (discussed in previous sections) are implemented, there isn’t much left to finalize. In our implementation, embeddings are managed outside the model class, as explained in [Section 2.1 Embedding](#21-embedding).
+
+The model’s constructor initializes N decoders (e.g. 80 for Llama3.1-70b), the `RMSNorm` and the `LMHead`, ensuring that weights for all components are loaded onto the appropriate devices.
+
+During the forward pass, the decoders are executed sequentially, followed by normalization and the `LMHead` computation at the end.
+A specific optimization is applied for the prefill mode: since only the last token is relevant, the `LMHead` is executed only on the final tile in this mode.
+
+In prefill mode, the RMSNorm output is interleaved, but the LMHead requires a sharded tensor. To accommodate this, the `interleaved_to_sharded` function is used to prepare the output accordingly.
+
+```py
+def forward(
+ self,
+ x: ttnn.Tensor,
+ current_pos,
+ rot_mat=None,
+ transformation_mats=None,
+ user_id=0,
+ mode="decode",
+ page_table=None,
+ get_last_token=-1,
+):
+ for layer in self.layers:
+ x = layer(x, current_pos, rot_mat, transformation_mats, user_id, mode, page_table)
+
+ if mode == "prefill" and get_last_token == -1:
+ return x
+
+ # Slicing the tensor to the nearest ceiling/floor multiples of 32 for the prefill_len, to get the last token
+ if get_last_token != -1:
+ x = ttnn.slice(x, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, x.shape[-1]))
+
+ # Output norm
+ x = self.norm(x, mode=mode)
+
+ if mode == "prefill":
+ x = ttnn.interleaved_to_sharded(
+ x,
+ self.model_config["LM_HEAD_INPUT_MEMCFG"],
+ )
+
+ return self.lm_head(x)
+```
+
+
## 3. Features
### 3.1 Generative Decoding
### 3.2 Prefill and Decode
@@ -221,7 +434,7 @@ This CSV file contains information recorded from all devices during program exec
python models/perf/perf_report.py OPS_CSV_FILE
```
-For device performance we recommend looking at a single layer. You can do this by using `--id-range` or by changing your test to run only a single layer of the model. For more information see: [Performance Report Analysis Tool](https://github.com/tenstorrent/tt-metal/tree/main/models/perf). The Performance Report Analysis Tool document describes how to select specific ranges of OPs.
+For device performance we recommend looking at a single layer. You can do this by using `--id-range` or by changing your test to run only a single layer of the model. For more information see: [Performance Report Analysis Tool](https://github.com/tenstorrent/tt-metal/tree/main/models/perf). The Performance Report Analysis Tool document describes how to select specific ranges of OPs.
##### What makes a good performance test?