Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support longer sequence lengths in ssm_prefix_scan #9776

Merged
merged 1 commit into from
Jul 2, 2024

Conversation

esmalTT
Copy link
Contributor

@esmalTT esmalTT commented Jun 27, 2024

This PR adds support for L > 32 in ssm_eltwise_mul. Logically we can handle any value of L but values of L > 128 will run out of L1 in bfloat8 format.

This change also addresses issue #9831.

@esmalTT esmalTT self-assigned this Jun 27, 2024
@esmalTT esmalTT changed the title Support longer sequence legnths in ssm_prefix_scan Support longer sequence lengths in ssm_prefix_scan Jun 27, 2024
@esmalTT esmalTT added the mamba label Jun 27, 2024
@esmalTT esmalTT force-pushed the esmal/prefix-scan-support-larger-seq branch 4 times, most recently from 6d4551f to 1f6f645 Compare June 30, 2024 12:24
@esmalTT esmalTT requested review from kpaigwar and TT-BrianLiu June 30, 2024 12:54
@esmalTT esmalTT marked this pull request as ready for review June 30, 2024 12:54
@esmalTT esmalTT force-pushed the esmal/prefix-scan-support-larger-seq branch 3 times, most recently from fc80603 to d62da2c Compare June 30, 2024 14:32
Copy link
Contributor

@TT-BrianLiu TT-BrianLiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments

@@ -4,10 +4,28 @@

#include "dataflow_api.h"

void fill_zeros(uint32_t cb_id) {
constexpr uint32_t num_zeros_reads = 2048 / MEM_ZEROS_SIZE;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably make 2048 a constexpr variable and comment that this is for bfloat16 only. Zeros CB for other data types will be different.

@@ -88,7 +88,12 @@ operation::ProgramWithCallbacks multi_core_ssm_prefix_scan(
const uint32_t cb_zeros_id = tt::CB::c_intermed6;
const auto cb_zeros = create_circular_buffer(cb_zeros_id, 1, intermediary_tile_size, intermediary_format);

std::vector<uint32_t> reader_compile_time_args = {cb_a_in_id, cb_bx_in_id};
const uint32_t cb_h_acc_id = tt::CB::c_intermed7;
const uint32_t num_chunks_per_row = ceil(float(total_tiles_per_row) / 32.0f);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

div_up might be cleaner for you without having to do these casts. Also use some variable for 32 const.

This change adds support for L > 32 in `ssm_eltwise_mul`. Logically this
op can now handle any value of L but values of L > 128 will run out of
L1 in `bfloat8` format.

This change also fixes #9831.
@esmalTT esmalTT force-pushed the esmal/prefix-scan-support-larger-seq branch from d62da2c to d5da023 Compare July 2, 2024 17:30
@esmalTT esmalTT merged commit 4a8f2da into main Jul 2, 2024
5 checks passed
@esmalTT esmalTT deleted the esmal/prefix-scan-support-larger-seq branch July 2, 2024 17:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
3 participants