From 1646ffb4d19b6777fd45ac727c7a7c323d51e7f8 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 18 Nov 2024 13:21:07 +0100 Subject: [PATCH 001/174] VLMs: `patch_size` -> `num_image_tokens` in processing (#33424) * use num additional tokens * fix copies + docs * another fix copies :) * add docs * move order for BC --- docs/source/en/model_doc/blip-2.md | 4 +++ docs/source/en/model_doc/instructblip.md | 4 +++ docs/source/en/model_doc/instructblipvideo.md | 4 +++ docs/source/en/model_doc/llava.md | 7 ++++++ docs/source/en/model_doc/llava_next.md | 6 +++++ docs/source/en/model_doc/llava_next_video.md | 6 +++++ docs/source/en/model_doc/video_llava.md | 6 +++++ docs/source/en/model_doc/vipllava.md | 6 +++++ .../models/llava/processing_llava.py | 19 +++++++++++--- .../llava_next/processing_llava_next.py | 17 ++++++++++--- .../processing_llava_next_video.py | 25 +++++++++++++++---- .../video_llava/processing_video_llava.py | 24 +++++++++++++++--- tests/models/llava/test_modeling_llava.py | 2 ++ .../llava_next/test_modeling_llava_next.py | 4 +++ .../test_modeling_llava_next_video.py | 6 +++++ .../video_llava/test_modeling_video_llava.py | 4 +++ .../models/vipllava/test_modeling_vipllava.py | 2 ++ 17 files changed, 131 insertions(+), 15 deletions(-) diff --git a/docs/source/en/model_doc/blip-2.md b/docs/source/en/model_doc/blip-2.md index b57c69ca6b321b..4125d372d55ad5 100644 --- a/docs/source/en/model_doc/blip-2.md +++ b/docs/source/en/model_doc/blip-2.md @@ -40,6 +40,10 @@ The original code can be found [here](https://github.com/salesforce/LAVIS/tree/5 - BLIP-2 can be used for conditional text generation given an image and an optional text prompt. At inference time, it's recommended to use the [`generate`] method. - One can use [`Blip2Processor`] to prepare images for the model, and decode the predicted tokens ID's back to text. +> [!NOTE] +> BLIP models after release v4.46 will raise warnings about adding `processor.num_query_tokens = {{num_query_tokens}}` and expand model embeddings layer to add special `` token. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you. Adding these attributes means that BLIP will add the number of query tokens required per image and expand the text with as many `` placeholders as there will be query tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there wil be failure when merging the embeddings. +The attributes can be obtained from model config, as `model.config.num_query_tokens` and model embeddings expansion can be done by following [this link](https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042). + ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with BLIP-2. diff --git a/docs/source/en/model_doc/instructblip.md b/docs/source/en/model_doc/instructblip.md index b5fc634b621626..904a96bc786f07 100644 --- a/docs/source/en/model_doc/instructblip.md +++ b/docs/source/en/model_doc/instructblip.md @@ -33,6 +33,10 @@ The original code can be found [here](https://github.com/salesforce/LAVIS/tree/m InstructBLIP uses the same architecture as [BLIP-2](blip2) with a tiny but important difference: it also feeds the text prompt (instruction) to the Q-Former. +> [!NOTE] +> BLIP models after release v4.46 will raise warnings about adding `processor.num_query_tokens = {{num_query_tokens}}` and expand model embeddings layer to add special `` token. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you. Adding these attributes means that BLIP will add the number of query tokens required per image and expand the text with as many `` placeholders as there will be query tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there wil be failure when merging the embeddings. +The attributes can be obtained from model config, as `model.config.num_query_tokens` and model embeddings expansion can be done by following [this link](https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042). + ## InstructBlipConfig [[autodoc]] InstructBlipConfig diff --git a/docs/source/en/model_doc/instructblipvideo.md b/docs/source/en/model_doc/instructblipvideo.md index aa93feb6b6dced..8b2207ce176566 100644 --- a/docs/source/en/model_doc/instructblipvideo.md +++ b/docs/source/en/model_doc/instructblipvideo.md @@ -35,6 +35,10 @@ The original code can be found [here](https://github.com/salesforce/LAVIS/tree/m - The model was trained by sampling 4 frames per video, so it's recommended to sample 4 frames +> [!NOTE] +> BLIP models after release v4.46 will raise warnings about adding `processor.num_query_tokens = {{num_query_tokens}}` and expand model embeddings layer to add special `` token. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you. Adding these attributes means that BLIP will add the number of query tokens required per image and expand the text with as many `` placeholders as there will be query tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there wil be failure when merging the embeddings. +The attributes can be obtained from model config, as `model.config.num_query_tokens` and model embeddings expansion can be done by following [this link](https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042). + ## InstructBlipVideoConfig [[autodoc]] InstructBlipVideoConfig diff --git a/docs/source/en/model_doc/llava.md b/docs/source/en/model_doc/llava.md index 7f326bd0c006db..dec19ca5ef45db 100644 --- a/docs/source/en/model_doc/llava.md +++ b/docs/source/en/model_doc/llava.md @@ -40,6 +40,13 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/ - Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results. + +> [!NOTE] +> LLaVA models after release v4.46 will raise warnings about adding `processor.patch_size = {{patch_size}}`, `processor.num_additional_image_tokens = {{num_additional_image_tokens}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you. +Adding these attributes means that LLaVA will try to infer the number of image tokens required per image and expand the text with as many `` placeholders as there will be tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there will be failure when merging the embeddings. +The attributes can be obtained from model config, as `model.config.vision_config.patch_size` or `model.config.vision_feature_select_strategy`. The `num_additional_image_tokens` should be `1` if the vision backbone adds a CLS token or `0` if nothing extra is added to the vision patches. + + ### Single image inference For best results, we recommend users to use the processor's `apply_chat_template()` method to format your prompt correctly. For that you need to construct a conversation history, passing in a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities, as follows: diff --git a/docs/source/en/model_doc/llava_next.md b/docs/source/en/model_doc/llava_next.md index b9146fbd33478a..88bd63e7101f17 100644 --- a/docs/source/en/model_doc/llava_next.md +++ b/docs/source/en/model_doc/llava_next.md @@ -53,6 +53,12 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/ +> [!NOTE] +> LLaVA models after release v4.46 will raise warnings about adding `processor.patch_size = {{patch_size}}`, `processor.num_additional_image_tokens = {{num_additional_image_tokens}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you. +Adding these attributes means that LLaVA will try to infer the number of image tokens required per image and expand the text with as many `` placeholders as there will be tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there will be failure when merging the embeddings. +The attributes can be obtained from model config, as `model.config.vision_config.patch_size` or `model.config.vision_feature_select_strategy`. The `num_additional_image_tokens` should be `1` if the vision backbone adds a CLS token or `0` if nothing extra is added to the vision patches. + + - Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use the processor's `apply_chat_template` to format your prompts correctly. For that you have to construct a conversation history, passing a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities. Below is an example of how to do that and the list of formats accepted by each checkpoint. We will use [llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) and a conversation history of text and image. Each content field has to be a list of dicts, as follows: diff --git a/docs/source/en/model_doc/llava_next_video.md b/docs/source/en/model_doc/llava_next_video.md index fe905dfb7932ab..f8a149f12b6779 100644 --- a/docs/source/en/model_doc/llava_next_video.md +++ b/docs/source/en/model_doc/llava_next_video.md @@ -50,6 +50,12 @@ The original code can be found [here](https://github.com/LLaVA-VL/LLaVA-NeXT/tre +> [!NOTE] +> LLaVA models after release v4.46 will raise warnings about adding `processor.patch_size = {{patch_size}}`, `processor.num_additional_image_tokens = {{num_additional_image_tokens}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you. +Adding these attributes means that LLaVA will try to infer the number of image tokens required per image and expand the text with as many `` placeholders as there will be tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there will be failure when merging the embeddings. +The attributes can be obtained from model config, as `model.config.vision_config.patch_size` or `model.config.vision_feature_select_strategy`. The `num_additional_image_tokens` should be `1` if the vision backbone adds a CLS token or `0` if nothing extra is added to the vision patches. + + - Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use tokenizer's `apply_chat_template` to format your prompts correctly. Below is an example of how to do that. We will use [LLaVA-NeXT-Video-7B-hf](https://huggingface.co/llava-hf/LLaVA-NeXT-Video-7B-hf) and a conversation history of videos and images. Each content field has to be a list of dicts, as follows: diff --git a/docs/source/en/model_doc/video_llava.md b/docs/source/en/model_doc/video_llava.md index 1c4b5b4b874dd7..105307196effd0 100644 --- a/docs/source/en/model_doc/video_llava.md +++ b/docs/source/en/model_doc/video_llava.md @@ -54,6 +54,12 @@ This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanT The original code can be found [here](https://github.com/PKU-YuanGroup/Video-LLaVA). +> [!NOTE] +> LLaVA models after release v4.46 will raise warnings about adding `processor.patch_size = {{patch_size}}`, `processor.num_additional_image_tokens = {{num_additional_image_tokens}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you. +Adding these attributes means that LLaVA will try to infer the number of image tokens required per image and expand the text with as many `` placeholders as there will be tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there will be failure when merging the embeddings. +The attributes can be obtained from model config, as `model.config.vision_config.patch_size` or `model.config.vision_feature_select_strategy`. The `num_additional_image_tokens` should be `1` if the vision backbone adds a CLS token or `0` if nothing extra is added to the vision patches. + + ## Usage example ### Single Media Mode diff --git a/docs/source/en/model_doc/vipllava.md b/docs/source/en/model_doc/vipllava.md index b3e76cd292e40a..328310f3e26b77 100644 --- a/docs/source/en/model_doc/vipllava.md +++ b/docs/source/en/model_doc/vipllava.md @@ -39,6 +39,12 @@ This model was contributed by [Younes Belkada](https://huggingface.co/ybelkada) - Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results. +> [!NOTE] +> LLaVA models after release v4.46 will raise warnings about adding `processor.patch_size = {{patch_size}}`, `processor.num_additional_image_tokens = {{num_additional_image_tokens}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you. +Adding these attributes means that LLaVA will try to infer the number of image tokens required per image and expand the text with as many `` placeholders as there will be tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there will be failure when merging the embeddings. +The attributes can be obtained from model config, as `model.config.vision_config.patch_size` or `model.config.vision_feature_select_strategy`. The `num_additional_image_tokens` should be `1` if the vision backbone adds a CLS token or `0` if nothing extra is added to the vision patches. + + - For better results, we recommend users to use the processor's `apply_chat_template()` method to format your prompt correctly. For that you need to construct a conversation history, passing in a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities, as follows: ```python diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 820fa581711a63..08caa3d1d8a75a 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -58,10 +58,19 @@ class LlavaProcessor(ProcessorMixin): in a chat into a tokenizable string. image_token (`str`, *optional*, defaults to `""`): Special token used to denote image location. + num_additional_image_tokens (`int`, *optional*, defaults to 0): + Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other + extra tokens appended, no need to set this arg. """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token"] + valid_kwargs = [ + "chat_template", + "patch_size", + "vision_feature_select_strategy", + "image_token", + "num_additional_image_tokens", + ] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" @@ -73,9 +82,11 @@ def __init__( vision_feature_select_strategy=None, chat_template=None, image_token="", # set the default and let users change if they have peculiar special tokens in rare cases + num_additional_image_tokens=0, **kwargs, ): self.patch_size = patch_size + self.num_additional_image_tokens = num_additional_image_tokens self.vision_feature_select_strategy = vision_feature_select_strategy self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token super().__init__(image_processor, tokenizer, chat_template=chat_template) @@ -147,9 +158,11 @@ def __call__( # Replace the image token with the expanded image token sequence pixel_values = image_inputs["pixel_values"] height, width = get_image_size(to_numpy_array(pixel_values[0])) - num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1 + num_image_tokens = (height // self.patch_size) * ( + width // self.patch_size + ) + self.num_additional_image_tokens if self.vision_feature_select_strategy == "default": - num_image_tokens -= 1 + num_image_tokens -= self.num_additional_image_tokens prompt_strings = [] for sample in text: diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index 89b885f0f1abb2..09f9e621a5873e 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -61,10 +61,19 @@ class LlavaNextProcessor(ProcessorMixin): in a chat into a tokenizable string. image_token (`str`, *optional*, defaults to `""`): Special token used to denote image location. + num_additional_image_tokens (`int`, *optional*, defaults to 0): + Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other + extra tokens appended, no need to set this arg. """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token"] + valid_kwargs = [ + "chat_template", + "patch_size", + "vision_feature_select_strategy", + "image_token", + "num_additional_image_tokens", + ] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" @@ -76,9 +85,11 @@ def __init__( vision_feature_select_strategy=None, chat_template=None, image_token="", # set the default and let users change if they have peculiar special tokens in rare cases + num_additional_image_tokens=0, **kwargs, ): self.patch_size = patch_size + self.num_additional_image_tokens = num_additional_image_tokens self.vision_feature_select_strategy = vision_feature_select_strategy self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token super().__init__(image_processor, tokenizer, chat_template=chat_template) @@ -155,7 +166,7 @@ def __call__( orig_height, orig_width = image_size num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) if self.vision_feature_select_strategy == "default": - num_image_tokens -= 1 + num_image_tokens -= self.num_additional_image_tokens sample = sample.replace(self.image_token, "" * num_image_tokens, 1) prompt_strings.append(sample) prompt_strings = [sample.replace("", self.image_token) for sample in prompt_strings] @@ -178,7 +189,7 @@ def _get_number_of_features(self, orig_height: int, orig_width: int, height: int orig_height, orig_width, patches_height, patches_width, scale_height, scale_width ) # The base patch covers the entire image (+1 for the CLS) - base_features = patches_height * patches_width + 1 + base_features = patches_height * patches_width + self.num_additional_image_tokens num_image_tokens = unpadded_features + newline_features + base_features return num_image_tokens diff --git a/src/transformers/models/llava_next_video/processing_llava_next_video.py b/src/transformers/models/llava_next_video/processing_llava_next_video.py index a42aafcadd64c6..db4999a2a8ae04 100644 --- a/src/transformers/models/llava_next_video/processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/processing_llava_next_video.py @@ -58,12 +58,22 @@ class LlavaNextVideoProcessor(ProcessorMixin): Special token used to denote video location. image_token (`str`, *optional*, defaults to `""`): Special token used to denote image location. + num_additional_image_tokens (`int`, *optional*, defaults to 0): + Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other + extra tokens appended, no need to set this arg. """ # video and image processor share same args, but have different processing logic # only image processor config is saved in the hub attributes = ["video_processor", "image_processor", "tokenizer"] - valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token", "video_token"] + valid_kwargs = [ + "chat_template", + "patch_size", + "vision_feature_select_strategy", + "image_token", + "video_token", + "num_additional_image_tokens", + ] image_processor_class = "LlavaNextImageProcessor" video_processor_class = "LlavaNextVideoImageProcessor" tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") @@ -78,9 +88,11 @@ def __init__( vision_feature_select_strategy=None, video_token="