Skip to content

Commit

Permalink
#9510: fixed ci error and cleaned up code
Browse files Browse the repository at this point in the history
  • Loading branch information
caixunshiren committed Jun 28, 2024
1 parent 913ed37 commit d1f2f8f
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def num_to_corerange(x):


def get_chunk_size(s):
# Got to test this!
if s <= 32:
return 32
if s <= 64:
Expand Down Expand Up @@ -169,6 +168,12 @@ def run_test_sdpa_decode(
sharded_in=False,
sharded_out=False,
):
if (
device.compute_with_storage_grid_size().x < grid_size[0]
or device.compute_with_storage_grid_size().y < grid_size[1]
):
pytest.skip("Grid size too large for device")

padded_num_heads = nearest_pow_2(nearest_n(nh, n=32))
torch.manual_seed(1234)

Expand Down Expand Up @@ -229,8 +234,6 @@ def run_test_sdpa_decode(
attn_mask[:, :, :, start_idx:] = torch.finfo(torch.float32).min

Q = torch.randn(1, b, padded_num_heads, d)
# Q = torch.eye(padded_num_heads, d).expand(1, b, padded_num_heads, d)
# Q = torch.ones(1, b, padded_num_heads, d) * 1

tt_Q = ttnn.as_tensor(
Q,
Expand All @@ -239,17 +242,11 @@ def run_test_sdpa_decode(
layout=ttnn.TILE_LAYOUT,
memory_config=height_sharded_memcfg if sharded_in else dram_memcfg,
)
# print(f"Q memcfg: {tt_Q.memory_config()}")

tt_attn_mask = ttnn.as_tensor(
attn_mask, device=device, dtype=mask_dtype, layout=ttnn.TILE_LAYOUT, memory_config=dram_memcfg
)

# logger.info(f"Q shape: {Q.shape}")
# logger.info(f"K shape: {K.shape}")
# logger.info(f"V shape: {V.shape}")
# logger.info(f"attn_mask shape: {attn_mask.shape}")

tt_back = tt_lib.operations.primary.transformers.scaled_dot_product_attention_decode(
tt_Q,
tt_K,
Expand Down Expand Up @@ -299,6 +296,12 @@ def run_test_sdpa_decode_single_iter(
sharded_in=False,
sharded_out=False,
):
if (
device.compute_with_storage_grid_size().x < grid_size[0]
or device.compute_with_storage_grid_size().y < grid_size[1]
):
pytest.skip("Grid size too large for device")

padded_num_heads = nearest_pow_2(nearest_n(nh, n=32))
torch.manual_seed(1234)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ void copy_block(uint32_t in_cb, uint32_t out_cb, uint32_t num_tiles) {
cb_pop_front(in_cb, num_tiles);
}

void matmul_blocks(const uint32_t& in0_cb, const uint32_t& in1_cb, const uint32_t& out_cb, const uint32_t& M, const uint32_t& N, const uint32_t& K, const uint32_t& num_blocks, const uint32_t& in0_num_subblocks, const uint32_t& in1_num_subblocks,
void cb_matmul_blocks(const uint32_t& in0_cb, const uint32_t& in1_cb, const uint32_t& out_cb, const uint32_t& M, const uint32_t& N, const uint32_t& K, const uint32_t& num_blocks, const uint32_t& in0_num_subblocks, const uint32_t& in1_num_subblocks,
const uint32_t& in0_block_w, const uint32_t& subblock_h, const uint32_t& subblock_w, const bool& transpose) {
// precondition: in0_cb has M*K produced
// preconditino: in1_cb has K*N produced
Expand Down Expand Up @@ -455,7 +455,7 @@ void MAIN {
/* QK = Q_CHUNK @ K_CHUNK */
unpack_reconfig_data_format(cb_q_in, cb_k_in); // DEBUG
pack_reconfig_data_format(cb_qk_im);
matmul_blocks(cb_q_in, cb_k_in, cb_qk_im, Sq_chunk_t, Sk_chunk_t, DHt, qk_num_blocks, qk_in0_num_subblocks, qk_in1_num_subblocks, qk_in0_block_w, qk_subblock_h, qk_subblock_w, true /*transpose*/);
cb_matmul_blocks(cb_q_in, cb_k_in, cb_qk_im, Sq_chunk_t, Sk_chunk_t, DHt, qk_num_blocks, qk_in0_num_subblocks, qk_in1_num_subblocks, qk_in0_block_w, qk_subblock_h, qk_subblock_w, true /*transpose*/);

// DPRINT << "[C] D QK 1"<< ENDL();

Expand Down Expand Up @@ -498,7 +498,7 @@ void MAIN {
/* OUT_IM = QK @ V_CHUNK */
unpack_reconfig_data_format(cb_qk_im, cb_v_in); // DEBUG
pack_reconfig_data_format(cb_out_im);
matmul_blocks(cb_qk_im, cb_v_in, cb_out_im, Sq_chunk_t, DHt, Sk_chunk_t, out_num_blocks, out_in0_num_subblocks, out_in1_num_subblocks, out_in0_block_w, out_subblock_h, out_subblock_w, false /*transpose*/);
cb_matmul_blocks(cb_qk_im, cb_v_in, cb_out_im, Sq_chunk_t, DHt, Sk_chunk_t, out_num_blocks, out_in0_num_subblocks, out_in1_num_subblocks, out_in0_block_w, out_subblock_h, out_subblock_w, false /*transpose*/);
unpack_reconfig_data_format_srca(cb_out_im);
cb_pop_front(cb_qk_im, qk_chunk_tiles);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ void kernel_main() {
uint32_t barrier_count = 0;

// First, read Q entirely, it could be interleaved or sharded
const uint32_t q_batch_offset = cur_batch * q_chunk_tiles;
const uint32_t q_chunk_tiles_bytes = q_chunk_tiles * q_tile_bytes;
constexpr uint32_t q_batch_offset = cur_batch * q_chunk_tiles;
constexpr uint32_t q_chunk_tiles_bytes = q_chunk_tiles * q_tile_bytes;

if (is_q_sharded){
if constexpr(is_q_sharded){
uint64_t q_read_addr;
if (is_worker){
if constexpr(is_worker){
q_read_addr = get_noc_addr(reduce_core_noc_x, reduce_core_noc_y, q_addr);
} else {
q_read_addr = get_noc_addr(q_addr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void kernel_main() {
constexpr uint32_t k_chunk_start = get_compile_time_arg_val(14);
constexpr uint32_t k_chunk_end = get_compile_time_arg_val(15);

if (k_chunk_start == k_chunk_end) {
if constexpr(k_chunk_start == k_chunk_end) {
// DPRINT << "[Writer Worker] No computes to be done for this worker" << ENDL();
return; // early exit because no computes needs to be done for this worker
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core(

}, compute_kernel_config);

// TT_FATAL(!fp32_dest_acc_en, "fp32_dest_acc_en not supported yet");

auto q_buffer = input_tensor_q.buffer();
auto k_buffer = input_tensor_k.buffer();
auto v_buffer = input_tensor_v.buffer();
Expand Down Expand Up @@ -495,7 +493,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core(
reader_compile_time_args.insert(reader_compile_time_args.end(), {cur_batch, k_chunk_start, k_chunk_end, is_q_sharded, !do_reduce, reduce_core_physical.x, reduce_core_physical.y});
auto reader_kernels_id = CreateKernel(
program,
"tt_eager/tt_dnn/op_library/sdpa/kernels/dataflow/reader_decode_interleaved.cpp",
"tt_eager/tt_dnn/op_library/sdpa/kernels/dataflow/reader_decode_all.cpp",
core,
tt_metal::ReaderDataMovementConfig(
reader_compile_time_args,
Expand All @@ -509,7 +507,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core(
writer_compile_time_args.insert(writer_compile_time_args.end(), {in0_mcast_reducer_semaphore, cur_batch, num_chunks, k_chunk_start, k_chunk_end, is_output_sharded});
writer_kernels_id = CreateKernel(
program,
"tt_eager/tt_dnn/op_library/sdpa/kernels/dataflow/writer_decode_reducer_interleaved.cpp",
"tt_eager/tt_dnn/op_library/sdpa/kernels/dataflow/writer_decode_reducer.cpp",
core,
tt_metal::WriterDataMovementConfig(
writer_compile_time_args,
Expand Down
18 changes: 9 additions & 9 deletions tt_eager/tt_dnn/op_library/sdpa/sdpa_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,17 @@ void ScaledDotProductAttention::validate(


for (auto& input_tensor : input_tensors) {
TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to softmax need to be on device!");
TT_FATAL(input_tensor.buffer() != nullptr, "Operands to softmax need to be allocated in buffers on device!");
TT_FATAL((input_tensor.get_layout() == Layout::TILE), "Inputs to softmax must be tilized");
TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to SDPA need to be on device!");
TT_FATAL(input_tensor.buffer() != nullptr, "Operands to SDPA need to be allocated in buffers on device!");
TT_FATAL((input_tensor.get_layout() == Layout::TILE), "Inputs to SDPA must be tilized");
TT_FATAL(
input_tensor.get_dtype() == DataType::BFLOAT16 ||
input_tensor.get_dtype() == DataType::BFLOAT8_B);

}

auto mask = optional_input_tensors.at(0).value();
TT_FATAL(mask.storage_type() == StorageType::DEVICE, "Operands to softmax need to be on device!");
TT_FATAL(mask.storage_type() == StorageType::DEVICE, "Operands to SDPA need to be on device!");
TT_FATAL(input_tensors.at(0).device() == mask.device());
TT_FATAL(mask.get_layout() == Layout::TILE);
TT_FATAL(mask.get_dtype() == DataType::BFLOAT16 || mask.get_dtype() == DataType::BFLOAT8_B);
Expand Down Expand Up @@ -241,17 +241,17 @@ void ScaledDotProductAttentionDecode::validate(


for (auto& input_tensor : input_tensors) {
TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to softmax need to be on device!");
TT_FATAL(input_tensor.buffer() != nullptr, "Operands to softmax need to be allocated in buffers on device!");
TT_FATAL((input_tensor.get_layout() == Layout::TILE), "Inputs to softmax must be tilized");
TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to SDPA need to be on device!");
TT_FATAL(input_tensor.buffer() != nullptr, "Operands to SDPA need to be allocated in buffers on device!");
TT_FATAL((input_tensor.get_layout() == Layout::TILE), "Inputs to SDPA must be tilized");
TT_FATAL(
input_tensor.get_dtype() == DataType::BFLOAT16 ||
input_tensor.get_dtype() == DataType::BFLOAT8_B || input_tensor.get_dtype() == DataType::BFLOAT4_B);

}

auto mask = optional_input_tensors.at(0).value();
TT_FATAL(mask.storage_type() == StorageType::DEVICE, "Operands to softmax need to be on device!");
TT_FATAL(mask.storage_type() == StorageType::DEVICE, "Operands to SDPA need to be on device!");
TT_FATAL(input_tensors.at(0).device() == mask.device());
TT_FATAL(mask.get_layout() == Layout::TILE);
TT_FATAL(mask.get_dtype() == DataType::BFLOAT16 || mask.get_dtype() == DataType::BFLOAT8_B || mask.get_dtype() == DataType::BFLOAT4_B);
Expand All @@ -269,7 +269,7 @@ void ScaledDotProductAttentionDecode::validate(
TT_FATAL(Q_memcfg.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED);
}
else{
TT_FATAL(input_tensors.at(0).buffer()->buffer_type() == tt_metal::BufferType::DRAM);
TT_FATAL(Q_memcfg.buffer_type == tt_metal::BufferType::DRAM);
}

for (std::size_t i = 1; i < input_tensors.size(); i++) {
Expand Down

0 comments on commit d1f2f8f

Please sign in to comment.