Skip to content

Commit

Permalink
#0: Fix bug where page table CB wasn't popped at end of batch loop
Browse files Browse the repository at this point in the history
  • Loading branch information
cglagovichTT committed Dec 11, 2024
1 parent aca1b57 commit 19ab7b7
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -299,59 +299,6 @@ def test_sdpa_noncausal_unequal_seqlen(device, b, nh, nkv, sq, sk, d, q_chunk_si
run_sdpa_noncausal(device, b, nh, nkv, sq, d, q_chunk_size, k_chunk_size, dtype, sk=sk)


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled")
@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
# @pytest.mark.parametrize("dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["bfp8", "bf16"])
@pytest.mark.parametrize("q_dtype", [ttnn.bfloat16])
@pytest.mark.parametrize("k_dtype", [ttnn.bfloat8_b])
@pytest.mark.parametrize("q_chunk_size", [128, 256], ids=["q128", "q256"])
@pytest.mark.parametrize("k_chunk_size", [128, 256], ids=["k128", "k256"])
@pytest.mark.parametrize("prefill_chunk_size", [1024, 2048])
@pytest.mark.parametrize("page_block_size", [64, 128])
@pytest.mark.parametrize(
"b, nh, nkv, s, d",
([1, 8, 1, 16 * 1024, 128],), # Llama2-70B
)
def test_sdpa_chunked(
device,
b,
nh,
nkv,
s,
d,
q_chunk_size,
k_chunk_size,
prefill_chunk_size,
page_block_size,
q_dtype,
k_dtype,
use_program_cache,
use_high_precision_compute=False,
):
for _ in range(2):
run_test_chunked_sdpa(
device,
b,
nh,
nkv,
s,
d,
q_chunk_size,
k_chunk_size,
prefill_chunk_size,
page_block_size,
q_dtype,
k_dtype,
use_high_precision_compute,
)

# Print number of program cache entries
assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format(
device.num_program_cache_entries()
)


def run_test_chunked_sdpa(
device,
b,
Expand All @@ -366,9 +313,10 @@ def run_test_chunked_sdpa(
q_dtype,
k_dtype,
use_high_precision_compute,
grid_size=None,
):
program_config = ttnn.SDPAProgramConfig(
compute_with_storage_grid_size=device.compute_with_storage_grid_size(),
compute_with_storage_grid_size=grid_size or device.compute_with_storage_grid_size(),
q_chunk_size=q_chunk_size,
k_chunk_size=k_chunk_size,
exp_approx_mode=False,
Expand Down Expand Up @@ -464,3 +412,115 @@ def unpage_cache(cache):
out_pass, out_pcc = comp_pcc(gt_chunk, tt_back, 0.998)
logger.debug(f"python vs pytorch: {out_pcc}")
assert out_pass


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled")
@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.parametrize("q_dtype", [ttnn.bfloat16])
@pytest.mark.parametrize("k_dtype", [ttnn.bfloat8_b])
@pytest.mark.parametrize("q_chunk_size", [128, 256], ids=["q128", "q256"])
@pytest.mark.parametrize("k_chunk_size", [128, 256], ids=["k128", "k256"])
@pytest.mark.parametrize("prefill_chunk_size", [1024, 2048])
@pytest.mark.parametrize("page_block_size", [64, 128])
@pytest.mark.parametrize(
"b, nh, nkv, s, d",
[
[1, 8, 1, 16 * 1024, 128],
], # Llama2-70B
)
def test_sdpa_chunked(
device,
b,
nh,
nkv,
s,
d,
q_chunk_size,
k_chunk_size,
prefill_chunk_size,
page_block_size,
q_dtype,
k_dtype,
use_program_cache,
use_high_precision_compute=False,
):
for _ in range(2):
run_test_chunked_sdpa(
device,
b,
nh,
nkv,
s,
d,
q_chunk_size,
k_chunk_size,
prefill_chunk_size,
page_block_size,
q_dtype,
k_dtype,
use_high_precision_compute,
)

# Print number of program cache entries
assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format(
device.num_program_cache_entries()
)


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled")
@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.parametrize("q_dtype", [ttnn.bfloat16])
@pytest.mark.parametrize("k_dtype", [ttnn.bfloat8_b])
@pytest.mark.parametrize("q_chunk_size", [128])
@pytest.mark.parametrize("k_chunk_size", [128])
@pytest.mark.parametrize("prefill_chunk_size", [1024])
@pytest.mark.parametrize("page_block_size", [64])
@pytest.mark.parametrize(
"b, nh, nkv, s, d",
[
[2, 1, 1, 4096, 128],
], # Llama2-70B
)
def test_sdpa_chunked_iterate_batch(
device,
b,
nh,
nkv,
s,
d,
q_chunk_size,
k_chunk_size,
prefill_chunk_size,
page_block_size,
q_dtype,
k_dtype,
use_program_cache,
use_high_precision_compute=False,
):
"""
This tests chunked prefill where a single core has more than one user to process.
"""
for _ in range(2):
run_test_chunked_sdpa(
device,
b,
nh,
nkv,
s,
d,
q_chunk_size,
k_chunk_size,
prefill_chunk_size,
page_block_size,
q_dtype,
k_dtype,
use_high_precision_compute,
grid_size=(1, 1),
)

# Print number of program cache entries
assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format(
device.num_program_cache_entries()
)
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ void kernel_main() {
constexpr uint32_t cb_k_in = tt::CBIndex::c_1;
constexpr uint32_t cb_v_in = tt::CBIndex::c_2;
constexpr uint32_t cb_mask_in = tt::CBIndex::c_3;
constexpr uint32_t cb_id_page_table = tt::CBIndex::c_6;

constexpr uint32_t onetile = 1;
constexpr uint32_t q_tile_bytes = get_tile_size(cb_q_in);
Expand Down Expand Up @@ -109,7 +110,6 @@ void kernel_main() {
for (uint32_t nb = local_batch_start; nb < local_batch_end; ++nb) {
if constexpr (is_chunked) {
// Chunked means that we have paged attention
constexpr uint32_t cb_id_page_table = tt::CBIndex::c_6;
const InterleavedAddrGen<page_table_is_dram> page_table_gen = {
.bank_base_address = page_table_addr, .page_size = page_table_stick_size};
cb_reserve_back(cb_id_page_table, 1);
Expand Down Expand Up @@ -306,5 +306,9 @@ void kernel_main() {
}
}
}

if constexpr (is_chunked) {
cb_pop_front(cb_id_page_table, 1);
}
}
}

0 comments on commit 19ab7b7

Please sign in to comment.