diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index e616adbe6798..b2be3f238d0c 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1069,12 +1069,15 @@ class StaticCache(Cache): The maximum sequence length with which the model will be used. device (`torch.device` or `str`): The device on which the cache should be initialized. Should be the same as the layer. + The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta` + device by default, and then moved to input device when updating. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. + Example: ```python @@ -1096,6 +1099,7 @@ class StaticCache(Cache): """ # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. + @deprecate_kwarg("layer_device_map", version="4.52.0") def __init__( self, config: PretrainedConfig, @@ -1122,6 +1126,7 @@ def __init__( ) self.dtype = dtype + self.device = torch.device(device) if device is not None else torch.device("meta") self.num_key_value_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None @@ -1136,7 +1141,7 @@ def __init__( if layer_device_map is not None: layer_device = layer_device_map[idx] else: - layer_device = device + layer_device = self.device new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) # Notes: @@ -1181,6 +1186,9 @@ def update( """ cache_position = cache_kwargs.get("cache_position") + if self.key_cache[layer_idx].device.type == "meta": + self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device) + self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device) k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] @@ -1209,6 +1217,8 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. # TODO: deprecate this function in favor of `cache_position` + if self.key_cache[layer_idx].device.type == "meta": + return 0 return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() def get_max_cache_shape(self) -> Optional[int]: @@ -1217,9 +1227,10 @@ def get_max_cache_shape(self) -> Optional[int]: def reset(self): """Resets the cache values while preserving the objects""" for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() + if self.key_cache[layer_idx].device.type != "meta": + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() @property def batch_size(self): @@ -1257,6 +1268,8 @@ class SlidingWindowCache(StaticCache): The maximum sequence length with which the model will be used. device (`torch.device` or `str`): The device on which the cache should be initialized. Should be the same as the layer. + The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta` + device by default, and then moved to input device when updating. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): @@ -1321,8 +1334,15 @@ def update( cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor]: cache_position = cache_kwargs.get("cache_position") + + if self.key_cache[layer_idx].device.type == "meta": + self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device) + self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device) + k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) if cache_position.shape[0] > self.max_cache_len: @@ -1365,9 +1385,10 @@ def get_max_cache_shape(self) -> Optional[int]: def reset(self): for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() + if self.key_cache[layer_idx].device.type != "meta": + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() class EncoderDecoderCache(Cache): @@ -1561,8 +1582,10 @@ class HybridCache(Cache): smaller batch size is used. max_cache_len (`int`): The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*, defaults to `"cpu"`): + device (`torch.device` or `str`, *optional*): The device on which the cache should be initialized. Should be the same as the layer. + The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta` + device by default, and then moved to input device when updating. dtype (torch.dtype, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): @@ -1590,12 +1613,13 @@ class HybridCache(Cache): """ # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. + @deprecate_kwarg("layer_device_map", version="4.52.0") def __init__( self, config: PretrainedConfig, batch_size: int = None, max_cache_len: int = None, - device: Union[torch.device, str] = "cpu", + device: Union[torch.device, str] = None, dtype: torch.dtype = torch.float32, max_batch_size: Optional[int] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, @@ -1623,9 +1647,11 @@ def __init__( self.num_key_value_heads = ( config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) + + self.device = torch.device(device) if device is not None else torch.device("meta") layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC self.is_sliding = torch.tensor( - [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device + [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool ) self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] @@ -1640,7 +1666,7 @@ def __init__( if layer_device_map is not None: layer_device = layer_device_map[i] else: - layer_device = device + layer_device = self.device # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape @@ -1696,8 +1722,16 @@ def update( ) -> Tuple[torch.Tensor]: cache_position = cache_kwargs.get("cache_position") sliding_window = cache_kwargs.get("sliding_window") + + if self.key_cache[layer_idx].device.type == "meta": + self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device) + self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device) + k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + if sliding_window: update_fn = self._sliding_update else: @@ -1725,14 +1759,18 @@ def get_seq_length(self, layer_idx: Optional[int] = 0): "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " "Using the `layer_idx` argument is not supported." ) + + if self.key_cache[layer_idx].device.type == "meta": + return 0 return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() def reset(self): """Resets the cache values while preserving the objects""" for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() + if self.key_cache[layer_idx].device.type != "meta": + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() @property def batch_size(self): @@ -1757,10 +1795,14 @@ class MambaCache: The default `dtype` to use when initializing the layer. device (`torch.device` or `str`, *optional*): The device on which the cache should be initialized. Should be the same as the layer. + The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta` + device by default, and then moved to input device when updating. Attributes: dtype: (`torch.dtype`): The default `dtype` used to initializing the cache. + device (`torch.device`): + The default device on which the cache was initialized. intermediate_size: (`int`): Model's intermediate_size taken from config. ssm_state_size: (`int`): @@ -1809,30 +1851,40 @@ def __init__( self.intermediate_size = config.intermediate_size self.ssm_state_size = config.state_size self.conv_kernel_size = config.conv_kernel + self.device = torch.device(device) if device is not None else torch.device("meta") + + self.conv_states: List[torch.Tensor] = [] + self.ssm_states: List[torch.Tensor] = [] + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=self.device, + dtype=dtype, + ) + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=self.device, + dtype=dtype, + ) - self.conv_states: torch.Tensor = torch.zeros( - config.num_hidden_layers, - self.max_batch_size, - self.intermediate_size, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - self.ssm_states: torch.Tensor = torch.zeros( - config.num_hidden_layers, - self.max_batch_size, - self.intermediate_size, - self.ssm_state_size, - device=device, - dtype=dtype, - ) - - torch._dynamo.mark_static_address(self.conv_states) - torch._dynamo.mark_static_address(self.ssm_states) + torch._dynamo.mark_static_address(conv_state) + torch._dynamo.mark_static_address(ssm_state) + self.conv_states.append(conv_state) + self.ssm_states.append(ssm_state) def update_conv_state( self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor ) -> torch.Tensor: + if self.conv_states[layer_idx].device.type == "meta": + self.conv_states[layer_idx] = torch.zeros_like( + self.conv_states[layer_idx], + device=new_conv_state.device, + ) + conv_state = self.conv_states[layer_idx] cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) @@ -1843,12 +1895,15 @@ def update_conv_state( return self.conv_states[layer_idx] def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) return self.ssm_states[layer_idx] def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() + for layer_idx in range(len(self.conv_states)): + if self.conv_states[layer_idx].device.type != "meta": + # In-place ops prevent breaking the static address + self.conv_states[layer_idx].zero_() + self.ssm_states[layer_idx].zero_() @property def batch_size(self): @@ -1920,6 +1975,7 @@ class OffloadedStaticCache(StaticCache): ``` """ + @deprecate_kwarg("layer_device_map", version="4.52.0") def __init__( self, config: PretrainedConfig, @@ -1930,9 +1986,10 @@ def __init__( offload_device: Union[str, torch.device] = torch.device("cpu"), layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: + super(Cache, self).__init__() self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len - self.device = torch.device(device) if layer_device_map is None else layer_device_map[0] + self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0]) self.offload_device = torch.device(offload_device) self.dtype = dtype if dtype is not None else torch.float32 diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 655a388cb70d..461d7e121581 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1633,45 +1633,12 @@ def _get_cache( # models. May cause trobles with non-text modalities. cache_dtype = self.get_output_embeddings().weight.dtype - def get_layer_device_map(execution_device_map: Optional[dict] = None): - num_hidden_layers = self.config.get_text_config().num_hidden_layers - if execution_device_map is None: - return None - elif len(execution_device_map) == 1 and "" in execution_device_map: - return {idx: execution_device_map[""] for idx in range(num_hidden_layers)} - layer_device_map = {} - for layer in execution_device_map: - for idx in range(num_hidden_layers): - if f".{idx}." in f"{layer}.": - layer_device_map[idx] = execution_device_map[layer] - break - for idx in range(num_hidden_layers): - if idx not in layer_device_map: - raise RuntimeError(f"layer {idx} has not been mapped to a device.") - return layer_device_map - - execution_device_map = None - # Taken from dispatch_model from accelerate. - # This is needed here if we don't want to make changes in accelerate in order to save execution_device - # For offloaded case, we need to get the execution device, not just the device where it is offloaded - if hasattr(self, "hf_device_map"): - if set(self.hf_device_map.values()) == {"cpu"} or set(self.hf_device_map.values()) == {"cpu", "disk"}: - main_device = "cpu" - else: - main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0] - execution_device_map = { - name: main_device if device in ["cpu", "disk"] else device - for name, device in self.hf_device_map.items() - } - layer_device_map = get_layer_device_map(execution_device_map) - cache_kwargs = { "config": self.config.get_text_config(), "max_batch_size": batch_size, "max_cache_len": max_cache_len, - "device": device, "dtype": cache_dtype, - "layer_device_map": layer_device_map, + "device": device if cache_implementation == "offloaded_static" else None, } self._cache = cache_cls(**cache_kwargs) if requires_cross_attention_cache: diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 258017f14180..a0cbc8ba4e78 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -73,6 +73,7 @@ def __init__(self, model: PreTrainedModel): batch_size=self.model.generation_config.cache_config.batch_size, max_cache_len=self.model.generation_config.cache_config.max_cache_len, dtype=self.model.dtype, + device=self.model.generation_config.cache_config.device, ) self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures) if self.is_causal: diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 0b38c89d75a5..15469577fb41 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -582,7 +582,6 @@ def forward( self.config, max_batch_size=batch_size, max_cache_len=seq_len, - device=self.device, dtype=inputs_embeds.dtype, ) diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 78419e78c08b..7020df27021f 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -461,7 +461,6 @@ def forward( self.config, max_batch_size=batch_size, max_cache_len=seq_len, - device=self.device, dtype=inputs_embeds.dtype, ) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index e64559b26650..fb7e59051a83 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -579,7 +579,6 @@ def forward( self.config, max_batch_size=batch_size, max_cache_len=seq_len, - device=self.device, dtype=inputs_embeds.dtype, ) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 5f21fc6bfffd..53a947eb95b3 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -405,7 +405,6 @@ def forward( self.config, max_batch_size=batch_size, max_cache_len=seq_len, - device=self.device, dtype=inputs_embeds.dtype, ) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 664616306d88..8d492ce673da 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -728,22 +728,13 @@ def test_compile_static_cache(self): dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) - # Static Cache + # Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used) generated_ids = model.generate( **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" ) static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) - # Static Cache + compile - model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"` - model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) - generated_ids = model.generate( - **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" - ) - static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) - @slow @require_read_token def test_export_static_cache(self): @@ -795,6 +786,7 @@ def test_export_static_cache(self): cache_config={ "batch_size": batch_size, "max_cache_len": max_generation_length, + "device": device, }, ), ) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cf259fabe302..d361378503fa 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4635,6 +4635,11 @@ def test_flash_attn_2_from_config(self): fa2_correctly_converted = True break + fa2_correctly_converted = ( + fa2_correctly_converted + if not model_class._supports_flex_attn + else fa2_model.config._attn_implementation == "flash_attention_2" + ) self.assertTrue(fa2_correctly_converted) _ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask) @@ -4653,6 +4658,11 @@ def test_flash_attn_2_from_config(self): fa2_correctly_converted = True break + fa2_correctly_converted = ( + fa2_correctly_converted + if not model_class._supports_flex_attn + else model_from_pretrained.config._attn_implementation == "flash_attention_2" + ) self.assertFalse(fa2_correctly_converted) def _get_custom_4d_mask_test_data(self): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 053d2cf6397a..d67b026638e9 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -198,6 +198,7 @@ def test_static_cache_exportability(self): cache_config={ "batch_size": batch_size, "max_cache_len": max_cache_len, + "device": device, }, ), ) @@ -310,11 +311,12 @@ def test_hybrid_cache_n_sequences(self): do_sample=False, max_new_tokens=20, num_return_sequences=2, + num_beams=2, ) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) expected_text = [ - "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", - "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", + "Hello I am doing a project for my school and I am trying to make a program that will allow me to input a", + "Hello I am doing a project for my school and I am trying to make a program that will allow me to use a", ] self.assertListEqual(decoded, expected_text) @@ -380,8 +382,6 @@ def test_sink_cache_iterative_prompts(self): [ ("eager", "static"), ("sdpa", "static"), - ("eager", "offloaded-static"), - ("sdpa", "offloaded-static"), ] ) def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation): @@ -427,8 +427,6 @@ def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_ [ ("eager", "static"), ("sdpa", "static"), - ("eager", "offloaded-static"), - ("sdpa", "offloaded-static"), ] ) def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation): @@ -462,26 +460,6 @@ def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache with self.subTest(f"{attn_implementation}, static, eager"): self.assertListEqual(decoded, EXPECTED_GENERATION) - set_seed(0) - model._forward = model.forward - compiled_forward = torch.compile(model.forward) - - def compiled(func, input_ids, **kwargs): - return func(input_ids, **kwargs) - - def call(input_ids, **kwargs): - if input_ids.shape[-1] == 1: - return compiled(compiled_forward, input_ids, **kwargs) - - return model._forward(input_ids, **kwargs) - - model.forward = call - - gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) - decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - with self.subTest(f"{attn_implementation}, static, compiled"): - self.assertListEqual(decoded, EXPECTED_GENERATION) - def test_dynamic_cache_extra_left_padding(self): """Tests that adding extra left-padding does not affect the generation with the dynamic cache""" EXPECTED_GENERATION = [ @@ -519,7 +497,6 @@ def test_dynamic_cache_extra_left_padding(self): @parameterized.expand( [ "static", - "offloaded-static", ] ) def test_static_cache_extra_left_padding(self, cache_implementation):