Skip to content

Commit

Permalink
[Llama3.2-11b vLLM Integration] Add support for paged cross attention…
Browse files Browse the repository at this point in the history
…, fixes for continuous batching, simplified decode forward call (#16076)
  • Loading branch information
skhorasganiTT authored Dec 17, 2024
1 parent 9208f77 commit 05e6c61
Show file tree
Hide file tree
Showing 9 changed files with 400 additions and 93 deletions.
24 changes: 8 additions & 16 deletions models/demos/llama3/demo/simple_vision_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,22 +189,14 @@ def test_llama_multimodal_demo_text(
position_id = prefill_lens + gen_idx
next_token_tensor = next_tokens.reshape(max_batch_size, 1)

if enable_trace:
logits = generator.easy_trace(
position_id,
next_token_tensor,
batch_xattn_masks,
batch_text_masks,
xattn_caches,
)
else:
logits = generator.decode_forward(
position_id,
next_token_tensor,
batch_xattn_masks,
batch_text_masks,
xattn_caches,
)
logits = generator.decode_forward(
position_id,
next_token_tensor,
batch_xattn_masks,
batch_text_masks,
xattn_caches,
enable_trace=enable_trace,
)

next_tokens, next_texts = sampler(logits)
# Update next token
Expand Down
Loading

0 comments on commit 05e6c61

Please sign in to comment.