Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DO NOT MERGE: Enabling Flash attention for alibi. #815

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ ARG DEP_GROUPS
RUN git clone -b main https://github.com/mosaicml/llm-foundry.git
RUN pip install --no-cache-dir "./llm-foundry${DEP_GROUPS}"
RUN pip uninstall -y llm-foundry
RUN rm -rf llm-foundry
RUN rm -rf llm-foundry
35 changes: 27 additions & 8 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,6 @@ def flash_attn_fn(

past_key_value = (key, value)

if attn_bias is not None:
raise NotImplementedError(f'attn_bias not implemented for flash attn.')

batch_size, seqlen = query.shape[:2]

if attention_mask_in_length is None:
Expand Down Expand Up @@ -333,6 +330,24 @@ def flash_attn_fn(
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights)
elif is_flash_v2_installed(
v2_version='2.3.6'
): # TODO: Change to 2.3.7 and do not merge before 2.3.7 is released!
output_unpad = flash_attn_interface.flash_attn_varlen_func(
q=query_unpad,
k=key_unpad,
v=value_unpad,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights,
window_size=(sliding_window_size, sliding_window_size),
alibi_slopes=attn_bias,
)
elif is_flash_v2_installed():
output_unpad = flash_attn_interface.flash_attn_varlen_func(
q=query_unpad,
Expand Down Expand Up @@ -586,7 +601,7 @@ def forward(
rotary_emb_w_meta_info: Optional[dict] = None,
is_causal: bool = True,
needs_weights: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
attention_mask_type: str = 'boolean',
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
qkv = self.Wqkv(x)
Expand All @@ -603,7 +618,7 @@ def forward(
dim=2,
)

key_padding_mask = attention_mask
key_padding_mask = attention_mask if attention_mask_type == 'boolean' else None

if self.qk_ln:
# Applying layernorm to qk
Expand Down Expand Up @@ -653,9 +668,13 @@ def forward(
extra_attn_kwargs = {}
if self.attn_impl == 'flash':
extra_attn_kwargs = {
'attention_mask_in_length': attention_mask_in_length,
'should_repeat_kv_for_gqa': not is_flash_v2_installed(),
'sliding_window_size': self.sliding_window_size,
'attention_mask_in_length':
attention_mask
if attention_mask_type == 'in_length' else None,
'should_repeat_kv_for_gqa':
not is_flash_v2_installed(),
'sliding_window_size':
self.sliding_window_size,
}

context, attn_weights, past_key_value = self.attn_fn(
Expand Down
10 changes: 6 additions & 4 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attn_bias: Optional[torch.Tensor] = None,
rotary_emb_w_meta_info: Optional[Dict] = None,
attention_mask: Optional[torch.ByteTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
attention_mask_in_length: Optional[torch.Tensor] = None,
attention_mask_type: str = 'boolean',
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
Expand All @@ -134,7 +134,7 @@ def forward(
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
attention_mask_in_length=attention_mask_in_length,
attention_mask_type=attention_mask_type,
)
x = x + self.resid_attn_dropout(b)
m = x
Expand All @@ -144,7 +144,9 @@ def forward(
indices = None
if not self.use_pad_tok_in_ffn:
assert unpad_input is not None
m, indices, _, _ = unpad_input(m, attention_mask)
m, indices, _, _ = unpad_input(
m, attention_mask
) # TODO: Handle the case of attention_mask is attention_mask_in_length
n = self.ffn(m)
if not self.use_pad_tok_in_ffn:
assert pad_input is not None
Expand Down
11 changes: 7 additions & 4 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,14 @@ def _validate_config(self) -> None:
'attn_impl'] not in ['torch', 'triton']:
raise NotImplementedError(
'prefix_lm only implemented with torch and triton attention.')
if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in [
'torch', 'triton'
]:
if self.attn_config['alibi'] and not (
self.attn_config['attn_impl'] in ['torch', 'triton'] or
(self.attn_config['attn_impl'] == 'flash' and
is_flash_v2_installed(v2_version='v2.3.6'))
): # TODO: Change to 2.3.7 and do not merge before 2.3.7 is released!
raise NotImplementedError(
'alibi only implemented with torch and triton attention.')
'alibi only implemented with torch, triton, and flash (v2.3.7 or higher) attention.'
)
if self.attn_config['attn_uses_sequence_id'] and not (
self.attn_config['attn_impl'] in ['torch', 'triton'] or
(self.attn_config['attn_impl'] == 'flash' and
Expand Down
22 changes: 20 additions & 2 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY,
attn_bias_shape,
build_attn_bias)
build_attn_bias, gen_slopes)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
Expand Down Expand Up @@ -216,6 +216,13 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,
return attention_mask_in_length


def gen_alibi_slopes(batch_size: int, n_heads: int, alibi_bias_max: int,
device: torch.device) -> torch.Tensor:
return gen_slopes(n_heads=n_heads,
alibi_bias_max=alibi_bias_max,
device=device).squeeze(dim=(2, 3)).expand(batch_size, -1)


def apply_sequence_id(attn_bias: torch.Tensor, sequence_id: torch.LongTensor,
max_seq_len: int) -> torch.Tensor:
seq_len = sequence_id.shape[-1]
Expand Down Expand Up @@ -607,6 +614,17 @@ def forward(
attn_uses_sequence_id=self.attn_uses_sequence_id,
attn_impl=self.attn_impl,
attention_mask=attention_mask)

alibi_slopes = None
if self.alibi and self.attn_impl == 'flash':
alibi_slopes = gen_alibi_slopes(batch_size=x.shape[0],
n_heads=self.config.n_heads,
alibi_bias_max=self.alibi_bias_max,
device=x.device)

attention_mask = attention_mask_in_length if attention_mask_in_length is not None else attention_mask
attention_mask_type = 'in_length' if attention_mask_in_length is not None else 'boolean'
attn_bias = alibi_slopes if alibi_slopes is not None else attn_bias
# initialize the past key values cache if it should be used
presents = () if use_cache else None
if use_cache and past_key_values is None:
Expand All @@ -629,7 +647,7 @@ def forward(
attention_mask=attention_mask,
is_causal=self.is_causal,
output_attentions=bool(output_attentions),
attention_mask_in_length=attention_mask_in_length,
attention_mask_type=attention_mask_type,
)
if presents is not None:
presents += (present,)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected]#subdirectory=csrc/xentropy',
]
extra_deps['gpu-flash2'] = [
'flash-attn==2.3.6',
'flash-attn==2.3.6', # TODO: Change to 2.3.7
'mosaicml-turbo==0.0.4',
]

Expand Down
Loading