From 232f8c6954b552707c455266f03f9abc68e2ce35 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 14 Jan 2025 17:01:26 -0800 Subject: [PATCH] Add end-to-end example for paged attention (#104) * e2e example for paged attention --- README.md | 1 + attn_gym/paged_attention/latency.py | 142 ++++++++++++++++ attn_gym/paged_attention/model.py | 215 ++++++++++++++++++++++++ attn_gym/paged_attention/throughput.py | 220 +++++++++++++++++++++++++ attn_gym/paged_attention/utils.py | 152 +++++++++++++++++ attn_gym/utils.py | 6 +- 6 files changed, 733 insertions(+), 3 deletions(-) create mode 100644 attn_gym/paged_attention/latency.py create mode 100644 attn_gym/paged_attention/model.py create mode 100644 attn_gym/paged_attention/throughput.py create mode 100644 attn_gym/paged_attention/utils.py diff --git a/README.md b/README.md index 3e0ecbf..3b88427 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ Attention Gym is organized for easy exploration of attention mechanisms: - `attn_gym.masks`: Examples creating `BlockMasks` - `attn_gym.mods`: Examples creating `score_mods` +- `attn_gym.paged_attention`: Examples using `PagedAttention` - `examples/`: Detailed implementations using FlexAttention ## 🛠️ Dev diff --git a/attn_gym/paged_attention/latency.py b/attn_gym/paged_attention/latency.py new file mode 100644 index 0000000..587f6cd --- /dev/null +++ b/attn_gym/paged_attention/latency.py @@ -0,0 +1,142 @@ +""" +Benchmarking the latency of a paged attention layer against a non-paged attention layer. + +Command: + python3 latency.py --setting change_max_seq_len +""" + +import torch +from torch.nn.attention.flex_attention import ( + create_block_mask, + noop_mask, +) +from torch._inductor.runtime.benchmarking import benchmarker + +from utils import random_init_paged_attention, gen_offset, generate_score_mod + +dtype = torch.bfloat16 + + +def benchmark_layer( + bsz, + n_heads, + max_seq_len, + head_dim, + paged_attention, + batch_idx, + input_pos, + block_mask, + score_mod, + converted_block_mask, + converted_score_mod, + dtype=torch.bfloat16, +): + from model import NonPagedAttentionLayer, PagedAttentionLayer + + # compile model + non_paged_foo = torch.compile( + NonPagedAttentionLayer(bsz, n_heads, max_seq_len, head_dim, dtype), fullgraph=True + ) + paged_foo = torch.compile( + PagedAttentionLayer(n_heads, head_dim, dtype, paged_attention), fullgraph=True + ) + + with torch.no_grad(): + # randomize a token embedding + x = torch.randn(bsz, 1, n_heads * head_dim, device="cuda", dtype=dtype) + + # warmup + for _ in range(10): + non_paged_foo(batch_idx, input_pos, x, block_mask, score_mod) + paged_foo(batch_idx, input_pos, x, converted_block_mask, converted_score_mod) + + # benchmark + non_paged_latency = benchmarker.benchmark_gpu( + lambda: non_paged_foo(batch_idx, input_pos, x, block_mask, score_mod) + ) + paged_latency = benchmarker.benchmark_gpu( + lambda: paged_foo(batch_idx, input_pos, x, converted_block_mask, converted_score_mod) + ) + print( + f"non_paged_latency: {non_paged_latency} ms, paged_latency: {paged_latency} ms, overhead: {round((paged_latency / non_paged_latency - 1.0) * 100, 2)}%" + ) + + +def benchmark( + attn_type: str, page_size: int, bsz: int, max_seq_len: int, n_heads: int, head_dim: int +): + # For decoding benchmark, we set input_pos to be half of max_seq_len + input_pos = torch.tensor([max_seq_len // 2] * bsz, device="cuda", dtype=torch.int32).view( + bsz, 1 + ) # [bsz, 1] + batch_idx = torch.arange(bsz, device="cuda", dtype=torch.int32) # [bsz] + + # init paged attention + n_pages = (max_seq_len + page_size - 1) // page_size * bsz + paged_attention = random_init_paged_attention(n_pages, page_size, bsz, max_seq_len) + + # Block mask + if attn_type == "causal": + mask_mod = gen_offset( + torch.tensor([max_seq_len // 2] * bsz, device="cuda", dtype=torch.int32) + ) + else: + mask_mod = noop_mask + block_mask = create_block_mask(mask_mod, bsz, 1, 1, max_seq_len, BLOCK_SIZE=page_size) + converted_block_mask = paged_attention.convert_logical_block_mask(block_mask) + + # Score mod + score_mod = generate_score_mod(attn_type) + converted_score_mod = paged_attention.get_score_mod(score_mod) + + benchmark_layer( + bsz, + n_heads, + max_seq_len, + head_dim, + paged_attention, + batch_idx, + input_pos, + block_mask, + score_mod, + converted_block_mask, + converted_score_mod, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--setting", type=str, default="change_max_seq_len") + args = parser.parse_args() + + if args.setting == "change_max_seq_len": + max_seq_len_candidates = [2048, 4096, 8192, 16384, 32768] + bsz_candidates = [32] + page_size_candidates = [128] + elif args.setting == "change_bsz": + max_seq_len_candidates = [8192] + bsz_candidates = [32, 64, 128] + page_size_candidates = [128] + elif args.setting == "change_page_size": + max_seq_len_candidates = [8192] + bsz_candidates = [32] + page_size_candidates = [64, 128, 256] + else: + raise NotImplementedError + + n_heads, head_dim = 16, 64 + + for attn_type in ["noop", "causal", "rel", "head_bias"]: + print(f"\nattn_type:{attn_type}") + for page_size in page_size_candidates: + print(f"page_size:{page_size}") + for bsz in bsz_candidates: + for max_seq_len in max_seq_len_candidates: + torch._dynamo.reset() + + print( + f"\nbsz: {bsz}, max_seq_len: {max_seq_len}, head_dim: {head_dim}, n_heads: {n_heads}" + ) + benchmark(attn_type, page_size, bsz, max_seq_len, n_heads, head_dim) diff --git a/attn_gym/paged_attention/model.py b/attn_gym/paged_attention/model.py new file mode 100644 index 0000000..0730f7d --- /dev/null +++ b/attn_gym/paged_attention/model.py @@ -0,0 +1,215 @@ +import torch +import math +from torch.nn.attention.flex_attention import BlockMask, flex_attention, _score_mod_signature +from torch import Tensor +from typing import Dict, Optional + + +class NonPagedAttentionLayer(torch.nn.Module): + """An attention layer without paged attention, ported from GPT-Fast: + https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L180-L227 + """ + + def __init__(self, bsz, n_heads, max_seq_len, head_dim, dtype, block_size: int = 32768): + super().__init__() + self.n_head = n_heads + self.head_dim = head_dim + + # key, query, value projections for all heads, but in a batch + total_head_dim = n_heads * head_dim + self.wqkv = torch.nn.Linear( + total_head_dim, 3 * total_head_dim, bias=False, device="cuda", dtype=dtype + ) + self.wo = torch.nn.Linear( + total_head_dim, total_head_dim, bias=False, device="cuda", dtype=dtype + ) + self.k_cache = torch.randn( + (bsz, n_heads, max_seq_len, head_dim), device="cuda", dtype=dtype + ) + self.v_cache = torch.randn( + (bsz, n_heads, max_seq_len, head_dim), device="cuda", dtype=dtype + ) + self.freqs_cis = precompute_freqs_cis(block_size, self.head_dim, dtype=dtype) + + def forward( + self, + batch_idx: Tensor, + input_pos: Tensor, + x: Tensor, + block_mask: BlockMask, + score_mod: _score_mod_signature, + ) -> Tensor: + # input_pos: [B, S], batch_idx: [B], x: [B, S, D] + B, S, _ = x.shape + + kv_size = self.n_head * self.head_dim + q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) + + q = q.view(B, S, self.n_head, self.head_dim) + k = k.view(B, S, self.n_head, self.head_dim) + v = v.view(B, S, self.n_head, self.head_dim) + + freqs_cis = self.freqs_cis.unsqueeze(0)[ + torch.zeros((B, 1), dtype=torch.int), input_pos + ] # [B, S, D//2, 2] + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q = q.transpose(1, 2) + self.k_cache[batch_idx.view(B, 1), :, input_pos] = k + self.v_cache[batch_idx.view(B, 1), :, input_pos] = v + + y = flex_attention( + q, self.k_cache, self.v_cache, block_mask=block_mask, score_mod=score_mod + ) + + y = y.transpose(1, 2).contiguous().view(B, S, -1) + + y = self.wo(y) + return y + + +class PagedAttentionLayer(torch.nn.Module): + """An attention layer with paged attention""" + + def __init__(self, n_heads, head_dim, dtype, paged_attention, block_size: int = 65536): + super().__init__() + self.n_head = n_heads + self.head_dim = head_dim + + # key, query, value projections for all heads, but in a batch + total_head_dim = n_heads * head_dim + self.wqkv = torch.nn.Linear( + total_head_dim, 3 * total_head_dim, bias=False, device="cuda", dtype=dtype + ) + self.wo = torch.nn.Linear( + total_head_dim, total_head_dim, bias=False, device="cuda", dtype=dtype + ) + + # allocate kv cache with batch size=1 for paged attention + max_cached_seq_len = paged_attention.n_pages * paged_attention.page_size + self.k_cache_paged = torch.randn( + 1, + n_heads, + max_cached_seq_len, + head_dim, + device="cuda", + dtype=dtype, + ) + self.v_cache_paged = torch.randn( + 1, + n_heads, + max_cached_seq_len, + head_dim, + device="cuda", + dtype=dtype, + ) + self.paged_attention = paged_attention + + self.freqs_cis = precompute_freqs_cis( + block_size, self.head_dim, dtype=dtype + ) # [block_size, D//2, 2] + + def forward( + self, + batch_idx: Tensor, + input_pos: Tensor, + x: Tensor, + converted_block_mask: BlockMask, + converted_score_mod: _score_mod_signature, + ) -> Tensor: + # input_pos: [B, S], batch_idx: [B], x: [B, S, D] + B, S, _ = x.shape + kv_size = self.n_head * self.head_dim + q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) + + q = q.view(B, S, self.n_head, self.head_dim) + k = k.view(B, S, self.n_head, self.head_dim) + v = v.view(B, S, self.n_head, self.head_dim) + + freqs_cis = self.freqs_cis.unsqueeze(0)[ + torch.zeros((B, 1), dtype=torch.int), input_pos + ] # [B, S, D//2, 2] + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + # Comparing with NonPagedAttention, here is the only change for updating kv cache + self.paged_attention.assign( + batch_idx, input_pos, k, v, self.k_cache_paged, self.v_cache_paged + ) + + y = flex_attention( + q, + self.k_cache_paged, + self.v_cache_paged, + block_mask=converted_block_mask, + score_mod=converted_score_mod, + ) + + y = y.transpose(1, 2).contiguous().view(B, S, -1) + + y = self.wo(y) + return y + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + # x: [B, S, H, D], freqs_cis: [B, S, D//2, 2] + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # [B, S, H, D//2, 2] + freqs_cis = freqs_cis.view( + xshaped.size(0), xshaped.size(1), 1, xshaped.size(3), 2 + ) # [B, S, 1, D//2, 2] + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) + + +def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Dict): + factor = rope_scaling["factor"] + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + old_context_len = rope_scaling["original_max_position_embeddings"] + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + +def precompute_freqs_cis( + seq_len: int, + n_elem: int, + base: int = 10000, + dtype: torch.dtype = torch.bfloat16, + rope_scaling: Optional[dict] = None, +) -> Tensor: + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + if rope_scaling is not None: + freqs = apply_rope_scaling(freqs, rope_scaling) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=dtype, device="cuda") diff --git a/attn_gym/paged_attention/throughput.py b/attn_gym/paged_attention/throughput.py new file mode 100644 index 0000000..14d5ec2 --- /dev/null +++ b/attn_gym/paged_attention/throughput.py @@ -0,0 +1,220 @@ +""" +Benchmmarking the throughput of paged attention layer in terms of the +maximum batch size that can be served. + +The benchmark is based on the prompt and response length distribution +collected from OpenOrca dataset (https://huggingface.co/datasets/Open-Orca/OpenOrca), +including ~1M GPT-4 completions and ~3.2M GPT-3.5 completions. + +For a fair comparison, we assume 4GB KV cache memory budget w/ and w/o paged attention. +We assume bfloat16 as the data type, 4 heads, and 64 embedding dim. + +[No Paged Attention] kv cache requires + 2 * (2 * b * h * kv_len * d) +bytes, where the first 2 is for kv cache for query and key, the second 2 is for bfloat16, +b is batch size, h is number of heads, kv_len is kv cache length, and d is embedding dim. +Taking the context length of 131072 from llama-3.1, the max batch size to serve is 32. + +[Paged Attention] kv cache requires + 2 * (2 * h * n_pages * page_size * d) +bytes. Assuming a page size of 128, there could be at most 32768 pages. +We empirically observe that the max batch size to serve is 2448, which is 76x of the +max batch size without paged attention. +""" + +import torch +from torch.nn.attention.experimental._paged_attention import PagedAttention +from torch.nn.attention.flex_attention import ( + _identity, + BlockMask, + create_block_mask, +) +from datasets import load_dataset +import random +from collections import deque +from typing import Tuple +from utils import gen_offset, slice_block_mask +from model import PagedAttentionLayer + +create_block_mask = torch.compile(create_block_mask) + + +class Requests: + def __init__(self): + self.data = load_dataset("Open-Orca/OpenOrca")["train"] + + def sample_request(self): + # sample a prompt len and response len from openorca dataset + # to simulate a real world use case + idx = random.randint(0, len(self.data) - 1) + prompt_len = len(self.data[idx]["system_prompt"]) + len(self.data[idx]["question"]) + response_len = len(self.data[idx]["response"]) + return prompt_len, response_len + + +class Server: + def __init__(self, batch_size: int, n_pages: int, page_size: int, n_heads: int, head_dim: int): + self.paged_attention = PagedAttention(n_pages, page_size, batch_size) + + self.model = torch.compile( + PagedAttentionLayer(n_heads, head_dim, torch.bfloat16, self.paged_attention) + ) + + self.batch_size = batch_size + self.n_heads = n_heads + self.head_dim = head_dim + self.bsz_watermark = 0 # max batch size served during benchmark + + self.available_batch_idx = list(range(batch_size))[::-1] + self.request_queue = deque([]) + self.batch_idx = [] + self.input_pos = torch.zeros(batch_size, device="cuda", dtype=torch.int64) + self.request_length = torch.tensor( + [float("inf")] * batch_size, device="cuda" + ) # decide whether a request is completed + + self.token_embedding = torch.randn( + (batch_size, 1, n_heads * head_dim), device="cuda", dtype=torch.bfloat16 + ) # [B, 1, n_heads*head_dim] + + self.block_mask = create_block_mask( + lambda b, h, q, kv: q >= kv, batch_size, 1, 64 * 1024, 64 * 1024, BLOCK_SIZE=page_size + ) + + def receive_request(self, prompt_len: int, response_len: int): + # assume we know prompt length and response length in advance. + self.request_queue.append((prompt_len, response_len)) + + def can_schedule(self, request: Tuple[int, int]) -> bool: + return len(self.paged_attention.empty_pages) * self.paged_attention.page_size >= sum( + request + ) + + def prefill_one_token(self, batch_idx: int, prompt_len: int, response_len: int): + # allocate page table + # in practice we don't know response length in advance. A good way is to use a heuristic to estimate response length + # and allocate page table accordingly. We may also allocate pages on the fly. For simplicity, we assume we know + # response length in advance. + self.paged_attention.reserve( + torch.tensor(batch_idx, device="cuda"), + torch.tensor(prompt_len + response_len, device="cuda"), + ) + + # simulate input token embedding + token_embedding = torch.randn( + 1, prompt_len, self.head_dim * self.n_heads, device="cuda", dtype=torch.bfloat16 + ) + + # generate block mask. The same block mask is used for all layers. + new_block_mask = slice_block_mask(self.block_mask, batch_idx, prompt_len, prompt_len) + converted_block_mask = self.paged_attention.convert_logical_block_mask( + new_block_mask, torch.tensor([batch_idx], device="cuda") + ) + converted_score_mod = self.paged_attention.get_score_mod(_identity) + + prefill_input_pos = torch.arange(prompt_len, device="cuda").view(1, -1) + token_embedding = self.model( + torch.tensor([batch_idx], device="cuda"), + prefill_input_pos, + token_embedding, + converted_block_mask, + converted_score_mod, + ) + return token_embedding + + def prefill(self): + while ( + self.request_queue + and self.can_schedule(self.request_queue[0]) + and self.available_batch_idx + ): + prompt_len, response_len = self.request_queue.popleft() + print( + f"serving a new request with prompt_len: {prompt_len}, response_len: {response_len}" + ) + new_batch_idx = self.available_batch_idx.pop() + token_embedding = self.prefill_one_token(new_batch_idx, prompt_len, response_len) + self.token_embedding[new_batch_idx] = token_embedding[:, -1].view(1, -1) + + self.batch_idx.append(new_batch_idx) + self.input_pos[new_batch_idx] = prompt_len + self.request_length[new_batch_idx] = prompt_len + response_len + + self.bsz_watermark = max(self.bsz_watermark, len(self.batch_idx)) + + def get_decode_mask(self, batch_idx: torch.Tensor, input_pos: torch.Tensor): + # batch_idx: [B], input_pos: [B] + (B,) = batch_idx.shape + input_block_idx = input_pos // self.block_mask.BLOCK_SIZE[0] # [B] + kv_num_blocks = self.block_mask.kv_num_blocks[batch_idx, :, input_block_idx].view(B, 1, 1) + kv_indices = self.block_mask.kv_indices[batch_idx, :, input_block_idx].view(B, 1, 1, -1) + full_kv_num_blocks, full_kv_indices = None, None + if self.block_mask.full_kv_num_blocks is not None: + full_kv_num_blocks = self.block_mask.full_kv_num_blocks[ + batch_idx, :, input_block_idx + ].view(B, 1, 1) + full_kv_indices = self.block_mask.full_kv_indices[batch_idx, :, input_block_idx].view( + B, 1, 1, -1 + ) + seq_length = (1, self.block_mask.seq_lengths[1]) + return BlockMask.from_kv_blocks( + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + BLOCK_SIZE=self.block_mask.BLOCK_SIZE, + mask_mod=gen_offset(input_pos), + seq_lengths=seq_length, + ) + + def decode(self): + B = len(self.batch_idx) + batch_idx = torch.tensor(self.batch_idx, device="cuda").view(-1) # [B]. + input_pos = self.input_pos[batch_idx] # [B] + mask = self.get_decode_mask(batch_idx, input_pos) + converted_block_mask = self.paged_attention.convert_logical_block_mask(mask, batch_idx) + converted_score_mod = self.paged_attention.get_score_mod(_identity) + self.token_embedding[batch_idx] = self.model( + batch_idx, + input_pos.view(B, 1), + self.token_embedding[batch_idx], + converted_block_mask, + converted_score_mod, + ) + self.input_pos[batch_idx] += 1 + + def clean(self): + completed_batch_indices = torch.where(self.input_pos >= self.request_length)[0] + self.available_batch_idx += completed_batch_indices.tolist() + self.batch_idx = [ + idx for idx in self.batch_idx if idx not in completed_batch_indices.tolist() + ] + + for b in completed_batch_indices: + self.paged_attention.erase(torch.tensor([b])) + + self.request_length[completed_batch_indices] = float("inf") + + +if __name__ == "__main__": + # serving loop + num_requests = 10 # total number of requests during benchmark + gap = 3 # get a new request after `gap` number of decoding tokens + + batch_size, n_pages, page_size, n_heads, head_dim = 4096, 32768, 128, 4, 64 + + requests = Requests() + server = Server(batch_size, n_pages, page_size, n_heads, head_dim) + + with torch.no_grad(): + for i in range(num_requests): + for _ in range(1024): + server.receive_request(*requests.sample_request()) + + server.prefill() + for _ in range(gap): + server.decode() + + server.clean() + + print("max batch size served: ", server.bsz_watermark) diff --git a/attn_gym/paged_attention/utils.py b/attn_gym/paged_attention/utils.py new file mode 100644 index 0000000..ccc7806 --- /dev/null +++ b/attn_gym/paged_attention/utils.py @@ -0,0 +1,152 @@ +import torch +from torch.nn.attention.experimental._paged_attention import PagedAttention +from torch.nn.attention.flex_attention import ( + _identity, + BlockMask, +) + + +def batch_reserve(paged_attention: PagedAttention, target_seq_len: torch.Tensor): + """Reserves pages for each sequence in the batch. + + Args: + paged_attention: PagedAttention instance. + target_seq_len: Tensor of shape (B,) containing the length of each sequence in the batch. + """ + (B,) = target_seq_len.shape + for b in range(B): + paged_attention.reserve( + torch.tensor(b), + target_seq_len[b], + ) + + +def random_init_paged_attention(n_pages: int, page_size: int, bsz: int, max_seq_len: int): + """Allocate physical pages across batches in a round-robin fashion to simulate a use case + where multiple batches run in parallel. This is for testing and benchmarking only. + + Args: + n_pages: Number of pages. + page_size: Size of each page. + bsz: Batch size. + max_seq_len: Maximum sequence length. + """ + paged_attention = PagedAttention(n_pages, page_size, bsz) + + repeat = bsz // 4 + sequence_lengths = [ + [max_seq_len // 4, max_seq_len // 2, max_seq_len // 4, max_seq_len // 3] * repeat, + [max_seq_len // 4, max_seq_len // 2, max_seq_len // 2, max_seq_len // 2] * repeat, + [max_seq_len // 4, max_seq_len // 2, max_seq_len // 2, max_seq_len // 2] * repeat, + [max_seq_len // 2, max_seq_len, max_seq_len // 2, max_seq_len] * repeat, + [max_seq_len, max_seq_len, max_seq_len, max_seq_len] * repeat, + ] + + for seq_len in sequence_lengths: + batch_reserve( + paged_attention, + torch.tensor(seq_len, device="cuda"), + ) + + return paged_attention + + +def gen_offset(off: torch.Tensor): + """Generates an offset function. + + Args: + off: Offset tensor. + """ + + def offset(b, h, m, n): + return m + off[b] >= n + + return offset + + +def generate_score_mod(attn_type: str): + """Generates a score modification function. + + Args: + attn_type: Attention type. + """ + + def relative_bias(score, b, h, m, n): + return score + (m - n) + + def head_bias(score, b, h, m, n): + return score + 2 * h + + function_dict = { + "noop": _identity, + "causal": _identity, + "rel": relative_bias, + "head_bias": head_bias, + } + return function_dict[attn_type] + + +def _adjust_num_blocks_and_indices( + num_blocks: torch.Tensor, + indices: torch.Tensor, + batch_idx: int, + new_num_rows: int, + new_num_cols: int, +): + """Adjust the number of blocks and indices based on the new number of rows and columns. + + Args: + num_blocks: KV Num Blocks. + indices: KV indices. + batch_idx: Batch index. + new_num_rows: New number of rows. + new_num_cols: New number of columns. + """ + indices = indices[[batch_idx], :, :new_num_rows, :new_num_cols] + num_blocks = num_blocks[[batch_idx], :, :new_num_rows] + num_blocks = torch.where(num_blocks < new_num_cols, num_blocks, new_num_cols) + num_blocks = torch.sum(indices < num_blocks[:, :, :, None], dim=-1).to(torch.int32) + return num_blocks.clone(), indices.clone() + + +def slice_block_mask( + block_mask: BlockMask, batch_idx: int, new_q_len: int, new_kv_len: int +) -> BlockMask: + """Slice the block mask based on the new query and key/value lengths. + + Args: + block_mask: Block mask. + batch_idx: Batch index. + new_q_len: New query length. + new_kv_len: New key/value length. + """ + new_num_rows = (new_q_len + block_mask.BLOCK_SIZE[0] - 1) // block_mask.BLOCK_SIZE[0] + new_num_cols = (new_kv_len + block_mask.BLOCK_SIZE[1] - 1) // block_mask.BLOCK_SIZE[1] + new_kv_num_blocks, new_kv_indices = _adjust_num_blocks_and_indices( + block_mask.kv_num_blocks, block_mask.kv_indices, batch_idx, new_num_rows, new_num_cols + ) + if block_mask.full_kv_num_blocks is not None: + assert block_mask.full_kv_indices is not None + ( + new_full_kv_num_blocks, + new_full_kv_indices, + ) = _adjust_num_blocks_and_indices( + block_mask.full_kv_num_blocks, + block_mask.full_kv_indices, + batch_idx, + new_num_rows, + new_num_cols, + ) + else: + new_full_kv_num_blocks = None + new_full_kv_indices = None + new_block_mask = block_mask.from_kv_blocks( + new_kv_num_blocks, + new_kv_indices, + new_full_kv_num_blocks, + new_full_kv_indices, + block_mask.BLOCK_SIZE, + block_mask.mask_mod, + ) + new_block_mask.seq_lengths = (new_q_len, new_kv_len) + return new_block_mask diff --git a/attn_gym/utils.py b/attn_gym/utils.py index 69cae71..8b6dc97 100644 --- a/attn_gym/utils.py +++ b/attn_gym/utils.py @@ -100,9 +100,9 @@ def visualize_attention_scores( Returns: None """ - assert ( - score_mod is not None or mask_mod is not None - ), "Must provide either score_mod or mask_mod" + assert score_mod is not None or mask_mod is not None, ( + "Must provide either score_mod or mask_mod" + ) query = query[batch_idx, head_idx, :, :] key = key[batch_idx, head_idx, :, :] scores_viz = create_score_mod(