Skip to content

Commit

Permalink
[Model] support input image embedding for minicpmv (vllm-project#9237)
Browse files Browse the repository at this point in the history
  • Loading branch information
whyiug authored Oct 10, 2024
1 parent 4058c2d commit 25bb619
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 43 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ Text Generation
- ✅︎
* - :code:`MiniCPMV`
- MiniCPM-V
- Image\ :sup:`+`
- Image\ :sup:`E+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
- ✅︎
- ✅︎
Expand Down
15 changes: 11 additions & 4 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,19 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
print(generated_text)
# Inference with image embeddings as input with additional parameters
# Specifically, we are conducting a trial run of Qwen2VL with the new input format, as the model utilizes additional parameters for calculating positional encoding.
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
image_grid_thw = torch.load(...) # torch.Tensor of shape (1, 3)
# Specifically, we are conducting a trial run of Qwen2VL and MiniCPM-V with the new input format, which utilizes additional parameters.
mm_data = {}
image_embeds = torch.load(...) # torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM)
# For Qwen2VL, image_grid_thw is needed to calculate positional encoding.
mm_data['image'] = {
"image_embeds": image_embeds,
"image_grid_thw": torch.load(...) # torch.Tensor of shape (1, 3),
}
# For MiniCPM-V, image_size_list is needed to calculate details of the sliced image.
mm_data['image'] = {
"image_embeds": image_embeds,
"image_grid_thw": image_grid_thw,
"image_size_list": [image.size] # list of image sizes
}
outputs = llm.generate({
"prompt": prompt,
Expand Down
127 changes: 89 additions & 38 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import math
import re
from functools import partial
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
TypedDict)
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict, Union)

import torch
import torch.types
Expand Down Expand Up @@ -65,10 +65,12 @@
"llm.lm_head": "lm_head",
}

RawImageType = Union[Image.Image, torch.Tensor]

class MiniCPMVImageInput(TypedDict):

class MiniCPMVRawImageInput(TypedDict):
"""Input mapper input with auxiliary data for computing image bounds."""
image: Image.Image
image: RawImageType

# Image bounds token ids in 0-dim scaler tensor.
im_start_id: torch.Tensor
Expand All @@ -78,7 +80,8 @@ class MiniCPMVImageInput(TypedDict):


class MiniCPMVImagePixelInputs(TypedDict):
pixel_values: List[torch.Tensor]
type: Literal["pixel_values"]
data: List[torch.Tensor]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Expand All @@ -101,6 +104,27 @@ class MiniCPMVImagePixelInputs(TypedDict):
"""


class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
"""

image_bounds: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(start, stop)` format.
"""


MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
MiniCPMVImageEmbeddingInputs]

DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)


Expand Down Expand Up @@ -194,22 +218,22 @@ def forward(self, x: torch.Tensor,


def _build_image_input(ctx: InputContext,
image: Image.Image) -> MiniCPMVImageInput:
image: RawImageType) -> MiniCPMVRawImageInput:
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code)
if hasattr(tokenizer, "slice_start_id"):
return MiniCPMVImageInput(
return MiniCPMVRawImageInput(
image=image,
im_start_id=torch.tensor(tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id),
slice_start_id=torch.tensor(tokenizer.slice_start_id),
slice_end_id=torch.tensor(tokenizer.slice_end_id))
else:
return MiniCPMVImageInput(image=image,
im_start_id=torch.tensor(
tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id))
return MiniCPMVRawImageInput(
image=image,
im_start_id=torch.tensor(tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id))


def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
Expand Down Expand Up @@ -280,20 +304,25 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int):

pattern = "(<image>./</image>)"
images = multi_modal_data["image"]
if isinstance(images, Image.Image):
images = [images]
image_tags = re.findall(pattern, prompt)

if len(image_tags) == 0:
new_token_ids = token_ids
new_prompt = prompt
else:
if isinstance(images, dict):
image_size_list = images.get("image_size_list")
images = [images.get("image_embeds")]
else:
if isinstance(images, Image.Image):
images = [images]
image_size_list = [image.size for image in images]

text_chunks = prompt.split(pattern)
new_prompt_chunks: List[str] = []
for i in range(len(images)):
for i in range(len(image_size_list)):
new_prompt_chunks += [
text_chunks[i],
get_placeholder(images[i].size, i)
get_placeholder(image_size_list[i], i)
]
new_prompt_chunks.append(text_chunks[-1])
new_prompt = "".join(new_prompt_chunks)
Expand Down Expand Up @@ -323,9 +352,15 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
if not isinstance(data, list):
raise ValueError(
"Image input must be list of MiniCPMVImageInput, got (%s)", data)
batch_data = image_processor \
.preprocess([img["image"] for img in data], return_tensors="pt") \
.data

if len(data) > 0 and isinstance(data[0]['image'], torch.Tensor):
batch_data = {
"image_embeds": data[0]['image'],
}
else:
batch_data = image_processor \
.preprocess([img["image"] for img in data], return_tensors="pt") \
.data

if len(data) > 0:
batch_data["im_start_id"] = data[0]["im_start_id"]
Expand Down Expand Up @@ -380,7 +415,7 @@ def __init__(
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImagePixelInputs],
image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
if hasattr(self.config, "scale_emb"):
Expand All @@ -389,7 +424,12 @@ def get_embedding(
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
if image_inputs["type"] == "image_embeds":
vision_hidden_states = (image_inputs["data"].type(
vlm_embedding.dtype).to(vlm_embedding.device))
else:
vision_hidden_states = self.get_vision_hidden_states(
image_inputs)

# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
Expand Down Expand Up @@ -440,9 +480,23 @@ def _parse_and_validate_inputs(
self,
input_ids: torch.Tensor,
**kwargs: object,
) -> Optional[MiniCPMVImagePixelInputs]:
) -> Optional[MiniCPMVImageInputs]:
pixel_values = kwargs.pop("pixel_values", [])
tgt_sizes = kwargs.pop("tgt_sizes", [])
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
image_embeds = kwargs.pop("image_embeds", None)

if image_embeds is not None:
return MiniCPMVImageEmbeddingInputs(
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
data=image_embeds,
type="image_embeds",
)

if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
Expand Down Expand Up @@ -477,19 +531,16 @@ def _parse_and_validate_inputs(
if len(pixel_values_flat) == 0:
return None

im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
if im_start_id is None:
return None

return MiniCPMVImagePixelInputs(
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
pixel_values=pixel_values_flat,
data=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
type="pixel_values",
)

def forward(
Expand Down Expand Up @@ -610,8 +661,8 @@ def get_vision_embedding(
) -> torch.Tensor:
raise NotImplementedError

def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
raise NotImplementedError

def is_default_weight_loading(self, name: str) -> bool:
Expand Down Expand Up @@ -705,9 +756,9 @@ def get_vision_embedding(
res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res)

def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]

return self.get_vision_embedding(pixel_values)

Expand Down Expand Up @@ -793,9 +844,9 @@ def get_vision_embedding(
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
return vision_embedding

def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"]

device = self.vpm.embeddings.position_embedding.weight.device
Expand Down Expand Up @@ -909,9 +960,9 @@ def get_vision_embedding(
)
return vision_embedding

def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"]

device = self.vpm.embeddings.position_embedding.weight.device
Expand Down

0 comments on commit 25bb619

Please sign in to comment.