Skip to content

Commit

Permalink
Cache: revert DynamicCache init for BC (#33861)
Browse files Browse the repository at this point in the history
* tmp commit

* tmp commit

* make fixup

* missing removal

* fix condition

* fix end-to-end compilation

* if -> elif

* BC

* BC

* use @deprecate_kwarg("num_hidden_layers", version="4.47.0")

* wups the import

* 🥴

---------

Co-authored-by: Arthur Zucker <[email protected]>
  • Loading branch information
gante and ArthurZucker authored Oct 4, 2024
1 parent f92d354 commit 38f9f10
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 56 deletions.
71 changes: 41 additions & 30 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
is_torchdynamo_compiling,
logging,
)
from .utils.deprecation import deprecate_kwarg


if is_hqq_available():
Expand Down Expand Up @@ -361,15 +362,12 @@ class DynamicCache(Cache):
```
"""

@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
super().__init__()
if num_hidden_layers is None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
else:
self.key_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
self.value_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Expand Down Expand Up @@ -425,11 +423,13 @@ def update(

# Update the cache
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append([])
self.value_cache.append([])
self.key_cache.append(key_states)
self.value_cache.append(value_states)
# content on layer cache can be a tensor and checking not tensor causes errors
# so we explicitly check for the empty list
elif self.key_cache[layer_idx] == []:
elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
Expand All @@ -441,9 +441,13 @@ def update(
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
if len(self.key_cache) <= layer_idx or (len(self.key_cache) > layer_idx and self.key_cache[layer_idx] == []):
return 0
return self.key_cache[layer_idx].shape[-2]
is_empty_layer = (
len(self.key_cache) == 0 # no cache in any layer
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
)
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
return layer_seq_length

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
Expand All @@ -458,12 +462,13 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
return legacy_cache

@classmethod
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls(num_hidden_layers)
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
Expand All @@ -486,23 +491,27 @@ def crop(self, max_length: int):
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]

def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List["DynamicCache"]:
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def batch_split(
self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
) -> List["DynamicCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
out = []
for i in range(0, full_batch_size, split_size):
current_split = DynamicCache(num_hidden_layers)
current_split = DynamicCache()
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
out.append(current_split)
return out

@classmethod
def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int) -> "DynamicCache":
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
cache = cls(num_hidden_layers)
cache = cls()
for idx in range(len(splits[0])):
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
Expand Down Expand Up @@ -618,7 +627,9 @@ def update(
self._seen_tokens += key_states.shape[-2]

# Update the cache
if len(self.key_cache) <= layer_idx:
if len(self.key_cache) < layer_idx:
raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
elif len(self.key_cache) == layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
self.original_device.append(key_states.device)
Expand Down Expand Up @@ -677,7 +688,9 @@ def update(
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]

if len(self.key_cache) <= layer_idx:
if len(self.key_cache) < layer_idx:
raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.")
elif len(self.key_cache) == layer_idx:
self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key))
self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value))
self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
Expand Down Expand Up @@ -1430,12 +1443,12 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:

@classmethod
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "EncoderDecoderCache":
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
cache = cls(
self_attention_cache=DynamicCache(num_hidden_layers),
cross_attention_cache=DynamicCache(num_hidden_layers),
self_attention_cache=DynamicCache(),
cross_attention_cache=DynamicCache(),
)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
Expand Down Expand Up @@ -1493,26 +1506,24 @@ def crop(self, maximum_length: int):
self.check_dynamic_cache(self.crop.__name__)
self.self_attention_cache.crop(maximum_length)

def batch_split(
self, full_batch_size: int, split_size: int, num_hidden_layers: int
) -> "List[EncoderDecoderCache]":
def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
self.check_dynamic_cache(self.batch_split.__name__)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)

out = []
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
out.append(EncoderDecoderCache(self_attn, cross_attn))
return out

@classmethod
def from_batch_splits(cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int) -> "EncoderDecoderCache":
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
self_attention_cache = DynamicCache(num_hidden_layers)
cross_attention_cache = DynamicCache(num_hidden_layers)
self_attention_cache = DynamicCache()
cross_attention_cache = DynamicCache()
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,11 +1697,10 @@ def _prepare_cache_for_generation(
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
else:
num_hidden_layers = self.config.get_text_config().num_hidden_layers
model_kwargs[cache_name] = (
DynamicCache(num_hidden_layers)
DynamicCache()
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers))
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)

def _supports_num_logits_to_keep(self) -> bool:
Expand Down
27 changes: 25 additions & 2 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,13 +1776,13 @@ def test_new_cache_format(self, num_beams, do_sample):
set_seed(seed)
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
set_seed(seed)
num_hidden_layers = config.get_text_config().num_hidden_layers
if config.is_encoder_decoder:
cache_cls = EncoderDecoderCache
past_key_values = cache_cls(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers))
past_key_values = cache_cls(DynamicCache(), DynamicCache())
else:
cache_cls = DynamicCache
past_key_values = cache_cls()

new_results = model.generate(past_key_values=past_key_values, **generation_kwargs, **inputs_dict)

# The two sets of generated sequences must match, despite the cache format between forward passes being
Expand Down Expand Up @@ -3725,6 +3725,29 @@ def test_padding_input_contrastive_search_t5(self):
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.")

def test_generate_compile_fullgraph_tiny(self):
"""
Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash)
NOTE: this test is quite slow (~20s on a consumer desktop), but it is important that we keep it as part of the
non-slow tests to prevent regressions!
"""
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM", torch_dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")

# compile generate
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")

# compiled generate does NOT accept parameterization except a) model inputs b) a generation config
generation_config = copy.deepcopy(model.generation_config)
generation_config.pad_token_id = model.config.eos_token_id

model_inputs = tokenizer(["Write a poem about the market crashing in summer"], return_tensors="pt")
model_inputs = model_inputs.to(model.device)
gen_out = compiled_generate(**model_inputs, generation_config=generation_config)
self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated


@require_torch
class TokenHealingTestCase(unittest.TestCase):
Expand Down
52 changes: 40 additions & 12 deletions tests/models/mllama/test_modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,45 +383,73 @@ def test_assisted_decoding_with_num_logits_to_keep(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_beam_sample_generate_dict_output():
def test_beam_sample_generate_dict_output(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_beam_search_generate_dict_output():
def test_beam_search_generate_dict_output(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_constrained_beam_search_generate_dict_output():
def test_constrained_beam_search_generate_dict_output(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_dola_decoding_sample():
def test_dola_decoding_sample(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_generate_methods_with_num_logits_to_keep():
def test_generate_methods_with_num_logits_to_keep(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_greedy_generate_dict_outputs():
def test_greedy_generate_dict_outputs(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_group_beam_search_generate_dict_output():
def test_group_beam_search_generate_dict_output(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_model_parallel_beam_search():
def test_model_parallel_beam_search(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_new_cache_format_2():
pass
@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
def test_new_cache_format_0(self):
super().test_new_cache_format_0()

@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
def test_new_cache_format_1(self):
super().test_new_cache_format_1()

@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
def test_new_cache_format_2(self):
super().test_new_cache_format_2()

@unittest.skip(reason="Failing test, need to fix")
def test_sample_generate_dict_output():
def test_sample_generate_dict_output(self):
pass

def test_generate_text_only_with_cache(self):
"""
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
required cache modifications (because layers are skipped in practice). This test should prevent regressions.
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()

inputs = self._prepare_for_class(inputs_dict, model_class)

input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]

model.generate(input_ids, use_cache=True)


@require_torch
class MllamaForConditionalGenerationIntegrationTest(unittest.TestCase):
Expand Down
14 changes: 5 additions & 9 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class CacheTest(unittest.TestCase):
def test_dynamic_cache_retrocompatibility(self):
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
legacy_cache = ()
new_cache = DynamicCache(num_hidden_layers=10)
new_cache = DynamicCache()

# Creates a new cache with 10 layers in both formats
for layer_idx in range(10):
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_dynamic_cache_retrocompatibility(self):
)

# Test 1: We can convert from legacy to new with no changes
from_legacy = DynamicCache.from_legacy_cache(legacy_cache, num_hidden_layers=10)
from_legacy = DynamicCache.from_legacy_cache(legacy_cache)
for layer_idx in range(10):
for key_value_idx in range(2):
self.assertTrue(
Expand All @@ -103,7 +103,7 @@ def test_reorder_cache_retrocompatibility(self):
legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function

legacy_cache = ()
new_cache = DynamicCache(num_hidden_layers=10)
new_cache = DynamicCache()

# Creates a new cache with 10 layers in both formats
for layer_idx in range(10):
Expand Down Expand Up @@ -240,9 +240,7 @@ def test_dynamic_cache_hard(self):
set_seed(0)
gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256)
set_seed(0)
gen_out = model.generate(
**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache(model.config.num_hidden_layers)
)
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache())
self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist())

decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
Expand Down Expand Up @@ -270,9 +268,7 @@ def test_dynamic_cache_batched(self):
model.device
)

gen_out = model.generate(
**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache(model.config.num_hidden_layers)
)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache())
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
self.assertListEqual(decoded, expected_text)
Expand Down

0 comments on commit 38f9f10

Please sign in to comment.