Skip to content

Commit

Permalink
feat: integrate image tokens into inputs embeds
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Dec 17, 2024
1 parent 305db7e commit a59b7fa
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 27 deletions.
17 changes: 14 additions & 3 deletions router/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,20 @@ pub struct ClipVisionModel {

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Idefics3 {
pub(crate) vision_encoder_max_image_size: usize,
pub(crate) image_seq_len: usize,
pub struct Idefics3 {}

impl Idefics3 {
pub fn get_max_longest_edge(&self) -> usize {
364
}

pub fn get_number_of_features(&self) -> usize {
169
}

pub fn get_max_longest_edge_for_image_resize(&self) -> usize {
1456
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
Expand Down
48 changes: 29 additions & 19 deletions server/text_generation_server/models/custom_modeling/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,19 @@ def __init__(self, prefix, config, weights):
config.pad_token_id if config.pad_token_id is not None else -1
)

def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
# mask = input_ids == self.config.image_token_index
mask = input_ids == self.config.image_token_id
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -835,25 +848,22 @@ def forward(

all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
# TODO: finish implementing the image token replacement

# inputs_embeds = self.inputs_merger(
# input_ids=input_ids,
# inputs_embeds=inputs_embeds,
# image_hidden_states=image_hidden_states,
# )

# import ipdb; ipdb.set_trace()
# num_images, _, vision_hidden_size = image_hidden_states.shape
# special_image_token_mask = input_ids == self.image_token_id
# new_inputs_embeds = inputs_embeds.clone()
# reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size).to(
# inputs_embeds.dtype
# ) # cast to the dtype of the input_embeds to support quantized models
# new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
# inputs_embeds = new_inputs_embeds
# TODO: remove when prefill image tokens are handled correctly
# * for now dummy tokens are added instead of the image tokens output byt the vision model
mask_size = (input_ids == self.config.image_token_id).sum().item()
unrolled_image_size = (
image_hidden_states.shape[1] * image_hidden_states.shape[2]
)
diff = mask_size - unrolled_image_size
if diff > 0:
print(
f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}."
)

if mask_size == unrolled_image_size:
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)

hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
Expand Down
100 changes: 95 additions & 5 deletions server/text_generation_server/models/vlm_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,75 @@
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
IDEFICS2_IMAGE_TOKEN = "<image>"

IDEFICS3_IMAGE_TOKEN = "<image>"
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"


def _prompt_split_image(
image_seq_len,
image_rows,
image_cols,
fake_token_around_image,
image_token,
global_img_token,
):
"""Prompt with expanded image tokens for when the image is split into patches."""
text_split_images = ""
for n_h in range(image_rows):
for n_w in range(image_cols):
text_split_images += (
f"{fake_token_around_image}"
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
+ f"{image_token}" * image_seq_len
)
text_split_images += "\n"

text_split_images += (
f"\n{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)
return text_split_images


def _prompt_single_image(
image_seq_len, fake_token_around_image, image_token, global_img_token
):
"""Prompt with expanded image tokens for a single image."""
return (
f"{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)


def get_image_prompt_string(
image_rows,
image_cols,
image_seq_len,
fake_token_around_image,
image_token,
global_img_token,
):
if image_rows == 0 and image_cols == 0:
return _prompt_single_image(
image_seq_len,
fake_token_around_image=fake_token_around_image,
image_token=image_token,
global_img_token=global_img_token,
)
return _prompt_split_image(
image_seq_len,
image_rows,
image_cols,
fake_token_around_image,
image_token,
global_img_token,
)


def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Expand Down Expand Up @@ -55,8 +124,22 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
image_str *= 5
return image_str
if config.model_type == "idefics3":
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN}{IDEFICS2_FAKE_TOKEN}"
image_str = ""
# TODO: implement this in a more general way
n_rows = image_input["rows"][0][image_id]
n_cols = image_input["cols"][0][image_id]

# TODO: avoid using hardcoded values
image_seq_len = 169 # default value
# image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2))

image_str = get_image_prompt_string(
n_rows,
n_cols,
image_seq_len,
image_token=IDEFICS3_IMAGE_TOKEN,
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
)
return image_str
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
Expand Down Expand Up @@ -85,6 +168,10 @@ def image_text_replacement_fixup(config, text: str) -> str:
return text.replace(
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
)
if config.model_type == "idefics3":
return text.replace(
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
)
return text


Expand Down Expand Up @@ -198,7 +285,9 @@ def batch_tokenized_inputs(
raise RuntimeError(f"Invalid chunk type {chunk_type}")

if images:
image_inputs = processor.image_processor(images, return_tensors="pt")
image_inputs = processor.image_processor(
images, return_tensors="pt", return_row_col_info=True
)
else:
image_inputs = None

Expand All @@ -212,9 +301,10 @@ def batch_tokenized_inputs(
if chunk_type == "text":
full_text += chunk.text
elif chunk_type == "image":
full_text += image_text_replacement(
replacement_text = image_text_replacement(
processor, image_inputs, config, image_id
)
full_text += replacement_text
image_id += 1

full_text = image_text_replacement_fixup(config, full_text)
Expand Down Expand Up @@ -289,7 +379,7 @@ def __init__(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
**processor_kwargs,
# **processor_kwargs,
)
self.batch_class = batch_class
# import ipdb; ipdb.set_trace()
Expand Down

0 comments on commit a59b7fa

Please sign in to comment.