Skip to content

Commit

Permalink
#5420: Remove dst layout from api calls, all apis are RowMajor
Browse files Browse the repository at this point in the history
  • Loading branch information
rtawfik01 committed Feb 19, 2024
1 parent 3224cce commit 069cbd3
Show file tree
Hide file tree
Showing 18 changed files with 93 additions and 240 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ add_tiles_bcast

.. doxygenfunction:: add_bcast_cols_init_short(uint32_t icb0 = 0, uint32_t icb1 = 1)
.. doxygenfunction:: add_bcast_rows_init_short(uint32_t icb0 = 0, uint32_t icb1 = 1)
.. doxygenfunction:: add_bcast_rows_init_short_post_matmul(uint32_t icb0 = 0, uint32_t icb1 = 1)
.. doxygenfunction:: add_tiles_bcast(uint32_t icb0, uint32_t icb1, uint32_t itile0, uint32_t itile1, uint32_t idst)
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
copy_tile
=========

.. doxygenfunction:: copy_tile(uint32_t icb, uint32_t itile, uint32_t idst)
.. doxygenfunction:: copy_tile(uint32_t in_cb_id, uint32_t in_tile_index, uint32_t dst_tile_index)
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace NAMESPACE {

FORCE_INLINE void reload_from_cb_to_dst(uint32_t in0_cb_id, uint32_t in1_cb_id, uint32_t mm_partials_cb_id, uint32_t out_subblock_num_tiles, uint32_t out_subblock_w, uint32_t out_subblock_h, uint32_t in0_block_w) {
// Reconfigure input
copy_tile_matmul_partials_init_short_with_dt(mm_partials_cb_id);
copy_tile_to_dst_init_short_with_dt(in1_cb_id, mm_partials_cb_id);
cb_wait_front(mm_partials_cb_id, out_subblock_num_tiles);
tile_regs_acquire();

Expand Down Expand Up @@ -216,11 +216,8 @@ void MAIN {
#if defined FP32_DEST_ACC_EN or defined PACKER_L1_ACC
PACK(( pack_reconfig_data_format(out_cb_id) ));
#endif
#ifdef ARCH_GRAYSKULL
add_bcast_rows_init_short_post_matmul();
#else
add_bcast_rows_init_short();
#endif

add_bcast_rows_init_short();
// reconfigure unpacker df for src B
unpack_reconfig_data_format(in1_cb_id, mm_partials_cb_id, in0_cb_id, bias_cb_id);
cb_wait_front(bias_cb_id, in1_per_core_w);
Expand Down Expand Up @@ -268,11 +265,6 @@ void MAIN {
if constexpr(batch > 1) {
// reconfigure init for matmul
mm_block_init_short(in0_cb_id, in1_cb_id, 0, out_subblock_w, out_subblock_h, in0_block_w);
#ifdef ARCH_GRAYSKULL
// reconfigure packer's dest registers to Col Major
PACK(( llk_pack_init<false, false, DstTileFaceLayout::ColMajor>() ));
PACK(( llk_init_packer_dest_offset_registers<SyncHalf,DstTileFaceLayout::ColMajor,false>() ));
#endif
// reconfigure unpacker df for src B
unpack_reconfig_data_format(in1_cb_id, in0_cb_id);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,6 @@
// SliceRange srr1 = SliceRange{.h0 = 1, .h1 = 2, .hs = 8, .w0 = 0, .w1 = 32, .ws = 1};
// SliceRange src = SliceRange{.h0 = 0, .h1 = 32, .hs = 1, .w0 = 0, .w1 = 1, .ws = 1};


// TODO: Uplift these APIs for compute_api.h?
inline void col_major_to_row_major_init() {
#ifdef ARCH_GRAYSKULL
// Configure to RowMajor for tilize (similar to add bcast for bias)
MATH(( llk_math_pack_sync_init<SYNC>() ));
PACK(( llk_pack_dest_init<SYNC, DstTileFaceLayout::RowMajor, false>() ));
#endif
}

inline void row_major_to_col_major_init() {
#ifdef ARCH_GRAYSKULL
// Configure back to ColMajor for matmul
MATH(( llk_math_pack_sync_init<SYNC>() ));
PACK(( llk_pack_dest_init<SYNC, DstTileFaceLayout::ColMajor, false>() ));
#endif
}

inline void tilize_in(
uint32_t in_cb_id,
uint32_t in_subblock_h,
Expand Down Expand Up @@ -185,9 +167,7 @@ void MAIN {
#ifdef PRE_TILIZE
unpack_reconfig_data_format_srca(in1_cb_id, in0_pretilize_cb_id);

col_major_to_row_major_init();
tilize_in(in0_pretilize_cb_id, in0_subblock_h, in0_block_w, in0_num_subblocks, tilized_in0_cb_id);
row_major_to_col_major_init();

// TODO: unpack_reconfig_data_format_srca(in0_pretilize_cb_id, in1_cb_id) doesn't work if in0 is BFLOATB_B and in1 is BFLOAT16
mm_block_init_short(in0_cb_id, in1_cb_id, false, out_subblock_w, out_subblock_h, in0_block_w);
Expand Down Expand Up @@ -217,12 +197,10 @@ void MAIN {
#endif
unpack_reconfig_data_format_srca(in1_cb_id, in0_cb_id);

col_major_to_row_major_init();
tilize_in(in0_cb_id, in0_subblock_h, in0_block_w, in0_num_subblocks_read, tilized_in0_cb_id);
#ifdef SPLIT_READER
tilize_in(in0_cb_second_reader_id, in0_subblock_h, in0_block_w, in0_num_subblocks_read, tilized_in0_cb_id);
#endif
row_major_to_col_major_init();

mm_block_init_short_with_dt(mm_in0_cb_id, in1_cb_id, /*srca_old_operand=*/in0_cb_id, out_subblock_w, out_subblock_h, in0_block_w);
}
Expand All @@ -242,7 +220,7 @@ void MAIN {
for (uint32_t in1_subblock_i = 0; in1_subblock_i < in1_num_subblocks; ++in1_subblock_i) {
if (enable_reload) {
// Reconfigure input
copy_tile_matmul_partials_init_short_with_dt(in1_cb_id, matmul_partials_cb);
copy_tile_to_dst_init_short_with_dt(in1_cb_id, matmul_partials_cb);
cb_wait_front(matmul_partials_cb, out_subblock_num_tiles);
tile_regs_acquire();

Expand Down Expand Up @@ -313,7 +291,7 @@ void MAIN {
// if last block we pack the final result with relu enabled
PACK(( llk_pack_relu_config(ReluType::ZERO_RELU) ));
#endif
add_bcast_rows_init_short_post_matmul();
add_bcast_rows_init_short();
unpack_reconfig_data_format(in1_cb_id, matmul_partials_cb, mm_in0_cb_id, bias_cb_id);
cb_wait_front(bias_cb_id, bias_ntiles_w);
cb_wait_front(matmul_partials_cb, out_block_num_tiles);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* LLK MATMUL
*************************************************************************/

template <int NUM_FIDELITY_PHASES, DstTileFaceLayout FaceLayout = DstTileFaceLayout::RowMajor>
template <int NUM_FIDELITY_PHASES>
inline void llk_math_matmul_init(
const std::uint32_t operandA /*not used*/,
const std::uint32_t operandB /*not used*/,
Expand All @@ -19,24 +19,24 @@ inline void llk_math_matmul_init(
const std::uint32_t rt_dim = 1,
const std::uint32_t kt_dim = 1) {

_llk_math_matmul_init_<NUM_FIDELITY_PHASES, FaceLayout>(
_llk_math_matmul_init_<NUM_FIDELITY_PHASES, DstTileFaceLayout::RowMajor>(
transpose,
ct_dim,
rt_dim,
kt_dim);
}


template <int NUM_FIDELITY_PHASES, DstTileFaceLayout FaceLayout = DstTileFaceLayout::RowMajor>
template <int NUM_FIDELITY_PHASES>
inline void llk_math_matmul(
uint dst_index,
const uint dst_index,
const bool transpose = false,
const std::uint32_t ct_dim = 1,
const std::uint32_t rt_dim = 1,
const std::uint32_t kt_dim = 1) {
for (std::uint32_t rt=0; rt<rt_dim; rt++) {
for (std::uint32_t ct=0; ct<ct_dim; ct++) {
_llk_math_matmul_<NUM_FIDELITY_PHASES, FaceLayout>(dst_index+rt*ct_dim+ct, transpose);
_llk_math_matmul_<NUM_FIDELITY_PHASES, DstTileFaceLayout::RowMajor>(dst_index+rt*ct_dim+ct, transpose);
}
}
}
24 changes: 12 additions & 12 deletions tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_pack_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
* LLK PACK
*************************************************************************/

template <bool untilize = false, bool zero_output = false, DstTileFaceLayout FaceLayout = DstTileFaceLayout::RowMajor>
template <bool untilize = false, bool zero_output = false>
inline void llk_pack_mop_config(const uint32_t output) {
constexpr bool write_tile_header = false;
_llk_pack_mop_config_<untilize, zero_output, FaceLayout, write_tile_header>();
_llk_pack_mop_config_<untilize, zero_output, DstTileFaceLayout::RowMajor, write_tile_header>();
}

template <bool untilize = false, bool is_fp32_dest_acc_en = false /*not used*/>
Expand Down Expand Up @@ -69,13 +69,13 @@ inline void llk_pack_reduce_hw_configure_disaggregated(std::uint32_t pack_output
llk_pack_reduce_hw_configure<untilize, type, dim, is_fp32_dest_acc_en>(&llk_pack_params);
}

template <bool untilize = false, bool zero_output = false, DstTileFaceLayout FaceLayout = DstTileFaceLayout::RowMajor>
template <bool untilize = false, bool zero_output = false>
inline void llk_pack_init(const std::uint32_t pack_output = 16) {

const std::uint32_t output_id = get_output_id(pack_output);
constexpr bool write_tile_header = false;

_llk_pack_init_<untilize, zero_output, FaceLayout, write_tile_header>();
_llk_pack_init_<untilize, zero_output, DstTileFaceLayout::RowMajor, write_tile_header>();
}

template <bool out_of_order_output, bool untilize>
Expand Down Expand Up @@ -155,14 +155,14 @@ inline void llk_pack_dest_section_done() {
_llk_pack_dest_section_done_<Dst, is_fp32_dest_acc_en>();
}

template <DstSync Dst, DstTileFaceLayout FaceLayout, bool untilize = false>
template <DstSync Dst, bool untilize = false>
inline void llk_init_packer_dest_offset_registers(const std::uint32_t pack_output = 16) {
_llk_init_packer_dest_offset_registers_<Dst, FaceLayout, untilize>();
_llk_init_packer_dest_offset_registers_<Dst, DstTileFaceLayout::RowMajor, untilize>();
}

template <DstSync Dst, DstTileFaceLayout FaceLayout = RowMajor, bool untilize = false, bool is_fp32_dest_acc_en = false /*unused*/>
template <DstSync Dst, bool untilize = false, bool is_fp32_dest_acc_en = false /*unused*/>
inline void llk_pack_dest_init(const std::uint32_t pack_output = 16) {
_llk_pack_dest_init_<Dst, FaceLayout, untilize, is_fp32_dest_acc_en>();
_llk_pack_dest_init_<Dst, DstTileFaceLayout::RowMajor, untilize, is_fp32_dest_acc_en>();
}

template <bool mail2math=true, bool mail2pack=true>
Expand All @@ -183,25 +183,25 @@ inline void llk_pack_debug_dump_seek(std::uint8_t offset) {
_llk_pack_debug_dump_seek_(offset);
}

template <bool is_fp32_dest_acc_en = false /*unused*/, bool is_tile_dim_reconfig_en = false /*unused*/, DstTileFaceLayout FaceLayout = DstTileFaceLayout::RowMajor /*unused*/>
template <bool is_fp32_dest_acc_en = false /*unused*/, bool is_tile_dim_reconfig_en = false /*unused*/>
inline void llk_pack_reconfig_data_format(const std::uint32_t new_output) {
std::uint32_t output_id = get_output_id(new_output);

_llk_pack_reconfig_data_format_<is_fp32_dest_acc_en, is_tile_dim_reconfig_en, FaceLayout>(
_llk_pack_reconfig_data_format_<is_fp32_dest_acc_en, is_tile_dim_reconfig_en, DstTileFaceLayout::RowMajor>(
pack_dst_format[output_id],
cb_interface[output_id].fifo_page_size
);
}

template <bool is_fp32_dest_acc_en = false /*unused*/, bool is_tile_dim_reconfig_en = false /*unused*/, DstTileFaceLayout FaceLayout = DstTileFaceLayout::RowMajor /*unused*/>
template <bool is_fp32_dest_acc_en = false /*unused*/, bool is_tile_dim_reconfig_en = false /*unused*/>
inline void llk_pack_reconfig_data_format(const std::uint32_t old_output, const std::uint32_t new_output) {
std::uint32_t old_output_id = get_output_id(old_output);
std::uint32_t new_output_id = get_output_id(new_output);

if((pack_dst_format[old_output_id] != pack_dst_format[new_output_id])
&& (pack_dst_format[old_output_id] != (uint)DataFormat::Invalid)
&& (pack_dst_format[new_output_id] != (uint)DataFormat::Invalid)) {
llk_pack_reconfig_data_format<is_fp32_dest_acc_en, is_tile_dim_reconfig_en, FaceLayout>(new_output);
llk_pack_reconfig_data_format<is_fp32_dest_acc_en, is_tile_dim_reconfig_en>(new_output);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* LLK MATMUL
*************************************************************************/

template <int NUM_FIDELITY_PHASES, DstTileFaceLayout FaceLayout = DstTileFaceLayout::ColMajor>
template <int NUM_FIDELITY_PHASES>
inline void llk_math_matmul_init(
const std::uint32_t operandA,
const std::uint32_t operandB,
Expand Down Expand Up @@ -40,12 +40,13 @@ inline void llk_math_matmul_init(
kt_dim);
}

template <int NUM_FIDELITY_PHASES, DstTileFaceLayout FaceLayout = DstTileFaceLayout::ColMajor>
template <int NUM_FIDELITY_PHASES>
inline void llk_math_matmul(
uint dst_index,
const uint dst_index,
const bool transpose = false,
const std::uint32_t ct_dim = 1,
const std::uint32_t rt_dim = 1,
const std::uint32_t kt_dim = 1) {

_llk_math_matmul_<NUM_FIDELITY_PHASES, DstTileFaceLayout::RowMajor>(dst_index, transpose, ct_dim, rt_dim, kt_dim);
}
26 changes: 13 additions & 13 deletions tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_pack_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
* LLK PACK
*************************************************************************/

template <bool untilize = false, bool zero_output = false, DstTileFaceLayout FaceLayout = DstTileFaceLayout::RowMajor>
template <bool untilize = false, bool zero_output = false>
inline void llk_pack_mop_config(const uint32_t output) {

const std::uint32_t output_id = get_output_id(output);
Expand All @@ -31,7 +31,7 @@ inline void llk_pack_mop_config(const uint32_t output) {
const bool partial_face = get_output_partial_face(output_id) && IS_BFP_FORMAT((uint)pack_dst_format[output_id]);
const bool narrow_tile = get_output_narrow_tile(output_id);

_llk_pack_mop_config_<untilize, zero_output, FaceLayout, false>(
_llk_pack_mop_config_<untilize, zero_output, DstTileFaceLayout::RowMajor, false>(
pack_dst_format[output_id],
face_r_dim,
num_faces,
Expand Down Expand Up @@ -99,7 +99,7 @@ inline void llk_pack_reduce_hw_configure_disaggregated(std::uint32_t pack_output
llk_pack_reduce_hw_configure<untilize, type, dim, is_fp32_dest_acc_en>(&llk_pack_params);
}

template <bool untilize = false, bool zero_output = false, DstTileFaceLayout FaceLayout = DstTileFaceLayout::RowMajor>
template <bool untilize = false, bool zero_output = false>
inline void llk_pack_init(const std::uint32_t pack_output = 16) {

const std::uint32_t output_id = get_output_id(pack_output);
Expand All @@ -108,7 +108,7 @@ inline void llk_pack_init(const std::uint32_t pack_output = 16) {
const bool partial_face = get_output_partial_face(output_id);
const bool narrow_tile = get_output_narrow_tile(output_id);

_llk_pack_init_<untilize, zero_output, FaceLayout, false>(
_llk_pack_init_<untilize, zero_output, DstTileFaceLayout::RowMajor, false>(
pack_dst_format[output_id],
face_r_dim,
num_faces,
Expand Down Expand Up @@ -233,26 +233,26 @@ inline void llk_pack_dest_section_done() {
_llk_pack_dest_section_done_<Dst, is_fp32_dest_acc_en>();
}

template <DstSync Dst, DstTileFaceLayout FaceLayout, bool untilize = false>
template <DstSync Dst, bool untilize = false>
inline void llk_init_packer_dest_offset_registers(const std::uint32_t pack_output = 16) {
const std::uint32_t output_id = get_output_id(pack_output);
const std::uint32_t face_r_dim = get_output_face_r_dim(output_id);
const bool narrow_tile = get_output_narrow_tile(output_id);

_llk_init_packer_dest_offset_registers_<Dst, FaceLayout, untilize>(
_llk_init_packer_dest_offset_registers_<Dst, DstTileFaceLayout::RowMajor, untilize>(
face_r_dim,
narrow_tile
);
}

template <DstSync Dst, DstTileFaceLayout FaceLayout = RowMajor, bool untilize = false, bool is_fp32_dest_acc_en = false>
template <DstSync Dst, bool untilize = false, bool is_fp32_dest_acc_en = false>
inline void llk_pack_dest_init(const std::uint32_t pack_output = 16) {

const std::uint32_t output_id = get_output_id(pack_output);
const std::uint32_t face_r_dim = get_output_face_r_dim(output_id);
const bool narrow_tile = get_output_narrow_tile(output_id);

_llk_pack_dest_init_<Dst, FaceLayout, untilize, is_fp32_dest_acc_en>(
_llk_pack_dest_init_<Dst, DstTileFaceLayout::RowMajor, untilize, is_fp32_dest_acc_en>(
face_r_dim,
narrow_tile
);
Expand All @@ -276,7 +276,7 @@ inline void llk_pack_debug_dump_seek(std::uint8_t offset) {
_llk_pack_debug_dump_seek_(offset);
}

template <bool is_fp32_dest_acc_en = false, bool is_tile_dim_reconfig_en = false, DstTileFaceLayout FaceLayout = DstTileFaceLayout::RowMajor>
template <bool is_fp32_dest_acc_en = false, bool is_tile_dim_reconfig_en = false>
inline void llk_pack_reconfig_data_format(const std::uint32_t new_output) {

const std::uint32_t output_id = get_output_id(new_output);
Expand All @@ -285,7 +285,7 @@ inline void llk_pack_reconfig_data_format(const std::uint32_t new_output) {
const bool partial_face = get_output_partial_face(output_id);
const bool narrow_tile = get_output_narrow_tile(output_id);

_llk_pack_reconfig_data_format_<is_fp32_dest_acc_en, is_tile_dim_reconfig_en, FaceLayout>(
_llk_pack_reconfig_data_format_<is_fp32_dest_acc_en, is_tile_dim_reconfig_en, DstTileFaceLayout::RowMajor>(
pack_src_format[output_id],
pack_dst_format[output_id],
cb_interface[output_id].fifo_page_size,
Expand All @@ -296,18 +296,18 @@ inline void llk_pack_reconfig_data_format(const std::uint32_t new_output) {
);
}

template <bool is_fp32_dest_acc_en = false, bool is_tile_dim_reconfig_en = false, DstTileFaceLayout FaceLayout = DstTileFaceLayout::RowMajor>
template <bool is_fp32_dest_acc_en = false, bool is_tile_dim_reconfig_en = false>
inline void llk_pack_reconfig_data_format(const std::uint32_t old_output, const std::uint32_t new_output) {
std::uint32_t old_output_id = get_output_id(old_output);
std::uint32_t new_output_id = get_output_id(new_output);

if((pack_dst_format[old_output_id] != pack_dst_format[new_output_id])
&& (pack_dst_format[old_output_id] != (uint)DataFormat::Invalid)
&& (pack_dst_format[new_output_id] != (uint)DataFormat::Invalid)) {
llk_pack_reconfig_data_format<is_fp32_dest_acc_en, is_tile_dim_reconfig_en, FaceLayout>(new_output);
llk_pack_reconfig_data_format<is_fp32_dest_acc_en, is_tile_dim_reconfig_en>(new_output);
} else if constexpr (is_tile_dim_reconfig_en) {
// Same format but different tile dims
llk_pack_mop_config<false, false, FaceLayout, false>(new_output);
llk_pack_mop_config<false, false>(new_output);
}
}

Expand Down
Loading

0 comments on commit 069cbd3

Please sign in to comment.