From 7a64d24aad69e4d2548aa0bf528d9fe63428ab01 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 3 Jun 2024 13:56:41 +0800 Subject: [PATCH] [Core] Support image processor (#4197) --- .github/workflows/mypy.yaml | 1 + docs/source/conf.py | 14 +- .../dev/multimodal/multimodal_index.rst | 51 ++++++ docs/source/index.rst | 6 +- docs/source/models/supported_models.rst | 4 + docs/source/models/vlm.rst | 56 +++++++ examples/llava_example.py | 29 ++-- format.sh | 1 + requirements-common.txt | 1 + requirements-dev.txt | 3 - tests/conftest.py | 45 ++--- tests/models/test_llava.py | 60 ++++--- tests/multimodal/__init__.py | 0 tests/multimodal/test_processor.py | 98 +++++++++++ tests/spec_decode/e2e/conftest.py | 3 +- tests/tokenization/test_image_processor.py | 20 +++ vllm/config.py | 6 +- vllm/engine/arg_utils.py | 108 ++++++++---- vllm/entrypoints/llm.py | 25 +-- vllm/model_executor/models/llava.py | 73 +++++--- vllm/multimodal/__init__.py | 7 + vllm/multimodal/base.py | 126 ++++++++++++++ vllm/multimodal/image.py | 141 ++++++++++++++++ vllm/multimodal/registry.py | 156 ++++++++++++++++++ vllm/sequence.py | 32 +--- vllm/transformers_utils/image_processor.py | 45 +++++ vllm/worker/cpu_model_runner.py | 57 ++++--- vllm/worker/embedding_model_runner.py | 10 +- vllm/worker/model_runner.py | 120 +++++++------- 29 files changed, 1042 insertions(+), 256 deletions(-) create mode 100644 docs/source/dev/multimodal/multimodal_index.rst create mode 100644 docs/source/models/vlm.rst create mode 100644 tests/multimodal/__init__.py create mode 100644 tests/multimodal/test_processor.py create mode 100644 tests/tokenization/test_image_processor.py create mode 100644 vllm/multimodal/__init__.py create mode 100644 vllm/multimodal/base.py create mode 100644 vllm/multimodal/image.py create mode 100644 vllm/multimodal/registry.py create mode 100644 vllm/transformers_utils/image_processor.py diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index a20753d8a7702..22e6c2ef0101e 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -37,6 +37,7 @@ jobs: mypy vllm/distributed --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml + mypy vllm/multimodal --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml mypy vllm/*.py --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml diff --git a/docs/source/conf.py b/docs/source/conf.py index cfebc2ff9bb33..f1a7013edd332 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -90,6 +90,7 @@ def setup(app): "sentencepiece", "vllm.cuda_utils", "vllm._C", + "PIL", "numpy", "tqdm", "tensorizer", @@ -116,12 +117,13 @@ def add_line(self, line: str, source: str, *lineno: int) -> None: autodoc.ClassDocumenter = MockedClassDocumenter intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'typing_extensions': - ('https://typing-extensions.readthedocs.io/en/latest', None), - 'numpy': ('https://numpy.org/doc/stable', None), - 'torch': ('https://pytorch.org/docs/stable', None), - 'psutil': ('https://psutil.readthedocs.io/en/stable', None), + "python": ("https://docs.python.org/3", None), + "typing_extensions": + ("https://typing-extensions.readthedocs.io/en/latest", None), + "pillow": ("https://pillow.readthedocs.io/en/stable", None), + "numpy": ("https://numpy.org/doc/stable", None), + "torch": ("https://pytorch.org/docs/stable", None), + "psutil": ("https://psutil.readthedocs.io/en/stable", None), } autodoc_preserve_defaults = True diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst new file mode 100644 index 0000000000000..a25eceecc276b --- /dev/null +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -0,0 +1,51 @@ +Multi-Modality +============== + +.. currentmodule:: vllm.multimodal + +vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. + +:class:`vllm.inputs.PromptStrictInputs` accepts an additional attribute ``multi_modal_data`` +which allows you to pass in multi-modal input alongside text and token prompts. + +By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model, +you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data `, +as well as :meth:`MULTIMODAL_REGISTRY.register_input ` for each modality type to support. + +.. contents:: + :local: + :backlinks: none + +Module Contents ++++++++++++++++ + +.. automodule:: vllm.multimodal + +Registry +-------- + +.. data:: vllm.multimodal.MULTIMODAL_REGISTRY + + The global :class:`MultiModalRegistry` which is used by model runners. + +.. autoclass:: vllm.multimodal.MultiModalRegistry + :members: + :show-inheritance: + +Base Classes +------------ + +.. autoclass:: vllm.multimodal.MultiModalData + :members: + :show-inheritance: + +.. autoclass:: vllm.multimodal.MultiModalPlugin + :members: + :show-inheritance: + +Image Classes +------------- + +.. automodule:: vllm.multimodal.image + :members: + :show-inheritance: diff --git a/docs/source/index.rst b/docs/source/index.rst index 5f18fe9ae0a73..fad3c3b05b0c0 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -88,6 +88,7 @@ Documentation models/adding_model models/engine_args models/lora + models/vlm models/performance .. toctree:: @@ -99,17 +100,18 @@ Documentation quantization/fp8_e4m3_kvcache .. toctree:: - :maxdepth: 2 + :maxdepth: 1 :caption: Developer Documentation dev/sampling_params dev/offline_inference/offline_index dev/engine/engine_index dev/kernel/paged_attention + dev/multimodal/multimodal_index dev/dockerfile/dockerfile .. toctree:: - :maxdepth: 2 + :maxdepth: 1 :caption: Community community/meetups diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 82e71e61975c8..24fa83df7d751 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -87,6 +87,10 @@ Alongside each architecture, we include some popular models that use it. - LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi - :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. - ✅︎ + * - :code:`LlavaForConditionalGeneration` + - LLaVA-1.5 + - :code:`llava-hf/llava-1.5-7b-hf`\*, :code:`llava-hf/llava-1.5-13b-hf`\*, etc. + - * - :code:`MiniCPMForCausalLM` - MiniCPM - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst new file mode 100644 index 0000000000000..52afda747aab8 --- /dev/null +++ b/docs/source/models/vlm.rst @@ -0,0 +1,56 @@ +.. _vlm: + +Using VLMs +========== + +This document shows you how to run and serve Vision Language Models (VLMs) using vLLM. + +Engine Arguments +---------------- + +The following :ref:`engine arguments ` are specific to VLMs: + +.. argparse:: + :module: vllm.engine.arg_utils + :func: _vlm_engine_args_parser + :prog: -m vllm.entrypoints.openai.api_server + :nodefaultconst: + +Offline Batched Inference +------------------------- + +To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` class for instantiating the engine. + +.. code-block:: python + + llm = LLM( + model="llava-hf/llava-1.5-7b-hf", + image_input_type="pixel_values", + image_token_id=32000, + image_input_shape="1,3,336,336", + image_feature_size=576, + ) + +For now, we only support a single image per text prompt. To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`: + +* ``prompt``: The prompt should have a number of ```` tokens equal to ``image_feature_size``. +* ``multi_modal_data``: This should be an instance of :class:`~vllm.multimodal.image.ImagePixelData` or :class:`~vllm.multimodal.image.ImageFeatureData`. + +.. code-block:: python + + prompt = "" * 576 + ( + "\nUSER: What is the content of this image?\nASSISTANT:") + + # Load the image using PIL.Image + image = ... + + outputs = llm.generate({ + "prompt": prompt, + "multi_modal_data": ImagePixelData(image), + }) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + +A code example can be found in `examples/llava_example.py `_. diff --git a/examples/llava_example.py b/examples/llava_example.py index 60250c4303fbf..980d7bf9f8a3c 100644 --- a/examples/llava_example.py +++ b/examples/llava_example.py @@ -3,33 +3,36 @@ import subprocess import torch +from PIL import Image from vllm import LLM -from vllm.sequence import MultiModalData +from vllm.multimodal.image import ImageFeatureData, ImagePixelData # The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`. +# You can use `.buildkite/download-images.sh` to download them -def run_llava_pixel_values(): +def run_llava_pixel_values(*, disable_image_processor: bool = False): llm = LLM( model="llava-hf/llava-1.5-7b-hf", image_input_type="pixel_values", image_token_id=32000, image_input_shape="1,3,336,336", image_feature_size=576, + disable_image_processor=disable_image_processor, ) prompt = "" * 576 + ( "\nUSER: What is the content of this image?\nASSISTANT:") - # This should be provided by another online or offline component. - image = torch.load("images/stop_sign_pixel_values.pt") + if disable_image_processor: + image = torch.load("images/stop_sign_pixel_values.pt") + else: + image = Image.open("images/stop_sign.jpg") outputs = llm.generate({ - "prompt": - prompt, - "multi_modal_data": - MultiModalData(type=MultiModalData.Type.IMAGE, data=image), + "prompt": prompt, + "multi_modal_data": ImagePixelData(image), }) for o in outputs: @@ -49,15 +52,13 @@ def run_llava_image_features(): prompt = "" * 576 + ( "\nUSER: What is the content of this image?\nASSISTANT:") - # This should be provided by another online or offline component. - image = torch.load("images/stop_sign_image_features.pt") + image: torch.Tensor = torch.load("images/stop_sign_image_features.pt") outputs = llm.generate({ - "prompt": - prompt, - "multi_modal_data": - MultiModalData(type=MultiModalData.Type.IMAGE, data=image), + "prompt": prompt, + "multi_modal_data": ImageFeatureData(image), }) + for o in outputs: generated_text = o.outputs[0].text print(generated_text) diff --git a/format.sh b/format.sh index d110855f8c273..ca828457f9999 100755 --- a/format.sh +++ b/format.sh @@ -101,6 +101,7 @@ mypy vllm/core --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml +mypy vllm/multimodal --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml mypy vllm/*.py --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml diff --git a/requirements-common.txt b/requirements-common.txt index 3ea22276f63f4..f41873570aa67 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -12,6 +12,7 @@ aiohttp openai uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. +pillow # Required for image processing prometheus_client >= 0.18.0 prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer diff --git a/requirements-dev.txt b/requirements-dev.txt index 2c6b33ea813a2..12b22a61ea162 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -33,8 +33,5 @@ sentence-transformers # required for embedding # Benchmarking aiohttp -# Multimodal -pillow - # quantization bitsandbytes==0.42.0 diff --git a/tests/conftest.py b/tests/conftest.py index d904058dc369c..e749338e1095a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,9 @@ from vllm.distributed import destroy_model_parallel from vllm.inputs import TextPrompt from vllm.logger import init_logger -from vllm.sequence import MultiModalData, SampleLogprobs +from vllm.multimodal import MultiModalData +from vllm.multimodal.image import ImageFeatureData, ImagePixelData +from vllm.sequence import SampleLogprobs logger = init_logger(__name__) @@ -24,6 +26,7 @@ _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] # Multi modal related +# You can use `.buildkite/download-images.sh` to download the assets _PIXEL_VALUES_FILES = [ os.path.join(_TEST_DIR, "images", filename) for filename in ["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"] @@ -89,17 +92,23 @@ def hf_images() -> List[Image.Image]: @pytest.fixture() -def vllm_images(request) -> "torch.Tensor": +def vllm_images(request) -> List[MultiModalData]: vision_language_config = request.getfixturevalue("model_and_config")[1] - all_images = [] if vision_language_config.image_input_type == ( VisionLanguageConfig.ImageInputType.IMAGE_FEATURES): - filenames = _IMAGE_FEATURES_FILES + return [ + ImageFeatureData(torch.load(filename)) + for filename in _IMAGE_FEATURES_FILES + ] else: - filenames = _PIXEL_VALUES_FILES - for filename in filenames: - all_images.append(torch.load(filename)) - return torch.concat(all_images, dim=0) + return [ + ImagePixelData(Image.open(filename)) for filename in _IMAGE_FILES + ] + + +@pytest.fixture() +def vllm_image_tensors(request) -> List[torch.Tensor]: + return [torch.load(filename) for filename in _PIXEL_VALUES_FILES] @pytest.fixture() @@ -392,23 +401,17 @@ def generate( self, prompts: List[str], sampling_params: SamplingParams, - images: Optional[torch.Tensor] = None, + images: Optional[List[MultiModalData]] = None, ) -> List[Tuple[List[List[int]], List[str]]]: if images is not None: assert len(prompts) == len(images) - prompt_inputs: List[TextPrompt] = [] - for i, prompt in enumerate(prompts): - prompt = TextPrompt(prompt=prompt) - if images is not None: - prompt["multi_modal_data"] = MultiModalData( - type=MultiModalData.Type.IMAGE, - data=images[i:i + 1], - ) - - prompt_inputs.append(prompt) + inputs = [TextPrompt(prompt=prompt) for prompt in prompts] + if images is not None: + for i, image in enumerate(images): + inputs[i]["multi_modal_data"] = image - req_outputs = self.model.generate(prompt_inputs, + req_outputs = self.model.generate(inputs, sampling_params=sampling_params) outputs: List[Tuple[List[List[int]], List[str]]] = [] @@ -447,7 +450,7 @@ def generate_greedy( self, prompts: List[str], max_tokens: int, - images: Optional[torch.Tensor] = None, + images: Optional[List[MultiModalData]] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs = self.generate(prompts, greedy_params, images=images) diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index f86cd3fa88f5d..cc0685ca9c5eb 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -1,7 +1,7 @@ import gc from dataclasses import fields from enum import Enum -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple import pytest import torch @@ -9,36 +9,50 @@ from vllm.config import VisionLanguageConfig + +def iter_llava_configs(model_name: str): + image_hw_to_feature_size = { + (336, 336): 576, + } + + for (h, w), f in image_hw_to_feature_size.items(): + for input_type, input_shape in [ + (VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)), + (VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, f, 1024)), + ]: + yield (model_name, + VisionLanguageConfig(image_input_type=input_type, + image_feature_size=f, + image_token_id=32000, + image_input_shape=input_shape, + image_processor=model_name, + image_processor_revision=None)) + + model_and_vl_config = [ - ("llava-hf/llava-1.5-7b-hf", - VisionLanguageConfig( - image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, - image_feature_size=576, - image_token_id=32000, - image_input_shape=(1, 3, 336, 336))), - ("llava-hf/llava-1.5-7b-hf", - VisionLanguageConfig( - image_input_type=VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, - image_feature_size=576, - image_token_id=32000, - image_input_shape=(1, 576, 1024))) + *iter_llava_configs("llava-hf/llava-1.5-7b-hf"), + # Not enough memory + # *iter_llava_configs("llava-hf/llava-1.5-13b-hf"), ] -def as_dict(vision_language_config: VisionLanguageConfig) -> Dict: +def as_dict(vlm_config: VisionLanguageConfig) -> Dict[str, Any]: """Flatten vision language config to pure args. Compatible with what llm entrypoint expects. """ result = {} - for field in fields(vision_language_config): - value = getattr(vision_language_config, field.name) + for field in fields(vlm_config): + value = getattr(vlm_config, field.name) if isinstance(value, Enum): result[field.name] = value.name.lower() elif isinstance(value, tuple): result[field.name] = ",".join([str(item) for item in value]) else: result[field.name] = value + + result["disable_image_processor"] = vlm_config.image_processor is None + return result @@ -67,18 +81,19 @@ def sanitize_vllm_output(vllm_output: Tuple[List[int], str], @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images, - vllm_image_prompts, vllm_images, model_and_config: tuple, - dtype: str, max_tokens: int, worker_use_ray: bool) -> None: + vllm_image_prompts, vllm_images, model_and_config, dtype: str, + max_tokens: int, worker_use_ray: bool) -> None: """Inference result should be the same between hf and vllm. All the image fixtures for the test is under tests/images. - For huggingface runner, we provide the raw images as input. - For vllm runner, we provide image tensors and corresponding + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalData objects and corresponding vision language config as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ model_id, vision_language_config = model_and_config + hf_model = hf_runner(model_id, dtype=dtype) hf_outputs = hf_model.generate_greedy(hf_image_prompts, max_tokens, @@ -88,6 +103,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images, vllm_model = vllm_runner(model_id, dtype=dtype, worker_use_ray=worker_use_ray, + enforce_eager=True, **as_dict(vision_language_config)) vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, max_tokens, @@ -105,3 +121,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images, f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +# TODO: Add test for `tensor_parallel_size` [ref: PR #3883] +# (Requires multiple GPUs) diff --git a/tests/multimodal/__init__.py b/tests/multimodal/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py new file mode 100644 index 0000000000000..4aeae633d07f7 --- /dev/null +++ b/tests/multimodal/test_processor.py @@ -0,0 +1,98 @@ +import numpy as np +import pytest +from transformers import CLIPImageProcessor + +from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import ImagePixelData + + +@pytest.mark.parametrize("dtype", ["half", "bfloat16", "float"]) +def test_clip_image_processor(hf_images, dtype): + MODEL_NAME = "llava-hf/llava-1.5-7b-hf" + IMAGE_HEIGHT = IMAGE_WIDTH = 33 + + hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME) + assert isinstance(hf_processor, CLIPImageProcessor) + + model_config = ModelConfig( + model=MODEL_NAME, + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype=dtype, + revision=None, + ) + vlm_config = VisionLanguageConfig( + image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, + image_token_id=32000, + image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), + image_feature_size=576, + image_processor=MODEL_NAME, + image_processor_revision=None, + ) + + for image in hf_images: + hf_result = hf_processor.preprocess( + image, + return_tensors="np", + ) + vllm_result = MULTIMODAL_REGISTRY.process_input( + ImagePixelData(image), + model_config=model_config, + vlm_config=vlm_config, + ) + + assert hf_result.keys() == vllm_result.keys() + for key, hf_arr in hf_result.items(): + vllm_arr: np.ndarray = vllm_result[key].numpy() + + assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}" + assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" + + +@pytest.mark.parametrize("dtype", ["float"]) +def test_image_pixel_types(hf_images, vllm_image_tensors, dtype): + MODEL_NAME = "llava-hf/llava-1.5-7b-hf" + IMAGE_HEIGHT = IMAGE_WIDTH = 33 + + model_config = ModelConfig( + model=MODEL_NAME, + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype=dtype, + revision=None, + ) + vlm_config = VisionLanguageConfig( + image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, + image_token_id=32000, + image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), + image_feature_size=576, + image_processor=MODEL_NAME, + image_processor_revision=None, + ) + + for image, tensor in zip(hf_images, vllm_image_tensors): + image_result = MULTIMODAL_REGISTRY.process_input( + ImagePixelData(image), + model_config=model_config, + vlm_config=vlm_config, + ) + tensor_result = MULTIMODAL_REGISTRY.process_input( + ImagePixelData(tensor), + model_config=model_config, + vlm_config=vlm_config, + ) + + assert image_result.keys() == tensor_result.keys() + for key, image_arr in image_result.items(): + tensor_arr: np.ndarray = tensor_result[key].numpy() + + assert image_arr.shape == tensor_arr.shape, f"Failed for key={key}" + + # The examples in PR#3042 have slightly different preprocessing from + # HuggingFace's LlavaProcessor, causing the test to fail. + # assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}" diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 7c5840baf3593..1d060e265848a 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -18,9 +18,10 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.lora.request import LoRARequest from vllm.model_executor.utils import set_random_seed +from vllm.multimodal import MultiModalData from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import Logprob, MultiModalData +from vllm.sequence import Logprob from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, random_uuid diff --git a/tests/tokenization/test_image_processor.py b/tests/tokenization/test_image_processor.py new file mode 100644 index 0000000000000..5ba2323367414 --- /dev/null +++ b/tests/tokenization/test_image_processor.py @@ -0,0 +1,20 @@ +import pytest +from transformers.image_processing_utils import BaseImageProcessor + +from vllm.transformers_utils.image_processor import get_image_processor + +IMAGE_PROCESSOR_NAMES = [ + "llava-hf/llava-1.5-7b-hf", + "llava-hf/llava-v1.6-34b-hf", +] + + +@pytest.mark.parametrize("processor_name", IMAGE_PROCESSOR_NAMES) +def test_image_processor_revision(processor_name: str): + # Assume that "main" branch always exists + image_processor = get_image_processor(processor_name, revision="main") + assert isinstance(image_processor, BaseImageProcessor) + + # Assume that "never" branch always does not exist + with pytest.raises(OSError, match='not a valid git identifier'): + get_image_processor(processor_name, revision="never") diff --git a/vllm/config.py b/vllm/config.py index ba4361ffb98b4..eee62d2683835 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1094,10 +1094,12 @@ class ImageInputType(enum.Enum): # worst case scenario (biggest supported resolution). image_input_shape: tuple image_feature_size: int + # The image processor to load from HuggingFace + image_processor: Optional[str] + image_processor_revision: Optional[str] @classmethod - def get_image_input_enum_type( - cls, value: str) -> "VisionLanguageConfig.ImageInputType": + def get_image_input_enum_type(cls, value: str) -> ImageInputType: """Get the image input type from a string.""" try: return cls.ImageInputType[value.upper()] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8a73fc931a95a..b315d4d2ece29 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,6 +1,7 @@ import argparse import dataclasses import json +import warnings from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -80,6 +81,10 @@ class EngineArgs: image_token_id: Optional[int] = None image_input_shape: Optional[str] = None image_feature_size: Optional[int] = None + image_processor: Optional[str] = None + image_processor_revision: Optional[str] = None + disable_image_processor: bool = False + scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False @@ -98,6 +103,53 @@ def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model + @staticmethod + def add_cli_args_for_vlm( + parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser.add_argument('--image-input-type', + type=nullable_str, + default=None, + choices=[ + t.name.lower() + for t in VisionLanguageConfig.ImageInputType + ], + help=('The image input type passed into vLLM.')) + parser.add_argument('--image-token-id', + type=int, + default=None, + help=('Input id for image token.')) + parser.add_argument( + '--image-input-shape', + type=nullable_str, + default=None, + help=('The biggest image input shape (worst for memory footprint) ' + 'given an input type. Only used for vLLM\'s profile_run.')) + parser.add_argument( + '--image-feature-size', + type=int, + default=None, + help=('The image feature size along the context dimension.')) + parser.add_argument( + '--image-processor', + type=str, + default=EngineArgs.image_processor, + help='Name or path of the huggingface image processor to use. ' + 'If unspecified, model name or path will be used.') + parser.add_argument( + '--image-processor-revision', + type=str, + default=None, + help='Revision of the huggingface image processor version to use. ' + 'It can be a branch name, a tag name, or a commit id. ' + 'If unspecified, will use the default version.') + parser.add_argument( + '--disable-image-processor', + action='store_true', + help='Disables the use of image processor, even if one is defined ' + 'for the model on huggingface.') + + return parser + @staticmethod def add_cli_args( parser: argparse.ArgumentParser) -> argparse.ArgumentParser: @@ -113,7 +165,8 @@ def add_cli_args( '--tokenizer', type=nullable_str, default=EngineArgs.tokenizer, - help='Name or path of the huggingface tokenizer to use.') + help='Name or path of the huggingface tokenizer to use. ' + 'If unspecified, model name or path will be used.') parser.add_argument( '--skip-tokenizer-init', action='store_true', @@ -136,9 +189,9 @@ def add_cli_args( '--tokenizer-revision', type=nullable_str, default=None, - help='The specific tokenizer version to use. It can be a branch ' - 'name, a tag name, or a commit id. If unspecified, will use ' - 'the default version.') + help='Revision of the huggingface tokenizer to use. ' + 'It can be a branch name, a tag name, or a commit id. ' + 'If unspecified, will use the default version.') parser.add_argument( '--tokenizer-mode', type=str, @@ -445,31 +498,10 @@ def add_cli_args( default=EngineArgs.device, choices=["auto", "cuda", "neuron", "cpu"], help='Device type for vLLM execution.') + # Related to Vision-language models such as llava - parser.add_argument( - '--image-input-type', - type=nullable_str, - default=None, - choices=[ - t.name.lower() for t in VisionLanguageConfig.ImageInputType - ], - help=('The image input type passed into vLLM. ' - 'Should be one of "pixel_values" or "image_features".')) - parser.add_argument('--image-token-id', - type=int, - default=None, - help=('Input id for image token.')) - parser.add_argument( - '--image-input-shape', - type=nullable_str, - default=None, - help=('The biggest image input shape (worst for memory footprint) ' - 'given an input type. Only used for vLLM\'s profile_run.')) - parser.add_argument( - '--image-feature-size', - type=int, - default=None, - help=('The image feature size along the context dimension.')) + parser = EngineArgs.add_cli_args_for_vlm(parser) + parser.add_argument( '--scheduler-delay-factor', type=float, @@ -488,7 +520,6 @@ def add_cli_args( default=EngineArgs.speculative_model, help= 'The name of the draft model to be used in speculative decoding.') - parser.add_argument( '--num-speculative-tokens', type=int, @@ -666,12 +697,27 @@ def create_engine_config(self, ) -> EngineConfig: raise ValueError( 'Specify `image_token_id`, `image_input_shape` and ' '`image_feature_size` together with `image_input_type`.') + + if self.image_processor is None: + self.image_processor = self.model + if self.disable_image_processor: + if self.image_processor != self.model: + warnings.warn( + "You've specified an image processor " + f"({self.image_processor}) but also disabled " + "it via `--disable-image-processor`.", + stacklevel=2) + + self.image_processor = None + vision_language_config = VisionLanguageConfig( image_input_type=VisionLanguageConfig. get_image_input_enum_type(self.image_input_type), image_token_id=self.image_token_id, image_input_shape=str_to_int_tuple(self.image_input_shape), image_feature_size=self.image_feature_size, + image_processor=self.image_processor, + image_processor_revision=self.image_processor_revision, ) else: vision_language_config = None @@ -734,3 +780,7 @@ def _engine_args_parser(): def _async_engine_args_parser(): return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(), async_args_only=True) + + +def _vlm_engine_args_parser(): + return EngineArgs.add_cli_args_for_vlm(argparse.ArgumentParser()) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index beee16d188eb5..d4a4c16f2a7d5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -14,7 +14,6 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.sequence import MultiModalData from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, deprecate_kwargs @@ -164,7 +163,6 @@ def generate( prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: ... @@ -177,7 +175,6 @@ def generate( prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: ... @@ -191,7 +188,6 @@ def generate( prompt_token_ids: List[int], use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: ... @@ -205,7 +201,6 @@ def generate( prompt_token_ids: List[List[int]], use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: ... @@ -217,7 +212,6 @@ def generate( prompt_token_ids: Union[List[int], List[List[int]]], use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: ... @@ -236,7 +230,6 @@ def generate( @deprecate_kwargs("prompts", "prompt_token_ids", - "multi_modal_data", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, additional_message="Please use the 'inputs' parameter " "instead.") @@ -249,7 +242,6 @@ def generate( prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -281,11 +273,10 @@ def generate( "LLM.generate() is only supported for generation models " "(XForCausalLM).") - if prompt_token_ids is not None or multi_modal_data is not None: + if prompt_token_ids is not None: inputs = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data, ) else: inputs = cast( @@ -314,7 +305,6 @@ def encode( prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -327,7 +317,6 @@ def encode( prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -341,7 +330,6 @@ def encode( prompt_token_ids: List[int], use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -355,7 +343,6 @@ def encode( prompt_token_ids: List[List[int]], use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -367,7 +354,6 @@ def encode( prompt_token_ids: Union[List[int], List[List[int]]], use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -386,7 +372,6 @@ def encode( @deprecate_kwargs("prompts", "prompt_token_ids", - "multi_modal_data", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, additional_message="Please use the 'inputs' parameter " "instead.") @@ -399,7 +384,6 @@ def encode( prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[EmbeddingRequestOutput]: """Generates the completions for the input prompts. @@ -430,11 +414,10 @@ def encode( "LLM.encode() is only supported for embedding models (XModel)." ) - if prompt_token_ids is not None or multi_modal_data is not None: + if prompt_token_ids is not None: inputs = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data, ) else: inputs = cast( @@ -459,7 +442,6 @@ def _convert_v1_inputs( self, prompts: Optional[Union[str, List[str]]], prompt_token_ids: Optional[Union[List[int], List[List[int]]]], - multi_modal_data: Optional[MultiModalData], ): # skip_tokenizer_init is now checked in engine @@ -499,9 +481,6 @@ def _convert_v1_inputs( else: raise AssertionError - if multi_modal_data is not None: - item["multi_modal_data"] = multi_modal_data - inputs.append(item) return inputs diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index fbd7638097286..3332bcc578460 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -17,6 +17,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import get_dummy_image_data from vllm.sequence import SamplerOutput from .vlm_base import VisionLanguageModelBase @@ -82,6 +84,9 @@ class LlavaImageFeatureInputs(TypedDict): LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs] +@MULTIMODAL_REGISTRY.register_image_feature_input() +@MULTIMODAL_REGISTRY.register_image_pixel_input() +@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data) class LlavaForConditionalGeneration(VisionLanguageModelBase): def __init__(self, @@ -131,30 +136,41 @@ def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor: return data def _parse_and_validate_image_input( - self, data: object) -> Optional[LlavaImageInputs]: + self, **kwargs: object) -> Optional[LlavaImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_features = kwargs.pop("image_features", None) + expected_input_type = self.vision_language_config.image_input_type ImageInputType = VisionLanguageConfig.ImageInputType - if data is None: - return None - if expected_input_type == ImageInputType.PIXEL_VALUES: - if not isinstance(data, torch.Tensor): - raise TypeError("Image pixel vector should be a tensor, " - f"but received type: {type(data)}") + if image_features is not None: + raise ValueError( + "Expected pixel values but got image features") + if pixel_values is None: + return None + + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values") return LlavaImagePixelInputs( type="pixel_values", - data=self._validate_image_data(data), + data=self._validate_image_data(pixel_values), ) - elif expected_input_type == ImageInputType.IMAGE_FEATURES: - if not isinstance(data, torch.Tensor): - raise TypeError("Image feature vector should be a tensor, " - f"but received type: {type(data)}") + + if expected_input_type == ImageInputType.IMAGE_FEATURES: + if pixel_values is not None: + raise ValueError( + "Expected image features but got pixel values") + if image_features is None: + return None + + if not isinstance(image_features, torch.Tensor): + raise ValueError("Incorrect type of image features") return LlavaImageFeatureInputs( type="image_features", - data=self._validate_image_data(data), + data=self._validate_image_data(image_features), ) return None @@ -201,12 +217,14 @@ def _process_image_input(self, return self.multi_modal_projector(image_features) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - image_input: Optional[torch.Tensor] = None) -> SamplerOutput: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs: object, + ) -> SamplerOutput: """Run forward pass for Llava 1.5. One key thing to understand is the `input_ids` already accounts for the @@ -227,10 +245,10 @@ def forward(self, This way, the `positions` and `attn_metadata` are consistent with the `input_ids`. - The model takes two types of image inputs: + The model takes two types of image inputs: PIXEL_VALUES and IMAGE_FEATURES. The following shows how each maps to huggingface implementation. - PIXEL_VALUES: + PIXEL_VALUES: - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353 IMAGE_FEATURES: - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430 @@ -239,14 +257,15 @@ def forward(self, Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. - image_input: A batch of image inputs. - For PIXEL_VALUES, expecting [1, 3, 336, 336]. - For IMAGE_FEATURES, expecting [1, 576, 1024]. + pixel_values: For PIXEL_VALUES, expects a batch with shape + [1, 3, 336, 336]. + image_features: For IMAGE_FEATURES, expects a batch with shape + [1, 576, 1024]. """ - parsed_image_input = self._parse_and_validate_image_input(image_input) + image_input = self._parse_and_validate_image_input(**kwargs) - if parsed_image_input is not None: - vision_embeddings = self._process_image_input(parsed_image_input) + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = _merge_vision_embeddings( diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py new file mode 100644 index 0000000000000..270012e7d1c3b --- /dev/null +++ b/vllm/multimodal/__init__.py @@ -0,0 +1,7 @@ +from .base import MultiModalData, MultiModalPlugin +from .registry import MULTIMODAL_REGISTRY, MultiModalRegistry + +__all__ = [ + "MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY", + "MultiModalRegistry" +] diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py new file mode 100644 index 0000000000000..847752449ba80 --- /dev/null +++ b/vllm/multimodal/base.py @@ -0,0 +1,126 @@ +from abc import ABC, abstractmethod +from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type, + TypeVar) + +from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.logger import init_logger + +if TYPE_CHECKING: + import torch + from torch import nn + +logger = init_logger(__name__) + + +class MultiModalData: + """ + Base class that contains multi-modal data. + + To add a new modality, add a new file under ``multimodal`` directory. + + In this new file, subclass :class:`~MultiModalData` and + :class:`~MultiModalPlugin`. + + Finally, register the new plugin to + :const:`vllm.multimodal.MULTIMODAL_REGISTRY`. + This enables models to call :meth:`MultiModalRegistry.register_input` for + the new modality. + """ + pass + + +D = TypeVar("D", bound=MultiModalData) +N = TypeVar("N", bound=Type["nn.Module"]) + +MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig], + Dict[str, "torch.Tensor"]] +"""Return a dictionary to be passed as keyword arguments to +:meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers +and processors in HuggingFace Transformers.""" + + +class MultiModalPlugin(ABC, Generic[D]): + """ + Base class that defines data processing logic for a specific modality. + + In particular, we adopt a registry pattern to dispatch data processing + according to the model being used (considering that different models may + process the same data differently). This registry is in turn used by + :class:`~MultiModalRegistry` which acts at a higher level + (i.e., the modality of the data). + """ + + @classmethod + def get_model_cls(cls, model_config: ModelConfig) -> Type["nn.Module"]: + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + return get_model_architecture(model_config)[0] + + def __init__(self) -> None: + self._input_processors: Dict[Type["nn.Module"], + MultiModalInputProcessor[D]] = {} + + @abstractmethod + def get_data_type(self) -> Type[D]: + """ + Get the modality (subclass of :class:`~MultiModalData`) served by + this plugin. + """ + raise NotImplementedError + + @abstractmethod + def _default_input_processor( + self, data: D, model_config: ModelConfig, + vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]: + """Return a dictionary to be passed as keyword arguments to + :meth:`torch.nn.Module.forward`. This is similar in concept to + tokenizers and processors in HuggingFace Transformers. + """ + raise NotImplementedError + + def register_input_processor(self, + processor: Optional[ + MultiModalInputProcessor[D]] = None): + """ + Register an input processor to a model class. + + When the model receives input data that matches the modality served by + this plugin (see :meth:`get_data_type`), the provided input processor is + applied to preprocess the data. If `None` is provided, then the default + input processor is applied instead. + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._input_processors: + logger.warning( + "Model class %s already has an input processor " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._input_processors[model_cls] = processor \ + or self._default_input_processor + + return model_cls + + return wrapper + + def process_input( + self, data: D, model_config: ModelConfig, + vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]: + """ + Apply an input processor to a :class:`~MultiModalData` instance passed + to the model. + + The model is identified by ``model_config``. ``vlm_config`` is + for compatibility purposes and may be merged into ``model_config`` + in the near future. + """ + model_cls = self.get_model_cls(model_config) + + processor = self._input_processors.get(model_cls) + if processor is None: + raise KeyError(f"No input processor in {self} is registered for " + f"model class {model_cls.__name__}.") + + return processor(data, model_config, vlm_config) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py new file mode 100644 index 0000000000000..b964e9ee42624 --- /dev/null +++ b/vllm/multimodal/image.py @@ -0,0 +1,141 @@ +from typing import Dict, Tuple, Type, Union + +import torch +from PIL import Image + +from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.logger import init_logger +from vllm.sequence import SequenceData +from vllm.transformers_utils.image_processor import cached_get_image_processor + +from .base import MultiModalData, MultiModalPlugin + +logger = init_logger(__name__) + + +def _get_dummy_seq_data(seq_len: int, + vlm_config: VisionLanguageConfig) -> SequenceData: + # NOTE: We assume that token is repeated `image_feature_size` times + # and then concatenated with the text prompt + # TODO: Enable other ways of inserting the image into the prompt + + token_ids = [vlm_config.image_token_id] * vlm_config.image_feature_size + token_ids += [0] * (seq_len - vlm_config.image_feature_size) + + return SequenceData(token_ids) + + +def _get_dummy_values(vlm_config: VisionLanguageConfig) -> torch.Tensor: + if vlm_config.image_processor is None: + values_dtype = torch.float16 + else: + values_dtype = torch.uint8 + + return torch.zeros(vlm_config.image_input_shape, dtype=values_dtype) + + +def get_dummy_image_data( + seq_len: int, + model_config: ModelConfig, + vlm_config: VisionLanguageConfig, +) -> Tuple[SequenceData, MultiModalData]: + """Standard dummy data factory for image data (to be used in + :meth:`vlm.multimodal.MultiModalRegistry.register_dummy_data`).""" + seq_data = _get_dummy_seq_data(seq_len, vlm_config) + values = _get_dummy_values(vlm_config) + + config_input_type = vlm_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + + fake_mm_data: MultiModalData + if config_input_type == ImageInputType.PIXEL_VALUES: + fake_mm_data = ImagePixelData(values) + elif config_input_type == ImageInputType.IMAGE_FEATURES: + fake_mm_data = ImageFeatureData(values) + else: + raise NotImplementedError + + return seq_data, fake_mm_data + + +class ImagePixelData(MultiModalData): + """ + The pixel data of an image. Can be one of: + + - :class:``PIL.Image``: An image object. Requires that a HuggingFace + processor is available to the model. + - :class:``torch.Tensor``: The raw pixel data which is passed to the model + without additional pre-processing. + """ + + def __init__(self, image: Union[Image.Image, torch.Tensor]) -> None: + if isinstance(image, Image.Image): + # So that this class can be created inside the Image context manager + image.load() + + self.image = image + + +class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]): + + def get_data_type(self) -> Type[ImagePixelData]: + return ImagePixelData + + def _get_hf_image_processor(self, model_config: ModelConfig, + vlm_config: VisionLanguageConfig): + if vlm_config is None or vlm_config.image_processor is None: + return None + + return cached_get_image_processor( + vlm_config.image_processor, + trust_remote_code=model_config.trust_remote_code, + revision=vlm_config.image_processor_revision, + ) + + def _default_input_processor( + self, data: ImagePixelData, model_config: ModelConfig, + vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: + image = data.image + image_processor = self._get_hf_image_processor(model_config, + vlm_config) + + if isinstance(image, Image.Image): + if image_processor is None: + raise RuntimeError("No HuggingFace processor is available" + "to process the image object") + try: + return image_processor.preprocess(image, return_tensors="pt") \ + .to(model_config.dtype).data + except Exception: + logger.error("Failed to process image (%s)", image) + raise + elif isinstance(image, torch.Tensor): + pixel_values = image.to(model_config.dtype) + + return {"pixel_values": pixel_values} + + raise TypeError(f"Invalid image type: {type(image)}") + + +class ImageFeatureData(MultiModalData): + """ + The feature vector of an image, passed directly to the model. + + This should be the output of the vision tower. + """ + + def __init__(self, image_features: torch.Tensor) -> None: + self.image_features = image_features + + +class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]): + + def get_data_type(self) -> Type[ImageFeatureData]: + return ImageFeatureData + + def _default_input_processor( + self, data: ImageFeatureData, model_config: ModelConfig, + vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: + image_features = data.image_features.to(model_config.dtype) + + return {"image_features": image_features} diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py new file mode 100644 index 0000000000000..4789ce5ce4cfe --- /dev/null +++ b/vllm/multimodal/registry.py @@ -0,0 +1,156 @@ +import functools +from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, + Tuple, Type, TypeVar) + +from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.logger import init_logger + +from .base import MultiModalData, MultiModalPlugin +from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData, + ImagePixelPlugin) + +if TYPE_CHECKING: + import torch + from torch import nn + + from vllm.sequence import SequenceData + +logger = init_logger(__name__) + +D = TypeVar("D", bound=MultiModalData) +N = TypeVar("N", bound=Type["nn.Module"]) + +MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig], + Dict[str, "torch.Tensor"]] +MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig], + Tuple["SequenceData", MultiModalData]] + + +class MultiModalRegistry: + """ + This registry is used by model runners to dispatch data processing + according to its modality and the target model. + """ + + DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin()) + + def __init__(self, + *, + plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS + ) -> None: + self._plugins_by_data_type = {p.get_data_type(): p for p in plugins} + self._dummy_factories_by_model_type: Dict[Type["nn.Module"], + MultiModalDummyFactory] = {} + + def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None: + data_type = plugin.get_data_type() + + if data_type in self._plugins_by_data_type: + logger.warning( + "A plugin is already registered for data type %s, " + "and will be overwritten by the new plugin %s.", data_type, + plugin) + + self._plugins_by_data_type[data_type] = plugin + + def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]): + for typ in data_type.mro(): + plugin = self._plugins_by_data_type.get(typ) + if plugin is not None: + return plugin + + msg = f"Unknown multi-modal data type: {data_type}" + raise NotImplementedError(msg) + + def register_dummy_data(self, factory: MultiModalDummyFactory): + """ + Register a dummy data factory to a model class. + + During memory profiling, the provided function is invoked to create + dummy data to be inputted into the model. The modality and shape of + the dummy data should be an upper bound of what the model would receive + at inference time. + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._dummy_factories_by_model_type: + logger.warning( + "Model class %s already has dummy data " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._dummy_factories_by_model_type[model_cls] = factory + + return model_cls + + return wrapper + + def dummy_data_for_profiling(self, seq_len: int, model_config: ModelConfig, + vlm_config: VisionLanguageConfig): + """Create dummy data for memory profiling.""" + model_cls = MultiModalPlugin.get_model_cls(model_config) + dummy_factory = self._dummy_factories_by_model_type.get(model_cls) + if dummy_factory is None: + msg = f"No dummy data defined for model class: {model_cls}" + raise NotImplementedError(msg) + + return dummy_factory(seq_len, model_config, vlm_config) + + def register_input( + self, + data_type: Type[D], + processor: Optional[MultiModalInputProcessor[D]] = None): + """ + Register an input processor for a specific modality to a model class. + + See :meth:`MultiModalPlugin.register_input_processor` for more details. + """ + return self._get_plugin_for_data_type(data_type) \ + .register_input_processor(processor) + + def register_image_pixel_input( + self, + processor: Optional[ + MultiModalInputProcessor[ImagePixelData]] = None): + """ + Register an input processor for image pixel data to a model class. + + See :meth:`MultiModalPlugin.register_input_processor` for more details. + """ + return self.register_input(ImagePixelData, processor) + + def register_image_feature_input( + self, + processor: Optional[ + MultiModalInputProcessor[ImageFeatureData]] = None): + """ + Register an input processor for image feature data to a model class. + + See :meth:`MultiModalPlugin.register_input_processor` for more details. + """ + return self.register_input(ImageFeatureData, processor) + + def process_input(self, data: MultiModalData, model_config: ModelConfig, + vlm_config: VisionLanguageConfig): + """ + Apply an input processor to a :class:`~MultiModalData` instance passed + to the model. + + See :meth:`MultiModalPlugin.process_input` for more details. + """ + return self._get_plugin_for_data_type(type(data)) \ + .process_input(data, model_config, vlm_config) + + def create_input_processor(self, model_config: ModelConfig, + vlm_config: VisionLanguageConfig): + """ + Create an input processor (see :meth:`process_input`) for a + specific model. + """ + return functools.partial(self.process_input, + model_config=model_config, + vlm_config=vlm_config) + + +MULTIMODAL_REGISTRY = MultiModalRegistry() +"""The global :class:`~MultiModalRegistry` which is used by model runners.""" diff --git a/vllm/sequence.py b/vllm/sequence.py index ac5c234d052bd..2f27bf33b166e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,6 +5,8 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +import torch + from vllm.block import LogicalTokenBlock from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest @@ -12,8 +14,7 @@ from vllm.sampling_params import SamplingParams if TYPE_CHECKING: - import torch - + from vllm.multimodal import MultiModalData from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -398,25 +399,6 @@ class SequenceGroupState: generator: Optional = None # type: ignore -class MultiModalData: - """Multi modal request. - - Args: - type: The data type. - data: The actual data. - The required shape and semantic meaning of it depends on the vision - language config of the hosted model. - See `VisionLanguageConfig` in `config.py`. - """ - - class Type(enum.Enum): - IMAGE = enum.auto() - - def __init__(self, type: Type, data: "torch.Tensor"): - self.type = type - self.data = data - - class SequenceGroup: """A group of sequences that are generated from the same prompt. @@ -473,7 +455,7 @@ def prompt_token_ids(self) -> List[int]: return next(iter(self.seqs_dict.values())).prompt_token_ids @property - def multi_modal_data(self) -> Optional[MultiModalData]: + def multi_modal_data(self) -> Optional["MultiModalData"]: # All sequences in the group should have the same multi-modal data. # We use the multi-modal data of an arbitrary sequence. return next(iter(self.seqs_dict.values())).multi_modal_data @@ -655,7 +637,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, - multi_modal_data: Optional[MultiModalData] = None, + multi_modal_data: Optional["MultiModalData"] = None, encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, ) -> None: @@ -798,13 +780,13 @@ class SamplerOutput: outputs: List[CompletionSequenceGroupOutput] # On-device tensor containing probabilities of each token. - sampled_token_probs: Optional["torch.Tensor"] = None + sampled_token_probs: Optional[torch.Tensor] = None # On-device tensor containing the logprobs of each token. logprobs: Optional["torch.Tensor"] = None # On-device tensor containing the sampled token ids. - sampled_token_ids: Optional["torch.Tensor"] = None + sampled_token_ids: Optional[torch.Tensor] = None # Spec decode metrics populated by workers. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None diff --git a/vllm/transformers_utils/image_processor.py b/vllm/transformers_utils/image_processor.py new file mode 100644 index 0000000000000..3239b1d0cfa2f --- /dev/null +++ b/vllm/transformers_utils/image_processor.py @@ -0,0 +1,45 @@ +from functools import lru_cache +from typing import Optional + +from transformers import AutoImageProcessor +from transformers.image_processing_utils import BaseImageProcessor + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def get_image_processor( + processor_name: str, + *args, + trust_remote_code: bool = False, + revision: Optional[str] = None, + **kwargs, +) -> BaseImageProcessor: + """Gets an image processor for the given model name via HuggingFace.""" + try: + processor: BaseImageProcessor = AutoImageProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the image processor. If the image processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + return processor + + +cached_get_image_processor = lru_cache(get_image_processor) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index bc88f2c5bed6c..eaf43247d4fc5 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,4 +1,5 @@ -from typing import List, Optional, Tuple +from collections import defaultdict +from typing import Dict, List, Optional, Tuple import torch from torch import nn @@ -11,6 +12,7 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad @@ -63,6 +65,16 @@ def __init__( self.block_size, ) + # Create processor for multi-modal data + if self.vision_language_config is not None: + self.multi_modal_input_processor = MULTIMODAL_REGISTRY \ + .create_input_processor( + self.model_config, + self.vision_language_config, + ) + else: + self.multi_modal_input_processor = None + # Lazy initialization. self.model: nn.Module # Set after init_Model @@ -80,14 +92,15 @@ def load_model(self) -> None: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], Dict[ + str, torch.Tensor]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] - multi_modal_input_list: List[torch.Tensor] = [] + multi_modal_kwargs_list: Dict[str, + List[torch.Tensor]] = defaultdict(list) for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -108,9 +121,17 @@ def _prepare_prompt( # is always the first token in the sequence. input_positions.extend(list(range(computed_len, seq_len))) - if seq_group_metadata.multi_modal_data: - multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data) + mm_data = seq_group_metadata.multi_modal_data + if mm_data is not None: + # Process multi-modal data + if self.multi_modal_input_processor is None: + raise ValueError( + "Multi-modal inputs are only supported by " + "vision language models.") + + mm_kwargs = self.multi_modal_input_processor(mm_data) + for k, v in mm_kwargs.items(): + multi_modal_kwargs_list[k].append(v) # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] @@ -134,14 +155,10 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - if multi_modal_input_list: - assert self.vision_language_config, ( - "Multi-modal inputs are only supported by " - "vision language models.") - multi_modal_input = torch.cat(multi_modal_input_list, - dim=0).to(self.device) - else: - multi_modal_input = None + multi_modal_kwargs = { + k: torch.cat(v, dim=0).to(self.device) + for k, v in multi_modal_kwargs_list.items() + } num_prompt_tokens = len(input_tokens) @@ -167,7 +184,7 @@ def _prepare_prompt( slot_mapping=slot_mapping, ) return (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_input) + multi_modal_kwargs) def _prepare_decode( self, @@ -257,8 +274,8 @@ def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Optional[torch.Tensor]]: - multi_modal_input = None + Optional[Dict[str, torch.Tensor]]]: + multi_modal_kwargs = None if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -266,7 +283,7 @@ def prepare_input_tensors( # Prepare input tensors. if is_prompt: (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_input + multi_modal_kwargs ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, @@ -307,7 +324,7 @@ def prepare_input_tensors( ) return (input_tokens, input_positions, attn_metadata, - sampling_metadata, multi_modal_input) + sampling_metadata, multi_modal_kwargs) @torch.inference_mode() def execute_model( diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 0ba1200696cab..465130d10e2f9 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -90,7 +90,7 @@ def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, - Set[LoRARequest], LoRAMapping, torch.Tensor]: + Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: if self.is_driver_worker: assert seq_group_metadata_list is not None # Prepare input tensors. @@ -102,7 +102,7 @@ def prepare_input_tensors( _, lora_mapping, lora_requests, - multi_modal_input, + multi_modal_kwargs, slot_mapping, num_prefill_tokens, num_decode_tokens, @@ -117,7 +117,7 @@ def prepare_input_tensors( "input_positions": input_positions, "lora_requests": lora_requests, "lora_mapping": lora_mapping, - "multi_modal_input": multi_modal_input, + "multi_modal_kwargs": multi_modal_kwargs, "num_prefill_tokens": num_prefill_tokens, "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, @@ -132,7 +132,7 @@ def prepare_input_tensors( input_positions = metadata_dict.pop("input_positions") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") - multi_modal_input = metadata_dict.pop("multi_modal_input") + multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") if metadata_dict: attn_metadata = self.attn_backend.make_metadata( **metadata_dict) @@ -143,7 +143,7 @@ def prepare_input_tensors( prompt_lens=None) return (input_tokens, input_positions, attn_metadata, pooling_metadata, - lora_requests, lora_mapping, multi_modal_input) + lora_requests, lora_mapping, multi_modal_kwargs) def _prepare_pooling( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 47aa70dc617af..63ec22d79694f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,5 +1,6 @@ import time import warnings +from collections import defaultdict from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union import numpy as np @@ -18,9 +19,9 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams -from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, - SequenceGroupMetadata) +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) @@ -44,7 +45,7 @@ class ModelInput(NamedTuple): query_lens: List[int] lora_mapping: Optional[LoRAMapping] lora_requests: Set[LoRARequest] - multi_modal_input: Optional[torch.Tensor] + multi_modal_kwargs: Dict[str, torch.Tensor] slot_mapping: torch.Tensor num_prefill_tokens: int num_decode_tokens: int @@ -60,7 +61,7 @@ def empty(cls, device): query_lens=[], lora_mapping=None, lora_requests=set(), - multi_modal_input=None, + multi_modal_kwargs={}, slot_mapping=torch.empty(0, device=device), num_prefill_tokens=0, num_decode_tokens=0, @@ -122,6 +123,16 @@ def __init__( self.block_size, ) + # Create processor for multi-modal data + if self.vision_language_config is not None: + self.multi_modal_input_processor = MULTIMODAL_REGISTRY \ + .create_input_processor( + self.model_config, + self.vision_language_config, + ) + else: + self.multi_modal_input_processor = None + # Lazy initialization self.model: nn.Module # Set after load_model # Set if the backend is flashinfer. @@ -242,7 +253,8 @@ def _prepare_model_input( context_lens: List[int] = [] query_lens: List[int] = [] block_tables: List[List[int]] = [] - multi_modal_input_list: List[torch.Tensor] = [] + multi_modal_kwargs_list: Dict[str, + List[torch.Tensor]] = defaultdict(list) decode_only = True num_prefills = 0 num_prefill_tokens = 0 @@ -417,9 +429,17 @@ def _prepare_model_input( and seq_group_metadata.sampling_params.prompt_logprobs else 1)) - if seq_group_metadata.multi_modal_data: - multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data) + mm_data = seq_group_metadata.multi_modal_data + if mm_data is not None: + # Process multi-modal data + if self.multi_modal_input_processor is None: + raise ValueError( + "Multi-modal inputs are only supported by " + "vision language models.") + + mm_kwargs = self.multi_modal_input_processor(mm_data) + for k, v in mm_kwargs.items(): + multi_modal_kwargs_list[k].append(v) if _is_block_tables_empty(seq_group_metadata.block_tables): # During memory profiling, the block tables are not @@ -508,16 +528,6 @@ def _prepare_model_input( context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) - - if multi_modal_input_list: - assert self.vision_language_config, ( - "Multi-modal inputs are only supported by " - "vision language models.") - multi_modal_input = torch.cat(multi_modal_input_list, - dim=0).to(self.device) - else: - multi_modal_input = None - query_lens_tensor = torch.tensor(query_lens, dtype=torch.long, device=self.device) @@ -614,6 +624,11 @@ def _prepare_model_input( else: lora_mapping = None + multi_modal_kwargs = { + k: torch.cat(v, dim=0).to(self.device) + for k, v in multi_modal_kwargs_list.items() + } + return ModelInput( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, @@ -622,7 +637,7 @@ def _prepare_model_input( query_lens=query_lens, lora_mapping=lora_mapping, lora_requests=lora_requests, - multi_modal_input=multi_modal_input, + multi_modal_kwargs=multi_modal_kwargs, slot_mapping=slot_mapping_tensor, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, @@ -633,7 +648,7 @@ def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[LoRARequest], LoRAMapping, torch.Tensor]: + Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: if self.is_driver_worker: assert seq_group_metadata_list is not None # Prepare input tensors. @@ -645,7 +660,7 @@ def prepare_input_tensors( query_lens, lora_mapping, lora_requests, - multi_modal_input, + multi_modal_kwargs, slot_mapping, num_prefill_tokens, num_decode_tokens, @@ -662,7 +677,7 @@ def prepare_input_tensors( sampling_metadata.selected_token_indices, "lora_requests": lora_requests, "lora_mapping": lora_mapping, - "multi_modal_input": multi_modal_input, + "multi_modal_kwargs": multi_modal_kwargs, "num_prefill_tokens": num_prefill_tokens, "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, @@ -679,7 +694,7 @@ def prepare_input_tensors( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") - multi_modal_input = metadata_dict.pop("multi_modal_input") + multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") if metadata_dict: attn_metadata = self.attn_backend.make_metadata( **metadata_dict) @@ -694,7 +709,7 @@ def prepare_input_tensors( return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, - multi_modal_input) + multi_modal_kwargs) @torch.inference_mode() def execute_model( @@ -703,7 +718,7 @@ def execute_model( kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, - lora_requests, lora_mapping, multi_modal_input + lora_requests, lora_mapping, multi_modal_kwargs ) = self.prepare_input_tensors(seq_group_metadata_list) if self.lora_config: @@ -717,15 +732,14 @@ def execute_model( model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model - execute_model_kwargs = { - "input_ids": input_tokens, - "positions": input_positions, - "kv_caches": kv_caches, - "attn_metadata": attn_metadata, - } - if self.vision_language_config: - execute_model_kwargs.update({"image_input": multi_modal_input}) - hidden_states = model_executable(**execute_model_kwargs) + + hidden_states = model_executable( + input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + **multi_modal_kwargs, + ) # Compute the logits. logits = self.model.compute_logits(hidden_states, sampling_metadata) @@ -781,16 +795,24 @@ def profile_run(self) -> None: # To exercise the worst scenario for GPU memory consumption, # the number of seqs (batch_size) is chosen to maximize the number # of images processed. - if self.vision_language_config: + model_config = self.model_config + vlm_config = self.vision_language_config + + if vlm_config: max_num_seqs = min( max_num_seqs, - int(max_num_batched_tokens / - self.vision_language_config.image_feature_size)) + int(max_num_batched_tokens / vlm_config.image_feature_size)) for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) - seq_data, fake_multi_modal_input = _prepare_fake_inputs( - seq_len, self.vision_language_config) + + if vlm_config is None: + seq_data = SequenceData([0] * seq_len) + dummy_multi_modal_data = None + else: + seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \ + .dummy_data_for_profiling(seq_len, model_config, vlm_config) + seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, @@ -799,7 +821,7 @@ def profile_run(self) -> None: block_tables=None, lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, - multi_modal_data=fake_multi_modal_input, + multi_modal_data=dummy_multi_modal_data, ) seqs.append(seq) @@ -1034,24 +1056,6 @@ def _get_graph_batch_size(batch_size: int) -> int: _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) -def _prepare_fake_inputs( - seq_len: int, vision_language_config: Optional[VisionLanguageConfig]): - """Prepare fake inputs for profile run.""" - if vision_language_config: - prompt_tokens = [ - vision_language_config.image_token_id - ] * vision_language_config.image_feature_size + [0] * ( - seq_len - vision_language_config.image_feature_size) - fake_image_input = MultiModalData( - type=MultiModalData.Type.IMAGE, - data=torch.zeros(vision_language_config.image_input_shape, - dtype=torch.float16)) - else: - prompt_tokens = [0] * seq_len - fake_image_input = None - return SequenceData(prompt_tokens), fake_image_input - - def _is_block_tables_empty(block_tables: Union[None, Dict]): """ Check if block_tables is None or a dictionary with all None values.