diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_vl_warmup/test_flash_qwen2_vl_simple.json b/integration-tests/models/__snapshots__/test_flash_qwen2_vl_warmup/test_flash_qwen2_vl_simple.json new file mode 100644 index 00000000000..a986510f239 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_vl_warmup/test_flash_qwen2_vl_simple.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "The correct answer is: blue", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1733445131, + "id": "", + "model": "Qwen/Qwen2-VL-2B-Instruct", + "object": "chat.completion", + "system_fingerprint": "2.4.2-dev0-native", + "usage": { + "completion_tokens": 7, + "prompt_tokens": 27, + "total_tokens": 34 + } +} diff --git a/integration-tests/models/test_flash_qwen2_vl_warmup.py b/integration-tests/models/test_flash_qwen2_vl_warmup.py index 74456e48df1..5be87ee21a3 100644 --- a/integration-tests/models/test_flash_qwen2_vl_warmup.py +++ b/integration-tests/models/test_flash_qwen2_vl_warmup.py @@ -5,7 +5,7 @@ def flash_qwen2_vl_handle(launcher): with launcher( "Qwen/Qwen2-VL-2B-Instruct", - max_input_tokens=40, + max_input_length=40, max_batch_prefill_tokens=50, max_total_tokens=51, ) as handle: diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fcc79608645..eb980d0e6f9 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -29,6 +29,7 @@ BloomForCausalLM, ) from text_generation_server.models.globals import ATTENTION +import text_generation_server.models.globals as globals from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.galactica import GalacticaCausalLMBatch from text_generation_server.models.custom_modeling.neox_modeling import ( @@ -1208,6 +1209,11 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == QWEN2_VL: + # TODO: remove edge case when cuda graph issue is resolved for BS=2 with Qwen2-VL + logger.warning( + "Qwen2-VL requires cuda graphs to be greater than 2. Removing all cuda graphs with a batch size equal or less than 2." + ) + globals.CUDA_GRAPHS = list(filter(lambda x: x > 2, globals.CUDA_GRAPHS)) return VlmCausalLM( model_id=model_id, model_class=Qwen2VLForConditionalGeneration, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index cc4039b1cbc..01d3bf1a377 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -138,7 +138,12 @@ def forward( dim=-1, ) - self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + self.rotary_emb( + query, + torch.select(kv, dim=1, index=0), + cos[: query.shape[0], ...], + sin[: query.shape[0], ...], + ) if prefill_cache_indices is not None: kv_to_cache = kv[prefill_cache_indices]