From 3f674a49b5033a6ed778ab960e86e03cfa64aa1f Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 15 Aug 2024 01:55:42 +0800 Subject: [PATCH] [VLM][Core] Support profiling with multiple multi-modal inputs per prompt (#7126) --- .../input_processing_pipeline.rst | 2 +- .../dev/multimodal/multimodal_index.rst | 3 + .../models/enabling_multimodal_inputs.rst | 2 +- tests/engine/test_arg_utils.py | 24 ++++ tests/models/test_blip2.py | 2 +- tests/models/test_fuyu.py | 2 +- tests/models/test_internvl.py | 2 +- tests/models/test_llava.py | 2 +- tests/models/test_llava_next.py | 2 +- tests/models/test_minicpmv.py | 5 +- tests/models/test_paligemma.py | 2 +- tests/models/test_phi3v.py | 2 +- tests/multimodal/test_mapper.py | 84 ++++++++++++- vllm/config.py | 14 ++- vllm/engine/arg_utils.py | 48 ++++++- vllm/engine/llm_engine.py | 11 +- vllm/inputs/registry.py | 83 ++++++++++--- vllm/model_executor/model_loader/loader.py | 2 +- vllm/model_executor/models/blip.py | 7 +- vllm/model_executor/models/blip2.py | 37 ++++-- vllm/model_executor/models/chameleon.py | 23 ++-- vllm/model_executor/models/clip.py | 8 +- vllm/model_executor/models/fuyu.py | 27 ++-- vllm/model_executor/models/interfaces.py | 7 +- vllm/model_executor/models/internvl.py | 11 +- vllm/model_executor/models/llava.py | 16 ++- vllm/model_executor/models/llava_next.py | 14 ++- vllm/model_executor/models/minicpmv.py | 21 ++-- vllm/model_executor/models/paligemma.py | 13 +- vllm/model_executor/models/phi3v.py | 12 +- vllm/model_executor/models/siglip.py | 8 +- vllm/multimodal/base.py | 48 ++++--- vllm/multimodal/image.py | 9 +- vllm/multimodal/registry.py | 117 +++++++++++++++--- vllm/utils.py | 28 ----- vllm/worker/enc_dec_model_runner.py | 19 ++- vllm/worker/model_runner.py | 37 +++--- vllm/worker/xpu_model_runner.py | 36 +++--- 38 files changed, 573 insertions(+), 217 deletions(-) create mode 100644 tests/engine/test_arg_utils.py diff --git a/docs/source/dev/input_processing/input_processing_pipeline.rst b/docs/source/dev/input_processing/input_processing_pipeline.rst index e0c773781115f..48abec8f75286 100644 --- a/docs/source/dev/input_processing/input_processing_pipeline.rst +++ b/docs/source/dev/input_processing/input_processing_pipeline.rst @@ -17,4 +17,4 @@ Input Processing Pipeline 6. If the data contains multi-modal data, convert it into keyword arguments using :meth:`MULTIMODAL_REGISTRY.map_input `. - - For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision language model. + - For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision model. diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index f70fd03e259ff..a45bc885dc122 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -15,6 +15,9 @@ by following :ref:`this guide `. Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here `. +.. + TODO: Add usage of --limit-mm-per-prompt when multi-image input is officially supported + Guides ++++++ diff --git a/docs/source/models/enabling_multimodal_inputs.rst b/docs/source/models/enabling_multimodal_inputs.rst index dc76f921d5b09..3d0d1aec69845 100644 --- a/docs/source/models/enabling_multimodal_inputs.rst +++ b/docs/source/models/enabling_multimodal_inputs.rst @@ -66,7 +66,7 @@ A default mapper is available for each modality in the core vLLM library. This i 3. Register maximum number of multi-modal tokens ------------------------------------------------ -For each modality type that the model accepts as input, calculate the maximum possible number of tokens +For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data instance and register it via :meth:`INPUT_REGISTRY.register_dummy_data `. .. code-block:: diff diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py new file mode 100644 index 0000000000000..3208d6bb48bdc --- /dev/null +++ b/tests/engine/test_arg_utils.py @@ -0,0 +1,24 @@ +import pytest + +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import FlexibleArgumentParser + + +@pytest.mark.parametrize(("arg", "expected"), [ + (None, None), + ("image=16", { + "image": 16 + }), + ("image=16,video=2", { + "image": 16, + "video": 2 + }), +]) +def test_limit_mm_per_prompt_parser(arg, expected): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + if arg is None: + args = parser.parse_args([]) + else: + args = parser.parse_args(["--limit-mm-per-prompt", arg]) + + assert args.limit_mm_per_prompt == expected diff --git a/tests/models/test_blip2.py b/tests/models/test_blip2.py index 26afd57ae6106..64b7a77404b98 100644 --- a/tests/models/test_blip2.py +++ b/tests/models/test_blip2.py @@ -59,7 +59,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalData objects and corresponding - vision language config as input. + MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ diff --git a/tests/models/test_fuyu.py b/tests/models/test_fuyu.py index 7d0f3be5ea008..0d666d8f71a92 100644 --- a/tests/models/test_fuyu.py +++ b/tests/models/test_fuyu.py @@ -49,7 +49,7 @@ def run_test( All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects - and corresponding vision language config as input. + and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ diff --git a/tests/models/test_internvl.py b/tests/models/test_internvl.py index 6aa0189648d72..6007a897d65ab 100644 --- a/tests/models/test_internvl.py +++ b/tests/models/test_internvl.py @@ -117,7 +117,7 @@ def run_test( All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects - and corresponding vision language config as input. + and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index 2724a0855117e..edaf7d400eb53 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -69,7 +69,7 @@ def run_test( All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects - and corresponding vision language config as input. + and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index 60c7fc33b72fe..2bd27f888680d 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -177,7 +177,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects - and corresponding vision language config as input. + and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ diff --git a/tests/models/test_minicpmv.py b/tests/models/test_minicpmv.py index 32f1cb2c2ed33..bf72dad0d1f5b 100644 --- a/tests/models/test_minicpmv.py +++ b/tests/models/test_minicpmv.py @@ -61,7 +61,7 @@ def run_test( All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects - and corresponding vision language config as input. + and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ @@ -176,7 +176,7 @@ def run_multi_image_test( All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects - and corresponding vision language config as input. + and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ @@ -197,6 +197,7 @@ def run_multi_image_test( with vllm_runner(model, max_model_len=4096, max_num_seqs=1, + limit_mm_per_prompt={"image": len(images)}, dtype=dtype, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, diff --git a/tests/models/test_paligemma.py b/tests/models/test_paligemma.py index f3f682b1c2cda..038a22f71acad 100644 --- a/tests/models/test_paligemma.py +++ b/tests/models/test_paligemma.py @@ -72,7 +72,7 @@ def run_test( All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects - and corresponding vision language config as input. + and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index 3737dc2bd076e..ccfc98a325982 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -73,7 +73,7 @@ def run_test( All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects - and corresponding vision language config as input. + and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ diff --git a/tests/multimodal/test_mapper.py b/tests/multimodal/test_mapper.py index 321566ad53a50..6b0c02c799c4a 100644 --- a/tests/multimodal/test_mapper.py +++ b/tests/multimodal/test_mapper.py @@ -1,15 +1,22 @@ +from contextlib import nullcontext + import numpy as np import pytest from transformers import CLIPImageProcessor, LlavaNextImageProcessor -from vllm.config import ModelConfig -from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.config import ModelConfig, MultiModalConfig +from vllm.multimodal import MultiModalRegistry from vllm.multimodal.utils import rescale_image_size +@pytest.fixture +def mm_registry(): + return MultiModalRegistry() + + @pytest.mark.parametrize("dtype", ["half", "float"]) @pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0]) -def test_clip_image_processor(image_assets, dtype, size_factor): +def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor): MODEL_NAME = "llava-hf/llava-1.5-7b-hf" hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME) @@ -24,6 +31,9 @@ def test_clip_image_processor(image_assets, dtype, size_factor): dtype=dtype, revision=None, ) + mm_config = MultiModalConfig(limit_per_prompt={"image": 1}) + + mm_registry.init_mm_limits_per_prompt(model_config, mm_config) for asset in image_assets: image = rescale_image_size(asset.pil_image, size_factor) @@ -32,7 +42,7 @@ def test_clip_image_processor(image_assets, dtype, size_factor): image, return_tensors="pt", ) - vllm_result = MULTIMODAL_REGISTRY.map_input( + vllm_result = mm_registry.map_input( model_config, {"image": image}, ) @@ -48,7 +58,8 @@ def test_clip_image_processor(image_assets, dtype, size_factor): @pytest.mark.parametrize("dtype", ["half", "float"]) @pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0]) -def test_llava_next_image_processor(image_assets, dtype, size_factor): +def test_llava_next_image_processor(image_assets, mm_registry, dtype, + size_factor): MODEL_NAME = "llava-hf/llava-v1.6-vicuna-7b-hf" hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME) @@ -63,6 +74,9 @@ def test_llava_next_image_processor(image_assets, dtype, size_factor): dtype=dtype, revision=None, ) + mm_config = MultiModalConfig(limit_per_prompt={"image": 1}) + + mm_registry.init_mm_limits_per_prompt(model_config, mm_config) for asset in image_assets: image = rescale_image_size(asset.pil_image, size_factor) @@ -71,7 +85,7 @@ def test_llava_next_image_processor(image_assets, dtype, size_factor): image, return_tensors="pt", ) - vllm_result = MULTIMODAL_REGISTRY.map_input( + vllm_result = mm_registry.map_input( model_config, {"image": image}, ) @@ -83,3 +97,61 @@ def test_llava_next_image_processor(image_assets, dtype, size_factor): assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}" assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" + + +@pytest.mark.parametrize( + ("num_images", "limit", "is_valid"), + [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True), + (2, 1, False), (2, 2, True)], +) +def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid): + MODEL_NAME = "llava-hf/llava-1.5-7b-hf" + + model_config = ModelConfig( + model=MODEL_NAME, + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="half", + revision=None, + ) + mm_config = MultiModalConfig(limit_per_prompt={"image": limit}) + + mm_registry.init_mm_limits_per_prompt(model_config, mm_config) + + image = image_assets[0].pil_image + if num_images == 0: + mm_inputs = {} + elif num_images == 1: + mm_inputs = {"image": image} + else: + mm_inputs = {"image": [image] * num_images} + + with nullcontext() if is_valid else pytest.raises(ValueError): + mm_registry.map_input(model_config, mm_inputs) + + +# NOTE: We don't test zero images since the HF processor doesn't support it +@pytest.mark.parametrize("num_images", [1, 2]) +def test_image_mapper_multi(image_assets, mm_registry, num_images): + MODEL_NAME = "llava-hf/llava-1.5-7b-hf" + + model_config = ModelConfig( + model=MODEL_NAME, + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="half", + revision=None, + ) + mm_config = MultiModalConfig(limit_per_prompt={"image": num_images}) + + mm_registry.init_mm_limits_per_prompt(model_config, mm_config) + + image = image_assets[0].pil_image + mm_inputs = {"image": [image] * num_images} + + mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + assert len(mapped_inputs["pixel_values"]) == num_images diff --git a/vllm/config.py b/vllm/config.py index a39f5307931e5..15d17d5e42a54 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,7 +1,8 @@ import enum import json from dataclasses import dataclass, field, fields -from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union +from typing import (TYPE_CHECKING, ClassVar, List, Mapping, Optional, Tuple, + Type, Union) import torch from transformers import PretrainedConfig @@ -1429,10 +1430,15 @@ def verify_with_model_config(self, model_config: ModelConfig): @dataclass class MultiModalConfig: - """Configs the input data format and how models should run for - multimodal models.""" + """Controls the behavior of multimodal models.""" + + limit_per_prompt: Mapping[str, int] + """ + The maximum number of multi-modal input instances allowed per prompt + for each :class:`~vllm.multimodal.MultiModalPlugin`. + """ + # TODO: Add configs to init vision tower or not. - pass _STR_DTYPE_TO_TORCH_DTYPE = { diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8911d91420d70..76bd3b630c54b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -2,7 +2,8 @@ import dataclasses import json from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union +from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type, + Union) from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, @@ -15,8 +16,7 @@ from vllm.utils import FlexibleArgumentParser if TYPE_CHECKING: - from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( - BaseTokenizerGroup) + from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup logger = init_logger(__name__) @@ -29,11 +29,32 @@ def nullable_str(val: str): return val +def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: + if len(val) == 0: + return None + + out_dict: Dict[str, int] = {} + for item in val.split(","): + try: + key, value = item.split("=") + except TypeError as exc: + msg = "Each item should be in the form KEY=VALUE" + raise ValueError(msg) from exc + + try: + out_dict[key] = int(value) + except ValueError as exc: + msg = f"Failed to parse value of item {key}={value}" + raise ValueError(msg) from exc + + return out_dict + + @dataclass class EngineArgs: """Arguments for vLLM engine.""" model: str = 'facebook/opt-125m' - served_model_name: Optional[Union[List[str]]] = None + served_model_name: Optional[Union[str, List[str]]] = None tokenizer: Optional[str] = None skip_tokenizer_init: bool = False tokenizer_mode: str = 'auto' @@ -81,6 +102,7 @@ class EngineArgs: # notice. tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray" tokenizer_pool_extra_config: Optional[dict] = None + limit_mm_per_prompt: Optional[Mapping[str, int]] = None enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 @@ -435,6 +457,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'This should be a JSON string that will be ' 'parsed into a dictionary. Ignored if ' 'tokenizer_pool_size is 0.') + + # Multimodal related configs + parser.add_argument( + '--limit-mm-per-prompt', + type=nullable_kvs, + default=EngineArgs.limit_mm_per_prompt, + # The default value is given in + # MultiModalRegistry.init_mm_limits_per_prompt + help=('For each multimodal plugin, limit how many ' + 'input instances to allow for each prompt. ' + 'Expects a comma-separated list of items, ' + 'e.g.: `image=16,video=2` allows a maximum of 16 ' + 'images and 2 videos per prompt. Defaults to 1 for ' + 'each modality.')) + # LoRA related configs parser.add_argument('--enable-lora', action='store_true', @@ -709,7 +746,8 @@ def create_engine_config(self, ) -> EngineConfig: "CPU offload space must be non-negative" f", but got {self.cpu_offload_gb}") - multimodal_config = MultiModalConfig() + multimodal_config = MultiModalConfig( + limit_per_prompt=self.limit_mm_per_prompt or {}) device_config = DeviceConfig(device=self.device) model_config = ModelConfig( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a25d60bc0aa33..979555eb6a05d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -24,8 +24,9 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, LLMInputs, - PromptInputs, SingletonPromptInputs) +from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, + InputRegistry, LLMInputs, PromptInputs, + SingletonPromptInputs) from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -180,6 +181,7 @@ def __init__( log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + input_registry: InputRegistry = INPUT_REGISTRY, ) -> None: logger.info( "Initializing an LLM engine (v%s) with config: " @@ -265,8 +267,9 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.generation_config_fields = _load_generation_config_dict( model_config) - self.input_processor = INPUT_REGISTRY.create_input_processor( - self.model_config) + self.input_registry = input_registry + self.input_processor = input_registry.create_input_processor( + model_config) self.model_executor = executor_class( model_config=model_config, diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 006dc8e146a6c..2ca8b10f71593 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,6 +1,8 @@ import functools +from collections import UserDict from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type +from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol, + Tuple, Type) from torch import nn from transformers import PretrainedConfig @@ -12,7 +14,7 @@ if TYPE_CHECKING: from vllm.config import ModelConfig, MultiModalConfig - from vllm.multimodal import MultiModalDataDict + from vllm.multimodal import MultiModalDataDict, MultiModalRegistry from vllm.sequence import SequenceData logger = init_logger(__name__) @@ -65,15 +67,38 @@ def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C: N = TypeVar("N", bound=Type[nn.Module]) -DummyDataFactory = Callable[[InputContext, int], - Tuple["SequenceData", - Optional["MultiModalDataDict"]]] -""" -Create dummy data to be inputted into the model. -Note: - :data:`InputProcessor` is not applied to the dummy data. -""" +class DummyDataFactory(Protocol): + + def __call__( + self, + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + """ + Create dummy data to be inputted into the model. + + Note: + :data:`InputProcessor` is not applied to the dummy data. + """ + ... + + +class _MultiModalCounts(UserDict): + """ + Wraps `mm_counts` for a more informative error message + when attempting to access a plugin that does not exist. + """ + + def __getitem__(self, key: str) -> int: + try: + return super().__getitem__(key) + except KeyError as exc: + msg = (f"There is no multi-modal plugin with the key: {key}. " + f"Available keys: {set(self.keys())}") + raise KeyError(msg) from exc + InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs] """Preprocess the inputs to the model.""" @@ -95,6 +120,7 @@ def _default_dummy_data_factory( self, ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int], ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ The default dummy data factory represents the longest possible text @@ -133,8 +159,12 @@ def wrapper(model_cls: N) -> N: return wrapper - def dummy_data_for_profiling(self, model_config: "ModelConfig", - seq_len: int): + def dummy_data_for_profiling( + self, + model_config: "ModelConfig", + seq_len: int, + mm_registry: "MultiModalRegistry", + ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ Create dummy data for profiling the memory usage of a model. @@ -142,6 +172,10 @@ def dummy_data_for_profiling(self, model_config: "ModelConfig", See also: :ref:`enabling_multimodal_inputs` + + Note: + This should be called after + :meth:`~MultiModalRegistry.init_mm_limits_per_prompt`. """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture @@ -149,8 +183,29 @@ def dummy_data_for_profiling(self, model_config: "ModelConfig", model_cls, _ = get_model_architecture(model_config) dummy_factory = self._dummy_factories_by_model_type \ .get(model_cls, self._default_dummy_data_factory) - - return dummy_factory(InputContext(model_config), seq_len) + mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + + seq_data, mm_data = dummy_factory( + InputContext(model_config), + seq_len, + _MultiModalCounts(mm_counts), + ) + + # Having more tokens is over-conservative but otherwise fine + num_tokens = seq_data.prompt_token_ids + assert len(num_tokens) >= seq_len, ( + f"Expected at least {seq_len} dummy tokens for profiling, " + f"but found {len(num_tokens)} tokens instead.") + + if mm_data is not None: + for k, v in mm_data.items(): + num_items = len(v) if isinstance(v, list) else 1 + num_expected = mm_counts[k] + assert num_items >= num_expected, ( + f"Expected at least {num_expected} dummy '{k}' instances " + f"for profiling, but found {num_items} instances instead.") + + return seq_data, mm_data def _default_input_processor(self, ctx: InputContext, inputs: LLMInputs) -> LLMInputs: diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b07a05828ed15..302bcb2e9fd5a 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -133,7 +133,7 @@ def _get_model_initialization_kwargs( if supports_multimodal(model_class): if multimodal_config is None: - raise ValueError("Provide vision related configurations " + raise ValueError("Provide multi-modal related configurations " "through LLM entrypoint or engine arguments.") extra_kwargs["multimodal_config"] = multimodal_config diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 0b124d5e8a85a..a6fd5f58b3cb6 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -31,13 +31,13 @@ def get_blip_num_patches(*, image_size: int, patch_size: int) -> int: def get_blip_image_feature_size( - hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int: + hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int: return get_blip_num_patches(image_size=hf_config.image_size, patch_size=hf_config.patch_size) def get_max_blip_image_tokens( - hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int: + hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int: return get_blip_image_feature_size(hf_config) @@ -60,6 +60,7 @@ def dummy_seq_data_for_blip( def dummy_image_for_blip( hf_config: Union[BlipVisionConfig, Blip2VisionConfig], + num_images: int, *, image_width_override: Optional[int] = None, image_height_override: Optional[int] = None, @@ -71,7 +72,7 @@ def dummy_image_for_blip( height = image_height_override image = Image.new("RGB", (width, height), color=0) - return {"image": image} + return {"image": image if num_images == 1 else [image] * num_images} def input_processor_for_blip( diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index c64428a4d7c75..386dfeb5bb1e5 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,4 +1,5 @@ -from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) import torch import torch.nn as nn @@ -413,17 +414,39 @@ def get_max_blip2_image_tokens(ctx: InputContext): raise NotImplementedError(msg) -def dummy_data_for_blip2(ctx: InputContext, seq_len: int): +def dummy_seq_data_for_blip2( + hf_config: Blip2Config, + seq_len: int, + num_images: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = get_blip2_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + token_ids = [image_token_id] * image_feature_size * num_images + token_ids += [0] * (seq_len - image_feature_size * num_images) + return SequenceData(token_ids) + + +def dummy_data_for_blip2(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): hf_config = ctx.get_hf_config(Blip2Config) vision_config = hf_config.vision_config + num_images = mm_counts["image"] - image_feature_size = get_blip2_image_feature_size(hf_config) - token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size - token_ids += [0] * (seq_len - image_feature_size) - seq_data = SequenceData(token_ids) + seq_data = dummy_seq_data_for_blip2( + hf_config, + seq_len, + num_images, + image_token_id=BLIP2_IMAGE_TOKEN_ID, + ) if isinstance(vision_config, Blip2VisionConfig): - mm_data = dummy_image_for_blip(vision_config) + mm_data = dummy_image_for_blip(vision_config, num_images) return seq_data, mm_data diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index fd694119932df..6776b93d126b0 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -1,6 +1,6 @@ from functools import cached_property -from typing import (Any, Dict, Iterable, List, Literal, Optional, Tuple, - TypedDict) +from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, + Tuple, TypedDict) import torch import torch.nn.functional as F @@ -19,8 +19,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -61,6 +60,7 @@ def get_max_chameleon_image_tokens(ctx: InputContext): def dummy_seq_data_for_chameleon( seq_len: int, + num_images: int, *, image_token_id: int, image_feature_size_override: Optional[int] = None, @@ -70,12 +70,14 @@ def dummy_seq_data_for_chameleon( else: image_feature_size = image_feature_size_override - token_ids = [image_token_id] * image_feature_size - token_ids += [0] * (seq_len - image_feature_size) + token_ids = [image_token_id] * image_feature_size * num_images + token_ids += [0] * (seq_len - image_feature_size * num_images) return SequenceData(token_ids) def dummy_image_for_chameleon( + num_images: int, + *, image_width_override: Optional[int] = None, image_height_override: Optional[int] = None, ): @@ -87,17 +89,20 @@ def dummy_image_for_chameleon( height = image_height_override image = Image.new("RGB", (width, height), color=0) - return {"image": image} + return {"image": image if num_images == 1 else [image] * num_images} -def dummy_data_for_chameleon(ctx: InputContext, seq_len: int): +def dummy_data_for_chameleon(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] seq_data = dummy_seq_data_for_chameleon( seq_len, + num_images, image_token_id=CHAMELEON_IMAGE_TOKEN_ID, ) - mm_data = dummy_image_for_chameleon() + mm_data = dummy_image_for_chameleon(num_images) return seq_data, mm_data diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 8ec72eeb14e52..fcd360ce8fd72 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -43,6 +43,7 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int: def dummy_seq_data_for_clip( hf_config: CLIPVisionConfig, seq_len: int, + num_images: int, *, image_token_id: int, image_feature_size_override: Optional[int] = None, @@ -52,13 +53,14 @@ def dummy_seq_data_for_clip( else: image_feature_size = image_feature_size_override - token_ids = [image_token_id] * image_feature_size - token_ids += [0] * (seq_len - image_feature_size) + token_ids = [image_token_id] * image_feature_size * num_images + token_ids += [0] * (seq_len - image_feature_size * num_images) return SequenceData(token_ids) def dummy_image_for_clip( hf_config: CLIPVisionConfig, + num_images: int, *, image_width_override: Optional[int] = None, image_height_override: Optional[int] = None, @@ -70,7 +72,7 @@ def dummy_image_for_clip( height = image_height_override image = Image.new("RGB", (width, height), color=0) - return {"image": image} + return {"image": image if num_images == 1 else [image] * num_images} def input_processor_for_clip( diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 5bb871d5a093b..e8184e466c5bf 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -16,7 +16,7 @@ # limitations under the License. """ PyTorch Fuyu model.""" import math -from typing import Iterable, List, Literal, Optional, Tuple, TypedDict +from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict import torch import torch.nn as nn @@ -29,8 +29,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger from vllm.model_executor.layers.linear import ColumnParallelLinear -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -94,27 +93,33 @@ def get_max_fuyu_image_tokens(ctx: InputContext): return (ncol + 1) * nrow -def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int): +def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int): ncol, nrow = get_max_fuyu_image_feature_size() image_feature_size = get_max_fuyu_image_tokens(ctx) - token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow - token_ids += [0] * (seq_len - image_feature_size) + image_token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow + token_ids = image_token_ids * num_images + token_ids += [0] * (seq_len - image_feature_size * num_images) return SequenceData(token_ids) def dummy_image_for_fuyu( + num_images: int, + *, image_width: int, image_height: int, ): image = Image.new("RGB", (image_width, image_height), color=0) - return {"image": image} + return {"image": image if num_images == 1 else [image] * num_images} -def dummy_data_for_fuyu(ctx: InputContext, seq_len: int): - seq_data = dummy_seq_data_for_fuyu(ctx, seq_len) - mm_data = dummy_image_for_fuyu(MAX_IMAGE_FEATURE_SIZE_WIDTH, - MAX_IMAGE_FEATURE_SIZE_HEIGHT) +def dummy_data_for_fuyu(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] + seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images) + mm_data = dummy_image_for_fuyu(num_images, + image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT) return seq_data, mm_data diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 2f323ea552ccb..069948f812253 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -11,14 +11,11 @@ @runtime_checkable class SupportsMultiModal(Protocol): - """ - The interface required for all multimodal (vision or audio) language - models. - """ + """The interface required for all multi-modal models.""" supports_multimodal: ClassVar[Literal[True]] = True """ - A flag that indicates this model supports multimodal inputs. + A flag that indicates this model supports multi-modal inputs. Note: There is no need to redefine this flag if this class is in the diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index bf772a80a343c..b379c86c1912b 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -5,7 +5,8 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import itertools -from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) import torch import torch.nn as nn @@ -230,7 +231,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): def input_mapper_for_internvl(ctx: InputContext, data: object): - hf_config = ctx.get_hf_config(PretrainedConfig) + hf_config = ctx.get_hf_config() use_thumbnail = hf_config.use_thumbnail min_num = hf_config.min_dynamic_patch @@ -256,7 +257,9 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): }) -def dummy_data_for_internvl(ctx: InputContext, seq_len: int): +def dummy_data_for_internvl(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] image_feature_size = get_max_internvl_image_tokens(ctx) model_config = ctx.model_config @@ -268,6 +271,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int): seq_data = dummy_seq_data_for_clip( vision_config, seq_len, + num_images, image_token_id=tokenizer.encode(IMG_CONTEXT, add_special_tokens=False)[0], image_feature_size_override=image_feature_size, @@ -281,6 +285,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int): mm_data = dummy_image_for_clip( vision_config, + num_images, image_width_override=max_image_width, image_height_override=max_image_height, ) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index d4faf82b49697..46db364895b13 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,5 +1,6 @@ import itertools -from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) import torch import torch.nn as nn @@ -9,8 +10,7 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -88,9 +88,11 @@ def get_max_llava_image_tokens(ctx: InputContext): raise ValueError(f"Unexpected select feature strategy: {strategy}") -def dummy_data_for_llava(ctx: InputContext, seq_len: int): +def dummy_data_for_llava(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): hf_config = ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config + num_images = mm_counts["image"] image_feature_size = get_max_llava_image_tokens(ctx) @@ -98,21 +100,23 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int): seq_data = dummy_seq_data_for_clip( vision_config, seq_len, + num_images, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) - mm_data = dummy_image_for_clip(vision_config) + mm_data = dummy_image_for_clip(vision_config, num_images) return seq_data, mm_data elif isinstance(vision_config, SiglipVisionConfig): seq_data = dummy_seq_data_for_siglip( vision_config, seq_len, + num_images, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) - mm_data = dummy_image_for_siglip(vision_config) + mm_data = dummy_image_for_siglip(vision_config, num_images) return seq_data, mm_data msg = f"Unsupported vision config: {type(vision_config)}" diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 4ae545461eef8..c1277359182e4 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,5 +1,6 @@ import itertools -from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) import torch import torch.nn as nn @@ -13,8 +14,7 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -158,9 +158,11 @@ def get_max_llava_next_image_tokens(ctx: InputContext): ) -def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): +def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): hf_config = ctx.get_hf_config(LlavaNextConfig) vision_config = hf_config.vision_config + num_images = mm_counts["image"] image_feature_size = get_max_llava_next_image_tokens(ctx) @@ -168,12 +170,14 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): seq_data = dummy_seq_data_for_clip( vision_config, seq_len, + num_images, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) mm_data = dummy_image_for_clip( vision_config, + num_images, image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) @@ -183,12 +187,14 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): seq_data = dummy_seq_data_for_siglip( vision_config, seq_len, + num_images, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) mm_data = dummy_image_for_siglip( vision_config, + num_images, image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 70002e2d532d4..ef2323398abd0 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -24,8 +24,8 @@ import math import re from functools import partial -from typing import (Any, Callable, Iterable, List, Optional, Tuple, TypedDict, - Union) +from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, + TypedDict, Union) import numpy as np import torch @@ -42,8 +42,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype @@ -408,22 +407,24 @@ def get_max_minicpmv_image_tokens(ctx: InputContext): return getattr(hf_config, "query_num", 64) -def dummy_seq_data_for_minicpmv(seq_len: int): +def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): token_ids = [0] * seq_len return SequenceData(token_ids) -def dummy_image_for_minicpmv(hf_config: PretrainedConfig): +def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int): width = height = hf_config.image_size image = Image.new("RGB", (width, height), color=0) - return {"image": image} + return {"image": image if num_images == 1 else [image] * num_images} -def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int): +def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): hf_config = ctx.get_hf_config() + num_images = mm_counts["image"] - seq_data = dummy_seq_data_for_minicpmv(seq_len) - mm_data = dummy_image_for_minicpmv(hf_config) + seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images) + mm_data = dummy_image_for_minicpmv(hf_config, num_images) return seq_data, mm_data diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 51ff8c5d6fd13..8beb2778fe37a 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,4 +1,5 @@ -from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) import torch from torch import nn @@ -9,8 +10,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.gemma import GemmaModel @@ -57,17 +57,20 @@ def get_max_paligemma_image_tokens(ctx: InputContext): return get_max_siglip_image_tokens(vision_config) -def dummy_data_for_paligemma(ctx: InputContext, seq_len: int): +def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): hf_config = ctx.get_hf_config(PaliGemmaConfig) vision_config = hf_config.vision_config + num_images = mm_counts["image"] seq_data = dummy_seq_data_for_siglip( vision_config, seq_len, + num_images, image_token_id=hf_config.image_token_index, ) - mm_data = dummy_image_for_siglip(vision_config) + mm_data = dummy_image_for_siglip(vision_config, num_images) return seq_data, mm_data diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index f0ae0b6fdfb93..1c8bb8a837c86 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -15,7 +15,8 @@ # limitations under the License. import re from functools import lru_cache -from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) import numpy as np import torch @@ -28,8 +29,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -347,18 +347,22 @@ def get_max_phi3v_image_tokens(ctx: InputContext): ) -def dummy_data_for_phi3v(ctx: InputContext, seq_len: int): +def dummy_data_for_phi3v(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] image_feature_size = get_max_phi3v_image_tokens(ctx) seq_data = dummy_seq_data_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, seq_len, + num_images, image_token_id=_IMAGE_TOKEN_ID, image_feature_size_override=image_feature_size, ) mm_data = dummy_image_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, + num_images, image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index afe57bf573ad5..4df8c0b54201c 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -52,6 +52,7 @@ def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int: def dummy_seq_data_for_siglip( hf_config: SiglipVisionConfig, seq_len: int, + num_images: int, *, image_token_id: int, image_feature_size_override: Optional[int] = None, @@ -61,13 +62,14 @@ def dummy_seq_data_for_siglip( else: image_feature_size = image_feature_size_override - token_ids = [image_token_id] * image_feature_size - token_ids += [0] * (seq_len - image_feature_size) + token_ids = [image_token_id] * image_feature_size * num_images + token_ids += [0] * (seq_len - image_feature_size * num_images) return SequenceData(token_ids) def dummy_image_for_siglip( hf_config: SiglipVisionConfig, + num_images: int, *, image_width_override: Optional[int] = None, image_height_override: Optional[int] = None, @@ -79,7 +81,7 @@ def dummy_image_for_siglip( height = image_height_override image = Image.new("RGB", (width, height), color=0) - return {"image": image} + return {"image": image if num_images == 1 else [image] * num_images} def input_processor_for_siglip( diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 7717d77198a19..8ada60c8fd6ae 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,9 +1,9 @@ import sys from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Mapping, Optional from typing import Sequence as GenericSequence -from typing import Tuple, Type, TypedDict, TypeVar, Union, cast +from typing import Tuple, Type, TypedDict, TypeVar, Union, cast, final import numpy as np import torch @@ -116,17 +116,30 @@ def as_kwargs( batched_inputs) +_T = TypeVar("_T") + +MultiModalData: TypeAlias = Union[_T, List[_T]] +""" +Either a single data instance, or a list of data instances. + +The number of data instances allowed per modality is restricted by +`--limit-mm-per-prompt`. +""" + + +@final class MultiModalDataBuiltins(TypedDict, total=False): """Modality types that are predefined by vLLM.""" - image: Image.Image - """The input image.""" + image: MultiModalData[Image.Image] + """The input image(s).""" - audio: Tuple[np.ndarray, Union[int, float]] - """The input audio and its sampling rate.""" + audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]] + """The input audio item(s) and corresponding sampling rate(s).""" -MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]] +MultiModalDataDict = Union[MultiModalDataBuiltins, + Mapping[str, MultiModalData[object]]] """ A dictionary containing an item for each modality type to input. @@ -137,7 +150,8 @@ class MultiModalDataBuiltins(TypedDict, total=False): Read more on that :ref:`here `. """ -MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs] +MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]], + MultiModalInputs] """ Return a dictionary to be passed as keyword arguments to :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers @@ -181,8 +195,11 @@ def get_data_key(self) -> str: raise NotImplementedError @abstractmethod - def _default_input_mapper(self, ctx: InputContext, - data: object) -> MultiModalInputs: + def _default_input_mapper( + self, + ctx: InputContext, + data: MultiModalData[object], + ) -> MultiModalInputs: """ Return a dictionary to be passed as keyword arguments to :meth:`~torch.nn.Module.forward`. This is similar in concept to @@ -225,7 +242,7 @@ def wrapper(model_cls: N) -> N: return wrapper def map_input(self, model_config: ModelConfig, - data: object) -> MultiModalInputs: + data: MultiModalData[object]) -> MultiModalInputs: """ Transform the data into a dictionary of model inputs using the input mapper registered for that model. @@ -254,8 +271,8 @@ def map_input(self, model_config: ModelConfig, @abstractmethod def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: """ - Calculate the maximum number of multimodal tokens input to the language - model. This does not include tokens that correspond to the input text. + Calculate the maximum number of tokens, corresponding to a single + instance of multimodal data, that are passed to the language model. """ raise NotImplementedError @@ -269,8 +286,9 @@ def register_max_multimodal_tokens( max_mm_tokens: Optional[MultiModalTokensCalc] = None, ): """ - Register the maximum number of multi-modal tokens input to the - language model for a model class. + Register the maximum number of tokens, corresponding to a single + instance of multimodal data, that are passed to the language model + for a model class. If `None` is provided, then the default calculation is used instead. diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 3e83c9ef381ac..916bd5e601bb7 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -11,7 +11,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.utils import is_list_of -from .base import MultiModalInputs, MultiModalPlugin +from .base import MultiModalData, MultiModalInputs, MultiModalPlugin logger = init_logger(__name__) @@ -110,8 +110,11 @@ def _get_hf_image_processor(self, model_config: ModelConfig): model_config.model, trust_remote_code=model_config.trust_remote_code) - def _default_input_mapper(self, ctx: InputContext, - data: object) -> MultiModalInputs: + def _default_input_mapper( + self, + ctx: InputContext, + data: MultiModalData[object], + ) -> MultiModalInputs: model_config = ctx.model_config # PIL image diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 19c26123c2df3..d487d20011b45 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,19 +1,33 @@ import functools -from typing import Dict, Optional, Sequence +from collections import UserDict +from typing import Dict, Mapping, Optional, Sequence -import torch - -from vllm.config import ModelConfig +from vllm.config import ModelConfig, MultiModalConfig from vllm.logger import init_logger from .audio import AudioPlugin from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs, - MultiModalPlugin, MultiModalTokensCalc) + MultiModalPlugin, MultiModalTokensCalc, NestedTensors) from .image import ImagePlugin logger = init_logger(__name__) +class _MultiModalLimits(UserDict): + """ + Wraps `_limits_by_model` for a more informative error message + when attempting to access a model that does not exist. + """ + + def __getitem__(self, key: ModelConfig) -> Dict[str, int]: + try: + return super().__getitem__(key) + except KeyError as exc: + msg = (f"Cannot find `mm_limits` for model={key.model}. Did you " + "forget to call `init_mm_limits_per_prompt`?") + raise KeyError(msg) from exc + + class MultiModalRegistry: """ A registry that dispatches data processing to the @@ -28,6 +42,11 @@ def __init__( plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: self._plugins = {p.get_data_key(): p for p in plugins} + # This is used for non-multimodal models + self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} + + self._limits_by_model = _MultiModalLimits() + def register_plugin(self, plugin: MultiModalPlugin) -> None: """ Register a multi-modal plugin so it can be recognized by vLLM. @@ -86,13 +105,24 @@ def map_input(self, model_config: ModelConfig, via the input mapper registered for that model. See :meth:`MultiModalPlugin.map_input` for more details. + + Note: + This should be called after :meth:`init_mm_limits_per_prompt`. """ - merged_dict: Dict[str, torch.Tensor] = {} + merged_dict: Dict[str, NestedTensors] = {} for data_key, data_value in data.items(): - input_dict = self._get_plugin(data_key) \ - .map_input(model_config, data_value) + plugin = self._get_plugin(data_key) + num_items = len(data_value) if isinstance(data_value, list) else 1 + max_items = self._limits_by_model[model_config][data_key] + if num_items > max_items: + raise ValueError( + f"You set {data_key}={max_items} (or defaulted to 1) in " + f"`--limit-mm-per-prompt`, but found {num_items} items " + "in the same prompt.") + + input_dict = plugin.map_input(model_config, data_value) for input_key, input_tensor in input_dict.items(): if input_key in merged_dict: raise ValueError(f"The input mappers (keys={set(data)}) " @@ -115,8 +145,9 @@ def register_max_multimodal_tokens( max_mm_tokens: Optional[MultiModalTokensCalc] = None, ): """ - Register the maximum number of tokens, belonging to a - specific modality, input to the language model for a model class. + Register the maximum number of tokens, corresponding to a single + instance of multimodal data belonging to a specific modality, that are + passed to the language model for a model class. """ return self._get_plugin(data_type_key) \ .register_max_multimodal_tokens(max_mm_tokens) @@ -126,8 +157,8 @@ def register_max_image_tokens( max_mm_tokens: Optional[MultiModalTokensCalc] = None, ): """ - Register the maximum number of image tokens - input to the language model for a model class. + Register the maximum number of image tokens, corresponding to a single + image, that are passed to the language model for a model class. """ return self.register_max_multimodal_tokens("image", max_mm_tokens) @@ -135,9 +166,63 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: """ Get the maximum number of multi-modal tokens for profiling the memory usage of a model. - + See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. + + Note: + This should be called after :meth:`init_mm_limits_per_prompt`. + """ + limits_per_plugin = self._limits_by_model[model_config] + + return sum((limits_per_plugin[key] * + plugin.get_max_multimodal_tokens(model_config)) + for key, plugin in self._plugins.items()) + + def init_mm_limits_per_prompt( + self, + model_config: ModelConfig, + multimodal_config: Optional[MultiModalConfig], + ) -> None: + """ + Initialize the maximum number of multi-modal input instances for each + modality that are allowed per prompt for a model class. + """ + if model_config in self._limits_by_model: + logger.warning( + "`mm_limits` has already been set for model=%s, and will " + "be overwritten by the new values.", model_config.model) + + if multimodal_config is None: + limits_per_plugin = self._disabled_limits_per_plugin + else: + config_limits_per_plugin = multimodal_config.limit_per_prompt + + extra_keys = config_limits_per_plugin.keys() - self._plugins.keys() + if extra_keys: + logger.warning( + "Detected extra keys in `--limit-mm-per-prompt` which " + "are not registered as multi-modal plugins: %s. " + "They will be ignored.", extra_keys) + + # NOTE: Currently the default is set to 1 for each plugin + # TODO: Automatically determine the limits based on budget + # once more models support multi-image inputs + limits_per_plugin = { + key: config_limits_per_plugin.get(key, 1) + for key in self._plugins + } + + self._limits_by_model[model_config] = limits_per_plugin + + def get_mm_limits_per_prompt( + self, + model_config: ModelConfig, + ) -> Mapping[str, int]: + """ + Get the maximum number of multi-modal input instances for each modality + that are allowed per prompt for a model class. + + Note: + This should be called after :meth:`init_mm_limits_per_prompt`. """ - return sum( - plugin.get_max_multimodal_tokens(model_config) - for plugin in self._plugins.values()) + return self._limits_by_model[model_config] diff --git a/vllm/utils.py b/vllm/utils.py index 753efca3e2a61..39fe742203a47 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -13,7 +13,6 @@ import uuid import warnings from asyncio import FIRST_COMPLETED, ensure_future -from collections import defaultdict from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, @@ -760,16 +759,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): gc.collect() -def str_to_int_tuple(s: str) -> Tuple[int, ...]: - """Convert a string to a tuple of integers.""" - try: - return tuple(map(int, s.split(","))) - except ValueError as e: - raise ValueError( - "String must be a series of integers separated by commas " - f"(e.g., 1, 2, 3). Given input: {s}") from e - - def make_ndarray_with_pad( x: List[List[T]], pad: T, @@ -863,23 +852,6 @@ def is_list_of( assert_never(check) -def merge_dicts(dict1: Dict[K, List[T]], - dict2: Dict[K, List[T]]) -> Dict[K, List[T]]: - """Merge 2 dicts that have key -> List of items. - - When a key conflicts, the values in dict1 is prioritized. - """ - merged_dict: Dict[K, List[T]] = defaultdict(list) - - for key, value in dict1.items(): - merged_dict[key].extend(value) - - for key, value in dict2.items(): - merged_dict[key].extend(value) - - return dict(merged_dict) - - JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"], Tuple["JSONTree[T]", ...], T] """A nested JSON structure where the leaves need not be JSON-serializable.""" diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 4e66a04674c2a..4aec8d1d408d7 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -12,9 +12,10 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) -from vllm.inputs import INPUT_REGISTRY +from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput, SequenceGroupMetadata) @@ -83,6 +84,8 @@ def __init__( prompt_adapter_config: Optional[PromptAdapterConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, observability_config: Optional[ObservabilityConfig] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): ''' EncoderDecoderModelRunner constructor. @@ -271,6 +274,16 @@ def profile_run(self) -> None: seqs: List[SequenceGroupMetadata] = [] model_config = self.model_config + mm_config = self.multimodal_config + + input_registry = self.input_registry + mm_registry = self.mm_registry + mm_registry.init_mm_limits_per_prompt(model_config, mm_config) + + max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config) + if max_mm_tokens > 0: + raise NotImplementedError( + "Multi-modal encoder-decoder models are not supported yet") batch_size = 0 for group_id in range(max_num_seqs): @@ -278,8 +291,8 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - seq_data, _ = INPUT_REGISTRY \ - .dummy_data_for_profiling(model_config, seq_len) + seq_data, _ = input_registry \ + .dummy_data_for_profiling(model_config, seq_len, mm_registry) # Having more tokens is over-conservative but otherwise fine assert len(seq_data.prompt_token_ids) >= seq_len, ( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a4ce1b512dd05..47068f77ec5df 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -31,7 +31,7 @@ ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture -from vllm.inputs import INPUT_REGISTRY +from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -43,7 +43,7 @@ supports_multimodal) from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs) + MultiModalInputs, MultiModalRegistry) from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.worker_manager import ( @@ -807,6 +807,8 @@ def __init__( multimodal_config: Optional[MultiModalConfig] = None, return_hidden_states: bool = False, observability_config: Optional[ObservabilityConfig] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): self.model_config = model_config self.parallel_config = parallel_config @@ -860,8 +862,10 @@ def __init__( ) if num_attn_heads else None # Multi-modal data support - self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ - .create_input_mapper(self.model_config) + self.input_registry = input_registry + self.mm_registry = mm_registry + self.multi_modal_input_mapper = mm_registry \ + .create_input_mapper(model_config) # Lazy initialization self.model: nn.Module # Set after load_model @@ -902,7 +906,7 @@ def load_model(self) -> None: assert supports_lora(self.model), "Model does not support LoRA" assert not supports_multimodal( self.model - ), "To be tested: multimodal language model with LoRA settings." + ), "To be tested: Multi-modal model with LoRA settings." self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, @@ -1046,17 +1050,21 @@ def profile_run(self) -> None: # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] - # Additional GPU memory may be needed for vision encoding, which needs - # to be accounted for when calculating the GPU blocks for + # Additional GPU memory may be needed for multi-modal encoding, which + # needs to be accounted for when calculating the GPU blocks for # vLLM blocker manager. # To exercise the worst scenario for GPU memory consumption, # the number of seqs (batch_size) is chosen to maximize the number # of images processed. model_config = self.model_config + mm_config = self.multimodal_config - if supports_multimodal(self.model): - max_mm_tokens = MULTIMODAL_REGISTRY \ - .get_max_multimodal_tokens(model_config) + input_registry = self.input_registry + mm_registry = self.mm_registry + mm_registry.init_mm_limits_per_prompt(model_config, mm_config) + + max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config) + if max_mm_tokens > 0: max_num_seqs_orig = max_num_seqs max_num_seqs = min(max_num_seqs, max_num_batched_tokens // max_mm_tokens) @@ -1074,13 +1082,8 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ - .dummy_data_for_profiling(model_config, seq_len) - - # Having more tokens is over-conservative but otherwise fine - assert len(seq_data.prompt_token_ids) >= seq_len, ( - f"Expected at least {seq_len} dummy tokens for profiling, " - f"but got: {len(seq_data.prompt_token_ids)}") + seq_data, dummy_multi_modal_data = input_registry \ + .dummy_data_for_profiling(model_config, seq_len, mm_registry) seq = SequenceGroupMetadata( request_id=str(group_id), diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index a1e1c1bef6336..d4b450199bb5d 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -9,12 +9,11 @@ ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.distributed import broadcast_tensor_dict -from vllm.inputs import INPUT_REGISTRY +from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models.interfaces import supports_multimodal from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs) + MultiModalInputs, MultiModalRegistry) from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) @@ -89,6 +88,8 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, *args, **kwargs, ): @@ -120,8 +121,10 @@ def __init__( ) # Multi-modal data support - self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ - .create_input_mapper(self.model_config) + self.input_registry = input_registry + self.mm_registry = mm_registry + self.multi_modal_input_mapper = mm_registry \ + .create_input_mapper(model_config) # Lazy initialization. self.model: nn.Module # Set after init_Model @@ -157,17 +160,21 @@ def profile_run(self) -> None: # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] - # Additional GPU memory may be needed for vision encoding, which needs - # to be accounted for when calculating the GPU blocks for + # Additional GPU memory may be needed for multi-modal encoding, which + # needs to be accounted for when calculating the GPU blocks for # vLLM blocker manager. # To exercise the worst scenario for GPU memory consumption, # the number of seqs (batch_size) is chosen to maximize the number # of images processed. model_config = self.model_config + mm_config = self.multimodal_config - if supports_multimodal(self.model): - max_mm_tokens = MULTIMODAL_REGISTRY \ - .get_max_multimodal_tokens(model_config) + input_registry = self.input_registry + mm_registry = self.mm_registry + mm_registry.init_mm_limits_per_prompt(model_config, mm_config) + + max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config) + if max_mm_tokens > 0: max_num_seqs_orig = max_num_seqs max_num_seqs = min(max_num_seqs, max_num_batched_tokens // max_mm_tokens) @@ -183,13 +190,8 @@ def profile_run(self) -> None: seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) - seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ - .dummy_data_for_profiling(model_config, seq_len) - - # Having more tokens is over-conservative but otherwise fine - assert len(seq_data.prompt_token_ids) >= seq_len, ( - f"Expected at least {seq_len} dummy tokens for profiling, " - f"but got: {len(seq_data.prompt_token_ids)}") + seq_data, dummy_multi_modal_data = input_registry \ + .dummy_data_for_profiling(model_config, seq_len, mm_registry) seq = SequenceGroupMetadata( request_id=str(group_id),