Skip to content

Commit

Permalink
#4686: add fp32 to untilize
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed Feb 21, 2024
1 parent d6eb4a8 commit ca2b91d
Show file tree
Hide file tree
Showing 8 changed files with 327 additions and 426 deletions.
639 changes: 293 additions & 346 deletions tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
#include "compute_kernel_api/tilize.h"
#include "compute_kernel_api/pack_untilize.h"

#include "compute_kernel_api/bcast.h"
#include "debug/dprint.h"

using std::uint32_t;

Expand Down Expand Up @@ -58,66 +56,38 @@ void MAIN {
constexpr uint32_t in0_num_blocks_w = 1; // TODO: Generalize


mm_init(cb_in0, cb_in1, cb_intermed0, transpose_hw);
// mm_block_init(cb_in0, cb_in1, cb_intermed0, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w );

// UNPACK(( DPRINT << num_rows_in_one_tile << ENDL() ));
// UNPACK(( DPRINT << in1_num_blocks << ENDL() ));
// UNPACK(( DPRINT << in1_num_blocks << ENDL() ));
// UNPACK(( DPRINT << in0_block_num_tiles << ENDL() ));

// mm_init(cb_in0, cb_in1, cb_intermed0, transpose_hw);
mm_block_init(cb_in0, cb_in1, cb_intermed0, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w );

for (uint32_t b = 0; b < batch; b++) {

for (uint32_t m = 0; m < Mt; m++) { // TODO: Must be 1; generalize to support batch > 32 (ie. Mt > 1)
for (uint32_t in0_block = 0; in0_block < in0_num_blocks_w; in0_block++) { // TODO: Must be 1; generalize to support inner dim blocking
cb_wait_front(cb_in0, in0_block_num_tiles);

// UNPACK(( DPRINT << TSLICE(cb_in0, 0, SliceRange::h0_32_w31()) << ENDL() ));

for (uint32_t in1_block = 0; in1_block < in1_num_blocks; in1_block++) {
uint32_t in0_index_subblock_offset = 0;
for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_one_tile; tile_row_id++) {
cb_wait_front(cb_in1, in1_block_num_tiles);
cb_pop_front(cb_in1, num_kv_heads_skip);


// UNPACK(( DPRINT << TSLICE(cb_in1, 0, SliceRange::h0_32_w31()) << ENDL() ));

for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; in1_subblock++) { // TODO: Must be 1; need to review inner dim blocking and the untilizing
uint32_t in1_index_subblock_offset = 0;

tile_regs_acquire();


// Compute output sub-block
// uint32_t dst_index = 0; // start at 0, each call to matmul_block internally increments dst_index
// uint32_t in0_index = in0_index_subblock_offset; // offset into in0 block
// uint32_t in1_index = in1_index_subblock_offset; // offset into in1 block
// // inner dim that we accumualte is the inner dim of in0/in1, which is in0_block_w
// for (uint32_t inner_dim_idx = 0; inner_dim_idx < in0_block_w; ++inner_dim_idx) {
// // matmul outer product of (out_subblock_h x out_subblock_w) tiles that fill dst
// // accumulation is done by iterating matmul_block across inner dim
// // in0_block_w is passed as innder dim (kt) to matmul_block, interally used to stride in0
// matmul_block(cb_in0, cb_in1, in0_index, in1_index, dst_index, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w);
// in0_index ++; // stride right by 1
// in1_index += in1_per_core_w; // to stride down by 1 need to stride by in_per_core_w (should be called in1_block_w)
// }

uint32_t dst_index = 0;
uint32_t in0_index_h_offset = 0;
for (uint32_t h = 0; h < out_subblock_h; h++) {
for (uint32_t w = 0; w < out_subblock_w; w++) {
uint32_t in1_index_inner_dim_offset = 0;
for (uint32_t inner_dim = 0; inner_dim < in0_block_w; inner_dim++) {
uint32_t in0_index = in0_index_subblock_offset + in0_index_h_offset + inner_dim;
uint32_t in1_index = in1_index_subblock_offset + in1_index_inner_dim_offset + w;
matmul_tiles(cb_in0, cb_in1, in0_index, in1_index, dst_index, transpose_hw);
in1_index_inner_dim_offset += in1_per_core_w;
}
dst_index++;
}
in0_index_h_offset += in0_block_w;
uint32_t dst_index = 0; // start at 0, each call to matmul_block internally increments dst_index
uint32_t in0_index = in0_index_subblock_offset; // offset into in0 block
uint32_t in1_index = in1_index_subblock_offset; // offset into in1 block
// inner dim that we accumualte is the inner dim of in0/in1, which is in0_block_w
for (uint32_t inner_dim_idx = 0; inner_dim_idx < in0_block_w; ++inner_dim_idx) {
// matmul outer product of (out_subblock_h x out_subblock_w) tiles that fill dst
// accumulation is done by iterating matmul_block across inner dim
// in0_block_w is passed as innder dim (kt) to matmul_block, interally used to stride in0
matmul_block(cb_in0, cb_in1, in0_index, in1_index, dst_index, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w);
in0_index ++; // stride right by 1
in1_index += in1_per_core_w; // to stride down by 1 need to stride by in_per_core_w (should be called in1_block_w)
}

tile_regs_commit();
Expand All @@ -128,34 +98,30 @@ void MAIN {

// TODO: Review inner dim blocking, untilizing, and in1_num_subblocks > 1 (with pack_untilize, can only untilize up to dst num tiles)
// This should normally be inside subblock loop and we pack out out_subblock_num_tiles
// pack_untilize_dst_init_short<intermediate_num_tiles>();
pack_untilize_dst_init_short<intermediate_num_tiles>();
cb_reserve_back(cb_intermed0, intermediate_num_tiles);
tile_regs_wait();
// pack_untilize_dst<intermediate_num_tiles>(cb_intermed0);
pack_tile(0, cb_intermed0);
pack_untilize_dst<intermediate_num_tiles>(cb_intermed0);
pack_untilize_uninit();

// pack_untilize_uninit();
tile_regs_release();
cb_push_back(cb_intermed0, intermediate_num_tiles);


cb_wait_front(cb_intermed0, intermediate_num_tiles);
UNPACK(( DPRINT << intermediate_num_tiles << ENDL() ));
UNPACK(( DPRINT << TSLICE(cb_intermed0, 0, SliceRange::h0_32_w31()) << ENDL() ));
} // 32 tiles loop

in0_index_subblock_offset += in0_subblock_num_tiles;
} // in1_num_blocks loop
} // in0_num_blocks_w

// cb_intermed1 comes from reader; untilized row-major tile
unpack_reconfig_data_format_srca(cb_in1, cb_intermed1);
pack_reconfig_data_format(cb_intermed0, out_cb_id);
cb_wait_front(cb_intermed1, out_num_tiles);

cb_reserve_back(out_cb_id, out_num_tiles);

// tilize CB::intermed1 and write to CB::c_out0
tilize_init_short_with_dt(cb_in1, cb_intermed1, out_num_tiles);
pack_reconfig_data_format(cb_intermed0, out_cb_id);
tilize_init_short(cb_intermed1, out_num_tiles);
tilize_block(cb_intermed1, out_num_tiles, out_cb_id);
cb_push_back(out_cb_id, out_num_tiles);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

#include "dataflow_api.h"

#include "debug/dprint.h"

void kernel_main() {
uint32_t i = 0;

Expand Down Expand Up @@ -33,10 +31,6 @@ void kernel_main() {
uint32_t bfloat16_Nt_bytes = get_arg_val<uint32_t>(i++);
uint32_t bfloat16_last_row_bytes_read = get_arg_val<uint32_t>(i++);

DPRINT << "bfloat16_row_bytes " <<bfloat16_row_bytes <<ENDL();
DPRINT << "bfloat16_Nt_bytes " <<bfloat16_Nt_bytes <<ENDL();


constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1;
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;
constexpr uint32_t cb_id_out = get_compile_time_arg_val(2);
Expand Down Expand Up @@ -110,8 +104,6 @@ void kernel_main() {

cb_push_back(cb_id_in0, in0_block_w);



cb_reserve_back(cb_id_intermed1, out_num_tiles);
uint32_t cb_intermed1_addr = get_write_ptr(cb_id_intermed1);
for (uint32_t in1_block = 0; in1_block < in1_num_blocks; in1_block++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ operation::ProgramWithCallbacks multi_core_attn_matmul(const Tensor &a, const Te

tt::DataFormat in0_data_format = tt_metal::datatype_to_dataformat_converter(a.dtype());
tt::DataFormat in1_data_format = tt_metal::datatype_to_dataformat_converter(b.dtype());
tt::DataFormat interm_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b;
tt::DataFormat interm_data_format = fp32_dest_acc_en and in0_data_format == tt::DataFormat::Float32 ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b;
tt::DataFormat output_data_format = tt_metal::datatype_to_dataformat_converter(output.dtype());
uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format);
uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format);
Expand All @@ -68,6 +68,10 @@ operation::ProgramWithCallbacks multi_core_attn_matmul(const Tensor &a, const Te
log_debug("math_approx_mode: {}", math_approx_mode);
log_debug("fp32_dest_acc_en: {}", fp32_dest_acc_en);
log_debug("packer_l1_acc: {}", packer_l1_acc);
log_debug("in0_data_format: {}", in0_data_format);
log_debug("in1_data_format: {}", in1_data_format);
log_debug("interm_data_format: {}", interm_data_format);
log_debug("output_data_format: {}", output_data_format);

tt_metal::Buffer *src0_buffer = a.buffer();
tt_metal::Buffer *src1_buffer = b.buffer();
Expand Down Expand Up @@ -149,7 +153,7 @@ operation::ProgramWithCallbacks multi_core_attn_matmul(const Tensor &a, const Te
(uint32_t) src0_is_dram,
(uint32_t) src1_is_dram,
(uint32_t) transpose_hw_bool,
(uint32_t) fp32_dest_acc_en
(uint32_t) (fp32_dest_acc_en and in0_data_format == tt::DataFormat::Float32)
};

bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ operation::ProgramWithCallbacks multi_core_group_attn_matmul(const Tensor &a, co
tt::DataFormat in0_data_format = tt_metal::datatype_to_dataformat_converter(a.dtype());
tt::DataFormat in1_data_format = tt_metal::datatype_to_dataformat_converter(b.dtype());
tt::DataFormat interm_data_format = fp32_dest_acc_en and in0_data_format == tt::DataFormat::Float32 ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b;
// interm_data_format=tt::DataFormat::Float16_b;
tt::DataFormat output_data_format = tt_metal::datatype_to_dataformat_converter(output.dtype());
uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format);
uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format);
Expand Down Expand Up @@ -98,7 +97,6 @@ operation::ProgramWithCallbacks multi_core_group_attn_matmul(const Tensor &a, co
const uint32_t in1_per_core_w = in1_num_subblocks * out_block_w;
const uint32_t in1_block_w_tile_bytes = out_subblock_w * in1_single_tile_size;
uint32_t ONE_ROW_BFLOAT16_BYTES = fp32_dest_acc_en and in0_data_format == tt::DataFormat::Float32 ? 128 : 64;
// ONE_ROW_BFLOAT16_BYTES = 64;
const uint32_t bfloat16_row_bytes = ONE_ROW_BFLOAT16_BYTES * out_block_w; // TODO: Generalize

log_debug("in0_block_w: {}", in0_block_w);
Expand Down
4 changes: 0 additions & 4 deletions tt_eager/tt_lib/csrc/operations/primary/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,6 @@ void py_module(py::module& m_primary) {
const MemoryConfig& out_mem_config,
std::optional<DataType> output_dtype,
std::optional<DeviceComputeKernelConfig> compute_kernel_config
// const MathFidelity math_fidelity,
// const bool fp32_dest_acc_en,
// const bool math_approx_mode,
// const bool packer_l1_acc
) {
return matmul(
input_tensor_a, input_tensor_b, bias, program_config, out_mem_config, output_dtype, compute_kernel_config);
Expand Down
2 changes: 0 additions & 2 deletions tt_metal/hw/inc/debug/dprint_tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ struct SliceRange {
static inline SliceRange hw0_32_4() { return SliceRange{ .h0 = 0, .h1 = 32, .hs = 4, .w0 = 0, .w1 = 32, .ws = 4 }; }
// [0, 0:32]
static inline SliceRange h0_w0_32() { return SliceRange{ .h0 = 0, .h1 = 1, .hs = 1, .w0 = 0, .w1 = 32, .ws = 1 }; }
static inline SliceRange h1_w0_32() { return SliceRange{ .h0 = 1, .h1 = 2, .hs = 1, .w0 = 0, .w1 = 32, .ws = 1 }; }
// [0:32, 0]
static inline SliceRange h0_32_w0() { return SliceRange{ .h0 = 0, .h1 = 32, .hs = 1, .w0 = 0, .w1 = 1, .ws = 1 }; }
static inline SliceRange h0_32_w31() { return SliceRange{ .h0 = 0, .h1 = 32, .hs = 1, .w0 = 31, .w1 = 32, .ws = 1 }; }
// [0:32:1, 1]
static inline SliceRange h0_32_w1() { return SliceRange{ .h0 = 0, .h1 = 32, .hs = 1, .w0 = 1, .w1 = 2, .ws = 1 }; }
// [0:4:1, 0:4:1]
Expand Down
18 changes: 9 additions & 9 deletions tt_metal/include/compute_kernel_api/untilize.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ namespace ckernel {
*/
ALWI void untilize_init(uint32_t icb, uint32_t ocb = 16)
{
MATH(( llk_math_eltwise_unary_datacopy_init<A2D, BroadcastType::NONE>(false /*transpose of faces*/, false /*transpose within 16x16 face*/, icb) ));
MATH(( llk_math_pack_sync_init<SyncHalf>() ));
MATH(( llk_math_eltwise_unary_datacopy_init<A2D, BroadcastType::NONE, DST_ACCUM_MODE>(false /*transpose of faces*/, false /*transpose within 16x16 face*/, icb) ));
MATH(( llk_math_pack_sync_init<SyncHalf, DST_ACCUM_MODE>() ));

PACK(( llk_pack_hw_configure_disaggregated<false>(ocb) ));
PACK(( llk_pack_hw_configure_disaggregated<false, DST_ACCUM_MODE>(ocb) ));
PACK(( llk_pack_init(ocb) ));
PACK(( llk_setup_outputs() ));
PACK(( llk_pack_dest_init<SyncHalf, DstTileFaceLayout::RowMajor, false>() ));
PACK(( llk_pack_dest_init<SyncHalf, DstTileFaceLayout::RowMajor, false, DST_ACCUM_MODE>() ));

UNPACK(( llk_setup_operands() ));
UNPACK(( llk_unpack_untilize_hw_configure_disaggregated(icb) ));
Expand All @@ -38,7 +38,7 @@ ALWI void untilize_init(uint32_t icb, uint32_t ocb = 16)
*/
ALWI void untilize_init_short(uint32_t icb)
{
MATH(( llk_math_eltwise_unary_datacopy_init<A2D, BroadcastType::NONE>(false /*transpose of faces*/, false /*transpose within 16x16 face*/, icb) ));
MATH(( llk_math_eltwise_unary_datacopy_init<A2D, BroadcastType::NONE, DST_ACCUM_MODE>(false /*transpose of faces*/, false /*transpose within 16x16 face*/, icb) ));
UNPACK(( llk_unpack_untilize_init(icb) ));
}

Expand All @@ -55,20 +55,20 @@ ALWI void untilize_block(uint32_t icb, uint32_t block, uint32_t ocb)

// Datacopy
for (int reg_id = 0; reg_id < N; reg_id++) {
MATH(( llk_math_eltwise_unary_datacopy<A2D, BroadcastType::NONE, SyncHalf>(reg_id) ));
MATH(( llk_math_eltwise_unary_datacopy<A2D, BroadcastType::NONE, SyncHalf, DST_ACCUM_MODE>(reg_id) ));
}

MATH(( llk_math_dest_section_done<SYNC>() ));
MATH(( llk_math_dest_section_done<SYNC, DST_ACCUM_MODE>() ));

PACK(( llk_packer_wait_for_math_done() ));

// Datacopy
for (int reg_id = 0; reg_id < N; reg_id++) {
PACK(( llk_pack<false, SYNC, false >(reg_id, ocb) ));
PACK(( llk_pack<false, SYNC, false, DST_ACCUM_MODE >(reg_id, ocb) ));
}

// Release dest
PACK(( llk_pack_dest_section_done<SYNC>() ));
PACK(( llk_pack_dest_section_done<SYNC, DST_ACCUM_MODE>() ));
}
}

Expand Down

0 comments on commit ca2b91d

Please sign in to comment.