diff --git a/router/src/config.rs b/router/src/config.rs index 8510d3560f1..4d5fcfa0639 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -112,9 +112,20 @@ pub struct ClipVisionModel { #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] -pub struct Idefics3 { - pub(crate) vision_encoder_max_image_size: usize, - pub(crate) image_seq_len: usize, +pub struct Idefics3 {} + +impl Idefics3 { + pub fn get_max_longest_edge(&self) -> usize { + 364 + } + + pub fn get_number_of_features(&self) -> usize { + 169 + } + + pub fn get_max_longest_edge_for_image_resize(&self) -> usize { + 1456 + } } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 1b7b7983552..6040625bfb2 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -753,6 +753,19 @@ def __init__(self, prefix, config, weights): config.pad_token_id if config.pad_token_id is not None else -1 ) + def _merge_input_ids_with_image_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + # mask = input_ids == self.config.image_token_index + mask = input_ids == self.config.image_token_id + # Let's pray we have enabled enough slots ! + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -835,25 +848,22 @@ def forward( all_states.append(image_hidden_states) image_hidden_states = torch.stack(all_states, dim=0) - # When we generate, we don't want to replace the potential image_token_id that we generated by images - # that simply don't exist - # TODO: finish implementing the image token replacement - - # inputs_embeds = self.inputs_merger( - # input_ids=input_ids, - # inputs_embeds=inputs_embeds, - # image_hidden_states=image_hidden_states, - # ) - - # import ipdb; ipdb.set_trace() - # num_images, _, vision_hidden_size = image_hidden_states.shape - # special_image_token_mask = input_ids == self.image_token_id - # new_inputs_embeds = inputs_embeds.clone() - # reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size).to( - # inputs_embeds.dtype - # ) # cast to the dtype of the input_embeds to support quantized models - # new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states - # inputs_embeds = new_inputs_embeds + # TODO: remove when prefill image tokens are handled correctly + # * for now dummy tokens are added instead of the image tokens output byt the vision model + mask_size = (input_ids == self.config.image_token_id).sum().item() + unrolled_image_size = ( + image_hidden_states.shape[1] * image_hidden_states.shape[2] + ) + diff = mask_size - unrolled_image_size + if diff > 0: + print( + f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}." + ) + + if mask_size == unrolled_image_size: + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_hidden_states + ) hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 61562f8a2da..bc1fd073113 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -23,6 +23,75 @@ IDEFICS2_FAKE_TOKEN = "" IDEFICS2_IMAGE_TOKEN = "" +IDEFICS3_IMAGE_TOKEN = "" +IDEFICS3_FAKE_IMAGE_TOKEN = "" +IDEFICS3_GLOBAL_IMG_TOKEN = "" + + +def _prompt_split_image( + image_seq_len, + image_rows, + image_cols, + fake_token_around_image, + image_token, + global_img_token, +): + """Prompt with expanded image tokens for when the image is split into patches.""" + text_split_images = "" + for n_h in range(image_rows): + for n_w in range(image_cols): + text_split_images += ( + f"{fake_token_around_image}" + + f"" + + f"{image_token}" * image_seq_len + ) + text_split_images += "\n" + + text_split_images += ( + f"\n{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" + ) + return text_split_images + + +def _prompt_single_image( + image_seq_len, fake_token_around_image, image_token, global_img_token +): + """Prompt with expanded image tokens for a single image.""" + return ( + f"{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" + ) + + +def get_image_prompt_string( + image_rows, + image_cols, + image_seq_len, + fake_token_around_image, + image_token, + global_img_token, +): + if image_rows == 0 and image_cols == 0: + return _prompt_single_image( + image_seq_len, + fake_token_around_image=fake_token_around_image, + image_token=image_token, + global_img_token=global_img_token, + ) + return _prompt_split_image( + image_seq_len, + image_rows, + image_cols, + fake_token_around_image, + image_token, + global_img_token, + ) + def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ @@ -55,8 +124,22 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str image_str *= 5 return image_str if config.model_type == "idefics3": - image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN}{IDEFICS2_FAKE_TOKEN}" - image_str = "" + # TODO: implement this in a more general way + n_rows = image_input["rows"][0][image_id] + n_cols = image_input["cols"][0][image_id] + + # TODO: avoid using hardcoded values + image_seq_len = 169 # default value + # image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2)) + + image_str = get_image_prompt_string( + n_rows, + n_cols, + image_seq_len, + image_token=IDEFICS3_IMAGE_TOKEN, + fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN, + global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, + ) return image_str elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] @@ -85,6 +168,10 @@ def image_text_replacement_fixup(config, text: str) -> str: return text.replace( f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN ) + if config.model_type == "idefics3": + return text.replace( + f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN + ) return text @@ -198,7 +285,9 @@ def batch_tokenized_inputs( raise RuntimeError(f"Invalid chunk type {chunk_type}") if images: - image_inputs = processor.image_processor(images, return_tensors="pt") + image_inputs = processor.image_processor( + images, return_tensors="pt", return_row_col_info=True + ) else: image_inputs = None @@ -212,9 +301,10 @@ def batch_tokenized_inputs( if chunk_type == "text": full_text += chunk.text elif chunk_type == "image": - full_text += image_text_replacement( + replacement_text = image_text_replacement( processor, image_inputs, config, image_id ) + full_text += replacement_text image_id += 1 full_text = image_text_replacement_fixup(config, full_text) @@ -289,7 +379,7 @@ def __init__( model_id, revision=revision, trust_remote_code=trust_remote_code, - **processor_kwargs, + # **processor_kwargs, ) self.batch_class = batch_class # import ipdb; ipdb.set_trace()