diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 084607c155cb0..ec64a82de84d4 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -378,7 +378,7 @@ Text Generation - ✅︎ * - :code:`MiniCPMV` - MiniCPM-V - - Image\ :sup:`+` + - Image\ :sup:`E+` - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - ✅︎ - ✅︎ diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index b2262de238660..a3ee5da044220 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -57,12 +57,19 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT print(generated_text) # Inference with image embeddings as input with additional parameters - # Specifically, we are conducting a trial run of Qwen2VL with the new input format, as the model utilizes additional parameters for calculating positional encoding. - image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM) - image_grid_thw = torch.load(...) # torch.Tensor of shape (1, 3) + # Specifically, we are conducting a trial run of Qwen2VL and MiniCPM-V with the new input format, which utilizes additional parameters. + mm_data = {} + + image_embeds = torch.load(...) # torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM) + # For Qwen2VL, image_grid_thw is needed to calculate positional encoding. + mm_data['image'] = { + "image_embeds": image_embeds, + "image_grid_thw": torch.load(...) # torch.Tensor of shape (1, 3), + } + # For MiniCPM-V, image_size_list is needed to calculate details of the sliced image. mm_data['image'] = { "image_embeds": image_embeds, - "image_grid_thw": image_grid_thw, + "image_size_list": [image.size] # list of image sizes } outputs = llm.generate({ "prompt": prompt, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 6d0fa34f299ad..9ee4dd0f0623b 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -24,8 +24,8 @@ import math import re from functools import partial -from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, - TypedDict) +from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, + Tuple, TypedDict, Union) import torch import torch.types @@ -65,10 +65,12 @@ "llm.lm_head": "lm_head", } +RawImageType = Union[Image.Image, torch.Tensor] -class MiniCPMVImageInput(TypedDict): + +class MiniCPMVRawImageInput(TypedDict): """Input mapper input with auxiliary data for computing image bounds.""" - image: Image.Image + image: RawImageType # Image bounds token ids in 0-dim scaler tensor. im_start_id: torch.Tensor @@ -78,7 +80,8 @@ class MiniCPMVImageInput(TypedDict): class MiniCPMVImagePixelInputs(TypedDict): - pixel_values: List[torch.Tensor] + type: Literal["pixel_values"] + data: List[torch.Tensor] """ Shape: `(batch_size * num_images, num_channels, height, width)` @@ -101,6 +104,27 @@ class MiniCPMVImagePixelInputs(TypedDict): """ +class MiniCPMVImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """ + Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + instead of a batched tensor. + """ + + image_bounds: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(start, stop)` format. + """ + + +MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, + MiniCPMVImageEmbeddingInputs] + DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) @@ -194,22 +218,22 @@ def forward(self, x: torch.Tensor, def _build_image_input(ctx: InputContext, - image: Image.Image) -> MiniCPMVImageInput: + image: RawImageType) -> MiniCPMVRawImageInput: tokenizer = cached_get_tokenizer( ctx.model_config.tokenizer, trust_remote_code=ctx.model_config.trust_remote_code) if hasattr(tokenizer, "slice_start_id"): - return MiniCPMVImageInput( + return MiniCPMVRawImageInput( image=image, im_start_id=torch.tensor(tokenizer.im_start_id), im_end_id=torch.tensor(tokenizer.im_end_id), slice_start_id=torch.tensor(tokenizer.slice_start_id), slice_end_id=torch.tensor(tokenizer.slice_end_id)) else: - return MiniCPMVImageInput(image=image, - im_start_id=torch.tensor( - tokenizer.im_start_id), - im_end_id=torch.tensor(tokenizer.im_end_id)) + return MiniCPMVRawImageInput( + image=image, + im_start_id=torch.tensor(tokenizer.im_start_id), + im_end_id=torch.tensor(tokenizer.im_end_id)) def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: @@ -280,20 +304,25 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int): pattern = "(./)" images = multi_modal_data["image"] - if isinstance(images, Image.Image): - images = [images] image_tags = re.findall(pattern, prompt) - if len(image_tags) == 0: new_token_ids = token_ids new_prompt = prompt else: + if isinstance(images, dict): + image_size_list = images.get("image_size_list") + images = [images.get("image_embeds")] + else: + if isinstance(images, Image.Image): + images = [images] + image_size_list = [image.size for image in images] + text_chunks = prompt.split(pattern) new_prompt_chunks: List[str] = [] - for i in range(len(images)): + for i in range(len(image_size_list)): new_prompt_chunks += [ text_chunks[i], - get_placeholder(images[i].size, i) + get_placeholder(image_size_list[i], i) ] new_prompt_chunks.append(text_chunks[-1]) new_prompt = "".join(new_prompt_chunks) @@ -323,9 +352,15 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object): if not isinstance(data, list): raise ValueError( "Image input must be list of MiniCPMVImageInput, got (%s)", data) - batch_data = image_processor \ - .preprocess([img["image"] for img in data], return_tensors="pt") \ - .data + + if len(data) > 0 and isinstance(data[0]['image'], torch.Tensor): + batch_data = { + "image_embeds": data[0]['image'], + } + else: + batch_data = image_processor \ + .preprocess([img["image"] for img in data], return_tensors="pt") \ + .data if len(data) > 0: batch_data["im_start_id"] = data[0]["im_start_id"] @@ -380,7 +415,7 @@ def __init__( def get_embedding( self, input_ids: torch.Tensor, - image_inputs: Optional[MiniCPMVImagePixelInputs], + image_inputs: Optional[MiniCPMVImageInputs], ) -> Tuple[torch.Tensor, torch.Tensor]: vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids) if hasattr(self.config, "scale_emb"): @@ -389,7 +424,12 @@ def get_embedding( if image_inputs is None: # No image vision_hidden_states = torch.tensor([], device=input_ids.device) else: - vision_hidden_states = self.get_vision_hidden_states(image_inputs) + if image_inputs["type"] == "image_embeds": + vision_hidden_states = (image_inputs["data"].type( + vlm_embedding.dtype).to(vlm_embedding.device)) + else: + vision_hidden_states = self.get_vision_hidden_states( + image_inputs) # See NOTE in _parse_and_validate_inputs image_bounds = image_inputs["image_bounds"] @@ -440,9 +480,23 @@ def _parse_and_validate_inputs( self, input_ids: torch.Tensor, **kwargs: object, - ) -> Optional[MiniCPMVImagePixelInputs]: + ) -> Optional[MiniCPMVImageInputs]: pixel_values = kwargs.pop("pixel_values", []) tgt_sizes = kwargs.pop("tgt_sizes", []) + im_start_id = kwargs.pop("im_start_id", None) + im_end_id = kwargs.pop("im_end_id", None) + slice_start_id = kwargs.pop("slice_start_id", None) + slice_end_id = kwargs.pop("slice_end_id", None) + image_embeds = kwargs.pop("image_embeds", None) + + if image_embeds is not None: + return MiniCPMVImageEmbeddingInputs( + image_bounds=self._get_image_bounds(input_ids, im_start_id, + im_end_id, slice_start_id, + slice_end_id), + data=image_embeds, + type="image_embeds", + ) if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " @@ -477,10 +531,6 @@ def _parse_and_validate_inputs( if len(pixel_values_flat) == 0: return None - im_start_id = kwargs.pop("im_start_id", None) - im_end_id = kwargs.pop("im_end_id", None) - slice_start_id = kwargs.pop("slice_start_id", None) - slice_end_id = kwargs.pop("slice_end_id", None) if im_start_id is None: return None @@ -488,8 +538,9 @@ def _parse_and_validate_inputs( image_bounds=self._get_image_bounds(input_ids, im_start_id, im_end_id, slice_start_id, slice_end_id), - pixel_values=pixel_values_flat, + data=pixel_values_flat, tgt_sizes=torch.stack(tgt_sizes_flat), + type="pixel_values", ) def forward( @@ -610,8 +661,8 @@ def get_vision_embedding( ) -> torch.Tensor: raise NotImplementedError - def get_vision_hidden_states( - self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: raise NotImplementedError def is_default_weight_loading(self, name: str) -> bool: @@ -705,9 +756,9 @@ def get_vision_embedding( res.append(self.resampler(vision_embedding, tgt_size)) return torch.vstack(res) - def get_vision_hidden_states( - self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: - pixel_values = data["pixel_values"] + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + pixel_values = data["data"] return self.get_vision_embedding(pixel_values) @@ -793,9 +844,9 @@ def get_vision_embedding( vision_embedding = self.resampler(vision_embedding, tgt_sizes) return vision_embedding - def get_vision_hidden_states( - self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: - pixel_values = data["pixel_values"] + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + pixel_values = data["data"] tgt_sizes = data["tgt_sizes"] device = self.vpm.embeddings.position_embedding.weight.device @@ -909,9 +960,9 @@ def get_vision_embedding( ) return vision_embedding - def get_vision_hidden_states( - self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: - pixel_values = data["pixel_values"] + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + pixel_values = data["data"] tgt_sizes = data["tgt_sizes"] device = self.vpm.embeddings.position_embedding.weight.device