-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vectorize KV-Cache by using Vec4
#222
base: master
Are you sure you want to change the base?
Vectorize KV-Cache by using Vec4
#222
Conversation
Code Metrics Report=============================================================================== Language Files Lines Code Comments Blanks =============================================================================== TOML 1 75 63 2 10 ------------------------------------------------------------------------------- Rust 62 13276 11417 185 1674 |- Markdown 34 311 0 244 67 (Total) 13587 11417 429 1741 =============================================================================== Total 63 13351 11480 187 1684 =============================================================================== |
crates/ratchet-core/src/ops/cache.rs
Outdated
builder.register_storage("C", BindingMode::ReadWrite, Array::<P>::default()); | ||
builder.register_storage("S", BindingMode::ReadOnly, Array::<P>::default()); | ||
builder.register_storage("D", BindingMode::ReadWrite, Array::<P>::default()); | ||
builder.register_storage("C", BindingMode::ReadWrite, Array::<vec4<f32>>::default()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will only work for vec4<f32>
!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, just learnt that Vec4<T>
implements WgslPrimitive
for any T. Making it Array::<P>
itself, from what I understand.
|
||
let dim = metadata.dim; | ||
if (dst_index[dim] < metadata.cum0) { | ||
//Inside cache, just copy from cache to DST | ||
let src_offset = ndIndexToOffset(dst_index, metadata.cache_stride); | ||
let src_offset = ndIndexToOffset(dst_index, metadata.cache_stride) / 4u; | ||
D[dst_offset] = C[src_offset]; | ||
return; | ||
} | ||
|
||
if (dst_index[dim] < metadata.cum1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What will happen to cum1
here if all lengths are / 4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm 🤔 . I'm still trying to understand this one and what needs to be done about it.
I suppose I need to divide cum1
by 4 in write_metadata
too?
…or `KernelElement`
@officialcjunior Haven't forgotten about this, just working on the refactor 👍🏻 |
Currently, the KV-Cache operation is scalar, this PR attempts to vectorize it.
Fixes #210