Skip to content

Commit

Permalink
#4686: switch back to matmul tiles for group attn matmul until didt fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed Feb 24, 2024
1 parent 39d65f4 commit 880cc79
Showing 1 changed file with 31 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ void MAIN {
#ifdef ARCH_GRAYSKULL
mm_init(cb_in0, cb_in1, cb_intermed0, transpose_hw);
#else
mm_block_init(cb_in0, cb_in1, cb_intermed0, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w);
// TODO: switch back to matmul block after didt solved
mm_init(cb_in0, cb_in1, cb_intermed0, transpose_hw);
// mm_block_init(cb_in0, cb_in1, cb_intermed0, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w);
#endif

for (uint32_t b = 0; b < batch; b++) {
Expand Down Expand Up @@ -95,19 +97,35 @@ void MAIN {
in0_index_h_offset += in0_block_w;
}
#else
// Compute output sub-block
uint32_t dst_index = 0; // start at 0, each call to matmul_block internally increments dst_index
uint32_t in0_index = in0_index_subblock_offset; // offset into in0 block
uint32_t in1_index = in1_index_subblock_offset; // offset into in1 block
// inner dim that we accumualte is the inner dim of in0/in1, which is in0_block_w
for (uint32_t inner_dim_idx = 0; inner_dim_idx < in0_block_w; ++inner_dim_idx) {
// matmul outer product of (out_subblock_h x out_subblock_w) tiles that fill dst
// accumulation is done by iterating matmul_block across inner dim
// in0_block_w is passed as innder dim (kt) to matmul_block, interally used to stride in0
matmul_block(cb_in0, cb_in1, in0_index, in1_index, dst_index, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w);
in0_index ++; // stride right by 1
in1_index += in1_per_core_w; // to stride down by 1 need to stride by in_per_core_w (should be called in1_block_w)
// TODO: switch back to matmul block after didt solved
uint32_t dst_index = 0;
uint32_t in0_index_h_offset = 0;
for (uint32_t h = 0; h < out_subblock_h; h++) {
for (uint32_t w = 0; w < out_subblock_w; w++) {
uint32_t in1_index_inner_dim_offset = 0;
for (uint32_t inner_dim = 0; inner_dim < in0_block_w; inner_dim++) {
uint32_t in0_index = in0_index_subblock_offset + in0_index_h_offset + inner_dim;
uint32_t in1_index = in1_index_subblock_offset + in1_index_inner_dim_offset + w;
matmul_tiles(cb_in0, cb_in1, in0_index, in1_index, dst_index, transpose_hw);
in1_index_inner_dim_offset += in1_per_core_w;
}
dst_index++;
}
in0_index_h_offset += in0_block_w;
}
// // Compute output sub-block
// uint32_t dst_index = 0; // start at 0, each call to matmul_block internally increments dst_index
// uint32_t in0_index = in0_index_subblock_offset; // offset into in0 block
// uint32_t in1_index = in1_index_subblock_offset; // offset into in1 block
// // inner dim that we accumualte is the inner dim of in0/in1, which is in0_block_w
// for (uint32_t inner_dim_idx = 0; inner_dim_idx < in0_block_w; ++inner_dim_idx) {
// // matmul outer product of (out_subblock_h x out_subblock_w) tiles that fill dst
// // accumulation is done by iterating matmul_block across inner dim
// // in0_block_w is passed as innder dim (kt) to matmul_block, interally used to stride in0
// matmul_block(cb_in0, cb_in1, in0_index, in1_index, dst_index, transpose_hw, out_subblock_w, out_subblock_h, in0_block_w);
// in0_index ++; // stride right by 1
// in1_index += in1_per_core_w; // to stride down by 1 need to stride by in_per_core_w (should be called in1_block_w)
// }
#endif

tile_regs_commit();
Expand Down

0 comments on commit 880cc79

Please sign in to comment.