Skip to content

Commit

Permalink
Allows interweaving of arbitrary kinds of 'attention' layers, like sl…
Browse files Browse the repository at this point in the history
…iding window, reuse prev layer kv cache etc. (#1299)

* [WIP] Allows interweaving of arbitrary kinds of 'attention' layers, like RNN, sliding window etc.

* lint

* applying overrides to blocks rather than just attentions

* add docstring

* minor

* changing yaml specification style

* ..

* fixes

* fix

* fix

* fix

* refactoring

* add warning

* compute only query vector when reusing kv

* refactor

* fixing

* adding test for reusing previous layer kv cache

* adding error messages

* ..

* adding test

* add logging

* adding logging

* minor

* bug fix, adding test

* minor

* addressing some comments

* addressing some comments

* setting absolute absolute value for reuse_kv_layer_idx

* lint

* adding tests for override_block_args

* adding error if reusing kv cache from a mismatch layer

* fixing test

* fixing code, test

* fix

* ..

* refactoring

* fix

* ..

* ..

* ..

* refactoring

* ..

* ..

* ..

* adding test for _get_modules_order_expanded

* fixing test

* fixing test

* lint

* lint

* adding test

* addressing comment

* ..

* fixing test

* changing yaml format

* fix configuation

* fixing test

* allowing repeat at top level

* allowing overriding error

* addressing comments

* lint

* addressing comments

* fix

* ..

* ..

* ..

* ..

* ..

* addressing comment

* fixing test
  • Loading branch information
ShashankMosaicML authored Jun 30, 2024
1 parent 88511f7 commit 8604bba
Show file tree
Hide file tree
Showing 6 changed files with 849 additions and 25 deletions.
93 changes: 73 additions & 20 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def __init__(
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
):
super().__init__()

Expand All @@ -428,6 +429,7 @@ def __init__(
self.n_heads = n_heads
self.kv_n_heads = kv_n_heads
self.sliding_window_size = sliding_window_size
self.reuse_kv_layer_idx = reuse_kv_layer_idx

self.head_dim = d_model // n_heads

Expand Down Expand Up @@ -458,18 +460,29 @@ def __init__(
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop

self.Wqkv = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
i * self.head_dim
for i in range(1, self.n_heads + 2 * self.kv_n_heads)
]
self.Wqkv._fused = (0, fuse_splits)
if self.reuse_kv_layer_idx is None:
self.Wqkv = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
i * self.head_dim
for i in range(1, self.n_heads + 2 * self.kv_n_heads)
]
self.Wqkv._fused = (0, fuse_splits)
else:
self.Wq = build_fc(
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)]
self.Wq._fused = (0, fuse_splits)

if self.qk_ln or self.qk_gn:
norm_size = self.head_dim if qk_gn else d_model
Expand All @@ -478,13 +491,14 @@ def __init__(
normalized_shape=norm_size,
device=device,
)
if qk_ln:
norm_size = self.head_dim * kv_n_heads
self.k_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
device=device,
)
if self.reuse_kv_layer_idx is None:
if qk_ln:
norm_size = self.head_dim * kv_n_heads
self.k_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
device=device,
)

self.attn_fn = attention_implementations.get(self.attn_impl)

