Skip to content

Commit

Permalink
Create chunked-prefill mode in SDPA op (#15907)
Browse files Browse the repository at this point in the history
  • Loading branch information
cglagovichTT authored Dec 12, 2024
1 parent 8926ed0 commit 52cc437
Show file tree
Hide file tree
Showing 11 changed files with 844 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def run_test_sdpa_tt(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype
tt_back = ttnn.transformer.scaled_dot_product_attention(
tt_Q, tt_K, tt_V, is_causal=True, program_config=program_config, compute_kernel_config=compute_kernel_config
)
tt_back = tt_back.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
tt_back = ttnn.to_torch(tt_back)

K_repeated = torch.cat([K[:, i : i + 1, :, :].repeat(1, nh // nkv, 1, 1) for i in range(nkv)], dim=1) # b, nh, d, S
V_repeated = torch.cat([V[:, i : i + 1, :, :].repeat(1, nh // nkv, 1, 1) for i in range(nkv)], dim=1) # b, nh, d, S
Expand Down Expand Up @@ -238,7 +238,7 @@ def run_sdpa_noncausal(device, b, nh, nkv, sq, d, q_chunk_size, k_chunk_size, dt
program_config=program_config,
compute_kernel_config=compute_kernel_config,
)
tt_back = tt_back.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
tt_back = ttnn.to_torch(tt_back)

if nkv > 1 and nkv != nh:
assert nh % nkv == 0
Expand Down Expand Up @@ -297,3 +297,230 @@ def test_sdpa_noncausal_unequal_seqlen(device, b, nh, nkv, sq, sk, d, q_chunk_si
pytest.skip("s must be divisible by q_chunk_size and k_chunk_size")
ttnn.device.DisablePersistentKernelCache()
run_sdpa_noncausal(device, b, nh, nkv, sq, d, q_chunk_size, k_chunk_size, dtype, sk=sk)


def run_test_chunked_sdpa(
device,
b,
nh,
nkv,
s,
d,
q_chunk_size,
k_chunk_size,
prefill_chunk_size,
page_block_size,
q_dtype,
k_dtype,
use_high_precision_compute,
grid_size=None,
):
program_config = ttnn.SDPAProgramConfig(
compute_with_storage_grid_size=grid_size or device.compute_with_storage_grid_size(),
q_chunk_size=q_chunk_size,
k_chunk_size=k_chunk_size,
exp_approx_mode=False,
)

if use_high_precision_compute:
compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi4,
math_approx_mode=False,
fp32_dest_acc_en=True,
packer_l1_acc=False,
)
else:
compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=True,
fp32_dest_acc_en=False,
packer_l1_acc=False,
)

Q = fa_rand(b, nh, s, d)
K = fa_rand(b, nkv, s, d)
V = fa_rand(b, nkv, s, d)
K_repeated = torch.cat([K[:, i : i + 1, :, :].repeat(1, nh // nkv, 1, 1) for i in range(nkv)], dim=1) # b, nh, d, S
V_repeated = torch.cat([V[:, i : i + 1, :, :].repeat(1, nh // nkv, 1, 1) for i in range(nkv)], dim=1) # b, nh, d, S
gt = torch.nn.functional.scaled_dot_product_attention(Q, K_repeated, V_repeated, is_causal=True)

# Print shapes of all inputs along with input names
logger.debug(f"Q: {Q.shape}")
logger.debug(f"K: {K.shape}")
logger.debug(f"V: {V.shape}")

assert s % prefill_chunk_size == 0, "s must be divisible by prefill_chunk_size"
assert prefill_chunk_size % page_block_size == 0, "prefill_chunk_size must be divisible by page_block_size"
num_prefill_chunks = s // prefill_chunk_size
# Prepare K, V paged for TT
max_num_blocks_per_seq = s // page_block_size
assert max_num_blocks_per_seq * page_block_size == s
max_num_blocks = b * max_num_blocks_per_seq
assert max_num_blocks * page_block_size == b * s

# Shuffle paged KV cache according to some random page_table
permutation = torch.randperm(max_num_blocks)
reverse_permutation = torch.argsort(permutation)
# page_table is the reverse permutation from shuffled -> unshuffled, and is used to map
# a virtual block to the physical block id.
page_table = reverse_permutation.reshape(b, max_num_blocks_per_seq)

def page_cache(cache):
paged_cache = (
cache.reshape(b, nkv, max_num_blocks_per_seq, page_block_size, d)
.transpose(1, 2)
.reshape(max_num_blocks, nkv, page_block_size, d)
)

shuffled_page_cache = paged_cache[permutation]
return shuffled_page_cache

def unpage_cache(cache):
unshuffled_page_cache = cache[reverse_permutation]
paged_cache_back = (
unshuffled_page_cache.reshape(b, nkv, max_num_blocks_per_seq, page_block_size, d)
.transpose(1, 2)
.reshape(b, nkv, s, d)
)
return paged_cache_back

# Check that we can convert from normal to paged to normal
assert torch.allclose(unpage_cache(page_cache(K)), K), "K is not equal to unpage_cache(page_cache(K))"
assert torch.allclose(unpage_cache(page_cache(V)), V), "V is not equal to unpage_cache(page_cache(V))"

tt_paged_K = ttnn.Tensor(page_cache(K), k_dtype).to(ttnn.TILE_LAYOUT).to(device)
tt_paged_V = ttnn.Tensor(page_cache(V), k_dtype).to(ttnn.TILE_LAYOUT).to(device)
page_table_tt = ttnn.Tensor(page_table, ttnn.int32).to(device)

for chunk_idx in range(num_prefill_chunks):
# Chunk Q
Q_chunk = Q[:, :, chunk_idx * prefill_chunk_size : (chunk_idx + 1) * prefill_chunk_size]
tt_Q_chunk = ttnn.Tensor(Q_chunk, q_dtype).to(ttnn.TILE_LAYOUT).to(device)
chunk_start_idx = chunk_idx * prefill_chunk_size

tt_back = ttnn.transformer.chunked_scaled_dot_product_attention(
tt_Q_chunk,
tt_paged_K,
tt_paged_V,
page_table_tt,
chunk_start_idx,
program_config=program_config,
compute_kernel_config=compute_kernel_config,
)
tt_back = tt_back.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
gt_chunk = gt[:, :, chunk_idx * prefill_chunk_size : (chunk_idx + 1) * prefill_chunk_size]
out_pass, out_pcc = comp_pcc(gt_chunk, tt_back, 0.998)
logger.debug(f"python vs pytorch: {out_pcc}")
assert out_pass


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled")
@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.parametrize("q_dtype", [ttnn.bfloat16])
@pytest.mark.parametrize("k_dtype", [ttnn.bfloat8_b])
@pytest.mark.parametrize("q_chunk_size", [128, 256], ids=["q128", "q256"])
@pytest.mark.parametrize("k_chunk_size", [128, 256], ids=["k128", "k256"])
@pytest.mark.parametrize("prefill_chunk_size", [1024, 2048])
@pytest.mark.parametrize("page_block_size", [64, 128])
@pytest.mark.parametrize(
"b, nh, nkv, s, d",
[
[1, 8, 1, 16 * 1024, 128],
], # Llama2-70B
)
def test_sdpa_chunked(
device,
b,
nh,
nkv,
s,
d,
q_chunk_size,
k_chunk_size,
prefill_chunk_size,
page_block_size,
q_dtype,
k_dtype,
use_program_cache,
use_high_precision_compute=False,
):
for _ in range(2):
run_test_chunked_sdpa(
device,
b,
nh,
nkv,
s,
d,
q_chunk_size,
k_chunk_size,
prefill_chunk_size,
page_block_size,
q_dtype,
k_dtype,
use_high_precision_compute,
)

# Print number of program cache entries
assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format(
device.num_program_cache_entries()
)


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled")
@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.parametrize("q_dtype", [ttnn.bfloat16])
@pytest.mark.parametrize("k_dtype", [ttnn.bfloat8_b])
@pytest.mark.parametrize("q_chunk_size", [128])
@pytest.mark.parametrize("k_chunk_size", [128])
@pytest.mark.parametrize("prefill_chunk_size", [1024])
@pytest.mark.parametrize("page_block_size", [64])
@pytest.mark.parametrize(
"b, nh, nkv, s, d",
[
[2, 1, 1, 4096, 128],
], # Llama2-70B
)
def test_sdpa_chunked_iterate_batch(
device,
b,
nh,
nkv,
s,
d,
q_chunk_size,
k_chunk_size,
prefill_chunk_size,
page_block_size,
q_dtype,
k_dtype,
use_program_cache,
use_high_precision_compute=False,
):
"""
This tests chunked prefill where a single core has more than one user to process.
"""
for _ in range(2):
run_test_chunked_sdpa(
device,
b,
nh,
nkv,
s,
d,
q_chunk_size,
k_chunk_size,
prefill_chunk_size,
page_block_size,
q_dtype,
k_dtype,
use_high_precision_compute,
grid_size=(1, 1),
)

# Print number of program cache entries
assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format(
device.num_program_cache_entries()
)
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ void MAIN {

constexpr uint32_t is_causal = get_compile_time_arg_val(22) == 1;
constexpr uint32_t use_provided_mask = get_compile_time_arg_val(23) == 1;
constexpr uint32_t is_chunked = get_compile_time_arg_val(24) == 1;

const uint32_t core_id = get_arg_val<uint32_t>(0);
const uint32_t local_batch_start = get_arg_val<uint32_t>(1);
Expand All @@ -368,6 +369,7 @@ void MAIN {
const uint32_t local_nh_end = get_arg_val<uint32_t>(4);
const uint32_t local_q_start = get_arg_val<uint32_t>(5);
const uint32_t local_q_end = get_arg_val<uint32_t>(6);
const uint32_t chunked_q_chunk_offset = get_arg_val<uint32_t>(7);

const uint32_t q_chunks_per_core = local_q_end - local_q_start;

Expand Down Expand Up @@ -413,7 +415,10 @@ void MAIN {
#endif

// Get Q chunk
const uint32_t q_low_idx =
if constexpr (is_chunked) {
q_chunk = chunked_q_chunk_offset + q_chunk;
}
uint32_t q_low_idx =
q_chunk * Sq_chunk_t; // This is the sequence index of the first tile of this chunk
uint32_t q_high_idx;
if constexpr (is_causal) {
Expand Down Expand Up @@ -510,6 +515,7 @@ void MAIN {
out_subblock_h,
out_subblock_w,
false /*transpose*/);

reconfig_data_format_srca(cb_out_im);
cb_pop_front(cb_qk_im, qk_chunk_tiles);

Expand Down
Loading

0 comments on commit 52cc437

Please sign in to comment.