From 892399cdea2d122e27bbed50bc5662f8f382b9bb Mon Sep 17 00:00:00 2001 From: kpaigwar Date: Thu, 9 May 2024 20:28:36 +0000 Subject: [PATCH] #8322: extend ssm_eltwise_mul to handle third case --- .../unit_testing/misc/test_ssm_eltwise_mul.py | 14 +- .../kernels/compute/ssm_eltwise_mul.cpp | 187 ++++++++++-------- .../dataflow/reader_ssm_eltwise_mul.cpp | 101 +++++----- .../multi_core_ssm_eltwise_mul.cpp | 16 +- 4 files changed, 181 insertions(+), 137 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_ssm_eltwise_mul.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_ssm_eltwise_mul.py index 002363400e1..70461c82fc7 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_ssm_eltwise_mul.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_ssm_eltwise_mul.py @@ -36,10 +36,14 @@ def run_ssm_eltwise_mul_test(in0_W, in1_W, dtype, in0_mem_config, in1_mem_config out = tt2torch_tensor(tt_out) # Compute reference on pytorch - if in0_W == latent_size: + if in0_W == latent_size and in1_W == hidden_size: ref_out = B.repeat(1, 1, 1, hidden_size) * X.repeat_interleave(latent_size, dim=-1) - elif in0_W == latent_size * hidden_size: + elif in0_W == latent_size * hidden_size and in1_W == hidden_size: ref_out = B * X.repeat_interleave(latent_size, dim=-1) + elif in0_W == latent_size and in1_W == latent_size * hidden_size: + ref_out = B.repeat(1, 1, 1, hidden_size) * X + else: + raise Exception("Input shapes invalid, use eltwise_mul for same input shapes,", in0_W, in1_W) passing_pcc, output_pcc = comp_pcc(out, ref_out, 0.9999) logger.debug(f"Out passing={passing_pcc}") @@ -78,7 +82,7 @@ def run_ssm_eltwise_mul_test(in0_W, in1_W, dtype, in0_mem_config, in1_mem_config ( (32, 5120), (32 * 5120, 5120), - # (32, 32*5120), # TODO: Enable this test case where in1 is already expanded + (32, 32 * 5120), ), ) def test_ssm_eltwise_mul(in0_W, in1_W, dtype, in0_mem_config, in1_mem_config, out_mem_config, device): @@ -94,8 +98,10 @@ def test_ssm_eltwise_mul_with_program_cache(device, use_program_cache): run_ssm_eltwise_mul_test(in0_W, in1_W, dtype, mem_config, mem_config, mem_config, device) in0_W, in1_W = 32 * 5120, 5120 run_ssm_eltwise_mul_test(in0_W, in1_W, dtype, mem_config, mem_config, mem_config, device) + in0_W, in1_W = 32, 32 * 5120 + run_ssm_eltwise_mul_test(in0_W, in1_W, dtype, mem_config, mem_config, mem_config, device) dummy_shape = [1, 1, 32, 32] py_dummy_tensor = torch.randn(dummy_shape) tt_dummy_tensor = ttl.tensor.Tensor(py_dummy_tensor, dtype).to(ttl.tensor.Layout.TILE).to(device, mem_config) - assert device.num_program_cache_entries() == 2 + assert device.num_program_cache_entries() == 3 diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/ssm_eltwise_mul.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/ssm_eltwise_mul.cpp index 779bdfcc78a..7ca09899009 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/ssm_eltwise_mul.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/ssm_eltwise_mul.cpp @@ -23,26 +23,33 @@ void MAIN { constexpr uint32_t num_rows_in_one_tile = 32; - binary_op_init_common(cb_in0_transposed, cb_in1_bcast_row); // TODO: Is there a specific one for bcast mul? + #ifdef REPEAT_INTERLEAVE_IN1 + binary_op_init_common(cb_in0_transposed, cb_in1_bcast_row); // TODO: Is there a specific one for bcast mul? + #else + binary_op_init_common(cb_id_in0, cb_id_in1); + #endif #ifdef REPEAT_IN0 // Transpose in0 cb_wait_front(cb_id_in0, onetile); - tile_regs_acquire(); - tile_regs_wait(); + // No need to transpose in0 if in1 is not repeat_interleaved + #ifdef REPEAT_INTERLEAVE_IN1 + tile_regs_acquire(); + tile_regs_wait(); - transpose_wh_init_short(cb_id_in0); - transpose_wh_tile(cb_id_in0, 0, 0); + transpose_wh_init_short(cb_id_in0); + transpose_wh_tile(cb_id_in0, 0, 0); - cb_reserve_back(cb_in0_transposed, onetile); - pack_tile(0, cb_in0_transposed); + cb_reserve_back(cb_in0_transposed, onetile); + pack_tile(0, cb_in0_transposed); - tile_regs_commit(); - tile_regs_release(); - cb_push_back(cb_in0_transposed, onetile); - cb_pop_front(cb_id_in0, onetile); + tile_regs_commit(); + tile_regs_release(); + cb_push_back(cb_in0_transposed, onetile); + cb_pop_front(cb_id_in0, onetile); - cb_wait_front(cb_in0_transposed, onetile); + cb_wait_front(cb_in0_transposed, onetile); + #endif #endif for (uint32_t in1_block = 0; in1_block < in1_num_blocks; in1_block++) { @@ -51,97 +58,115 @@ void MAIN { tile_regs_acquire(); tile_regs_wait(); - transpose_wh_init_short(cb_id_in1); - transpose_wh_tile(cb_id_in1, 0, 0); + //If input b is not repeat_interleaved, then no need to transpose, bcast row + #ifndef REPEAT_INTERLEAVE_IN1 + mul_tiles_init(cb_id_in0, cb_id_in1); + mul_tiles(cb_id_in0, cb_id_in1, 0, 0, 0); - cb_reserve_back(cb_in1_transposed, onetile); - pack_tile(0, cb_in1_transposed); + cb_reserve_back(cb_id_out, onetile); + pack_tile(0, cb_id_out); - tile_regs_commit(); - tile_regs_release(); - cb_push_back(cb_in1_transposed, onetile); - cb_pop_front(cb_id_in1, onetile); + tile_regs_commit(); + tile_regs_release(); + cb_push_back(cb_id_out, onetile); + cb_pop_front(cb_id_in1, onetile); + #else + transpose_wh_init_short(cb_id_in1); + transpose_wh_tile(cb_id_in1, 0, 0); - // Receive in1 as single rows to bcast mul with in0 - for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_one_tile; tile_row_id++) { - #ifndef REPEAT_IN0 - // Transpose in0 - cb_wait_front(cb_id_in0, onetile); - tile_regs_acquire(); - tile_regs_wait(); + cb_reserve_back(cb_in1_transposed, onetile); + pack_tile(0, cb_in1_transposed); - transpose_wh_init_short(cb_id_in0); - transpose_wh_tile(cb_id_in0, 0, 0); + tile_regs_commit(); + tile_regs_release(); + cb_push_back(cb_in1_transposed, onetile); + cb_pop_front(cb_id_in1, onetile); - cb_reserve_back(cb_in0_transposed, onetile); - pack_tile(0, cb_in0_transposed); + // Receive in1 as single rows to bcast mul with in0 + for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_one_tile; tile_row_id++) { + #ifndef REPEAT_IN0 + // Transpose in0 + cb_wait_front(cb_id_in0, onetile); + tile_regs_acquire(); + tile_regs_wait(); - tile_regs_commit(); - tile_regs_release(); - cb_push_back(cb_in0_transposed, onetile); - cb_pop_front(cb_id_in0, onetile); + transpose_wh_init_short(cb_id_in0); + transpose_wh_tile(cb_id_in0, 0, 0); - cb_wait_front(cb_in0_transposed, onetile); - #endif + cb_reserve_back(cb_in0_transposed, onetile); + pack_tile(0, cb_in0_transposed); - cb_wait_front(cb_in1_bcast_row, onetile); - tile_regs_acquire(); - tile_regs_wait(); + tile_regs_commit(); + tile_regs_release(); + cb_push_back(cb_in0_transposed, onetile); + cb_pop_front(cb_id_in0, onetile); - mul_bcast_rows_init_short(cb_in0_transposed, cb_in1_bcast_row); - mul_tiles_bcast_rows(cb_in0_transposed, cb_in1_bcast_row, 0, 0, 0); + cb_wait_front(cb_in0_transposed, onetile); + #endif - cb_reserve_back(cb_out_transposed, onetile); - pack_tile(0, cb_out_transposed); + cb_wait_front(cb_in1_bcast_row, onetile); + tile_regs_acquire(); + tile_regs_wait(); - tile_regs_commit(); - tile_regs_release(); - cb_push_back(cb_out_transposed, onetile); - #ifndef REPEAT_IN0 - cb_pop_front(cb_in0_transposed, onetile); - #endif - cb_pop_front(cb_in1_bcast_row, onetile); - - // Transpose output back - cb_wait_front(cb_out_transposed, onetile); - tile_regs_acquire(); - tile_regs_wait(); + mul_bcast_rows_init_short(cb_in0_transposed, cb_in1_bcast_row); + mul_tiles_bcast_rows(cb_in0_transposed, cb_in1_bcast_row, 0, 0, 0); - transpose_wh_init_short(cb_out_transposed); - transpose_wh_tile(cb_out_transposed, 0, 0); + cb_reserve_back(cb_out_transposed, onetile); + pack_tile(0, cb_out_transposed); - cb_reserve_back(cb_id_out, onetile); - pack_tile(0, cb_id_out); + tile_regs_commit(); + tile_regs_release(); + cb_push_back(cb_out_transposed, onetile); + #ifndef REPEAT_IN0 + cb_pop_front(cb_in0_transposed, onetile); + #endif + cb_pop_front(cb_in1_bcast_row, onetile); + + // Transpose output back + cb_wait_front(cb_out_transposed, onetile); + tile_regs_acquire(); + tile_regs_wait(); - tile_regs_commit(); - tile_regs_release(); - cb_push_back(cb_id_out, onetile); - cb_pop_front(cb_out_transposed, onetile); + transpose_wh_init_short(cb_out_transposed); + transpose_wh_tile(cb_out_transposed, 0, 0); - /* TODO: Transpose directly on tiles in DST; is something like this possible? - cb_reserve_back(cb_id_out, onetile); + cb_reserve_back(cb_id_out, onetile); + pack_tile(0, cb_id_out); - tile_regs_acquire(); - tile_regs_wait(); - mul_bcast_rows_init_short(cb_in0_transposed, cb_in1_bcast_row); - mul_tiles_bcast_rows(cb_in0_transposed, cb_in1_bcast_row, 0, 0, 0); + tile_regs_commit(); + tile_regs_release(); + cb_push_back(cb_id_out, onetile); + cb_pop_front(cb_out_transposed, onetile); - MATH(( llk_math_eltwise_unary_datacopy_init(true, true, cb_id_out) )); - MATH(( llk_math_eltwise_unary_datacopy(0) )); + /* TODO: Transpose directly on tiles in DST; is something like this possible? + cb_reserve_back(cb_id_out, onetile); - pack_tile(0, cb_id_out); + tile_regs_acquire(); + tile_regs_wait(); + mul_bcast_rows_init_short(cb_in0_transposed, cb_in1_bcast_row); + mul_tiles_bcast_rows(cb_in0_transposed, cb_in1_bcast_row, 0, 0, 0); - tile_regs_commit(); - tile_regs_release(); - cb_push_back(cb_id_out, onetile); - cb_pop_front(cb_in1_bcast_row, onetile); - */ - } + MATH(( llk_math_eltwise_unary_datacopy_init(true, true, cb_id_out) )); + MATH(( llk_math_eltwise_unary_datacopy(0) )); + + pack_tile(0, cb_id_out); + + tile_regs_commit(); + tile_regs_release(); + cb_push_back(cb_id_out, onetile); + cb_pop_front(cb_in1_bcast_row, onetile); + */ + } - cb_pop_front(cb_in1_transposed, onetile); + cb_pop_front(cb_in1_transposed, onetile); + #endif } #ifdef REPEAT_IN0 - cb_pop_front(cb_in0_transposed, onetile); + #ifdef REPEAT_INTERLEAVE_IN1 + cb_pop_front(cb_in0_transposed, onetile); + #else + cb_pop_front(cb_id_in0, onetile); + #endif #endif } } diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_ssm_eltwise_mul.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_ssm_eltwise_mul.cpp index 5f6526877f7..466d94466da 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_ssm_eltwise_mul.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_ssm_eltwise_mul.cpp @@ -61,54 +61,57 @@ void kernel_main() { noc_async_read_barrier(); cb_push_back(cb_id_in1, onetile); - cb_wait_front(cb_in1_transposed, onetile); - uint64_t cb_in1_transposed_read_ptr = get_noc_addr(get_read_ptr(cb_in1_transposed)); - - // Manually unroll iterating across the tile to eliminate unncessary conditional checking - // First + second face - for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_face; tile_row_id++) { - cb_reserve_back(cb_in1_bcast_row, onetile); - uint32_t cb_in1_bcast_row_write_ptr = get_write_ptr(cb_in1_bcast_row); - - #ifndef REPEAT_IN0 - cb_reserve_back(cb_id_in0, onetile); - l1_write_addr_in0 = get_write_ptr(cb_id_in0); - noc_async_read_tile(i * in0_blocks_per_in1_block + tile_row_id, s0, l1_write_addr_in0); - #endif - noc_async_read(cb_in1_transposed_read_ptr, cb_in1_bcast_row_write_ptr, bfloat16_one_row_in_face_bytes); - noc_async_read(cb_in1_transposed_read_ptr + bfloat16_one_face_bytes, cb_in1_bcast_row_write_ptr + bfloat16_one_face_bytes, bfloat16_one_row_in_face_bytes); - noc_async_read_barrier(); - - #ifndef REPEAT_IN0 - cb_push_back(cb_id_in0, onetile); - #endif - cb_push_back(cb_in1_bcast_row, onetile); - - cb_in1_transposed_read_ptr += bfloat16_one_row_in_face_bytes; - } - - cb_in1_transposed_read_ptr += bfloat16_one_face_bytes; - // Third + fourth face - for (uint32_t tile_row_id = num_rows_in_face; tile_row_id < 2*num_rows_in_face; tile_row_id++) { - cb_reserve_back(cb_in1_bcast_row, onetile); - uint32_t cb_in1_bcast_row_write_ptr = get_write_ptr(cb_in1_bcast_row); - - #ifndef REPEAT_IN0 - cb_reserve_back(cb_id_in0, onetile); - l1_write_addr_in0 = get_write_ptr(cb_id_in0); - noc_async_read_tile(i * in0_blocks_per_in1_block + tile_row_id, s0, l1_write_addr_in0); - #endif - noc_async_read(cb_in1_transposed_read_ptr, cb_in1_bcast_row_write_ptr, bfloat16_one_row_in_face_bytes); - noc_async_read(cb_in1_transposed_read_ptr + bfloat16_one_face_bytes, cb_in1_bcast_row_write_ptr + bfloat16_one_face_bytes, bfloat16_one_row_in_face_bytes); - noc_async_read_barrier(); - - #ifndef REPEAT_IN0 - cb_push_back(cb_id_in0, onetile); - #endif - cb_push_back(cb_in1_bcast_row, onetile); - - cb_in1_transposed_read_ptr += bfloat16_one_row_in_face_bytes; - } - cb_pop_front(cb_in1_transposed, onetile); + #ifdef REPEAT_INTERLEAVE_IN1 + cb_wait_front(cb_in1_transposed, onetile); + uint64_t cb_in1_transposed_read_ptr = get_noc_addr(get_read_ptr(cb_in1_transposed)); + + // Manually unroll iterating across the tile to eliminate unncessary conditional checking + // First + second face + for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_face; tile_row_id++) { + cb_reserve_back(cb_in1_bcast_row, onetile); + uint32_t cb_in1_bcast_row_write_ptr = get_write_ptr(cb_in1_bcast_row); + + #ifndef REPEAT_IN0 + cb_reserve_back(cb_id_in0, onetile); + l1_write_addr_in0 = get_write_ptr(cb_id_in0); + noc_async_read_tile(i * in0_blocks_per_in1_block + tile_row_id, s0, l1_write_addr_in0); + #endif + noc_async_read(cb_in1_transposed_read_ptr, cb_in1_bcast_row_write_ptr, bfloat16_one_row_in_face_bytes); + noc_async_read(cb_in1_transposed_read_ptr + bfloat16_one_face_bytes, cb_in1_bcast_row_write_ptr + bfloat16_one_face_bytes, bfloat16_one_row_in_face_bytes); + noc_async_read_barrier(); + + #ifndef REPEAT_IN0 + cb_push_back(cb_id_in0, onetile); + #endif + cb_push_back(cb_in1_bcast_row, onetile); + + cb_in1_transposed_read_ptr += bfloat16_one_row_in_face_bytes; + } + + cb_in1_transposed_read_ptr += bfloat16_one_face_bytes; + // Third + fourth face + for (uint32_t tile_row_id = num_rows_in_face; tile_row_id < 2*num_rows_in_face; tile_row_id++) { + cb_reserve_back(cb_in1_bcast_row, onetile); + uint32_t cb_in1_bcast_row_write_ptr = get_write_ptr(cb_in1_bcast_row); + + #ifndef REPEAT_IN0 + cb_reserve_back(cb_id_in0, onetile); + l1_write_addr_in0 = get_write_ptr(cb_id_in0); + noc_async_read_tile(i * in0_blocks_per_in1_block + tile_row_id, s0, l1_write_addr_in0); + #endif + noc_async_read(cb_in1_transposed_read_ptr, cb_in1_bcast_row_write_ptr, bfloat16_one_row_in_face_bytes); + noc_async_read(cb_in1_transposed_read_ptr + bfloat16_one_face_bytes, cb_in1_bcast_row_write_ptr + bfloat16_one_face_bytes, bfloat16_one_row_in_face_bytes); + noc_async_read_barrier(); + + #ifndef REPEAT_IN0 + cb_push_back(cb_id_in0, onetile); + #endif + cb_push_back(cb_in1_bcast_row, onetile); + + cb_in1_transposed_read_ptr += bfloat16_one_row_in_face_bytes; + } + cb_pop_front(cb_in1_transposed, onetile); + + #endif } } diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_ssm_eltwise_mul/multi_core_ssm_eltwise_mul.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_ssm_eltwise_mul/multi_core_ssm_eltwise_mul.cpp index 06a3c16df87..86e18bff487 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_ssm_eltwise_mul/multi_core_ssm_eltwise_mul.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_ssm_eltwise_mul/multi_core_ssm_eltwise_mul.cpp @@ -171,7 +171,9 @@ operation::ProgramWithCallbacks multi_core_ssm_eltwise_mul(const Tensor &a, cons g1_numcores, g2_numcores, num_blocks_per_core_group_1, - num_blocks_per_core_group_2 + num_blocks_per_core_group_2, + bshape, + hidden_size ] ( Program& program, @@ -225,8 +227,16 @@ operation::ProgramWithCallbacks multi_core_ssm_eltwise_mul(const Tensor &a, cons all_reader_runtime_args[i][3] = num_blocks_written; all_writer_runtime_args[i][0] = dst_buffer->address(); - all_writer_runtime_args[i][1] = num_blocks_per_core * TILE_WIDTH; - all_writer_runtime_args[i][2] = num_blocks_written * TILE_WIDTH; + + // update writer's num_tiles based on input_b already repeat_interleaved or not + if (bshape[-1] == hidden_size) { + all_writer_runtime_args[i][1] = num_blocks_per_core * TILE_WIDTH; + all_writer_runtime_args[i][2] = num_blocks_written * TILE_WIDTH; + } + else { + all_writer_runtime_args[i][1] = num_blocks_per_core; + all_writer_runtime_args[i][2] = num_blocks_written; + } all_compute_runtime_args[i][0] = num_blocks_per_core;