Skip to content

Commit

Permalink
#5337: skip reshape op in attention for batch 32
Browse files Browse the repository at this point in the history
  • Loading branch information
caixunshiren committed Jul 19, 2024
1 parent b5393a9 commit ec46ffd
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,11 @@ def forward_decode(
)

# Reshape such that true unpadded batch is tracked in shape
fqkv_shape = xqkv_fused.shape
xqkv_fused = ttnn.reshape(
xqkv_fused, ttnn.Shape((1, 1, self.max_batch_size, fqkv_shape[3]), (1, 1, 32, fqkv_shape[3]))
)
if self.max_batch_size < 32:
fqkv_shape = xqkv_fused.shape
xqkv_fused = ttnn.reshape(
xqkv_fused, ttnn.Shape((1, 1, self.max_batch_size, fqkv_shape[3]), (1, 1, 32, fqkv_shape[3]))
)

# split qkv into heads
(
Expand Down

0 comments on commit ec46ffd

Please sign in to comment.