Skip to content

Commit

Permalink
#0: Implement op hashing now that a struct member will be used as a r…
Browse files Browse the repository at this point in the history
…untime arg. Refactor such that chunk offset is passed in terms of Q chunks.
  • Loading branch information
cglagovichTT committed Dec 11, 2024
1 parent d018e2a commit aca1b57
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def run_test_sdpa_tt(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype
tt_back = ttnn.transformer.scaled_dot_product_attention(
tt_Q, tt_K, tt_V, is_causal=True, program_config=program_config, compute_kernel_config=compute_kernel_config
)
tt_back = tt_back.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
tt_back = ttnn.to_torch(tt_back)

K_repeated = torch.cat([K[:, i : i + 1, :, :].repeat(1, nh // nkv, 1, 1) for i in range(nkv)], dim=1) # b, nh, d, S
V_repeated = torch.cat([V[:, i : i + 1, :, :].repeat(1, nh // nkv, 1, 1) for i in range(nkv)], dim=1) # b, nh, d, S
Expand Down Expand Up @@ -238,7 +238,7 @@ def run_sdpa_noncausal(device, b, nh, nkv, sq, d, q_chunk_size, k_chunk_size, dt
program_config=program_config,
compute_kernel_config=compute_kernel_config,
)
tt_back = tt_back.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
tt_back = ttnn.to_torch(tt_back)

if nkv > 1 and nkv != nh:
assert nh % nkv == 0
Expand Down Expand Up @@ -305,19 +305,13 @@ def test_sdpa_noncausal_unequal_seqlen(device, b, nh, nkv, sq, sk, d, q_chunk_si
# @pytest.mark.parametrize("dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["bfp8", "bf16"])
@pytest.mark.parametrize("q_dtype", [ttnn.bfloat16])
@pytest.mark.parametrize("k_dtype", [ttnn.bfloat8_b])
# @pytest.mark.parametrize("q_chunk_size", [128, 256], ids=["q128", "q256"])
# @pytest.mark.parametrize("k_chunk_size", [128, 256], ids=["k128", "k256"])
@pytest.mark.parametrize("q_chunk_size", [64, 128])
@pytest.mark.parametrize("k_chunk_size", [64, 128])
@pytest.mark.parametrize("q_chunk_size", [128, 256], ids=["q128", "q256"])
@pytest.mark.parametrize("k_chunk_size", [128, 256], ids=["k128", "k256"])
@pytest.mark.parametrize("prefill_chunk_size", [1024, 2048])
@pytest.mark.parametrize("page_block_size", [64, 128])
@pytest.mark.parametrize(
"b, nh, nkv, s, d",
(
[1, 1, 1, 16 * 1024, 32], # Llama2-70B
# [1, 16, 1, 2048, 64], # Falcon-40B
# [1, 71, 1, 2048, 64], # Falcon-7B
),
([1, 8, 1, 16 * 1024, 128],), # Llama2-70B
)
def test_sdpa_chunked(
device,
Expand All @@ -332,8 +326,8 @@ def test_sdpa_chunked(
page_block_size,
q_dtype,
k_dtype,
use_program_cache,
use_high_precision_compute=False,
use_program_cache=False,
):
for _ in range(2):
run_test_chunked_sdpa(
Expand All @@ -352,6 +346,11 @@ def test_sdpa_chunked(
use_high_precision_compute,
)

# Print number of program cache entries
assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format(
device.num_program_cache_entries()
)


def run_test_chunked_sdpa(
device,
Expand All @@ -369,7 +368,7 @@ def run_test_chunked_sdpa(
use_high_precision_compute,
):
program_config = ttnn.SDPAProgramConfig(
compute_with_storage_grid_size=(1, 1),
compute_with_storage_grid_size=device.compute_with_storage_grid_size(),
q_chunk_size=q_chunk_size,
k_chunk_size=k_chunk_size,
exp_approx_mode=False,
Expand Down Expand Up @@ -438,24 +437,13 @@ def unpage_cache(cache):
return paged_cache_back

# Check that we can convert from normal to paged to normal
assert torch.allclose(unpage_cache(page_cache(K)), K)
assert torch.allclose(unpage_cache(page_cache(V)), V)
assert torch.allclose(unpage_cache(page_cache(K)), K), "K is not equal to unpage_cache(page_cache(K))"
assert torch.allclose(unpage_cache(page_cache(V)), V), "V is not equal to unpage_cache(page_cache(V))"

tt_paged_K = ttnn.Tensor(page_cache(K), k_dtype).to(ttnn.TILE_LAYOUT).to(device)
tt_paged_V = ttnn.Tensor(page_cache(V), k_dtype).to(ttnn.TILE_LAYOUT).to(device)
page_table_tt = ttnn.Tensor(page_table, ttnn.int32).to(device)

# tt_K = ttnn.Tensor(K, k_dtype).to(ttnn.TILE_LAYOUT).to(device)
# tt_V = ttnn.Tensor(V, k_dtype).to(ttnn.TILE_LAYOUT).to(device)

# # TODO: Chunk Q
# tt_Q = ttnn.Tensor(Q, q_dtype).to(ttnn.TILE_LAYOUT).to(device)

# tt_back = ttnn.transformer.scaled_dot_product_attention(
# tt_Q, tt_K, tt_V, is_causal=True, program_config=program_config, compute_kernel_config=compute_kernel_config
# )
# tt_back = tt_back.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()

for chunk_idx in range(num_prefill_chunks):
# Chunk Q
Q_chunk = Q[:, :, chunk_idx * prefill_chunk_size : (chunk_idx + 1) * prefill_chunk_size]
Expand All @@ -472,13 +460,7 @@ def unpage_cache(cache):
compute_kernel_config=compute_kernel_config,
)
tt_back = tt_back.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()

# tt_chunk = tt_back[:, :, chunk_idx * prefill_chunk_size : (chunk_idx + 1) * prefill_chunk_size]
gt_chunk = gt[:, :, chunk_idx * prefill_chunk_size : (chunk_idx + 1) * prefill_chunk_size]
out_pass, out_pcc = comp_pcc(gt_chunk, tt_back, 0.994)
out_pass, out_pcc = comp_pcc(gt_chunk, tt_back, 0.998)
logger.debug(f"python vs pytorch: {out_pcc}")
assert out_pass

# out_pass, out_pcc = comp_pcc(gt, tt_back, 0.994)
# logger.debug(f"python vs pytorch: {out_pcc}")
# assert out_pass
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "compute_kernel_api/tile_move_copy.h"
#include "compute_kernel_api/matmul.h"
#include "compute_kernel_api/reduce.h"
#include "debug/dprint.h"

namespace NAMESPACE {
template <uint32_t in0, uint32_t in1, uint32_t num_tiles>
Expand Down Expand Up @@ -370,7 +369,7 @@ void MAIN {
const uint32_t local_nh_end = get_arg_val<uint32_t>(4);
const uint32_t local_q_start = get_arg_val<uint32_t>(5);
const uint32_t local_q_end = get_arg_val<uint32_t>(6);
const uint32_t chunk_start_t = get_arg_val<uint32_t>(7);
const uint32_t chunked_q_chunk_offset = get_arg_val<uint32_t>(7);

const uint32_t q_chunks_per_core = local_q_end - local_q_start;

Expand Down Expand Up @@ -416,6 +415,9 @@ void MAIN {
#endif

// Get Q chunk
if constexpr (is_chunked) {
q_chunk = chunked_q_chunk_offset + q_chunk;
}
uint32_t q_low_idx =
q_chunk * Sq_chunk_t; // This is the sequence index of the first tile of this chunk
uint32_t q_high_idx;
Expand All @@ -424,19 +426,12 @@ void MAIN {
} else {
q_high_idx = Skt;
}
if constexpr (is_chunked) {
q_low_idx = chunk_start_t + q_low_idx;
q_high_idx = chunk_start_t + q_high_idx;
}
// UNPACK(DPRINT << "UNPACK: q_low_idx: " << q_low_idx << ENDL();)
// UNPACK(DPRINT << "UNPACK: q_high_idx: " << q_high_idx << ENDL();)
cb_wait_front(cb_q_in, q_chunk_tiles);

// loop while k_low < q_high
for (uint32_t k_chunk = 0; (k_chunk * Sk_chunk_t) < q_high_idx; ++k_chunk) {
const uint32_t k_low_idx = k_chunk * Sk_chunk_t;
const uint32_t k_high_idx = k_low_idx + Sk_chunk_t;
// UNPACK(DPRINT << "UNPACK: k_chunk: " << k_chunk << ENDL();)

/* QK = Q_CHUNK @ K_CHUNK */
pack_reconfig_data_format(cb_qk_im);
Expand All @@ -455,8 +450,6 @@ void MAIN {
qk_subblock_w,
true /*transpose*/);

// UNPACK(DPRINT << "UNPACK: done with qk_im" << ENDL();)

/* QK *= SCALE */
mul_block_bcast_scalar_inplace<cb_qk_im, cb_scale_in, qk_chunk_tiles>();

Expand All @@ -467,7 +460,6 @@ void MAIN {
// Due to loop bounds, we should never have k_low >= q_high. Can simplify this conditional check
if constexpr (is_causal) {
if (!(q_low_idx >= k_high_idx)) {
// UNPACK(DPRINT << "UNPACK: adding mask" << ENDL();)
/* QK += MASK */
reconfig_data_format(cb_qk_im, cb_mask_in);
add_block_inplace(cb_qk_im, cb_mask_in, qk_chunk_tiles);
Expand All @@ -479,7 +471,6 @@ void MAIN {
add_block_inplace(cb_qk_im, cb_mask_in, qk_chunk_tiles);
}
}
// UNPACK(DPRINT << "UNPACK: done with mask" << ENDL();)

reconfig_data_format(cb_qk_im, cb_identity_scale_in);
reduce_c<
Expand Down Expand Up @@ -524,7 +515,7 @@ void MAIN {
out_subblock_h,
out_subblock_w,
false /*transpose*/);
// UNPACK(DPRINT << "UNPACK: done with out_im" << ENDL();)

reconfig_data_format_srca(cb_out_im);
cb_pop_front(cb_qk_im, qk_chunk_tiles);

Expand Down Expand Up @@ -570,8 +561,5 @@ void MAIN {
}
}
}
// UNPACK(DPRINT << "UNPACK: done" << ENDL();)
// MATH(DPRINT << "MATH: done" << ENDL();)
// PACK(DPRINT << "PACK: done" << ENDL();)
}
} // namespace NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include <stdint.h>
#include "dataflow_api.h"
#include "debug/dprint.h"

template <uint32_t tile_bytes, uint32_t num_readers>
constexpr uint32_t get_barrier_read_threshold() {
Expand Down Expand Up @@ -58,7 +57,7 @@ void kernel_main() {
const uint32_t local_nh_end = get_arg_val<uint32_t>(argidx++);
const uint32_t local_q_start = get_arg_val<uint32_t>(argidx++);
const uint32_t local_q_end = get_arg_val<uint32_t>(argidx++);
const uint32_t chunk_start_t = get_arg_val<uint32_t>(argidx++);
const uint32_t chunked_q_chunk_offset = get_arg_val<uint32_t>(argidx++);

const uint32_t q_chunks_per_core = local_q_end - local_q_start;

Expand Down Expand Up @@ -126,7 +125,6 @@ void kernel_main() {
const uint32_t mask_batch_offset = nb * Sqt * Skt;
for (uint32_t nq = local_nh_start; nq < local_nh_end; ++nq) {
for (uint32_t q_iter = 0; q_iter < q_chunks_per_core; ++q_iter) {
// DPRINT << "q_iter: " << q_iter << ENDL();
uint32_t q_chunk;
#if defined BALANCED_Q_PARALLEL
uint32_t q_chunk_div_2 = q_chunks_per_core / 2;
Expand Down Expand Up @@ -163,8 +161,9 @@ void kernel_main() {

cb_push_back(cb_q_in, q_chunk_tiles);

// DPRINT << "q_chunk: " << q_chunk << "\n";

if constexpr (is_chunked) {
q_chunk = chunked_q_chunk_offset + q_chunk;
}
uint32_t q_low_idx =
q_chunk * Sq_chunk_t; // This is the sequence index of the first tile of this chunk
uint32_t q_high_idx;
Expand All @@ -173,14 +172,6 @@ void kernel_main() {
} else {
q_high_idx = Skt;
}
if constexpr (is_chunked) {
// Add the chunk offset to the low and high indices
q_low_idx = chunk_start_t + q_low_idx;
q_high_idx = chunk_start_t + q_high_idx;
}

// DPRINT << "q_low_idx: " << q_low_idx << ENDL();
// DPRINT << "q_high_idx: " << q_high_idx << ENDL();

const uint32_t kv_head = nq / q_heads_per_kv;
const uint32_t kv_head_offset = kv_head * Skt * DHt;
Expand All @@ -190,7 +181,6 @@ void kernel_main() {
const uint32_t k_low_idx = k_chunk * Sk_chunk_t;
const uint32_t k_high_idx = k_low_idx + Sk_chunk_t;
const uint32_t k_start_tile_id = kv_batch_offset + kv_head_offset + k_chunk * Sk_chunk_t * DHt;
// DPRINT << "READER: k_chunk: " << k_chunk << ENDL();

if constexpr (is_chunked) {
// Use page table to read K chunk
Expand Down Expand Up @@ -240,7 +230,6 @@ void kernel_main() {
cb_push_back(cb_k_in, k_chunk_tiles);
}

// DPRINT << "READER: done with k_chunk: " << k_chunk << ENDL();
if constexpr (use_provided_mask) {
// Finding the diagonal is harder now that q_chunk_size and k_chunk_size can differ
// Q-range = [q_low, q_high)
Expand Down Expand Up @@ -314,11 +303,8 @@ void kernel_main() {
noc_async_read_barrier();
cb_push_back(cb_v_in, k_chunk_tiles);
}

// DPRINT << "READER: done with v_chunk: " << k_chunk << ENDL();
}
}
}
}
// DPRINT << "READER: done" << ENDL();
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ void fill_diagonal_tile(uint32_t cb_id, uint32_t tile_id, uint32_t partial_val)

fill_tile<tile_bytes>(cb_id, tile_id, 0);

// DPRINT << "Fill partial tile" << ENDL();
const uint16_t datum_val = partial_val >> 16;
volatile tt_l1_ptr uint16_t* uint16_ptr =
reinterpret_cast<volatile tt_l1_ptr uint16_t*>(get_write_ptr(cb_id) + tile_id * tile_bytes);
Expand Down Expand Up @@ -215,12 +214,6 @@ void kernel_main() {
q_chunk * Sq_chunk_t; // This is the sequence index of the first tile of this chunk
uint32_t q_high_idx = q_low_idx + Sq_chunk_t;

// if constexpr (is_chunked) {
// q_low_idx = chunk_start_t + q_low_idx;
// q_high_idx = chunk_start_t + q_high_idx;
// q_chunk = chunk_start_t_in_q_chunks + q_chunk;
// }

for (uint32_t k_chunk = 0; (k_chunk * Sk_chunk_t) < q_high_idx; ++k_chunk) {
const uint32_t k_low_idx = k_chunk * Sk_chunk_t;
const uint32_t k_high_idx = k_low_idx + Sk_chunk_t;
Expand All @@ -233,7 +226,6 @@ void kernel_main() {
if (!(q_low_idx >= k_high_idx)) {
generate_mask<cb_mask_in>(Sq_chunk_t, Sk_chunk_t, q_chunk, k_chunk);
}
// DPRINT << "WRITER: done with k_chunk: " << k_chunk << ENDL();
}
}

Expand Down
15 changes: 15 additions & 0 deletions ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,4 +259,19 @@ operation::ProgramWithCallbacks ScaledDotProductAttention::create_program(
this->program_config);
}

operation::Hash ScaledDotProductAttention::compute_program_hash(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors) const {
bool is_chunked_prefill = this->chunk_start_idx.has_value();
return operation::hash_operation<ScaledDotProductAttention>(
this->scale,
this->output_mem_config,
this->program_config,
this->is_causal,
is_chunked_prefill,
this->compute_kernel_config,
input_tensors,
optional_input_tensors);
}

} // namespace ttnn::operations::transformer
4 changes: 4 additions & 0 deletions ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ struct ScaledDotProductAttention {
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
std::vector<Tensor>& output_tensors) const;

operation::Hash compute_program_hash(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors) const;
};

} // namespace ttnn::operations::transformer
Loading

0 comments on commit aca1b57

Please sign in to comment.