From 6beb3f1691ca13556c7421b2c0503068906137a2 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 1 Nov 2024 08:39:39 +0100 Subject: [PATCH] Blip: get/set input embeddings correctly (#34152) * set-get embeds * add tests * fix tests * remove * return dict True * fix tests * why did i remove this * enabel torchscript tests --- src/transformers/models/blip/modeling_blip.py | 31 ++- .../models/blip/modeling_blip_text.py | 6 + .../models/blip_2/modeling_blip_2.py | 26 +- tests/models/blip/test_modeling_blip.py | 12 +- tests/models/blip_2/test_modeling_blip_2.py | 239 +++++++++++++++++- .../test_modeling_instructblip.py | 2 +- .../test_modeling_instructblipvideo.py | 2 +- tests/test_modeling_common.py | 2 + 8 files changed, 288 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index e7df0578588653..b623d2a8adb17b 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -795,6 +795,12 @@ def __init__(self, config: BlipConfig): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + @add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING) def get_text_features( self, @@ -1053,8 +1059,11 @@ def __init__(self, config: BlipConfig): # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self) -> nn.Module: - return self.vision_model.embeddings.patch_embedding + def get_input_embeddings(self): + return self.text_decoder.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_decoder.set_input_embeddings(value) @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BlipForConditionalGenerationModelOutput, config_class=BlipVisionConfig) @@ -1117,7 +1126,8 @@ def forward( ) if not return_dict: - outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:] + outputs = (outputs[0], outputs[1]) if labels is not None else (outputs[0],) + outputs += (image_embeds, vision_outputs[0]) + vision_outputs[2:] return tuple(output for output in outputs if output is not None) return BlipForConditionalGenerationModelOutput( @@ -1232,8 +1242,12 @@ def __init__(self, config: BlipConfig): # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self) -> nn.Module: - return self.vision_model.embeddings.patch_embedding + def set_input_embeddings(self, value): + self.text_encoder.set_input_embeddings(value) + + def get_input_embeddings(self): + # This will return shared embeddings if they are shared else specific to encoder. + return self.text_encoder.get_input_embeddings() @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig) @@ -1474,8 +1488,11 @@ def __init__(self, config: BlipConfig): # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self) -> nn.Module: - return self.vision_model.embeddings.patch_embedding + def get_input_embeddings(self): + return self.text_encoder.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_encoder.set_input_embeddings(value) @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig) diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 5ee7ae21f9d549..97a4f523380bc5 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -817,6 +817,12 @@ def __init__(self, config): self.cls = BlipTextOnlyMLMHead(config) self.label_smoothing = config.label_smoothing + def get_input_embeddings(self): + return self.bert.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.bert.set_input_embeddings(new_embeddings) + def get_output_embeddings(self): return self.cls.predictions.decoder diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index eba82cd1b3c8e4..4c06d85b50df6a 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1768,11 +1768,12 @@ def forward( decoder_attention_mask=decoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, # toggle for easier access to loss/logits below labels=labels, ) - loss = outputs.loss if return_dict else outputs[0] - logits = outputs.logits if return_dict else outputs[1] + loss = outputs.loss + logits = outputs.logits + outputs = outputs.to_tuple() if not return_dict else outputs if not return_dict: output = (logits, vision_outputs, query_outputs, outputs) @@ -1810,6 +1811,12 @@ def __init__(self, config: Blip2Config): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + @add_start_docstrings_to_model_forward(BLIP_2_TEXT_WITH_PROJECTION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Blip2TextModelOutput, config_class=Blip2Config) def forward( @@ -2233,11 +2240,12 @@ def forward( decoder_attention_mask=decoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, # toggle for easier access to loss/logits below labels=labels, ) - loss = outputs.loss if return_dict else outputs[0] - logits = outputs.logits if return_dict else outputs[1] + loss = outputs.loss + logits = outputs.logits + outputs = outputs.to_tuple() if not return_dict else outputs if not return_dict: output = (logits, vision_outputs, query_outputs, outputs) @@ -2389,6 +2397,12 @@ def __init__(self, config: Blip2Config): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + @add_start_docstrings_to_model_forward(BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Blip2ImageTextMatchingModelOutput, config_class=Blip2Config) def forward( diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index d60c76393f02bb..d542757cbf879f 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -444,7 +444,7 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): fx_compatible = False test_head_masking = False test_pruning = False - test_resize_embeddings = False + test_resize_embeddings = True test_attention_outputs = False def setUp(self): @@ -738,7 +738,6 @@ def prepare_config_and_inputs_for_common(self): config, input_ids, attention_mask, pixel_values = config_and_inputs inputs_dict = { "input_ids": input_ids, - "labels": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values, } @@ -787,10 +786,10 @@ def prepare_config_and_inputs_for_common(self): config, input_ids, attention_mask, pixel_values = config_and_inputs inputs_dict = { "input_ids": input_ids, - "labels": input_ids, "decoder_input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values, + "labels": input_ids, } return config, inputs_dict @@ -802,7 +801,7 @@ class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase): fx_compatible = False test_head_masking = False test_pruning = False - test_resize_embeddings = False + test_resize_embeddings = True test_attention_outputs = False test_torchscript = False @@ -811,7 +810,6 @@ def setUp(self): def _prepare_inputs_for_vqa(self): _, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - inputs_dict["labels"] = inputs_dict["input_ids"] inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"] inputs_dict.pop("return_loss") return inputs_dict @@ -882,7 +880,7 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): fx_compatible = False test_head_masking = False test_pruning = False - test_resize_embeddings = False + test_resize_embeddings = True test_attention_outputs = False test_torchscript = False @@ -1110,7 +1108,7 @@ class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase): fx_compatible = False test_head_masking = False test_pruning = False - test_resize_embeddings = False + test_resize_embeddings = True test_attention_outputs = False test_torchscript = False diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index f2b945ef4451e4..d91adf1bd4104f 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -15,6 +15,7 @@ """Testing suite for the PyTorch BLIP-2 model.""" import inspect +import os import tempfile import unittest @@ -32,7 +33,7 @@ slow, torch_device, ) -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torch_sdpa_available, is_vision_available from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -443,7 +444,6 @@ def prepare_config_and_inputs_for_common(self): "pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask, - "labels": input_ids, } return config, inputs_dict @@ -456,7 +456,7 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT test_pruning = False test_resize_embeddings = False test_attention_outputs = False - test_torchscript = False + test_torchscript = True _is_composite = True def setUp(self): @@ -466,6 +466,116 @@ def test_for_conditional_generation(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs) + def _create_and_check_torchscript(self, config, inputs_dict): + # overwrite because BLIP requires ipnut ids and pixel values as input + if not self.test_torchscript: + self.skipTest(reason="test_torchscript is set to `False`") + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.torchscript = True + for model_class in self.all_model_classes: + for attn_implementation in ["eager", "sdpa"]: + if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()): + continue + + configs_no_init._attn_implementation = attn_implementation + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class) + + main_input_name = model_class.main_input_name + + try: + if model.config.is_encoder_decoder: + model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward + main_input = inputs[main_input_name] + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + decoder_input_ids = inputs["decoder_input_ids"] + decoder_attention_mask = inputs["decoder_attention_mask"] + model(main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask) + traced_model = torch.jit.trace( + model, (main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask) + ) + else: + main_input = inputs[main_input_name] + input_ids = inputs["input_ids"] + + if model.config._attn_implementation == "sdpa": + trace_input = {main_input_name: main_input, "input_ids": input_ids} + + if "attention_mask" in inputs: + trace_input["attention_mask"] = inputs["attention_mask"] + else: + self.skipTest(reason="testing SDPA without attention_mask is not supported") + + model(main_input, attention_mask=inputs["attention_mask"]) + # example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1. + traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input) + else: + model(main_input, input_ids) + traced_model = torch.jit.trace(model, (main_input, input_ids)) + except RuntimeError: + self.fail("Couldn't trace module.") + + with tempfile.TemporaryDirectory() as tmp_dir_name: + pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") + + try: + torch.jit.save(traced_model, pt_file_name) + except Exception: + self.fail("Couldn't save module.") + + try: + loaded_model = torch.jit.load(pt_file_name) + except Exception: + self.fail("Couldn't load module.") + + model.to(torch_device) + model.eval() + + loaded_model.to(torch_device) + loaded_model.eval() + + model_state_dict = model.state_dict() + loaded_model_state_dict = loaded_model.state_dict() + + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + + models_equal = True + for layer_name, p1 in model_state_dict.items(): + if layer_name in loaded_model_state_dict: + p2 = loaded_model_state_dict[layer_name] + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. + # (Even with this call, there are still memory leak by ~0.04MB) + self.clear_torch_jit_class_registry() + @unittest.skip(reason="Hidden_states is tested in individual model tests") def test_hidden_states_output(self): pass @@ -754,7 +864,6 @@ def prepare_config_and_inputs_for_common(self): "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": decoder_attention_mask, - "labels": labels, } return config, inputs_dict @@ -775,9 +884,9 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi fx_compatible = False test_head_masking = False test_pruning = False - test_resize_embeddings = False + test_resize_embeddings = True test_attention_outputs = False - test_torchscript = False + test_torchscript = True _is_composite = True # TODO: Fix the failed tests @@ -804,6 +913,116 @@ def test_for_conditional_generation(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs) + def _create_and_check_torchscript(self, config, inputs_dict): + # overwrite because BLIP requires ipnut ids and pixel values as input + if not self.test_torchscript: + self.skipTest(reason="test_torchscript is set to `False`") + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.torchscript = True + for model_class in self.all_model_classes: + for attn_implementation in ["eager", "sdpa"]: + if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()): + continue + + configs_no_init._attn_implementation = attn_implementation + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class) + + main_input_name = model_class.main_input_name + + try: + if model.config.is_encoder_decoder: + model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward + main_input = inputs[main_input_name] + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + decoder_input_ids = inputs["decoder_input_ids"] + decoder_attention_mask = inputs["decoder_attention_mask"] + model(main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask) + traced_model = torch.jit.trace( + model, (main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask) + ) + else: + main_input = inputs[main_input_name] + input_ids = inputs["input_ids"] + + if model.config._attn_implementation == "sdpa": + trace_input = {main_input_name: main_input, "input_ids": input_ids} + + if "attention_mask" in inputs: + trace_input["attention_mask"] = inputs["attention_mask"] + else: + self.skipTest(reason="testing SDPA without attention_mask is not supported") + + model(main_input, attention_mask=inputs["attention_mask"]) + # example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1. + traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input) + else: + model(main_input, input_ids) + traced_model = torch.jit.trace(model, (main_input, input_ids)) + except RuntimeError: + self.fail("Couldn't trace module.") + + with tempfile.TemporaryDirectory() as tmp_dir_name: + pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") + + try: + torch.jit.save(traced_model, pt_file_name) + except Exception: + self.fail("Couldn't save module.") + + try: + loaded_model = torch.jit.load(pt_file_name) + except Exception: + self.fail("Couldn't load module.") + + model.to(torch_device) + model.eval() + + loaded_model.to(torch_device) + loaded_model.eval() + + model_state_dict = model.state_dict() + loaded_model_state_dict = loaded_model.state_dict() + + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + + models_equal = True + for layer_name, p1 in model_state_dict.items(): + if layer_name in loaded_model_state_dict: + p2 = loaded_model_state_dict[layer_name] + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. + # (Even with this call, there are still memory leak by ~0.04MB) + self.clear_torch_jit_class_registry() + @unittest.skip(reason="Hidden_states is tested in individual model tests") def test_hidden_states_output(self): pass @@ -942,7 +1161,7 @@ def test_get_text_features(self): def test_get_image_features(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"] + keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] for key in keys_to_pop: inputs_dict.pop(key) @@ -962,7 +1181,7 @@ def test_get_image_features(self): def test_get_qformer_features(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"] + keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] for key in keys_to_pop: inputs_dict.pop(key) @@ -1072,7 +1291,7 @@ class Blip2TextModelWithProjectionTest(ModelTesterMixin, unittest.TestCase): test_pruning = False test_head_masking = False - test_resize_embeddings = False + test_resize_embeddings = True test_attention_outputs = False test_torchscript = False @@ -1396,7 +1615,7 @@ class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): fx_compatible = False test_head_masking = False test_pruning = False - test_resize_embeddings = False + test_resize_embeddings = True test_attention_outputs = False test_torchscript = False diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index 2771dac1e3767e..a9dba06dab823c 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -459,7 +459,7 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene fx_compatible = False test_head_masking = False test_pruning = False - test_resize_embeddings = False + test_resize_embeddings = True test_attention_outputs = False test_torchscript = False _is_composite = True diff --git a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py index 298c7a8d7ff46f..ce25571d29333e 100644 --- a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py +++ b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py @@ -479,7 +479,7 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest( fx_compatible = False test_head_masking = False test_pruning = False - test_resize_embeddings = False + test_resize_embeddings = True test_attention_outputs = False test_torchscript = False _is_composite = True diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 96d548972a91cb..13c4d5155be445 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1811,6 +1811,7 @@ def test_resize_tokens_embeddings(self): original_config, inputs_dict, ) = self.model_tester.prepare_config_and_inputs_for_common() + inputs_dict.pop("labels", None) for model_class in self.all_model_classes: config = copy.deepcopy(original_config) @@ -1988,6 +1989,7 @@ def test_resize_embeddings_untied(self): original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() original_config.tie_word_embeddings = False + inputs_dict.pop("labels", None) # if model cannot untied embeddings -> leave test if original_config.tie_word_embeddings: