-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support longer sequence lengths in ssm_prefix_scan
#9776
Conversation
ssm_prefix_scan
ssm_prefix_scan
6d4551f
to
1f6f645
Compare
fc80603
to
d62da2c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments
@@ -4,10 +4,28 @@ | |||
|
|||
#include "dataflow_api.h" | |||
|
|||
void fill_zeros(uint32_t cb_id) { | |||
constexpr uint32_t num_zeros_reads = 2048 / MEM_ZEROS_SIZE; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably make 2048 a constexpr
variable and comment that this is for bfloat16
only. Zeros CB for other data types will be different.
@@ -88,7 +88,12 @@ operation::ProgramWithCallbacks multi_core_ssm_prefix_scan( | |||
const uint32_t cb_zeros_id = tt::CB::c_intermed6; | |||
const auto cb_zeros = create_circular_buffer(cb_zeros_id, 1, intermediary_tile_size, intermediary_format); | |||
|
|||
std::vector<uint32_t> reader_compile_time_args = {cb_a_in_id, cb_bx_in_id}; | |||
const uint32_t cb_h_acc_id = tt::CB::c_intermed7; | |||
const uint32_t num_chunks_per_row = ceil(float(total_tiles_per_row) / 32.0f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
div_up
might be cleaner for you without having to do these casts. Also use some variable for 32
const.
This change adds support for L > 32 in `ssm_eltwise_mul`. Logically this op can now handle any value of L but values of L > 128 will run out of L1 in `bfloat8` format. This change also fixes #9831.
d62da2c
to
d5da023
Compare
This PR adds support for L > 32 in
ssm_eltwise_mul
. Logically we can handle any value of L but values of L > 128 will run out of L1 inbfloat8
format.This change also addresses issue #9831.