Skip to content

Commit

Permalink
Automatic compilation in generate: do not rely on inner function (#34923
Browse files Browse the repository at this point in the history
)

* compiled forward in PreTrainedModel

* update

* style

* update name

* trigger CIs

* Add way to use custom compile args

* style

* switch parameterization to generation_config

* Add to inits

* Update configuration_utils.py

* inits

* style

* docs

* style

* Update configuration_utils.py

* back without dataclass for repo consistency

* Update configuration_utils.py

* style

* style

* style once again

* add config serialization

* update

* true dataclass

* trigger CIs

* merge compile methods + remove serialization of compile config
  • Loading branch information
Cyrilvallez authored Dec 3, 2024
1 parent f9c7e60 commit ee37bf0
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 12 deletions.
6 changes: 6 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,9 @@ A [`Constraint`] can be used to force the generation to include specific tokens

[[autodoc]] SynthIDTextWatermarkDetector
- __call__

## Compile Utils

[[autodoc]] CompileConfig
- __call__

3 changes: 2 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [],
"generation": [
"CompileConfig",
"GenerationConfig",
"TextIteratorStreamer",
"TextStreamer",
Expand Down Expand Up @@ -4981,7 +4982,7 @@
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin

# Generation
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
from .generation import CompileConfig, GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
from .hf_argparser import HfArgumentParser

# Integrations
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
_import_structure = {
"configuration_utils": [
"BaseWatermarkingConfig",
"CompileConfig",
"GenerationConfig",
"GenerationMode",
"SynthIDTextWatermarkingConfig",
Expand Down Expand Up @@ -192,6 +193,7 @@
if TYPE_CHECKING:
from .configuration_utils import (
BaseWatermarkingConfig,
CompileConfig,
GenerationConfig,
GenerationMode,
SynthIDTextWatermarkingConfig,
Expand Down
69 changes: 67 additions & 2 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass, is_dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from .. import __version__
from ..configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -371,6 +371,12 @@ class GenerationConfig(PushToHubMixin):
to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
> Parameters related to performances and compilation
compile_config (CompileConfig, *optional*):
If using a static cache, this controls how `generate` will `compile` the forward pass for performance
gains.
> Wild card
generation_kwargs:
Expand Down Expand Up @@ -474,6 +480,9 @@ def __init__(self, **kwargs):
self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10)
self.target_lookbehind = kwargs.pop("target_lookbehind", 10)

# Performances
self.compile_config = kwargs.pop("compile_config", CompileConfig())

# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})

Expand Down Expand Up @@ -794,7 +803,13 @@ def validate(self, is_init=False):
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
self.watermarking_config.validate()

# 7. other incorrect combinations
# 7. performances arguments
if not isinstance(self.compile_config, CompileConfig):
raise ValueError(
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an instance of `CompileConfig`."
)

# 8. other incorrect combinations
if self.return_dict_in_generate is not True:
for extra_output_flag in self.extra_output_flags:
if getattr(self, extra_output_flag) is True:
Expand Down Expand Up @@ -1175,6 +1190,8 @@ def to_dict(self) -> Dict[str, Any]:
del output["_commit_hash"]
if "_original_object_hash" in output:
del output["_original_object_hash"]
if "compile_config" in output:
del output["compile_config"]

# Transformers version when serializing this file
output["transformers_version"] = __version__
Expand Down Expand Up @@ -1559,3 +1576,51 @@ def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProces
skip_first_ngram_calls=self.skip_first_ngram_calls,
debug_mode=self.debug_mode,
)


@dataclass
class CompileConfig(object):
"""
Class that holds arguments relative to `torch.compile` behavior, when using automatic compilation in `generate`.
See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments.
Args:
fullgraph (`bool`, *optional*, defaults to `True`):
If `True`, requires that the whole forward be capturable in a single graph.
dynamic (`bool` or `None`, *optional*):
Whether to try to use dynamic shape graphs.
backend (`str` or `Callable`, *optional*, defaults to `"inductor"`):
Backend to be used.
mode (`str`, *optional*, defaults to `"reduce-overhead"`):
Controls balance between performance and overhead.
options (`dict`, *optional*):
A dictionary of options to pass to the backend.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, CompileConfig
>>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b')
>>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b').cuda()
>>> # Automatic compile configuration, used with static cache
>>> compile_config = CompileConfig(dynamic=True)
>>> # Generation with static cache and compile config
>>> input = tokenizer.encode("Hello there, how", return_tensors="pt").cuda()
>>> output = model.generate(
... input, do_sample=False, max_new_tokens=300, cache_implementation="static", compile_config=compile_config
... )
>>> output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
```
"""

fullgraph: bool = True
dynamic: Optional[bool] = None
backend: Union[str, Callable] = "inductor"
mode: str = "reduce-overhead"
options: Optional[dict] = None

def to_dict(self) -> Dict[str, Any]:
"""Serializes this instance to a Python dictionary."""
return copy.deepcopy(self.__dict__)
14 changes: 6 additions & 8 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3230,16 +3230,14 @@ def _sample(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

def model_forward(model, *args, **kwargs):
return model.forward(*args, **kwargs)

model_forward = self.__call__
if isinstance(model_kwargs.get("past_key_values"), StaticCache):
if self.device.type == "cuda":
logger.warning_once("Using `torch.compile`.")
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
model_forward = self.get_compiled_call(generation_config.compile_config)

i = 0
is_prefill = True
while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):
Expand All @@ -3250,11 +3248,11 @@ def model_forward(model, *args, **kwargs):
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

if i == 0:
if is_prefill:
outputs = self(**model_inputs, return_dict=True)
i += 1
is_prefill = False
else:
outputs = model_forward(self, return_dict=True, **model_inputs)
outputs = model_forward(**model_inputs, return_dict=True)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
Expand Down
17 changes: 16 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import GenerationConfig, GenerationMixin
from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .loss.loss_utils import LOSS_MAPPING
from .pytorch_utils import ( # noqa: F401
Expand Down Expand Up @@ -5094,6 +5094,21 @@ def loss_function(self):
loss_type = "ForCausalLM"
return LOSS_MAPPING[loss_type]

def get_compiled_call(self, compile_config: CompileConfig):
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
(where we want the speed-ups of compiled version with static shapes)."""
# Only reset it if not present or different from previous config
default_config = getattr(self.generation_config, "compile_config", CompileConfig())
if (
not hasattr(self, "_compiled_call")
or getattr(self, "_last_compile_config", default_config) != compile_config
):
self._last_compile_config = compile_config
self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict())
return self._compiled_call


PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None:
Expand Down

0 comments on commit ee37bf0

Please sign in to comment.