Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc] use out argument for flash attention #10822

Merged
merged 7 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,5 +247,6 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
2 changes: 2 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.

Expand Down Expand Up @@ -448,5 +449,6 @@ def forward(
blocksparse_head_sliding_step=self.head_sliding_step,
)

assert output is not None
# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
39 changes: 17 additions & 22 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,24 +638,27 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.

Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
output: shape = [num_tokens, num_heads * head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
NOTE: It in-place updates the output tensor.
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

assert output is not None, "Output tensor must be provided."

if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
Expand All @@ -679,6 +682,7 @@ def forward(

# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
output = output.view(-1, num_heads, head_size)
if (key is not None) and (value is not None):
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
Expand Down Expand Up @@ -721,13 +725,13 @@ def forward(
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
decode_query = query[num_prefill_query_tokens:]
decode_output = output[num_prefill_query_tokens:]
# QKV for prefill.
query = query[:num_prefill_query_tokens]
prefill_output = output[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens

prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
Expand All @@ -741,7 +745,7 @@ def forward(
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]

prefill_output = flash_attn_varlen_func(
flash_attn_varlen_func(
q=query,
k=key,
v=value,
Expand All @@ -754,14 +758,15 @@ def forward(
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=prefill_output,
)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
prefill_output = flash_attn_varlen_func( # noqa
flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
Expand All @@ -775,6 +780,7 @@ def forward(
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
out=prefill_output,
)

if decode_meta := attn_metadata.decode_metadata:
Expand All @@ -788,7 +794,7 @@ def forward(
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1"
)
decode_output = flash_attn_varlen_func(
flash_attn_varlen_func(
q=decode_query,
k=key_cache,
v=value_cache,
Expand All @@ -802,6 +808,7 @@ def forward(
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
out=decode_output,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
Expand All @@ -810,7 +817,7 @@ def forward(
_,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
decode_output = flash_attn_with_kvcache(
flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
Expand All @@ -821,20 +828,8 @@ def forward(
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
).squeeze(1)

if prefill_output is None:
assert decode_output is not None
return decode_output.view(num_decode_query_tokens, hidden_size)
if decode_output is None:
assert prefill_output is not None
return prefill_output.view(num_prefill_query_tokens, hidden_size)

assert decode_meta is not None
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)

out=decode_output.unsqueeze(1),
)
return output


Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,11 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:

# TODO: directly write to output tensor

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.

Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down
63 changes: 59 additions & 4 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
import torch
import torch.nn as nn

import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op


Expand Down Expand Up @@ -103,8 +102,14 @@ def __init__(
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = envs.VLLM_USE_V1 or not (
current_platform.is_cuda_alike() or current_platform.is_cpu())
self.use_direct_call = not current_platform.is_cuda_alike(
) and not current_platform.is_cpu()

# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
self.use_output = self.backend == _Backend.FLASH_ATTN or \
self.backend == _Backend.FLASH_ATTN_VLLM_V1
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
Expand All @@ -130,6 +135,12 @@ def forward(
self._k_scale,
self._v_scale,
attn_type=attn_type)
elif self.use_output:
output = torch.empty_like(query)
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, kv_cache, attn_type,
self.layer_name)
return output
else:
return torch.ops.vllm.unified_attention(query, key, value,
kv_cache, attn_type,
Expand Down Expand Up @@ -183,3 +194,47 @@ def unified_attention_fake(
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)


def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type,
output=output)


def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> None:
return


direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2238,7 +2238,7 @@ class CompilationConfig(BaseModel):
custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [
"vllm.unified_attention",
"vllm.unified_v1_flash_attention",
"vllm.unified_attention_with_output",
])

use_inductor: bool = True
Expand Down
Loading