Skip to content

Commit

Permalink
Init cache on meta device (#35164)
Browse files Browse the repository at this point in the history
* init cache on meta device

* offloaded static + enable tests

* tests weren't running before  :(

* update

* fix mamba

* fix copies

* update

* address comments and fix tests

* fix copies

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <[email protected]>

* update

* mamba fix

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
zucchini-nlp and ArthurZucker authored Jan 22, 2025
1 parent 870e2c8 commit 373e50e
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 112 deletions.
131 changes: 94 additions & 37 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,12 +1069,15 @@ class StaticCache(Cache):
The maximum sequence length with which the model will be used.
device (`torch.device` or `str`):
The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
Example:
```python
Expand All @@ -1096,6 +1099,7 @@ class StaticCache(Cache):
"""

# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
@deprecate_kwarg("layer_device_map", version="4.52.0")
def __init__(
self,
config: PretrainedConfig,
Expand All @@ -1122,6 +1126,7 @@ def __init__(
)

self.dtype = dtype
self.device = torch.device(device) if device is not None else torch.device("meta")
self.num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
Expand All @@ -1136,7 +1141,7 @@ def __init__(
if layer_device_map is not None:
layer_device = layer_device_map[idx]
else:
layer_device = device
layer_device = self.device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
# Notes:
Expand Down Expand Up @@ -1181,6 +1186,9 @@ def update(
"""

cache_position = cache_kwargs.get("cache_position")
if self.key_cache[layer_idx].device.type == "meta":
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)

k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
Expand Down Expand Up @@ -1209,6 +1217,8 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
if self.key_cache[layer_idx].device.type == "meta":
return 0
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()

def get_max_cache_shape(self) -> Optional[int]:
Expand All @@ -1217,9 +1227,10 @@ def get_max_cache_shape(self) -> Optional[int]:
def reset(self):
"""Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
if self.key_cache[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

@property
def batch_size(self):
Expand Down Expand Up @@ -1257,6 +1268,8 @@ class SlidingWindowCache(StaticCache):
The maximum sequence length with which the model will be used.
device (`torch.device` or `str`):
The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
Expand Down Expand Up @@ -1321,8 +1334,15 @@ def update(
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")

if self.key_cache[layer_idx].device.type == "meta":
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)

k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)

# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
if cache_position.shape[0] > self.max_cache_len:
Expand Down Expand Up @@ -1365,9 +1385,10 @@ def get_max_cache_shape(self) -> Optional[int]:

def reset(self):
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
if self.key_cache[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()


class EncoderDecoderCache(Cache):
Expand Down Expand Up @@ -1561,8 +1582,10 @@ class HybridCache(Cache):
smaller batch size is used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device` or `str`, *optional*, defaults to `"cpu"`):
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
Expand Down Expand Up @@ -1590,12 +1613,13 @@ class HybridCache(Cache):
"""

# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
@deprecate_kwarg("layer_device_map", version="4.52.0")
def __init__(
self,
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: Union[torch.device, str] = "cpu",
device: Union[torch.device, str] = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
Expand Down Expand Up @@ -1623,9 +1647,11 @@ def __init__(
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)

self.device = torch.device(device) if device is not None else torch.device("meta")
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
self.is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
Expand All @@ -1640,7 +1666,7 @@ def __init__(
if layer_device_map is not None:
layer_device = layer_device_map[i]
else:
layer_device = device
layer_device = self.device
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
Expand Down Expand Up @@ -1696,8 +1722,16 @@ def update(
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window")

if self.key_cache[layer_idx].device.type == "meta":
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)

k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)

if sliding_window:
update_fn = self._sliding_update
else:
Expand Down Expand Up @@ -1725,14 +1759,18 @@ def get_seq_length(self, layer_idx: Optional[int] = 0):
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
"Using the `layer_idx` argument is not supported."
)

if self.key_cache[layer_idx].device.type == "meta":
return 0
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()

def reset(self):
"""Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
if self.key_cache[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

@property
def batch_size(self):
Expand All @@ -1757,10 +1795,14 @@ class MambaCache:
The default `dtype` to use when initializing the layer.
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
Attributes:
dtype: (`torch.dtype`):
The default `dtype` used to initializing the cache.
device (`torch.device`):
The default device on which the cache was initialized.
intermediate_size: (`int`):
Model's intermediate_size taken from config.
ssm_state_size: (`int`):
Expand Down Expand Up @@ -1809,30 +1851,40 @@ def __init__(
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.device = torch.device(device) if device is not None else torch.device("meta")

self.conv_states: List[torch.Tensor] = []
self.ssm_states: List[torch.Tensor] = []
for _ in range(config.num_hidden_layers):
conv_state: torch.Tensor = torch.zeros(
self.max_batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=self.device,
dtype=dtype,
)
ssm_state: torch.Tensor = torch.zeros(
self.max_batch_size,
self.intermediate_size,
self.ssm_state_size,
device=self.device,
dtype=dtype,
)

self.conv_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.max_batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=device,
dtype=dtype,
)
self.ssm_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.max_batch_size,
self.intermediate_size,
self.ssm_state_size,
device=device,
dtype=dtype,
)

