Skip to content

Commit

Permalink
🚨🚨🚨 Uniformize kwargs for TrOCR Processor (#34587)
Browse files Browse the repository at this point in the history
* Make kwargs uniform for TrOCR

* Add tests

* Put back current_processor

* Remove args

* Add todo comment

* Code review - breaking change
  • Loading branch information
tibor-reiss authored Nov 29, 2024
1 parent 0b5b5e6 commit 89d7bf5
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 11 deletions.
37 changes: 26 additions & 11 deletions src/transformers/models/trocr/processing_trocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,16 @@

import warnings
from contextlib import contextmanager
from typing import List, Union

from ...processing_utils import ProcessorMixin
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput


class TrOCRProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}


class TrOCRProcessor(ProcessorMixin):
Expand Down Expand Up @@ -61,7 +69,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
self.current_processor = self.image_processor
self._in_target_context_manager = False

def __call__(self, *args, **kwargs):
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[TrOCRProcessorKwargs],
) -> BatchFeature:
"""
When used in normal mode, this method forwards all its arguments to AutoImageProcessor's
[`~AutoImageProcessor.__call__`] and returns its output. If used in the context
Expand All @@ -70,21 +85,21 @@ def __call__(self, *args, **kwargs):
"""
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs)

images = kwargs.pop("images", None)
text = kwargs.pop("text", None)
if len(args) > 0:
images = args[0]
args = args[1:]
return self.current_processor(images, **kwargs)

if images is None and text is None:
raise ValueError("You need to specify either an `images` or `text` input to process.")

output_kwargs = self._merge_kwargs(
TrOCRProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

if images is not None:
inputs = self.image_processor(images, *args, **kwargs)
inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
if text is not None:
encodings = self.tokenizer(text, **kwargs)
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])

if text is None:
return inputs
Expand Down
129 changes: 129 additions & 0 deletions tests/models/trocr/test_processor_trocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os
import shutil
import tempfile
import unittest

import pytest

from transformers.models.xlm_roberta.tokenization_xlm_roberta import VOCAB_FILES_NAMES
from transformers.testing_utils import (
require_sentencepiece,
require_tokenizers,
require_vision,
)
from transformers.utils import is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from transformers import TrOCRProcessor, ViTImageProcessor, XLMRobertaTokenizerFast


@require_sentencepiece
@require_tokenizers
@require_vision
class TrOCRProcessorTest(ProcessorTesterMixin, unittest.TestCase):
text_input_name = "labels"
processor_class = TrOCRProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()

vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest"] # fmt: skip
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

image_processor = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit")
tokenizer = XLMRobertaTokenizerFast.from_pretrained("FacebookAI/xlm-roberta-base")
processor = TrOCRProcessor(image_processor=image_processor, tokenizer=tokenizer)
processor.save_pretrained(self.tmpdirname)

def tearDown(self):
shutil.rmtree(self.tmpdirname)

def get_tokenizer(self, **kwargs):
return XLMRobertaTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)

def get_image_processor(self, **kwargs):
return ViTImageProcessor.from_pretrained(self.tmpdirname, **kwargs)

def test_save_load_pretrained_default(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = TrOCRProcessor(image_processor=image_processor, tokenizer=tokenizer)

processor.save_pretrained(self.tmpdirname)
processor = TrOCRProcessor.from_pretrained(self.tmpdirname)

self.assertIsInstance(processor.tokenizer, XLMRobertaTokenizerFast)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
self.assertIsInstance(processor.image_processor, ViTImageProcessor)
self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string())

def test_save_load_pretrained_additional_features(self):
processor = TrOCRProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor())
processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)

processor = TrOCRProcessor.from_pretrained(
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
)

self.assertIsInstance(processor.tokenizer, XLMRobertaTokenizerFast)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())

self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
self.assertIsInstance(processor.image_processor, ViTImageProcessor)

def test_image_processor(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor)
image_input = self.prepare_image_inputs()

input_feat_extract = image_processor(image_input, return_tensors="np")
input_processor = processor(images=image_input, return_tensors="np")

for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)

def test_tokenizer(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor)
input_str = "lower newer"

encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)

for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])

def test_processor_text(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input)

self.assertListEqual(list(inputs.keys()), ["pixel_values", "labels"])

# test if it raises when no input is passed
with pytest.raises(ValueError):
processor()

def test_tokenizer_decode(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor)
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]

decoded_processor = processor.batch_decode(predicted_ids)
decoded_tok = tokenizer.batch_decode(predicted_ids)

self.assertListEqual(decoded_tok, decoded_processor)

0 comments on commit 89d7bf5

Please sign in to comment.