From dc8b6eaeeeb59dd3089b478cc09b577f2c62a297 Mon Sep 17 00:00:00 2001 From: Duc-Viet Hoang Date: Fri, 20 Sep 2024 22:52:08 +0700 Subject: [PATCH] Fix contrastive search to correctly handle input with padding (#33507) * fix: handle padding in contrastive search for decoder-only models * fix: handle padding in contrastive search for encoder-decoder models * tests: move padding contrastive test to test_util, add t5 test * fix: handle if model_kwargs["decoder_attention_mask"] is None * refactor: improve padding input contrastive search generation tests * chore: _ranking_fast to use LongTensor for cosine_matrix_mask --- src/transformers/generation/utils.py | 24 ++++- tests/generation/test_utils.py | 135 +++++++++++++++++++++++++++ 2 files changed, 158 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2ba330101e0..2fe92d3e3ed 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2604,6 +2604,15 @@ def _contrastive_search( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + # Create cosine_matrix_mask based on the attention_mask + cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long) + if self.config.is_encoder_decoder: + if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None: + cosine_matrix_mask = model_kwargs["decoder_attention_mask"] + else: + cosine_matrix_mask = model_kwargs["attention_mask"] + cosine_matrix_mask = cosine_matrix_mask.repeat_interleave(top_k, dim=0) + this_peer_finished = False while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): @@ -2771,7 +2780,12 @@ def _contrastive_search( # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the # model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't # introduce (noticeable) slowdowns on single-device runs. - selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) + selected_idx = _ranking_fast( + context_hidden, next_hidden, top_k_probs, cosine_matrix_mask, penalty_alpha, top_k + ) + cosine_matrix_mask = torch.cat( + [cosine_matrix_mask, cosine_matrix_mask.new_ones((cosine_matrix_mask.shape[0], 1))], dim=-1 + ) selected_idx = selected_idx.to("cpu") # This will be used instead of the previous inneficient torch.stack(torch.split()) @@ -4283,6 +4297,7 @@ def _ranking_fast( context_hidden: torch.FloatTensor, next_hidden: torch.FloatTensor, next_top_k_probs: torch.FloatTensor, + cosine_matrix_mask: torch.LongTensor, alpha: float, beam_width: int, ) -> torch.FloatTensor: @@ -4294,6 +4309,13 @@ def _ranking_fast( norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] + + # Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions) + # Using a large negative value for masked positions + cosine_matrix_mask = cosine_matrix_mask.to(dtype=cosine_matrix.dtype) + cosine_matrix_mask = (1 - cosine_matrix_mask) * torch.finfo(cosine_matrix.dtype).min + cosine_matrix = cosine_matrix + cosine_matrix_mask + degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K] next_top_k_probs = next_top_k_probs.view(-1) # [B*K] contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e5a83beac86..2f8e60c7915 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -44,6 +44,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( AutoModelForCausalLM, @@ -59,6 +60,7 @@ GPT2Tokenizer, ImageGPTForCausalImageModeling, SpeechEncoderDecoderModel, + T5ForConditionalGeneration, ) from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache from transformers.generation import ( @@ -3644,6 +3646,139 @@ def test_init_static_cache_multi_gpu(self): value_cache_1 = results.past_key_values.value_cache[1] self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) + @slow + def test_padding_input_contrastive_search_gpt2(self): + # Load the pre-trained GPT-2 model and tokenizer + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2") + model.to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", clean_up_tokenization_spaces=True) + + # Set the tokenizer to left-pad the sequences + tokenizer.padding_side = "left" + + # Define the PAD token as the EOS token + tokenizer.pad_token = tokenizer.eos_token + model.generation_config.pad_token_id = model.generation_config.eos_token_id + + # Define the input prompt + prompt_text = "The whispered legends of the haunted mansion spoke" + + # Tokenize the input prompt + encoded_prompt = tokenizer(prompt_text, return_tensors="pt", padding=True) + input_ids = encoded_prompt.input_ids.to(torch_device) + attention_mask = encoded_prompt.attention_mask.to(torch_device) + + # Define the contrastive search params + penalty_alpha = 0.6 + top_k = 4 + + # Define the padding length to add to the input IDs and attention mask + padding_length = 10 + + # Generate text without padding + outputs = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + do_sample=False, + penalty_alpha=penalty_alpha, + top_k=top_k, + max_new_tokens=64, + ) + generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True) + + # Pad the input IDs and attention mask on the left + padded_input_ids = F.pad( + input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id + ) + padded_attention_mask = F.pad(attention_mask, (padding_length, 0), "constant", value=0) + + # Generate text with padded inputs + outputs_with_padding = model.generate( + input_ids=padded_input_ids, + attention_mask=padded_attention_mask, + do_sample=False, + penalty_alpha=penalty_alpha, + top_k=top_k, + max_new_tokens=64, + ) + generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True) + + # Assert that the generated texts are identical for padded and non-padded inputs + self.assertEqual(generated_text_no_padding, generated_text_with_padding) + self.assertEqual( + generated_text_with_padding, + 'The whispered legends of the haunted mansion spoke of the "souls of the dead" who were "falling ' + 'out of the sky" and "falling into the sea."\n\nThe ghostly apparitions were said to have been ' + 'created by the spirits of the dead, who were "falling out of the sky" and "falling into the sea', + ) + + @slow + def test_padding_input_contrastive_search_t5(self): + # Load the pre-trained T5 model and tokenizer + model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + model.to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small", clean_up_tokenization_spaces=True) + + # Define the input prompt + prompt_text = "translate English to German: I need to finish this task before the end of the day." + + # Tokenize the input prompt + encoded_prompt = tokenizer(prompt_text, return_tensors="pt") + input_ids = encoded_prompt.input_ids.to(torch_device) + attention_mask = encoded_prompt.attention_mask.to(torch_device) + + # Define the decoder prompt + decoder_prompt_text = "Ich muss diese Aufgabe" + encoded_decoder_prompt = tokenizer(decoder_prompt_text, add_special_tokens=False, return_tensors="pt") + decoder_input_ids = encoded_decoder_prompt.input_ids.to(torch_device) + decoder_attention_mask = encoded_decoder_prompt.attention_mask.to(torch_device) + + # Define the contrastive search params + penalty_alpha = 0.6 + top_k = 4 + + # Generate text without padding + outputs = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + do_sample=False, + penalty_alpha=penalty_alpha, + top_k=top_k, + max_new_tokens=64, + ) + generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True) + + # Define the padding length to add to the input IDs and attention mask + padding_length = 10 + + # Pad the decoder input IDs and attention mask on the left + padded_decoder_input_ids = F.pad( + decoder_input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id + ) + padded_decoder_attention_mask = F.pad(decoder_attention_mask, (padding_length, 0), "constant", value=0) + # Since the decoder_start_token_id is the same as the pad_token_id, + # the last padded token represents the decoder start token. + # Set the attention mask for the decoder_start_token_id to True (1). + padded_decoder_attention_mask[:, padding_length - 1] = 1 + # Generate text with padded inputs + outputs_with_padding = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=padded_decoder_input_ids, + decoder_attention_mask=padded_decoder_attention_mask, + do_sample=False, + penalty_alpha=penalty_alpha, + top_k=top_k, + max_new_tokens=64, + ) + generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True) + + # Assert that the generated texts are identical for padded and non-padded inputs + self.assertEqual(generated_text_no_padding, generated_text_with_padding) + self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.") + @require_torch class TokenHealingTestCase(unittest.TestCase):