Skip to content

Commit

Permalink
Blip: get/set input embeddings correctly (#34152)
Browse files Browse the repository at this point in the history
* set-get embeds

* add tests

* fix tests

* remove

* return dict True

* fix tests

* why did i remove this

* enabel torchscript tests
  • Loading branch information
zucchini-nlp authored Nov 1, 2024
1 parent b53e44e commit 6beb3f1
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 32 deletions.
31 changes: 24 additions & 7 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,12 @@ def __init__(self, config: BlipConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.text_model.get_input_embeddings()

def set_input_embeddings(self, value):
self.text_model.set_input_embeddings(value)

@add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING)
def get_text_features(
self,
Expand Down Expand Up @@ -1053,8 +1059,11 @@ def __init__(self, config: BlipConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def get_input_embeddings(self):
return self.text_decoder.get_input_embeddings()

def set_input_embeddings(self, value):
self.text_decoder.set_input_embeddings(value)

@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BlipForConditionalGenerationModelOutput, config_class=BlipVisionConfig)
Expand Down Expand Up @@ -1117,7 +1126,8 @@ def forward(
)

if not return_dict:
outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]
outputs = (outputs[0], outputs[1]) if labels is not None else (outputs[0],)
outputs += (image_embeds, vision_outputs[0]) + vision_outputs[2:]
return tuple(output for output in outputs if output is not None)

return BlipForConditionalGenerationModelOutput(
Expand Down Expand Up @@ -1232,8 +1242,12 @@ def __init__(self, config: BlipConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def set_input_embeddings(self, value):
self.text_encoder.set_input_embeddings(value)

def get_input_embeddings(self):
# This will return shared embeddings if they are shared else specific to encoder.
return self.text_encoder.get_input_embeddings()

@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
Expand Down Expand Up @@ -1474,8 +1488,11 @@ def __init__(self, config: BlipConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def get_input_embeddings(self):
return self.text_encoder.get_input_embeddings()

def set_input_embeddings(self, value):
self.text_encoder.set_input_embeddings(value)

@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/blip/modeling_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,12 @@ def __init__(self, config):
self.cls = BlipTextOnlyMLMHead(config)
self.label_smoothing = config.label_smoothing

def get_input_embeddings(self):
return self.bert.get_input_embeddings()

def set_input_embeddings(self, new_embeddings):
self.bert.set_input_embeddings(new_embeddings)

def get_output_embeddings(self):
return self.cls.predictions.decoder

Expand Down
26 changes: 20 additions & 6 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,11 +1768,12 @@ def forward(
decoder_attention_mask=decoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True, # toggle for easier access to loss/logits below
labels=labels,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
loss = outputs.loss
logits = outputs.logits
outputs = outputs.to_tuple() if not return_dict else outputs

if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
Expand Down Expand Up @@ -1810,6 +1811,12 @@ def __init__(self, config: Blip2Config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value

@add_start_docstrings_to_model_forward(BLIP_2_TEXT_WITH_PROJECTION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Blip2TextModelOutput, config_class=Blip2Config)
def forward(
Expand Down Expand Up @@ -2233,11 +2240,12 @@ def forward(
decoder_attention_mask=decoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True, # toggle for easier access to loss/logits below
labels=labels,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
loss = outputs.loss
logits = outputs.logits
outputs = outputs.to_tuple() if not return_dict else outputs

if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
Expand Down Expand Up @@ -2389,6 +2397,12 @@ def __init__(self, config: Blip2Config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value

@add_start_docstrings_to_model_forward(BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Blip2ImageTextMatchingModelOutput, config_class=Blip2Config)
def forward(
Expand Down
12 changes: 5 additions & 7 deletions tests/models/blip/test_modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False

def setUp(self):
Expand Down Expand Up @@ -738,7 +738,6 @@ def prepare_config_and_inputs_for_common(self):
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"labels": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
}
Expand Down Expand Up @@ -787,10 +786,10 @@ def prepare_config_and_inputs_for_common(self):
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"labels": input_ids,
"decoder_input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"labels": input_ids,
}
return config, inputs_dict

Expand All @@ -802,7 +801,7 @@ class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False

Expand All @@ -811,7 +810,6 @@ def setUp(self):

def _prepare_inputs_for_vqa(self):
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict["labels"] = inputs_dict["input_ids"]
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
inputs_dict.pop("return_loss")
return inputs_dict
Expand Down Expand Up @@ -882,7 +880,7 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False

Expand Down Expand Up @@ -1110,7 +1108,7 @@ class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False

Expand Down
Loading

0 comments on commit 6beb3f1

Please sign in to comment.