diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/transformer_group_attn_matmul.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/transformer_group_attn_matmul.cpp index 274ca729f2e..93910da916c 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/transformer_group_attn_matmul.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/transformer_group_attn_matmul.cpp @@ -58,7 +58,9 @@ void MAIN { #ifdef ARCH_GRAYSKULL mm_init(cb_in0, cb_in1, cb_intermed0, transpose_hw); #else - mm_block_init(cb_in0, cb_in1, cb_intermed0, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w); + // TODO: switch back to matmul block after didt solved + mm_init(cb_in0, cb_in1, cb_intermed0, transpose_hw); + // mm_block_init(cb_in0, cb_in1, cb_intermed0, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w); #endif for (uint32_t b = 0; b < batch; b++) { @@ -95,19 +97,35 @@ void MAIN { in0_index_h_offset += in0_block_w; } #else - // Compute output sub-block - uint32_t dst_index = 0; // start at 0, each call to matmul_block internally increments dst_index - uint32_t in0_index = in0_index_subblock_offset; // offset into in0 block - uint32_t in1_index = in1_index_subblock_offset; // offset into in1 block - // inner dim that we accumualte is the inner dim of in0/in1, which is in0_block_w - for (uint32_t inner_dim_idx = 0; inner_dim_idx < in0_block_w; ++inner_dim_idx) { - // matmul outer product of (out_subblock_h x out_subblock_w) tiles that fill dst - // accumulation is done by iterating matmul_block across inner dim - // in0_block_w is passed as innder dim (kt) to matmul_block, interally used to stride in0 - matmul_block(cb_in0, cb_in1, in0_index, in1_index, dst_index, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w); - in0_index ++; // stride right by 1 - in1_index += in1_per_core_w; // to stride down by 1 need to stride by in_per_core_w (should be called in1_block_w) + // TODO: switch back to matmul block after didt solved + uint32_t dst_index = 0; + uint32_t in0_index_h_offset = 0; + for (uint32_t h = 0; h < out_subblock_h; h++) { + for (uint32_t w = 0; w < out_subblock_w; w++) { + uint32_t in1_index_inner_dim_offset = 0; + for (uint32_t inner_dim = 0; inner_dim < in0_block_w; inner_dim++) { + uint32_t in0_index = in0_index_subblock_offset + in0_index_h_offset + inner_dim; + uint32_t in1_index = in1_index_subblock_offset + in1_index_inner_dim_offset + w; + matmul_tiles(cb_in0, cb_in1, in0_index, in1_index, dst_index, transpose_hw); + in1_index_inner_dim_offset += in1_per_core_w; + } + dst_index++; + } + in0_index_h_offset += in0_block_w; } + // // Compute output sub-block + // uint32_t dst_index = 0; // start at 0, each call to matmul_block internally increments dst_index + // uint32_t in0_index = in0_index_subblock_offset; // offset into in0 block + // uint32_t in1_index = in1_index_subblock_offset; // offset into in1 block + // // inner dim that we accumualte is the inner dim of in0/in1, which is in0_block_w + // for (uint32_t inner_dim_idx = 0; inner_dim_idx < in0_block_w; ++inner_dim_idx) { + // // matmul outer product of (out_subblock_h x out_subblock_w) tiles that fill dst + // // accumulation is done by iterating matmul_block across inner dim + // // in0_block_w is passed as innder dim (kt) to matmul_block, interally used to stride in0 + // matmul_block(cb_in0, cb_in1, in0_index, in1_index, dst_index, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w); + // in0_index ++; // stride right by 1 + // in1_index += in1_per_core_w; // to stride down by 1 need to stride by in_per_core_w (should be called in1_block_w) + // } #endif tile_regs_commit();