Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix Idefics3 processor - handle gracefully cases with text and no images #35363

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 43 additions & 35 deletions src/transformers/models/idefics3/processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,45 +283,53 @@ def __call__(
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
inputs.update(image_inputs)

if text is not None:
if n_images_in_images != n_images_in_text:
raise ValueError(
f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
)

image_rows = inputs.pop("rows", [[0] * len(text)])
image_cols = inputs.pop("cols", [[0] * len(text)])

fake_image_token = self.fake_image_token.content
image_token = self.image_token.content
global_img_token = self.global_image_tag

prompt_strings = []
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
image_prompt_strings = []
for n_rows, n_cols in zip(sample_rows, sample_cols):
image_prompt_string = get_image_prompt_string(
n_rows,
n_cols,
image_seq_len,
image_token=image_token,
fake_token_around_image=fake_image_token,
global_img_token=global_img_token,
if text is not None:
if n_images_in_images != n_images_in_text:
raise ValueError(
f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
)
image_prompt_strings.append(image_prompt_string)

split_sample = sample.split(image_token)
if len(split_sample) == 0:
raise ValueError("The image token should be present in the text.")
image_rows = inputs.pop("rows", [[0] * len(text)])
image_cols = inputs.pop("cols", [[0] * len(text)])

fake_image_token = self.fake_image_token.content
image_token = self.image_token.content
global_img_token = self.global_image_tag

prompt_strings = []
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
image_prompt_strings = []
for n_rows, n_cols in zip(sample_rows, sample_cols):
image_prompt_string = get_image_prompt_string(
n_rows,
n_cols,
image_seq_len,
image_token=image_token,
fake_token_around_image=fake_image_token,
global_img_token=global_img_token,
)
image_prompt_strings.append(image_prompt_string)

# Place in the image prompt strings where the image tokens are
sample = split_sample[0]
for i, image_prompt_string in enumerate(image_prompt_strings):
sample += image_prompt_string + split_sample[i + 1]
prompt_strings.append(sample)
split_sample = sample.split(image_token)
if len(split_sample) == 0:
raise ValueError("The image token should be present in the text.")

text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"])
# Place in the image prompt strings where the image tokens are
sample = split_sample[0]
for i, image_prompt_string in enumerate(image_prompt_strings):
sample += image_prompt_string + split_sample[i + 1]
prompt_strings.append(sample)

text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"])
inputs.update(text_inputs)

elif text is not None:
if any(n_images_in_text):
raise ValueError(
f"Found {sum(n_images_in_text)} {self.image_token.content} tokens in the text but no images were passed."
)
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
inputs.update(text_inputs)

return inputs
Expand Down
71 changes: 71 additions & 0 deletions tests/models/idefics3/test_processor_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,74 @@ def test_unstructured_kwargs(self):

self.assertEqual(inputs["pixel_values"].shape[3], 32)
self.assertEqual(len(inputs["input_ids"][0]), 120)

@require_torch
@require_vision
def test_text_only_inference(self):
"""Test that the processor works correctly with text-only input."""
processor = self.get_processor()

text = "This is a simple text without images."
inputs = processor(text=text)

tokenized_sentence = processor.tokenizer(text, add_special_tokens=False)
expected_input_ids = [[self.bos_token_id] + tokenized_sentence["input_ids"]]

self.assertEqual(inputs["input_ids"], expected_input_ids)
self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids[0])])
self.assertTrue("pixel_values" not in inputs)
self.assertTrue("pixel_attention_mask" not in inputs)

# Test batch of texts without image tokens
texts = ["First text.", "Second piece of text."]
batch_inputs = processor(text=texts, padding=True)

tokenized_1 = processor.tokenizer(texts[0], add_special_tokens=False)
tokenized_2 = processor.tokenizer(texts[1], add_special_tokens=False)

expected_1 = [self.bos_token_id] + tokenized_1["input_ids"]
expected_2 = [self.bos_token_id] + tokenized_2["input_ids"]

# Pad the shorter sequence
pad_len = len(expected_2) - len(expected_1)
if pad_len > 0:
padded_expected_1 = [self.padding_token_id] * pad_len + expected_1
expected_attention_1 = [0] * pad_len + [1] * len(expected_1)
self.assertEqual(batch_inputs["input_ids"], [padded_expected_1, expected_2])
self.assertEqual(batch_inputs["attention_mask"], [expected_attention_1, [1] * len(expected_2)])
else:
pad_len = -pad_len
padded_expected_2 = [self.padding_token_id] * pad_len + expected_2
expected_attention_2 = [0] * pad_len + [1] * len(expected_2)
self.assertEqual(batch_inputs["input_ids"], [expected_1, padded_expected_2])
self.assertEqual(batch_inputs["attention_mask"], [[1] * len(expected_1), expected_attention_2])

@require_torch
@require_vision
def test_missing_images_error(self):
"""Test that appropriate error is raised when images are referenced but not provided."""
processor = self.get_processor()

# Test single text with image token but no image
text = "Let me show you this image: <image> What do you think?"
with self.assertRaises(ValueError) as context:
processor(text=text)
self.assertTrue("tokens in the text but no images were passed" in str(context.exception))

# Test batch with image tokens but no images
texts = [
"First text with <image> token.",
"Second text <image> with token.",
]
with self.assertRaises(ValueError) as context:
processor(text=texts)
self.assertTrue("tokens in the text but no images were passed" in str(context.exception))

# Test with None as Images
with self.assertRaises(ValueError) as context:
processor(text=text, images=None)
self.assertTrue("tokens in the text but no images were passed" in str(context.exception))

with self.assertRaises(ValueError) as context:
processor(text=texts, images=None)
self.assertTrue("tokens in the text but no images were passed" in str(context.exception))
Loading