Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core]: (2/N) Support prefill only models by Workflow Defined Engine - Prefill only attention #9124

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added tests/attention/__init__.py
Empty file.
Empty file.
89 changes: 89 additions & 0 deletions tests/attention/prefill_only/test_basic_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import itertools as it

import pytest
import torch
import torch.nn.functional as F

from vllm.attention.layer import Attention
from vllm.attention.prefill_only.abstract import AttentionType
from vllm.attention.prefill_only.selector import (AttentionImpls, AttnBackend,
_Backend)
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE


def compare_embeddings(embeddings1, embeddings2):
similarities = [
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0)
for e1, e2 in zip(embeddings1, embeddings2)
]
return similarities


SEQ_LENS = [1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29]


@pytest.mark.parametrize("head_dim", [64])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("num_kv_heads", [1, 2, 4, 8])
@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"])
@pytest.mark.parametrize("attn_type", ["DECODER", "ENCODER"])
@pytest.mark.parametrize("n_seqs", list(range(1, len(SEQ_LENS))))
def test_basic_correctness(head_dim: int, num_heads: int, num_kv_heads: int,
attn_type: str, dtype: str, n_seqs: int):
assert num_heads % num_kv_heads == 0

torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]

attention_impls = AttentionImpls[dtype]

seq_lens = SEQ_LENS[:n_seqs]
batchsize = sum(seq_lens)

query = torch.rand((batchsize, num_heads, head_dim),
dtype=torch_dtype,
device="cuda:0").view((batchsize, -1))
key = torch.rand((batchsize, num_kv_heads, head_dim),
dtype=torch_dtype,
device="cuda:0").view((batchsize, -1))
value = torch.rand((batchsize, num_kv_heads, head_dim),
dtype=torch_dtype,
device="cuda:0").view((batchsize, -1))

impl_outputs_list = []

for attention_impl in attention_impls:
selected_backend = _Backend.backend_name_to_enum(attention_impl)
backend_cls = AttnBackend.get_backend_cls(selected_backend)

attn_type_enum = AttentionType.attn_type_name_to_enum(attn_type)

attn_backend = backend_cls(attn_type_enum)
scaling = head_dim**-0.5

attn = Attention(num_heads,
head_dim,
scale=scaling,
num_kv_heads=num_kv_heads,
attn_backend=attn_backend)

metadata_builder = attn_backend.make_metadata_builder()
attn_metadata = metadata_builder(seq_lens=seq_lens)
attn_metadata = attn_metadata.to("cuda:0")

outputs = attn.forward(query,
key,
value,
kv_cache=None,
attn_metadata=attn_metadata)

impl_outputs_list.append((attention_impl, outputs))

tolerance = 1e-2
for a, b in it.combinations(impl_outputs_list, 2):
similarities = compare_embeddings(a[1], b[1])
all_similarities = torch.stack(similarities)

assert torch.all(
(all_similarities <= 1.0 + tolerance)
& (all_similarities >= 1.0 - tolerance)
), f"{a[0]} vs {b[0]}, not all values are within {tolerance} of 1.0"
54 changes: 54 additions & 0 deletions tests/attention/prefill_only/test_enum_verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest

from vllm.attention.prefill_only.abstract import (AttentionType,
PrefillOnlyAttentionBackend)
from vllm.attention.prefill_only.selector import (AttentionImpls, AttnBackend,
_Backend)


def get_attn_backend(attention_impl: str, attn_type: str):
selected_backend = _Backend.backend_name_to_enum(attention_impl)
backend_cls = AttnBackend.get_backend_cls(selected_backend)

attn_type_enum = AttentionType.attn_type_name_to_enum(attn_type)

attn_backend = backend_cls(attn_type_enum)
return attn_backend


@pytest.mark.parametrize("attn_type", ["DECODER", "ENCODER"])
@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"])
def test_backend(dtype: str, attn_type: str):
attention_impls = AttentionImpls[dtype]

for attention_impl in attention_impls:
attn_backend = get_attn_backend(attention_impl, attn_type)

assert isinstance(attn_backend, PrefillOnlyAttentionBackend)


@pytest.mark.parametrize("attn_type", ["ENCODER_DECODER"])
@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"])
def test_ENCODER_DECODER_not_supported(dtype: str, attn_type: str):
attention_impls = AttentionImpls[dtype]

for attention_impl in attention_impls:
with pytest.raises(NotImplementedError):
get_attn_backend(attention_impl, attn_type)


def test_not_supported_backend():
attention_impls = ["not_supported_backend", 0, 1.0]

