Skip to content

Commit

Permalink
#8322: extend ssm_eltwise_mul to handle third case
Browse files Browse the repository at this point in the history
  • Loading branch information
kpaigwar committed May 10, 2024
1 parent 3c9c641 commit 892399c
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,14 @@ def run_ssm_eltwise_mul_test(in0_W, in1_W, dtype, in0_mem_config, in1_mem_config
out = tt2torch_tensor(tt_out)

# Compute reference on pytorch
if in0_W == latent_size:
if in0_W == latent_size and in1_W == hidden_size:
ref_out = B.repeat(1, 1, 1, hidden_size) * X.repeat_interleave(latent_size, dim=-1)
elif in0_W == latent_size * hidden_size:
elif in0_W == latent_size * hidden_size and in1_W == hidden_size:
ref_out = B * X.repeat_interleave(latent_size, dim=-1)
elif in0_W == latent_size and in1_W == latent_size * hidden_size:
ref_out = B.repeat(1, 1, 1, hidden_size) * X
else:
raise Exception("Input shapes invalid, use eltwise_mul for same input shapes,", in0_W, in1_W)

passing_pcc, output_pcc = comp_pcc(out, ref_out, 0.9999)
logger.debug(f"Out passing={passing_pcc}")
Expand Down Expand Up @@ -78,7 +82,7 @@ def run_ssm_eltwise_mul_test(in0_W, in1_W, dtype, in0_mem_config, in1_mem_config
(
(32, 5120),
(32 * 5120, 5120),
# (32, 32*5120), # TODO: Enable this test case where in1 is already expanded
(32, 32 * 5120),
),
)
def test_ssm_eltwise_mul(in0_W, in1_W, dtype, in0_mem_config, in1_mem_config, out_mem_config, device):
Expand All @@ -94,8 +98,10 @@ def test_ssm_eltwise_mul_with_program_cache(device, use_program_cache):
run_ssm_eltwise_mul_test(in0_W, in1_W, dtype, mem_config, mem_config, mem_config, device)
in0_W, in1_W = 32 * 5120, 5120
run_ssm_eltwise_mul_test(in0_W, in1_W, dtype, mem_config, mem_config, mem_config, device)
in0_W, in1_W = 32, 32 * 5120
run_ssm_eltwise_mul_test(in0_W, in1_W, dtype, mem_config, mem_config, mem_config, device)
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttl.tensor.Tensor(py_dummy_tensor, dtype).to(ttl.tensor.Layout.TILE).to(device, mem_config)

assert device.num_program_cache_entries() == 2
assert device.num_program_cache_entries() == 3
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,33 @@ void MAIN {
constexpr uint32_t num_rows_in_one_tile = 32;


binary_op_init_common(cb_in0_transposed, cb_in1_bcast_row); // TODO: Is there a specific one for bcast mul?
#ifdef REPEAT_INTERLEAVE_IN1
binary_op_init_common(cb_in0_transposed, cb_in1_bcast_row); // TODO: Is there a specific one for bcast mul?
#else
binary_op_init_common(cb_id_in0, cb_id_in1);
#endif

#ifdef REPEAT_IN0
// Transpose in0
cb_wait_front(cb_id_in0, onetile);
tile_regs_acquire();
tile_regs_wait();
// No need to transpose in0 if in1 is not repeat_interleaved
#ifdef REPEAT_INTERLEAVE_IN1
tile_regs_acquire();
tile_regs_wait();

transpose_wh_init_short(cb_id_in0);
transpose_wh_tile(cb_id_in0, 0, 0);
transpose_wh_init_short(cb_id_in0);
transpose_wh_tile(cb_id_in0, 0, 0);

cb_reserve_back(cb_in0_transposed, onetile);
pack_tile(0, cb_in0_transposed);
cb_reserve_back(cb_in0_transposed, onetile);
pack_tile(0, cb_in0_transposed);

tile_regs_commit();
tile_regs_release();
cb_push_back(cb_in0_transposed, onetile);
cb_pop_front(cb_id_in0, onetile);
tile_regs_commit();
tile_regs_release();
cb_push_back(cb_in0_transposed, onetile);
cb_pop_front(cb_id_in0, onetile);

cb_wait_front(cb_in0_transposed, onetile);
cb_wait_front(cb_in0_transposed, onetile);
#endif
#endif

for (uint32_t in1_block = 0; in1_block < in1_num_blocks; in1_block++) {
Expand All @@ -51,97 +58,115 @@ void MAIN {
tile_regs_acquire();
tile_regs_wait();

transpose_wh_init_short(cb_id_in1);
transpose_wh_tile(cb_id_in1, 0, 0);
//If input b is not repeat_interleaved, then no need to transpose, bcast row
#ifndef REPEAT_INTERLEAVE_IN1
mul_tiles_init(cb_id_in0, cb_id_in1);
mul_tiles(cb_id_in0, cb_id_in1, 0, 0, 0);

cb_reserve_back(cb_in1_transposed, onetile);
pack_tile(0, cb_in1_transposed);
cb_reserve_back(cb_id_out, onetile);
pack_tile(0, cb_id_out);

tile_regs_commit();
tile_regs_release();
cb_push_back(cb_in1_transposed, onetile);
cb_pop_front(cb_id_in1, onetile);
tile_regs_commit();
tile_regs_release();
cb_push_back(cb_id_out, onetile);
cb_pop_front(cb_id_in1, onetile);
#else
transpose_wh_init_short(cb_id_in1);
transpose_wh_tile(cb_id_in1, 0, 0);

// Receive in1 as single rows to bcast mul with in0
for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_one_tile; tile_row_id++) {
#ifndef REPEAT_IN0
// Transpose in0
cb_wait_front(cb_id_in0, onetile);
tile_regs_acquire();
tile_regs_wait();
cb_reserve_back(cb_in1_transposed, onetile);
pack_tile(0, cb_in1_transposed);

transpose_wh_init_short(cb_id_in0);
transpose_wh_tile(cb_id_in0, 0, 0);
tile_regs_commit();
tile_regs_release();
cb_push_back(cb_in1_transposed, onetile);
cb_pop_front(cb_id_in1, onetile);

cb_reserve_back(cb_in0_transposed, onetile);
pack_tile(0, cb_in0_transposed);
// Receive in1 as single rows to bcast mul with in0
for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_one_tile; tile_row_id++) {
#ifndef REPEAT_IN0
// Transpose in0
cb_wait_front(cb_id_in0, onetile);
tile_regs_acquire();
tile_regs_wait();

tile_regs_commit();
tile_regs_release();
cb_push_back(cb_in0_transposed, onetile);
cb_pop_front(cb_id_in0, onetile);
transpose_wh_init_short(cb_id_in0);
transpose_wh_tile(cb_id_in0, 0, 0);

cb_wait_front(cb_in0_transposed, onetile);
#endif
cb_reserve_back(cb_in0_transposed, onetile);
pack_tile(0, cb_in0_transposed);

cb_wait_front(cb_in1_bcast_row, onetile);
tile_regs_acquire();
tile_regs_wait();
tile_regs_commit();
tile_regs_release();
cb_push_back(cb_in0_transposed, onetile);
cb_pop_front(cb_id_in0, onetile);

mul_bcast_rows_init_short(cb_in0_transposed, cb_in1_bcast_row);
mul_tiles_bcast_rows(cb_in0_transposed, cb_in1_bcast_row, 0, 0, 0);
cb_wait_front(cb_in0_transposed, onetile);
#endif

cb_reserve_back(cb_out_transposed, onetile);
pack_tile(0, cb_out_transposed);
cb_wait_front(cb_in1_bcast_row, onetile);
tile_regs_acquire();
tile_regs_wait();

tile_regs_commit();
tile_regs_release();
cb_push_back(cb_out_transposed, onetile);
#ifndef REPEAT_IN0
cb_pop_front(cb_in0_transposed, onetile);
#endif
cb_pop_front(cb_in1_bcast_row, onetile);

// Transpose output back
cb_wait_front(cb_out_transposed, onetile);
tile_regs_acquire();
tile_regs_wait();
mul_bcast_rows_init_short(cb_in0_transposed, cb_in1_bcast_row);
mul_tiles_bcast_rows(cb_in0_transposed, cb_in1_bcast_row, 0, 0, 0);

transpose_wh_init_short(cb_out_transposed);
transpose_wh_tile(cb_out_transposed, 0, 0);
cb_reserve_back(cb_out_transposed, onetile);
pack_tile(0, cb_out_transposed);

cb_reserve_back(cb_id_out, onetile);
pack_tile(0, cb_id_out);
tile_regs_commit();
tile_regs_release();
cb_push_back(cb_out_transposed, onetile);
#ifndef REPEAT_IN0
cb_pop_front(cb_in0_transposed, onetile);
#endif
cb_pop_front(cb_in1_bcast_row, onetile);

// Transpose output back
cb_wait_front(cb_out_transposed, onetile);
tile_regs_acquire();
tile_regs_wait();

tile_regs_commit();
tile_regs_release();
cb_push_back(cb_id_out, onetile);
cb_pop_front(cb_out_transposed, onetile);
transpose_wh_init_short(cb_out_transposed);
transpose_wh_tile(cb_out_transposed, 0, 0);

/* TODO: Transpose directly on tiles in DST; is something like this possible?
cb_reserve_back(cb_id_out, onetile);
cb_reserve_back(cb_id_out, onetile);
pack_tile(0, cb_id_out);

tile_regs_acquire();
tile_regs_wait();
mul_bcast_rows_init_short(cb_in0_transposed, cb_in1_bcast_row);
mul_tiles_bcast_rows(cb_in0_transposed, cb_in1_bcast_row, 0, 0, 0);
tile_regs_commit();
tile_regs_release();
cb_push_back(cb_id_out, onetile);
cb_pop_front(cb_out_transposed, onetile);

MATH(( llk_math_eltwise_unary_datacopy_init<A2D, BroadcastType::NONE, DST_ACCUM_MODE>(true, true, cb_id_out) ));
MATH(( llk_math_eltwise_unary_datacopy<A2D, BroadcastType::NONE, DST_ACCUM_MODE>(0) ));
/* TODO: Transpose directly on tiles in DST; is something like this possible?
cb_reserve_back(cb_id_out, onetile);
pack_tile(0, cb_id_out);
tile_regs_acquire();
tile_regs_wait();
mul_bcast_rows_init_short(cb_in0_transposed, cb_in1_bcast_row);
mul_tiles_bcast_rows(cb_in0_transposed, cb_in1_bcast_row, 0, 0, 0);
tile_regs_commit();
tile_regs_release();
cb_push_back(cb_id_out, onetile);
cb_pop_front(cb_in1_bcast_row, onetile);
*/
}
MATH(( llk_math_eltwise_unary_datacopy_init<A2D, BroadcastType::NONE, DST_ACCUM_MODE>(true, true, cb_id_out) ));
MATH(( llk_math_eltwise_unary_datacopy<A2D, BroadcastType::NONE, DST_ACCUM_MODE>(0) ));
pack_tile(0, cb_id_out);
tile_regs_commit();
tile_regs_release();
cb_push_back(cb_id_out, onetile);
cb_pop_front(cb_in1_bcast_row, onetile);
*/
}

cb_pop_front(cb_in1_transposed, onetile);
cb_pop_front(cb_in1_transposed, onetile);
#endif
}
#ifdef REPEAT_IN0
cb_pop_front(cb_in0_transposed, onetile);
#ifdef REPEAT_INTERLEAVE_IN1
cb_pop_front(cb_in0_transposed, onetile);
#else
cb_pop_front(cb_id_in0, onetile);
#endif
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,54 +61,57 @@ void kernel_main() {
noc_async_read_barrier();
cb_push_back(cb_id_in1, onetile);

cb_wait_front(cb_in1_transposed, onetile);
uint64_t cb_in1_transposed_read_ptr = get_noc_addr(get_read_ptr(cb_in1_transposed));

// Manually unroll iterating across the tile to eliminate unncessary conditional checking
// First + second face
for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_face; tile_row_id++) {
cb_reserve_back(cb_in1_bcast_row, onetile);
uint32_t cb_in1_bcast_row_write_ptr = get_write_ptr(cb_in1_bcast_row);

#ifndef REPEAT_IN0
cb_reserve_back(cb_id_in0, onetile);
l1_write_addr_in0 = get_write_ptr(cb_id_in0);
noc_async_read_tile(i * in0_blocks_per_in1_block + tile_row_id, s0, l1_write_addr_in0);
#endif
noc_async_read(cb_in1_transposed_read_ptr, cb_in1_bcast_row_write_ptr, bfloat16_one_row_in_face_bytes);
noc_async_read(cb_in1_transposed_read_ptr + bfloat16_one_face_bytes, cb_in1_bcast_row_write_ptr + bfloat16_one_face_bytes, bfloat16_one_row_in_face_bytes);
noc_async_read_barrier();

#ifndef REPEAT_IN0
cb_push_back(cb_id_in0, onetile);
#endif
cb_push_back(cb_in1_bcast_row, onetile);

cb_in1_transposed_read_ptr += bfloat16_one_row_in_face_bytes;
}

cb_in1_transposed_read_ptr += bfloat16_one_face_bytes;
// Third + fourth face
for (uint32_t tile_row_id = num_rows_in_face; tile_row_id < 2*num_rows_in_face; tile_row_id++) {
cb_reserve_back(cb_in1_bcast_row, onetile);
uint32_t cb_in1_bcast_row_write_ptr = get_write_ptr(cb_in1_bcast_row);

#ifndef REPEAT_IN0
cb_reserve_back(cb_id_in0, onetile);
l1_write_addr_in0 = get_write_ptr(cb_id_in0);
noc_async_read_tile(i * in0_blocks_per_in1_block + tile_row_id, s0, l1_write_addr_in0);
#endif
noc_async_read(cb_in1_transposed_read_ptr, cb_in1_bcast_row_write_ptr, bfloat16_one_row_in_face_bytes);
noc_async_read(cb_in1_transposed_read_ptr + bfloat16_one_face_bytes, cb_in1_bcast_row_write_ptr + bfloat16_one_face_bytes, bfloat16_one_row_in_face_bytes);
noc_async_read_barrier();

#ifndef REPEAT_IN0
cb_push_back(cb_id_in0, onetile);
#endif
cb_push_back(cb_in1_bcast_row, onetile);

cb_in1_transposed_read_ptr += bfloat16_one_row_in_face_bytes;
}
cb_pop_front(cb_in1_transposed, onetile);
#ifdef REPEAT_INTERLEAVE_IN1
cb_wait_front(cb_in1_transposed, onetile);
uint64_t cb_in1_transposed_read_ptr = get_noc_addr(get_read_ptr(cb_in1_transposed));

// Manually unroll iterating across the tile to eliminate unncessary conditional checking
// First + second face
for (uint32_t tile_row_id = 0; tile_row_id < num_rows_in_face; tile_row_id++) {
cb_reserve_back(cb_in1_bcast_row, onetile);
uint32_t cb_in1_bcast_row_write_ptr = get_write_ptr(cb_in1_bcast_row);

#ifndef REPEAT_IN0
cb_reserve_back(cb_id_in0, onetile);
l1_write_addr_in0 = get_write_ptr(cb_id_in0);
noc_async_read_tile(i * in0_blocks_per_in1_block + tile_row_id, s0, l1_write_addr_in0);
#endif
noc_async_read(cb_in1_transposed_read_ptr, cb_in1_bcast_row_write_ptr, bfloat16_one_row_in_face_bytes);
noc_async_read(cb_in1_transposed_read_ptr + bfloat16_one_face_bytes, cb_in1_bcast_row_write_ptr + bfloat16_one_face_bytes, bfloat16_one_row_in_face_bytes);
noc_async_read_barrier();

#ifndef REPEAT_IN0
cb_push_back(cb_id_in0, onetile);
#endif
cb_push_back(cb_in1_bcast_row, onetile);

cb_in1_transposed_read_ptr += bfloat16_one_row_in_face_bytes;
}

cb_in1_transposed_read_ptr += bfloat16_one_face_bytes;
// Third + fourth face
for (uint32_t tile_row_id = num_rows_in_face; tile_row_id < 2*num_rows_in_face; tile_row_id++) {
cb_reserve_back(cb_in1_bcast_row, onetile);
uint32_t cb_in1_bcast_row_write_ptr = get_write_ptr(cb_in1_bcast_row);

#ifndef REPEAT_IN0
cb_reserve_back(cb_id_in0, onetile);
l1_write_addr_in0 = get_write_ptr(cb_id_in0);
noc_async_read_tile(i * in0_blocks_per_in1_block + tile_row_id, s0, l1_write_addr_in0);
#endif
noc_async_read(cb_in1_transposed_read_ptr, cb_in1_bcast_row_write_ptr, bfloat16_one_row_in_face_bytes);
noc_async_read(cb_in1_transposed_read_ptr + bfloat16_one_face_bytes, cb_in1_bcast_row_write_ptr + bfloat16_one_face_bytes, bfloat16_one_row_in_face_bytes);
noc_async_read_barrier();

#ifndef REPEAT_IN0
cb_push_back(cb_id_in0, onetile);
#endif
cb_push_back(cb_in1_bcast_row, onetile);

cb_in1_transposed_read_ptr += bfloat16_one_row_in_face_bytes;
}
cb_pop_front(cb_in1_transposed, onetile);

#endif
}
}
Loading

0 comments on commit 892399c

Please sign in to comment.