diff --git a/src/transformers/models/pixtral/image_processing_pixtral_fast.py b/src/transformers/models/pixtral/image_processing_pixtral_fast.py index 82fbf3b2c09..5fa23923fe7 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral_fast.py +++ b/src/transformers/models/pixtral/image_processing_pixtral_fast.py @@ -346,4 +346,7 @@ def preprocess( batch_images.append(images) batch_image_sizes.append(image_sizes) - return BatchMixFeature(data={"pixel_values": batch_images, "image_sizes": batch_image_sizes}, tensor_type=None) + return BatchMixFeature( + data={"pixel_values": batch_images, "image_sizes": batch_image_sizes}, + tensor_type=None, + ) diff --git a/tests/models/pixtral/test_image_processing_pixtral.py b/tests/models/pixtral/test_image_processing_pixtral.py index a45ead50612..1377b676917 100644 --- a/tests/models/pixtral/test_image_processing_pixtral.py +++ b/tests/models/pixtral/test_image_processing_pixtral.py @@ -19,8 +19,15 @@ import numpy as np import requests - -from transformers.testing_utils import require_torch, require_vision +from packaging import version + +from transformers.testing_utils import ( + require_torch, + require_torch_gpu, + require_vision, + slow, + torch_device, +) from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -157,6 +164,9 @@ def test_image_processor_properties(self): self.assertTrue(hasattr(image_processing, "image_std")) self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + # The following tests are overriden as PixtralImageProcessor can return images of different sizes + # and thus doesn't support returning batched tensors + def test_call_pil(self): for image_processing_class in self.image_processor_list: # Initialize image_processing @@ -273,6 +283,25 @@ def test_slow_fast_equivalence(self): self.assertTrue(torch.allclose(encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0], atol=1e-2)) + @slow + @require_torch_gpu + @require_vision + def test_can_compile_fast_image_processor(self): + if self.fast_image_processing_class is None: + self.skipTest("Skipping compilation test as fast image processor is not defined") + if version.parse(torch.__version__) < version.parse("2.3"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + torch.compiler.reset() + input_image = torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8) + image_processor = self.fast_image_processing_class(**self.image_processor_dict) + output_eager = image_processor(input_image, device=torch_device, return_tensors="pt") + + image_processor = torch.compile(image_processor, mode="reduce-overhead") + output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt") + + self.assertTrue(torch.allclose(output_eager.pixel_values[0][0], output_compiled.pixel_values[0][0], atol=1e-4)) + @unittest.skip(reason="PixtralImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy def test_call_numpy_4_channels(self): pass diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py index 221552175a9..1cb92174df1 100644 --- a/tests/test_image_processing_common.py +++ b/tests/test_image_processing_common.py @@ -23,10 +23,18 @@ import numpy as np import requests +from packaging import version from transformers import AutoImageProcessor, BatchFeature from transformers.image_utils import AnnotationFormat, AnnotionFormat -from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_vision +from transformers.testing_utils import ( + check_json_file_has_correct_format, + require_torch, + require_torch_gpu, + require_vision, + slow, + torch_device, +) from transformers.utils import is_torch_available, is_vision_available @@ -463,6 +471,25 @@ def test_image_processor_preprocess_arguments(self): if not is_tested: self.skipTest(reason="No validation found for `preprocess` method") + @slow + @require_torch_gpu + @require_vision + def test_can_compile_fast_image_processor(self): + if self.fast_image_processing_class is None: + self.skipTest("Skipping compilation test as fast image processor is not defined") + if version.parse(torch.__version__) < version.parse("2.3"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + torch.compiler.reset() + input_image = torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8) + image_processor = self.fast_image_processing_class(**self.image_processor_dict) + output_eager = image_processor(input_image, device=torch_device, return_tensors="pt") + + image_processor = torch.compile(image_processor, mode="reduce-overhead") + output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt") + + self.assertTrue(torch.allclose(output_eager.pixel_values, output_compiled.pixel_values, atol=1e-4)) + class AnnotationFormatTestMixin: # this mixin adds a test to assert that usages of the