for attention_impl in attention_impls:
with pytest.raises(ValueError):
selected_backend = _Backend.backend_name_to_enum(attention_impl)
AttnBackend.get_backend_cls(selected_backend)


def test_not_supported_attn_type():
attn_types = ["not_supported_attn_type", 0, 1.0]

for attn_type in attn_types:
with pytest.raises(ValueError):
AttentionType.attn_type_name_to_enum(attn_type)
13 changes: 13 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ class AttentionType(Enum):
ENCODER = auto() # Encoder attention between previous layer Q/K/V
ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V

@staticmethod
def attn_type_name_to_enum(attn_type: str) -> "AttentionType":
assert attn_type is not None

attn_type_members = AttentionType.__members__
if attn_type not in attn_type_members:
raise ValueError(
f"Invalid attn_type '{attn_type}'. "
f"Available backends: {', '.join(attn_type_members)} "
"(case-sensitive).")

return AttentionType[attn_type]


class AttentionBackend(ABC):
"""Abstract class for attention backends."""
Expand Down
28 changes: 19 additions & 9 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn as nn

from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import (
Expand Down Expand Up @@ -36,6 +36,7 @@ def __init__(
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
prefix: str = "",
attn_backend: Optional[AttentionBackend] = None,
) -> None:
super().__init__()
if cache_config is not None:
Expand Down Expand Up @@ -73,14 +74,18 @@ def __init__(
self.quant_method = quant_method
self.quant_method.create_weights(self)

# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size, blocksparse_params
is not None)
impl_cls = attn_backend.get_impl_cls()
if attn_backend is None:
# During model initialization, the default dtype is set as the model
# weight and activation dtype.

dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(
num_heads, head_size, num_kv_heads, sliding_window, dtype,
kv_cache_dtype, block_size, blocksparse_params is not None)()
else:
self.attn_backend = attn_backend

impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap)
Expand All @@ -94,6 +99,11 @@ def forward(
attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
if hasattr(self.attn_backend, "attn_type"):
return self.impl.forward(query, key, value, kv_cache,
attn_metadata, self._k_scale,
self._v_scale,
self.attn_backend.attn_type)

return self.impl.forward(query,
key,
Expand Down
Empty file.
125 changes: 125 additions & 0 deletions vllm/attention/prefill_only/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar

import torch

from vllm.attention.backends.abstract import AttentionType
from vllm.utils import is_pin_memory_available

pin_memory = is_pin_memory_available()


class PrefillOnlyAttentionBackend(ABC):

def __init__(self, attn_type: AttentionType):
if attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError("Encoder/decoder cross-attention "
"are not implemented for "
"PrefillOnlyAttentionBackend")

self._attn_type = attn_type

@property
def attn_type(self) -> AttentionType:
return self._attn_type

@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError

@staticmethod
@abstractmethod
def get_impl_cls() -> Type["PrefillOnlyAttentionImpl"]:
raise NotImplementedError

@staticmethod
def get_metadata_cls() -> Type["PrefillOnlyAttentionMetadata"]:
return PrefillOnlyAttentionMetadata

@classmethod
def make_metadata(cls, *args, **kwargs) -> "PrefillOnlyAttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)

@staticmethod
def get_builder_cls() -> Type["PrefillOnlyAttentionMetadataBuilder"]:
return PrefillOnlyAttentionMetadataBuilder

@classmethod
def make_metadata_builder(
cls, *args, **kwargs) -> "PrefillOnlyAttentionMetadataBuilder":
return cls.get_builder_cls()(*args, **kwargs)


@dataclass
class PrefillOnlyAttentionMetadata:
max_seq_len: int
seq_lens: List[int]

# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]

def to(self, device, non_blocking=False):
for k, v in self.__dict__.items():
if isinstance(v, torch.Tensor):
self.__dict__[k] = v.to(device, non_blocking=non_blocking)

return self


T = TypeVar("T", bound=PrefillOnlyAttentionMetadata)


class PrefillOnlyAttentionMetadataBuilder(Generic[T]):

def __call__(self, seq_lens: List[int]):
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.long,
pin_memory=pin_memory,
device="cpu")
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device="cpu")
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])

return PrefillOnlyAttentionMetadata(seq_lens=seq_lens,
max_seq_len=max(seq_lens),
seq_start_loc=seq_start_loc)


class PrefillOnlyAttentionImpl(ABC):

@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
raise NotImplementedError

@abstractmethod
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: T,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
raise NotImplementedError
Loading
Loading