torch._dynamo.mark_static_address(self.conv_states)
torch._dynamo.mark_static_address(self.ssm_states)
torch._dynamo.mark_static_address(conv_state)
torch._dynamo.mark_static_address(ssm_state)
self.conv_states.append(conv_state)
self.ssm_states.append(ssm_state)

def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
if self.conv_states[layer_idx].device.type == "meta":
self.conv_states[layer_idx] = torch.zeros_like(
self.conv_states[layer_idx],
device=new_conv_state.device,
)

conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)

Expand All @@ -1843,12 +1895,15 @@ def update_conv_state(
return self.conv_states[layer_idx]

def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device)
return self.ssm_states[layer_idx]

def reset(self):
self.conv_states.zero_()
self.ssm_states.zero_()
for layer_idx in range(len(self.conv_states)):
if self.conv_states[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address
self.conv_states[layer_idx].zero_()
self.ssm_states[layer_idx].zero_()

@property
def batch_size(self):
Expand Down Expand Up @@ -1920,6 +1975,7 @@ class OffloadedStaticCache(StaticCache):
```
"""

@deprecate_kwarg("layer_device_map", version="4.52.0")
def __init__(
self,
config: PretrainedConfig,
Expand All @@ -1930,9 +1986,10 @@ def __init__(
offload_device: Union[str, torch.device] = torch.device("cpu"),
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super(Cache, self).__init__()
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device) if layer_device_map is None else layer_device_map[0]
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
self.offload_device = torch.device(offload_device)
self.dtype = dtype if dtype is not None else torch.float32

Expand Down
35 changes: 1 addition & 34 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,45 +1633,12 @@ def _get_cache(
# models. May cause trobles with non-text modalities.
cache_dtype = self.get_output_embeddings().weight.dtype

def get_layer_device_map(execution_device_map: Optional[dict] = None):
num_hidden_layers = self.config.get_text_config().num_hidden_layers
if execution_device_map is None:
return None
elif len(execution_device_map) == 1 and "" in execution_device_map:
return {idx: execution_device_map[""] for idx in range(num_hidden_layers)}
layer_device_map = {}
for layer in execution_device_map:
for idx in range(num_hidden_layers):
if f".{idx}." in f"{layer}.":
layer_device_map[idx] = execution_device_map[layer]
break
for idx in range(num_hidden_layers):
if idx not in layer_device_map:
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
return layer_device_map

execution_device_map = None
# Taken from dispatch_model from accelerate.
# This is needed here if we don't want to make changes in accelerate in order to save execution_device
# For offloaded case, we need to get the execution device, not just the device where it is offloaded
if hasattr(self, "hf_device_map"):
if set(self.hf_device_map.values()) == {"cpu"} or set(self.hf_device_map.values()) == {"cpu", "disk"}:
main_device = "cpu"
else:
main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
execution_device_map = {
name: main_device if device in ["cpu", "disk"] else device
for name, device in self.hf_device_map.items()
}
layer_device_map = get_layer_device_map(execution_device_map)

cache_kwargs = {
"config": self.config.get_text_config(),
"max_batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
"dtype": cache_dtype,
"layer_device_map": layer_device_map,
"device": device if cache_implementation == "offloaded_static" else None,
}
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, model: PreTrainedModel):
batch_size=self.model.generation_config.cache_config.batch_size,
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
dtype=self.model.dtype,
device=self.model.generation_config.cache_config.device,
)
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
if self.is_causal:
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,6 @@ def forward(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
)

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,6 @@ def forward(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
)

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,6 @@ def forward(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
)

Expand Down
Loading

0 comments on commit 373e50e

Please sign in to comment.