Skip to content

Commit

Permalink
[VLM] Abstract out multi-modal data parsing in merged processor (#11620)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Dec 30, 2024
1 parent b12e87f commit 8d9b672
Show file tree
Hide file tree
Showing 15 changed files with 560 additions and 312 deletions.
4 changes: 2 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ steps:
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/language -m 'not core_model'

- label: Multi-Modal Models Test (Standard) # 28min
- label: Multi-Modal Models Test (Standard) # 40min
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
Expand All @@ -372,7 +372,7 @@ steps:
- pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model

- label: Multi-Modal Models Test (Extended) 1 # 1h16m
- label: Multi-Modal Models Test (Extended) 1 # 48m
optional: true
source_file_dependencies:
- vllm/
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs,
from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
Expand All @@ -54,7 +54,7 @@ def calculate_image_placeholder(vision_config):

def mm_input_mapper_for_glmv(
ctx: InputContext,
data: MultiModalData[object],
data: ModalityData[object],
) -> Dict:
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
Expand Down
18 changes: 11 additions & 7 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalInputsV2,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement,
MultiModalDataItems, ProcessorInputs,
PromptReplacement,
full_groupby_modality)
from vllm.sequence import IntermediateTensors

Expand Down Expand Up @@ -179,7 +181,9 @@ def _get_prompt_replacements(
assert isinstance(vision_config, PixtralVisionConfig)

def get_replacement_pixtral(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)

(
num_width_tokens,
num_height_tokens,
Expand Down Expand Up @@ -591,8 +595,8 @@ def apply(

result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)

mm_items = self._get_mm_items(mm_data)
mm_item_counts = mm_items.get_item_counts()
mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts()
mm_kwargs = result["mm_kwargs"]

# We reimplement the functionality of MLlavaProcessor from
Expand Down
19 changes: 12 additions & 7 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalInputsV2,
MultiModalKwargs, NestedTensors,
PlaceholderRange)
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement,
MultiModalDataItems, ProcessorInputs,
PromptReplacement,
_BoundPromptReplacement,
_PlaceholderInfo)
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -381,20 +382,24 @@ def _get_prompt_replacements(
assert isinstance(bos_token_id, int)

def get_replacement_phi3v(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)

num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=image_size.width,
height=image_size.height,
)

return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]

num_images = mm_items.get_count("image", strict=False)

return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_phi3v,
) for image_token in image_tokens[:len(mm_items.images)]
) for image_token in image_tokens[:num_images]
]

def _apply_prompt_replacements(
Expand Down
22 changes: 9 additions & 13 deletions vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from functools import cached_property
from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)

import numpy as np
import torch
Expand All @@ -38,10 +38,12 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement)
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsMultiModal, SupportsPP
Expand Down Expand Up @@ -99,15 +101,9 @@ def _get_hf_processor(
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().feature_extractor # type: ignore

def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
# resample audio to the model's sampling rate
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)

return super()._get_hf_mm_data(mm_items)
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)

def _call_hf_processor(
self,
Expand Down
153 changes: 72 additions & 81 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Set, Tuple, Type, TypedDict, Union)

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -55,15 +54,16 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
NestedTensors, VideoItem)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement)
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import is_list_of

from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
Expand Down Expand Up @@ -719,61 +719,81 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext,
data_type_key="video")


class Qwen2VLMultiModalDataItems(MultiModalDataItems):
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]):

@staticmethod
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
"""
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
"""
multi_data = Qwen2VLMultiModalDataItems()

for k, v in data.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
# yapf: disable
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = ( # type: ignore[index]
v if (
isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
or is_list_of(v, list)
or isinstance(v[0], (np.ndarray, torch.Tensor))
and v[0].ndim == 4
) else [v]
)
elif k in ("image", "audio"):
multi_data[k] = ( # type: ignore[index]
v if isinstance(v, (dict, torch.Tensor, list)) else [v]
)
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
def __init__(self, data: dict, modality: str) -> None:
super().__init__(data)

return multi_data
self.modality = modality

def get_item_counts(self) -> Mapping[str, int]:
return {
m: (
len(items[f"{m}_grid_thw"]) # type: ignore
if isinstance(items, dict) else len(items))
for m, items in self.items()
}
grid_thw = data[f"{modality}_grid_thw"]
slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist()
self._slices = [
slice(slice_idxs[i], slice_idxs[i + 1])
for i in range(len(grid_thw))
]

def has_embedding_inputs(self) -> bool:
return any(
isinstance(items, dict) or any(
isinstance(item, torch.Tensor) for item in items)
for items in self.values())
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r})")

def get_count(self) -> int:
return len(self.data[f"{self.modality}_grid_thw"])

class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
def get(self, index: int) -> dict[str, torch.Tensor]:
out = {}
for k, v in self.data.items():
if v != f"{self.modality}_grid_thw":
v = v[self._slices[index]]

out[k] = v

return out

def get_processor_data(self) -> Mapping[str, object]:
return {}

def get_passthrough_data(self) -> Mapping[str, object]:
return self.data


class Qwen2ImageEmbeddingItems(Qwen2EmbeddingItems):

def __init__(self, data: dict) -> None:
super().__init__(data, "image")


class Qwen2VideoEmbeddingItems(Qwen2EmbeddingItems):

def _get_mm_items(
def __init__(self, data: dict) -> None:
super().__init__(data, "video")


class Qwen2MultiModalDataParser(MultiModalDataParser):

def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return Qwen2EmbeddingItems(data, modality="image")

return super()._parse_image_data(data)

def _parse_video_data(
self,
mm_data: MultiModalDataDict,
) -> MultiModalDataItems:
return Qwen2VLMultiModalDataItems.from_dict(mm_data)
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return Qwen2EmbeddingItems(data, modality="video")

return super()._parse_video_data(data)


class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):

def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser()

def _get_hf_processor(
self,
Expand All @@ -796,35 +816,6 @@ def _get_hf_processor(

return hf_processor

def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]()

for k, v in mm_items.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
if k in ("image", "video", "audio"):
if isinstance(v, dict):
# Pass through embedding inputs (dict)
passthrough_data.update(v)
elif isinstance(v, torch.Tensor) and v.ndim == 3:
# Pass through embedding inputs (single)
passthrough_data[f"{k}_embeds"] = [v]
elif (is_list_of(v, torch.Tensor) and len(v) > 0
and v[0].ndim == 2):
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
elif len(v) > 0:
# Map keys to plural form, e.g.: image -> images
processor_data[f"{k}s"] = v
else:
processor_data[k] = v

return processor_data, passthrough_data

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
Expand Down
Loading

0 comments on commit 8d9b672

Please sign in to comment.