Skip to content

Commit

Permalink
[skip ci] #0: ViT tech report (#13032)
Browse files Browse the repository at this point in the history
#0: ViT tech report
  • Loading branch information
mbahnasTT authored Sep 24, 2024
1 parent 0adc624 commit 7c36ba9
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions tech_reports/ViT-TTNN/vit.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Authors: Vishal Shenoy, Mohamed Bahnas
- [1. Overview](#1-overview)
- [2. ViT TT-NN Optimization Techniques](#2-vit-tt-nn-optimization-techniques)
- [2.1 Sharding on all relevant OPs](#21-sharding-on-all-relevant-ops)
- [2.2 Matmul sharding variants](#22-matmul-sharding-variants)
- [2.2 Matmul sharding variants in ViT](#22-matmul-sharding-variants-in-vit)
- [2.3 Transformer optimizations](#23-transformer-optimizations)
- [3. ViT TT-NN Code Structure](#3-vit-tt-nn-code-structure)
- [3.1 Top-level modules](#31-top-level-modules)
Expand Down Expand Up @@ -45,18 +45,22 @@ The implemented optimization techniques in TT-NN compared to the conventional fl
![Sharding Concept](images/sharding_concept.png)
- Illustrative example
![Sharding Example](images/sharding_example.png)
### 2.2 Matmul sharding variants
### 2.2 Matmul sharding variants in ViT
#### 2.2.1 Matmul Reuse (BMM)
The batch Matmul(BMM) Reuse case used in ViT model is in the Multi-head Self Attention module, where both inputs (in0 and in1) as well as the output are height sharded. There no multi-cast (mcast) technique applied on the inputs here. Each core will be responsible for the Matmul of single head of one image of the batch.

![BMM Height](images/bmm_height.png)
#### 2.2.2 Matmul Reuse Mcast (2D)
The Reuse Mcast case used in ViT model is the block sharded Matmul cases in QKV generation as well as the Feed-Forward Network.
The implemented config is Block sharded as Row_Major, where the in0 outer dimension (M) is sharded along the y-axis of the core grid. On the inner dimension of in0, the sharded slices are mcasted along the x-direction of the core grid. The mcast process is done in turn from one core to all other cores in the row, so the whole inner dimension of in0 exists per each core during its Matmul operation.
The in1 is interleaved (on L1 or DRAM) and its slices along the N (outer) dimension are mcasted along the cores in the same column, where each slide has the full inner dimension (K). This is aligned with the previously mentioned mcast of in0 slices.
Worth to mention that in some cases it may be better to implement the Column_Major (and mcast transposed = True) config, where the in0 M dimension is sharded along the x-axis of the core as shown in the figure. All the mcast techniques in the Column_Major will be transposed with respect to the Row_Major config mentioned in the previous paragraph.
- The implemented config is Block sharded orientation is Row_Major, where the in0 outer dimension (M) is sharded along the y-axis of the core grid. On the inner dimension of in0, the sharded slices are mcasted along the x-direction of the core grid. The mcast process is done in turn from one core to all other cores in the row, so the whole inner dimension of in0 exists per each core during its Matmul operation.
- Please note that the Row_Major term mentioned here is referring to the sharded blocks placement on the core grid. It's different than the Row_Major data layout that is compared to the Tile layout in the report [tensor_layouts](https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/tensor_layouts/tensor_layouts.md)
- The in1 is interleaved (on L1 or DRAM) and its slices along the N (outer) dimension are mcasted along the cores in the same column, where each slide has the full inner dimension (K). This is aligned with the previously mentioned mcast of in0 slices.
- Worth to mention that in some cases it may be better to implement the Column_Major (and mcast transposed = True) config, where the in0 M dimension is sharded along the x-axis of the core as shown in the figure. All the mcast techniques in the Column_Major will be transposed with respect to the Row_Major config mentioned in the previous paragraph.

![Mcast Block](images/block_mcast.png)
#### 2.2.3 Matmul Reuse Mcast (1D)
The other Reuse Mcast case (not used in ViT) is the height sharded on in0, while in1 is still interleaved, as shown in the figure.

![Mcast Height](images/height_mcast.png)

### 2.3 Transformer optimizations
Expand Down

0 comments on commit 7c36ba9

Please sign in to comment.