Skip to content

Commit

Permalink
#12314: Modified ROPE op to support long contexts (#12315)
Browse files Browse the repository at this point in the history
* #0: Modified ROPE op to reduce buffering load for super long context lengths. Smaller inputs use the old implementation. Larger inputs use a new impl which reloads sin/cos for each row, and only double buffers one row for input, cos, and sin

(cherry picked from commit 614c021)

* #0: update rotary embedding test to include long seqlen
  • Loading branch information
cglagovichTT authored Sep 9, 2024
1 parent bb1eda9 commit fba4a2d
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def run_test_rotary_embedding_llama(
(1, 4096),
(1, 8192),
(1, 16384),
(1, 128 * 1024),
),
ids=(
"prefill_32",
Expand All @@ -200,6 +201,7 @@ def run_test_rotary_embedding_llama(
"prefill_4k",
"prefill_8k",
"prefill_16k",
"prefill_128k",
),
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -230,6 +232,9 @@ def test_rotary_embedding_llama(
if compute_grid_size.x < 8 or compute_grid_size.y < 8:
pytest.skip(f"Requires grid size of at least {(8, 8)} to run")

if seq_len == 128 * 1024 and (n_heads, n_kv_heads, head_dim) != (8, 1, 128):
pytest.skip("Only testing for (8, 1, 128) due to time constraints")

max_seq_len = max(4096, seq_len)

run_test_rotary_embedding_llama(devices, batch, seq_len, pcc, n_heads, n_kv_heads, head_dim, max_seq_len, datatype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,20 @@ void MAIN {
uint32_t in1_index = 0;
uint32_t interm_index = 0;

#if RELOAD_IMPL == 0
cb_wait_front(sin_cb, sin_cos_cb_size_in_tiles);
cb_wait_front(cos_cb, sin_cos_cb_size_in_tiles);
#endif

uint32_t sin_cos_row_cnt = 0;

for (uint32_t i = 0; i < num_rows_per_core; ++i) {
// input cb wait and reserve
cb_wait_front(in_cb, Wt);
#if RELOAD_IMPL == 1
cb_wait_front(sin_cb, Wt);
cb_wait_front(cos_cb, Wt);
#endif

cb_reserve_back(rotated_in_interm_cb, Wt);
cb_reserve_back(sin_interm_cb, Wt);
Expand Down Expand Up @@ -85,6 +91,10 @@ void MAIN {
REL();
cb_push_back(cos_interm_cb, Wt);
cb_pop_front(in_cb, Wt); // Done with input
#if RELOAD_IMPL == 1
cb_pop_front(sin_cb, Wt);
cb_pop_front(cos_cb, Wt);
#endif

cb_wait_front(cos_interm_cb, Wt);
cb_wait_front(sin_interm_cb, Wt);
Expand All @@ -100,15 +110,21 @@ void MAIN {
cb_pop_front(cos_interm_cb, Wt);
cb_pop_front(sin_interm_cb, Wt);

#if RELOAD_IMPL == 0
// no-reload needs to increment this counter
// Used a sin/cos row
sin_cos_row_cnt++;
// Loop back to the beginning of the sin/cos rows
if (sin_cos_row_cnt == num_sin_cos_rows_per_core) {
sin_cos_row_cnt = 0;
}
#endif
}

#if RELOAD_IMPL == 0
cb_pop_front(sin_cb, sin_cos_cb_size_in_tiles);
cb_pop_front(cos_cb, sin_cos_cb_size_in_tiles);
#endif


// Done with the transformation matrix, so remove from CB
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ void kernel_main() {
Ht = 4
Wt = 4
*/
#if RELOAD_IMPL == 0
cb_reserve_back(sin_cb_id, sin_cos_cb_size_in_tiles);
cb_reserve_back(cos_cb_id, sin_cos_cb_size_in_tiles);
uint32_t sin_l1_write_addr = get_write_ptr(sin_cb_id);
uint32_t cos_l1_write_addr = get_write_ptr(cos_cb_id);

#endif

// To make sure the sin/cos row are read only once
uint32_t sin_cos_row_cnt = 0;
Expand All @@ -99,6 +100,12 @@ void kernel_main() {
uint32_t input_row_cnt = 0;

for (uint32_t i = 0; i < num_rows_per_core; ++i) {
#if RELOAD_IMPL == 1
cb_reserve_back(sin_cb_id, Wt);
cb_reserve_back(cos_cb_id, Wt);
uint32_t sin_l1_write_addr = get_write_ptr(sin_cb_id);
uint32_t cos_l1_write_addr = get_write_ptr(cos_cb_id);
#endif
cb_reserve_back(input_cb_id, Wt);
uint32_t input_l1_write_addr = get_write_ptr(input_cb_id);
for (uint32_t j = 0; j < Wt; ++j) {
Expand All @@ -124,6 +131,10 @@ void kernel_main() {
noc_async_read_barrier();
cb_push_back(input_cb_id, Wt);
input_row_cnt++;
#if RELOAD_IMPL == 1
cb_push_back(sin_cb_id, Wt);
cb_push_back(cos_cb_id, Wt);
#else

if (!done_sin_cos) {
cb_push_back(sin_cb_id, Wt);
Expand All @@ -136,9 +147,11 @@ void kernel_main() {
done_sin_cos = true;
}
}
#endif
// Update input_curr_idx to stride the correct amount to the next row
if (input_row_cnt % num_sin_cos_rows_per_core == 0) {
input_curr_idx += (Ht - num_sin_cos_rows_per_core) * Wt;
cos_sin_curr_idx = cos_sin_start_idx; // For reload case, reset cos_sin_curr_idx
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,26 @@ operation::ProgramWithCallbacks rotary_embedding_llama_multi_core(

num_rows_per_core = num_rows_per_core_group_1; // Will always find equal split
uint32_t num_sin_cos_rows_per_core = std::max((uint32_t) 1, (uint32_t) (Ht / num_cores));
uint32_t num_cos_sin_tiles = 2 * Wt * num_sin_cos_rows_per_core;

uint32_t input_cb_num_tiles = num_sin_cos_rows_per_core * num_input_tiles;

const bool use_reload_impl = num_rows_per_core > 8;
if (use_reload_impl) {
// Do reload implementation of kernel to reduce buffer sizes
// Only size CBs to double buffer Wt tiles for all inputs
input_cb_num_tiles = num_input_tiles;
num_cos_sin_tiles = num_input_tiles;
}


uint32_t input_cb_index = CB::c_in0;
tt_metal::CircularBufferConfig cb_input_config =
tt_metal::CircularBufferConfig(
num_sin_cos_rows_per_core * num_input_tiles * input_single_tile_size, {{input_cb_index, input_cb_data_format}})
input_cb_num_tiles * input_single_tile_size, {{input_cb_index, input_cb_data_format}})
.set_page_size(input_cb_index, input_single_tile_size);
auto cb_input = tt_metal::CreateCircularBuffer(program, all_cores, cb_input_config);

uint32_t num_cos_sin_tiles = 2 * Wt * num_sin_cos_rows_per_core;

uint32_t cos_cb_index = CB::c_in1;
tt_metal::CircularBufferConfig cb_cos_config =
tt_metal::CircularBufferConfig(num_cos_sin_tiles * cos_single_tile_size, {{cos_cb_index, cos_cb_data_format}})
Expand All @@ -116,10 +126,10 @@ operation::ProgramWithCallbacks rotary_embedding_llama_multi_core(
auto cb_sin = tt_metal::CreateCircularBuffer(program, all_cores, cb_sin_config);

uint32_t trans_mat_cb_index = CB::c_in3;
// We only take one tile of trans_mat, doubled buffered
uint32_t num_trans_mat_tiles = 2;
// We only take one tile of trans_mat
uint32_t num_trans_mat_tiles = 1;
tt_metal::CircularBufferConfig cb_trans_mat_config =
tt_metal::CircularBufferConfig(num_input_tiles * trans_mat_single_tile_size, {{trans_mat_cb_index, trans_mat_cb_data_format}})
tt_metal::CircularBufferConfig(num_trans_mat_tiles * trans_mat_single_tile_size, {{trans_mat_cb_index, trans_mat_cb_data_format}})
.set_page_size(trans_mat_cb_index, trans_mat_single_tile_size);
auto cb_trans_mat = tt_metal::CreateCircularBuffer(program, all_cores, cb_trans_mat_config);

Expand Down Expand Up @@ -153,6 +163,7 @@ operation::ProgramWithCallbacks rotary_embedding_llama_multi_core(
auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config);

std::map<string, string> kernel_defines;
kernel_defines["RELOAD_IMPL"] = use_reload_impl ? "1" : "0";

auto src_buffer = input.buffer();
auto cos_buffer = cos.buffer();
Expand Down

0 comments on commit fba4a2d

Please sign in to comment.