Skip to content

Commit

Permalink
feat: adjust rotary embed and avoid cuda graphs of size 2 and smaller
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Jan 7, 2025
1 parent 901156c commit 90b83e2
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The correct answer is: blue",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1733445131,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.4.2-dev0-native",
"usage": {
"completion_tokens": 7,
"prompt_tokens": 27,
"total_tokens": 34
}
}
2 changes: 1 addition & 1 deletion integration-tests/models/test_flash_qwen2_vl_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
def flash_qwen2_vl_handle(launcher):
with launcher(
"Qwen/Qwen2-VL-2B-Instruct",
max_input_tokens=40,
max_input_length=40,
max_batch_prefill_tokens=50,
max_total_tokens=51,
) as handle:
Expand Down
6 changes: 6 additions & 0 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
BloomForCausalLM,
)
from text_generation_server.models.globals import ATTENTION
import text_generation_server.models.globals as globals
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
Expand Down Expand Up @@ -1208,6 +1209,11 @@ def get_model(
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == QWEN2_VL:
# TODO: remove edge case when cuda graph issue is resolved for BS=2 with Qwen2-VL
logger.warning(
"Qwen2-VL requires cuda graphs to be greater than 2. Removing all cuda graphs with a batch size equal or less than 2."
)
globals.CUDA_GRAPHS = list(filter(lambda x: x > 2, globals.CUDA_GRAPHS))
return VlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ def forward(
dim=-1,
)

self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
self.rotary_emb(
query,
torch.select(kv, dim=1, index=0),
cos[: query.shape[0], ...],
sin[: query.shape[0], ...],
)

if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
Expand Down

0 comments on commit 90b83e2

Please sign in to comment.