Expand All @@ -507,9 +521,14 @@ def forward(
needs_weights: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
prev_layer_key_value: Optional[Tuple[torch.Tensor,
torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
query, key, value = self.get_qkv(x)
extra_kwargs = {}
if prev_layer_key_value is not None:
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
query, key, value = self.get_qkv(x, **extra_kwargs)

if rotary_emb_w_meta_info is not None:
query, key, value = self._apply_rotary_embeddings(
Expand Down Expand Up @@ -546,6 +565,8 @@ def forward(
def get_qkv(
self,
x: torch.Tensor,
prev_layer_key_value: Optional[Tuple[torch.Tensor,
torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Computes and returns the query, key, and value tensors.
Expand All @@ -557,6 +578,27 @@ def get_qkv(
key (torch.Tensor): The key tensor.
value (torch.Tensor): The value tensor.
"""
if self.reuse_kv_layer_idx is not None:
if prev_layer_key_value is None:
raise ValueError(
'prev_layer_key_value is None, cannot reuse_prev_layer_kv.',
)
key, value = prev_layer_key_value

query = self.Wq(x)
if self.clip_qkv:
query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv)

if self.qk_ln or self.qk_gn:
# Applying layernorm to qk
q_shape = query.shape
if self.qk_gn:
b, s = query.shape[:2]
query = query.view(b, s, self.n_heads, -1)
dtype = query.dtype
query = self.q_ln(query).to(dtype).view(q_shape)
return query, key, value

qkv = self.Wqkv(x)

if self.clip_qkv:
Expand Down Expand Up @@ -591,6 +633,10 @@ def _apply_rotary_embeddings(
key: torch.Tensor,
value: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.reuse_kv_layer_idx is not None:
orig_key, orig_value = key, value
key, value = torch.empty_like(key), torch.empty_like(value)

rotary_emb = rotary_emb_w_meta_info['rotary_emb']
seq_len = rotary_emb_w_meta_info['seq_len']
offset_info = rotary_emb_w_meta_info['offset_info']
Expand All @@ -602,6 +648,7 @@ def _apply_rotary_embeddings(
value = value.view(bsz, seqlen, -1, self.head_dim)

kv = torch.stack([key, value], dim=2)
# Note: Rotates in place (https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/flash_attn/layers/rotary.py#L429)
query, kv = rotary_emb(
query,
kv,
Expand Down Expand Up @@ -652,6 +699,8 @@ def _apply_rotary_embeddings(

query = query.view(bsz, seqlen, -1)
key = key.view(bsz, seqlen, -1)
if self.reuse_kv_layer_idx is not None:
return query, orig_key, orig_value # type: ignore
return query, key, value

def get_implementation_specific_args(
Expand Down Expand Up @@ -705,6 +754,7 @@ def __init__(
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
):
super().__init__(
d_model=d_model,
Expand All @@ -721,6 +771,7 @@ def __init__(
device=device,
bias=bias,
sliding_window_size=sliding_window_size,
reuse_kv_layer_idx=reuse_kv_layer_idx,
)


Expand All @@ -746,6 +797,7 @@ def __init__(
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
):
super().__init__(
d_model=d_model,
Expand All @@ -762,6 +814,7 @@ def __init__(
device=device,
bias=bias,
sliding_window_size=sliding_window_size,
reuse_kv_layer_idx=reuse_kv_layer_idx,
)


Expand Down
13 changes: 13 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,13 @@ def forward(
output_attentions: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
prev_layer_key_value: Optional[Tuple[torch.Tensor,
torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
extra_kwargs = {}
if prev_layer_key_value is not None:
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
if self.fuse_norm_attn_norm:
x, m, attn_weights, past_key_value = self.norm_attn_norm(
x,
Expand All @@ -171,6 +176,7 @@ def forward(
output_attentions=output_attentions,
alibi_slopes=alibi_slopes,
flash_attn_padding_info=flash_attn_padding_info,
**extra_kwargs,
)
else:
a = self.norm_1(x)
Expand All @@ -184,6 +190,7 @@ def forward(
needs_weights=output_attentions,
alibi_slopes=alibi_slopes,
flash_attn_padding_info=flash_attn_padding_info,
**extra_kwargs,
)
x = x + self.resid_attn_dropout(b)
m = x
Expand Down Expand Up @@ -308,9 +315,14 @@ def forward(
output_attentions: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
prev_layer_key_value: Optional[Tuple[torch.Tensor,
torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
extra_kwargs = {}
if prev_layer_key_value is not None:
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
b, attn_weights, past_key_value = self.attn(
a,
past_key_value=past_key_value,
Expand All @@ -321,6 +333,7 @@ def forward(
needs_weights=output_attentions,
alibi_slopes=alibi_slopes,
flash_attn_padding_info=flash_attn_padding_info,
**extra_kwargs,
)
x = x + self.resid_attn_dropout(b)
m = x
Expand Down
61 changes: 61 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ffn_config_defaults,
init_config_defaults,
)
from llmfoundry.utils.warnings import ExperimentalWarning


class MPTConfig(PretrainedConfig):
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(
fc_type: Union[str, Dict] = 'torch',
tie_word_embeddings: bool = True,
use_pad_tok_in_ffn: bool = True,
block_overrides: Optional[Dict[str, Any]] = None,
**kwargs: Any,
):
"""The MPT configuration class.
Expand Down Expand Up @@ -117,6 +119,30 @@ def __init__(
also be a dictionary that specifies the fc layer name and any kwargs for the fc layer.
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks.
block_overrides: This allows for overriding default block configs for certain layers. This must contain `overrides` and `order`. `order` is a nested list which describes the order of the layers. For each kind of layer, specify the `overrides` in the overrides config (default refers to a layer that does not apply any overrides).
To specify this model (https://research.character.ai/optimizing-inference/) , the following config will be needed:
block_overrides:
order:
- name: default
- repeat: 2
order:
- name: sliding_window_layer
- name: sliding_window_layer_reuse
- name: sliding_window_layer
- repeat: 2
name: sliding_window_layer_reuse
- name: reuse_kv_layer
overrides:
sliding_window_layer:
attn_config:
sliding_window_size: 1024
sliding_window_layer_reuse:
attn_config:
sliding_window_size: 1024
reuse_kv_layer_idx: -1 # Relative index of the layer whose kv cache to reuse
reuse_kv_layer:
attn_config:
reuse_kv_layer_idx: -6 # Relative index of the layer whose kv cache to reuse
"""
self.d_model = d_model
self.n_heads = n_heads
Expand Down Expand Up @@ -145,6 +171,15 @@ def __init__(
init_config_defaults,
)

if 'reuse_kv_layer_idx' in self.attn_config and self.attn_config[
'attn_impl'] == 'torch':
raise NotImplementedError(
'reusing kv cache from a previous layer is not implemented for torch attention.',
)
if block_overrides is not None:
self._validate_block_overrides(block_overrides)
self.block_overrides = block_overrides

if isinstance(fc_type, str):
fc_type = {'name': fc_type}
self.fc_type = fc_type
Expand All @@ -169,6 +204,23 @@ def __init__(

self._validate_config()

def _validate_block_overrides(self, block_overrides: Dict[str, Any]):
warnings.warn(ExperimentalWarning('block_overrides'))
if 'order' not in block_overrides:
raise ValueError('`order` should be defined in block_overrides',)
if 'overrides' not in block_overrides:
raise ValueError(
'`overrides` should be defined in block_overrides',
)
for name, override in block_overrides['overrides'].items():
if name == 'default':
raise ValueError('block overrides cannot be named "default".',)
if 'attn_config' in override and 'reuse_kv_layer_idx' in override[
'attn_config'] and self.attn_config['attn_impl'] == 'torch':
raise NotImplementedError(
'reusing kv cache from a previous layer is not implemented for torch attention.',
)

def _set_config_defaults(
self,
config: Dict[str, Any],
Expand Down Expand Up @@ -335,3 +387,12 @@ def _validate_config(self) -> None:
)

self.validate_attention_config()

@property
def allowed_block_overrides(self):
return {
'attn_config': {
'sliding_window_size': None,
'reuse_kv_layer_idx': None,
},
}
Loading

0 comments on commit 8604bba

Please sign in to comment.