Skip to content

Commit

Permalink
BLIP: enable generation tests (#34174)
Browse files Browse the repository at this point in the history
* blip2 tests

* instructblips

* copies

* fix slow tests

* fix

* uncomment this

* clean up after rebase

* should be model main input

* fix overwritten tests

* oops len should be multiple of frame number

* style

* fix some tests
  • Loading branch information
zucchini-nlp authored Nov 1, 2024
1 parent 6beb3f1 commit 4cc0813
Show file tree
Hide file tree
Showing 8 changed files with 671 additions and 96 deletions.
21 changes: 4 additions & 17 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2342,24 +2342,11 @@ def generate(
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_kwargs,
)

# this is a temporary workaround to be consistent with other generation models and
# have BOS as the first token, even though under the hood we are calling LM with embeds
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
bos_tokens = (
torch.LongTensor([[self.config.text_config.bos_token_id]])
.repeat(batch_size, 1)
.to(image_embeds.device)
)
if not isinstance(outputs, torch.Tensor):
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
else:
outputs = torch.cat([bos_tokens, outputs], dim=-1)
inputs["input_ids"] = input_ids

outputs = self.language_model.generate(**inputs, **generate_kwargs)
return outputs


Expand Down
25 changes: 4 additions & 21 deletions src/transformers/models/instructblip/modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,27 +1625,10 @@ def generate(
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_kwargs,
)

# this is a temporary workaround to be consistent with other generation models and
# have BOS as the first token, even though under the hood we are calling LM with embeds
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
# the InstructBLIP authors used inconsistent tokenizer/model files during training,
# with the tokenizer's bos token being set to </s> which has ID=2,
# whereas the model's text config has bos token id = 0
bos_token_id = (
2
if self.config.text_config.architectures[0] == "LLaMAForCausalLM"
else self.config.text_config.bos_token_id
)
bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
if not isinstance(outputs, torch.Tensor):
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
else:
outputs = torch.cat([bos_tokens, outputs], dim=-1)
inputs["input_ids"] = input_ids

outputs = self.language_model.generate(**inputs, **generate_kwargs)

return outputs
Original file line number Diff line number Diff line change
Expand Up @@ -1660,27 +1660,10 @@ def generate(
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_kwargs,
)

# this is a temporary workaround to be consistent with other generation models and
# have BOS as the first token, even though under the hood we are calling LM with embeds
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
# the InstructBLIP authors used inconsistent tokenizer/model files during training,
# with the tokenizer's bos token being set to </s> which has ID=2,
# whereas the model's text config has bos token id = 0
bos_token_id = (
2
if self.config.text_config.architectures[0] == "LLaMAForCausalLM"
else self.config.text_config.bos_token_id
)
bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
if not isinstance(outputs, torch.Tensor):
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
else:
outputs = torch.cat([bos_tokens, outputs], dim=-1)
inputs["input_ids"] = input_ids

outputs = self.language_model.generate(**inputs, **generate_kwargs)

return outputs
Original file line number Diff line number Diff line change
Expand Up @@ -468,27 +468,10 @@ def generate(
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]

outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_kwargs,
)

# this is a temporary workaround to be consistent with other generation models and
# have BOS as the first token, even though under the hood we are calling LM with embeds
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
# the InstructBLIP authors used inconsistent tokenizer/model files during training,
# with the tokenizer's bos token being set to </s> which has ID=2,
# whereas the model's text config has bos token id = 0
bos_token_id = (
2
if self.config.text_config.architectures[0] == "LLaMAForCausalLM"
else self.config.text_config.bos_token_id
)
bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
if not isinstance(outputs, torch.Tensor):
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
else:
outputs = torch.cat([bos_tokens, outputs], dim=-1)
inputs["input_ids"] = input_ids

outputs = self.language_model.generate(**inputs, **generate_kwargs)

return outputs
1 change: 1 addition & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@


class GenerationTesterMixin:
input_name = "input_ids"
model_tester = None
all_generative_model_classes = ()
max_new_tokens = 3
Expand Down
Loading

0 comments on commit 4cc0813

Please sign in to comment.