Skip to content

Commit

Permalink
#0: Support longer sequence legnths in ssm_prefix_scan
Browse files Browse the repository at this point in the history
This change adds support for L > 32 in `ssm_eltwise_mul`. Logically this
op can now handle any value of L but values of L > 128 will run out of
L1 in `bfloat8` format.

This change also fixes #9831.
  • Loading branch information
esmalTT committed Jul 2, 2024
1 parent fa67366 commit 4a8f2da
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,15 @@ def run_ssm_prefix_scan(L: int, E: int, N: int, num_cores: int, dtype, device):
(
(32, 32, 32, 1),
(32, 64, 32, 1),
(64, 32, 32, 1),
(64, 64, 32, 1),
(32, 2560, 32, 32),
(32, 5120, 32, 40),
# (32, 5120, 32, 64) -> 8x8 grid not supported on CI
# (32, 5120, 32, 64), #-> 8x8 grid not supported on CI
# (64, 5120, 32, 64) #-> 8x8 grid not supported on CI
),
)
def test_ssm_reduce(L: int, E: int, N: int, num_cores: int, dtype, device):
def test_ssm_prefix_scan(L: int, E: int, N: int, num_cores: int, dtype, device):
run_ssm_prefix_scan(L, E, N, num_cores, dtype, device)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ constexpr uint32_t cb_out = get_compile_time_arg_val(8);

constexpr uint32_t cb_zeros = get_compile_time_arg_val(9);

constexpr uint32_t cb_h_acc = get_compile_time_arg_val(10);

// This function relies on untilizing NUM_TILES_IN_TILIZED_CHUNK tiles so we pad up to that amount
FORCE_INLINE void pack_block_rows_into_tiles(uint32_t cb_in, uint32_t cb_out, uint32_t num_tiles) {
unpack_reconfig_data_format_srca(cb_in);
Expand Down Expand Up @@ -126,11 +128,6 @@ FORCE_INLINE void copy(uint32_t cb_in, uint32_t cb_out) {
cb_push_back(cb_out, 1);
}

FORCE_INLINE void setup_cb_zeros() {
cb_reserve_back(cb_zeros, 1);
cb_push_back(cb_zeros, 1);
}

FORCE_INLINE void fill_tile_zeros(uint32_t cb_id) { copy(cb_zeros, cb_id); }

FORCE_INLINE void compute_ht(uint32_t cb_a, uint32_t cb_bx, uint32_t cb_out, uint32_t num_tiles) {
Expand All @@ -141,6 +138,8 @@ FORCE_INLINE void compute_ht(uint32_t cb_a, uint32_t cb_bx, uint32_t cb_out, uin
copy(cb_h, cb_out); // TODO: Get rid of this extraneous copy
cb_pop_front(cb_h, 1);
}
copy(cb_h_prev, cb_h_acc); // Store the last row of this tile for the next iteration

// Make sure to remove the last hidden state
cb_wait_front(cb_h_prev, 1);
cb_pop_front(cb_h_prev, 1);
Expand All @@ -158,12 +157,17 @@ void MAIN {
untilize_init(cb_a_in);
binary_op_init_common(cb_a_in, cb_bx_in);

setup_cb_zeros();
// Fill initial hidden states with zeros
for (uint32_t tilized_chunk_idx = 0; tilized_chunk_idx < num_tilize_per_row; tilized_chunk_idx++) {
fill_tile_zeros(cb_h_acc);
}

// For each row of tiles we want to tilize chunks of 32 tiles to pack the rows into tiles
for (uint32_t row_idx = 0; row_idx < total_tiles_per_col; row_idx++) {
for (uint32_t tilized_chunk_idx = 0; tilized_chunk_idx < num_tilize_per_row; tilized_chunk_idx++) {
fill_tile_zeros(cb_h_prev);
// Load the last row from the hidden state above this row
copy(cb_h_acc, cb_h_prev);
cb_pop_front(cb_h_acc, 1);

// If we don't have a full chunk (NUM_TILES_IN_TILIZED_CHUNK tiles) we should figure out how many tiles we
// have left. This only runs 2-3 tiles per shard so no need to unroll.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,28 @@

#include "dataflow_api.h"

void fill_zeros(uint32_t cb_id) {
constexpr uint32_t num_zeros_reads = 2048 / MEM_ZEROS_SIZE;
uint64_t zeros_noc_addr = get_noc_addr(MEM_ZEROS_BASE);
uint32_t write_addr = get_write_ptr(cb_id);
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(write_addr);

// Fill tile with zeros
for (uint32_t i = 0; i < num_zeros_reads; ++i) {
noc_async_read(zeros_noc_addr, write_addr, MEM_ZEROS_SIZE);
write_addr += MEM_ZEROS_SIZE;
}
noc_async_read_barrier();
}

void kernel_main() {
uint32_t num_tiles_per_core = get_arg_val<uint32_t>(0);
constexpr uint32_t cb_a_in = get_compile_time_arg_val(0);
constexpr uint32_t cb_bx_in = get_compile_time_arg_val(1);
constexpr uint32_t cb_zeros = get_compile_time_arg_val(2);

fill_zeros(cb_zeros);
cb_push_back(cb_zeros, 1);

cb_push_back(cb_a_in, num_tiles_per_core);
cb_push_back(cb_bx_in, num_tiles_per_core);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@ operation::ProgramWithCallbacks multi_core_ssm_prefix_scan(
const uint32_t cb_zeros_id = tt::CB::c_intermed6;
const auto cb_zeros = create_circular_buffer(cb_zeros_id, 1, intermediary_tile_size, intermediary_format);

std::vector<uint32_t> reader_compile_time_args = {cb_a_in_id, cb_bx_in_id};
const uint32_t cb_h_acc_id = tt::CB::c_intermed7;
const uint32_t num_chunks_per_row = ceil(float(total_tiles_per_row) / 32.0f);
const auto cb_h_acc =
create_circular_buffer(cb_h_acc_id, num_chunks_per_row, intermediary_tile_size, intermediary_format);

std::vector<uint32_t> reader_compile_time_args = {cb_a_in_id, cb_bx_in_id, cb_zeros_id};
std::vector<uint32_t> writer_compile_time_args = {cb_out_id};
std::vector<uint32_t> compute_compile_time_args = {
cb_a_in_id,
Expand All @@ -100,7 +105,8 @@ operation::ProgramWithCallbacks multi_core_ssm_prefix_scan(
cb_h_id,
cb_tilize_out_id,
cb_out_id,
cb_zeros_id};
cb_zeros_id,
cb_h_acc_id};

auto reader_kernel_id = tt_metal::CreateKernel(
program,
Expand Down

0 comments on commit 4a8f2da

Please sign in to comment.