From 5bd5740c2971559dbc632bdcf7ce1b111ffad798 Mon Sep 17 00:00:00 2001 From: Aswin C Date: Fri, 21 Jun 2024 07:54:49 +0530 Subject: [PATCH 1/2] Vectorize KV-Cache by using `Vec4` --- crates/ratchet-core/src/ops/cache.rs | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/crates/ratchet-core/src/ops/cache.rs b/crates/ratchet-core/src/ops/cache.rs index 5ff00d7b..f0e1a4b1 100644 --- a/crates/ratchet-core/src/ops/cache.rs +++ b/crates/ratchet-core/src/ops/cache.rs @@ -35,9 +35,9 @@ impl Cache { builder: &mut WgslKernelBuilder, _: bool, ) -> Result<(), OperationError> { - builder.register_storage("C", BindingMode::ReadWrite, Array::

::default()); - builder.register_storage("S", BindingMode::ReadOnly, Array::

::default()); - builder.register_storage("D", BindingMode::ReadWrite, Array::

::default()); + builder.register_storage("C", BindingMode::ReadWrite, Array::>::default()); + builder.register_storage("S", BindingMode::ReadOnly, Array::>::default()); + builder.register_storage("D", BindingMode::ReadWrite, Array::>::default()); builder.register_uniform(); Ok(()) @@ -68,29 +68,29 @@ impl Cache { kernel_builder.write_index_to_offset(); kernel_builder.write_main(wgsl! { - //Dispatch 1 thread per output element + //Dispatch 1 thread per output element (vec4) //dst_offset is index into the output buffer (1D) let x_offset = workgroup_id.x * 64u; let dst_offset = (workgroup_id.y * num_workgroups.x * 64u) + x_offset + local_invocation_index; - if (dst_offset >= metadata.dst_numel) { + if (dst_offset >= metadata.dst_numel / 4u) { return; } - //Convert 1D offset into 4D index - var dst_index = offsetToNdIndex(dst_offset, metadata.dst_stride); + // Convert 1D offset into 4D index + var dst_index = offsetToNdIndex(dst_offset * 4u, metadata.dst_stride); let dim = metadata.dim; if (dst_index[dim] < metadata.cum0) { //Inside cache, just copy from cache to DST - let src_offset = ndIndexToOffset(dst_index, metadata.cache_stride); + let src_offset = ndIndexToOffset(dst_index, metadata.cache_stride) / 4u; D[dst_offset] = C[src_offset]; return; } if (dst_index[dim] < metadata.cum1) { //Inside src, copy from src to cache and then to DST - let cache_offset = ndIndexToOffset(dst_index, metadata.cache_stride); + let cache_offset = ndIndexToOffset(dst_index, metadata.cache_stride) / 4u; dst_index[dim] -= metadata.cum0; - let src_offset = ndIndexToOffset(dst_index, metadata.src_stride); + let src_offset = ndIndexToOffset(dst_index, metadata.src_stride) / 4u; let val = S[src_offset]; C[cache_offset] = val; D[dst_offset] = val; @@ -151,7 +151,12 @@ impl MetaOperation for Cache { } fn kernel_element(&self, _dst: &Tensor) -> KernelElement { - KernelElement::Scalar + let numel = self.input.shape().numel(); + if numel % 4 == 0 { + KernelElement::Vec4 + } else { + KernelElement::Scalar + } } fn calculate_dispatch(&self, dst: &Tensor) -> Result { From 18f24fb887635f362c281e461e50e604175af90d Mon Sep 17 00:00:00 2001 From: Aswin C Date: Sun, 23 Jun 2024 18:59:57 +0530 Subject: [PATCH 2/2] Use `Array::

>` while registering storage and add more conditions for `KernelElement` --- crates/ratchet-core/src/ops/cache.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/crates/ratchet-core/src/ops/cache.rs b/crates/ratchet-core/src/ops/cache.rs index f0e1a4b1..126509f3 100644 --- a/crates/ratchet-core/src/ops/cache.rs +++ b/crates/ratchet-core/src/ops/cache.rs @@ -35,9 +35,9 @@ impl Cache { builder: &mut WgslKernelBuilder, _: bool, ) -> Result<(), OperationError> { - builder.register_storage("C", BindingMode::ReadWrite, Array::>::default()); - builder.register_storage("S", BindingMode::ReadOnly, Array::>::default()); - builder.register_storage("D", BindingMode::ReadWrite, Array::>::default()); + builder.register_storage("C", BindingMode::ReadWrite, Array::

::default()); + builder.register_storage("S", BindingMode::ReadOnly, Array::

::default()); + builder.register_storage("D", BindingMode::ReadWrite, Array::

::default()); builder.register_uniform(); Ok(()) @@ -150,10 +150,13 @@ impl MetaOperation for Cache { rvec![&self.cache, &self.source] } - fn kernel_element(&self, _dst: &Tensor) -> KernelElement { - let numel = self.input.shape().numel(); + fn kernel_element(&self, dst: &Tensor) -> KernelElement { + let numel = dst.shape().numel(); + if numel % 4 == 0 { KernelElement::Vec4 + } else if numel % 2 == 0 { + KernelElement::Vec2 } else { KernelElement::Scalar }