-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrate chunked prefill into t3k Llama3-70B (#15921)
- Loading branch information
1 parent
05886fd
commit 8d01f5d
Showing
15 changed files
with
926 additions
and
164 deletions.
There are no files selected for viewing
351 changes: 351 additions & 0 deletions
351
models/demos/t3000/llama2_70b/tests/test_chunked_generation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,351 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
from loguru import logger | ||
import torch | ||
import ttnn | ||
from ttnn import ReplicateTensorToMesh | ||
|
||
from models.demos.t3000.llama2_70b.reference.llama.llama import Llama | ||
from models.demos.t3000.llama2_70b.tt.llama_generation import ( | ||
TtLlamaModelForGeneration, | ||
get_block_size, | ||
num_blocks_in_seq, | ||
) | ||
from models.demos.t3000.llama2_70b.tt.llama_common import ( | ||
setup_llama_env, | ||
check_mesh_device, | ||
comp_pcc, | ||
load_llama_state_dict, | ||
) | ||
from models.demos.t3000.llama2_70b.demo.demo_continuous_batching_paged_attention import ( | ||
PagedAttentionConfig, | ||
ModelArgs, | ||
TTArgs, | ||
) | ||
|
||
|
||
def run_chunked_prefill_single_user(model_args, tt_args, chunk_size): | ||
# Set up paged attention config | ||
paged_attention_config = PagedAttentionConfig() | ||
bsz = model_args.max_batch_size | ||
|
||
# Create static page table (same as in demo) | ||
permutation = torch.randperm(paged_attention_config.max_num_blocks) | ||
reverse_permutation = torch.argsort(permutation) | ||
static_page_table = reverse_permutation.reshape(bsz, paged_attention_config.max_num_blocks // bsz) | ||
|
||
# Build reference generator | ||
ref_generator = Llama.build( | ||
ckpt_dir=model_args.ckpt_dir, | ||
tokenizer_path=model_args.tokenizer_path, | ||
max_seq_len=model_args.max_seq_len, | ||
max_batch_size=model_args.max_batch_size, | ||
skip_model_load=model_args.skip_model_load, | ||
n_layers=model_args.num_layers, | ||
) | ||
|
||
# Load state dict for TT model | ||
state_dict = load_llama_state_dict(model_args.ckpt_dir, n_layers=model_args.num_layers) | ||
|
||
# Build TT generator with paged attention | ||
tt_model = TtLlamaModelForGeneration( | ||
configuration=ref_generator.model.params, | ||
state_dict=state_dict, | ||
model_args=model_args, | ||
tt_args=tt_args, | ||
paged_attention_config=paged_attention_config, | ||
) | ||
|
||
# For testing, override the max prefill length | ||
tt_model.model_config["MAX_PREFILL_SEQ_LEN"] = chunk_size | ||
|
||
# Extract the model's KV cache such that we can pass it in to the forward function | ||
kv_cache = [l.attention.layer_past for l in tt_model.tt_model.layers] | ||
|
||
# Create random input | ||
seq_len = model_args.max_seq_len | ||
input_tokens = torch.randint(0, 32000, (1, seq_len), dtype=torch.long) | ||
|
||
# Slice out relevant part of page table | ||
block_size = get_block_size(kv_cache) | ||
num_blocks = num_blocks_in_seq(seq_len, block_size) | ||
static_page_table = static_page_table[:, :num_blocks] | ||
|
||
# Run both models | ||
with torch.no_grad(): | ||
tt_logits = tt_model.prefill_forward_single_user( | ||
input_tokens, | ||
start_pos=0, | ||
user_id=0, | ||
page_table=static_page_table, | ||
kv_cache=kv_cache, | ||
) | ||
ref_logits = ref_generator.model.forward(input_tokens, start_pos=0) | ||
|
||
# Compare outputs | ||
does_pass, pcc = comp_pcc(ref_logits, tt_logits, pcc=0.99) | ||
logger.info(f"PCC between reference and TT model logits: {pcc}") | ||
assert does_pass, f"Logits PCC {pcc} below threshold of 0.99" | ||
|
||
ref_kv_cache = [[l.attention.cache_k, l.attention.cache_v] for l in ref_generator.model.layers] | ||
# Compare KV caches | ||
for layer_idx in range(len(kv_cache)): | ||
tt_cache = kv_cache[layer_idx] | ||
ref_cache = ref_kv_cache[layer_idx] | ||
|
||
# Unshuffle paged cache and review it as unpaged cache (similar to paged_update_cache test) | ||
tt_got_back_shuffled = [ | ||
ttnn.to_torch(kv, mesh_composer=ttnn.ConcatMeshToTensor(tt_args.mesh_device, dim=1)) for kv in tt_cache | ||
] | ||
tt_got_back_unshuffled = [shuffled[reverse_permutation] for shuffled in tt_got_back_shuffled] | ||
|
||
# Reshape to match reference cache dimensions | ||
max_num_blocks = tt_got_back_shuffled[0].shape[0] | ||
block_size = tt_got_back_shuffled[0].shape[2] | ||
num_heads = tt_got_back_shuffled[0].shape[1] | ||
head_dim = tt_got_back_shuffled[0].shape[3] | ||
tt_got_back = [ | ||
unshuffled.reshape(1, max_num_blocks, num_heads, block_size, head_dim) | ||
.transpose(1, 2) | ||
.reshape(1, num_heads, -1, head_dim) | ||
for unshuffled in tt_got_back_unshuffled | ||
] | ||
|
||
for i in range(len(tt_got_back)): | ||
ref_cache_slice = ref_cache[i][:1, :seq_len, :, :].permute(0, 2, 1, 3) | ||
# Compare caches | ||
does_pass_cache, pcc_cache = comp_pcc(ref_cache_slice, tt_got_back[i][:, :, :seq_len, :]) | ||
logger.info(f"PCC between reference and TT model KV cache at layer {layer_idx}: {pcc_cache}") | ||
assert does_pass_cache, f"KV cache PCC {pcc_cache} below threshold at layer {layer_idx}" | ||
|
||
return does_pass, pcc | ||
|
||
|
||
@torch.no_grad() | ||
@pytest.mark.timeout(240000) | ||
@pytest.mark.parametrize( | ||
"llama_version", | ||
["llama3"], | ||
) | ||
@pytest.mark.parametrize( | ||
"num_layers", | ||
[ | ||
1, | ||
], | ||
ids=[ | ||
"1L", | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"max_batch_size, max_context_len, chunk_size", | ||
[(1, 128 * 1024, 32 * 1024), (16, 8 * 1024, 2 * 1024), (32, 2 * 1024, 1 * 1024)], | ||
ids=["1BSZ", "16BSZ", "32BSZ"], | ||
) | ||
def test_chunked_prefill_single_user( | ||
t3k_mesh_device, llama_version, num_layers, max_batch_size, max_context_len, chunk_size | ||
): | ||
""" | ||
This test ensures that chunked prefill, when used by calling `prefill_forward_single_user`, | ||
matches the reference implementation. | ||
""" | ||
if max_context_len == 128 * 1024: | ||
pytest.skip("Skipping test for max_context_len = 128*1024 since reference runs OOM") | ||
# Set up environment | ||
model_config, ckpt_dir, tokenizer_path, cache_path = setup_llama_env( | ||
llama_version=llama_version, | ||
) | ||
|
||
# Check device compatibility | ||
check_mesh_device(t3k_mesh_device, model_config) | ||
t3k_mesh_device.enable_async(True) | ||
|
||
# Create args | ||
model_args = ModelArgs( | ||
implementation="tt", | ||
llama_version=llama_version, | ||
ckpt_dir=ckpt_dir, | ||
tokenizer_path=tokenizer_path, | ||
max_batch_size=max_batch_size, | ||
num_layers=num_layers, | ||
max_seq_len=max_context_len, | ||
max_kv_context_len=max_context_len, | ||
) | ||
|
||
tt_args = TTArgs( | ||
mesh_device=t3k_mesh_device, | ||
n_devices=8, | ||
cache_path=cache_path, | ||
) | ||
|
||
# Run test | ||
does_pass, pcc = run_chunked_prefill_single_user(model_args, tt_args, chunk_size) | ||
assert does_pass, f"Test failed with PCC {pcc}" | ||
|
||
|
||
def run_batch_prefill_test(model_args, tt_args, chunk_size, batch): | ||
# Set up paged attention config | ||
paged_attention_config = PagedAttentionConfig() | ||
# Create static page table (same as in demo) | ||
permutation = torch.randperm(paged_attention_config.max_num_blocks) | ||
reverse_permutation = torch.argsort(permutation) | ||
static_page_table = reverse_permutation.reshape( | ||
model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size | ||
) | ||
# Build reference generator | ||
ref_generator = Llama.build( | ||
ckpt_dir=model_args.ckpt_dir, | ||
tokenizer_path=model_args.tokenizer_path, | ||
max_seq_len=model_args.max_seq_len, | ||
max_batch_size=model_args.max_batch_size, | ||
skip_model_load=model_args.skip_model_load, | ||
n_layers=model_args.num_layers, | ||
) | ||
|
||
# Load state dict for TT model | ||
state_dict = load_llama_state_dict(model_args.ckpt_dir, n_layers=model_args.num_layers) | ||
|
||
# Build TT generator with paged attention | ||
tt_model = TtLlamaModelForGeneration( | ||
configuration=ref_generator.model.params, | ||
state_dict=state_dict, | ||
model_args=model_args, | ||
tt_args=tt_args, | ||
paged_attention_config=paged_attention_config, | ||
) | ||
|
||
# For testing, override the max prefill length | ||
tt_model.model_config["MAX_PREFILL_SEQ_LEN"] = chunk_size | ||
|
||
# Extract the model's KV cache such that we can pass it in to the forward function | ||
kv_cache = [l.attention.layer_past for l in tt_model.tt_model.layers] | ||
|
||
# Create random input with varying sequence lengths | ||
max_seq_len = model_args.max_seq_len | ||
prompt_lens = torch.randint( | ||
chunk_size, max_seq_len + 1, (batch,) | ||
) # Random lengths between chunk_size and max_seq_len | ||
input_tokens = torch.randint(0, 32000, (batch, max_seq_len), dtype=torch.long) | ||
logger.info(f"Prompt lengths: {prompt_lens}") | ||
batch_page_table = static_page_table[:batch] | ||
# Run both models | ||
with torch.no_grad(): | ||
tt_logits = tt_model.prefill_forward( | ||
input_tokens, | ||
start_pos=0, | ||
page_table=batch_page_table, | ||
kv_cache=kv_cache, | ||
prompt_lens=prompt_lens, | ||
) | ||
logger.info(f"TT logits shape: {tt_logits.shape}") | ||
|
||
# Run reference model | ||
batch_logits = ref_generator.model.forward(input_tokens, start_pos=0) | ||
ref_logits = batch_logits[torch.arange(batch), prompt_lens - 1, :].unsqueeze(1) # Only keep last token's logits | ||
ref_kv_cache = [[l.attention.cache_k, l.attention.cache_v] for l in ref_generator.model.layers] | ||
|
||
# Compare outputs | ||
does_pass, pcc = comp_pcc(ref_logits, tt_logits, pcc=0.99) | ||
logger.info(f"PCC between reference and TT model: {pcc}") | ||
assert does_pass, f"PCC {pcc} below threshold of 0.99" | ||
|
||
# Compare KV caches | ||
for layer_idx in range(len(kv_cache)): | ||
tt_cache = kv_cache[layer_idx] | ||
ref_cache = ref_kv_cache[layer_idx] | ||
|
||
# Unshuffle paged cache and review it as unpaged cache (similar to paged_update_cache test) | ||
tt_got_back_shuffled = [ | ||
ttnn.to_torch(kv, mesh_composer=ttnn.ConcatMeshToTensor(tt_args.mesh_device, dim=1)) for kv in tt_cache | ||
] | ||
tt_got_back_unshuffled = [shuffled[reverse_permutation] for shuffled in tt_got_back_shuffled] | ||
|
||
# Reshape to match reference cache dimensions | ||
max_num_blocks = tt_got_back_shuffled[0].shape[0] | ||
block_size = tt_got_back_shuffled[0].shape[2] | ||
num_heads = tt_got_back_shuffled[0].shape[1] | ||
head_dim = tt_got_back_shuffled[0].shape[3] | ||
tt_got_back = [ | ||
unshuffled.reshape( | ||
model_args.max_batch_size, max_num_blocks // model_args.max_batch_size, num_heads, block_size, head_dim | ||
) | ||
.transpose(1, 2) | ||
.reshape(model_args.max_batch_size, num_heads, -1, head_dim) | ||
for unshuffled in tt_got_back_unshuffled | ||
] | ||
|
||
for b in range(batch): | ||
valid_seq_len = prompt_lens[b] | ||
logger.info(f"valid seq len: {valid_seq_len}") | ||
for i in range(len(tt_got_back)): | ||
logger.info(f"layer {i}, batch {b}") | ||
logger.info(f"ref cache shape: {ref_cache[i].shape}") | ||
logger.info(f"tt cache shape: {tt_got_back[i].shape}") | ||
ref_cache_slice = ref_cache[i][b : b + 1, :valid_seq_len, :, :].permute(0, 2, 1, 3) | ||
tt_cache_slice = tt_got_back[i][b : b + 1, :, :valid_seq_len, :] | ||
# Compare caches | ||
does_pass_cache, pcc_cache = comp_pcc(ref_cache_slice, tt_cache_slice) | ||
logger.info(f"PCC between reference and TT model KV cache at layer {layer_idx}: {pcc_cache}") | ||
assert does_pass_cache, f"KV cache PCC {pcc_cache} below threshold at layer {layer_idx}" | ||
|
||
return does_pass, pcc | ||
|
||
|
||
@torch.no_grad() | ||
@pytest.mark.timeout(240000) | ||
@pytest.mark.parametrize( | ||
"llama_version", | ||
["llama3"], | ||
) | ||
@pytest.mark.parametrize( | ||
"num_layers", | ||
[ | ||
1, | ||
], | ||
ids=[ | ||
"1L", | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"max_batch_size, max_context_len, chunk_size, batch", | ||
[(1, 128 * 1024, 32 * 1024, 1), (16, 8 * 1024, 2 * 1024, 4), (32, 2 * 1024, 1 * 1024, 4)], | ||
ids=["1BSZ", "16BSZ", "32BSZ"], | ||
) | ||
def test_batch_prefill(t3k_mesh_device, llama_version, num_layers, max_batch_size, max_context_len, chunk_size, batch): | ||
""" | ||
This test ensures that batch prefill matches the reference implementation | ||
when processing multiple sequences of different lengths. | ||
""" | ||
if max_context_len == 128 * 1024: | ||
pytest.skip("Skipping test for max_context_len = 128*1024 since reference runs OOM") | ||
# Set up environment | ||
model_config, ckpt_dir, tokenizer_path, cache_path = setup_llama_env( | ||
llama_version=llama_version, | ||
) | ||
|
||
# Check device compatibility | ||
check_mesh_device(t3k_mesh_device, model_config) | ||
t3k_mesh_device.enable_async(True) | ||
|
||
# Create args | ||
model_args = ModelArgs( | ||
implementation="tt", | ||
llama_version=llama_version, | ||
ckpt_dir=ckpt_dir, | ||
tokenizer_path=tokenizer_path, | ||
max_batch_size=max_batch_size, | ||
num_layers=num_layers, | ||
max_seq_len=max_context_len, | ||
max_kv_context_len=max_context_len, | ||
) | ||
|
||
tt_args = TTArgs( | ||
mesh_device=t3k_mesh_device, | ||
n_devices=8, | ||
cache_path=cache_path, | ||
) | ||
|
||
# Run test | ||
does_pass, pcc = run_batch_prefill_test(model_args, tt_args, chunk_size, batch) | ||
assert does_pass, f"Test failed with PCC {pcc}" |
Oops, something went wrong.