-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vectorize KV-Cache by using Vec4
#222
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,9 +35,9 @@ impl Cache { | |
builder: &mut WgslKernelBuilder, | ||
_: bool, | ||
) -> Result<(), OperationError> { | ||
builder.register_storage("C", BindingMode::ReadWrite, Array::<P>::default()); | ||
builder.register_storage("S", BindingMode::ReadOnly, Array::<P>::default()); | ||
builder.register_storage("D", BindingMode::ReadWrite, Array::<P>::default()); | ||
builder.register_storage("C", BindingMode::ReadWrite, Array::<vec4<f32>>::default()); | ||
builder.register_storage("S", BindingMode::ReadOnly, Array::<vec4<f32>>::default()); | ||
builder.register_storage("D", BindingMode::ReadWrite, Array::<vec4<f32>>::default()); | ||
|
||
builder.register_uniform(); | ||
Ok(()) | ||
|
@@ -68,29 +68,29 @@ impl Cache { | |
kernel_builder.write_index_to_offset(); | ||
|
||
kernel_builder.write_main(wgsl! { | ||
//Dispatch 1 thread per output element | ||
//Dispatch 1 thread per output element (vec4<f32>) | ||
//dst_offset is index into the output buffer (1D) | ||
let x_offset = workgroup_id.x * 64u; | ||
let dst_offset = (workgroup_id.y * num_workgroups.x * 64u) + x_offset + local_invocation_index; | ||
if (dst_offset >= metadata.dst_numel) { | ||
if (dst_offset >= metadata.dst_numel / 4u) { | ||
return; | ||
} | ||
//Convert 1D offset into 4D index | ||
var dst_index = offsetToNdIndex(dst_offset, metadata.dst_stride); | ||
// Convert 1D offset into 4D index | ||
var dst_index = offsetToNdIndex(dst_offset * 4u, metadata.dst_stride); | ||
|
||
let dim = metadata.dim; | ||
if (dst_index[dim] < metadata.cum0) { | ||
//Inside cache, just copy from cache to DST | ||
let src_offset = ndIndexToOffset(dst_index, metadata.cache_stride); | ||
let src_offset = ndIndexToOffset(dst_index, metadata.cache_stride) / 4u; | ||
D[dst_offset] = C[src_offset]; | ||
return; | ||
} | ||
|
||
if (dst_index[dim] < metadata.cum1) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What will happen to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm 🤔 . I'm still trying to understand this one and what needs to be done about it. I suppose I need to divide |
||
//Inside src, copy from src to cache and then to DST | ||
let cache_offset = ndIndexToOffset(dst_index, metadata.cache_stride); | ||
let cache_offset = ndIndexToOffset(dst_index, metadata.cache_stride) / 4u; | ||
dst_index[dim] -= metadata.cum0; | ||
let src_offset = ndIndexToOffset(dst_index, metadata.src_stride); | ||
let src_offset = ndIndexToOffset(dst_index, metadata.src_stride) / 4u; | ||
let val = S[src_offset]; | ||
C[cache_offset] = val; | ||
D[dst_offset] = val; | ||
|
@@ -151,7 +151,12 @@ impl MetaOperation for Cache { | |
} | ||
|
||
fn kernel_element(&self, _dst: &Tensor) -> KernelElement { | ||
KernelElement::Scalar | ||
let numel = self.input.shape().numel(); | ||
if numel % 4 == 0 { | ||
FL33TW00D marked this conversation as resolved.
Show resolved
Hide resolved
|
||
KernelElement::Vec4 | ||
} else { | ||
KernelElement::Scalar | ||
} | ||
} | ||
|
||
fn calculate_dispatch(&self, dst: &Tensor) -> Result<Workload, OperationError> { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will only work for
vec4<f32>
!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, just learnt that
Vec4<T>
implementsWgslPrimitive
for any T. Making itArray::<P>
itself, from what I understand.