diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 0dbf0cc682d7a3..2ee20aea5568a0 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -16,6 +16,7 @@ is_torchdynamo_compiling, logging, ) +from .utils.deprecation import deprecate_kwarg if is_hqq_available(): @@ -361,15 +362,12 @@ class DynamicCache(Cache): ``` """ + @deprecate_kwarg("num_hidden_layers", version="4.47.0") def __init__(self, num_hidden_layers: Optional[int] = None) -> None: super().__init__() - if num_hidden_layers is None: - self.key_cache: List[torch.Tensor] = [] - self.value_cache: List[torch.Tensor] = [] - else: - self.key_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)] - self.value_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)] self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: """ @@ -425,11 +423,13 @@ def update( # Update the cache if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append([]) + self.value_cache.append([]) self.key_cache.append(key_states) self.value_cache.append(value_states) - # content on layer cache can be a tensor and checking not tensor causes errors - # so we explicitly check for the empty list - elif self.key_cache[layer_idx] == []: + elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors self.key_cache[layer_idx] = key_states self.value_cache[layer_idx] = value_states else: @@ -441,9 +441,13 @@ def update( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` - if len(self.key_cache) <= layer_idx or (len(self.key_cache) > layer_idx and self.key_cache[layer_idx] == []): - return 0 - return self.key_cache[layer_idx].shape[-2] + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or len(self.key_cache[layer_idx]) == 0 # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" @@ -458,12 +462,13 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: return legacy_cache @classmethod + @deprecate_kwarg("num_hidden_layers", version="4.47.0") def from_legacy_cache( cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None ) -> "DynamicCache": """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for backward compatibility.""" - cache = cls(num_hidden_layers) + cache = cls() if past_key_values is not None: for layer_idx in range(len(past_key_values)): key_states, value_states = past_key_values[layer_idx] @@ -486,12 +491,15 @@ def crop(self, max_length: int): self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] - def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List["DynamicCache"]: + @deprecate_kwarg("num_hidden_layers", version="4.47.0") + def batch_split( + self, full_batch_size: int, split_size: int, num_hidden_layers: int = None + ) -> List["DynamicCache"]: """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by `_split_model_inputs()` in `generation.utils`""" out = [] for i in range(0, full_batch_size, split_size): - current_split = DynamicCache(num_hidden_layers) + current_split = DynamicCache() current_split._seen_tokens = self._seen_tokens current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] @@ -499,10 +507,11 @@ def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: return out @classmethod - def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int) -> "DynamicCache": + @deprecate_kwarg("num_hidden_layers", version="4.47.0") + def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" - cache = cls(num_hidden_layers) + cache = cls() for idx in range(len(splits[0])): key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] @@ -618,7 +627,9 @@ def update( self._seen_tokens += key_states.shape[-2] # Update the cache - if len(self.key_cache) <= layer_idx: + if len(self.key_cache) < layer_idx: + raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") + elif len(self.key_cache) == layer_idx: self.key_cache.append(key_states) self.value_cache.append(value_states) self.original_device.append(key_states.device) @@ -677,7 +688,9 @@ def update( if layer_idx == 0: self._seen_tokens += key_states.shape[-2] - if len(self.key_cache) <= layer_idx: + if len(self.key_cache) < layer_idx: + raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.") + elif len(self.key_cache) == layer_idx: self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key)) self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value)) self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) @@ -1430,12 +1443,12 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: @classmethod def from_legacy_cache( - cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None + cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None ) -> "EncoderDecoderCache": """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" cache = cls( - self_attention_cache=DynamicCache(num_hidden_layers), - cross_attention_cache=DynamicCache(num_hidden_layers), + self_attention_cache=DynamicCache(), + cross_attention_cache=DynamicCache(), ) if past_key_values is not None: for layer_idx in range(len(past_key_values)): @@ -1493,14 +1506,12 @@ def crop(self, maximum_length: int): self.check_dynamic_cache(self.crop.__name__) self.self_attention_cache.crop(maximum_length) - def batch_split( - self, full_batch_size: int, split_size: int, num_hidden_layers: int - ) -> "List[EncoderDecoderCache]": + def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]": """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by `_split_model_inputs()` in `generation.utils`""" self.check_dynamic_cache(self.batch_split.__name__) - self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers) - cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers) + self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) + cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) out = [] for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): @@ -1508,11 +1519,11 @@ def batch_split( return out @classmethod - def from_batch_splits(cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int) -> "EncoderDecoderCache": + def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" - self_attention_cache = DynamicCache(num_hidden_layers) - cross_attention_cache = DynamicCache(num_hidden_layers) + self_attention_cache = DynamicCache() + cross_attention_cache = DynamicCache() for idx in range(len(splits[0])): layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0) layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 43eda333149744..e00d0e41556f8a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1697,11 +1697,10 @@ def _prepare_cache_for_generation( # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that # keeps copying the cache thus using much more memory else: - num_hidden_layers = self.config.get_text_config().num_hidden_layers model_kwargs[cache_name] = ( - DynamicCache(num_hidden_layers) + DynamicCache() if not requires_cross_attention_cache - else EncoderDecoderCache(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers)) + else EncoderDecoderCache(DynamicCache(), DynamicCache()) ) def _supports_num_logits_to_keep(self) -> bool: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 29b31dab50a662..59192be876971b 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1776,13 +1776,13 @@ def test_new_cache_format(self, num_beams, do_sample): set_seed(seed) legacy_results = model.generate(**generation_kwargs, **inputs_dict) set_seed(seed) - num_hidden_layers = config.get_text_config().num_hidden_layers if config.is_encoder_decoder: cache_cls = EncoderDecoderCache - past_key_values = cache_cls(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers)) + past_key_values = cache_cls(DynamicCache(), DynamicCache()) else: cache_cls = DynamicCache past_key_values = cache_cls() + new_results = model.generate(past_key_values=past_key_values, **generation_kwargs, **inputs_dict) # The two sets of generated sequences must match, despite the cache format between forward passes being @@ -3725,6 +3725,29 @@ def test_padding_input_contrastive_search_t5(self): self.assertEqual(generated_text_no_padding, generated_text_with_padding) self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.") + def test_generate_compile_fullgraph_tiny(self): + """ + Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash) + NOTE: this test is quite slow (~20s on a consumer desktop), but it is important that we keep it as part of the + non-slow tests to prevent regressions! + """ + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-LlamaForCausalLM", torch_dtype=torch.bfloat16, device_map="auto" + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + + # compile generate + compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") + + # compiled generate does NOT accept parameterization except a) model inputs b) a generation config + generation_config = copy.deepcopy(model.generation_config) + generation_config.pad_token_id = model.config.eos_token_id + + model_inputs = tokenizer(["Write a poem about the market crashing in summer"], return_tensors="pt") + model_inputs = model_inputs.to(model.device) + gen_out = compiled_generate(**model_inputs, generation_config=generation_config) + self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated + @require_torch class TokenHealingTestCase(unittest.TestCase): diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py index f31957d78aa8a9..85e54f707d7d2e 100644 --- a/tests/models/mllama/test_modeling_mllama.py +++ b/tests/models/mllama/test_modeling_mllama.py @@ -383,45 +383,73 @@ def test_assisted_decoding_with_num_logits_to_keep(self): pass @unittest.skip(reason="Failing test, need to fix") - def test_beam_sample_generate_dict_output(): + def test_beam_sample_generate_dict_output(self): pass @unittest.skip(reason="Failing test, need to fix") - def test_beam_search_generate_dict_output(): + def test_beam_search_generate_dict_output(self): pass @unittest.skip(reason="Failing test, need to fix") - def test_constrained_beam_search_generate_dict_output(): + def test_constrained_beam_search_generate_dict_output(self): pass @unittest.skip(reason="Failing test, need to fix") - def test_dola_decoding_sample(): + def test_dola_decoding_sample(self): pass @unittest.skip(reason="Failing test, need to fix") - def test_generate_methods_with_num_logits_to_keep(): + def test_generate_methods_with_num_logits_to_keep(self): pass @unittest.skip(reason="Failing test, need to fix") - def test_greedy_generate_dict_outputs(): + def test_greedy_generate_dict_outputs(self): pass @unittest.skip(reason="Failing test, need to fix") - def test_group_beam_search_generate_dict_output(): + def test_group_beam_search_generate_dict_output(self): pass @unittest.skip(reason="Failing test, need to fix") - def test_model_parallel_beam_search(): + def test_model_parallel_beam_search(self): pass - @unittest.skip(reason="Failing test, need to fix") - def test_new_cache_format_2(): - pass + @is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization) + def test_new_cache_format_0(self): + super().test_new_cache_format_0() + + @is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization) + def test_new_cache_format_1(self): + super().test_new_cache_format_1() + + @is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization) + def test_new_cache_format_2(self): + super().test_new_cache_format_2() @unittest.skip(reason="Failing test, need to fix") - def test_sample_generate_dict_output(): + def test_sample_generate_dict_output(self): pass + def test_generate_text_only_with_cache(self): + """ + Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature + required cache modifications (because layers are skipped in practice). This test should prevent regressions. + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + model.generate(input_ids, use_cache=True) + @require_torch class MllamaForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 8392ed18b71665..4a6dae67cbc807 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -53,7 +53,7 @@ class CacheTest(unittest.TestCase): def test_dynamic_cache_retrocompatibility(self): """Tests that we can convert back and forth between the legacy cache format and DynamicCache""" legacy_cache = () - new_cache = DynamicCache(num_hidden_layers=10) + new_cache = DynamicCache() # Creates a new cache with 10 layers in both formats for layer_idx in range(10): @@ -83,7 +83,7 @@ def test_dynamic_cache_retrocompatibility(self): ) # Test 1: We can convert from legacy to new with no changes - from_legacy = DynamicCache.from_legacy_cache(legacy_cache, num_hidden_layers=10) + from_legacy = DynamicCache.from_legacy_cache(legacy_cache) for layer_idx in range(10): for key_value_idx in range(2): self.assertTrue( @@ -103,7 +103,7 @@ def test_reorder_cache_retrocompatibility(self): legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function legacy_cache = () - new_cache = DynamicCache(num_hidden_layers=10) + new_cache = DynamicCache() # Creates a new cache with 10 layers in both formats for layer_idx in range(10): @@ -240,9 +240,7 @@ def test_dynamic_cache_hard(self): set_seed(0) gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256) set_seed(0) - gen_out = model.generate( - **inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache(model.config.num_hidden_layers) - ) + gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache()) self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist()) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) @@ -270,9 +268,7 @@ def test_dynamic_cache_batched(self): model.device ) - gen_out = model.generate( - **inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache(model.config.num_hidden_layers) - ) + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache()) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"] self.assertListEqual(decoded, expected_text)