-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add end-to-end example for paged attention (#104)
* e2e example for paged attention
- Loading branch information
1 parent
2e4d04a
commit 232f8c6
Showing
6 changed files
with
733 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
""" | ||
Benchmarking the latency of a paged attention layer against a non-paged attention layer. | ||
Command: | ||
python3 latency.py --setting change_max_seq_len | ||
""" | ||
|
||
import torch | ||
from torch.nn.attention.flex_attention import ( | ||
create_block_mask, | ||
noop_mask, | ||
) | ||
from torch._inductor.runtime.benchmarking import benchmarker | ||
|
||
from utils import random_init_paged_attention, gen_offset, generate_score_mod | ||
|
||
dtype = torch.bfloat16 | ||
|
||
|
||
def benchmark_layer( | ||
bsz, | ||
n_heads, | ||
max_seq_len, | ||
head_dim, | ||
paged_attention, | ||
batch_idx, | ||
input_pos, | ||
block_mask, | ||
score_mod, | ||
converted_block_mask, | ||
converted_score_mod, | ||
dtype=torch.bfloat16, | ||
): | ||
from model import NonPagedAttentionLayer, PagedAttentionLayer | ||
|
||
# compile model | ||
non_paged_foo = torch.compile( | ||
NonPagedAttentionLayer(bsz, n_heads, max_seq_len, head_dim, dtype), fullgraph=True | ||
) | ||
paged_foo = torch.compile( | ||
PagedAttentionLayer(n_heads, head_dim, dtype, paged_attention), fullgraph=True | ||
) | ||
|
||
with torch.no_grad(): | ||
# randomize a token embedding | ||
x = torch.randn(bsz, 1, n_heads * head_dim, device="cuda", dtype=dtype) | ||
|
||
# warmup | ||
for _ in range(10): | ||
non_paged_foo(batch_idx, input_pos, x, block_mask, score_mod) | ||
paged_foo(batch_idx, input_pos, x, converted_block_mask, converted_score_mod) | ||
|
||
# benchmark | ||
non_paged_latency = benchmarker.benchmark_gpu( | ||
lambda: non_paged_foo(batch_idx, input_pos, x, block_mask, score_mod) | ||
) | ||
paged_latency = benchmarker.benchmark_gpu( | ||
lambda: paged_foo(batch_idx, input_pos, x, converted_block_mask, converted_score_mod) | ||
) | ||
print( | ||
f"non_paged_latency: {non_paged_latency} ms, paged_latency: {paged_latency} ms, overhead: {round((paged_latency / non_paged_latency - 1.0) * 100, 2)}%" | ||
) | ||
|
||
|
||
def benchmark( | ||
attn_type: str, page_size: int, bsz: int, max_seq_len: int, n_heads: int, head_dim: int | ||
): | ||
# For decoding benchmark, we set input_pos to be half of max_seq_len | ||
input_pos = torch.tensor([max_seq_len // 2] * bsz, device="cuda", dtype=torch.int32).view( | ||
bsz, 1 | ||
) # [bsz, 1] | ||
batch_idx = torch.arange(bsz, device="cuda", dtype=torch.int32) # [bsz] | ||
|
||
# init paged attention | ||
n_pages = (max_seq_len + page_size - 1) // page_size * bsz | ||
paged_attention = random_init_paged_attention(n_pages, page_size, bsz, max_seq_len) | ||
|
||
# Block mask | ||
if attn_type == "causal": | ||
mask_mod = gen_offset( | ||
torch.tensor([max_seq_len // 2] * bsz, device="cuda", dtype=torch.int32) | ||
) | ||
else: | ||
mask_mod = noop_mask | ||
block_mask = create_block_mask(mask_mod, bsz, 1, 1, max_seq_len, BLOCK_SIZE=page_size) | ||
converted_block_mask = paged_attention.convert_logical_block_mask(block_mask) | ||
|
||
# Score mod | ||
score_mod = generate_score_mod(attn_type) | ||
converted_score_mod = paged_attention.get_score_mod(score_mod) | ||
|
||
benchmark_layer( | ||
bsz, | ||
n_heads, | ||
max_seq_len, | ||
head_dim, | ||
paged_attention, | ||
batch_idx, | ||
input_pos, | ||
block_mask, | ||
score_mod, | ||
converted_block_mask, | ||
converted_score_mod, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--setting", type=str, default="change_max_seq_len") | ||
args = parser.parse_args() | ||
|
||
if args.setting == "change_max_seq_len": | ||
max_seq_len_candidates = [2048, 4096, 8192, 16384, 32768] | ||
bsz_candidates = [32] | ||
page_size_candidates = [128] | ||
elif args.setting == "change_bsz": | ||
max_seq_len_candidates = [8192] | ||
bsz_candidates = [32, 64, 128] | ||
page_size_candidates = [128] | ||
elif args.setting == "change_page_size": | ||
max_seq_len_candidates = [8192] | ||
bsz_candidates = [32] | ||
page_size_candidates = [64, 128, 256] | ||
else: | ||
raise NotImplementedError | ||
|
||
n_heads, head_dim = 16, 64 | ||
|
||
for attn_type in ["noop", "causal", "rel", "head_bias"]: | ||
print(f"\nattn_type:{attn_type}") | ||
for page_size in page_size_candidates: | ||
print(f"page_size:{page_size}") | ||
for bsz in bsz_candidates: | ||
for max_seq_len in max_seq_len_candidates: | ||
torch._dynamo.reset() | ||
|
||
print( | ||
f"\nbsz: {bsz}, max_seq_len: {max_seq_len}, head_dim: {head_dim}, n_heads: {n_heads}" | ||
) | ||
benchmark(attn_type, page_size, bsz, max_seq_len, n_heads, head_dim) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
import torch | ||
import math | ||
from torch.nn.attention.flex_attention import BlockMask, flex_attention, _score_mod_signature | ||
from torch import Tensor | ||
from typing import Dict, Optional | ||
|
||
|
||
class NonPagedAttentionLayer(torch.nn.Module): | ||
"""An attention layer without paged attention, ported from GPT-Fast: | ||
https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L180-L227 | ||
""" | ||
|
||
def __init__(self, bsz, n_heads, max_seq_len, head_dim, dtype, block_size: int = 32768): | ||
super().__init__() | ||
self.n_head = n_heads | ||
self.head_dim = head_dim | ||
|
||
# key, query, value projections for all heads, but in a batch | ||
total_head_dim = n_heads * head_dim | ||
self.wqkv = torch.nn.Linear( | ||
total_head_dim, 3 * total_head_dim, bias=False, device="cuda", dtype=dtype | ||
) | ||
self.wo = torch.nn.Linear( | ||
total_head_dim, total_head_dim, bias=False, device="cuda", dtype=dtype | ||
) | ||
self.k_cache = torch.randn( | ||
(bsz, n_heads, max_seq_len, head_dim), device="cuda", dtype=dtype | ||
) | ||
self.v_cache = torch.randn( | ||
(bsz, n_heads, max_seq_len, head_dim), device="cuda", dtype=dtype | ||
) | ||
self.freqs_cis = precompute_freqs_cis(block_size, self.head_dim, dtype=dtype) | ||
|
||
def forward( | ||
self, | ||
batch_idx: Tensor, | ||
input_pos: Tensor, | ||
x: Tensor, | ||
block_mask: BlockMask, | ||
score_mod: _score_mod_signature, | ||
) -> Tensor: | ||
# input_pos: [B, S], batch_idx: [B], x: [B, S, D] | ||
B, S, _ = x.shape | ||
|
||
kv_size = self.n_head * self.head_dim | ||
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) | ||
|
||
q = q.view(B, S, self.n_head, self.head_dim) | ||
k = k.view(B, S, self.n_head, self.head_dim) | ||
v = v.view(B, S, self.n_head, self.head_dim) | ||
|
||
freqs_cis = self.freqs_cis.unsqueeze(0)[ | ||
torch.zeros((B, 1), dtype=torch.int), input_pos | ||
] # [B, S, D//2, 2] | ||
|
||
q = apply_rotary_emb(q, freqs_cis) | ||
k = apply_rotary_emb(k, freqs_cis) | ||
|
||
q = q.transpose(1, 2) | ||
self.k_cache[batch_idx.view(B, 1), :, input_pos] = k | ||
self.v_cache[batch_idx.view(B, 1), :, input_pos] = v | ||
|
||
y = flex_attention( | ||
q, self.k_cache, self.v_cache, block_mask=block_mask, score_mod=score_mod | ||
) | ||
|
||
y = y.transpose(1, 2).contiguous().view(B, S, -1) | ||
|
||
y = self.wo(y) | ||
return y | ||
|
||
|
||
class PagedAttentionLayer(torch.nn.Module): | ||
"""An attention layer with paged attention""" | ||
|
||
def __init__(self, n_heads, head_dim, dtype, paged_attention, block_size: int = 65536): | ||
super().__init__() | ||
self.n_head = n_heads | ||
self.head_dim = head_dim | ||
|
||
# key, query, value projections for all heads, but in a batch | ||
total_head_dim = n_heads * head_dim | ||
self.wqkv = torch.nn.Linear( | ||
total_head_dim, 3 * total_head_dim, bias=False, device="cuda", dtype=dtype | ||
) | ||
self.wo = torch.nn.Linear( | ||
total_head_dim, total_head_dim, bias=False, device="cuda", dtype=dtype | ||
) | ||
|
||
# allocate kv cache with batch size=1 for paged attention | ||
max_cached_seq_len = paged_attention.n_pages * paged_attention.page_size | ||
self.k_cache_paged = torch.randn( | ||
1, | ||
n_heads, | ||
max_cached_seq_len, | ||
head_dim, | ||
device="cuda", | ||
dtype=dtype, | ||
) | ||
self.v_cache_paged = torch.randn( | ||
1, | ||
n_heads, | ||
max_cached_seq_len, | ||
head_dim, | ||
device="cuda", | ||
dtype=dtype, | ||
) | ||
self.paged_attention = paged_attention | ||
|
||
self.freqs_cis = precompute_freqs_cis( | ||
block_size, self.head_dim, dtype=dtype | ||
) # [block_size, D//2, 2] | ||
|
||
def forward( | ||
self, | ||
batch_idx: Tensor, | ||
input_pos: Tensor, | ||
x: Tensor, | ||
converted_block_mask: BlockMask, | ||
converted_score_mod: _score_mod_signature, | ||
) -> Tensor: | ||
# input_pos: [B, S], batch_idx: [B], x: [B, S, D] | ||
B, S, _ = x.shape | ||
kv_size = self.n_head * self.head_dim | ||
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) | ||
|
||
q = q.view(B, S, self.n_head, self.head_dim) | ||
k = k.view(B, S, self.n_head, self.head_dim) | ||
v = v.view(B, S, self.n_head, self.head_dim) | ||
|
||
freqs_cis = self.freqs_cis.unsqueeze(0)[ | ||
torch.zeros((B, 1), dtype=torch.int), input_pos | ||
] # [B, S, D//2, 2] | ||
|
||
q = apply_rotary_emb(q, freqs_cis) | ||
k = apply_rotary_emb(k, freqs_cis) | ||
|
||
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) | ||
|
||
# Comparing with NonPagedAttention, here is the only change for updating kv cache | ||
self.paged_attention.assign( | ||
batch_idx, input_pos, k, v, self.k_cache_paged, self.v_cache_paged | ||
) | ||
|
||
y = flex_attention( | ||
q, | ||
self.k_cache_paged, | ||
self.v_cache_paged, | ||
block_mask=converted_block_mask, | ||
score_mod=converted_score_mod, | ||
) | ||
|
||
y = y.transpose(1, 2).contiguous().view(B, S, -1) | ||
|
||
y = self.wo(y) | ||
return y | ||
|
||
|
||
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: | ||
# x: [B, S, H, D], freqs_cis: [B, S, D//2, 2] | ||
xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # [B, S, H, D//2, 2] | ||
freqs_cis = freqs_cis.view( | ||
xshaped.size(0), xshaped.size(1), 1, xshaped.size(3), 2 | ||
) # [B, S, 1, D//2, 2] | ||
x_out2 = torch.stack( | ||
[ | ||
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], | ||
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], | ||
], | ||
-1, | ||
) | ||
|
||
x_out2 = x_out2.flatten(3) | ||
return x_out2.type_as(x) | ||
|
||
|
||
def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Dict): | ||
factor = rope_scaling["factor"] | ||
low_freq_factor = rope_scaling["low_freq_factor"] | ||
high_freq_factor = rope_scaling["high_freq_factor"] | ||
old_context_len = rope_scaling["original_max_position_embeddings"] | ||
|
||
low_freq_wavelen = old_context_len / low_freq_factor | ||
high_freq_wavelen = old_context_len / high_freq_factor | ||
new_freqs = [] | ||
for freq in freqs: | ||
wavelen = 2 * math.pi / freq | ||
if wavelen < high_freq_wavelen: | ||
new_freqs.append(freq) | ||
elif wavelen > low_freq_wavelen: | ||
new_freqs.append(freq / factor) | ||
else: | ||
assert low_freq_wavelen != high_freq_wavelen | ||
smooth = (old_context_len / wavelen - low_freq_factor) / ( | ||
high_freq_factor - low_freq_factor | ||
) | ||
new_freqs.append((1 - smooth) * freq / factor + smooth * freq) | ||
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) | ||
|
||
|
||
def precompute_freqs_cis( | ||
seq_len: int, | ||
n_elem: int, | ||
base: int = 10000, | ||
dtype: torch.dtype = torch.bfloat16, | ||
rope_scaling: Optional[dict] = None, | ||
) -> Tensor: | ||
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) | ||
if rope_scaling is not None: | ||
freqs = apply_rope_scaling(freqs, rope_scaling) | ||
t = torch.arange(seq_len, device=freqs.device) | ||
freqs = torch.outer(t, freqs) | ||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) | ||
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) | ||
return cache.to(dtype=dtype, device="cuda") |
Oops, something went wrong.