Skip to content

Commit

Permalink
add CFG for .generate() (huggingface#24654)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vermeille authored Aug 6, 2023
1 parent a6e6b1c commit d533465
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"EncoderNoRepeatNGramLogitsProcessor",
"ExponentialDecayLengthPenalty",
"LogitNormalization",
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
]
_import_structure["stopping_criteria"] = [
"MaxNewTokensCriteria",
Expand Down Expand Up @@ -188,6 +189,7 @@
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
)
from .stopping_criteria import (
MaxLengthCriteria,
Expand Down
118 changes: 117 additions & 1 deletion src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import inspect
import math
from typing import Callable, Dict, Iterable, List, Tuple, Union
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -1334,3 +1334,119 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
scores[:, : self.semantic_vocab_size + self.codebook_size] = -float("inf")

return scores


class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
r"""Logits processor for Classifier-Free Guidance (CFG). The processors
computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits,
parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with
the `unconditional_ids` branch.
See [the paper](https://arxiv.org/abs/2306.17806) for more information.
Args:
guidance_scale (`float`):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
unconditional_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary for the unconditional branch. If unset, will default to
the last token of the prompt.
unconditional_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, **optional**):
Attention mask for unconditional_ids.
model (`PreTrainedModel`):
The model computing the unconditional scores. Supposedly the same as the one computing the conditional
scores. Both models must use the same tokenizer.
smooth_factor (`float`, **optional**):
The interpolation weight for CFG Rescale. 1 means no rescaling, 0 reduces to the conditional scores without
CFG. Turn it lower if the output degenerates.
use_cache (`bool`, **optional**):
Whether to cache key/values during the negative prompt forward pass.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> inputs = tokenizer(["Today, a dragon flew over Paris, France,"], return_tensors="pt")
>>> out = model.generate(inputs["input_ids"], guidance_scale=1.5)
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
The dragon flew over Paris, France, landing in Lyon, a city of a few million. Dragon-flying was a new form of
transport, and the dragon was the first in Europe.
>>> # with a negative prompt
>>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt")
>>> out = model.generate(inputs["input_ids"], guidance_scale=2, negative_prompt_ids=neg_inputs["input_ids"])
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
The dragon flew over Paris, France, crashing into Notre Dame Cathedral in the French capital killing at least 127
people and injuring more than 350.
```
"""

def __init__(
self,
guidance_scale: float,
model,
unconditional_ids: Optional[torch.LongTensor] = None,
unconditional_attention_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = True,
):
self.guidance_scale = guidance_scale
self.model = model
self.unconditional_context = {
"input_ids": unconditional_ids,
"attention_mask": unconditional_attention_mask,
"use_cache": use_cache,
"past_key_values": None,
"first_pass": True,
}

def get_unconditional_logits(self, input_ids):
if self.unconditional_context["first_pass"]:
if self.unconditional_context["input_ids"] is None:
self.unconditional_context["input_ids"] = input_ids[:, -1:]
if self.unconditional_context["attention_mask"] is None:
self.unconditional_context["attention_mask"] = torch.ones_like(
self.unconditional_context["input_ids"], dtype=torch.long
)
input_ids = self.unconditional_context["input_ids"]
attention_mask = self.unconditional_context["attention_mask"]
self.unconditional_context["first_pass"] = False
else:
attention_mask = torch.cat(
[
self.unconditional_context["attention_mask"],
torch.ones_like(input_ids[:, -1:], dtype=torch.long),
],
dim=1,
)
if not self.unconditional_context["use_cache"]:
input_ids = torch.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], dim=1)
else:
input_ids = input_ids[:, -1:]
self.unconditional_context["input_ids"] = input_ids
self.unconditional_context["attention_mask"] = attention_mask

out = self.model(
input_ids,
attention_mask=attention_mask,
use_cache=self.unconditional_context["use_cache"],
past_key_values=self.unconditional_context["past_key_values"],
)
self.unconditional_context["past_key_values"] = out.get("past_key_values", None)

return out.logits

def __call__(self, input_ids, scores):
scores = torch.nn.functional.log_softmax(scores, dim=-1)
if self.guidance_scale == 1:
return scores

logits = self.get_unconditional_logits(input_ids)

unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1)
out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
return out
27 changes: 24 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .configuration_utils import GenerationConfig
from .logits_process import (
ClassifierFreeGuidanceLogitsProcessor,
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper,
Expand All @@ -64,6 +63,7 @@
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
)
from .stopping_criteria import (
MaxLengthCriteria,
Expand Down Expand Up @@ -893,6 +893,9 @@ def _get_logits_processor(
encoder_input_ids: torch.LongTensor,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
logits_processor: Optional[LogitsProcessorList],
model_kwargs: Optional[Dict[str, Any]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
Expand All @@ -901,6 +904,16 @@ def _get_logits_processor(
# instantiate processors list
processors = LogitsProcessorList()

if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
processors.append(
UnbatchedClassifierFreeGuidanceLogitsProcessor(
generation_config.guidance_scale,
self,
unconditional_ids=negative_prompt_ids,
unconditional_attention_mask=negative_prompt_attention_mask,
use_cache=model_kwargs["use_cache"],
)
)
if generation_config.sequence_bias is not None:
processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))

Expand Down Expand Up @@ -998,8 +1011,6 @@ def _get_logits_processor(
)
if generation_config.forced_decoder_ids is not None:
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
processors.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
processors = self._merge_criteria_processor_list(processors, logits_processor)
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
Expand Down Expand Up @@ -1251,6 +1262,8 @@ def generate(
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -1308,6 +1321,11 @@ def generate(
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
The negative prompt needed for some processors such as CFG. The batch size must match the input batch
size. This is an experimental feature, subject to breaking API changes in future versions.
negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Attention_mask for `negative_prompt_ids`.
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
Expand Down Expand Up @@ -1511,6 +1529,9 @@ def generate(
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)

# 9. prepare stopping criteria
Expand Down
52 changes: 52 additions & 0 deletions tests/generation/test_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
)


Expand Down Expand Up @@ -743,3 +744,54 @@ def test_normalization(self):
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))

self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))

def test_classifier_free_guidance(self):
class Namespace(dict):
pass

logits_uncond = torch.tensor([[[1.0, 0, 1.5]]])
logits_cond = torch.tensor([[[1.0, 1.0, 1.0]]])

def dummy_model(input_ids, attention_mask, use_cache=True, past_key_values=None):
out = Namespace()
out.logits = logits_uncond
out.past_key_values = None
return out

def lsm(x):
return torch.nn.functional.log_softmax(x, dim=-1)

# explicit unconditional prompt + attention mask
input_ids = torch.LongTensor([[0]])
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(
1.5, dummy_model, input_ids, torch.ones_like(input_ids, dtype=torch.long)
)
out = cfg(input_ids, logits_cond)[0, -1]

res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]

self.assertAlmostEqual(out[0].item(), res[0].item())
self.assertAlmostEqual(out[1].item(), res[1].item())
self.assertAlmostEqual(out[2].item(), res[2].item())

# explicit unconditional prompt
input_ids = torch.LongTensor([[0]])
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model, input_ids)
out = cfg(input_ids, logits_cond)[0, -1]

res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]

self.assertAlmostEqual(out[0].item(), res[0].item())
self.assertAlmostEqual(out[1].item(), res[1].item())
self.assertAlmostEqual(out[2].item(), res[2].item())

# all implicit
input_ids = torch.LongTensor([[0]])
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model)
out = cfg(input_ids, logits_cond)[0, -1]

res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]

self.assertAlmostEqual(out[0].item(), res[0].item())
self.assertAlmostEqual(out[1].item(), res[1].item())
self.assertAlmostEqual(out[2].item(), res[2].item())
40 changes: 40 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2585,6 +2585,46 @@ def test_constrained_beam_search_mixed_mixin(self):
],
)

@slow
def test_cfg_mixin(self):
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

input = tokenizer(["The dragon flew over Paris,"], return_tensors="pt", return_attention_mask=True)
input["input_ids"] = input["input_ids"].to(torch_device)
input["attention_mask"] = input["attention_mask"].to(torch_device)

outputs = model.generate(**input, max_new_tokens=32, guidance_scale=1.5)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)

self.assertListEqual(
generated_text,
[
"The dragon flew over Paris, landing in the Rue de la Bastille. The crowd was so excited "
'that they had to leave the city.\n\n"We\'re going to Paris!"\n'
],
)

neg = tokenizer(["France,"], return_tensors="pt", return_attention_mask=True)
neg["input_ids"] = neg["input_ids"].to(torch_device)
neg["attention_mask"] = neg["attention_mask"].to(torch_device)
outputs = model.generate(
**input,
max_new_tokens=32,
guidance_scale=1.5,
negative_prompt_ids=neg["input_ids"],
negative_prompt_attention_mask=neg["attention_mask"],
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)

self.assertListEqual(
generated_text,
[
'The dragon flew over Paris, landing on the pavement.\n\n"Paris!"\n\n"Paris!"\n\n"'
'Paris!"\n\n"Paris!"\n\n"Paris!"\n\n'
],
)

@slow
def test_constrained_beam_search_example_translation_mixin(self):
# PT-only test: TF doesn't have constrained beam search
Expand Down

0 comments on commit d533465

Please sign in to comment.