From 6206dcb29eb99b3eebf5f00c97a5690c9b7df4f1 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sat, 6 Jul 2024 18:25:50 -0700 Subject: [PATCH] [Model] Add PaliGemma (#5189) Co-authored-by: Woosuk Kwon --- docs/source/models/supported_models.rst | 4 + examples/paligemma_example.py | 52 ++++ tests/models/test_paligemma.py | 147 ++++++++++ vllm/model_executor/models/__init__.py | 2 + vllm/model_executor/models/gemma.py | 10 +- vllm/model_executor/models/paligemma.py | 344 ++++++++++++++++++++++++ 6 files changed, 557 insertions(+), 2 deletions(-) create mode 100644 examples/paligemma_example.py create mode 100644 tests/models/test_paligemma.py create mode 100644 vllm/model_executor/models/paligemma.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index e64a072394680..f56679c3c6d00 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -186,6 +186,10 @@ Vision Language Models - LLaVA-NeXT - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - + * - :code:`PaliGemmaForConditionalGeneration` + - PaliGemma + - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. + - * - :code:`Phi3VForCausalLM` - Phi-3-Vision - :code:`microsoft/Phi-3-vision-128k-instruct`, etc. diff --git a/examples/paligemma_example.py b/examples/paligemma_example.py new file mode 100644 index 0000000000000..b315eafe5dda4 --- /dev/null +++ b/examples/paligemma_example.py @@ -0,0 +1,52 @@ +import os +import subprocess + +from PIL import Image + +from vllm import LLM + +# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`. +# You can use `.buildkite/download-images.sh` to download them + + +def run_paligemma(): + llm = LLM(model="google/paligemma-3b-mix-224") + + prompt = "caption es" + + image = Image.open("images/stop_sign.jpg") + + outputs = llm.generate({ + "prompt": prompt, + "multi_modal_data": { + "image": image + }, + }) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + + +def main(): + run_paligemma() + + +if __name__ == "__main__": + # Download from s3 + s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/" + local_directory = "images" + + # Make sure the local directory exists or create it + os.makedirs(local_directory, exist_ok=True) + + # Use AWS CLI to sync the directory, assume anonymous access + subprocess.check_call([ + "aws", + "s3", + "sync", + s3_bucket_path, + local_directory, + "--no-sign-request", + ]) + main() diff --git a/tests/models/test_paligemma.py b/tests/models/test_paligemma.py new file mode 100644 index 0000000000000..2b1d3c5b43b44 --- /dev/null +++ b/tests/models/test_paligemma.py @@ -0,0 +1,147 @@ +from typing import List, Optional, Tuple, Type + +import pytest +from transformers import AutoTokenizer + +from vllm.multimodal.utils import rescale_image_size +from vllm.sequence import SampleLogprobs + +from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from .utils import check_logprobs_close + +pytestmark = pytest.mark.vlm + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": "caption es", + "cherry_blossom": "What is in the picture?", + "boardwalk": "What is in the picture?", +}) + +IMAGE_TOKEN_ID = 257152 + +models = ["google/paligemma-3b-mix-224"] + + +def vllm_to_hf_output(vllm_output: Tuple[List[int], str, + Optional[SampleLogprobs]], + model: str): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + tokenizer = AutoTokenizer.from_pretrained(model) + eos_token_id = tokenizer.eos_token_id + + hf_output_ids = [ + token_id for idx, token_id in enumerate(output_ids) + if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID + ] + + hf_output_str = output_str + + if hf_output_ids[-1] == eos_token_id: + hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) + + return hf_output_ids, hf_output_str, out_logprobs + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding vision language config as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + + with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, model) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, + dtype: str, max_tokens: int, num_logprobs: int) -> None: + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors=size_factors, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index a4fe18d52d608..644b95aae3656 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -49,6 +49,8 @@ "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), + "PaliGemmaForConditionalGeneration": + ("paligemma", "PaliGemmaForConditionalGeneration"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index b603a59110915..16548c6c1e8c7 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -268,16 +268,22 @@ def __init__( normalizer = self.config.hidden_size**0.5 self.register_buffer("normalizer", torch.tensor(normalizer)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) hidden_states *= self.normalizer - residual = None for i in range(len(self.layers)): layer = self.layers[i] diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py new file mode 100644 index 0000000000000..2af2bedd8e48e --- /dev/null +++ b/vllm/model_executor/models/paligemma.py @@ -0,0 +1,344 @@ +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict + +import torch +from PIL import Image +from torch import nn +from transformers import PaliGemmaConfig, SiglipVisionConfig, SiglipVisionModel + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.gemma import GemmaModel +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import cached_get_tokenizer +from vllm.sequence import SamplerOutput, SequenceData + +from .interfaces import SupportsVision +from .utils import merge_vision_embeddings + +logger = init_logger(__name__) + +_KEYS_TO_MODIFY_MAPPING = { + "language_model.model": "language_model", +} + + +def get_max_paligemma_image_tokens(ctx: InputContext): + hf_config = ctx.get_hf_config(PaliGemmaConfig) + text_config = hf_config.text_config + + return text_config.num_image_tokens + + +def dummy_seq_data_for_paligemma( + hf_config: PaliGemmaConfig, + seq_len: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = hf_config.text_config.num_image_tokens + else: + image_feature_size = image_feature_size_override + + token_ids = [image_token_id] * image_feature_size + token_ids += [0] * (seq_len - image_feature_size) + return SequenceData(token_ids) + + +def dummy_image_for_paligemma( + hf_config: SiglipVisionConfig, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + width = height = hf_config.image_size + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return {"image": image} + + +def dummy_data_for_paligemma(ctx: InputContext, seq_len: int): + hf_config = ctx.get_hf_config(PaliGemmaConfig) + vision_config = hf_config.vision_config + + seq_data = dummy_seq_data_for_paligemma( + hf_config, + seq_len, + image_token_id=hf_config.image_token_index, + ) + + mm_data = dummy_image_for_paligemma(vision_config) + return seq_data, mm_data + + +def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): + + """ + The correct prompt format needs to be: + '' * image_feature_size + '' + prompt + '\n' + + See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55 + """ # noqa + + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + hf_config = ctx.get_hf_config(PaliGemmaConfig) + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + image_feature_size = hf_config.text_config.num_image_tokens + image_token_str = tokenizer.decode(hf_config.image_token_index) + bos_token = tokenizer.decode(hf_config.bos_token_id) + image_token_str_pad = image_token_str * image_feature_size + image_token_ids_pad = [hf_config.image_token_index] * image_feature_size + + orig_prompt = llm_inputs.get("prompt") + orig_prompt_ids = llm_inputs.get("prompt_token_ids") + + if image_token_str in orig_prompt: + logger.warning( + "The image token '%s' was detected in the prompt and " + "will be removed. Please follow the proper prompt format" + " documented on HuggingFace.", image_token_str) + orig_prompt = orig_prompt.replace(image_token_str, "") + orig_prompt_ids.remove(hf_config.image_token_index) + + new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n" + new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + +class PaliGemmaMultiModalProjector(nn.Module): + + def __init__(self, vision_hidden_size: int, projection_dim: int): + super().__init__() + + self.linear = ColumnParallelLinear(vision_hidden_size, + projection_dim, + bias=True) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.linear(image_features) + return hidden_states + + +class PaliGemmaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, num_channels, height, width)""" + + +PaliGemmaImageInputs = PaliGemmaImagePixelInputs + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma) +@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma) +class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision): + + def __init__(self, + config: PaliGemmaConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + # TODO(ywang96): Port over SiglipVisionModel & TP + self.vision_tower = SiglipVisionModel(config.vision_config) + self.multi_modal_projector = PaliGemmaMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + projection_dim=config.vision_config.projection_dim) + + self.quant_config = quant_config + self.language_model = GemmaModel(config.text_config, cache_config, + quant_config) + self.unpadded_vocab_size = config.text_config.vocab_size + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[PaliGemmaImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + + if pixel_values is None: + return None + + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return PaliGemmaImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + ) + + def _image_pixels_to_features(self, vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor) -> torch.Tensor: + + image_outputs = vision_tower(pixel_values, output_hidden_states=True) + + selected_image_features = image_outputs.last_hidden_state + + return selected_image_features + + def _process_image_pixels( + self, inputs: PaliGemmaImagePixelInputs) -> torch.Tensor: + assert self.vision_tower is not None + + pixel_values = inputs["data"] + + return self._image_pixels_to_features(self.vision_tower, pixel_values) + + def _process_image_input( + self, image_input: PaliGemmaImageInputs) -> torch.Tensor: + + assert self.vision_tower is not None + image_features = self._process_image_pixels(image_input) + + return self.multi_modal_projector(image_features) + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs: object) -> SamplerOutput: + + parsed_image_input = self._parse_and_validate_image_input(**kwargs) + + if parsed_image_input is not None: + vision_embeddings = self._process_image_input(parsed_image_input) + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa + vision_embeddings = vision_embeddings * (self.config.hidden_size** + -0.5) + + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + inputs_embeds = merge_vision_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) + + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model(input_ids, + positions, + kv_caches, + attn_metadata, + inputs_embeds=inputs_embeds) + + return hidden_states + + # Copied from vllm/model_executor/models/gemma.py + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.language_model.embed_tokens, + hidden_states, sampling_metadata) + return logits + + # Copied from vllm/model_executor/models/gemma.py + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + # Adapted from vllm/model_executor/models/gemma.py + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params = set() + for name, loaded_weight in weights: + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in name: + name = name.replace(key_to_modify, new_key) + use_default_weight_loading = False + if "vision" in name: + if self.vision_tower is not None: + # We only do sharding for language model and + # not vision model for now. + use_default_weight_loading = True + else: + for (param_name, shard_name, + shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # lm_head is not used in vllm as it is tied with + # embed_token. To prevent errors, skip loading + # lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + use_default_weight_loading = True + + if use_default_weight_loading: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + loaded_params.add(name) + + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + raise RuntimeError( + "Some weights are not initialized from checkpoints: " + f"{unloaded_params}")