From d533465150532b0c5de167b574e59f64c68b1154 Mon Sep 17 00:00:00 2001 From: "Guillaume \"Vermeille\" Sanchez" Date: Sun, 6 Aug 2023 21:15:24 +0200 Subject: [PATCH] add CFG for .generate() (#24654) --- src/transformers/generation/__init__.py | 2 + src/transformers/generation/logits_process.py | 118 +++++++++++++++++- src/transformers/generation/utils.py | 27 +++- tests/generation/test_logits_process.py | 52 ++++++++ tests/generation/test_utils.py | 40 ++++++ 5 files changed, 235 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 0a522e9bb7971f..f0da9f514e7af0 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -65,6 +65,7 @@ "EncoderNoRepeatNGramLogitsProcessor", "ExponentialDecayLengthPenalty", "LogitNormalization", + "UnbatchedClassifierFreeGuidanceLogitsProcessor", ] _import_structure["stopping_criteria"] = [ "MaxNewTokensCriteria", @@ -188,6 +189,7 @@ TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, + UnbatchedClassifierFreeGuidanceLogitsProcessor, ) from .stopping_criteria import ( MaxLengthCriteria, diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 6b1093761fb975..f33cef909f0846 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -15,7 +15,7 @@ import inspect import math -from typing import Callable, Dict, Iterable, List, Tuple, Union +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -1334,3 +1334,119 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to scores[:, : self.semantic_vocab_size + self.codebook_size] = -float("inf") return scores + + +class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): + r"""Logits processor for Classifier-Free Guidance (CFG). The processors + computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits, + parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with + the `unconditional_ids` branch. + + See [the paper](https://arxiv.org/abs/2306.17806) for more information. + + Args: + guidance_scale (`float`): + The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. + Higher guidance scale encourages the model to generate samples that are more closely linked to the input + prompt, usually at the expense of poorer quality. + unconditional_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary for the unconditional branch. If unset, will default to + the last token of the prompt. + unconditional_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, **optional**): + Attention mask for unconditional_ids. + model (`PreTrainedModel`): + The model computing the unconditional scores. Supposedly the same as the one computing the conditional + scores. Both models must use the same tokenizer. + smooth_factor (`float`, **optional**): + The interpolation weight for CFG Rescale. 1 means no rescaling, 0 reduces to the conditional scores without + CFG. Turn it lower if the output degenerates. + use_cache (`bool`, **optional**): + Whether to cache key/values during the negative prompt forward pass. + + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> inputs = tokenizer(["Today, a dragon flew over Paris, France,"], return_tensors="pt") + >>> out = model.generate(inputs["input_ids"], guidance_scale=1.5) + >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0] + The dragon flew over Paris, France, landing in Lyon, a city of a few million. Dragon-flying was a new form of + transport, and the dragon was the first in Europe. + + >>> # with a negative prompt + >>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt") + >>> out = model.generate(inputs["input_ids"], guidance_scale=2, negative_prompt_ids=neg_inputs["input_ids"]) + >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0] + The dragon flew over Paris, France, crashing into Notre Dame Cathedral in the French capital killing at least 127 + people and injuring more than 350. + ``` + """ + + def __init__( + self, + guidance_scale: float, + model, + unconditional_ids: Optional[torch.LongTensor] = None, + unconditional_attention_mask: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = True, + ): + self.guidance_scale = guidance_scale + self.model = model + self.unconditional_context = { + "input_ids": unconditional_ids, + "attention_mask": unconditional_attention_mask, + "use_cache": use_cache, + "past_key_values": None, + "first_pass": True, + } + + def get_unconditional_logits(self, input_ids): + if self.unconditional_context["first_pass"]: + if self.unconditional_context["input_ids"] is None: + self.unconditional_context["input_ids"] = input_ids[:, -1:] + if self.unconditional_context["attention_mask"] is None: + self.unconditional_context["attention_mask"] = torch.ones_like( + self.unconditional_context["input_ids"], dtype=torch.long + ) + input_ids = self.unconditional_context["input_ids"] + attention_mask = self.unconditional_context["attention_mask"] + self.unconditional_context["first_pass"] = False + else: + attention_mask = torch.cat( + [ + self.unconditional_context["attention_mask"], + torch.ones_like(input_ids[:, -1:], dtype=torch.long), + ], + dim=1, + ) + if not self.unconditional_context["use_cache"]: + input_ids = torch.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], dim=1) + else: + input_ids = input_ids[:, -1:] + self.unconditional_context["input_ids"] = input_ids + self.unconditional_context["attention_mask"] = attention_mask + + out = self.model( + input_ids, + attention_mask=attention_mask, + use_cache=self.unconditional_context["use_cache"], + past_key_values=self.unconditional_context["past_key_values"], + ) + self.unconditional_context["past_key_values"] = out.get("past_key_values", None) + + return out.logits + + def __call__(self, input_ids, scores): + scores = torch.nn.functional.log_softmax(scores, dim=-1) + if self.guidance_scale == 1: + return scores + + logits = self.get_unconditional_logits(input_ids) + + unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1) + out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits + return out diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 75b2c4f7145e1e..344326ae31bca0 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -38,7 +38,6 @@ from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .configuration_utils import GenerationConfig from .logits_process import ( - ClassifierFreeGuidanceLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor, EncoderRepetitionPenaltyLogitsProcessor, EpsilonLogitsWarper, @@ -64,6 +63,7 @@ TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, + UnbatchedClassifierFreeGuidanceLogitsProcessor, ) from .stopping_criteria import ( MaxLengthCriteria, @@ -893,6 +893,9 @@ def _get_logits_processor( encoder_input_ids: torch.LongTensor, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], logits_processor: Optional[LogitsProcessorList], + model_kwargs: Optional[Dict[str, Any]] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] @@ -901,6 +904,16 @@ def _get_logits_processor( # instantiate processors list processors = LogitsProcessorList() + if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: + processors.append( + UnbatchedClassifierFreeGuidanceLogitsProcessor( + generation_config.guidance_scale, + self, + unconditional_ids=negative_prompt_ids, + unconditional_attention_mask=negative_prompt_attention_mask, + use_cache=model_kwargs["use_cache"], + ) + ) if generation_config.sequence_bias is not None: processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias)) @@ -998,8 +1011,6 @@ def _get_logits_processor( ) if generation_config.forced_decoder_ids is not None: processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)) - if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: - processors.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) processors = self._merge_criteria_processor_list(processors, logits_processor) # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: @@ -1251,6 +1262,8 @@ def generate( synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" @@ -1308,6 +1321,11 @@ def generate( streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + The negative prompt needed for some processors such as CFG. The batch size must match the input batch + size. This is an experimental feature, subject to breaking API changes in future versions. + negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Attention_mask for `negative_prompt_ids`. kwargs (`Dict[str, Any]`, *optional*): Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder @@ -1511,6 +1529,9 @@ def generate( encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, + model_kwargs=model_kwargs, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, ) # 9. prepare stopping criteria diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index fed27097a0d350..e161f791caf853 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -51,6 +51,7 @@ TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, + UnbatchedClassifierFreeGuidanceLogitsProcessor, ) @@ -743,3 +744,54 @@ def test_normalization(self): self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones)) self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1))) + + def test_classifier_free_guidance(self): + class Namespace(dict): + pass + + logits_uncond = torch.tensor([[[1.0, 0, 1.5]]]) + logits_cond = torch.tensor([[[1.0, 1.0, 1.0]]]) + + def dummy_model(input_ids, attention_mask, use_cache=True, past_key_values=None): + out = Namespace() + out.logits = logits_uncond + out.past_key_values = None + return out + + def lsm(x): + return torch.nn.functional.log_softmax(x, dim=-1) + + # explicit unconditional prompt + attention mask + input_ids = torch.LongTensor([[0]]) + cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor( + 1.5, dummy_model, input_ids, torch.ones_like(input_ids, dtype=torch.long) + ) + out = cfg(input_ids, logits_cond)[0, -1] + + res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1] + + self.assertAlmostEqual(out[0].item(), res[0].item()) + self.assertAlmostEqual(out[1].item(), res[1].item()) + self.assertAlmostEqual(out[2].item(), res[2].item()) + + # explicit unconditional prompt + input_ids = torch.LongTensor([[0]]) + cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model, input_ids) + out = cfg(input_ids, logits_cond)[0, -1] + + res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1] + + self.assertAlmostEqual(out[0].item(), res[0].item()) + self.assertAlmostEqual(out[1].item(), res[1].item()) + self.assertAlmostEqual(out[2].item(), res[2].item()) + + # all implicit + input_ids = torch.LongTensor([[0]]) + cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model) + out = cfg(input_ids, logits_cond)[0, -1] + + res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1] + + self.assertAlmostEqual(out[0].item(), res[0].item()) + self.assertAlmostEqual(out[1].item(), res[1].item()) + self.assertAlmostEqual(out[2].item(), res[2].item()) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 0f50632c63e41e..0cdc92398ba85a 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2585,6 +2585,46 @@ def test_constrained_beam_search_mixed_mixin(self): ], ) + @slow + def test_cfg_mixin(self): + model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + + input = tokenizer(["The dragon flew over Paris,"], return_tensors="pt", return_attention_mask=True) + input["input_ids"] = input["input_ids"].to(torch_device) + input["attention_mask"] = input["attention_mask"].to(torch_device) + + outputs = model.generate(**input, max_new_tokens=32, guidance_scale=1.5) + generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + "The dragon flew over Paris, landing in the Rue de la Bastille. The crowd was so excited " + 'that they had to leave the city.\n\n"We\'re going to Paris!"\n' + ], + ) + + neg = tokenizer(["France,"], return_tensors="pt", return_attention_mask=True) + neg["input_ids"] = neg["input_ids"].to(torch_device) + neg["attention_mask"] = neg["attention_mask"].to(torch_device) + outputs = model.generate( + **input, + max_new_tokens=32, + guidance_scale=1.5, + negative_prompt_ids=neg["input_ids"], + negative_prompt_attention_mask=neg["attention_mask"], + ) + generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + 'The dragon flew over Paris, landing on the pavement.\n\n"Paris!"\n\n"Paris!"\n\n"' + 'Paris!"\n\n"Paris!"\n\n"Paris!"\n\n' + ], + ) + @slow def test_constrained_beam_search_example_translation_mixin(self): # PT-only test: TF doesn't have constrained beam search