Skip to content

Commit

Permalink
[Llama3] Add padding fixes to support continuous batching in vLLM
Browse files Browse the repository at this point in the history
Signed-off-by: Salar Hosseini <[email protected]>
  • Loading branch information
skhorasganiTT committed Dec 17, 2024
1 parent fec57ad commit 78e9f29
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
5 changes: 4 additions & 1 deletion models/demos/llama3/tt/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None):
mesh_mapper=mesh_mapper,
)

rope_idxs = self.rope_setup.get_rot_idxs(current_pos, on_host=True)
rot_current_pos = torch.maximum(
current_pos, torch.tensor(0, dtype=torch.int64)
) # Ensure position indices are non-negative
rope_idxs = self.rope_setup.get_rot_idxs(rot_current_pos, on_host=True)
current_pos_tt = ttnn.from_torch(
current_pos,
device=None,
Expand Down
27 changes: 24 additions & 3 deletions models/demos/llama3/tt/multimodal/llama_vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,10 @@ def prepare_decode_inputs_host(
assert (
B == self.configuration.max_batch_size
), f"Batch size must match max batch size. Got {B}, expected {self.configuration.max_batch_size}"
unpadded_batch_size = len(cross_attention_masks)
assert unpadded_batch_size == len(
full_text_row_masked_out_mask
), f"cross_attention_masks batch dim ({unpadded_batch_size}) does not match full_text_row_masked_out_mask batch dim ({len(full_text_row_masked_out_mask)})"
h = self.prepare_inputs_common(position_id, tokens)
tt_h = self.configuration.prepare_residual_tensor_decode(
h,
Expand All @@ -481,8 +485,20 @@ def prepare_decode_inputs_host(
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
)

tt_rope_id = self.text_model.rope_setup.get_rot_idxs(position_id, on_host=True)
xattn_mask = torch.cat([cross_attention_masks[i][:, :, position_id[i]] for i in range(B)], dim=1).unsqueeze(0)
rot_position_id = torch.maximum(
position_id, torch.tensor(0, dtype=torch.int64)
) # Ensure position indices are non-negative
tt_rope_id = self.text_model.rope_setup.get_rot_idxs(rot_position_id, on_host=True)

xattn_mask = torch.cat(
[cross_attention_masks[i][:, :, position_id[i]] for i in range(unpadded_batch_size)], dim=1
).unsqueeze(0)
# Pad xattn_mask along batch if tokens have been padded
if B > unpadded_batch_size:
xattn_mask = torch.cat(
[xattn_mask, torch.zeros(1, 1, B - unpadded_batch_size, xattn_mask.shape[-1])], dim=2
)

xattn_mask_expand = xattn_mask.expand(-1, self.configuration.n_heads // self.configuration.num_devices, -1, -1)
xattn_mask_expand = xattn_mask_expand.transpose(1, 2).contiguous()

Expand All @@ -495,8 +511,13 @@ def prepare_decode_inputs_host(
)

full_text_mask = torch.cat(
[full_text_row_masked_out_mask[i][:, :, position_id[i]] for i in range(B)], dim=1
[full_text_row_masked_out_mask[i][:, :, position_id[i]] for i in range(unpadded_batch_size)], dim=1
).unsqueeze(0)
# Pad full_text_mask along batch if tokens have been padded
if B > unpadded_batch_size:
full_text_mask = torch.cat(
[full_text_mask, torch.zeros(1, 1, B - unpadded_batch_size, full_text_mask.shape[-1])], dim=2
)
full_text_mask_expand_1NSH = full_text_mask.expand(
-1, self.configuration.n_heads // self.configuration.num_devices, -1, self.configuration.head_dim
)
Expand Down

0 comments on commit 78e9f29

Please sign in to comment.