From 89d7bf584f52ff0b4f90cb0964e4c558be2c0141 Mon Sep 17 00:00:00 2001 From: Tibor Reiss <75096465+tibor-reiss@users.noreply.github.com> Date: Fri, 29 Nov 2024 12:58:11 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=F0=9F=9A=A8=20Uniformi?= =?UTF-8?q?ze=20kwargs=20for=20TrOCR=20Processor=20(#34587)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Make kwargs uniform for TrOCR * Add tests * Put back current_processor * Remove args * Add todo comment * Code review - breaking change --- .../models/trocr/processing_trocr.py | 37 +++-- tests/models/trocr/test_processor_trocr.py | 129 ++++++++++++++++++ 2 files changed, 155 insertions(+), 11 deletions(-) create mode 100644 tests/models/trocr/test_processor_trocr.py diff --git a/src/transformers/models/trocr/processing_trocr.py b/src/transformers/models/trocr/processing_trocr.py index b0d2e823fe6816..16b75b9812b482 100644 --- a/src/transformers/models/trocr/processing_trocr.py +++ b/src/transformers/models/trocr/processing_trocr.py @@ -18,8 +18,16 @@ import warnings from contextlib import contextmanager +from typing import List, Union -from ...processing_utils import ProcessorMixin +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +class TrOCRProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} class TrOCRProcessor(ProcessorMixin): @@ -61,7 +69,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): self.current_processor = self.image_processor self._in_target_context_manager = False - def __call__(self, *args, **kwargs): + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[TrOCRProcessorKwargs], + ) -> BatchFeature: """ When used in normal mode, this method forwards all its arguments to AutoImageProcessor's [`~AutoImageProcessor.__call__`] and returns its output. If used in the context @@ -70,21 +85,21 @@ def __call__(self, *args, **kwargs): """ # For backward compatibility if self._in_target_context_manager: - return self.current_processor(*args, **kwargs) - - images = kwargs.pop("images", None) - text = kwargs.pop("text", None) - if len(args) > 0: - images = args[0] - args = args[1:] + return self.current_processor(images, **kwargs) if images is None and text is None: raise ValueError("You need to specify either an `images` or `text` input to process.") + output_kwargs = self._merge_kwargs( + TrOCRProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: - inputs = self.image_processor(images, *args, **kwargs) + inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) if text is not None: - encodings = self.tokenizer(text, **kwargs) + encodings = self.tokenizer(text, **output_kwargs["text_kwargs"]) if text is None: return inputs diff --git a/tests/models/trocr/test_processor_trocr.py b/tests/models/trocr/test_processor_trocr.py new file mode 100644 index 00000000000000..b76af40280f2fe --- /dev/null +++ b/tests/models/trocr/test_processor_trocr.py @@ -0,0 +1,129 @@ +import os +import shutil +import tempfile +import unittest + +import pytest + +from transformers.models.xlm_roberta.tokenization_xlm_roberta import VOCAB_FILES_NAMES +from transformers.testing_utils import ( + require_sentencepiece, + require_tokenizers, + require_vision, +) +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import TrOCRProcessor, ViTImageProcessor, XLMRobertaTokenizerFast + + +@require_sentencepiece +@require_tokenizers +@require_vision +class TrOCRProcessorTest(ProcessorTesterMixin, unittest.TestCase): + text_input_name = "labels" + processor_class = TrOCRProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest"] # fmt: skip + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + image_processor = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit") + tokenizer = XLMRobertaTokenizerFast.from_pretrained("FacebookAI/xlm-roberta-base") + processor = TrOCRProcessor(image_processor=image_processor, tokenizer=tokenizer) + processor.save_pretrained(self.tmpdirname) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def get_tokenizer(self, **kwargs): + return XLMRobertaTokenizerFast.from_pretrained(self.tmpdirname, **kwargs) + + def get_image_processor(self, **kwargs): + return ViTImageProcessor.from_pretrained(self.tmpdirname, **kwargs) + + def test_save_load_pretrained_default(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + processor = TrOCRProcessor(image_processor=image_processor, tokenizer=tokenizer) + + processor.save_pretrained(self.tmpdirname) + processor = TrOCRProcessor.from_pretrained(self.tmpdirname) + + self.assertIsInstance(processor.tokenizer, XLMRobertaTokenizerFast) + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab()) + self.assertIsInstance(processor.image_processor, ViTImageProcessor) + self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string()) + + def test_save_load_pretrained_additional_features(self): + processor = TrOCRProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor()) + processor.save_pretrained(self.tmpdirname) + tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)") + image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0) + + processor = TrOCRProcessor.from_pretrained( + self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0 + ) + + self.assertIsInstance(processor.tokenizer, XLMRobertaTokenizerFast) + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + + self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.image_processor, ViTImageProcessor) + + def test_image_processor(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor) + image_input = self.prepare_image_inputs() + + input_feat_extract = image_processor(image_input, return_tensors="np") + input_processor = processor(images=image_input, return_tensors="np") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + def test_tokenizer(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor) + input_str = "lower newer" + + encoded_processor = processor(text=input_str) + encoded_tok = tokenizer(input_str) + + for key in encoded_tok.keys(): + self.assertListEqual(encoded_tok[key], encoded_processor[key]) + + def test_processor_text(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + + self.assertListEqual(list(inputs.keys()), ["pixel_values", "labels"]) + + # test if it raises when no input is passed + with pytest.raises(ValueError): + processor() + + def test_tokenizer_decode(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor) + predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] + + decoded_processor = processor.batch_decode(predicted_ids) + decoded_tok = tokenizer.batch_decode(predicted_ids) + + self.assertListEqual(decoded_tok, decoded_processor)