From 83a9e883e4cfe28344253afdf697a82c7ba9b76f Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Wed, 15 May 2024 15:10:22 -0400 Subject: [PATCH 01/75] Mistral.rs Squash Changes (#4) * Offset it * Freeze * Offset it * Offset it * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Remove debugs * Polish it up * Polish it up * Clippy * Remove test file * Add config for if neox * Fix bug * Fix bug * Cast cache type on rust side * Cast types * To dtype * Drop temp * Update casting * Update casting * Update casting * Create dtype in bf16 * Check type * Debug * Check dtype * Check dtype * Check dtype * Check dtype * Check dtype * Check dtype * Check dtype * Check dtype * Check dtype * Debug * Debug * Debug * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Reseting * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Remove debug * Debug * Debug * Remove debug * Remove debug * Debug * Remove debug * Debug * Remove debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Try to use 3dim rotemb fused * Try to use 3dim rotemb fused * Remove contig and debug * Check handling * Cleanup * Fix * Remove prints * Lower block dim * Use fused layernorm * Pass batch size * Simplify internal API * Simplify internal API * Try slow * Try candle layer norm * Try candle layer norm * Fix dep of candle layer norm * Reshape input for rank 2 * Reshape input for rank 2 * Fix ref * Code style * Make dep optional * Ensure contig * Ensure contig * Ensure contig * Debug contig dmmv error * Debug contig dmmv error * Debug contig dmmv error * Debug contig dmmv error * Try other method * Try other method * Try other method * Try other method * Try other method * Use typestate to optimize * Use typestate to optimize * Fixes * Fixes * Fixes * Fixes * Fixes * Debug via using slow rmsnorm * Debug via using slow rope * Remove debug * More debugging * Remove debug * Remove debug * Remove debug * Add better error enum * Fix diff marker * Fix some things * Fix some things * Fix some things * Fix dummy backends * Re add from storage noop * Fix removed kvconcat custom op * Fix erroneous feature gate * Complete metal backend refactoring * Check if calling * Check if calling * Update default for force dmmv * Load atomic * Debug * Use mmvq * Update * Add the empty functions * Add rope new_partial function * Make variant of qmatmul pub * Make variant of qmatmul pub * Add the varbuilder set_device function * Only link stdc++ if target has msvc * Only link stdc++ if target has msvc * Only link stdc++ if target has msvc * Only link stdc++ if target has msvc * Handle case of device mapping * Handle case of device mapping * Add getter * Fix * Fix * Support nvcc flags in flash attn * Support nvcc flags in flash attn * Support nvcc flags in flash attn * Support nvcc flags in flash attn * Support nvcc flags in flash attn * Fixes * Fixes * Fix the tests * Fix the tests --- .vscode/settings.json | 3 +- README.md | 2 + candle-core/src/cuda_backend/mod.rs | 160 +++++++++ candle-core/src/device.rs | 8 +- candle-core/src/lib.rs | 2 +- candle-core/src/tensor.rs | 61 ++++ .../examples/mamba-minimal/model.rs | 10 +- .../examples/quantized-phi/main.rs | 13 +- candle-flash-attn/build.rs | 23 +- candle-kernels/src/cuda_utils.cuh | 29 ++ candle-kernels/src/fused_layer_norm.cu | 329 +++++++++++++++++ candle-kernels/src/fused_rms_norm.cu | 82 +++++ candle-kernels/src/fused_rope.cu | 231 ++++++++++++ candle-kernels/src/kvconcat.cu | 53 +++ candle-kernels/src/lib.rs | 3 + candle-nn/Cargo.toml | 3 +- candle-nn/src/layer_norm.rs | 222 ++++++++++-- candle-nn/src/lib.rs | 8 +- candle-nn/src/ops.rs | 27 ++ candle-nn/src/rope.rs | 330 ++++++++++++++++++ candle-nn/src/var_builder.rs | 20 +- candle-transformers/src/models/chatglm.rs | 6 +- candle-transformers/src/models/llama2_c.rs | 22 +- candle-transformers/src/models/mamba.rs | 10 +- candle-transformers/src/models/metavoice.rs | 23 +- candle-transformers/src/models/mod.rs | 1 - .../src/models/quantized_phi3.rs | 301 ---------------- .../src/models/with_tracing.rs | 8 +- candle-wasm-examples/llama2-c/src/model.rs | 35 +- 29 files changed, 1624 insertions(+), 401 deletions(-) create mode 100644 candle-kernels/src/fused_layer_norm.cu create mode 100644 candle-kernels/src/fused_rms_norm.cu create mode 100644 candle-kernels/src/fused_rope.cu create mode 100644 candle-kernels/src/kvconcat.cu create mode 100644 candle-nn/src/rope.rs delete mode 100644 candle-transformers/src/models/quantized_phi3.rs diff --git a/.vscode/settings.json b/.vscode/settings.json index b2dbd68012..b7345f2ca6 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,5 +7,6 @@ "candle-pyo3" ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true + "python.testing.pytestEnabled": true, + //"rust-analyzer.cargo.features": ["cuda"], } \ No newline at end of file diff --git a/README.md b/README.md index 5644d81813..a5fbe7d7d8 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ [![Documentation](https://docs.rs/candle-core/badge.svg)](https://docs.rs/candle-core) ![License](https://img.shields.io/crates/l/candle-core.svg) +**This is an optimized implmentation by Eric Buehler.** + Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support) and ease of use. Try our online demos: [whisper](https://huggingface.co/spaces/lmz/candle-whisper), diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 1ea9beafe7..864906ab68 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -2008,3 +2008,163 @@ unsafe fn gemm_strided_batched_bf16( sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP, ) } + +pub struct KVConcat { + pub concat_dim: usize, +} +impl crate::CustomOp2 for KVConcat { + fn name(&self) -> &'static str { + "kvconcat" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + crate::bail!("no cpu support for kvconcat") + } + + fn cuda_fwd( + &self, + ltensor: &CudaStorage, + ltensor_l: &Layout, + rtensor: &CudaStorage, + rtensor_l: &Layout, + ) -> Result<(CudaStorage, Shape)> { + assert!(self.concat_dim == 2 || self.concat_dim == 0); //must be in the dim of sequence len + let dev = <ensor.device; + let elem_count = ltensor_l.shape().elem_count() + rtensor_l.shape().elem_count(); + let dims_l = ltensor_l.shape().dims(); + let dims_r = rtensor_l.shape().dims(); + let dim_size = dims_l.len(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + + let chunk_l = if dim_size > 3 { + dims_l[0] * dims_l[1] + } else { + dims_l[0] + }; + let chunk_r = if dim_size > 3 { + dims_r[0] * dims_r[1] + } else { + dims_r[0] + }; + let lstride = if dim_size > 3 { + dims_l[2] * dims_l[3] + } else { + dims_l[1] * dims_l[2] + }; + let rstride = if dim_size > 3 { + dims_r[2] * dims_r[3] + } else { + dims_r[1] * dims_r[2] + }; + + let slice = match (<ensor.slice, &rtensor.slice) { + (CudaStorageSlice::BF16(left_), CudaStorageSlice::BF16(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_bf16", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::BF16(out) + } + (CudaStorageSlice::F32(left_), CudaStorageSlice::F32(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f32", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F32(out) + } + (CudaStorageSlice::F16(left_), CudaStorageSlice::F16(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f16", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F16(out) + } + (CudaStorageSlice::F64(left_), CudaStorageSlice::F64(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f64", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F64(out) + } + (CudaStorageSlice::U8(left_), CudaStorageSlice::U8(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_u8", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::U8(out) + } + _ => Err(CudaError::InternalError("dtype mismatch in kvconcat op"))?, + }; + + let mut lshape: Vec = ltensor_l.shape().dims().to_vec(); + if self.concat_dim == 0 { + lshape[0] += rtensor_l.shape().dims()[0]; + } else { + if dim_size > 3 { + lshape[2] += rtensor_l.shape().dims()[2]; + } else { + lshape[1] += rtensor_l.shape().dims()[1]; + } + } + + let device = dev.clone(); + Ok(( + CudaStorage { + slice: slice, + device, + }, + lshape.into(), + )) + } +} diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 1cd26167f5..7f2dbc411c 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -325,12 +325,12 @@ impl Device { Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), Device::Cuda(device) => { let storage = array.to_cpu_storage(); - let storage = device.storage_from_cpu_storage_owned(storage)?; + let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Cuda(storage)) } Device::Metal(device) => { let storage = array.to_cpu_storage(); - let storage = device.storage_from_cpu_storage_owned(storage)?; + let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Metal(storage)) } } @@ -341,12 +341,12 @@ impl Device { Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))), Device::Cuda(device) => { let storage = S::to_cpu_storage_owned(data); - let storage = device.storage_from_cpu_storage_owned(storage)?; + let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Cuda(storage)) } Device::Metal(device) => { let storage = S::to_cpu_storage_owned(data); - let storage = device.storage_from_cpu_storage_owned(storage)?; + let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Metal(storage)) } } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 90e3644dd5..8a46377b59 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -85,7 +85,7 @@ pub use layout::Layout; pub use shape::{Shape, D}; pub use storage::Storage; pub use strided_index::{StridedBlocks, StridedIndex}; -pub use tensor::{Tensor, TensorId}; +pub use tensor::{from_storage_no_op, Tensor, TensorId}; pub use variable::Var; #[cfg(feature = "cuda")] diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index dd1b44b0a0..baaa288bb5 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -176,6 +176,22 @@ pub(crate) fn from_storage>( Tensor(Arc::new(tensor_)) } +/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides. This has a BackpropOp:none(). +pub fn from_storage_no_op>(storage: Storage, shape: S, is_variable: bool) -> Tensor { + let dtype = storage.dtype(); + let device = storage.device(); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: Arc::new(RwLock::new(storage)), + layout: Layout::contiguous(shape), + op: BackpropOp::none(), + is_variable, + dtype, + device, + }; + Tensor(Arc::new(tensor_)) +} + impl Tensor { pub(crate) fn ones_impl>( shape: S, @@ -256,6 +272,51 @@ impl Tensor { Tensor::zeros(self.shape(), self.dtype(), self.device()) } + // Do not expose outside of the crate, the `is_variable=true` case should only be accessed from + // the variable module. + pub(crate) unsafe fn empty_impl>( + shape: S, + dtype: DType, + device: &Device, + is_variable: bool, + ) -> Result { + let none = BackpropOp::none(); + let shape = shape.into(); + let storage = device.alloc_uninit(&shape, dtype)?; + Ok(from_storage(storage, shape, none, is_variable)) + } + + /// Creates a new tensor filled with uninitialized memory. + /// + /// # Safety + /// This returns uninitialized memory. + /// + /// ```rust + /// use candle_core::{Tensor, DType, Device}; + /// let a = unsafe { Tensor::empty((2, 3), DType::F32, &Device::Cpu)? }; + /// // a == b + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub unsafe fn empty>(shape: S, dtype: DType, device: &Device) -> Result { + Self::empty_impl(shape, dtype, device, false) + } + + /// Creates a new tensor filled with uninitialized memory of the same shape, dtype, and device as the other + /// tensor. + /// + /// # Safety + /// This returns uninitialized memory. + /// + /// ```rust + /// use candle_core::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = unsafe { a.empty_like()? }; + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub unsafe fn empty_like(&self) -> Result { + Tensor::empty(self.shape(), self.dtype(), self.device()) + } + pub(crate) fn rand_impl, T: crate::FloatDType>( lo: T, up: T, diff --git a/candle-examples/examples/mamba-minimal/model.rs b/candle-examples/examples/mamba-minimal/model.rs index 4a0a345d17..b8fa01a51a 100644 --- a/candle-examples/examples/mamba-minimal/model.rs +++ b/candle-examples/examples/mamba-minimal/model.rs @@ -2,7 +2,7 @@ /// https://github.com/johnma2006/mamba-minimal/blob/master/model.py /// Simple, minimal implementation of Mamba in one file of PyTorch. use candle::{IndexOp, Module, Result, Tensor, D}; -use candle_nn::{RmsNorm, VarBuilder}; +use candle_nn::{layer_norm::RmsNormNonQuantized, RmsNorm, VarBuilder}; use candle_transformers::models::with_tracing::{linear, linear_no_bias, Linear}; @@ -144,12 +144,12 @@ impl Module for MambaBlock { #[derive(Clone, Debug)] pub struct ResidualBlock { mixer: MambaBlock, - norm: RmsNorm, + norm: RmsNorm, } impl ResidualBlock { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?; + let norm = candle_nn::rms_norm_non_quant(cfg.d_model, 1e-5, vb.pp("norm"))?; let mixer = MambaBlock::new(cfg, vb.pp("mixer"))?; Ok(Self { mixer, norm }) } @@ -166,7 +166,7 @@ impl Module for ResidualBlock { pub struct Model { embedding: candle_nn::Embedding, layers: Vec, - norm_f: RmsNorm, + norm_f: RmsNorm, lm_head: Linear, } @@ -179,7 +179,7 @@ impl Model { let layer = ResidualBlock::new(cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?; + let norm_f = candle_nn::rms_norm_non_quant(cfg.d_model, 1e-5, vb.pp("norm_f"))?; let lm_head = Linear::from_weights(embedding.embeddings().clone(), None); Ok(Self { embedding, diff --git a/candle-examples/examples/quantized-phi/main.rs b/candle-examples/examples/quantized-phi/main.rs index e22118445e..f17cec69c9 100644 --- a/candle-examples/examples/quantized-phi/main.rs +++ b/candle-examples/examples/quantized-phi/main.rs @@ -15,7 +15,6 @@ use candle_transformers::generation::{LogitsProcessor, Sampling}; use candle_examples::token_output_stream::TokenOutputStream; use candle_transformers::models::quantized_llama::ModelWeights as Phi3b; use candle_transformers::models::quantized_phi::ModelWeights as Phi2; -use candle_transformers::models::quantized_phi3::ModelWeights as Phi3; const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. "; @@ -23,8 +22,6 @@ const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. " enum Which { #[value(name = "phi-2")] Phi2, - #[value(name = "phi-3")] - Phi3, /// Alternative implementation of phi-3, based on llama. #[value(name = "phi-3b")] Phi3b, @@ -100,7 +97,7 @@ impl Args { let api = hf_hub::api::sync::Api::new()?; let repo = match self.which { Which::Phi2 => "microsoft/phi-2", - Which::Phi3 | Which::Phi3b => "microsoft/Phi-3-mini-4k-instruct", + Which::Phi3b => "microsoft/Phi-3-mini-4k-instruct", }; let api = api.model(repo.to_string()); api.get("tokenizer.json")? @@ -115,11 +112,6 @@ impl Args { None => { let (repo, filename, revision) = match self.which { Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf", "main"), - Which::Phi3 => ( - "microsoft/Phi-3-mini-4k-instruct-gguf", - "Phi-3-mini-4k-instruct-q4.gguf", - "main", - ), Which::Phi3b => ( "microsoft/Phi-3-mini-4k-instruct-gguf", "Phi-3-mini-4k-instruct-q4.gguf", @@ -153,7 +145,6 @@ fn format_size(size_in_bytes: usize) -> String { enum Model { Phi2(Phi2), - Phi3(Phi3), Phi3b(Phi3b), } @@ -161,7 +152,6 @@ impl Model { fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result { match self { Self::Phi2(m) => m.forward(xs, pos), - Self::Phi3(m) => m.forward(xs, pos), Self::Phi3b(m) => m.forward(xs, pos), } } @@ -213,7 +203,6 @@ fn main() -> anyhow::Result<()> { ); match args.which { Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?), - Which::Phi3 => Model::Phi3(Phi3::from_gguf(model, &mut file, &device)?), Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?), } }; diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index 4002770be4..b64dd1cb68 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -4,6 +4,8 @@ use anyhow::{Context, Result}; use std::path::PathBuf; +const CUDA_NVCC_FLAGS: Option<&'static str> = option_env!("CUDA_NVCC_FLAGS"); + const KERNEL_FILES: [&str; 17] = [ "kernels/flash_api.cu", "kernels/flash_fwd_hdim128_fp16_sm80.cu", @@ -56,7 +58,7 @@ fn main() -> Result<()> { }; let kernels = KERNEL_FILES.iter().collect(); - let builder = bindgen_cuda::Builder::default() + let mut builder = bindgen_cuda::Builder::default() .kernel_paths(kernels) .out_dir(build_dir.clone()) .arg("-std=c++17") @@ -71,13 +73,30 @@ fn main() -> Result<()> { .arg("--use_fast_math") .arg("--verbose"); + // https://github.com/EricLBuehler/mistral.rs/issues/286 + // https://github.com/huggingface/candle-flash-attn-v1/pull/2 + if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS { + builder = builder.arg("--compiler-options"); + builder = builder.arg(cuda_nvcc_flags_env); + } + let out_file = build_dir.join("libflashattention.a"); builder.build_lib(out_file); println!("cargo:rustc-link-search={}", build_dir.display()); println!("cargo:rustc-link-lib=flashattention"); println!("cargo:rustc-link-lib=dylib=cudart"); - println!("cargo:rustc-link-lib=dylib=stdc++"); + // https://github.com/denoland/rusty_v8/blob/20b2989186d1ecdf4c291d0706ff9eb1baaf2cfd/build.rs#L602 + let target = std::env::var("TARGET").unwrap(); + if target.contains("msvc") { + // nothing to link to + } else if target.contains("apple") || target.contains("freebsd") || target.contains("openbsd") { + println!("cargo:rustc-link-lib=dylib=c++"); + } else if target.contains("android") { + println!("cargo:rustc-link-lib=dylib=c++_shared"); + } else { + println!("cargo:rustc-link-lib=dylib=stdc++"); + } Ok(()) } diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index 2673b8aaf1..f7a2506d0e 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -115,6 +115,35 @@ __device__ void chunk_sum( } } +__device__ __forceinline__ int GetBlockNum(void) { + return (gridDim.x * gridDim.y * gridDim.z); +} + +__device__ __forceinline__ int GetBlockIdx(void) { + return (blockIdx.z * (gridDim.x * gridDim.y) + blockIdx.y * gridDim.x + + blockIdx.x); +} + +__device__ __forceinline__ int GetThreadNumEachBlock(void) { + return (blockDim.x * blockDim.y * blockDim.z); +} + +__device__ __forceinline__ int GetThreadNum(void) { + return GetBlockNum() * GetThreadNumEachBlock(); +} + +__device__ __forceinline__ int GetThreadIdxInBlock(void) { + return threadIdx.z * (blockDim.x * blockDim.y) + + threadIdx.y * blockDim.x + threadIdx.x; +} + +__device__ __forceinline__ int GetThreadIdx(void) { + int blockIdx = GetBlockIdx(); + int threadNumEachBlock = GetThreadNumEachBlock(); + + return blockIdx * threadNumEachBlock + GetThreadIdxInBlock(); +} + __device__ __forceinline__ bool isnang(float a) { return isnan(a); } __device__ __forceinline__ bool isnang(double a) { return isnan(a); } __device__ __forceinline__ float recipg(float a) { return 1.0 / a; } diff --git a/candle-kernels/src/fused_layer_norm.cu b/candle-kernels/src/fused_layer_norm.cu new file mode 100644 index 0000000000..cea64c519b --- /dev/null +++ b/candle-kernels/src/fused_layer_norm.cu @@ -0,0 +1,329 @@ +// Based on https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/multihead_attn/layer_norm.cuh#L243 +// Modified Eric Buehler 2024 + +#include "cuda_fp16.h" +#include +#include + +#if __CUDA_ARCH__ >= 800 +#include +#endif + +template +__device__ void cuWelfordOnlineSum(const U curr, U &mu, U &sigma2, U &count) { + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template +__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, + U &mu, U &sigma2, U &count) { + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA * mu + nB * muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +// https://github.com/pytorch/pytorch/blob/7fe0cc53e903e515e86b4a350614011c66e3b32d/aten/src/ATen/cuda/DeviceUtils.cuh#L50 +template +__device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if !defined(USE_ROCM) + return __shfl_sync(mask, value, srcLane, width); +#else + return __shfl(value, srcLane, width); +#endif +} + +template +__device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1, + const int n2, const int i1, U &mu, U &sigma2, + U *buf) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu = U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T *lvals = vals + i1 * n2; + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l + k]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + U muB = WARP_SHFL(mu, srcLaneB); + U countB = WARP_SHFL(count, srcLaneB); + U sigma2B = WARP_SHFL(sigma2, srcLaneB); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U *ubuf = (U *)buf; + U *ibuf = (U *)(ubuf + blockDim.y); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2 * wrt_y] = mu; + ubuf[2 * wrt_y + 1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U muB = ubuf[2 * threadIdx.y]; + U sigma2B = ubuf[2 * threadIdx.y + 1]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1] / U(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2 / U(n2), 0); + } + } +} + +template <> +__device__ void cuWelfordMuSigma2(const __half *__restrict__ vals, + const int n1, const int n2, const int i1, + float &mu, float &sigma2, float *buf) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu = float(0); + sigma2 = float(0); + + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const __half *lvals = vals + i1 * n2; + int l = 8 * thrx; + if ((((size_t)lvals) & 3) != 0) { + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l + 7 < n2; l += 8 * numx) { + for (int k = 0; k < 8; k += 2) { + float2 curr = __half22float2(*((__half2 *)(lvals + l + k))); + cuWelfordOnlineSum(curr.x, mu, sigma2, count); + cuWelfordOnlineSum(curr.y, mu, sigma2, count); + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + float muB = WARP_SHFL(mu, srcLaneB); + float countB = WARP_SHFL(count, srcLaneB); + float sigma2B = WARP_SHFL(sigma2, srcLaneB); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + float *ubuf = (float *)buf; + float *ibuf = (float *)(ubuf + blockDim.y); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2 * wrt_y] = mu; + ubuf[2 * wrt_y + 1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float muB = ubuf[2 * threadIdx.y]; + float sigma2B = ubuf[2 * threadIdx.y + 1]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1] / float(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2 / float(n2), 0); + } + } +} + +template __device__ U rsqrt(U v) { return U(1) / sqrt(v); } +template <> __device__ float rsqrt(float v) { return rsqrtf(v); } +template <> __device__ double rsqrt(double v) { return rsqrt(v); } +template <> __device__ __half rsqrt(__half v) { return rsqrt(v); } +#if __CUDA_ARCH__ >= 800 +template <> __device__ __nv_bfloat16 rsqrt(__nv_bfloat16 v) { return rsqrt(v); } +#endif + +// This is the un-specialized struct. Note that we prevent instantiation of +// this struct by putting an undefined symbol in the function body so it won't +// compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template struct SharedMemory; +template <> struct SharedMemory { + __device__ float *getPointer() { + extern __shared__ float s_float[]; + return s_float; + } +}; + +template <> struct SharedMemory<__half> { + __device__ __half *getPointer() { + extern __shared__ __half s_half[]; + return s_half; + } +}; + +#if __CUDA_ARCH__ >= 800 +template <> struct SharedMemory<__nv_bfloat16> { + __device__ __nv_bfloat16 *getPointer() { + extern __shared__ __nv_bfloat16 s_bf[]; + return s_bf; + } +}; +#endif + +template +__device__ void +cuApplyLayerNorm(T *__restrict__ output_vals, U *__restrict__ mean, + U *__restrict__ invvar, const T *__restrict__ vals, + const int n1, const int n2, const U epsilon, + const T *__restrict__ gamma, const T *__restrict__ beta) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensors are contiguous + // + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + SharedMemory shared; + U *buf = shared.getPointer(); + U mu, sigma2; + cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf); + const T *lvals = vals + i1 * n2; + T *ovals = output_vals + i1 * n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && beta != NULL) { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; + } + } else { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + ovals[i] = static_cast(c_invvar * (curr - mu)); + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + mean[i1] = mu; + invvar[i1] = c_invvar; + } + } +} + +extern "C" __global__ void layernorm_f16(__half *__restrict__ output_vals, __half *__restrict__ mean, + __half *__restrict__ invvar, const __half *__restrict__ vals, + const int n1, const int n2, const __half epsilon, + const __half *__restrict__ gamma, const __half *__restrict__ beta) { + cuApplyLayerNorm(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta); +} + +extern "C" __global__ void layernorm_f32(float *__restrict__ output_vals, float *__restrict__ mean, + float *__restrict__ invvar, const float *__restrict__ vals, + const int n1, const int n2, const float epsilon, + const float *__restrict__ gamma, const float *__restrict__ beta) { + cuApplyLayerNorm(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta); +} + +#if __CUDA_ARCH__ >= 800 +#include +extern "C" __global__ void layernorm_bf16(__nv_bfloat16 *__restrict__ output_vals, __nv_bfloat16 *__restrict__ mean, + __nv_bfloat16 *__restrict__ invvar, const __nv_bfloat16 *__restrict__ vals, + const int n1, const int n2, const __nv_bfloat16 epsilon, + const __nv_bfloat16 *__restrict__ gamma, const __nv_bfloat16 *__restrict__ beta) { + cuApplyLayerNorm(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta); +} +#endif diff --git a/candle-kernels/src/fused_rms_norm.cu b/candle-kernels/src/fused_rms_norm.cu new file mode 100644 index 0000000000..f012e002ad --- /dev/null +++ b/candle-kernels/src/fused_rms_norm.cu @@ -0,0 +1,82 @@ +#include "cuda_fp16.h" +#include + +#define WARP_SIZE 32 + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) +#endif + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) + val += VLLM_SHFL_XOR_SYNC(val, mask); + return val; +} + +__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) { + return warp_size - 1; +} + +__inline__ __device__ constexpr int _calculateWidShift(int warp_size) { + return 5 + (warp_size >> 6); +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[WARP_SIZE]; + constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE); + constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE); + int lane = threadIdx.x & LANE_MASK; + int wid = threadIdx.x >> WID_SHIFT; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + return val; +} + +#define RMS_NORM_OP(FN_NAME, TYPENAME)\ +extern "C" __global__ void FN_NAME(\ + TYPENAME* __restrict__ out,\ + const TYPENAME* __restrict__ input,\ + const TYPENAME* __restrict__ weight,\ + const float epsilon,\ + const int num_tokens,\ + const int hidden_size) {\ + __shared__ float s_variance;\ + float variance = 0.0f;\ + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {\ + const float x = (float) input[blockIdx.x * hidden_size + idx];\ + variance += x * x;\ + }\ + variance = blockReduceSum(variance);\ + if (threadIdx.x == 0) {\ + s_variance = rsqrtf(variance / hidden_size + epsilon);\ + }\ + __syncthreads();\ + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {\ + float x = (float) input[blockIdx.x * hidden_size + idx];\ + out[blockIdx.x * hidden_size + idx] = ((TYPENAME) (x * s_variance)) * weight[idx];\ + }\ +}\ + +RMS_NORM_OP(rms_norm_f32, float) +RMS_NORM_OP(rms_norm_f16, __half) + +#if __CUDA_ARCH__ >= 800 +#include +RMS_NORM_OP(rms_norm_bf16, __nv_bfloat16) +#endif \ No newline at end of file diff --git a/candle-kernels/src/fused_rope.cu b/candle-kernels/src/fused_rope.cu new file mode 100644 index 0000000000..9f7873cca7 --- /dev/null +++ b/candle-kernels/src/fused_rope.cu @@ -0,0 +1,231 @@ +#include "cuda_fp16.h" + +#ifndef USE_ROCM + #define LDG(arg) __ldg(arg) +#else + #define LDG(arg) *arg +#endif + +template +inline __device__ void apply_token_rotary_embedding( + scalar_t* __restrict__ arr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int embed_dim) +{ + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = LDG(cos_ptr + x_index); + sin = LDG(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = LDG(cos_ptr + x_index / 2); + sin = LDG(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const scalar_t* cache_ptr, + const int head_size, + const int num_heads, + const int num_kv_heads, + const int rot_dim, + const int token_idx, + const int64_t query_stride, + const int64_t key_stride) +{ + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(query + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); + } + + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(key + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); + } +} + +extern "C" __global__ void rotary_embedding_kernel_f32( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + float* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + float* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const float* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const float* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +extern "C" __global__ void rotary_embedding_kernel_f16( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + __half* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + __half* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const __half* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const __half* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding<__half, false>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +extern "C" __global__ void rotary_embedding_kernel_f64( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + double* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + double* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const double* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const double* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + + + + +extern "C" __global__ void rotary_embedding_kernel_neox_f32( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + float* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + float* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const float* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const float* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +extern "C" __global__ void rotary_embedding_kernel_neox_f16( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + __half* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + __half* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const __half* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const __half* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding<__half, true>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +extern "C" __global__ void rotary_embedding_kernel_neox_f64( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + double* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + double* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const double* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const double* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +#if __CUDA_ARCH__ >= 800 +#include +extern "C" __global__ void rotary_embedding_kernel_bf16( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + __nv_bfloat16* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + __nv_bfloat16* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const __nv_bfloat16* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const __nv_bfloat16* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding<__nv_bfloat16, false>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +extern "C" __global__ void rotary_embedding_kernel_neox_bf16( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + __nv_bfloat16* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + __nv_bfloat16* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const __nv_bfloat16* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const __nv_bfloat16* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding<__nv_bfloat16, true>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} +#endif \ No newline at end of file diff --git a/candle-kernels/src/kvconcat.cu b/candle-kernels/src/kvconcat.cu new file mode 100644 index 0000000000..7c78d3abe7 --- /dev/null +++ b/candle-kernels/src/kvconcat.cu @@ -0,0 +1,53 @@ +#include "cuda_utils.cuh" +#include + +template +__device__ __forceinline__ void kvconcat_dim0_kernel(T *ltensor, T* rtensor, T *out, + const size_t chunk_l, const size_t chunk_r, const size_t lstride, const size_t rstride) { + size_t idx = GetThreadIdx(); + if (idx < chunk_l * lstride) { + out[idx] = ltensor[idx]; + } else { + out[idx] = rtensor[idx - chunk_l * lstride]; + } +} +template +__device__ __forceinline__ void kvconcat_dim2_kernel(T *ltensor, T* rtensor, T *out, + const size_t chunk_l, const size_t chunk_r, const size_t lstride, const size_t rstride) { + int thread_id = GetThreadIdx(); + int out_stride = lstride + rstride; + int idx = thread_id / out_stride; + int j = thread_id % out_stride; + T* pLeft = ltensor + idx * lstride; + T* pRight = rtensor + idx * rstride; + T* pOut = out + idx * out_stride; + if (idx < chunk_l) { + if (j < lstride) + pOut[j] = pLeft[j]; + else + pOut[j] = pRight[j - lstride]; + } +} + +#define KVCONCAT_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME(TYPENAME *ltensor, TYPENAME* rtensor, TYPENAME *out, const size_t concat_dim,\ + const size_t chunk_l, const size_t chunk_r, const size_t lstride, const size_t rstride) {\ + if (concat_dim == 2)\ + kvconcat_dim2_kernel(ltensor, rtensor, out, chunk_l, chunk_r, lstride, rstride);\ + else if (concat_dim == 0) {\ + if (blockIdx.x == 0 && threadIdx.x ==0) \ + kvconcat_dim0_kernel(ltensor, rtensor, out, chunk_l, chunk_r, lstride, rstride);\ + }\ +}\ + +KVCONCAT_OP(uint8_t, kvconcat_u8) +KVCONCAT_OP(double, kvconcat_f64) +KVCONCAT_OP(float, kvconcat_f32) + +#if __CUDA_ARCH__ >= 530 +KVCONCAT_OP(__half, kvconcat_f16) +#endif + +#if __CUDA_ARCH__ >= 800 +KVCONCAT_OP(__nv_bfloat16, kvconcat_bf16) +#endif \ No newline at end of file diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs index 1c73d6b774..74c6d3d6bb 100644 --- a/candle-kernels/src/lib.rs +++ b/candle-kernels/src/lib.rs @@ -9,3 +9,6 @@ pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx")); pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx")); pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx")); +pub const FUSED_RMS_NORM: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_rms_norm.ptx")); +pub const FUSED_ROPE: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_rope.ptx")); +pub const KVCONCAT: &str = include_str!(concat!(env!("OUT_DIR"), "/kvconcat.ptx")); diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 9f0d56bdea..b878d6d936 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -21,6 +21,7 @@ safetensors = { workspace = true } serde = { workspace = true } metal = { workspace = true, optional = true } candle-metal-kernels = { workspace = true, optional = true } +candle-layer-norm = { git = "https://github.com/EricLBuehler/candle-layer-norm.git", version = "0.0.1", optional = true } [dev-dependencies] anyhow = { workspace = true } @@ -31,7 +32,7 @@ criterion = { workspace = true } [features] default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] -cuda = ["candle/cuda"] +cuda = ["candle/cuda", "dep:candle-layer-norm"] mkl = ["dep:intel-mkl-src", "candle/mkl"] metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"] diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 23d0c01b09..85814a417a 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -28,8 +28,31 @@ //! ``` //! //! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450 + +use std::marker::PhantomData; +#[cfg(feature = "cuda")] +use std::{ + mem, + sync::{Arc, Mutex}, +}; + +#[cfg(feature = "cuda")] +use candle::cuda_backend::{ + cudarc::driver::{sys, DeviceRepr, LaunchAsync, LaunchConfig}, + kernel_name, kernels, CudaDType, WrapErr, +}; + +#[cfg(feature = "cuda")] +use candle::{ + backend::BackendStorage, from_storage_no_op, CudaDevice, CudaStorage, Device, Storage, + WithDType, +}; + use candle::{DType, Module, Result, Tensor, D}; +#[cfg(feature = "cuda")] +static MAX_GRID_Y: Mutex> = Mutex::new(None); + #[derive(Debug, Clone, Copy, PartialEq)] pub struct LayerNormConfig { pub eps: f64, @@ -63,7 +86,7 @@ impl From for LayerNormConfig { #[derive(Clone, Debug)] pub struct LayerNorm { weight: Tensor, - bias: Option, + bias: Tensor, remove_mean: bool, eps: f64, } @@ -72,7 +95,7 @@ impl LayerNorm { pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { Self { weight, - bias: Some(bias), + bias, remove_mean: true, eps, } @@ -80,8 +103,8 @@ impl LayerNorm { pub fn new_no_bias(weight: Tensor, eps: f64) -> Self { Self { - weight, - bias: None, + weight: weight.clone(), + bias: Tensor::zeros_like(&weight).unwrap(), remove_mean: true, eps, } @@ -89,8 +112,8 @@ impl LayerNorm { pub fn rms_norm(weight: Tensor, eps: f64) -> Self { Self { - weight, - bias: None, + weight: weight.clone(), + bias: Tensor::zeros_like(&weight).unwrap(), remove_mean: false, eps, } @@ -100,8 +123,8 @@ impl LayerNorm { &self.weight } - pub fn bias(&self) -> Option<&Tensor> { - self.bias.as_ref() + pub fn bias(&self) -> &Tensor { + &self.bias } } @@ -123,10 +146,7 @@ impl Module for LayerNorm { let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?; - match &self.bias { - None => Ok(x), - Some(bias) => x.broadcast_add(bias), - } + x.broadcast_add(&self.bias) } } @@ -143,47 +163,193 @@ pub fn layer_norm>( None }; Ok(LayerNorm { - weight, - bias, + weight: weight.clone(), + bias: bias.unwrap_or(Tensor::zeros_like(&weight)?), remove_mean: config.remove_mean, eps: config.eps, }) } +// This whole non quantized/quantized RmsNorm is a hack. It seems like quantized works without this impl, but it is slower. +#[derive(Clone, Debug)] +pub struct RmsNormQuantized; +#[derive(Clone, Debug)] +pub struct RmsNormNonQuantized; + /// RmsNorm is a specialized version of the LayerNorm module. #[derive(Clone, Debug)] -pub struct RmsNorm(LayerNorm); +pub struct RmsNorm { + inner: LayerNorm, + _ghost: PhantomData, +} -impl RmsNorm { +impl RmsNorm { pub fn new(weight: Tensor, eps: f64) -> Self { - Self(LayerNorm::rms_norm(weight, eps)) + Self { + inner: LayerNorm::rms_norm(weight, eps), + _ghost: PhantomData, + } + } +} + +impl RmsNorm { + pub fn new(weight: Tensor, eps: f64) -> Self { + Self { + inner: LayerNorm::rms_norm(weight, eps), + _ghost: PhantomData, + } + } + + #[cfg(feature = "cuda")] + fn dtype_execute_rmsnorm( + &self, + dev: &CudaDevice, + eps_converter: F, + x_storage: &CudaStorage, + weight_storage: &CudaStorage, + x: &Tensor, + ) -> Result + where + F: FnOnce(f64) -> T, + { + assert!(x.layout().is_contiguous()); + let hidden_size = *x.dims().last().unwrap(); + let elem_count = x.elem_count(); + let num_tokens = elem_count / hidden_size; + let out = unsafe { dev.alloc::(elem_count) }.w()?; + + let cfg = LaunchConfig { + grid_dim: (num_tokens as u32, 1, 1), + block_dim: (u32::min(hidden_size as u32, 1024), 1, 1), + shared_mem_bytes: 0, + }; + + let func = dev.get_or_load_func(&kernel_name::("rms_norm"), kernels::FUSED_RMS_NORM)?; + + let params = ( + &out, + x_storage.as_cuda_slice::()?, + weight_storage.as_cuda_slice::()?, + eps_converter(self.inner.eps), + num_tokens as i32, + hidden_size as i32, + ); + unsafe { func.launch(cfg, params) }.w()?; + + Ok(from_storage_no_op( + Storage::Cuda(CudaStorage::wrap_cuda_slice(out, dev.clone())), + x.shape(), + false, + )) } + #[cfg(feature = "cuda")] + fn fused_rmsnorm(&self, x: &Tensor, dev: &CudaDevice) -> Result { + match ( + &*x.storage_and_layout().0, + &*self.inner.weight().storage_and_layout().0, + ) { + (Storage::Cuda(x_storage), Storage::Cuda(weight_storage)) => { + match (x_storage.dtype(), weight_storage.dtype()) { + (DType::BF16, DType::BF16) => self.dtype_execute_rmsnorm::( + dev, + |x| half::bf16::from_f64(x), + &x_storage, + &weight_storage, + x, + ), + (DType::F16, DType::F16) => self.dtype_execute_rmsnorm::( + dev, + |x| half::f16::from_f64(x), + &x_storage, + &weight_storage, + x, + ), + (DType::F32, DType::F32) => self.dtype_execute_rmsnorm::( + dev, + |x| x as f32, + &x_storage, + &weight_storage, + x, + ), + _ => candle::bail!("DType mismatch in fused rmsnorm."), + } + } + _ => unreachable!(), + } + } +} + +impl RmsNorm { pub fn into_inner(self) -> LayerNorm { - self.0 + self.inner } + pub fn inner(&self) -> &LayerNorm { + &self.inner + } +} - /// Faster variant of the forward kernel, this can only be used on contiguous tensors though. - pub fn forward_diff(&self, xs: &Tensor) -> Result { - self.0.forward(xs) +impl Module for RmsNorm { + fn forward(&self, xs: &Tensor) -> Result { + #[cfg(feature = "cuda")] + { + let (bs, s, h) = xs.dims3()?; + let xs = xs.reshape((bs * s, h))?; + let res = + candle_layer_norm::rms_norm(&xs, self.inner.weight(), None, self.inner.eps as f32)?; + res.reshape((bs, s, h)) + } + #[cfg(not(feature = "cuda"))] + { + self.inner.forward(xs) + } } } -impl Module for RmsNorm { +impl Module for RmsNorm { fn forward(&self, xs: &Tensor) -> Result { - if xs.is_contiguous() { - crate::ops::rms_norm(xs, &self.0.weight, self.0.eps as f32) - } else { - self.0.forward(xs) + #[cfg(feature = "cuda")] + match (xs.dtype(), xs.device()) { + (DType::BF16, Device::Cuda(dev)) + | (DType::F32, Device::Cuda(dev)) + | (DType::F16, Device::Cuda(dev)) => return self.fused_rmsnorm(xs, &dev), + _ => return self.inner.forward(xs), + } + #[cfg(not(feature = "cuda"))] + { + self.inner.forward(xs) } } } -pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result { +pub fn rms_norm_non_quant( + size: usize, + eps: f64, + vb: crate::VarBuilder, +) -> Result> { + let config = LayerNormConfig { + eps, + remove_mean: false, + affine: false, + }; + Ok(RmsNorm { + inner: layer_norm(size, config, vb)?, + _ghost: PhantomData, + }) +} + +pub fn rms_norm_quant( + size: usize, + eps: f64, + vb: crate::VarBuilder, +) -> Result> { let config = LayerNormConfig { eps, remove_mean: false, affine: false, }; - Ok(RmsNorm(layer_norm(size, config, vb)?)) + Ok(RmsNorm { + inner: layer_norm(size, config, vb)?, + _ghost: PhantomData, + }) } diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 5c0fbb376d..482898e9ca 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -12,6 +12,7 @@ pub mod loss; pub mod ops; pub mod optim; pub mod rnn; +pub mod rope; pub mod rotary_emb; pub mod sequential; pub mod var_builder; @@ -28,11 +29,14 @@ pub use embedding::{embedding, Embedding}; pub use func::{func, func_t, Func, FuncT}; pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; -pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; +pub use layer_norm::{ + layer_norm, rms_norm_non_quant, rms_norm_quant, LayerNorm, LayerNormConfig, RmsNorm, +}; pub use linear::{linear, linear_b, linear_no_bias, Linear}; -pub use ops::Dropout; +pub use ops::{kvconcat, Dropout}; pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; +pub use rope::RotaryEmbedding; pub use sequential::{seq, Sequential}; pub use var_builder::VarBuilder; pub use var_map::VarMap; diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index eabc95d81e..abe42699bc 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -678,3 +678,30 @@ pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result { n => candle::bail!("replication-pad with a size of {n} is not supported"), } } + +#[cfg(feature = "cuda")] +pub fn kvconcat(ltensor: &Tensor, rtensor: &Tensor, concat_dim: usize) -> Result { + if !ltensor.device().is_cuda() { + return Tensor::cat(&[ltensor, &rtensor], concat_dim as usize)?.contiguous(); + } + use candle::cuda_backend::KVConcat; + let op = KVConcat { concat_dim }; + //inputs for kvconcat must be contiguous tensors + if ltensor.is_contiguous() && rtensor.is_contiguous() { + ltensor.apply_op2(&rtensor, op) + } else if ltensor.is_contiguous() { + ltensor.apply_op2(&rtensor.contiguous()?, op) + } else if rtensor.is_contiguous() { + let ltensor = ltensor.contiguous()?; + ltensor.apply_op2(&rtensor, op) + } else { + let ltensor = ltensor.contiguous()?; + let rtensor = rtensor.contiguous()?; + ltensor.apply_op2(&rtensor, op) + } +} + +#[cfg(not(feature = "cuda"))] +pub fn kvconcat(ltensor: &Tensor, rtensor: &Tensor, concat_dim: i32) -> Result { + Tensor::cat(&[ltensor, rtensor], concat_dim as usize)?.contiguous() +} diff --git a/candle-nn/src/rope.rs b/candle-nn/src/rope.rs new file mode 100644 index 0000000000..f405ec02a4 --- /dev/null +++ b/candle-nn/src/rope.rs @@ -0,0 +1,330 @@ +use std::iter::zip; + +#[allow(unused_imports)] +use candle::{ + backend::BackendStorage, CudaDevice, CudaStorage, DType, Device, IndexOp, Module, Result, + Storage, Tensor, WithDType, D, +}; + +#[cfg(feature = "cuda")] +use candle::cuda_backend::{ + cudarc::driver::{ + CudaFunction, CudaStream, DeviceRepr, DriverError, LaunchAsync, LaunchConfig, + }, + kernel_name, kernels, CudaDType, +}; + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct RotaryEmbedding { + cos: Tensor, + sin: Tensor, + head_size: usize, + cache: Tensor, + is_gpt_neox: bool, +} + +impl RotaryEmbedding { + pub fn new( + base: f32, + head_dim: usize, + max_position_embeddings: usize, + device: &Device, + is_gpt_neox: bool, + dtype: DType, + ) -> Result { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32)) + .collect(); + let theta_len = theta.len(); + let theta = Tensor::from_vec(theta, (1, theta_len), device)?.to_dtype(DType::F32)?; + let idx_theta = Tensor::arange(0, max_position_embeddings as u32, device)? + .to_dtype(DType::F32)? + .reshape((max_position_embeddings, 1))? + .matmul(&theta)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok(Self { + head_size: head_dim, + cos: if is_gpt_neox { + Tensor::cat( + &[cos.clone().to_dtype(dtype)?, cos.clone().to_dtype(dtype)?], + D::Minus1, + )? + } else { + cos.clone().to_dtype(dtype)? + }, + sin: if is_gpt_neox { + Tensor::cat( + &[sin.clone().to_dtype(dtype)?, sin.clone().to_dtype(dtype)?], + D::Minus1, + )? + } else { + sin.clone().to_dtype(dtype)? + }, + cache: Tensor::cat(&[cos.clone(), sin.clone()], D::Minus1)? + .contiguous()? + .to_dtype(dtype)?, + is_gpt_neox, + }) + } + + pub fn new_partial( + base: f32, + head_dim: usize, + rot_dim: usize, + max_position_embeddings: usize, + device: &Device, + is_gpt_neox: bool, + dtype: DType, + ) -> Result { + let theta: Vec<_> = (0..rot_dim) + .step_by(2) + .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32)) + .collect(); + let theta_len = theta.len(); + let theta = Tensor::from_vec(theta, (1, theta_len), device)?.to_dtype(DType::F32)?; + let idx_theta = Tensor::arange(0, max_position_embeddings as u32, device)? + .to_dtype(DType::F32)? + .reshape((max_position_embeddings, 1))? + .matmul(&theta)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok(Self { + head_size: head_dim, + cos: if is_gpt_neox { + Tensor::cat( + &[cos.clone().to_dtype(dtype)?, cos.clone().to_dtype(dtype)?], + D::Minus1, + )? + } else { + cos.clone().to_dtype(dtype)? + }, + sin: if is_gpt_neox { + Tensor::cat( + &[sin.clone().to_dtype(dtype)?, sin.clone().to_dtype(dtype)?], + D::Minus1, + )? + } else { + sin.clone().to_dtype(dtype)? + }, + cache: Tensor::cat(&[cos.clone(), sin.clone()], D::Minus1)? + .contiguous()? + .to_dtype(dtype)?, + is_gpt_neox, + }) + } + + #[cfg(feature = "cuda")] + fn execute_dtype( + &self, + dev: &CudaDevice, + q_storage: &CudaStorage, + k_storage: &CudaStorage, + q: &Tensor, + k: &Tensor, + cache_storage: &CudaStorage, + pos_storage: &CudaStorage, + ) -> Result<()> { + use candle::cuda_backend::WrapErr; + + let num_tokens = q.dim(0)?; + let rot_dim = self.cache.dim(1)?; + let num_heads = q.dim(1)?; + let num_kv_heads = k.dim(1)?; + let q_stride = q.stride()[0]; + let k_stride = k.stride()[0]; + + let func = dev.get_or_load_func( + &if self.is_gpt_neox { + kernel_name::("rotary_embedding_kernel_neox") + } else { + kernel_name::("rotary_embedding_kernel") + }, + kernels::FUSED_ROPE, + )?; + + let cfg = LaunchConfig { + grid_dim: (num_tokens as u32, 1, 1), + block_dim: (512.min((num_heads * rot_dim / 2) as u32), 1, 1), + shared_mem_bytes: 0, + }; + + let params = ( + pos_storage.as_cuda_slice::()?, + q_storage.as_cuda_slice::()?, + k_storage.as_cuda_slice::()?, + cache_storage.as_cuda_slice::()?, + rot_dim as i32, + q_stride as i64, + k_stride as i64, + num_heads as i32, + num_kv_heads as i32, + self.head_size as i32, + ); + unsafe { func.launch(cfg, params) }.w()?; + + Ok(()) + } + + #[cfg(feature = "cuda")] + fn fused_rope( + &self, + dev: &CudaDevice, + positions: &Tensor, + q: &Tensor, + k: &Tensor, + ) -> Result<()> { + let cache_type = self.cache.dtype(); + match ( + &*q.storage_and_layout().0, + &*k.storage_and_layout().0, + &*self.cache.storage_and_layout().0, + &*positions.storage_and_layout().0, + ) { + ( + Storage::Cuda(q_storage), + Storage::Cuda(k_storage), + Storage::Cuda(cache_storage), + Storage::Cuda(pos_storage), + ) => { + return match (q.dtype(), k.dtype(), cache_type) { + (DType::BF16, DType::BF16, DType::BF16) => self.execute_dtype::( + &dev, + q_storage, + k_storage, + q, + k, + cache_storage, + pos_storage, + ), + (DType::F16, DType::F16, DType::F16) => self.execute_dtype::( + &dev, + q_storage, + k_storage, + q, + k, + cache_storage, + pos_storage, + ), + (DType::F32, DType::F32, DType::F32) => self.execute_dtype::( + &dev, + q_storage, + k_storage, + q, + k, + cache_storage, + pos_storage, + ), + (DType::F64, DType::F64, DType::F64) => self.execute_dtype::( + &dev, + q_storage, + k_storage, + q, + k, + cache_storage, + pos_storage, + ), + _ => candle::bail!( + "DType mismatch in fused RotaryEmbedding q={:?}, k={:?}, cache={:?}", + q.dtype(), + k.dtype(), + cache_type + ), + } + } + _ => unreachable!(), + }; + } + + /// This may modify the tensors in place! + #[allow(unused_variables)] + pub fn forward( + &self, + positions: &[usize], + positions_kernel: &Tensor, + q: &mut Tensor, + k: &mut Tensor, + b_sz: usize, + ) -> Result<()> { + match (q.device(), k.device()) { + #[cfg(feature = "cuda")] + (Device::Cuda(dev), Device::Cuda(_)) => { + self.fused_rope(dev, positions_kernel, &*q, &*k)?; + } + + _ => { + *q = self.apply_rotary_emb(&*q, positions, b_sz)?; + *k = self.apply_rotary_emb(&*k, positions, b_sz)?; + } + }; + Ok(()) + } + + fn apply_rotary_emb( + &self, + x: &Tensor, + seqlen_offsets: &[usize], + b_sz: usize, + ) -> Result { + let (b_sz_seq_len, h, n_embd) = x.dims3()?; + let x = x + .reshape((b_sz, b_sz_seq_len / b_sz, h, n_embd))? + .transpose(1, 2)?; + + fn rotate_half(xs: &Tensor) -> Result { + let last_dim = xs.dim(D::Minus1)?; + let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; + let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; + Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) + } + let (b_sz, n_head, seq_len, _n_embd) = x.dims4()?; + if self.is_gpt_neox { + let mut embeds = Vec::new(); + for (b, seqlen_offset) in zip(0..b_sz, seqlen_offsets) { + let cos = self.cos.narrow(0, *seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, *seqlen_offset, seq_len)?; + let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let x_b = x.i(b)?.unsqueeze(0)?; + let embed = (x_b.broadcast_mul(&cos)? + rotate_half(&x_b)?.broadcast_mul(&sin)?)?; + embeds.push(embed); + } + Tensor::cat(&embeds, 0) + } else { + let mut ropes = Vec::new(); + let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; + for (b, seqlen_offset) in zip(0..b_sz, seqlen_offsets) { + let cos = self.cos.narrow(0, *seqlen_offset, seq_len)?.reshape(( + seq_len, + n_embd / 2, + 1, + ))?; + let sin = self.sin.narrow(0, *seqlen_offset, seq_len)?.reshape(( + seq_len, + n_embd / 2, + 1, + ))?; + let cos = cos.broadcast_as((1, 1, seq_len, n_embd / 2, 1))?; + let sin = sin.broadcast_as((1, 1, seq_len, n_embd / 2, 1))?; + // This mimics the llama.cpp behavior. + // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105 + // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. + // The resulting y0 and y1 are also interleaved with: + // y0 = x0*cos - x1*sin + // y1 = x0*sin + x1*cos + let x_b = x.i(b)?.unsqueeze(0)?; + let x_b = x_b.reshape((1, n_head, seq_len, n_embd / 2, 2))?; + let x0 = x_b.narrow(D::Minus1, 0, 1)?; + let x1 = x_b.narrow(D::Minus1, 1, 1)?; + let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; + let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; + let rope = Tensor::cat(&[y0, y1], D::Minus1)?; + let rope = rope.flatten_from(D::Minus2)?; + ropes.push(rope); + } + Tensor::cat(&ropes, 0) + } + } +} diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 68bd6f058f..b51d055ca0 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -32,7 +32,7 @@ impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> { pub type VarBuilder<'a> = VarBuilderArgs<'a, Box>; struct TensorData { - backend: B, + backend: Arc, pub dtype: DType, pub device: Device, } @@ -94,7 +94,7 @@ impl<'a> Backend for Box { impl<'a, B: Backend> VarBuilderArgs<'a, B> { pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self { let data = TensorData { - backend, + backend: Arc::new(backend), dtype, device: dev.clone(), }; @@ -199,6 +199,18 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { .backend .get(s.into(), &path, hints, dtype, &self.data.device) } + + /// Set the device of the VarBuilder. + pub fn set_device(self, device: Device) -> Self { + Self { + data: Arc::new(TensorData { + backend: self.data.backend.clone(), + dtype: self.data.dtype, + device, + }), + ..self + } + } } struct Zeros; @@ -434,7 +446,7 @@ impl<'a> VarBuilder<'a> { device: Device, ) -> Self { let data = TensorData { - backend, + backend: Arc::new(backend), dtype, device, }; @@ -535,7 +547,7 @@ impl<'a> VarBuilder<'a> { let backend = Rename::new(self, renamer); let backend: Box = Box::new(backend); let data = TensorData { - backend, + backend: Arc::new(backend), dtype, device, }; diff --git a/candle-transformers/src/models/chatglm.rs b/candle-transformers/src/models/chatglm.rs index 0686b34ef3..da093a7c17 100644 --- a/candle-transformers/src/models/chatglm.rs +++ b/candle-transformers/src/models/chatglm.rs @@ -374,7 +374,7 @@ struct Block { impl Block { fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result { let input_layernorm = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("input_layernorm"), @@ -388,7 +388,7 @@ impl Block { )? }; let post_attention_layernorm = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("post_attention_layernorm"), @@ -460,7 +460,7 @@ impl Transformer { } let final_layernorm = if cfg.post_layer_norm { let ln = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("final_layernorm"), diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index bba8b66607..ac9bed2850 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -1,6 +1,7 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::layer_norm::RmsNormNonQuantized; use candle_nn::linear_no_bias as linear; -use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; +use candle_nn::{embedding, rms_norm_non_quant, Embedding, Linear, Module, RmsNorm, VarBuilder}; use std::collections::HashMap; #[derive(Debug, Clone)] @@ -282,14 +283,19 @@ impl Mlp { #[derive(Debug, Clone)] struct Block { - rms_1: RmsNorm, + rms_1: RmsNorm, attn: CausalSelfAttention, - rms_2: RmsNorm, + rms_2: RmsNorm, mlp: Mlp, } impl Block { - fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + fn new( + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Mlp, + ) -> Self { Self { rms_1, attn, @@ -316,9 +322,9 @@ impl Block { fn load(vb: VarBuilder, cfg: &Config) -> Result { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; - let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; + let input_layernorm = rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; let post_attention_layernorm = - rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; + rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; Ok(Self::new( input_layernorm, attn, @@ -332,7 +338,7 @@ impl Block { pub struct Llama { wte: Embedding, blocks: Vec, - ln_f: RmsNorm, + ln_f: RmsNorm, lm_head: Linear, pub config: Config, } @@ -352,7 +358,7 @@ impl Llama { pub fn load(vb: VarBuilder, cfg: Config) -> Result { let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; - let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; + let ln_f = rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), &cfg).unwrap()) .collect(); diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs index a75ee87a6e..c8d9bf1e2e 100644 --- a/candle-transformers/src/models/mamba.rs +++ b/candle-transformers/src/models/mamba.rs @@ -2,7 +2,7 @@ /// This is based on: https://github.com/LaurentMazare/mamba.rs use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::{RmsNorm, VarBuilder}; +use candle_nn::{layer_norm::RmsNormNonQuantized, RmsNorm, VarBuilder}; const D_CONV: usize = 4; const D_STATE: usize = 16; @@ -155,12 +155,12 @@ impl MambaBlock { #[derive(Clone, Debug)] pub struct ResidualBlock { mixer: MambaBlock, - norm: RmsNorm, + norm: RmsNorm, } impl ResidualBlock { pub fn new(layer_index: usize, cfg: &Config, vb: VarBuilder) -> Result { - let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?; + let norm = candle_nn::rms_norm_non_quant(cfg.d_model, 1e-5, vb.pp("norm"))?; let mixer = MambaBlock::new(layer_index, cfg, vb.pp("mixer"))?; Ok(Self { mixer, norm }) } @@ -175,7 +175,7 @@ impl ResidualBlock { pub struct Model { embedding: candle_nn::Embedding, layers: Vec, - norm_f: RmsNorm, + norm_f: RmsNorm, lm_head: Linear, dtype: DType, } @@ -189,7 +189,7 @@ impl Model { let layer = ResidualBlock::new(layer_idx, cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?; + let norm_f = candle_nn::rms_norm_non_quant(cfg.d_model, 1e-5, vb.pp("norm_f"))?; let lm_head = Linear::from_weights(embedding.embeddings().clone(), None); Ok(Self { embedding, diff --git a/candle-transformers/src/models/metavoice.rs b/candle-transformers/src/models/metavoice.rs index 43de594f9d..ec382711cd 100644 --- a/candle-transformers/src/models/metavoice.rs +++ b/candle-transformers/src/models/metavoice.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D}; -use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; +use candle_nn::{embedding, linear_b, rms_norm_non_quant, Embedding, Linear, RmsNorm, VarBuilder}; // Equivalent to torch.repeat_interleave pub(crate) fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result { @@ -328,6 +328,8 @@ pub mod tokenizers { } pub mod gpt { + use candle_nn::layer_norm::RmsNormNonQuantized; + use super::*; #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] @@ -350,7 +352,7 @@ pub mod gpt { } enum Norm { - RMSNorm(candle_nn::RmsNorm), + RMSNorm(candle_nn::RmsNorm), LayerNorm(candle_nn::LayerNorm), } @@ -400,7 +402,7 @@ pub mod gpt { fn new(cfg: &Config, vb: VarBuilder) -> Result { match cfg.norm_type { NormType::RMSNorm => { - let rms_norm = candle_nn::rms_norm(cfg.n_embd, cfg.rmsnorm_eps, vb)?; + let rms_norm = candle_nn::rms_norm_non_quant(cfg.n_embd, cfg.rmsnorm_eps, vb)?; Ok(Self::RMSNorm(rms_norm)) } NormType::LayerNorm => { @@ -666,6 +668,8 @@ pub mod gpt { } pub mod transformer { + use candle_nn::layer_norm::RmsNormNonQuantized; + use super::*; #[derive(Debug, Clone, serde::Deserialize)] @@ -833,8 +837,8 @@ pub mod transformer { struct Block { attention: Attention, feed_forward: FeedForward, - ffn_norm: RmsNorm, - attention_norm: RmsNorm, + ffn_norm: RmsNorm, + attention_norm: RmsNorm, span: tracing::Span, } @@ -842,8 +846,9 @@ pub mod transformer { fn new(cfg: &Config, vb: VarBuilder) -> Result { let attention = Attention::new(cfg, vb.pp("attention"))?; let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?; - let ffn_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?; - let attention_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?; + let ffn_norm = rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?; + let attention_norm = + rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?; Ok(Self { attention, feed_forward, @@ -871,7 +876,7 @@ pub mod transformer { pos_embeddings: Embedding, speaker_cond_pos: Linear, layers: Vec, - norm: RmsNorm, + norm: RmsNorm, output: Linear, spk_cond_mask: Tensor, span: tracing::Span, @@ -893,7 +898,7 @@ pub mod transformer { let layer = Block::new(cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("norm"))?; + let norm = rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("norm"))?; let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?; let dtype = vb.dtype(); let spk_cond_mask = Tensor::cat( diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index de2430a293..02f8415887 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -40,7 +40,6 @@ pub mod quantized_mixformer; pub mod quantized_moondream; pub mod quantized_mpt; pub mod quantized_phi; -pub mod quantized_phi3; pub mod quantized_recurrent_gemma; pub mod quantized_rwkv_v5; pub mod quantized_rwkv_v6; diff --git a/candle-transformers/src/models/quantized_phi3.rs b/candle-transformers/src/models/quantized_phi3.rs deleted file mode 100644 index ef404ca0e1..0000000000 --- a/candle-transformers/src/models/quantized_phi3.rs +++ /dev/null @@ -1,301 +0,0 @@ -use std::collections::HashMap; - -use candle::quantized::gguf_file; -use candle::quantized::QTensor; -use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::{Embedding, RmsNorm}; - -pub const MAX_SEQ_LEN: usize = 4096; - -#[derive(Debug, Clone)] -struct QLinear { - inner: candle::quantized::QMatMul, - span: tracing::Span, -} - -impl QLinear { - fn new( - ct: &gguf_file::Content, - r: &mut R, - name: &str, - device: &Device, - ) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); - let w = ct.tensor(r, &format!("{name}.weight"), device)?; - let inner = candle::quantized::QMatMul::from_qtensor(w)?; - Ok(Self { inner, span }) - } -} - -impl Module for QLinear { - fn forward(&self, xs: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(xs) - } -} - -#[derive(Debug, Clone)] -struct Mlp { - ffn_up: QLinear, - ffn_down: QLinear, - i_size: usize, -} - -impl Module for Mlp { - fn forward(&self, xs: &Tensor) -> Result { - let up_states = xs.apply(&self.ffn_up)?; - let gate = up_states.narrow(D::Minus1, 0, self.i_size)?; - let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?; - let up_states = (up_states * gate.silu()?)?; - up_states.apply(&self.ffn_down) - } -} - -fn rms_norm(w: QTensor, eps: f64) -> Result { - let w = w.dequantize(&w.device())?; - let rms = RmsNorm::new(w, eps); - Ok(rms) -} - -#[derive(Debug, Clone)] -struct LayerWeights { - attn_qkv: QLinear, - attn_output: QLinear, - attn_norm: RmsNorm, - ffn_norm: RmsNorm, - mlp: Mlp, - n_head: usize, - n_kv_head: usize, - head_dim: usize, - cos: Tensor, - sin: Tensor, - neg_inf: Tensor, - kv_cache: Option<(Tensor, Tensor)>, - span_attn: tracing::Span, - span_rot: tracing::Span, -} - -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result { - let shape = mask.shape(); - let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; - Ok(m) -} - -impl LayerWeights { - fn apply_rotary_emb(&self, xs: &Tensor, index_pos: usize) -> Result { - let _enter = self.span_rot.enter(); - let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?; - let cos = self.cos.narrow(0, index_pos, seq_len)?; - let sin = self.sin.narrow(0, index_pos, seq_len)?; - candle_nn::rotary_emb::rope(&xs.contiguous()?, &cos, &sin) - } - - fn forward_attn( - &mut self, - x: &Tensor, - mask: Option<&Tensor>, - index_pos: usize, - ) -> Result { - let _enter = self.span_attn.enter(); - let (b_sz, seq_len, n_embd) = x.dims3()?; - let qkv = self.attn_qkv.forward(x)?; - - let query_pos = self.n_head * self.head_dim; - let q = qkv.narrow(D::Minus1, 0, query_pos)?; - let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?; - let v = qkv.narrow( - D::Minus1, - query_pos + self.n_kv_head * self.head_dim, - self.n_kv_head * self.head_dim, - )?; - - let q = q - .reshape((b_sz, seq_len, self.n_head, self.head_dim))? - .transpose(1, 2)?; - let k = k - .reshape((b_sz, seq_len, self.n_head, self.head_dim))? - .transpose(1, 2)?; - let v = v - .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? - .transpose(1, 2)?; - - let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?; - let k = self.apply_rotary_emb(&k, index_pos)?; - - let (k, v) = match &self.kv_cache { - None => (k.contiguous()?, v.contiguous()?), - Some((k_cache, v_cache)) => { - if index_pos == 0 { - (k.contiguous()?, v.contiguous()?) - } else { - let k = Tensor::cat(&[k_cache, &k], 2)?; - let v = Tensor::cat(&[v_cache, &v], 2)?; - (k.contiguous()?, v.contiguous()?) - } - } - }; - self.kv_cache = Some((k.clone(), v.clone())); - - let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?; - let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?; - - let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let att = match mask { - None => att, - Some(mask) => { - let mask = mask.broadcast_as(att.shape())?; - masked_fill(&att, &mask, &self.neg_inf)? - } - }; - let att = candle_nn::ops::softmax_last_dim(&att)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; - let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; - let y = self.attn_output.forward(&y)?; - Ok(y) - } -} - -#[derive(Debug, Clone)] -pub struct ModelWeights { - tok_embeddings: Embedding, - layers: Vec, - output_norm: RmsNorm, - output: QLinear, - masks: HashMap, - span: tracing::Span, - span_output: tracing::Span, -} - -fn precomput_freqs_cis( - head_dim: usize, - freq_base: f32, - device: &Device, -) -> Result<(Tensor, Tensor)> { - let theta: Vec<_> = (0..head_dim) - .step_by(2) - .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) - .collect(); - let theta = Tensor::new(theta.as_slice(), device)?; - let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? - .to_dtype(DType::F32)? - .reshape((MAX_SEQ_LEN, 1))? - .matmul(&theta.reshape((1, theta.elem_count()))?)?; - let cos = idx_theta.cos()?; - let sin = idx_theta.sin()?; - Ok((cos, sin)) -} - -impl ModelWeights { - pub fn from_gguf( - ct: gguf_file::Content, - reader: &mut R, - device: &Device, - ) -> Result { - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; - - // Parameter extraction from metadata. - let head_count = md_get("phi3.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("phi3.block_count")?.to_u32()? as usize; - let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize; - let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize; - let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize; - let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; - let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?; - let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; - - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; - let tok_embeddings = tok_embeddings.dequantize(device)?; - let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?; - let output = QLinear::new(&ct, reader, "output", device)?; - let mut layers = Vec::with_capacity(block_count); - for layer_idx in 0..block_count { - let prefix = format!("blk.{layer_idx}"); - let ffn_up = QLinear::new(&ct, reader, &format!("{prefix}.ffn_up"), device)?; - let ffn_down = QLinear::new(&ct, reader, &format!("{prefix}.ffn_down"), device)?; - let mlp = Mlp { - ffn_up, - ffn_down, - i_size, - }; - let attn_norm = rms_norm( - ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?, - rms_eps, - )?; - let ffn_norm = rms_norm( - ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?, - rms_eps, - )?; - let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); - let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); - layers.push(LayerWeights { - attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?, - attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?, - attn_norm, - ffn_norm, - mlp, - n_head: head_count, - n_kv_head: head_count_kv, - head_dim: embedding_length / head_count, - cos: cos.clone(), - sin: sin.clone(), - neg_inf: neg_inf.clone(), - kv_cache: None, - span_attn, - span_rot, - }) - } - let span = tracing::span!(tracing::Level::TRACE, "model"); - let span_output = tracing::span!(tracing::Level::TRACE, "output"); - Ok(Self { - tok_embeddings: Embedding::new(tok_embeddings, embedding_length), - layers, - output_norm, - output, - masks: HashMap::new(), - span, - span_output, - }) - } - - fn mask(&mut self, t: usize, device: &Device) -> Result { - if let Some(mask) = self.masks.get(&t) { - Ok(mask.clone()) - } else { - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) - .collect(); - let mask = Tensor::from_slice(&mask, (t, t), device)?; - self.masks.insert(t, mask.clone()); - Ok(mask) - } - } - - pub fn forward(&mut self, xs: &Tensor, index_pos: usize) -> Result { - let (_b_sz, seq_len) = xs.dims2()?; - let mask = if seq_len == 1 { - None - } else { - Some(self.mask(seq_len, xs.device())?) - }; - let _enter = self.span.enter(); - let mut xs = self.tok_embeddings.forward(xs)?; - for layer in self.layers.iter_mut() { - let residual = &xs; - let ys = xs.apply(&layer.attn_norm)?; - let ys = layer.forward_attn(&ys, mask.as_ref(), index_pos)?; - let ys = (ys + residual)?; - let residual = &ys; - let ys = ys.apply(&layer.ffn_norm)?; - let ys = layer.mlp.forward(&ys)?; - xs = (ys + residual)? - } - let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?; - let _enter = self.span_output.enter(); - self.output.forward(&xs) - } -} diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index f4706c7e95..29bbf637e6 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -1,5 +1,5 @@ use candle::{Module, Result, Tensor}; -use candle_nn::VarBuilder; +use candle_nn::{layer_norm::RmsNormNonQuantized, VarBuilder}; #[derive(Debug, Clone)] pub struct Embedding { @@ -170,20 +170,20 @@ pub fn layer_norm>( #[derive(Debug, Clone)] pub struct RmsNorm { - inner: candle_nn::RmsNorm, + inner: candle_nn::RmsNorm, span: tracing::Span, } impl RmsNorm { pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let inner = candle_nn::rms_norm(size, eps, vb)?; + let inner = candle_nn::rms_norm_non_quant(size, eps, vb)?; Ok(Self { inner, span }) } pub fn forward_diff(&self, x: &Tensor) -> Result { let _enter = self.span.enter(); - self.inner.forward_diff(x) + self.inner.forward(x) } } diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index ab9333d27e..f640bbd08e 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -1,6 +1,7 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::layer_norm::RmsNormNonQuantized; use candle_nn::{ - embedding, linear_no_bias as linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder, + embedding, linear_no_bias as linear, Embedding, Linear, Module, RmsNorm, VarBuilder, }; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -201,14 +202,19 @@ impl Mlp { } struct Block { - rms_1: RmsNorm, + rms_1: RmsNorm, attn: CausalSelfAttention, - rms_2: RmsNorm, + rms_2: RmsNorm, mlp: Mlp, } impl Block { - fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + fn new( + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Mlp, + ) -> Self { Self { rms_1, attn, @@ -229,9 +235,13 @@ impl Block { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; - let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; - let post_attention_layernorm = - rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; + let input_layernorm = + candle_nn::rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = candle_nn::rms_norm_non_quant( + cfg.dim, + cfg.norm_eps, + vb.pp("post_attention_layernorm"), + )?; Ok(Self::new( input_layernorm, attn, @@ -244,12 +254,17 @@ impl Block { pub struct Llama { wte: Embedding, blocks: Vec, - ln_f: RmsNorm, + ln_f: RmsNorm, lm_head: Linear, } impl Llama { - fn new(wte: Embedding, blocks: Vec, ln_f: RmsNorm, lm_head: Linear) -> Self { + fn new( + wte: Embedding, + blocks: Vec, + ln_f: RmsNorm, + lm_head: Linear, + ) -> Self { Self { wte, blocks, @@ -273,7 +288,7 @@ impl Llama { pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; - let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; + let norm = candle_nn::rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap()) .collect(); From 5892fac1ebb77bfb1ad4e02b0c3fb8312693db1e Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Thu, 16 May 2024 16:39:22 -0400 Subject: [PATCH 02/75] fix issue with cuda header file for A10G (#5) --- candle-flash-attn/kernels/flash_fwd_launch_template.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h index 66ab6206db..e1d0a503cf 100644 --- a/candle-flash-attn/kernels/flash_fwd_launch_template.h +++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h @@ -60,7 +60,11 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + if constexpr(!Is_dropout) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } }); }); } From ea49ea28e4dcfb4836be148a3fbbe4d2914528d7 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Sun, 19 May 2024 10:32:16 -0400 Subject: [PATCH 03/75] Remove candle-layer-norm (#6) * Support flash-attn in quantized phi3. (#2194) * Use flash-attn in gemma. (#2195) * Use flash-attn in gemma. * Fix flash-attn for head dim 256. * Remove candle-layer-norm --------- Co-authored-by: Laurent Mazare --- candle-examples/examples/gemma/main.rs | 5 +- .../examples/quantized-phi/main.rs | 3 + .../kernels/flash_fwd_launch_template.h | 4 ++ candle-flash-attn/src/lib.rs | 4 +- candle-nn/Cargo.toml | 3 +- candle-nn/src/layer_norm.rs | 13 +--- candle-transformers/src/models/gemma.rs | 62 +++++++++++++------ 7 files changed, 60 insertions(+), 34 deletions(-) diff --git a/candle-examples/examples/gemma/main.rs b/candle-examples/examples/gemma/main.rs index a5f7d5917d..31c5561842 100644 --- a/candle-examples/examples/gemma/main.rs +++ b/candle-examples/examples/gemma/main.rs @@ -193,6 +193,9 @@ struct Args { /// The model to use. #[arg(long, default_value = "2b")] which: Which, + + #[arg(long)] + use_flash_attn: bool, } fn main() -> Result<()> { @@ -270,7 +273,7 @@ fn main() -> Result<()> { DType::F32 }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; - let model = Model::new(&config, vb)?; + let model = Model::new(args.use_flash_attn, &config, vb)?; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-examples/examples/quantized-phi/main.rs b/candle-examples/examples/quantized-phi/main.rs index f17cec69c9..9ab024c20f 100644 --- a/candle-examples/examples/quantized-phi/main.rs +++ b/candle-examples/examples/quantized-phi/main.rs @@ -87,6 +87,9 @@ struct Args { /// The model size to use. #[arg(long, default_value = "phi-3b")] which: Which, + + #[arg(long)] + use_flash_attn: bool, } impl Args { diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h index e1d0a503cf..759e84dddb 100644 --- a/candle-flash-attn/kernels/flash_fwd_launch_template.h +++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h @@ -42,6 +42,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // auto kernel = &flash_fwd_kernel; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } // int ctas_per_sm; // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 21a06b5ecf..f171a9868f 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -139,7 +139,9 @@ impl FlashAttn { let elem_count = out_shape.elem_count(); let dst = unsafe { dev.alloc::(elem_count) }.w()?; - let softmax_lse = dev.alloc_zeros::(b_sz * num_heads * seqlen_q).w()?; + let softmax_lse = dev + .alloc_zeros::(b_sz * 128 * num_heads * seqlen_q) + .w()?; let is_bf16 = if is_bf16 { 1 } else { 0 }; diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index b878d6d936..9f0d56bdea 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -21,7 +21,6 @@ safetensors = { workspace = true } serde = { workspace = true } metal = { workspace = true, optional = true } candle-metal-kernels = { workspace = true, optional = true } -candle-layer-norm = { git = "https://github.com/EricLBuehler/candle-layer-norm.git", version = "0.0.1", optional = true } [dev-dependencies] anyhow = { workspace = true } @@ -32,7 +31,7 @@ criterion = { workspace = true } [features] default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] -cuda = ["candle/cuda", "dep:candle-layer-norm"] +cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src", "candle/mkl"] metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"] diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 85814a417a..9fc6328e8d 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -291,18 +291,7 @@ impl RmsNorm { impl Module for RmsNorm { fn forward(&self, xs: &Tensor) -> Result { - #[cfg(feature = "cuda")] - { - let (bs, s, h) = xs.dims3()?; - let xs = xs.reshape((bs * s, h))?; - let res = - candle_layer_norm::rms_norm(&xs, self.inner.weight(), None, self.inner.eps as f32)?; - res.reshape((bs, s, h)) - } - #[cfg(not(feature = "cuda"))] - { - self.inner.forward(xs) - } + self.inner.forward(xs) } } diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index 3bde88b4c0..1cfef59eba 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -73,13 +73,6 @@ struct RotaryEmbedding { cos: Tensor, } -fn rotate_half(xs: &Tensor) -> Result { - let last_dim = xs.dim(D::Minus1)?; - let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; - let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; - Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) -} - impl RotaryEmbedding { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { let dim = cfg.head_dim; @@ -94,7 +87,6 @@ impl RotaryEmbedding { .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, @@ -110,10 +102,8 @@ impl RotaryEmbedding { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; - let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; - let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; Ok((q_embed, k_embed)) } } @@ -163,10 +153,16 @@ struct Attention { head_dim: usize, rotary_emb: Arc, kv_cache: Option<(Tensor, Tensor)>, + use_flash_attn: bool, } impl Attention { - fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + fn new( + rotary_emb: Arc, + use_flash_attn: bool, + cfg: &Config, + vb: VarBuilder, + ) -> Result { let hidden_sz = cfg.hidden_size; let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; @@ -188,6 +184,7 @@ impl Attention { head_dim, rotary_emb, kv_cache: None, + use_flash_attn, }) } @@ -231,7 +228,14 @@ impl Attention { let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; - let attn_output = { + let attn_output = if self.use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = query_states.transpose(1, 2)?; + let k = key_states.transpose(1, 2)?; + let v = value_states.transpose(1, 2)?; + let scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)? + } else { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; @@ -253,6 +257,22 @@ impl Attention { } } +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { + unimplemented!("compile with '--features flash-attn'") +} + #[derive(Debug, Clone)] struct DecoderLayer { self_attn: Attention, @@ -262,8 +282,13 @@ struct DecoderLayer { } impl DecoderLayer { - fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { - let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + fn new( + rotary_emb: Arc, + use_flash_attn: bool, + cfg: &Config, + vb: VarBuilder, + ) -> Result { + let self_attn = Attention::new(rotary_emb, use_flash_attn, cfg, vb.pp("self_attn"))?; let mlp = MLP::new(cfg, vb.pp("mlp"))?; let input_layernorm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; @@ -312,7 +337,7 @@ pub struct Model { } impl Model { - pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result { let vb_m = vb.pp("model"); let embed_tokens = candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; @@ -320,7 +345,8 @@ impl Model { let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let vb_l = vb_m.pp("layers"); for layer_idx in 0..cfg.num_hidden_layers { - let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + let layer = + DecoderLayer::new(rotary_emb.clone(), use_flash_attn, cfg, vb_l.pp(layer_idx))?; layers.push(layer) } let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; From faa943579ea0555edfb44804d5a3bc62ff9b751c Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Mon, 3 Jun 2024 08:50:30 -0400 Subject: [PATCH 04/75] Add a set_dtype method --- candle-nn/src/var_builder.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index d4ff447cc3..fdaa80ae63 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -211,6 +211,18 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { ..self } } + + /// Set the dtype of the VarBuilder. + pub fn set_dtype(self, dtype: DType) -> Self { + Self { + data: Arc::new(TensorData { + backend: self.data.backend.clone(), + dtype: dtype, + device: self.data.device.clone(), + }), + ..self + } + } } struct Zeros; From 696acaa43d23519428670d5c2669ee12be9e5b83 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Sun, 9 Jun 2024 11:38:10 -0400 Subject: [PATCH 05/75] Add more capability to slice_assign (#7) --- candle-core/src/lib.rs | 1 + candle-core/src/tensor.rs | 294 --------------------- candle-core/src/tensor_indexing.rs | 379 ++++++++++++++++++++++++++++ candle-core/tests/indexing_tests.rs | 113 ++++++++- 4 files changed, 484 insertions(+), 303 deletions(-) create mode 100644 candle-core/src/tensor_indexing.rs diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 8a46377b59..8a3a58a582 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -68,6 +68,7 @@ mod storage; mod strided_index; mod tensor; mod tensor_cat; +mod tensor_indexing; pub mod test_utils; pub mod utils; mod variable; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index baaa288bb5..fb690018f8 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1335,244 +1335,6 @@ impl Tensor { self.index_select(ids, 0) } - pub fn scatter_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { - let dim = dim.to_index(self.shape(), "scatter-add")?; - let source_dims = source.dims(); - let self_dims = self.dims(); - let mismatch = if source_dims.len() != self_dims.len() { - true - } else { - let mut mismatch = false; - for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { - if i != dim && d1 != d2 { - mismatch = true; - break; - } - } - mismatch - }; - if mismatch { - Err(Error::ShapeMismatchBinaryOp { - op: "scatter-add (self, src)", - lhs: self.shape().clone(), - rhs: source.shape().clone(), - } - .bt())? - } - if indexes.dims() != source.dims() { - Err(Error::ShapeMismatchBinaryOp { - op: "scatter-add (indexes, src)", - lhs: indexes.shape().clone(), - rhs: source.shape().clone(), - } - .bt())? - } - let storage = self.storage().scatter_add( - self.layout(), - &indexes.storage(), - indexes.layout(), - &source.storage(), - source.layout(), - dim, - )?; - let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| { - Op::ScatterAdd(t1, t2, t3, dim) - }); - Ok(from_storage(storage, self.shape(), op, false)) - } - - /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension. - pub fn slice_scatter(&self, src: &Self, dim: D, start: usize) -> Result { - let dim = dim.to_index(self.shape(), "slice-scatter")?; - if dim == 0 { - self.slice_scatter0(src, start) - } else { - // TODO: Maybe we want to add a more efficient implementation at some point. - self.transpose(0, dim)? - .slice_scatter0(&src.transpose(0, dim)?, start)? - .transpose(0, dim) - } - } - - /// Embeds the values of the `src` tensor into the `self` tensor on the first dimension. - pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result { - if self.dtype() != src.dtype() { - Err(Error::DTypeMismatchBinaryOp { - lhs: self.dtype(), - rhs: src.dtype(), - op: "slice-scatter", - } - .bt())? - } - if self.device().location() != src.device.location() { - Err(Error::DeviceMismatchBinaryOp { - lhs: self.device().location(), - rhs: src.device().location(), - op: "slice-scatter", - } - .bt())? - } - if self.rank() != src.rank() { - Err(Error::UnexpectedNumberOfDims { - expected: self.rank(), - got: src.rank(), - shape: src.shape().clone(), - } - .bt())? - } - let shape_ok = - self.dims() - .iter() - .zip(src.dims().iter()) - .enumerate() - .all(|(dim_idx, (&d1, &d2))| { - if 0 == dim_idx { - d2 + start <= d1 - } else { - d1 == d2 - } - }); - if !shape_ok { - Err(Error::ShapeMismatchBinaryOp { - op: "slice-scatter (self, src)", - lhs: self.shape().clone(), - rhs: src.shape().clone(), - } - .bt())? - } - let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? }; - self.storage() - .copy_strided_src(&mut storage, 0, self.layout())?; - let offset = start * src.dims()[1..].iter().product::(); - src.storage() - .copy_strided_src(&mut storage, offset, src.layout())?; - let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start)); - Ok(from_storage(storage, self.shape(), op, false)) - } - - /// Accumulate element from `source` at indexes `indexes` and add them to `self`. - pub fn index_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { - let dim = dim.to_index(self.shape(), "index-add")?; - let source_dims = source.dims(); - let self_dims = self.dims(); - let mismatch = if source_dims.len() != self_dims.len() { - true - } else { - let mut mismatch = false; - for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { - if i != dim && d1 != d2 { - mismatch = true; - break; - } - } - mismatch - }; - if mismatch { - Err(Error::ShapeMismatchBinaryOp { - op: "index-add (self, source)", - lhs: self.shape().clone(), - rhs: source.shape().clone(), - } - .bt())? - } - // The number of element in indexes must match the dimension on which the add is - // performed on the source tensor (and the index values from `indexes` are taken from - // the target tensor self) - let indexes_len = indexes.dims1()?; - if source_dims[dim] != indexes_len { - Err(Error::ShapeMismatchBinaryOp { - op: "index-add (ids, source))", - lhs: indexes.shape().clone(), - rhs: source.shape().clone(), - } - .bt())? - } - let storage = self.storage().index_add( - self.layout(), - &indexes.storage(), - indexes.layout(), - &source.storage(), - source.layout(), - dim, - )?; - let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| { - Op::IndexAdd(t1, t2, t3, dim) - }); - Ok(from_storage(storage, self.shape(), op, false)) - } - - /// Gather values across the target dimension. - /// - /// # Arguments - /// - /// * `self` - The input tensor. - /// * `indexes` - The indices of elements to gather, this should have the same shape as `self` - /// but can have a different number of elements on the target dimension. - /// * `dim` - the target dimension. - /// - /// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on - /// dimension `dim` by the values in `indexes`. - pub fn gather(&self, indexes: &Self, dim: D) -> Result { - let dim = dim.to_index(self.shape(), "gather")?; - let self_dims = self.dims(); - let indexes_dims = indexes.dims(); - let mismatch = if indexes_dims.len() != self_dims.len() { - true - } else { - let mut mismatch = false; - for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() { - if i != dim && d1 != d2 { - mismatch = true; - break; - } - } - mismatch - }; - if mismatch { - Err(Error::ShapeMismatchBinaryOp { - op: "gather", - lhs: self.shape().clone(), - rhs: indexes.shape().clone(), - } - .bt())? - } - let storage = - self.storage() - .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?; - let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim)); - Ok(from_storage(storage, indexes.shape(), op, false)) - } - - /// Select values for the input tensor at the target indexes across the specified dimension. - /// - /// The `indexes` is argument is an int tensor with a single dimension. - /// The output has the same number of dimension as the `self` input. The target dimension of - /// the output has length the length of `indexes` and the values are taken from `self` using - /// the index from `indexes`. Other dimensions have the same number of elements as the input - /// tensor. - pub fn index_select(&self, indexes: &Self, dim: D) -> Result { - let dim = dim.to_index(self.shape(), "index-select")?; - let indexes_len = match indexes.dims() { - [l] => *l, - _ => Err(Error::ShapeMismatchBinaryOp { - lhs: self.shape().clone(), - rhs: indexes.shape().clone(), - op: "index-select", - } - .bt())?, - }; - let storage = self.storage().index_select( - &indexes.storage(), - self.layout(), - indexes.layout(), - dim, - )?; - let mut dims = self.dims().to_vec(); - dims[dim] = indexes_len; - let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim)); - Ok(from_storage(storage, dims, op, false)) - } - /// Returns an iterator over position of the elements in the storage when ranging over the /// index tuples in lexicographic order. pub fn strided_index(&self) -> crate::StridedIndex { @@ -2443,62 +2205,6 @@ impl Tensor { } } - /// Returns a copy of `self` where the values within `ranges` have been replaced with the - /// content of `src`. - pub fn slice_assign>( - &self, - ranges: &[D], - src: &Tensor, - ) -> Result { - let src_dims = src.dims(); - let self_dims = self.dims(); - if self_dims.len() != src_dims.len() { - bail!( - "slice-assign requires input with the same rank {} <> {}", - self_dims.len(), - src_dims.len() - ) - } - if self_dims.len() != ranges.len() { - bail!( - "slice-assign requires input with the same rank as there are ranges {} <> {}", - self_dims.len(), - ranges.len() - ) - } - let mut src = src.clone(); - let mut mask = Self::ones(src.shape(), DType::U8, src.device())?; - for (i, range) in ranges.iter().enumerate() { - let start_included = match range.start_bound() { - std::ops::Bound::Unbounded => 0, - std::ops::Bound::Included(v) => *v, - std::ops::Bound::Excluded(v) => *v + 1, - }; - let end_excluded = match range.end_bound() { - std::ops::Bound::Unbounded => self_dims[i], - std::ops::Bound::Included(v) => *v + 1, - std::ops::Bound::Excluded(v) => *v, - }; - if end_excluded <= start_included { - bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}") - } - if self_dims[i] < end_excluded { - bail!( - "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}", - self_dims[i] - ) - } - if end_excluded - start_included != src_dims[i] { - bail!( - "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i] - ) - } - src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?; - mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)? - } - mask.where_cond(/* on_true= */ &src, /* on_false= */ self) - } - /// Returns log(sum(exp(tensor), dim)). pub fn log_sum_exp(&self, sum_dims: D) -> Result { let exp = self.exp()?; diff --git a/candle-core/src/tensor_indexing.rs b/candle-core/src/tensor_indexing.rs new file mode 100644 index 0000000000..140876456b --- /dev/null +++ b/candle-core/src/tensor_indexing.rs @@ -0,0 +1,379 @@ +use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; + +use crate::{ + bail, + op::{BackpropOp, Op}, + shape::Dim, + tensor::from_storage, + DType, Error, Result, Tensor, +}; + +/// Specialization of `std::ops::RangeBounds` for `usize` to allow trait objects. +pub trait RangeBound { + fn start_bound(&self) -> std::ops::Bound; + fn end_bound(&self) -> std::ops::Bound; +} + +macro_rules! range_bound { + ($name:ident) => { + impl RangeBound for $name { + fn end_bound(&self) -> std::ops::Bound { + >::end_bound(&self).cloned() + } + fn start_bound(&self) -> std::ops::Bound { + >::start_bound(&self).cloned() + } + } + }; + // Use the marker to designate no generics + ($name:ident, $marker:expr) => { + impl RangeBound for $name { + fn end_bound(&self) -> std::ops::Bound { + >::end_bound(&self).cloned() + } + fn start_bound(&self) -> std::ops::Bound { + >::start_bound(&self).cloned() + } + } + }; + // Use the marker to designate no generics + ($name:ty) => { + impl RangeBound for $name { + fn end_bound(&self) -> std::ops::Bound { + >::end_bound(&self).cloned() + } + fn start_bound(&self) -> std::ops::Bound { + >::start_bound(&self).cloned() + } + } + }; +} + +range_bound!(Range); +range_bound!(RangeFrom); +range_bound!(RangeFull, ()); +range_bound!(RangeInclusive); +range_bound!(RangeTo); +range_bound!(RangeToInclusive); +range_bound!((std::ops::Bound, std::ops::Bound)); + +impl RangeBound for usize { + fn end_bound(&self) -> std::ops::Bound { + std::ops::Bound::Excluded(self + 1) + } + fn start_bound(&self) -> std::ops::Bound { + std::ops::Bound::Included(*self) + } +} + +impl Tensor { + /// Returns a copy of `self` where the values within `ranges` have been replaced with the + /// content of `src`. This is analogous to slice asignment in `torch`. + /// + /// # Example + /// ```rust + /// use candle_core::{Device, Tensor}; + /// + /// let dev = Device::Cpu; + /// let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + /// let src = Tensor::arange(100u32, (2 * 3) + 100, &dev)?.reshape((3, 2))?; + /// let out = tensor.slice_assign(&[&(..3), &(3..5)], &src)?; + /// assert_eq!( + /// out.to_vec2::()?, + /// &[ + /// [0, 1, 2, 100, 101], + /// [5, 6, 7, 102, 103], + /// [10, 11, 12, 104, 105], + /// [15, 16, 17, 18, 19] + /// ] + /// ); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn slice_assign(&self, ranges: &[&dyn RangeBound], src: &Tensor) -> Result { + let src_dims = src.dims(); + let self_dims = self.dims(); + if self_dims.len() != src_dims.len() { + bail!( + "slice-assign requires input with the same rank {} <> {}", + self_dims.len(), + src_dims.len() + ) + } + if self_dims.len() != ranges.len() { + bail!( + "slice-assign requires input with the same rank as there are ranges {} <> {}", + self_dims.len(), + ranges.len() + ) + } + let mut src = src.clone(); + let mut mask = Self::ones(src.shape(), DType::U8, src.device())?; + for (i, range) in ranges.iter().enumerate() { + let start_included = match range.start_bound() { + std::ops::Bound::Unbounded => 0, + std::ops::Bound::Included(v) => v, + std::ops::Bound::Excluded(v) => v + 1, + }; + let end_excluded = match range.end_bound() { + std::ops::Bound::Unbounded => self_dims[i], + std::ops::Bound::Included(v) => v + 1, + std::ops::Bound::Excluded(v) => v, + }; + if end_excluded <= start_included { + bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}") + } + if self_dims[i] < end_excluded { + bail!( + "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}", + self_dims[i] + ) + } + if end_excluded - start_included != src_dims[i] { + bail!( + "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i] + ) + } + src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?; + mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)? + } + mask.where_cond(/* on_true= */ &src, /* on_false= */ self) + } + + pub fn scatter_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "scatter-add")?; + let source_dims = source.dims(); + let self_dims = self.dims(); + let mismatch = if source_dims.len() != self_dims.len() { + true + } else { + let mut mismatch = false; + for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { + if i != dim && d1 != d2 { + mismatch = true; + break; + } + } + mismatch + }; + if mismatch { + Err(Error::ShapeMismatchBinaryOp { + op: "scatter-add (self, src)", + lhs: self.shape().clone(), + rhs: source.shape().clone(), + } + .bt())? + } + if indexes.dims() != source.dims() { + Err(Error::ShapeMismatchBinaryOp { + op: "scatter-add (indexes, src)", + lhs: indexes.shape().clone(), + rhs: source.shape().clone(), + } + .bt())? + } + let storage = self.storage().scatter_add( + self.layout(), + &indexes.storage(), + indexes.layout(), + &source.storage(), + source.layout(), + dim, + )?; + let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| { + Op::ScatterAdd(t1, t2, t3, dim) + }); + Ok(from_storage(storage, self.shape(), op, false)) + } + + /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension. + pub fn slice_scatter(&self, src: &Self, dim: D, start: usize) -> Result { + let dim = dim.to_index(self.shape(), "slice-scatter")?; + if dim == 0 { + self.slice_scatter0(src, start) + } else { + // TODO: Maybe we want to add a more efficient implementation at some point. + self.transpose(0, dim)? + .slice_scatter0(&src.transpose(0, dim)?, start)? + .transpose(0, dim) + } + } + + /// Embeds the values of the `src` tensor into the `self` tensor on the first dimension. + pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result { + if self.dtype() != src.dtype() { + Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: src.dtype(), + op: "slice-scatter", + } + .bt())? + } + if self.device().location() != src.device().location() { + Err(Error::DeviceMismatchBinaryOp { + lhs: self.device().location(), + rhs: src.device().location(), + op: "slice-scatter", + } + .bt())? + } + if self.rank() != src.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: self.rank(), + got: src.rank(), + shape: src.shape().clone(), + } + .bt())? + } + let shape_ok = + self.dims() + .iter() + .zip(src.dims().iter()) + .enumerate() + .all(|(dim_idx, (&d1, &d2))| { + if 0 == dim_idx { + d2 + start <= d1 + } else { + d1 == d2 + } + }); + if !shape_ok { + Err(Error::ShapeMismatchBinaryOp { + op: "slice-scatter (self, src)", + lhs: self.shape().clone(), + rhs: src.shape().clone(), + } + .bt())? + } + let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? }; + self.storage() + .copy_strided_src(&mut storage, 0, self.layout())?; + let offset = start * src.dims()[1..].iter().product::(); + src.storage() + .copy_strided_src(&mut storage, offset, src.layout())?; + let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start)); + Ok(from_storage(storage, self.shape(), op, false)) + } + + /// Accumulate element from `source` at indexes `indexes` and add them to `self`. + pub fn index_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "index-add")?; + let source_dims = source.dims(); + let self_dims = self.dims(); + let mismatch = if source_dims.len() != self_dims.len() { + true + } else { + let mut mismatch = false; + for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { + if i != dim && d1 != d2 { + mismatch = true; + break; + } + } + mismatch + }; + if mismatch { + Err(Error::ShapeMismatchBinaryOp { + op: "index-add (self, source)", + lhs: self.shape().clone(), + rhs: source.shape().clone(), + } + .bt())? + } + // The number of element in indexes must match the dimension on which the add is + // performed on the source tensor (and the index values from `indexes` are taken from + // the target tensor self) + let indexes_len = indexes.dims1()?; + if source_dims[dim] != indexes_len { + Err(Error::ShapeMismatchBinaryOp { + op: "index-add (ids, source))", + lhs: indexes.shape().clone(), + rhs: source.shape().clone(), + } + .bt())? + } + let storage = self.storage().index_add( + self.layout(), + &indexes.storage(), + indexes.layout(), + &source.storage(), + source.layout(), + dim, + )?; + let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| { + Op::IndexAdd(t1, t2, t3, dim) + }); + Ok(from_storage(storage, self.shape(), op, false)) + } + + /// Gather values across the target dimension. + /// + /// # Arguments + /// + /// * `self` - The input tensor. + /// * `indexes` - The indices of elements to gather, this should have the same shape as `self` + /// but can have a different number of elements on the target dimension. + /// * `dim` - the target dimension. + /// + /// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on + /// dimension `dim` by the values in `indexes`. + pub fn gather(&self, indexes: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "gather")?; + let self_dims = self.dims(); + let indexes_dims = indexes.dims(); + let mismatch = if indexes_dims.len() != self_dims.len() { + true + } else { + let mut mismatch = false; + for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() { + if i != dim && d1 != d2 { + mismatch = true; + break; + } + } + mismatch + }; + if mismatch { + Err(Error::ShapeMismatchBinaryOp { + op: "gather", + lhs: self.shape().clone(), + rhs: indexes.shape().clone(), + } + .bt())? + } + let storage = + self.storage() + .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?; + let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim)); + Ok(from_storage(storage, indexes.shape(), op, false)) + } + + /// Select values for the input tensor at the target indexes across the specified dimension. + /// + /// The `indexes` is argument is an int tensor with a single dimension. + /// The output has the same number of dimension as the `self` input. The target dimension of + /// the output has length the length of `indexes` and the values are taken from `self` using + /// the index from `indexes`. Other dimensions have the same number of elements as the input + /// tensor. + pub fn index_select(&self, indexes: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "index-select")?; + let indexes_len = match indexes.dims() { + [l] => *l, + _ => Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: indexes.shape().clone(), + op: "index-select", + } + .bt())?, + }; + let storage = self.storage().index_select( + &indexes.storage(), + self.layout(), + indexes.layout(), + dim, + )?; + let mut dims = self.dims().to_vec(); + dims[dim] = indexes_len; + let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim)); + Ok(from_storage(storage, dims, op, false)) + } +} diff --git a/candle-core/tests/indexing_tests.rs b/candle-core/tests/indexing_tests.rs index 047205a31f..417d54a41f 100644 --- a/candle-core/tests/indexing_tests.rs +++ b/candle-core/tests/indexing_tests.rs @@ -93,28 +93,123 @@ fn index_3d() -> Result<()> { } #[test] -fn slice_assign() -> Result<()> { +fn slice_assign_range() -> Result<()> { let dev = Device::Cpu; let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; - let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?; - let out = tensor.slice_assign(&[1..4, 3..5], &src)?; + let src = Tensor::arange(100u32, (2 * 3) + 100, &dev)?.reshape((3, 2))?; + let out = tensor.slice_assign(&[&(1..4), &(3..5)], &src)?; assert_eq!( out.to_vec2::()?, &[ [0, 1, 2, 3, 4], - [5, 6, 7, 0, 1], - [10, 11, 12, 2, 3], - [15, 16, 17, 4, 5] + [5, 6, 7, 100, 101], + [10, 11, 12, 102, 103], + [15, 16, 17, 104, 105] ] ); - let out = tensor.slice_assign(&[0..3, 0..2], &src)?; + let out = tensor.slice_assign(&[&(0..3), &(0..2)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [100, 101, 2, 3, 4], + [102, 103, 7, 8, 9], + [104, 105, 12, 13, 14], + [15, 16, 17, 18, 19] + ] + ); + Ok(()) +} + +#[test] +fn slice_assign_to() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, (2 * 3) + 100, &dev)?.reshape((3, 2))?; + let out = tensor.slice_assign(&[&(..3), &(3..5)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 1, 2, 100, 101], + [5, 6, 7, 102, 103], + [10, 11, 12, 104, 105], + [15, 16, 17, 18, 19] + ] + ); + Ok(()) +} + +#[test] +fn slice_assign_from() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, (2 * 3) + 100, &dev)?.reshape((3, 2))?; + let out = tensor.slice_assign(&[&(1..), &(0..2)], &src)?; assert_eq!( out.to_vec2::()?, &[ [0, 1, 2, 3, 4], - [2, 3, 7, 8, 9], - [4, 5, 12, 13, 14], + [100, 101, 7, 8, 9], + [102, 103, 12, 13, 14], + [104, 105, 17, 18, 19] + ] + ); + Ok(()) +} + +#[test] +fn slice_assign_to_incl() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, (2 * 3) + 100, &dev)?.reshape((3, 2))?; + let out = tensor.slice_assign(&[&(..=2), &(1..3)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 100, 101, 3, 4], + [5, 102, 103, 8, 9], + [10, 104, 105, 13, 14], + [15, 16, 17, 18, 19] + ] + ); + Ok(()) +} + +#[test] +fn slice_assign_full() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, (2 * 4) + 100, &dev)?.reshape((4, 2))?; + let out = tensor.slice_assign(&[&(..), &(3..5)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 1, 2, 100, 101], + [5, 6, 7, 102, 103], + [10, 11, 12, 104, 105], + [15, 16, 17, 106, 107] + ] + ); + Ok(()) +} + +#[test] +fn slice_assign_exact() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, 2 + 100, &dev)?.reshape((1, 2))?; + let out = tensor.slice_assign(&[&0, &(3..5)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 1, 2, 100, 101], + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], [15, 16, 17, 18, 19] ] ); From 0936406f36644e26519660efd0be1dca9add2c8d Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Sun, 9 Jun 2024 13:17:10 -0400 Subject: [PATCH 06/75] Implement unfold (#8) * Add unfold * Format --- candle-core/src/tensor.rs | 43 +++++++++++++++++++++++++++++++ candle-core/tests/tensor_tests.rs | 12 +++++++++ candle-nn/src/var_builder.rs | 2 +- 3 files changed, 56 insertions(+), 1 deletion(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index fb690018f8..403357f691 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2221,6 +2221,49 @@ impl Tensor { pub fn broadcast_pow(&self, rhs: &Tensor) -> Result { rhs.broadcast_mul(&self.log()?)?.exp() } + + /// Returns a view of which contains all slices of size `size` from self tensor in the dimension + /// `dim` and stepped by `step`. + pub fn unfold(&self, dim: D, size: usize, step: usize) -> Result { + // https://github.com/pytorch/pytorch/blob/75b0720a97ac5d82e8a7a1a6ae7c5f7a87d7183d/aten/src/ATen/native/TensorShape.cpp#L3785-L3804 + let mut sizes = self.dims().to_vec(); + let mut strides = self.stride().to_vec(); + + let dim = dim.to_index(self.shape(), "unfold")?; + + let max_len = if self.dims().is_empty() { + 1 + } else { + sizes[dim] + }; + if size > max_len { + bail!( + "unsqueeze: maximum size for tensor at dimension {dim} is {max_len} but size is {size}" + ) + } + sizes.push(size); + strides.push(if self.dims().is_empty() { + 1 + } else { + strides[dim] + }); + + if !self.dims().is_empty() { + sizes[dim] = ((sizes[dim] as f32 - size as f32) / step as f32 + 1.) as usize; + strides[dim] *= step; + } + + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: Layout::new(sizes.into(), strides, self.layout.start_offset()), + op: BackpropOp::new1(self, Op::Reshape), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } } macro_rules! bin_trait { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index cd5f4ca148..975f40ac95 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1345,3 +1345,15 @@ fn pow() -> Result<()> { ); Ok(()) } + +#[test] +fn unfold() -> Result<()> { + let x = Tensor::arange(0i64, 3 * 2, &Device::Cpu)?.reshape((3, 2))?; + let unfolded = x.unfold(0, 2, 1)?; + dbg!(&unfolded); + assert_eq!( + unfolded.to_vec3::()?, + vec![[[0i64, 2], [1, 3]], [[2, 4], [3, 5]]] + ); + Ok(()) +} diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index fdaa80ae63..2d288be0de 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -217,7 +217,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { Self { data: Arc::new(TensorData { backend: self.data.backend.clone(), - dtype: dtype, + dtype, device: self.data.device.clone(), }), ..self From f52e2347b6237d19ffd7af26315f543c22f9f286 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Tue, 11 Jun 2024 13:45:02 -0400 Subject: [PATCH 07/75] Bump cudarc to 0.11.5 (#10) --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 810cb51aed..3e5db36598 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,7 @@ candle-onnx = { path = "./candle-onnx", version = "0.6.0" } candle-transformers = { path = "./candle-transformers", version = "0.6.0" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features=false } -cudarc = { version = "0.11.4", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } +cudarc = { version = "0.11.5", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } fancy-regex = "0.13.0" gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } hf-hub = "0.3.0" From bb8f6f0e0b8abf07226c662a8088d89e6bec6583 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Sat, 29 Jun 2024 18:42:18 -0400 Subject: [PATCH 08/75] Add QTensor::quantize_onto (#12) * Add the quantize_onto api * Take ref * Clippy * Format * Add error checking --- candle-core/src/cpu_backend/mod.rs | 3 +- candle-core/src/cpu_backend/utils.rs | 20 ++++++++--- candle-core/src/quantized/cuda.rs | 17 ++++++++++ candle-core/src/quantized/dummy_cuda.rs | 4 +++ candle-core/src/quantized/dummy_metal.rs | 4 +++ candle-core/src/quantized/metal.rs | 17 ++++++++++ candle-core/src/quantized/mod.rs | 42 +++++++++++++++++++++++- candle-transformers/src/models/vgg.rs | 9 +++-- 8 files changed, 104 insertions(+), 12 deletions(-) diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 18b73e9b60..58773c8020 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -121,7 +121,8 @@ impl ReduceIndex { let dst_len = src_l.shape().elem_count() / reduce_dim_size; let mut dst: Vec = Vec::with_capacity(dst_len); let dst_to_set = dst.spare_capacity_mut(); - let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) }; + let dst_to_set = + unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [U]>(dst_to_set) }; match src_l.contiguous_offsets() { Some((o1, o2)) => { let src = &src[o1..o2]; diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index af25a2aff9..3e0c69b4f7 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -174,7 +174,9 @@ pub fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [ (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => { let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; + let ys_to_set = unsafe { + std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [T]>(ys_to_set) + }; f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set); // SAFETY: values are all set by f_vec. unsafe { ys.set_len(el_count) }; @@ -185,7 +187,9 @@ pub fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [ let rhs = &rhs[ob.start..ob.start + ob.len]; let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; + let ys_to_set = unsafe { + std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [T]>(ys_to_set) + }; let mut dst_i = 0; for src_i in (o_l1..o_l2).step_by(ob.len) { f_vec( @@ -224,7 +228,9 @@ pub fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [ let lhs = &lhs[ob.start..ob.start + ob.len]; let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; + let ys_to_set = unsafe { + std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [T]>(ys_to_set) + }; let mut dst_i = 0; for src_i in (o_r1..o_r2).step_by(ob.len) { f_vec( @@ -311,7 +317,9 @@ pub fn unary_map_vec U, FV: FnMut(&[T], &mut [U crate::StridedBlocks::SingleBlock { start_offset, len } => { let mut ys: Vec = Vec::with_capacity(len); let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; + let ys_to_set = unsafe { + std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [U]>(ys_to_set) + }; f_vec(&vs[start_offset..start_offset + len], ys_to_set); // SAFETY: values are all set by f_vec. unsafe { ys.set_len(len) }; @@ -333,7 +341,9 @@ pub fn unary_map_vec U, FV: FnMut(&[T], &mut [U } else { let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; + let ys_to_set = unsafe { + std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [U]>(ys_to_set) + }; let mut dst_index = 0; for src_index in block_start_index { let vs = &vs[src_index..src_index + block_len]; diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 8e4884b28d..6318d673cf 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -449,6 +449,23 @@ impl QCudaStorage { Ok(()) } + pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> { + // Run the quantization on cpu. + let src_len = src.as_slice::()?.len(); + let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?; + + if let QStorage::Cpu(storage) = &mut qcpu_storage { + storage.from_float(src.as_slice::()?)?; + } else { + unreachable!() + } + + let data = qcpu_storage.data()?; + let data = self.device.htod_sync_copy(data.as_ref()).w()?; + self.data = data; + Ok(()) + } + pub fn storage_size_in_bytes(&self) -> usize { self.data.len() } diff --git a/candle-core/src/quantized/dummy_cuda.rs b/candle-core/src/quantized/dummy_cuda.rs index ca7b812084..69daad3cc4 100644 --- a/candle-core/src/quantized/dummy_cuda.rs +++ b/candle-core/src/quantized/dummy_cuda.rs @@ -32,6 +32,10 @@ impl QCudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + pub fn storage_size_in_bytes(&self) -> usize { 0 } diff --git a/candle-core/src/quantized/dummy_metal.rs b/candle-core/src/quantized/dummy_metal.rs index 520d0ed49a..fc51214c19 100644 --- a/candle-core/src/quantized/dummy_metal.rs +++ b/candle-core/src/quantized/dummy_metal.rs @@ -28,6 +28,10 @@ impl QMetalStorage { Err(Error::NotCompiledWithMetalSupport) } + pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + pub fn storage_size_in_bytes(&self) -> usize { 0 } diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index f7f5b68ac2..5f61749b36 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -126,6 +126,23 @@ impl QMetalStorage { Ok(()) } + pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> { + // Quantization only happens on CPU for now. + let elem_count = src.as_slice::()?.len(); + let src = crate::Storage::Cpu(src); + let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?; + + if let QStorage::Cpu(storage) = &mut qcpu_storage { + storage.from_float(src.as_slice::()?)?; + } else { + unreachable!() + } + + let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?; + self.buffer = buffer; + Ok(()) + } + pub fn storage_size_in_bytes(&self) -> usize { self.buffer.length() as usize } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index d852d50410..ff00f36389 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -101,7 +101,19 @@ impl QStorage { } (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?, (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?, - _ => crate::bail!("Invalid dequantize storage locations do not match"), + _ => crate::bail!("Invalid quantize storage locations do not match"), + } + Ok(()) + } + + fn quantize_onto(&mut self, src: &Storage) -> Result<()> { + match (self, src) { + (QStorage::Cpu(storage), Storage::Cpu(src)) => { + storage.from_float(src.as_slice::()?)?; + } + (QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?, + (QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?, + _ => crate::bail!("Invalid quantize source storage locations: not on cpu"), } Ok(()) } @@ -341,6 +353,34 @@ impl QTensor { }) } + /// Quantize `src` (currently on the CPU) to a QTensor on `dev` + pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result { + if !src.device().is_cpu() { + crate::bail!( + "`quantize_onto` expects a `src` to be on the cpu, got {:?}.", + src.device() + ) + } + let shape = src.shape(); + let block_size = dtype.block_size(); + check_shape(shape, block_size)?; + let src = src.to_dtype(crate::DType::F32)?.flatten_all()?; + let elem_count = shape.elem_count(); + if elem_count % block_size != 0 { + crate::bail!( + "tensor size ({shape:?}) is not divisible by block size {}", + block_size + ) + } + // storage is on the `dev`, src is on `cpu` + let mut storage = dev.qzeros(elem_count, dtype)?; + storage.quantize_onto(&src.storage())?; + Ok(Self { + storage, + shape: shape.clone(), + }) + } + pub fn dtype(&self) -> GgmlDType { self.storage.dtype() } diff --git a/candle-transformers/src/models/vgg.rs b/candle-transformers/src/models/vgg.rs index a20b5e3725..7c8dad510e 100644 --- a/candle-transformers/src/models/vgg.rs +++ b/candle-transformers/src/models/vgg.rs @@ -54,18 +54,17 @@ impl ModuleT for Vgg<'_> { fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result> { let layers = convs .iter() - .enumerate() - .map(|(_, &(in_c, out_c, name))| { + .map(|(in_c, out_c, name)| { candle_nn::conv2d( - in_c, - out_c, + *in_c, + *out_c, 3, candle_nn::Conv2dConfig { stride: 1, padding: 1, ..Default::default() }, - vb.pp(name), + vb.pp(*name), ) }) .collect::>>()?; From 5b04d9638339068521240c123e81f5fbcbeece33 Mon Sep 17 00:00:00 2001 From: shua Date: Wed, 12 Jun 2024 08:15:32 +0200 Subject: [PATCH 09/75] implement Slice op (#2260) --- candle-onnx/src/eval.rs | 80 +++++++++++++++++++++++ candle-onnx/tests/ops.rs | 135 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 215 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index f52e6c5cca..10a3b9377b 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -14,6 +14,7 @@ pub fn dtype(dt: DataType) -> Option { DataType::Float16 => Some(DType::F16), DataType::Float => Some(DType::F32), DataType::Double => Some(DType::F64), + DataType::Bool => Some(DType::U8), _ => None, } } @@ -1053,6 +1054,85 @@ fn simple_eval_( ), } } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#slice + "Slice" => { + let data = get(&node.input[0])?; + let starts = get(&node.input[1])?; + let ends = get(&node.input[2])?; + let default_axes; + let default_steps; + let axes: &Tensor; + let steps: &Tensor; + // If axes are omitted, they are set to [0, ..., r-1]. If steps are omitted, + // they are set to [1, ..., 1] of length len(starts) + match node.input.len() { + 3 => { + let len = starts.dims()[0]; + default_axes = Some(Tensor::arange(0, len as i64, starts.device())?); + axes = default_axes.as_ref().unwrap(); + default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?); + steps = default_steps.as_ref().unwrap(); + } + 4 => { + let len = starts.dims()[0]; + axes = get(&node.input[3])?; + default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?); + steps = default_steps.as_ref().unwrap(); + } + 5 => { + steps = get(&node.input[4])?; + axes = get(&node.input[3])?; + } + _ => bail!( + "Slice node is invalid, expected 3-5 inputs, got {}: {:?}", + node.input.len(), + node + ), + } + + let mut out = data.clone(); + for (i, axis) in axes.to_vec1::()?.into_iter().enumerate() { + // All negative elements of axes are made non-negative by + // adding r to them, where r = rank(input). + let axis = if axis < 0 { + axis + data.rank() as i64 + } else { + axis + } as usize; + + let data_dim = data.dims()[axis] as i64; + let mut s = starts.get(i)?.to_scalar::()?; + let mut e = ends.get(i)?.to_scalar::()?; + // All negative values in starts[i] and ends[i] have + // dims[axes[i]] added to them, where dims are the + // dimensions of input. + if s < 0 { + s += data_dim; + } + if e < 0 { + e += data_dim; + } + + let p = steps.get(i)?.to_scalar::()?; + // starts[i] is clamped into the range [0, dims[axes[i]]] + // for positive stepping and [0, dims[axes[i]]-1] for + // negative stepping. + // for positive stepping ends[axes[i]] is clamped to + // [0, dims[axes[i]]], while for negative stepping it is + // clamped to [-1, dims[axes[i]]-1]. + if p >= 0 { + s = s.clamp(0, data_dim); + e = e.clamp(0, data_dim); + } else { + s = s.clamp(0, data_dim - 1); + e = e.clamp(-1, data_dim - 1); + } + + let indexes = Tensor::arange_step(s, e, p, data.device())?; + out = out.index_select(&indexes, axis)? + } + values.insert(node.output[0].clone(), out); + } // https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13 // TODO: This version is only compatible with ReduceMean V13 and below. "ReduceMean" => { diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index b4299af1bc..82d38aa490 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -3272,3 +3272,138 @@ fn test_pad() -> Result<()> { assert_eq!(actual.to_vec2::()?, expected.to_vec2::()?); Ok(()) } + +#[test] +fn test_slice() -> Result<()> { + let model = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Slice".to_string(), + input: vec![ + "data".to_string(), + "starts".to_string(), + "ends".to_string(), + "axes".to_string(), + "steps".to_string(), + ], + output: vec!["result".to_string()], + ..NodeProto::default() + }], + input: ["data", "starts", "ends", "axes", "steps"] + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + r#type: None, + doc_string: "".to_string(), + }) + .collect(), + output: ["result"] + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + r#type: None, + doc_string: "".to_string(), + }) + .collect(), + ..GraphProto::default() + })); + + /* + data = [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + axes = [0, 1] + starts = [1, 0] + ends = [2, 3] + steps = [1, 2] + result = [ + [5, 7], + ] + */ + + let outputs = candle_onnx::simple_eval( + &model, + HashMap::from_iter([ + ( + "data".to_string(), + Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?, + ), + ( + "starts".to_string(), + Tensor::from_vec(vec![1i64, 0], (2,), &Device::Cpu)?, + ), + ( + "ends".to_string(), + Tensor::from_vec(vec![2i64, 3], (2,), &Device::Cpu)?, + ), + ( + "axes".to_string(), + Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?, + ), + ( + "steps".to_string(), + Tensor::from_vec(vec![1i64, 2], (2,), &Device::Cpu)?, + ), + ]), + )?; + let actual = outputs.get("result").unwrap().to_vec2::()?; + assert_eq!(actual, vec![vec![5i64, 7]]); + + /* + data = [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + starts = [0, 1] + ends = [-1, 1000] + result = [ + [2, 3, 4], + ] + */ + let model = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Slice".to_string(), + input: vec!["data".to_string(), "starts".to_string(), "ends".to_string()], + output: vec!["result".to_string()], + ..NodeProto::default() + }], + input: ["data", "starts", "ends"] + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + r#type: None, + doc_string: "".to_string(), + }) + .collect(), + output: ["result"] + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + r#type: None, + doc_string: "".to_string(), + }) + .collect(), + ..GraphProto::default() + })); + let outputs = candle_onnx::simple_eval( + &model, + HashMap::from_iter([ + ( + "data".to_string(), + Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?, + ), + ( + "starts".to_string(), + Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?, + ), + ( + "ends".to_string(), + Tensor::from_vec(vec![-1i64, 1000], (2,), &Device::Cpu)?, + ), + ]), + )?; + let actual = outputs.get("result").unwrap().to_vec2::()?; + assert_eq!(actual, vec![vec![2i64, 3, 4]]); + + Ok(()) +} From f7095bb324df586ae900ea404e64ef9a003f14ac Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 18 Jun 2024 23:46:58 +0200 Subject: [PATCH 10/75] Fix the fast bf16 gemm cublas kernels. (#2274) * Use flash-attn in gemma. * Fix for the fast bf16 cublas gemm. * Fix some clippy lints. * Fix another lint. * Proper clippy fix. --- candle-core/examples/cuda_basics.rs | 5 ++++- candle-core/src/cuda_backend/mod.rs | 8 +++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index 00e937cb88..9af1b006e3 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -9,8 +9,10 @@ use candle_core::{Device, Tensor}; fn main() -> Result<()> { let device = Device::new_cuda(0)?; - let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?; + let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)? + .to_dtype(candle_core::DType::BF16)?; candle_core::cuda::set_gemm_reduced_precision_f32(false); + candle_core::cuda::set_gemm_reduced_precision_bf16(false); let _x1 = x.matmul(&x)?; drop(_x1); let start_time = std::time::Instant::now(); @@ -19,6 +21,7 @@ fn main() -> Result<()> { println!("fp32: {:?}", start_time.elapsed()); drop(_x1); candle_core::cuda::set_gemm_reduced_precision_f32(true); + candle_core::cuda::set_gemm_reduced_precision_bf16(true); let _x1 = x.matmul(&x)?; drop(_x1); let start_time = std::time::Instant::now(); diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 542fcc96af..eb70da71cd 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -2035,15 +2035,13 @@ unsafe fn gemm_strided_batched_bf16( let alpha_f32: f32 = cfg.gemm.alpha.to_f32(); let beta_f32: f32 = cfg.gemm.beta.to_f32(); - let alpha = f16::from_f32(alpha_f32); - let beta = f16::from_f32(beta_f32); // The type for alpha and beta depends on the computeType. // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() { ( - sys::cublasComputeType_t::CUBLAS_COMPUTE_16F, - (&alpha) as *const f16 as *const _, - (&beta) as *const f16 as *const _, + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF, + (&alpha_f32) as *const f32 as *const _, + (&beta_f32) as *const f32 as *const _, ) } else { ( From b55b360dbed872a891c711386c805ca1d671b3c2 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 22 Jun 2024 23:21:20 +0200 Subject: [PATCH 11/75] Fix a bug in the metal implemtation of col2im1d. (#2284) --- candle-core/src/metal_backend/mod.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 06f6cd3715..fa83692df7 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -848,7 +848,6 @@ impl BackendStorage for MetalStorage { .device .new_buffer(dst_el, self.dtype, "conv_transpose1d")?; - let command_buffer = self.device.command_buffer()?; let name = match self.dtype { DType::F32 => "col2im1d_f32", DType::U32 => "col2im1d_u32", @@ -869,6 +868,12 @@ impl BackendStorage for MetalStorage { &kernel_l_mm, )? }; + // It is important for the command buffer to be obtained *after* the matmul + // kernel has run, otherwise we might use a command-buffer that has been commited + // already resulting in the following error. + // _status < MTLCommandBufferStatusCommitted > + // -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:] + let command_buffer = self.device.command_buffer()?; candle_metal_kernels::call_col2im1d( &self.device.device, &command_buffer, From 08e93a6c5277119d9422bd447de41c1faf1ba5b7 Mon Sep 17 00:00:00 2001 From: Jeroen Vlek Date: Mon, 24 Jun 2024 19:12:52 +0200 Subject: [PATCH 12/75] Depth Anything v2 (#2279) * define structs * construct ResidualConvUnit * forward() for ResidualConvUnit * implement FeatureFusionBlock * implement Scratch * implement DPTHead * add identity module * implement forward for DTPHead * add get_intermediate_layers to DinoVisionTransformer * implement DepthAnythingV2 * some minor tweaks * fix compile errors * fix var builder prefixes * setup initial example * use fixed patch size of 37 (518 / 14) * debugged until output * print min and max values * add some dynamism to the output location * scale input image * extract prep function * extract output path function * normalize image with magic mean and std * add spectral coloring * squeeze in the right place * make enterpolation optional * use bail instead of panic * omit unnecessary Shape call * remove empty curly braces * use bail instead of assert * use vb and pp * remove closures * extract config object * Apply rustfmt. * Fix some clippy lints. * More lints. * Use the array methods. --------- Co-authored-by: laurent --- candle-examples/Cargo.toml | 7 + .../examples/depth_anything_v2/README.md | 13 + .../examples/depth_anything_v2/color_map.rs | 50 ++ .../examples/depth_anything_v2/main.rs | 187 ++++++ candle-nn/src/ops.rs | 23 +- .../src/models/depth_anything_v2.rs | 553 ++++++++++++++++++ candle-transformers/src/models/dinov2.rs | 78 +++ candle-transformers/src/models/mod.rs | 1 + 8 files changed, 911 insertions(+), 1 deletion(-) create mode 100644 candle-examples/examples/depth_anything_v2/README.md create mode 100644 candle-examples/examples/depth_anything_v2/color_map.rs create mode 100644 candle-examples/examples/depth_anything_v2/main.rs create mode 100644 candle-transformers/src/models/depth_anything_v2.rs diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 5b90f140c2..fa5c620a48 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -25,6 +25,8 @@ hf-hub = { workspace = true, features = ["tokio"] } image = { workspace = true } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } +palette = { version = "0.7.6", optional = true } +enterpolation = { version = "0.2.1", optional = true} pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true } rayon = { workspace = true } rubato = { version = "0.15.0", optional = true } @@ -65,6 +67,7 @@ onnx = ["candle-onnx"] metal = ["candle/metal", "candle-nn/metal"] microphone = ["cpal"] encodec = ["cpal", "symphonia", "rubato"] +depth_anything_v2 = ["palette", "enterpolation"] [[example]] name = "llama_multiprocess" @@ -101,3 +104,7 @@ required-features = ["candle-datasets"] [[example]] name = "encodec" required-features = ["encodec"] + +[[example]] +name = "depth_anything_v2" +required-features = ["depth_anything_v2"] diff --git a/candle-examples/examples/depth_anything_v2/README.md b/candle-examples/examples/depth_anything_v2/README.md new file mode 100644 index 0000000000..163b398b89 --- /dev/null +++ b/candle-examples/examples/depth_anything_v2/README.md @@ -0,0 +1,13 @@ +# candle-dinov2 + +[Depth Anything V2] is a model for Monocular Depth Estimation (MDE, i.e. just using a single image) which +builds on the [DINOv2](https://github.com/facebookresearch/dinov2) vision transformer. + +This example first instantiates the DINOv2 model and then proceeds to create DepthAnythingV2 and run it. + +## Running an example with color map and CUDA + +```bash +cargo run --features cuda,depth_anything_v2 --package candle-examples --example depth_anything_v2 -- --color-map --image candle-examples/examples/yolo-v8/assets/bike.jpg +``` + diff --git a/candle-examples/examples/depth_anything_v2/color_map.rs b/candle-examples/examples/depth_anything_v2/color_map.rs new file mode 100644 index 0000000000..94be326fc5 --- /dev/null +++ b/candle-examples/examples/depth_anything_v2/color_map.rs @@ -0,0 +1,50 @@ +use enterpolation::linear::ConstEquidistantLinear; +use enterpolation::Generator; +use palette::LinSrgb; + +use candle::Tensor; + +pub struct SpectralRColormap { + gradient: ConstEquidistantLinear, +} + +impl SpectralRColormap { + pub(crate) fn new() -> Self { + // Define a colormap similar to 'Spectral_r' by specifying key colors. + // got the colors from ChatGPT-4o + let gradient = ConstEquidistantLinear::::equidistant_unchecked([ + LinSrgb::new(0.3686, 0.3098, 0.6353), // Dark blue + LinSrgb::new(0.1961, 0.5333, 0.7412), // Blue + LinSrgb::new(0.4000, 0.7608, 0.6471), // Cyan + LinSrgb::new(0.6706, 0.8667, 0.6431), // Green + LinSrgb::new(0.9020, 0.9608, 0.5961), // Yellow + LinSrgb::new(0.9961, 0.8784, 0.5451), // Orange + LinSrgb::new(0.9922, 0.6824, 0.3804), // Red + LinSrgb::new(0.9569, 0.4275, 0.2627), // Dark red + LinSrgb::new(0.8353, 0.2431, 0.3098), // Dark purple + ]); + Self { gradient } + } + + fn get_color(&self, value: f32) -> LinSrgb { + self.gradient.gen(value) + } + + pub fn gray2color(&self, gray: &Tensor) -> candle::Result { + println!("Gray: {:?}", gray.dims()); + let gray_values: Vec = gray.flatten_all()?.to_vec1()?; + let rgb_values: Vec = gray_values + .iter() + .map(|g| self.get_color(*g)) + .flat_map(|rgb| [rgb.red, rgb.green, rgb.blue]) + .collect(); + + let [.., height, width] = gray.dims() else { + candle::bail!("Not enough dims!") + }; + + let color = Tensor::from_vec(rgb_values, (*height, *width, 3), gray.device())?; + + color.permute((2, 0, 1)) + } +} diff --git a/candle-examples/examples/depth_anything_v2/main.rs b/candle-examples/examples/depth_anything_v2/main.rs new file mode 100644 index 0000000000..ef337ebab4 --- /dev/null +++ b/candle-examples/examples/depth_anything_v2/main.rs @@ -0,0 +1,187 @@ +//! Depth Anything V2 +//! https://huggingface.co/spaces/depth-anything/Depth-Anything-V2 + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use std::ffi::OsString; +use std::path::PathBuf; + +use clap::Parser; + +use candle::DType::{F32, U8}; +use candle::{DType, Device, Module, Result, Tensor}; +use candle_examples::{load_image, load_image_and_resize, save_image}; +use candle_nn::VarBuilder; +use candle_transformers::models::depth_anything_v2::{DepthAnythingV2, DepthAnythingV2Config}; +use candle_transformers::models::dinov2; + +use crate::color_map::SpectralRColormap; + +mod color_map; + +// taken these from: https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py#L207 +const MAGIC_MEAN: [f32; 3] = [0.485, 0.456, 0.406]; +const MAGIC_STD: [f32; 3] = [0.229, 0.224, 0.225]; + +const DINO_IMG_SIZE: usize = 518; + +#[derive(Parser)] +struct Args { + #[arg(long)] + dinov2_model: Option, + + #[arg(long)] + depth_anything_v2_model: Option, + + #[arg(long)] + image: PathBuf, + + #[arg(long)] + output_dir: Option, + + #[arg(long)] + cpu: bool, + + #[arg(long)] + color_map: bool, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + let device = candle_examples::device(args.cpu)?; + + let dinov2_model_file = match args.dinov2_model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("lmz/candle-dino-v2".into()); + api.get("dinov2_vits14.safetensors")? + } + Some(dinov2_model) => dinov2_model, + }; + println!("Using file {:?}", dinov2_model_file); + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[dinov2_model_file], F32, &device)? }; + let dinov2 = dinov2::vit_small(vb)?; + println!("DinoV2 model built"); + + let depth_anything_model_file = match args.depth_anything_v2_model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("jeroenvlek/depth-anything-v2-safetensors".into()); + api.get("depth_anything_v2_vits.safetensors")? + } + Some(depth_anything_model) => depth_anything_model, + }; + println!("Using file {:?}", depth_anything_model_file); + + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&[depth_anything_model_file], DType::F32, &device)? + }; + + let config = DepthAnythingV2Config::vit_small(); + let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?; + + let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?; + + println!("Loaded image {image:?}"); + + let depth = depth_anything.forward(&image)?; + + println!("Got predictions {:?}", depth.shape()); + + let output_image = post_process_image(&depth, original_height, original_width, args.color_map)?; + + let output_path = full_output_path(&args.image, &args.output_dir); + println!("Saving image to {}", output_path.to_string_lossy()); + save_image(&output_image, output_path)?; + + Ok(()) +} + +fn full_output_path(image_path: &PathBuf, output_dir: &Option) -> PathBuf { + let input_file_name = image_path.file_name().unwrap(); + let mut output_file_name = OsString::from("depth_"); + output_file_name.push(input_file_name); + let mut output_path = match output_dir { + None => image_path.parent().unwrap().to_path_buf(), + Some(output_path) => output_path.clone(), + }; + output_path.push(output_file_name); + + output_path +} + +fn load_and_prep_image( + image_path: &PathBuf, + device: &Device, +) -> anyhow::Result<(usize, usize, Tensor)> { + let (_original_image, original_height, original_width) = load_image(&image_path, None)?; + + let image = load_image_and_resize(&image_path, DINO_IMG_SIZE, DINO_IMG_SIZE)? + .unsqueeze(0)? + .to_dtype(F32)? + .to_device(&device)?; + + let max_pixel_val = Tensor::try_from(255.0f32)? + .to_device(&device)? + .broadcast_as(image.shape())?; + let image = (image / max_pixel_val)?; + let image = normalize_image(&image, &MAGIC_MEAN, &MAGIC_STD)?; + + Ok((original_height, original_width, image)) +} + +fn normalize_image(image: &Tensor, mean: &[f32; 3], std: &[f32; 3]) -> Result { + let mean_tensor = + Tensor::from_vec(mean.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?; + let std_tensor = + Tensor::from_vec(std.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?; + image.sub(&mean_tensor)?.div(&std_tensor) +} + +fn post_process_image( + image: &Tensor, + original_height: usize, + original_width: usize, + color_map: bool, +) -> Result { + let out = image.interpolate2d(original_height, original_width)?; + let out = scale_image(&out)?; + + let out = if color_map { + let spectral_r = SpectralRColormap::new(); + spectral_r.gray2color(&out)? + } else { + let rgb_slice = [&out, &out, &out]; + Tensor::cat(&rgb_slice, 0)?.squeeze(1)? + }; + + let max_pixel_val = Tensor::try_from(255.0f32)? + .to_device(out.device())? + .broadcast_as(out.shape())?; + let out = (out * max_pixel_val)?; + + out.to_dtype(U8) +} + +fn scale_image(depth: &Tensor) -> Result { + let flat_values: Vec = depth.flatten_all()?.to_vec1()?; + + let min_val = flat_values.iter().min_by(|a, b| a.total_cmp(b)).unwrap(); + let max_val = flat_values.iter().max_by(|a, b| a.total_cmp(b)).unwrap(); + + let min_val_tensor = Tensor::try_from(*min_val)? + .to_device(depth.device())? + .broadcast_as(depth.shape())?; + let depth = (depth - min_val_tensor)?; + + let range = max_val - min_val; + let range_tensor = Tensor::try_from(range)? + .to_device(depth.device())? + .broadcast_as(depth.shape())?; + + depth / range_tensor +} diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index e54e619545..beb771aaf9 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1,4 +1,4 @@ -use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor, D}; +use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D}; use rayon::prelude::*; /// Applies the softmax function to the input tensor, rescaling the element so that elements on @@ -953,3 +953,24 @@ pub fn kvconcat(ltensor: &Tensor, rtensor: &Tensor, concat_dim: usize) -> Result pub fn kvconcat(ltensor: &Tensor, rtensor: &Tensor, concat_dim: i32) -> Result { Tensor::cat(&[ltensor, rtensor], concat_dim as usize)?.contiguous() } + +#[derive(Clone, Debug)] +pub struct Identity; + +impl Identity { + pub fn new() -> Identity { + Self + } +} + +impl Default for Identity { + fn default() -> Self { + Self + } +} + +impl Module for Identity { + fn forward(&self, xs: &Tensor) -> Result { + Ok(xs.clone()) + } +} diff --git a/candle-transformers/src/models/depth_anything_v2.rs b/candle-transformers/src/models/depth_anything_v2.rs new file mode 100644 index 0000000000..9eee6d1130 --- /dev/null +++ b/candle-transformers/src/models/depth_anything_v2.rs @@ -0,0 +1,553 @@ +use candle::D::Minus1; +use candle::{Module, Result, Tensor}; +use candle_nn::ops::Identity; +use candle_nn::{ + batch_norm, conv2d, conv2d_no_bias, conv_transpose2d, linear, seq, Activation, BatchNorm, + BatchNormConfig, Conv2d, Conv2dConfig, ConvTranspose2dConfig, Sequential, VarBuilder, +}; + +use crate::models::dinov2::DinoVisionTransformer; + +pub struct DepthAnythingV2Config { + out_channel_sizes: [usize; 4], + in_channel_size: usize, // embed_dim in the Dino model + num_features: usize, + use_batch_norm: bool, + use_class_token: bool, + layer_ids_vits: Vec, + input_image_size: usize, + target_patch_size: usize, +} + +impl DepthAnythingV2Config { + #[allow(clippy::too_many_arguments)] + pub fn new( + out_channel_sizes: [usize; 4], + in_channel_size: usize, + num_features: usize, + use_batch_norm: bool, + use_class_token: bool, + layer_ids_vits: Vec, + input_image_size: usize, + target_patch_size: usize, + ) -> Self { + Self { + out_channel_sizes, + in_channel_size, + num_features, + use_batch_norm, + use_class_token, + layer_ids_vits, + input_image_size, + target_patch_size, + } + } + + pub fn vit_small() -> Self { + Self { + out_channel_sizes: [48, 96, 192, 384], + in_channel_size: 384, + num_features: 64, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![2, 5, 8, 11], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } + + pub fn vit_base() -> Self { + Self { + out_channel_sizes: [96, 192, 384, 768], + in_channel_size: 768, + num_features: 128, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![2, 5, 8, 11], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } + + pub fn vit_large() -> Self { + Self { + out_channel_sizes: [256, 512, 1024, 1024], + in_channel_size: 1024, + num_features: 256, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![4, 11, 17, 23], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } + + pub fn vit_giant() -> Self { + Self { + out_channel_sizes: [1536, 1536, 1536, 1536], + in_channel_size: 1536, + num_features: 384, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![9, 19, 29, 39], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } +} + +pub struct ResidualConvUnit { + activation: Activation, + conv1: Conv2d, + conv2: Conv2d, + batch_norm1: Option, + batch_norm2: Option, +} + +impl ResidualConvUnit { + pub fn new( + conf: &DepthAnythingV2Config, + activation: Activation, + vb: VarBuilder, + ) -> Result { + const KERNEL_SIZE: usize = 3; + let conv_cfg = Conv2dConfig { + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + }; + let conv1 = conv2d( + conf.num_features, + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("conv1"), + )?; + let conv2 = conv2d( + conf.num_features, + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("conv2"), + )?; + + let (batch_norm1, batch_norm2) = match conf.use_batch_norm { + true => { + let batch_norm_cfg = BatchNormConfig { + eps: 1e-05, + remove_mean: false, + affine: true, + momentum: 0.1, + }; + ( + Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn1"))?), + Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn2"))?), + ) + } + false => (None, None), + }; + + Ok(Self { + activation, + conv1, + conv2, + batch_norm1, + batch_norm2, + }) + } +} + +impl Module for ResidualConvUnit { + fn forward(&self, xs: &Tensor) -> Result { + let out = self.activation.forward(xs)?; + let out = self.conv1.forward(&out)?; + let out = if let Some(batch_norm1) = &self.batch_norm1 { + batch_norm1.forward_train(&out)? + } else { + out + }; + + let out = self.activation.forward(&out)?; + let out = self.conv2.forward(&out)?; + let out = if let Some(batch_norm2) = &self.batch_norm2 { + batch_norm2.forward_train(&out)? + } else { + out + }; + + out + xs + } +} + +pub struct FeatureFusionBlock { + res_conv_unit1: ResidualConvUnit, + res_conv_unit2: ResidualConvUnit, + output_conv: Conv2d, + target_patch_size: usize, +} + +impl FeatureFusionBlock { + pub fn new( + conf: &DepthAnythingV2Config, + target_patch_size: usize, + activation: Activation, + vb: VarBuilder, + ) -> Result { + const KERNEL_SIZE: usize = 1; + let conv_cfg = Conv2dConfig { + padding: 0, + stride: 1, + dilation: 1, + groups: 1, + }; + let output_conv = conv2d( + conf.num_features, + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("out_conv"), + )?; + let res_conv_unit1 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit1"))?; + let res_conv_unit2 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit2"))?; + + Ok(Self { + res_conv_unit1, + res_conv_unit2, + output_conv, + target_patch_size, + }) + } +} + +impl Module for FeatureFusionBlock { + fn forward(&self, xs: &Tensor) -> Result { + let out = self.res_conv_unit2.forward(xs)?; + let out = out.interpolate2d(self.target_patch_size, self.target_patch_size)?; + + self.output_conv.forward(&out) + } +} + +pub struct Scratch { + layer1_rn: Conv2d, + layer2_rn: Conv2d, + layer3_rn: Conv2d, + layer4_rn: Conv2d, + refine_net1: FeatureFusionBlock, + refine_net2: FeatureFusionBlock, + refine_net3: FeatureFusionBlock, + refine_net4: FeatureFusionBlock, + output_conv1: Conv2d, + output_conv2: Sequential, +} + +impl Scratch { + pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result { + const KERNEL_SIZE: usize = 3; + let conv_cfg = Conv2dConfig { + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + }; + + let layer1_rn = conv2d_no_bias( + conf.out_channel_sizes[0], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer1_rn"), + )?; + let layer2_rn = conv2d_no_bias( + conf.out_channel_sizes[1], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer2_rn"), + )?; + let layer3_rn = conv2d_no_bias( + conf.out_channel_sizes[2], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer3_rn"), + )?; + let layer4_rn = conv2d_no_bias( + conf.out_channel_sizes[3], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer4_rn"), + )?; + + let refine_net1 = FeatureFusionBlock::new( + conf, + conf.target_patch_size * 8, + Activation::Relu, + vb.pp("refinenet1"), + )?; + let refine_net2 = FeatureFusionBlock::new( + conf, + conf.target_patch_size * 4, + Activation::Relu, + vb.pp("refinenet2"), + )?; + let refine_net3 = FeatureFusionBlock::new( + conf, + conf.target_patch_size * 2, + Activation::Relu, + vb.pp("refinenet3"), + )?; + let refine_net4 = FeatureFusionBlock::new( + conf, + conf.target_patch_size, + Activation::Relu, + vb.pp("refinenet4"), + )?; + + let conv_cfg = Conv2dConfig { + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + }; + let output_conv1 = conv2d( + conf.num_features, + conf.num_features / 2, + KERNEL_SIZE, + conv_cfg, + vb.pp("output_conv1"), + )?; + + let output_conv2 = seq(); + const HEAD_FEATURES_2: usize = 32; + const OUT_CHANNELS_2: usize = 1; + const KERNEL_SIZE_2: usize = 1; + let output_conv2 = output_conv2.add(conv2d( + conf.num_features / 2, + HEAD_FEATURES_2, + KERNEL_SIZE, + conv_cfg, + vb.pp("output_conv2").pp("0"), + )?); + let output_conv2 = output_conv2 + .add(Activation::Relu) + .add(conv2d( + HEAD_FEATURES_2, + OUT_CHANNELS_2, + KERNEL_SIZE_2, + conv_cfg, + vb.pp("output_conv2").pp("2"), + )?) + .add(Activation::Relu); + + Ok(Self { + layer1_rn, + layer2_rn, + layer3_rn, + layer4_rn, + refine_net1, + refine_net2, + refine_net3, + refine_net4, + output_conv1, + output_conv2, + }) + } +} + +const NUM_CHANNELS: usize = 4; + +pub struct DPTHead<'a> { + conf: &'a DepthAnythingV2Config, + projections: Vec, + resize_layers: Vec>, + readout_projections: Vec, + scratch: Scratch, +} + +impl<'a> DPTHead<'a> { + pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result { + let mut projections: Vec = Vec::with_capacity(conf.out_channel_sizes.len()); + for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() { + projections.push(conv2d( + conf.in_channel_size, + *out_channel_size, + 1, + Default::default(), + vb.pp("projects").pp(conv_index.to_string()), + )?); + } + + let resize_layers: Vec> = vec![ + Box::new(conv_transpose2d( + conf.out_channel_sizes[0], + conf.out_channel_sizes[0], + 4, + ConvTranspose2dConfig { + padding: 0, + stride: 4, + dilation: 1, + output_padding: 0, + }, + vb.pp("resize_layers").pp("0"), + )?), + Box::new(conv_transpose2d( + conf.out_channel_sizes[1], + conf.out_channel_sizes[1], + 2, + ConvTranspose2dConfig { + padding: 0, + stride: 2, + dilation: 1, + output_padding: 0, + }, + vb.pp("resize_layers").pp("1"), + )?), + Box::new(Identity::new()), + Box::new(conv2d( + conf.out_channel_sizes[3], + conf.out_channel_sizes[3], + 3, + Conv2dConfig { + padding: 1, + stride: 2, + dilation: 1, + groups: 1, + }, + vb.pp("resize_layers").pp("3"), + )?), + ]; + + let readout_projections = if conf.use_class_token { + let rop = Vec::with_capacity(NUM_CHANNELS); + for rop_index in 0..NUM_CHANNELS { + seq() + .add(linear( + 2 * conf.in_channel_size, + conf.in_channel_size, + vb.pp("readout_projects").pp(rop_index.to_string()), + )?) + .add(Activation::Gelu); + } + rop + } else { + vec![] + }; + + let scratch = Scratch::new(conf, vb.pp("scratch"))?; + + Ok(Self { + conf, + projections, + resize_layers, + readout_projections, + scratch, + }) + } +} + +impl Module for DPTHead<'_> { + fn forward(&self, xs: &Tensor) -> Result { + let mut out: Vec = Vec::with_capacity(NUM_CHANNELS); + for i in 0..NUM_CHANNELS { + let x = if self.conf.use_class_token { + let x = xs.get(i)?.get(0)?; + let class_token = xs.get(i)?.get(1)?; + let readout = class_token.unsqueeze(1)?.expand(x.shape())?; + let to_cat = [x, readout]; + let cat = Tensor::cat(&to_cat, Minus1)?; + self.readout_projections[i].forward(&cat)? + } else { + xs.get(i)? + }; + let x_dims = x.dims(); + + let x = x.permute((0, 2, 1))?.reshape(( + x_dims[0], + x_dims[x_dims.len() - 1], + self.conf.target_patch_size, + self.conf.target_patch_size, + ))?; + let x = self.projections[i].forward(&x)?; + + let x = self.resize_layers[i].forward(&x)?; + out.push(x); + } + + let layer_1_rn = self.scratch.layer1_rn.forward(&out[0])?; + let layer_2_rn = self.scratch.layer2_rn.forward(&out[1])?; + let layer_3_rn = self.scratch.layer3_rn.forward(&out[2])?; + let layer_4_rn = self.scratch.layer4_rn.forward(&out[3])?; + + let path4 = self.scratch.refine_net4.forward(&layer_4_rn)?; + + let res3_out = self + .scratch + .refine_net3 + .res_conv_unit1 + .forward(&layer_3_rn)?; + let res3_out = path4.add(&res3_out)?; + let path3 = self.scratch.refine_net3.forward(&res3_out)?; + + let res2_out = self + .scratch + .refine_net2 + .res_conv_unit1 + .forward(&layer_2_rn)?; + let res2_out = path3.add(&res2_out)?; + let path2 = self.scratch.refine_net2.forward(&res2_out)?; + + let res1_out = self + .scratch + .refine_net1 + .res_conv_unit1 + .forward(&layer_1_rn)?; + let res1_out = path2.add(&res1_out)?; + let path1 = self.scratch.refine_net1.forward(&res1_out)?; + + let out = self.scratch.output_conv1.forward(&path1)?; + + let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?; + + self.scratch.output_conv2.forward(&out) + } +} + +pub struct DepthAnythingV2<'a> { + pretrained: &'a DinoVisionTransformer, + depth_head: DPTHead<'a>, + conf: &'a DepthAnythingV2Config, +} + +impl<'a> DepthAnythingV2<'a> { + pub fn new( + pretrained: &'a DinoVisionTransformer, + conf: &'a DepthAnythingV2Config, + vb: VarBuilder, + ) -> Result { + let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?; + + Ok(Self { + pretrained, + depth_head, + conf, + }) + } +} + +impl<'a> Module for DepthAnythingV2<'a> { + fn forward(&self, xs: &Tensor) -> Result { + let features = self.pretrained.get_intermediate_layers( + xs, + &self.conf.layer_ids_vits, + false, + false, + true, + )?; + let depth = self.depth_head.forward(&features)?; + + depth.relu() + } +} diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs index 757aa88ac4..00e501ce0d 100644 --- a/candle-transformers/src/models/dinov2.rs +++ b/candle-transformers/src/models/dinov2.rs @@ -258,6 +258,84 @@ impl DinoVisionTransformer { let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?; &xs + &self.interpolate_pos_encoding(&xs, w, h)? } + + fn get_intermediate_layers_not_chunked( + &self, + xs: &Tensor, + blocks_to_take: &[usize], + ) -> Result> { + let mut xs = self.prepare_tokens_with_mask(xs)?; + let mut output = Vec::new(); + for (i, blk) in self.blocks.iter().enumerate() { + xs = blk.forward(&xs)?; + if blocks_to_take.contains(&i) { + output.push(xs.clone()); + } + } + if output.len() != blocks_to_take.len() { + candle::bail!( + "only {} / {} blocks found", + output.len(), + blocks_to_take.len() + ); + } + Ok(output) + } + + pub fn get_intermediate_layers( + &self, + xs: &Tensor, + blocks_to_take: &[usize], + reshape: bool, + return_class_token: bool, + norm: bool, + ) -> Result { + let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?; + let outputs = if norm { + outputs + .iter() + .map(|out| self.norm.forward(out)) + .collect::>>()? + } else { + outputs + }; + let class_tokens = outputs + .iter() + .map(|out| out.i((.., 0))) + .collect::>>()?; + let outputs = outputs + .iter() + .map(|out| out.i((.., 1..))) + .collect::>>()?; + + let outputs = if reshape { + let (b, _c, w, h) = xs.dims4()?; + let patch_size = self.patch_embed.patch_size.0; + let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size)); + outputs + .iter() + .map(|out| { + out.reshape((b, w / patch_size, h / patch_size, num_channels))? + .transpose(2, 3)? + .transpose(1, 2) + }) + .collect::>>()? + } else { + outputs + }; + + let outputs = if return_class_token { + outputs + .iter() + .zip(class_tokens.iter()) + .map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1)) + .collect::>>()? + } else { + outputs + }; + + Tensor::stack(&outputs[..], 0) + } } impl Module for DinoVisionTransformer { diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 99a2b3e940..3b0fa834f5 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -6,6 +6,7 @@ pub mod chatglm; pub mod clip; pub mod convmixer; pub mod convnext; +pub mod depth_anything_v2; pub mod dinov2; pub mod distilbert; pub mod efficientnet; From 5df1ae26727d14b5237a306a8e6c47f89ba3dd4d Mon Sep 17 00:00:00 2001 From: "drCathieSo.eth" Date: Sat, 29 Jun 2024 03:40:31 +0800 Subject: [PATCH 13/75] Adding Gemm and ArgMax operators to candle-onnx (#2231) * feat(gemm): implement Gemm operator in candle-onnx * feat(onnx): Add support for ArgMax operator in candle-onnx * Apply rustfmt. * Remove argmax as it was already present. --------- Co-authored-by: Laurent --- candle-onnx/src/eval.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 10a3b9377b..f7203b36f7 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1274,6 +1274,30 @@ fn simple_eval_( let output = candle_nn::ops::leaky_relu(input, alpha.into())?; values.insert(node.output[0].clone(), output); } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm + "Gemm" => { + let a = get(&node.input[0])?; + let b = get(&node.input[1])?; + let c = get(&node.input[2])?; + + let alpha = get_attr_opt::(node, "alpha")?.copied().unwrap_or(1.0); + let beta = get_attr_opt::(node, "beta")?.copied().unwrap_or(1.0); + + let alpha = Tensor::full(alpha, a.shape(), &Device::Cpu)?; + let beta = Tensor::full(beta, c.shape(), &Device::Cpu)?; + + let trans_a = get_attr_opt::(node, "transA")?.copied().unwrap_or(0); + let trans_b = get_attr_opt::(node, "transB")?.copied().unwrap_or(0); + + let a = if trans_a == 0 { a.clone() } else { a.t()? }; + let b = if trans_b == 0 { b.clone() } else { b.t()? }; + + let output = a + .broadcast_mul(&alpha)? + .broadcast_matmul(&b)? + .broadcast_add(&c.broadcast_mul(&beta)?)?; + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } From 0bb678ce10eeb85861a2176ae9c6ac59ef4fc823 Mon Sep 17 00:00:00 2001 From: v-espitalier <125037408+v-espitalier@users.noreply.github.com> Date: Sat, 29 Jun 2024 11:49:15 +0200 Subject: [PATCH 14/75] Add DINOv2Reg4 + PlantCLEF2024 (#2293) * Add: DINOv2Reg4 with PlantCLEF2024 weights and example ( See https://arxiv.org/abs/2309.16588 and https://zenodo.org/records/10848263 ) * Remove extra files + update README to download them + remove extra lines * minor fix (README remove extra spaces) * minor fix (README: Fix image url) * Modif: Add back interpolate_pos_encoding() + fix when no interpolation + remove extra comments + Update README ( source image changed and so the predictions ) * Fix: Improve code lisibility with '$ cargo clippy' and '$ cargo fmt' * Another clippy fix. --------- Co-authored-by: x-VEspit Co-authored-by: laurent --- candle-examples/examples/dinov2reg4/README.md | 25 ++ candle-examples/examples/dinov2reg4/main.rs | 70 +++++ candle-examples/src/imagenet.rs | 18 ++ candle-transformers/src/models/dinov2reg4.rs | 281 ++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + 5 files changed, 395 insertions(+) create mode 100644 candle-examples/examples/dinov2reg4/README.md create mode 100644 candle-examples/examples/dinov2reg4/main.rs create mode 100644 candle-transformers/src/models/dinov2reg4.rs diff --git a/candle-examples/examples/dinov2reg4/README.md b/candle-examples/examples/dinov2reg4/README.md new file mode 100644 index 0000000000..ac86ca6911 --- /dev/null +++ b/candle-examples/examples/dinov2reg4/README.md @@ -0,0 +1,25 @@ +# candle-dinov2-reg4 + +[DINOv2-reg4](https://arxiv.org/abs/2309.16588) is the lastest version of DINOv2 with registers. +In this example, it is used as an plant species classifier: the model returns the +probability for the image to belong to each of the 7806 PlantCLEF2024 categories. + +## Running some example + +```bash +# Download classes names and a plant picture to identify +curl https://huggingface.co/vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights/raw/main/species_id_mapping.txt --output candle-examples/examples/dinov2reg4/species_id_mapping.txt +curl https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c --output candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg + +# Perform inference +cargo run --example dinov2reg4 --release -- --image candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg + +> Orchis simia Lam. : 45.55% +> Orchis × bergonii Nanteuil: 9.80% +> Orchis italica Poir. : 9.66% +> Orchis × angusticruris Franch.: 2.76% +> Orchis × bivonae Tod. : 2.54% + +``` + +![Orchis Simia](https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c) diff --git a/candle-examples/examples/dinov2reg4/main.rs b/candle-examples/examples/dinov2reg4/main.rs new file mode 100644 index 0000000000..15270517c5 --- /dev/null +++ b/candle-examples/examples/dinov2reg4/main.rs @@ -0,0 +1,70 @@ +//! DINOv2 reg4 finetuned on PlantCLEF 2024 +//! https://arxiv.org/abs/2309.16588 +//! https://huggingface.co/spaces/BVRA/PlantCLEF2024 +//! https://zenodo.org/records/10848263 + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::Parser; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::dinov2reg4; + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image518(args.image)?.to_device(&device)?; + println!("loaded image {image:?}"); + + let f_species_id_mapping = "candle-examples/examples/dinov2reg4/species_id_mapping.txt"; + let classes: Vec = std::fs::read_to_string(f_species_id_mapping) + .expect("missing classes file") + .split('\n') + .map(|s| s.to_string()) + .collect(); + + let model_file = match args.model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = + api.model("vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights".into()); + api.get( + "vit_base_patch14_reg4_dinov2_lvd142m_pc24_onlyclassifier_then_all.safetensors", + )? + } + Some(model) => model.into(), + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = dinov2reg4::vit_base(vb)?; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::()?; + let mut prs = prs.iter().enumerate().collect::>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!("{:24}: {:.2}%", classes[category_idx], 100. * pr); + } + Ok(()) +} diff --git a/candle-examples/src/imagenet.rs b/candle-examples/src/imagenet.rs index cefbd71bbe..781dcd4fc3 100644 --- a/candle-examples/src/imagenet.rs +++ b/candle-examples/src/imagenet.rs @@ -17,6 +17,24 @@ pub fn load_image224>(p: P) -> Result { .broadcast_div(&std) } +/// Loads an image from disk using the image crate, this returns a tensor with shape +/// (3, 518, 518). imagenet normalization is applied. +/// The model dinov2 reg4 analyzes images with dimensions 3x518x518 (resulting in 37x37 transformer tokens). +pub fn load_image518>(p: P) -> Result { + let img = image::io::Reader::open(p)? + .decode() + .map_err(candle::Error::wrap)? + .resize_to_fill(518, 518, image::imageops::FilterType::Triangle); + let img = img.to_rgb8(); + let data = img.into_raw(); + let data = Tensor::from_vec(data, (518, 518, 3), &Device::Cpu)?.permute((2, 0, 1))?; + let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?; + let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?; + (data.to_dtype(candle::DType::F32)? / 255.)? + .broadcast_sub(&mean)? + .broadcast_div(&std) +} + pub const CLASS_COUNT: i64 = 1000; pub const CLASSES: [&str; 1000] = [ diff --git a/candle-transformers/src/models/dinov2reg4.rs b/candle-transformers/src/models/dinov2reg4.rs new file mode 100644 index 0000000000..6bbe2e2410 --- /dev/null +++ b/candle-transformers/src/models/dinov2reg4.rs @@ -0,0 +1,281 @@ +use candle::{IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +const IMG_SIZE: usize = 518; +const PATCH_SIZE: usize = 14; +const NUM_CLASSES: usize = 7806; // PlantCLEF2024 DINOv2 (https://zenodo.org/records/10848263) + +fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result { + if bias { + candle_nn::linear(in_dim, out_dim, vb) + } else { + candle_nn::linear_no_bias(in_dim, out_dim, vb) + } +} + +#[derive(Debug)] +struct Attention { + qkv: Linear, + proj: Linear, + num_heads: usize, + scale: f64, +} + +impl Attention { + fn new( + vb: VarBuilder, + dim: usize, + num_heads: usize, + qkv_bias: bool, + proj_bias: bool, + ) -> Result { + let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; + let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?; + let scale = 1. / ((dim / num_heads) as f64).sqrt(); + Ok(Self { + qkv, + proj, + num_heads, + scale, + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result { + let (b, n, c) = xs.dims3()?; + let qkv = self + .qkv + .forward(xs)? + .reshape((b, n, 3, self.num_heads, c / self.num_heads))? + .transpose(1, 2)? // 02134 + .transpose(0, 1)? // 20134 + .transpose(2, 3)?; // 20314 + let q = (qkv.i(0)? * self.scale)?; + let k = qkv.i(1)?.contiguous()?; + let v = qkv.i(2)?.contiguous()?; + let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?; + let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?; + self.proj.forward(&attn) + } +} + +#[derive(Debug)] +struct LayerScale { + gamma: Tensor, +} + +impl LayerScale { + fn new(vb: VarBuilder, dim: usize) -> Result { + let gamma = vb.get(dim, "gamma")?; + Ok(Self { gamma }) + } +} + +impl Module for LayerScale { + fn forward(&self, xs: &Tensor) -> Result { + xs.broadcast_mul(&self.gamma) + } +} + +#[derive(Debug)] +struct Mlp { + fc1: Linear, + fc2: Linear, +} + +impl Mlp { + fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result { + let out_features = in_features; + let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?; + let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?; + Ok(Self { fc1, fc2 }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.fc1.forward(xs)?.gelu()?; + self.fc2.forward(&xs) + } +} + +#[derive(Debug)] +struct Block { + norm1: LayerNorm, + attn: Attention, + ls1: LayerScale, + norm2: LayerNorm, + mlp: Mlp, + ls2: LayerScale, +} + +impl Block { + fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result { + let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?; + let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?; + let ls1 = LayerScale::new(vb.pp("ls1"), dim)?; + let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?; + let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?; + let ls2 = LayerScale::new(vb.pp("ls2"), dim)?; + Ok(Self { + norm1, + attn, + ls1, + norm2, + mlp, + ls2, + }) + } +} + +impl Module for Block { + fn forward(&self, xs: &Tensor) -> Result { + let residual = xs; + let xs = self + .ls1 + .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = self + .ls2 + .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?; + xs + residual + } +} + +#[derive(Debug)] +struct PatchEmbed { + proj: candle_nn::Conv2d, + patch_size: (usize, usize), + num_patches: usize, +} + +impl PatchEmbed { + fn new( + vb: VarBuilder, + img_size: usize, + patch_size: usize, + in_chans: usize, + embed_dim: usize, + ) -> Result { + let config = candle_nn::Conv2dConfig { + stride: patch_size, + ..Default::default() + }; + let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?; + let num_patches = (img_size / patch_size) * (img_size / patch_size); + Ok(Self { + proj, + patch_size: (patch_size, patch_size), + num_patches, + }) + } +} + +impl Module for PatchEmbed { + fn forward(&self, xs: &Tensor) -> Result { + let (_b, _c, h, w) = xs.dims4()?; + let (patch_h, patch_w) = self.patch_size; + if (h % patch_h) != 0 { + candle::bail!("image height {h} is not a multiple of patch height {patch_h}") + } + if (w % patch_w) != 0 { + candle::bail!("image width {w} is not a multiple of patch width {patch_w}") + } + let xs = self.proj.forward(xs)?; + let (b, c, h, w) = xs.dims4()?; + // flatten embeddings. + xs.reshape((b, c, h * w))?.transpose(1, 2) + } +} + +#[derive(Debug)] +pub struct DinoVisionTransformer { + patch_embed: PatchEmbed, + cls_token: Tensor, + reg_token: Tensor, + pos_embed: Tensor, + blocks: Vec, + norm: LayerNorm, + head: Linear, +} + +impl DinoVisionTransformer { + pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result { + let patch_embed = + PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?; + let cls_token = vb.get((1, 1, embed_dim), "cls_token")?; + let reg_token = vb.get((1, 4, embed_dim), "reg_token")?; + let pos_embed = vb.get((1, patch_embed.num_patches, embed_dim), "pos_embed")?; + let head = linear(vb.pp("head"), embed_dim, NUM_CLASSES, true)?; + let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?; + let vb_b = vb.pp("blocks"); + let blocks = (0..depth) + .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) + .collect::>>()?; + Ok(Self { + patch_embed, + cls_token, + reg_token, + pos_embed, + blocks, + norm, + head, + }) + } + + fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result { + let npatch = xs.dim(1)? - 1; + let n = self.pos_embed.dim(1)? - 1; + let sqrt_n = (n as f64).sqrt(); + if npatch == n && w == h { + return Ok(self.pos_embed.clone()); + } + let patch_pos_embed = &self.pos_embed; + let dim = xs.dim(D::Minus1)?; + let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1); + let patch_pos_embed = patch_pos_embed + .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))? + .transpose(2, 3)? + .transpose(1, 2)?; + // This uses bicubic interpolation in the original implementation. + let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?; + let el_count = patch_pos_embed.shape().elem_count(); + patch_pos_embed + .transpose(1, 2)? + .transpose(2, 3)? + .reshape((1, el_count / dim, dim)) + } + + fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result { + let (_b, _nc, w, h) = xs.dims4()?; + if (w != IMG_SIZE) || (h != IMG_SIZE) { + panic!("Error: The input tensor should have the shape: Bx3x518x518."); + } + let xs = self.patch_embed.forward(xs)?; + let xs = (&xs + &self.interpolate_pos_encoding(&xs, w, h)?)?; + let xs = Tensor::cat(&[&self.cls_token, &self.reg_token, &xs], 1)?; + Ok(xs) + } +} + +impl Module for DinoVisionTransformer { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = self.prepare_tokens_with_mask(xs)?; + for blk in self.blocks.iter() { + xs = blk.forward(&xs)? + } + let xs = self.norm.forward(&xs)?; + let xs_norm_clstoken = xs.i((.., 0))?; + self.head.forward(&xs_norm_clstoken) + } +} + +pub fn vit_small(vb: VarBuilder) -> Result { + DinoVisionTransformer::new(vb, 12, 384, 6) +} + +pub fn vit_base(vb: VarBuilder) -> Result { + DinoVisionTransformer::new(vb, 12, 768, 12) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 3b0fa834f5..dfdabfd5fc 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -8,6 +8,7 @@ pub mod convmixer; pub mod convnext; pub mod depth_anything_v2; pub mod dinov2; +pub mod dinov2reg4; pub mod distilbert; pub mod efficientnet; pub mod efficientvit; From b438cba3f8742519d6f2154209e096c123fc3eb9 Mon Sep 17 00:00:00 2001 From: Czxck001 <10724409+Czxck001@users.noreply.github.com> Date: Sat, 29 Jun 2024 12:34:42 -0700 Subject: [PATCH 15/75] make up for the missing last token output of phi2 example (#2299) --- candle-examples/examples/phi/main.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 1cfeb443a2..1a0d9aca53 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -114,6 +114,10 @@ impl TextGeneration { tokens.push(next_token); generated_tokens += 1; if next_token == eos_token { + if let Some(t) = self.tokenizer.decode_rest()? { + print!("{t}"); + std::io::stdout().flush()?; + } break; } if let Some(t) = self.tokenizer.next_token(next_token)? { From b7a3e3449a7928332ce5db8eb84b28b9133da4d2 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Sun, 30 Jun 2024 02:34:15 -0400 Subject: [PATCH 16/75] Patch metal function --- candle-core/src/quantized/metal.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 5f61749b36..f23d6e15df 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -129,7 +129,6 @@ impl QMetalStorage { pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> { // Quantization only happens on CPU for now. let elem_count = src.as_slice::()?.len(); - let src = crate::Storage::Cpu(src); let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?; if let QStorage::Cpu(storage) = &mut qcpu_storage { From 9e09d7f3ef36bb5b30b6aa963992392271d27f49 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 26 Jul 2024 11:14:22 -0400 Subject: [PATCH 17/75] Expose cublas handle --- .vscode/settings.json | 2 +- candle-core/src/cuda_backend/device.rs | 4 ++++ candle-kernels/src/lib.rs | 7 ++++--- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index b7345f2ca6..0230093c95 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,5 +8,5 @@ ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, - //"rust-analyzer.cargo.features": ["cuda"], + "rust-analyzer.cargo.features": ["cuda"], } \ No newline at end of file diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 0aa58cacde..352bae9442 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -47,6 +47,10 @@ impl std::ops::Deref for CudaDevice { } impl CudaDevice { + pub fn cublas_handle(&self) -> &cudarc::cublas::CudaBlas { + &*self.blas + } + pub fn cuda_device(&self) -> Arc { self.device.clone() } diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs index 74c6d3d6bb..cec1b1e2d4 100644 --- a/candle-kernels/src/lib.rs +++ b/candle-kernels/src/lib.rs @@ -3,12 +3,13 @@ pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx")); pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx")); pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); +pub const FUSED_LAYER_NORM: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_layer_norm.ptx")); +pub const FUSED_RMS_NORM: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_rms_norm.ptx")); +pub const FUSED_ROPE: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_rope.ptx")); pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx")); +pub const KVCONCAT: &str = include_str!(concat!(env!("OUT_DIR"), "/kvconcat.ptx")); pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx")); pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx")); pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx")); pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx")); -pub const FUSED_RMS_NORM: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_rms_norm.ptx")); -pub const FUSED_ROPE: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_rope.ptx")); -pub const KVCONCAT: &str = include_str!(concat!(env!("OUT_DIR"), "/kvconcat.ptx")); From 1a4876734375d39a2b4428e075ae9a9e41759ae5 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 4 Aug 2024 15:02:19 -0400 Subject: [PATCH 18/75] Add sdpa function with cublaslt --- .vscode/settings.json | 3 +- candle-nn/Cargo.toml | 2 + candle-nn/src/attention.rs | 123 +++++ candle-nn/src/cublaslt/api.rs | 942 +++++++++++++++++++++++++++++++++ candle-nn/src/cublaslt/mod.rs | 100 ++++ candle-nn/src/lib.rs | 3 + candle-transformers/Cargo.toml | 2 +- 7 files changed, 1173 insertions(+), 2 deletions(-) create mode 100644 candle-nn/src/attention.rs create mode 100644 candle-nn/src/cublaslt/api.rs create mode 100644 candle-nn/src/cublaslt/mod.rs diff --git a/.vscode/settings.json b/.vscode/settings.json index b2dbd68012..646783a968 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,5 +7,6 @@ "candle-pyo3" ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true + "python.testing.pytestEnabled": true, + "rust-analyzer.cargo.features": ["cuda"] } \ No newline at end of file diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 9f0d56bdea..570edb48be 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -21,6 +21,7 @@ safetensors = { workspace = true } serde = { workspace = true } metal = { workspace = true, optional = true } candle-metal-kernels = { workspace = true, optional = true } +candle-flash-attn = { workspace = true, optional = true } [dev-dependencies] anyhow = { workspace = true } @@ -34,6 +35,7 @@ accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src", "candle/mkl"] metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"] +flash-attn = ["cuda", "dep:candle-flash-attn"] [[bench]] name = "bench_main" diff --git a/candle-nn/src/attention.rs b/candle-nn/src/attention.rs new file mode 100644 index 0000000000..490e1c4eb5 --- /dev/null +++ b/candle-nn/src/attention.rs @@ -0,0 +1,123 @@ +use candle::{Device, Result, Tensor}; + +use crate::cublaslt::{setup_cublas_lt_wrapper, CUBLASLT_HANDLE}; + +#[cfg(feature = "flash-attn")] +pub fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +pub fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { + unimplemented!("Compile with '--features flash-attn'") +} + +/// Computes softmax(QK^T*sqrt(d_k))V +fn naive_sdpa( + q: &Tensor, + k: &Tensor, + v: &Tensor, + head_dim: usize, + mask: Option<&Tensor>, +) -> Result { + let att = (&q.contiguous()?.matmul(&k.t()?.contiguous()?)? / (head_dim as f64).sqrt())?; + + let att = match mask { + Some(m) => att.broadcast_add(m)?, + None => att, + }; + let att = crate::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous()?) +} + +/// Computes softmax(QK^T*sqrt(d_k))V +/// +/// The attention implementation is automatically accelerated and dispatched as follows: +/// 1) If `use_flash_attn == true`, use a Flash Attention V2 kernel +/// 2) If using CUDA, it will attempt to use cuBLASlt an optimized version +/// 3) Otherwise, use the "naive" SDPA implementation - just matmuls and elementwise operations. +/// +/// Note that there may be minute differences in output because floating point operations are not associative. +#[allow(unused_variables, clippy::too_many_arguments)] +pub fn scaled_dot_product_attention( + q: &Tensor, + k: &Tensor, + v: &Tensor, + n_attn_heads: usize, + head_dim: usize, + mask: Option<&Tensor>, + use_flash_attn: bool, + b_sz: usize, + seq_len: usize, +) -> Result { + if use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + let softmax_scale = 1f32 / (head_dim as f32).sqrt(); + return flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2); + } + + // Initializiation is behind a LazyLock. So, the first call will be slightly slower. + // No cost to the other calls. + setup_cublas_lt_wrapper(); + + if let (Device::Cuda(_), Some(cublaslt)) = (q.device(), *CUBLASLT_HANDLE.lock().unwrap()) { + #[cfg(feature = "cuda")] + { + // cuBLASLt batch matmul implementation requires inputs to be dims3 + let k = k.flatten(0, 1)?; + let q = q.flatten(0, 1)?; + let v = v.flatten(0, 1)?; + let attention_bias = mask.map(|mask| mask.flatten(0, 1)).transpose()?; + + // If attention_bias is set, we fuse the add by giving it as the output matrix + // and setting beta to 1.0 + let beta = match attention_bias.is_some() { + true => Some(1.0), + false => None, + }; + + // Batch matrix multiplication + // Fuse softmax scale and attention_bias add + let attention_scores = cublaslt.batch_matmul( + &k, + &q, + attention_bias.as_ref(), + Some((1.0 / (head_dim as f64).sqrt()) as f32), + beta, + None, + None, + )?; + let attention_probs = crate::ops::softmax_last_dim(&attention_scores)?; + + let context_layer = cublaslt.batch_matmul( + &v.t()?.contiguous()?, + &attention_probs, + // We save one allocation + Some(&q), + None, + None, + None, + None, + )?; + + // Reshape to dims4 + context_layer.reshape((b_sz, n_attn_heads, seq_len, head_dim)) + } + #[cfg(not(feature = "cuda"))] + { + candle::bail!("`cuda` feature is not enabled") + } + } else { + naive_sdpa(q, k, v, head_dim, mask) + } +} diff --git a/candle-nn/src/cublaslt/api.rs b/candle-nn/src/cublaslt/api.rs new file mode 100644 index 0000000000..918ab31ef7 --- /dev/null +++ b/candle-nn/src/cublaslt/api.rs @@ -0,0 +1,942 @@ +//! This module inspired from: +//! +//! https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/candle/src/layers/cublaslt.rs + +pub use candle::cuda_backend::cudarc::cublaslt::Activation; +use std::ffi::c_int; + +use candle::backend::BackendStorage; +use candle::cuda_backend::WrapErr; +use candle::{CpuStorage, Device, Layout, Result, Shape, Storage, Tensor}; +use half::{bf16, f16}; +use std::sync::Arc; + +use candle::cuda_backend::cudarc::cublaslt::{CudaBlasLT, Matmul, MatmulConfig}; + +#[derive(Debug, Clone)] +pub struct CublasLt(Arc); + +impl CublasLt { + pub fn new(device: &Device) -> Result { + let dev = match device { + Device::Cuda(d) => d, + _ => candle::bail!("`device` must be a `cuda` device"), + }; + + let inner = CudaBlasLT::new(dev.cuda_device()).unwrap(); + + Ok(Self(Arc::new(inner))) + } +} + +pub struct CublasLTMatmul { + pub cublaslt: Arc, + pub act: Option, + pub c: Option, + pub alpha: Option, + pub beta: Option, +} + +impl CublasLTMatmul { + pub fn fwd_f16( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: Option<&candle::CudaStorage>, + bias_l: Option<&Layout>, + ) -> Result<(candle::CudaStorage, Shape)> { + let dev = a.device(); + + // Assume TN + let (m, k) = a_l.shape().dims2()?; + + let (n, b_1) = b_l.shape().dims2()?; + + if b_1 != k { + candle::bail!("This layer only supports TN layout"); + } + + let lda = k; + let ldb = k; + let ldc = m; + + let out_shape = Shape::from((n, m)); + + let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); + let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); + + let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { + if bias_l.shape().dims1()? != m { + candle::bail!("Bias does not have the correct shape"); + } + + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) + } else { + None + }; + + let mut out = if let Some(c) = &self.c { + let (c, c_l) = c.storage_and_layout(); + let c = match &*c { + Storage::Cuda(storage) => storage.as_cuda_slice::()?, + _ => candle::bail!("`c` must be a cuda tensor"), + }; + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + candle::bail!("`c` start offset must be 0"); + } + if o2 != out_shape.elem_count() { + candle::bail!("`c` end offset must be {}", out_shape.elem_count()) + } + } + None => candle::bail!("`c` has to be contiguous"), + }; + if c_l.shape().dims2()? != (n, m) { + candle::bail!("`c` does not have the correct shape"); + } + + c.clone() + } else { + // Allocate out tensor + unsafe { dev.alloc::(out_shape.elem_count()).w()? } + }; + + let config = MatmulConfig { + transa: true, + transb: false, + m: m as u64, + n: n as u64, + k: k as u64, + alpha: self.alpha.unwrap_or(1.0), + lda: lda as i64, + ldb: ldb as i64, + beta: self.beta.unwrap_or(0.0), + ldc: ldc as i64, + stride_a: None, + stride_b: None, + stride_c: None, + stride_bias: None, + batch_size: None, + }; + + unsafe { + self.cublaslt + .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) + .map_err(|e| candle::Error::Cuda(Box::new(e)))?; + } + + let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); + + Ok((out, out_shape)) + } + + pub fn fwd_bf16( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: Option<&candle::CudaStorage>, + bias_l: Option<&Layout>, + ) -> Result<(candle::CudaStorage, Shape)> { + let dev = a.device(); + + // Assume TN + let (m, k) = a_l.shape().dims2()?; + + let (n, b_1) = b_l.shape().dims2()?; + + if b_1 != k { + candle::bail!("This layer only supports TN layout"); + } + + let lda = k; + let ldb = k; + let ldc = m; + + let out_shape = Shape::from((n, m)); + + let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); + let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); + + let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { + if bias_l.shape().dims1()? != m { + candle::bail!("Bias does not have the correct shape"); + } + + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) + } else { + None + }; + + let mut out = if let Some(c) = &self.c { + let (c, c_l) = c.storage_and_layout(); + let c = match &*c { + Storage::Cuda(storage) => storage.as_cuda_slice::()?, + _ => candle::bail!("`c` must be a cuda tensor"), + }; + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + candle::bail!("`c` start offset must be 0"); + } + if o2 != out_shape.elem_count() { + candle::bail!("`c` end offset must be {}", out_shape.elem_count()) + } + } + None => candle::bail!("`c` has to be contiguous"), + }; + if c_l.shape().dims2()? != (n, m) { + candle::bail!("`c` does not have the correct shape"); + } + + c.clone() + } else { + // Allocate out tensor + unsafe { dev.alloc::(out_shape.elem_count()).w()? } + }; + + let config = MatmulConfig { + transa: true, + transb: false, + m: m as u64, + n: n as u64, + k: k as u64, + alpha: self.alpha.unwrap_or(1.0), + lda: lda as i64, + ldb: ldb as i64, + beta: self.beta.unwrap_or(0.0), + ldc: ldc as i64, + stride_a: None, + stride_b: None, + stride_c: None, + stride_bias: None, + batch_size: None, + }; + + unsafe { + self.cublaslt + .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) + .map_err(|e| candle::Error::Cuda(Box::new(e)))?; + } + + let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); + + Ok((out, out_shape)) + } + + pub fn fwd_f32( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: Option<&candle::CudaStorage>, + bias_l: Option<&Layout>, + ) -> Result<(candle::CudaStorage, Shape)> { + let dev = a.device(); + + // Assume TN + let (m, k) = a_l.shape().dims2()?; + + let (n, b_1) = b_l.shape().dims2()?; + + if b_1 != k { + candle::bail!("This layer only supports TN layout"); + } + + let lda = k; + let ldb = k; + let ldc = m; + + let out_shape = Shape::from((n, m)); + + let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); + let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); + + let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { + if bias_l.shape().dims1()? != m { + candle::bail!("Bias does not have the correct shape"); + } + + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) + } else { + None + }; + + let mut out = if let Some(c) = &self.c { + let (c, c_l) = c.storage_and_layout(); + let c = match &*c { + Storage::Cuda(storage) => storage.as_cuda_slice::()?, + _ => candle::bail!("`c` must be a cuda tensor"), + }; + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + candle::bail!("`c` start offset must be 0"); + } + if o2 != out_shape.elem_count() { + candle::bail!("`c` end offset must be {}", out_shape.elem_count()) + } + } + None => candle::bail!("`c` has to be contiguous"), + }; + if c_l.shape().dims2()? != (n, m) { + candle::bail!("`c` does not have the correct shape"); + } + + c.clone() + } else { + // Allocate out tensor + unsafe { dev.alloc::(out_shape.elem_count()).w()? } + }; + + let config = MatmulConfig { + transa: true, + transb: false, + m: m as u64, + n: n as u64, + k: k as u64, + alpha: self.alpha.unwrap_or(1.0), + lda: lda as i64, + ldb: ldb as i64, + beta: self.beta.unwrap_or(0.0), + ldc: ldc as i64, + stride_a: None, + stride_b: None, + stride_c: None, + stride_bias: None, + batch_size: None, + }; + + unsafe { + self.cublaslt + .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) + .map_err(|e| candle::Error::Cuda(Box::new(e)))?; + } + + let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); + + Ok((out, out_shape)) + } +} + +impl candle::CustomOp2 for CublasLTMatmul { + fn name(&self) -> &'static str { + "cublaslt-matmul" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for cublaslt-matmul") + } + + fn cuda_fwd( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match a.dtype() { + candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None), + candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None), + candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None), + dt => candle::bail!("cublaslt-matmul is only supported for f16/bf16/f32 ({dt:?})"), + } + } +} + +impl candle::CustomOp3 for CublasLTMatmul { + fn name(&self) -> &'static str { + "cublaslt-matmul-add" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for cublaslt-matmul") + } + + fn cuda_fwd( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: &candle::CudaStorage, + bias_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match a.dtype() { + candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)), + candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)), + candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)), + dt => candle::bail!("cublaslt-matmul is only supported for f16/bf16/f32 ({dt:?})"), + } + } +} + +/// Fused matmul + add + Relu/Gelu activation using CublasLt +/// +/// # Arguments +/// +/// * `a` - Input tensor of size MxK +/// * `b` - Input tensor of size NxK +/// * `out` - Optional Output tensor of size NxK. +/// If set and beta != 0, will be added to the end result of A*B before `act` +/// * `alpha` - Optional scaling factor for A*B +/// * `beta` - Optional scaling factor for C +/// * `bias` - Optional bias tensor of size M +/// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result +/// * `cublaslt` - CublasLt handle +/// +/// The resulting tensor is of shape NxM +#[allow(clippy::too_many_arguments)] +pub fn fused_matmul( + a: &Tensor, + b: &Tensor, + out: Option<&Tensor>, + alpha: Option, + beta: Option, + bias: Option<&Tensor>, + act: Option, + cublaslt: CublasLt, +) -> Result { + let op = CublasLTMatmul { + act, + cublaslt: cublaslt.0, + c: out.cloned(), + alpha, + beta, + }; + + if let Some(bias) = bias { + a.apply_op3(b, bias, op) + } else { + a.apply_op2(b, op) + } +} + +pub struct CublasLTBatchMatmul { + pub cublaslt: Arc, + pub act: Option, + pub c: Option, + pub alpha: Option, + pub beta: Option, +} + +impl CublasLTBatchMatmul { + pub fn fwd_f16( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: Option<&candle::CudaStorage>, + bias_l: Option<&Layout>, + ) -> Result<(candle::CudaStorage, Shape)> { + let dev = a.device(); + + // Assume TN + let (batch_size, m, k) = a_l.shape().dims3()?; + let (b_0, n, b_2) = b_l.shape().dims3()?; + + if b_2 != k { + candle::bail!("This layer only supports TN layout"); + } + + if b_0 != batch_size { + candle::bail!("`b` must have the same batch size as `a`") + } + + let lda = k; + let ldb = k; + let ldc = m; + + let out_shape = Shape::from((batch_size, n, m)); + + let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); + let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); + + let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { + if bias_l.shape().dims1()? != m { + candle::bail!("Bias does not have the correct shape"); + } + + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) + } else { + None + }; + + let (mut out, stride_c) = if let Some(c) = &self.c { + let (c, c_l) = c.storage_and_layout(); + let c = match &*c { + Storage::Cuda(storage) => storage.as_cuda_slice::()?, + _ => candle::bail!("`c` must be a cuda tensor"), + }; + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + candle::bail!("`c` start offset must be 0"); + } + if o2 != out_shape.elem_count() { + candle::bail!("`c` end offset must be {}", out_shape.elem_count()) + } + } + None => candle::bail!("`c` has to be contiguous"), + }; + + if c_l.shape().dims3()? != (batch_size, n, m) { + candle::bail!("`c` does not have the correct shape"); + } + + // Set beta to 0.0 if it is not set + (c.clone(), c_l.stride()[0]) + } else { + // Allocate out tensor + ( + unsafe { dev.alloc::(out_shape.elem_count()).w()? }, + (n * m), + ) + }; + + let config = MatmulConfig { + transa: true, + transb: false, + m: m as u64, + n: n as u64, + k: k as u64, + alpha: self.alpha.unwrap_or(1.0), + lda: lda as i64, + ldb: ldb as i64, + beta: self.beta.unwrap_or(0.0), + ldc: ldc as i64, + stride_a: Some(a_l.stride()[0] as i64), + stride_b: Some(b_l.stride()[0] as i64), + stride_c: Some(stride_c as i64), + stride_bias: None, + batch_size: Some(c_int::try_from(batch_size)?), + }; + + unsafe { + self.cublaslt + .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) + .map_err(|e| candle::Error::Cuda(Box::new(e)))?; + } + + let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); + + Ok((out, out_shape)) + } + + pub fn fwd_bf16( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: Option<&candle::CudaStorage>, + bias_l: Option<&Layout>, + ) -> Result<(candle::CudaStorage, Shape)> { + let dev = a.device(); + + // Assume TN + let (batch_size, m, k) = a_l.shape().dims3()?; + let (b_0, n, b_2) = b_l.shape().dims3()?; + + if b_2 != k { + candle::bail!("This layer only supports TN layout"); + } + + if b_0 != batch_size { + candle::bail!("`b` must have the same batch size as `a`") + } + + let lda = k; + let ldb = k; + let ldc = m; + + let out_shape = Shape::from((batch_size, n, m)); + + let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); + let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); + + let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { + if bias_l.shape().dims1()? != m { + candle::bail!("Bias does not have the correct shape"); + } + + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) + } else { + None + }; + + let (mut out, stride_c) = if let Some(c) = &self.c { + let (c, c_l) = c.storage_and_layout(); + let c = match &*c { + Storage::Cuda(storage) => storage.as_cuda_slice::()?, + _ => candle::bail!("`c` must be a cuda tensor"), + }; + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + candle::bail!("`c` start offset must be 0"); + } + if o2 != out_shape.elem_count() { + candle::bail!("`c` end offset must be {}", out_shape.elem_count()) + } + } + None => candle::bail!("`c` has to be contiguous"), + }; + + if c_l.shape().dims3()? != (batch_size, n, m) { + candle::bail!("`c` does not have the correct shape"); + } + + // Set beta to 0.0 if it is not set + (c.clone(), c_l.stride()[0]) + } else { + // Allocate out tensor + ( + unsafe { dev.alloc::(out_shape.elem_count()).w()? }, + (n * m), + ) + }; + + let config = MatmulConfig { + transa: true, + transb: false, + m: m as u64, + n: n as u64, + k: k as u64, + alpha: self.alpha.unwrap_or(1.0), + lda: lda as i64, + ldb: ldb as i64, + beta: self.beta.unwrap_or(0.0), + ldc: ldc as i64, + stride_a: Some(a_l.stride()[0] as i64), + stride_b: Some(b_l.stride()[0] as i64), + stride_c: Some(stride_c as i64), + stride_bias: None, + batch_size: Some(c_int::try_from(batch_size)?), + }; + + unsafe { + self.cublaslt + .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) + .map_err(|e| candle::Error::Cuda(Box::new(e)))?; + } + + let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); + + Ok((out, out_shape)) + } + + pub fn fwd_f32( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: Option<&candle::CudaStorage>, + bias_l: Option<&Layout>, + ) -> Result<(candle::CudaStorage, Shape)> { + let dev = a.device(); + + // Assume TN + let (batch_size, m, k) = a_l.shape().dims3()?; + let (b_0, n, b_2) = b_l.shape().dims3()?; + + if b_2 != k { + candle::bail!("This layer only supports TN layout"); + } + + if b_0 != batch_size { + candle::bail!("`b` must have the same batch size as `a`") + } + + let lda = k; + let ldb = k; + let ldc = m; + + let out_shape = Shape::from((batch_size, n, m)); + + let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); + let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); + + let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { + if bias_l.shape().dims1()? != m { + candle::bail!("Bias does not have the correct shape"); + } + + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) + } else { + None + }; + + let (mut out, stride_c) = if let Some(c) = &self.c { + let (c, c_l) = c.storage_and_layout(); + let c = match &*c { + Storage::Cuda(storage) => storage.as_cuda_slice::()?, + _ => candle::bail!("`c` must be a cuda tensor"), + }; + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + candle::bail!("`c` start offset must be 0"); + } + if o2 != out_shape.elem_count() { + candle::bail!("`c` end offset must be {}", out_shape.elem_count()) + } + } + None => candle::bail!("`c` has to be contiguous"), + }; + + if c_l.shape().dims3()? != (batch_size, n, m) { + candle::bail!("`c` does not have the correct shape"); + } + + // Set beta to 0.0 if it is not set + (c.clone(), c_l.stride()[0]) + } else { + // Allocate out tensor + ( + unsafe { dev.alloc::(out_shape.elem_count()).w()? }, + (n * m), + ) + }; + + let config = MatmulConfig { + transa: true, + transb: false, + m: m as u64, + n: n as u64, + k: k as u64, + alpha: self.alpha.unwrap_or(1.0), + lda: lda as i64, + ldb: ldb as i64, + beta: self.beta.unwrap_or(0.0), + ldc: ldc as i64, + stride_a: Some(a_l.stride()[0] as i64), + stride_b: Some(b_l.stride()[0] as i64), + stride_c: Some(stride_c as i64), + stride_bias: None, + batch_size: Some(c_int::try_from(batch_size)?), + }; + + unsafe { + self.cublaslt + .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) + .map_err(|e| candle::Error::Cuda(Box::new(e)))?; + } + + let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); + + Ok((out, out_shape)) + } +} + +impl candle::CustomOp2 for CublasLTBatchMatmul { + fn name(&self) -> &'static str { + "cublaslt-batch-matmul" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for cublaslt-batch-matmul") + } + + fn cuda_fwd( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match a.dtype() { + candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None), + candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None), + candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None), + dt => { + candle::bail!("cublaslt-batch-matmul is only supported for f16/bf16/f32 ({dt:?})") + } + } + } +} + +impl candle::CustomOp3 for CublasLTBatchMatmul { + fn name(&self) -> &'static str { + "cublaslt-batch-matmul-add" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for cublaslt-batch-matmul-add") + } + + fn cuda_fwd( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: &candle::CudaStorage, + bias_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match a.dtype() { + candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)), + candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)), + candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)), + dt => candle::bail!( + "cublaslt-batch-matmul-add is only supported for f16/bf16/f32 ({dt:?})" + ), + } + } +} + +/// Fused batch matmul + add + Relu/Gelu activation using CublasLt +/// +/// # Arguments +/// +/// * `a` - Input tensor of size BxMxK +/// * `b` - Input tensor of size BxNxK +/// * `out` - Optional Output tensor of size BxNxK. +/// If set and beta != 0, will be added to the end result of A*B before `act` +/// * `alpha` - Optional scaling factor for A*B +/// * `beta` - Optional scaling factor for C +/// * `bias` - Optional bias tensor of size M +/// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result +/// * `cublaslt` - CublasLt handle +/// +/// The resulting tensor is of shape NxM +#[allow(clippy::too_many_arguments)] +pub fn fused_batch_matmul( + a: &Tensor, + b: &Tensor, + out: Option<&Tensor>, + alpha: Option, + beta: Option, + bias: Option<&Tensor>, + act: Option, + cublaslt: CublasLt, +) -> Result { + let op = CublasLTBatchMatmul { + act, + cublaslt: cublaslt.0, + c: out.cloned(), + alpha, + beta, + }; + + if let Some(bias) = bias { + a.apply_op3(b, bias, op) + } else { + a.apply_op2(b, op) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use candle::{DType, Device}; + + fn to_vec2_round(t: Tensor, digits: i32) -> Result>> { + let b = 10f32.powi(digits); + let t = t.to_vec2::()?; + let t = t + .iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect(); + Ok(t) + } + + fn to_vec3_round(t: Tensor, digits: i32) -> Result>>> { + let b = 10f32.powi(digits); + let t = t.to_vec3::()?; + let t = t + .iter() + .map(|t| { + t.iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect() + }) + .collect(); + Ok(t) + } + + #[test] + fn test_fused_matmul() -> Result<()> { + let device = Device::new_cuda(0)?; + + let a = Tensor::randn(0., 1., (8, 4), &device)?.to_dtype(DType::F32)?; + let b = Tensor::randn(0., 1., (2, 4), &device)?.to_dtype(DType::F32)?; + let bias = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; + + let cublaslt = CublasLt::new(&device)?; + + let res = fused_matmul(&a, &b, None, None, None, Some(&bias), None, cublaslt)?; + let expected = (b.matmul(&a.t()?)? + bias.broadcast_left(2)?)?; + + assert_eq!( + to_vec2_round(res.to_dtype(DType::F32)?, 4)?, + to_vec2_round(expected.to_dtype(DType::F32)?, 4)? + ); + Ok(()) + } + + #[test] + fn test_fused_batch_matmul() -> Result<()> { + let device = Device::new_cuda(0)?; + + let a = Tensor::randn(0., 1., (3, 8, 4), &device)?.to_dtype(DType::F32)?; + let b = Tensor::randn(0., 1., (3, 2, 4), &device)?.to_dtype(DType::F32)?; + let c = Tensor::randn(0., 1., (3, 2, 8), &device)?.to_dtype(DType::F32)?; + let bias = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; + + let cublaslt = CublasLt::new(&device)?; + + let res = fused_batch_matmul( + &a, + &b, + Some(&c), + None, + Some(1.0), + Some(&bias), + None, + cublaslt, + )?; + let expected = (b.matmul(&a.t()?)?.add(&c)? + bias.broadcast_left((3, 2))?)?; + + assert_eq!( + to_vec3_round(res.to_dtype(DType::F32)?, 4)?, + to_vec3_round(expected.to_dtype(DType::F32)?, 4)? + ); + Ok(()) + } +} diff --git a/candle-nn/src/cublaslt/mod.rs b/candle-nn/src/cublaslt/mod.rs new file mode 100644 index 0000000000..461556722d --- /dev/null +++ b/candle-nn/src/cublaslt/mod.rs @@ -0,0 +1,100 @@ +//! This module inspired from: +//! +//! https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/candle/src/layers/cublaslt.rs + +#![allow(unused_variables, unused_imports, dead_code)] + +use crate::Activation as CandleActivation; +use candle::{Device, Result, Tensor}; +use std::sync::{LazyLock, Mutex, Once}; + +#[cfg(feature = "cuda")] +mod api; + +#[cfg(feature = "cuda")] +use api::{fused_batch_matmul, fused_matmul, Activation, CublasLt}; + +static INIT: Once = Once::new(); +static mut CUBLASLT: Option = None; +pub(crate) static CUBLASLT_HANDLE: LazyLock>> = + LazyLock::new(|| Mutex::new(None)); + +/// Internal function to initialize the cublaslt handle wrapper, behind a LazyLock so initialization occurs +/// only once. +pub(crate) fn setup_cublas_lt_wrapper() { + unsafe { + INIT.call_once(|| { + #[cfg(not(feature = "cuda"))] + { + CUBLASLT = None; + } + + #[cfg(feature = "cuda")] + { + // Check if we can call the driver + // Then check if we can create a device + // Then check that the device is CUDA + use candle::cuda_backend::cudarc::driver; + CUBLASLT = driver::result::init() + .ok() + .and_then(|_| Device::cuda_if_available(0).ok()) + .and_then(|device| match device { + Device::Cuda(_) => Some(CublasLtWrapper { + cublaslt: CublasLt::new(&device).unwrap(), + }), + _ => None, + }); + } + }); + let cublaslt: Option<&'static CublasLtWrapper> = CUBLASLT.as_ref(); + *CUBLASLT_HANDLE.lock().unwrap() = cublaslt; + } +} + +#[derive(Debug, Clone)] +pub struct CublasLtWrapper { + #[cfg(feature = "cuda")] + pub cublaslt: CublasLt, +} + +impl CublasLtWrapper { + #[allow(clippy::too_many_arguments)] + pub fn batch_matmul( + &self, + a: &Tensor, + b: &Tensor, + out: Option<&Tensor>, + alpha: Option, + beta: Option, + bias: Option<&Tensor>, + act: Option, + ) -> Result { + #[cfg(feature = "cuda")] + { + let inner_act = act.map(|a| match a { + CandleActivation::Relu => Activation::Relu, + CandleActivation::Gelu => Activation::Gelu, + _ => unreachable!("Unsupported activation in cublaslt matmul"), + }); + let mut result = fused_batch_matmul( + a, + b, + out, + alpha, + beta, + bias, + inner_act, + self.cublaslt.clone(), + )?; + + if Some(CandleActivation::Swiglu) == act { + result = crate::ops::swiglu(&result)?; + } + Ok(result) + } + #[cfg(not(feature = "cuda"))] + { + candle::bail!("`cuda` feature is not enabled") + } + } +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index fcac58308c..34793cd8af 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -1,6 +1,8 @@ pub mod activation; +pub mod attention; pub mod batch_norm; pub mod conv; +pub mod cublaslt; pub mod embedding; pub mod encoding; pub mod func; @@ -19,6 +21,7 @@ pub mod var_builder; pub mod var_map; pub use activation::{prelu, Activation, PReLU}; +pub use attention::scaled_dot_product_attention; pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig}; pub use conv::{ conv1d, conv1d_no_bias, conv2d, conv2d_no_bias, conv_transpose1d, conv_transpose1d_no_bias, diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 6589b4b146..94d3f51fd9 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -29,6 +29,6 @@ tracing = { workspace = true } default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda"] -flash-attn = ["cuda", "dep:candle-flash-attn"] +flash-attn = ["cuda", "dep:candle-flash-attn", "candle-nn/flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"] metal = ["candle/metal", "candle-nn/metal"] From 7bbcf00f14f3681007bdf1c5e084e016a3b85926 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 4 Aug 2024 15:04:04 -0400 Subject: [PATCH 19/75] Update docs --- candle-nn/src/attention.rs | 4 ++-- candle-nn/src/cublaslt/api.rs | 2 +- candle-nn/src/cublaslt/mod.rs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/candle-nn/src/attention.rs b/candle-nn/src/attention.rs index 490e1c4eb5..108f2bbc37 100644 --- a/candle-nn/src/attention.rs +++ b/candle-nn/src/attention.rs @@ -37,13 +37,13 @@ fn naive_sdpa( att.matmul(&v.contiguous()?) } -/// Computes softmax(QK^T*sqrt(d_k))V +/// Computes (softmax(QK^T*sqrt(d_k)) + M)V. `M` is the attention mask, and is a bias (0 for unmasked, -inf for masked). /// /// The attention implementation is automatically accelerated and dispatched as follows: /// 1) If `use_flash_attn == true`, use a Flash Attention V2 kernel /// 2) If using CUDA, it will attempt to use cuBLASlt an optimized version /// 3) Otherwise, use the "naive" SDPA implementation - just matmuls and elementwise operations. -/// +/// /// Note that there may be minute differences in output because floating point operations are not associative. #[allow(unused_variables, clippy::too_many_arguments)] pub fn scaled_dot_product_attention( diff --git a/candle-nn/src/cublaslt/api.rs b/candle-nn/src/cublaslt/api.rs index 918ab31ef7..b14eecf91f 100644 --- a/candle-nn/src/cublaslt/api.rs +++ b/candle-nn/src/cublaslt/api.rs @@ -1,5 +1,5 @@ //! This module inspired from: -//! +//! //! https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/candle/src/layers/cublaslt.rs pub use candle::cuda_backend::cudarc::cublaslt::Activation; diff --git a/candle-nn/src/cublaslt/mod.rs b/candle-nn/src/cublaslt/mod.rs index 461556722d..e46aaabf67 100644 --- a/candle-nn/src/cublaslt/mod.rs +++ b/candle-nn/src/cublaslt/mod.rs @@ -1,5 +1,5 @@ //! This module inspired from: -//! +//! //! https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/candle/src/layers/cublaslt.rs #![allow(unused_variables, unused_imports, dead_code)] From 1bf7101c20ccceed68220cac37e05cdcc4c8997b Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 4 Aug 2024 17:36:50 -0400 Subject: [PATCH 20/75] Add matmul_bias_and_scale --- candle-core/src/backend.rs | 11 + candle-core/src/cpu_backend/mod.rs | 394 ++++++++++++++++++++++++- candle-core/src/cpu_backend/utils.rs | 41 +++ candle-core/src/cuda_backend/mod.rs | 70 +++++ candle-core/src/dummy_cuda_backend.rs | 13 + candle-core/src/dummy_metal_backend.rs | 13 + candle-core/src/error.rs | 8 + candle-core/src/metal_backend/mod.rs | 57 ++++ candle-core/src/storage.rs | 34 +++ candle-core/src/tensor.rs | 70 +++++ candle-metal-kernels/src/lib.rs | 4 +- candle-metal-kernels/src/tests.rs | 2 + 12 files changed, 714 insertions(+), 3 deletions(-) diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index afe3e40754..d419031b8e 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -97,6 +97,17 @@ pub trait BackendStorage: Sized { _: &Layout, ) -> Result; + fn matmul_bias_and_scale( + &self, + _: &Self, + _: &mut Self, + _: Option, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + _: &Layout, + ) -> Result<()>; + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>; #[allow(clippy::too_many_arguments)] diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 58773c8020..6205d834e9 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -1,3 +1,5 @@ +use std::ops::Deref; + use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; @@ -6,7 +8,7 @@ use rayon::prelude::*; mod utils; pub use utils::{ - binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8, + binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8, Map3, }; const USE_IM2COL_CONV1D: bool = true; @@ -1529,6 +1531,383 @@ impl Map2 for MatMul { } } +struct MatMulWithBias(MatMul); + +impl Deref for MatMulWithBias { + type Target = MatMul; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Map3 for MatMulWithBias { + const OP: &'static str = "mat_mul_c"; + + #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + c: &mut [T], + c_l: &Layout, + s: Option, + ) -> Result<()> { + use gemm::{gemm, Parallelism}; + + match T::DTYPE { + DType::F16 | DType::F32 | DType::F64 => {} + _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?, + } + + let (b, m, n, k) = self.0 .0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rank = lhs_stride.len(); + let lhs_cs = lhs_stride[rank - 1]; + let lhs_rs = lhs_stride[rank - 2]; + + let rhs_cs = rhs_stride[rank - 1]; + let rhs_rs = rhs_stride[rank - 2]; + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let dst_shape: Shape = (m, n).into(); + let dst_strides = dst_shape.stride_contiguous(); + let dst_rs = dst_strides[0]; + let dst_cs = dst_strides[1]; + + let num_threads = crate::utils::get_num_threads(); + let parallelism = if num_threads > 1 { + Parallelism::Rayon(num_threads) + } else { + Parallelism::None + }; + + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + crate::bail!("`c` start offset must be 0"); + } + if o2 != b * m * n { + crate::bail!("`c` end offset must be {}", b * m * n) + } + } + None => crate::bail!("`c` has to be contiguous"), + }; + + let alpha = T::from_f64(s.unwrap_or(1.0)); + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + gemm( + /* m: usize = */ m, + /* n: usize = */ n, + /* k: usize = */ k, + /* dst: *mut T = */ dst_p.as_mut_ptr(), + /* dst_cs: isize = */ dst_cs as isize, + /* dst_rs: isize = */ dst_rs as isize, + /* read_dst: bool = */ false, + /* lhs: *const T = */ lhs_p.as_ptr(), + /* lhs_cs: isize = */ lhs_cs as isize, + /* lhs_rs: isize = */ lhs_rs as isize, + /* rhs: *const T = */ rhs_p.as_ptr(), + /* rhs_cs: isize = */ rhs_cs as isize, + /* rhs_rs: isize = */ rhs_rs as isize, + /* alpha: T = */ alpha, + /* beta: T = */ T::one(), + /* conj_dst: bool = */ false, + /* conj_lhs: bool = */ false, + /* conj_rhs: bool = */ false, + parallelism, + ) + } + } + Ok(()) + } + + #[cfg(feature = "accelerate")] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + c: &mut [T], + c_l: &Layout, + s: Option, + ) -> Result> { + let (b, m, n, k) = self.0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? + }; + + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + crate::bail!("`c` start offset must be 0"); + } + if o2 != b * m * n { + crate::bail!("`c` end offset must be {}", b * m * n) + } + } + None => crate::bail!("`c` has to be contiguous"), + }; + + match T::DTYPE { + DType::F16 => { + crate::bail!("the accelerate backend does not support f16 matmul") + } + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::sgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 1., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::dgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 1., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, + } + Ok(dst) + } + + #[cfg(feature = "mkl")] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + c: &mut [T], + c_l: &Layout, + s: Option, + ) -> Result> { + let (b, m, n, k) = self.0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? + }; + + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + crate::bail!("`c` start offset must be 0"); + } + if o2 != b * m * n { + crate::bail!("`c` end offset must be {}", b * m * n) + } + } + None => crate::bail!("`c` has to be contiguous"), + }; + + match T::DTYPE { + DType::F16 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f16; + let b = lhs_p.as_ptr() as *const f16; + let c = dst_p.as_mut_ptr() as *mut f16; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::hgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ f16::from_f64(s.unwrap_or(1.)), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ f16::ONE, + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::sgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.) as f32, + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 0., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::dgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 0., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, + } + Ok(dst) + } +} + fn elu(v: T, alpha: T) -> T { if v.is_sign_positive() { v @@ -2433,6 +2812,19 @@ impl BackendStorage for CpuStorage { MatMul(bmnk).map(self, lhs_l, rhs, rhs_l) } + fn matmul_bias_and_scale( + &self, + rhs: &Self, + c: &mut Self, + s: Option, + bmnk: (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + c_l: &Layout, + ) -> Result<()> { + MatMulWithBias(MatMul(bmnk)).map(self, lhs_l, rhs, rhs_l, c, c_l, s) + } + fn device(&self) -> &Self::Device { &CpuDevice } diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index 3e0c69b4f7..7c3429ff42 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -58,6 +58,47 @@ pub trait Map2 { } } +pub trait Map3 { + const OP: &'static str; + fn f( + &self, + v1: &[T], + l1: &Layout, + v2: &[T], + l2: &Layout, + v3: &mut [T], + l3: &Layout, + s: Option, + ) -> Result<()>; + + fn map( + &self, + v1: &C, + l1: &Layout, + v2: &C, + l2: &Layout, + v3: &mut C, + l3: &Layout, + s: Option, + ) -> Result<()> { + match (v1, v2, v3) { + (C::U8(v1), C::U8(v2), C::U8(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::U32(v1), C::U32(v2), C::U32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::I64(v1), C::I64(v2), C::I64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::BF16(v1), C::BF16(v2), C::BF16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F16(v1), C::F16(v2), C::F16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F32(v1), C::F32(v2), C::F32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F64(v1), C::F64(v2), C::F64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + _ => Err(Error::DTypeMismatchBinaryOp { + lhs: v1.dtype(), + rhs: v2.dtype(), + op: Self::OP, + } + .bt()), + } + } +} + pub trait Map2U8 { const OP: &'static str; fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 7edad3d409..eb6b2aac15 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1709,6 +1709,76 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } + fn matmul_bias_and_scale( + &self, + rhs: &Self, + c: &mut Self, + s: Option, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + c_l: &Layout, + ) -> Result<()> { + let elem_count = b * m * n; + + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + crate::bail!("`c` start offset must be 0"); + } + if o2 != elem_count { + crate::bail!("`c` end offset must be {}", elem_count) + } + } + None => crate::bail!("`c` has to be contiguous"), + }; + + match (&self.slice, &rhs.slice, &mut c.slice) { + ( + CudaStorageSlice::BF16(lhs), + CudaStorageSlice::BF16(rhs), + CudaStorageSlice::BF16(c), + ) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config( + bf16::from_f64(s.unwrap_or(1.0)), + bf16::ZERO, + (b, m, n, k), + lhs_l, + rhs_l, + )?; + unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, c) }.w()?; + } + (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs), CudaStorageSlice::F16(c)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config( + f16::from_f64(s.unwrap_or(1.0)), + f16::ZERO, + (b, m, n, k), + lhs_l, + rhs_l, + )?; + unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, c) }.w()?; + } + (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs), CudaStorageSlice::F32(c)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(s.unwrap_or(1.0) as f32, 0., (b, m, n, k), lhs_l, rhs_l)?; + unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, c) }.w()?; + } + (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs), CudaStorageSlice::F64(c)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(s.unwrap_or(1.0), 0., (b, m, n, k), lhs_l, rhs_l)?; + unsafe { self.device.blas.gemm_strided_batched(cfg, rhs, lhs, c) }.w()?; + } + _ => Err(CudaError::InternalError("dtype mismatch in matmul op"))?, + }; + Ok(()) + } + fn copy2d( &self, dst: &mut Self, diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 68eef1efed..2c4d016f7c 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -150,6 +150,19 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + fn matmul_bias_and_scale( + &self, + _: &Self, + _: &mut Self, + _: Option, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + _: &Layout, + ) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index a1c2394d49..db4b29667d 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -162,6 +162,19 @@ impl crate::backend::BackendStorage for MetalStorage { Err(Error::NotCompiledWithMetalSupport) } + fn matmul_bias_and_scale( + &self, + _: &Self, + _: &mut Self, + _: Option, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + _: &Layout, + ) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { Err(Error::NotCompiledWithMetalSupport) } diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index e7112e2e61..2aacfd54ef 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -100,6 +100,14 @@ pub enum Error { op: &'static str, }, + #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}, c: {c:?}")] + DeviceMismatchBinaryOp3 { + lhs: DeviceLocation, + rhs: DeviceLocation, + c: DeviceLocation, + op: &'static str, + }, + // === Op Specific Errors === #[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")] NarrowInvalidArgs { diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 09d5fd49cd..3d00d0cb2c 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1430,6 +1430,8 @@ impl BackendStorage for MetalStorage { rhs_l.start_offset() * rhs.dtype.size_in_bytes(), &rhs.buffer, &buffer, + 1., + 0., ) .map_err(MetalError::from)?; Ok(Self::new( @@ -1440,6 +1442,61 @@ impl BackendStorage for MetalStorage { )) } + fn matmul_bias_and_scale( + &self, + rhs: &Self, + c: &mut Self, + s: Option, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + c_l: &Layout, + ) -> Result<()> { + let name = match self.dtype { + DType::F32 => "sgemm", + DType::F16 => "hgemm", + DType::BF16 => "bgemm", + dtype => { + return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into()) + } + }; + + let elem_count = b * m * n; + + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + crate::bail!("`c` start offset must be 0"); + } + if o2 != elem_count { + crate::bail!("`c` end offset must be {}", elem_count) + } + } + None => crate::bail!("`c` has to be contiguous"), + }; + + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("matmul"); + candle_metal_kernels::call_gemm( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + (b, m, n, k), + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &self.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &rhs.buffer, + &c.buffer, + s.unwrap_or(1.) as f32, + 1., + ) + .map_err(MetalError::from)?; + Ok(()) + } + fn copy2d( &self, dst: &mut Self, diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 8a0637e304..383fbca174 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -736,6 +736,40 @@ impl Storage { } } + pub(crate) fn matmul_bias_and_scale( + &self, + rhs: &Self, + c: &mut Self, + s: Option, + bmnk: (usize, usize, usize, usize), + lhs_layout: &Layout, + rhs_layout: &Layout, + c_layout: &Layout, + ) -> Result<()> { + self.same_device(rhs, "matmul_bias_and_scale")?; + self.same_dtype(rhs, "matmul_bias_and_scale")?; + self.same_device(c, "matmul_bias_and_scale")?; + self.same_dtype(c, "matmul_bias_and_scale")?; + match (self, rhs, c) { + (Self::Cpu(lhs), Self::Cpu(rhs), Self::Cpu(c)) => { + lhs.matmul_bias_and_scale(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) + } + (Self::Cuda(lhs), Self::Cuda(rhs), Self::Cuda(c)) => { + lhs.matmul_bias_and_scale(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) + } + (Self::Metal(lhs), Self::Metal(rhs), Self::Metal(c)) => { + lhs.matmul_bias_and_scale(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) + } + (lhs, rhs, c) => Err(Error::DeviceMismatchBinaryOp3 { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + c: c.device().location(), + op: "matmul_bias_and_scale", + } + .bt()), + } + } + // self, the source can be strided whereas dst is contiguous. pub(crate) fn copy_strided_src( &self, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e8b026057e..eb5197a889 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1226,6 +1226,76 @@ impl Tensor { } } + /// Returns the matrix-multiplication of the input tensor with the other provided tensor. The result is scaled + /// and then added to the output tensor, the bias tensor `c`. + /// + /// This is incompatible with gradient tracking. No gradients will be tracked on this operation. + /// + /// # Arguments + /// + /// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`. + /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`. + /// * `c` - A tensor with dimensions `b1, b2, ..., bi, m, n`, into which the result is accumulated and added to. + /// * `scale` - Factor to multiply `self` x `rhs` by + pub fn matmul_bias_and_scale( + &self, + rhs: &Self, + c: &mut Self, + scale: Option, + ) -> Result<()> { + let a_dims = self.shape().dims(); + let b_dims = rhs.shape().dims(); + + let dim = a_dims.len(); + + if dim < 2 || b_dims.len() != dim { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul", + } + .bt())? + } + + let m = a_dims[dim - 2]; + let k = a_dims[dim - 1]; + let k2 = b_dims[dim - 2]; + let n = b_dims[dim - 1]; + + let exp_c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]); + if exp_c_shape.elem_count() == 0 || k == 0 { + bail!("Expected `c` to have more than one element, got 0."); + } + if exp_c_shape != c.shape().clone() { + Err(Error::UnexpectedShape { + msg: "`c` has an unexpected shape.".to_string(), + expected: exp_c_shape, + got: c.shape().clone(), + })? + } + + let batching: usize = a_dims[..dim - 2].iter().product(); + let batching_b: usize = b_dims[..dim - 2].iter().product(); + if k != k2 || batching != batching_b { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul", + } + .bt())? + } + + self.storage().matmul_bias_and_scale( + &rhs.storage(), + &mut c.storage_mut(), + scale, + (batching, m, n, k), + self.layout(), + rhs.layout(), + c.layout(), + ) + } + /// Returns a tensor with the same shape as the input tensor, the values are taken from /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the /// input tensor is equal to zero. diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 743b9fe2b3..a97d327468 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1465,6 +1465,8 @@ pub fn call_gemm( rhs_offset: usize, rhs_buffer: &Buffer, output: &Buffer, + alpha: f32, + beta: f32, ) -> Result<(), MetalKernelError> { assert!(rhs_stride.len() >= 2); assert!(lhs_stride.len() >= 2); @@ -1499,8 +1501,6 @@ pub fn call_gemm( })?; }; let d_trans = false; - let alpha = 1.0f32; - let beta = 0.0f32; let batched = b > 1; let fused_activation = false; let fused_bias = false; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 30c454af3f..0a37632bd1 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1086,6 +1086,8 @@ fn run_gemm( rhs_offset, &rhs, &output, + 1., + 0., ) .unwrap(); command_buffer.commit(); From d6d3d1890d90dee209c64f42140330524f42ddcc Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 4 Aug 2024 17:37:33 -0400 Subject: [PATCH 21/75] Rename --- candle-core/src/backend.rs | 2 +- candle-core/src/cpu_backend/mod.rs | 2 +- candle-core/src/cuda_backend/mod.rs | 2 +- candle-core/src/dummy_cuda_backend.rs | 2 +- candle-core/src/dummy_metal_backend.rs | 2 +- candle-core/src/metal_backend/mod.rs | 2 +- candle-core/src/storage.rs | 18 +++++++++--------- candle-core/src/tensor.rs | 9 ++------- 8 files changed, 17 insertions(+), 22 deletions(-) diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index d419031b8e..fe646230c3 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -97,7 +97,7 @@ pub trait BackendStorage: Sized { _: &Layout, ) -> Result; - fn matmul_bias_and_scale( + fn matmul_with_beta( &self, _: &Self, _: &mut Self, diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 6205d834e9..a010191862 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -2812,7 +2812,7 @@ impl BackendStorage for CpuStorage { MatMul(bmnk).map(self, lhs_l, rhs, rhs_l) } - fn matmul_bias_and_scale( + fn matmul_with_beta( &self, rhs: &Self, c: &mut Self, diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index eb6b2aac15..ccb160e60f 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1709,7 +1709,7 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } - fn matmul_bias_and_scale( + fn matmul_with_beta( &self, rhs: &Self, c: &mut Self, diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 2c4d016f7c..c69afe03b6 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -150,7 +150,7 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - fn matmul_bias_and_scale( + fn matmul_with_beta( &self, _: &Self, _: &mut Self, diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index db4b29667d..82704bd87b 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -162,7 +162,7 @@ impl crate::backend::BackendStorage for MetalStorage { Err(Error::NotCompiledWithMetalSupport) } - fn matmul_bias_and_scale( + fn matmul_with_beta( &self, _: &Self, _: &mut Self, diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 3d00d0cb2c..8029cbd203 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1442,7 +1442,7 @@ impl BackendStorage for MetalStorage { )) } - fn matmul_bias_and_scale( + fn matmul_with_beta( &self, rhs: &Self, c: &mut Self, diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 383fbca174..a80f9c3b6f 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -736,7 +736,7 @@ impl Storage { } } - pub(crate) fn matmul_bias_and_scale( + pub(crate) fn matmul_with_beta( &self, rhs: &Self, c: &mut Self, @@ -746,25 +746,25 @@ impl Storage { rhs_layout: &Layout, c_layout: &Layout, ) -> Result<()> { - self.same_device(rhs, "matmul_bias_and_scale")?; - self.same_dtype(rhs, "matmul_bias_and_scale")?; - self.same_device(c, "matmul_bias_and_scale")?; - self.same_dtype(c, "matmul_bias_and_scale")?; + self.same_device(rhs, "matmul_with_beta")?; + self.same_dtype(rhs, "matmul_with_beta")?; + self.same_device(c, "matmul_with_beta")?; + self.same_dtype(c, "matmul_with_beta")?; match (self, rhs, c) { (Self::Cpu(lhs), Self::Cpu(rhs), Self::Cpu(c)) => { - lhs.matmul_bias_and_scale(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) + lhs.matmul_with_beta(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) } (Self::Cuda(lhs), Self::Cuda(rhs), Self::Cuda(c)) => { - lhs.matmul_bias_and_scale(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) + lhs.matmul_with_beta(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) } (Self::Metal(lhs), Self::Metal(rhs), Self::Metal(c)) => { - lhs.matmul_bias_and_scale(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) + lhs.matmul_with_beta(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) } (lhs, rhs, c) => Err(Error::DeviceMismatchBinaryOp3 { lhs: lhs.device().location(), rhs: rhs.device().location(), c: c.device().location(), - op: "matmul_bias_and_scale", + op: "matmul_with_beta", } .bt()), } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index eb5197a889..f52ddded95 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1237,12 +1237,7 @@ impl Tensor { /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`. /// * `c` - A tensor with dimensions `b1, b2, ..., bi, m, n`, into which the result is accumulated and added to. /// * `scale` - Factor to multiply `self` x `rhs` by - pub fn matmul_bias_and_scale( - &self, - rhs: &Self, - c: &mut Self, - scale: Option, - ) -> Result<()> { + pub fn matmul_with_beta(&self, rhs: &Self, c: &mut Self, scale: Option) -> Result<()> { let a_dims = self.shape().dims(); let b_dims = rhs.shape().dims(); @@ -1285,7 +1280,7 @@ impl Tensor { .bt())? } - self.storage().matmul_bias_and_scale( + self.storage().matmul_with_beta( &rhs.storage(), &mut c.storage_mut(), scale, From e20d85a783a3187b28011a41e7769b6f0e430c84 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 4 Aug 2024 18:24:56 -0400 Subject: [PATCH 22/75] Add a simple test and fix for cpu --- candle-core/src/backend.rs | 1 + candle-core/src/cpu_backend/mod.rs | 6 +++--- candle-core/src/cpu_backend/utils.rs | 2 ++ candle-core/src/cuda_backend/mod.rs | 8 ++++---- candle-core/src/storage.rs | 1 + candle-core/tests/matmul_tests.rs | 19 +++++++++++++++++++ 6 files changed, 30 insertions(+), 7 deletions(-) diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index fe646230c3..2c93deb623 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -97,6 +97,7 @@ pub trait BackendStorage: Sized { _: &Layout, ) -> Result; + #[allow(clippy::too_many_arguments)] fn matmul_with_beta( &self, _: &Self, diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index a010191862..0e2c23316d 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -1615,15 +1615,15 @@ impl Map3 for MatMulWithBias { /* dst: *mut T = */ dst_p.as_mut_ptr(), /* dst_cs: isize = */ dst_cs as isize, /* dst_rs: isize = */ dst_rs as isize, - /* read_dst: bool = */ false, + /* read_dst: bool = */ true, /* lhs: *const T = */ lhs_p.as_ptr(), /* lhs_cs: isize = */ lhs_cs as isize, /* lhs_rs: isize = */ lhs_rs as isize, /* rhs: *const T = */ rhs_p.as_ptr(), /* rhs_cs: isize = */ rhs_cs as isize, /* rhs_rs: isize = */ rhs_rs as isize, - /* alpha: T = */ alpha, - /* beta: T = */ T::one(), + /* alpha: T = */ T::one(), + /* beta: T = */ alpha, /* conj_dst: bool = */ false, /* conj_lhs: bool = */ false, /* conj_rhs: bool = */ false, diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index 7c3429ff42..1566f6ebfd 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -60,6 +60,7 @@ pub trait Map2 { pub trait Map3 { const OP: &'static str; + #[allow(clippy::too_many_arguments)] fn f( &self, v1: &[T], @@ -71,6 +72,7 @@ pub trait Map3 { s: Option, ) -> Result<()>; + #[allow(clippy::too_many_arguments)] fn map( &self, v1: &C, diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index ccb160e60f..8cf2fd0805 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1743,7 +1743,7 @@ impl BackendStorage for CudaStorage { let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config( bf16::from_f64(s.unwrap_or(1.0)), - bf16::ZERO, + bf16::ONE, (b, m, n, k), lhs_l, rhs_l, @@ -1755,7 +1755,7 @@ impl BackendStorage for CudaStorage { let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config( f16::from_f64(s.unwrap_or(1.0)), - f16::ZERO, + f16::ONE, (b, m, n, k), lhs_l, rhs_l, @@ -1765,13 +1765,13 @@ impl BackendStorage for CudaStorage { (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs), CudaStorageSlice::F32(c)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); - let cfg = gemm_config(s.unwrap_or(1.0) as f32, 0., (b, m, n, k), lhs_l, rhs_l)?; + let cfg = gemm_config(s.unwrap_or(1.0) as f32, 1., (b, m, n, k), lhs_l, rhs_l)?; unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, c) }.w()?; } (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs), CudaStorageSlice::F64(c)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); - let cfg = gemm_config(s.unwrap_or(1.0), 0., (b, m, n, k), lhs_l, rhs_l)?; + let cfg = gemm_config(s.unwrap_or(1.0), 1., (b, m, n, k), lhs_l, rhs_l)?; unsafe { self.device.blas.gemm_strided_batched(cfg, rhs, lhs, c) }.w()?; } _ => Err(CudaError::InternalError("dtype mismatch in matmul op"))?, diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index a80f9c3b6f..9aaf975f3d 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -736,6 +736,7 @@ impl Storage { } } + #[allow(clippy::too_many_arguments)] pub(crate) fn matmul_with_beta( &self, rhs: &Self, diff --git a/candle-core/tests/matmul_tests.rs b/candle-core/tests/matmul_tests.rs index c1c16401a8..f445a7f1bb 100644 --- a/candle-core/tests/matmul_tests.rs +++ b/candle-core/tests/matmul_tests.rs @@ -109,7 +109,26 @@ fn mm_layout(device: &Device) -> Result<()> { Ok(()) } +fn matmul_beta(device: &Device) -> Result<()> { + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 1.0, 1.0, 1.0]; + let mut c = Tensor::from_slice(&data, (2, 2), device)?; + + a.matmul_with_beta(&b, &mut c, None)?; + assert_eq!(c.to_vec2::()?, &[[8.0f32, 11.0], [16.0, 23.0]]); + Ok(()) +} + test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal); +test_device!( + matmul_beta, + matmul_beta_cpu, + matmul_beta_gpu, + matmul_beta_metal +); test_device!( matmul_bf16, matmul_bf16_cpu, From 8d2f32ab41eb069d5307e53b82b905913c1f12c7 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 4 Aug 2024 18:38:52 -0400 Subject: [PATCH 23/75] Update sdpa function --- candle-core/src/backend.rs | 2 +- candle-core/src/cpu_backend/mod.rs | 2 +- candle-core/src/cuda_backend/mod.rs | 2 +- candle-core/src/dummy_cuda_backend.rs | 2 +- candle-core/src/dummy_metal_backend.rs | 2 +- candle-core/src/metal_backend/mod.rs | 2 +- candle-core/src/storage.rs | 18 +- candle-core/src/tensor.rs | 16 +- candle-core/tests/matmul_tests.rs | 2 +- candle-nn/src/attention.rs | 93 +-- candle-nn/src/cublaslt/api.rs | 942 ------------------------- candle-nn/src/cublaslt/mod.rs | 100 --- candle-nn/src/lib.rs | 1 - 13 files changed, 46 insertions(+), 1138 deletions(-) delete mode 100644 candle-nn/src/cublaslt/api.rs delete mode 100644 candle-nn/src/cublaslt/mod.rs diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 2c93deb623..b1e67921a9 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -98,7 +98,7 @@ pub trait BackendStorage: Sized { ) -> Result; #[allow(clippy::too_many_arguments)] - fn matmul_with_beta( + fn matmul_with_alpha_beta( &self, _: &Self, _: &mut Self, diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 0e2c23316d..466ab5195b 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -2812,7 +2812,7 @@ impl BackendStorage for CpuStorage { MatMul(bmnk).map(self, lhs_l, rhs, rhs_l) } - fn matmul_with_beta( + fn matmul_with_alpha_beta( &self, rhs: &Self, c: &mut Self, diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 8cf2fd0805..856972adae 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1709,7 +1709,7 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } - fn matmul_with_beta( + fn matmul_with_alpha_beta( &self, rhs: &Self, c: &mut Self, diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index c69afe03b6..3035a02976 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -150,7 +150,7 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - fn matmul_with_beta( + fn matmul_with_alpha_beta( &self, _: &Self, _: &mut Self, diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index 82704bd87b..dd19ec5cc7 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -162,7 +162,7 @@ impl crate::backend::BackendStorage for MetalStorage { Err(Error::NotCompiledWithMetalSupport) } - fn matmul_with_beta( + fn matmul_with_alpha_beta( &self, _: &Self, _: &mut Self, diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 8029cbd203..b02edac825 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1442,7 +1442,7 @@ impl BackendStorage for MetalStorage { )) } - fn matmul_with_beta( + fn matmul_with_alpha_beta( &self, rhs: &Self, c: &mut Self, diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 9aaf975f3d..fe635c40f9 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -737,7 +737,7 @@ impl Storage { } #[allow(clippy::too_many_arguments)] - pub(crate) fn matmul_with_beta( + pub(crate) fn matmul_with_alpha_beta( &self, rhs: &Self, c: &mut Self, @@ -747,25 +747,25 @@ impl Storage { rhs_layout: &Layout, c_layout: &Layout, ) -> Result<()> { - self.same_device(rhs, "matmul_with_beta")?; - self.same_dtype(rhs, "matmul_with_beta")?; - self.same_device(c, "matmul_with_beta")?; - self.same_dtype(c, "matmul_with_beta")?; + self.same_device(rhs, "matmul_with_alpha_beta")?; + self.same_dtype(rhs, "matmul_with_alpha_beta")?; + self.same_device(c, "matmul_with_alpha_beta")?; + self.same_dtype(c, "matmul_with_alpha_beta")?; match (self, rhs, c) { (Self::Cpu(lhs), Self::Cpu(rhs), Self::Cpu(c)) => { - lhs.matmul_with_beta(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) + lhs.matmul_with_alpha_beta(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) } (Self::Cuda(lhs), Self::Cuda(rhs), Self::Cuda(c)) => { - lhs.matmul_with_beta(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) + lhs.matmul_with_alpha_beta(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) } (Self::Metal(lhs), Self::Metal(rhs), Self::Metal(c)) => { - lhs.matmul_with_beta(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) + lhs.matmul_with_alpha_beta(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) } (lhs, rhs, c) => Err(Error::DeviceMismatchBinaryOp3 { lhs: lhs.device().location(), rhs: rhs.device().location(), c: c.device().location(), - op: "matmul_with_beta", + op: "matmul_with_alpha_beta", } .bt()), } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f52ddded95..c65f1f5846 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1229,6 +1229,13 @@ impl Tensor { /// Returns the matrix-multiplication of the input tensor with the other provided tensor. The result is scaled /// and then added to the output tensor, the bias tensor `c`. /// + /// If `scale` is None, then the output is as follows: + /// `c := c + axb` + /// + /// Else: + /// `c := c + scale * (axb)` + /// + /// /// This is incompatible with gradient tracking. No gradients will be tracked on this operation. /// /// # Arguments @@ -1237,7 +1244,12 @@ impl Tensor { /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`. /// * `c` - A tensor with dimensions `b1, b2, ..., bi, m, n`, into which the result is accumulated and added to. /// * `scale` - Factor to multiply `self` x `rhs` by - pub fn matmul_with_beta(&self, rhs: &Self, c: &mut Self, scale: Option) -> Result<()> { + pub fn matmul_with_alpha_beta( + &self, + rhs: &Self, + c: &mut Self, + scale: Option, + ) -> Result<()> { let a_dims = self.shape().dims(); let b_dims = rhs.shape().dims(); @@ -1280,7 +1292,7 @@ impl Tensor { .bt())? } - self.storage().matmul_with_beta( + self.storage().matmul_with_alpha_beta( &rhs.storage(), &mut c.storage_mut(), scale, diff --git a/candle-core/tests/matmul_tests.rs b/candle-core/tests/matmul_tests.rs index f445a7f1bb..c2af6012be 100644 --- a/candle-core/tests/matmul_tests.rs +++ b/candle-core/tests/matmul_tests.rs @@ -117,7 +117,7 @@ fn matmul_beta(device: &Device) -> Result<()> { let data = vec![1.0f32, 1.0, 1.0, 1.0]; let mut c = Tensor::from_slice(&data, (2, 2), device)?; - a.matmul_with_beta(&b, &mut c, None)?; + a.matmul_with_alpha_beta(&b, &mut c, None)?; assert_eq!(c.to_vec2::()?, &[[8.0f32, 11.0], [16.0, 23.0]]); Ok(()) } diff --git a/candle-nn/src/attention.rs b/candle-nn/src/attention.rs index 108f2bbc37..525f3d9eb2 100644 --- a/candle-nn/src/attention.rs +++ b/candle-nn/src/attention.rs @@ -1,6 +1,4 @@ -use candle::{Device, Result, Tensor}; - -use crate::cublaslt::{setup_cublas_lt_wrapper, CUBLASLT_HANDLE}; +use candle::{Result, Tensor}; #[cfg(feature = "flash-attn")] pub fn flash_attn( @@ -18,31 +16,11 @@ pub fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result unimplemented!("Compile with '--features flash-attn'") } -/// Computes softmax(QK^T*sqrt(d_k))V -fn naive_sdpa( - q: &Tensor, - k: &Tensor, - v: &Tensor, - head_dim: usize, - mask: Option<&Tensor>, -) -> Result { - let att = (&q.contiguous()?.matmul(&k.t()?.contiguous()?)? / (head_dim as f64).sqrt())?; - - let att = match mask { - Some(m) => att.broadcast_add(m)?, - None => att, - }; - let att = crate::ops::softmax_last_dim(&att)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - att.matmul(&v.contiguous()?) -} - /// Computes (softmax(QK^T*sqrt(d_k)) + M)V. `M` is the attention mask, and is a bias (0 for unmasked, -inf for masked). /// /// The attention implementation is automatically accelerated and dispatched as follows: /// 1) If `use_flash_attn == true`, use a Flash Attention V2 kernel -/// 2) If using CUDA, it will attempt to use cuBLASlt an optimized version -/// 3) Otherwise, use the "naive" SDPA implementation - just matmuls and elementwise operations. +/// 2) Otherwise, use SDPA with fusion of softmax scale and attention bias application /// /// Note that there may be minute differences in output because floating point operations are not associative. #[allow(unused_variables, clippy::too_many_arguments)] @@ -52,7 +30,7 @@ pub fn scaled_dot_product_attention( v: &Tensor, n_attn_heads: usize, head_dim: usize, - mask: Option<&Tensor>, + mask: Option, use_flash_attn: bool, b_sz: usize, seq_len: usize, @@ -66,58 +44,19 @@ pub fn scaled_dot_product_attention( return flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2); } - // Initializiation is behind a LazyLock. So, the first call will be slightly slower. - // No cost to the other calls. - setup_cublas_lt_wrapper(); - - if let (Device::Cuda(_), Some(cublaslt)) = (q.device(), *CUBLASLT_HANDLE.lock().unwrap()) { - #[cfg(feature = "cuda")] - { - // cuBLASLt batch matmul implementation requires inputs to be dims3 - let k = k.flatten(0, 1)?; - let q = q.flatten(0, 1)?; - let v = v.flatten(0, 1)?; - let attention_bias = mask.map(|mask| mask.flatten(0, 1)).transpose()?; - - // If attention_bias is set, we fuse the add by giving it as the output matrix - // and setting beta to 1.0 - let beta = match attention_bias.is_some() { - true => Some(1.0), - false => None, - }; - - // Batch matrix multiplication - // Fuse softmax scale and attention_bias add - let attention_scores = cublaslt.batch_matmul( - &k, - &q, - attention_bias.as_ref(), - Some((1.0 / (head_dim as f64).sqrt()) as f32), - beta, - None, - None, - )?; - let attention_probs = crate::ops::softmax_last_dim(&attention_scores)?; - - let context_layer = cublaslt.batch_matmul( - &v.t()?.contiguous()?, - &attention_probs, - // We save one allocation - Some(&q), - None, - None, - None, - None, + let att = match mask { + Some(mut m) => { + q.contiguous()?.matmul_with_alpha_beta( + &k.t()?.contiguous()?, + &mut m, + Some(1. / (head_dim as f64).sqrt()), )?; - - // Reshape to dims4 - context_layer.reshape((b_sz, n_attn_heads, seq_len, head_dim)) + m } - #[cfg(not(feature = "cuda"))] - { - candle::bail!("`cuda` feature is not enabled") - } - } else { - naive_sdpa(q, k, v, head_dim, mask) - } + None => (&q.contiguous()?.matmul(&k.t()?.contiguous()?)? / (head_dim as f64).sqrt())?, + }; + + let att = crate::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous()?) } diff --git a/candle-nn/src/cublaslt/api.rs b/candle-nn/src/cublaslt/api.rs deleted file mode 100644 index b14eecf91f..0000000000 --- a/candle-nn/src/cublaslt/api.rs +++ /dev/null @@ -1,942 +0,0 @@ -//! This module inspired from: -//! -//! https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/candle/src/layers/cublaslt.rs - -pub use candle::cuda_backend::cudarc::cublaslt::Activation; -use std::ffi::c_int; - -use candle::backend::BackendStorage; -use candle::cuda_backend::WrapErr; -use candle::{CpuStorage, Device, Layout, Result, Shape, Storage, Tensor}; -use half::{bf16, f16}; -use std::sync::Arc; - -use candle::cuda_backend::cudarc::cublaslt::{CudaBlasLT, Matmul, MatmulConfig}; - -#[derive(Debug, Clone)] -pub struct CublasLt(Arc); - -impl CublasLt { - pub fn new(device: &Device) -> Result { - let dev = match device { - Device::Cuda(d) => d, - _ => candle::bail!("`device` must be a `cuda` device"), - }; - - let inner = CudaBlasLT::new(dev.cuda_device()).unwrap(); - - Ok(Self(Arc::new(inner))) - } -} - -pub struct CublasLTMatmul { - pub cublaslt: Arc, - pub act: Option, - pub c: Option, - pub alpha: Option, - pub beta: Option, -} - -impl CublasLTMatmul { - pub fn fwd_f16( - &self, - a: &candle::CudaStorage, - a_l: &Layout, - b: &candle::CudaStorage, - b_l: &Layout, - bias: Option<&candle::CudaStorage>, - bias_l: Option<&Layout>, - ) -> Result<(candle::CudaStorage, Shape)> { - let dev = a.device(); - - // Assume TN - let (m, k) = a_l.shape().dims2()?; - - let (n, b_1) = b_l.shape().dims2()?; - - if b_1 != k { - candle::bail!("This layer only supports TN layout"); - } - - let lda = k; - let ldb = k; - let ldc = m; - - let out_shape = Shape::from((n, m)); - - let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); - let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); - - let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { - if bias_l.shape().dims1()? != m { - candle::bail!("Bias does not have the correct shape"); - } - - Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) - } else { - None - }; - - let mut out = if let Some(c) = &self.c { - let (c, c_l) = c.storage_and_layout(); - let c = match &*c { - Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle::bail!("`c` must be a cuda tensor"), - }; - match c_l.contiguous_offsets() { - Some((o1, o2)) => { - if o1 != 0 { - candle::bail!("`c` start offset must be 0"); - } - if o2 != out_shape.elem_count() { - candle::bail!("`c` end offset must be {}", out_shape.elem_count()) - } - } - None => candle::bail!("`c` has to be contiguous"), - }; - if c_l.shape().dims2()? != (n, m) { - candle::bail!("`c` does not have the correct shape"); - } - - c.clone() - } else { - // Allocate out tensor - unsafe { dev.alloc::(out_shape.elem_count()).w()? } - }; - - let config = MatmulConfig { - transa: true, - transb: false, - m: m as u64, - n: n as u64, - k: k as u64, - alpha: self.alpha.unwrap_or(1.0), - lda: lda as i64, - ldb: ldb as i64, - beta: self.beta.unwrap_or(0.0), - ldc: ldc as i64, - stride_a: None, - stride_b: None, - stride_c: None, - stride_bias: None, - batch_size: None, - }; - - unsafe { - self.cublaslt - .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) - .map_err(|e| candle::Error::Cuda(Box::new(e)))?; - } - - let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); - - Ok((out, out_shape)) - } - - pub fn fwd_bf16( - &self, - a: &candle::CudaStorage, - a_l: &Layout, - b: &candle::CudaStorage, - b_l: &Layout, - bias: Option<&candle::CudaStorage>, - bias_l: Option<&Layout>, - ) -> Result<(candle::CudaStorage, Shape)> { - let dev = a.device(); - - // Assume TN - let (m, k) = a_l.shape().dims2()?; - - let (n, b_1) = b_l.shape().dims2()?; - - if b_1 != k { - candle::bail!("This layer only supports TN layout"); - } - - let lda = k; - let ldb = k; - let ldc = m; - - let out_shape = Shape::from((n, m)); - - let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); - let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); - - let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { - if bias_l.shape().dims1()? != m { - candle::bail!("Bias does not have the correct shape"); - } - - Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) - } else { - None - }; - - let mut out = if let Some(c) = &self.c { - let (c, c_l) = c.storage_and_layout(); - let c = match &*c { - Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle::bail!("`c` must be a cuda tensor"), - }; - match c_l.contiguous_offsets() { - Some((o1, o2)) => { - if o1 != 0 { - candle::bail!("`c` start offset must be 0"); - } - if o2 != out_shape.elem_count() { - candle::bail!("`c` end offset must be {}", out_shape.elem_count()) - } - } - None => candle::bail!("`c` has to be contiguous"), - }; - if c_l.shape().dims2()? != (n, m) { - candle::bail!("`c` does not have the correct shape"); - } - - c.clone() - } else { - // Allocate out tensor - unsafe { dev.alloc::(out_shape.elem_count()).w()? } - }; - - let config = MatmulConfig { - transa: true, - transb: false, - m: m as u64, - n: n as u64, - k: k as u64, - alpha: self.alpha.unwrap_or(1.0), - lda: lda as i64, - ldb: ldb as i64, - beta: self.beta.unwrap_or(0.0), - ldc: ldc as i64, - stride_a: None, - stride_b: None, - stride_c: None, - stride_bias: None, - batch_size: None, - }; - - unsafe { - self.cublaslt - .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) - .map_err(|e| candle::Error::Cuda(Box::new(e)))?; - } - - let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); - - Ok((out, out_shape)) - } - - pub fn fwd_f32( - &self, - a: &candle::CudaStorage, - a_l: &Layout, - b: &candle::CudaStorage, - b_l: &Layout, - bias: Option<&candle::CudaStorage>, - bias_l: Option<&Layout>, - ) -> Result<(candle::CudaStorage, Shape)> { - let dev = a.device(); - - // Assume TN - let (m, k) = a_l.shape().dims2()?; - - let (n, b_1) = b_l.shape().dims2()?; - - if b_1 != k { - candle::bail!("This layer only supports TN layout"); - } - - let lda = k; - let ldb = k; - let ldc = m; - - let out_shape = Shape::from((n, m)); - - let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); - let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); - - let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { - if bias_l.shape().dims1()? != m { - candle::bail!("Bias does not have the correct shape"); - } - - Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) - } else { - None - }; - - let mut out = if let Some(c) = &self.c { - let (c, c_l) = c.storage_and_layout(); - let c = match &*c { - Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle::bail!("`c` must be a cuda tensor"), - }; - match c_l.contiguous_offsets() { - Some((o1, o2)) => { - if o1 != 0 { - candle::bail!("`c` start offset must be 0"); - } - if o2 != out_shape.elem_count() { - candle::bail!("`c` end offset must be {}", out_shape.elem_count()) - } - } - None => candle::bail!("`c` has to be contiguous"), - }; - if c_l.shape().dims2()? != (n, m) { - candle::bail!("`c` does not have the correct shape"); - } - - c.clone() - } else { - // Allocate out tensor - unsafe { dev.alloc::(out_shape.elem_count()).w()? } - }; - - let config = MatmulConfig { - transa: true, - transb: false, - m: m as u64, - n: n as u64, - k: k as u64, - alpha: self.alpha.unwrap_or(1.0), - lda: lda as i64, - ldb: ldb as i64, - beta: self.beta.unwrap_or(0.0), - ldc: ldc as i64, - stride_a: None, - stride_b: None, - stride_c: None, - stride_bias: None, - batch_size: None, - }; - - unsafe { - self.cublaslt - .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) - .map_err(|e| candle::Error::Cuda(Box::new(e)))?; - } - - let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); - - Ok((out, out_shape)) - } -} - -impl candle::CustomOp2 for CublasLTMatmul { - fn name(&self) -> &'static str { - "cublaslt-matmul" - } - - fn cpu_fwd( - &self, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - ) -> Result<(CpuStorage, Shape)> { - candle::bail!("no cpu support for cublaslt-matmul") - } - - fn cuda_fwd( - &self, - a: &candle::CudaStorage, - a_l: &Layout, - b: &candle::CudaStorage, - b_l: &Layout, - ) -> Result<(candle::CudaStorage, Shape)> { - match a.dtype() { - candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None), - candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None), - candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None), - dt => candle::bail!("cublaslt-matmul is only supported for f16/bf16/f32 ({dt:?})"), - } - } -} - -impl candle::CustomOp3 for CublasLTMatmul { - fn name(&self) -> &'static str { - "cublaslt-matmul-add" - } - - fn cpu_fwd( - &self, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - ) -> Result<(CpuStorage, Shape)> { - candle::bail!("no cpu support for cublaslt-matmul") - } - - fn cuda_fwd( - &self, - a: &candle::CudaStorage, - a_l: &Layout, - b: &candle::CudaStorage, - b_l: &Layout, - bias: &candle::CudaStorage, - bias_l: &Layout, - ) -> Result<(candle::CudaStorage, Shape)> { - match a.dtype() { - candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)), - candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)), - candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)), - dt => candle::bail!("cublaslt-matmul is only supported for f16/bf16/f32 ({dt:?})"), - } - } -} - -/// Fused matmul + add + Relu/Gelu activation using CublasLt -/// -/// # Arguments -/// -/// * `a` - Input tensor of size MxK -/// * `b` - Input tensor of size NxK -/// * `out` - Optional Output tensor of size NxK. -/// If set and beta != 0, will be added to the end result of A*B before `act` -/// * `alpha` - Optional scaling factor for A*B -/// * `beta` - Optional scaling factor for C -/// * `bias` - Optional bias tensor of size M -/// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result -/// * `cublaslt` - CublasLt handle -/// -/// The resulting tensor is of shape NxM -#[allow(clippy::too_many_arguments)] -pub fn fused_matmul( - a: &Tensor, - b: &Tensor, - out: Option<&Tensor>, - alpha: Option, - beta: Option, - bias: Option<&Tensor>, - act: Option, - cublaslt: CublasLt, -) -> Result { - let op = CublasLTMatmul { - act, - cublaslt: cublaslt.0, - c: out.cloned(), - alpha, - beta, - }; - - if let Some(bias) = bias { - a.apply_op3(b, bias, op) - } else { - a.apply_op2(b, op) - } -} - -pub struct CublasLTBatchMatmul { - pub cublaslt: Arc, - pub act: Option, - pub c: Option, - pub alpha: Option, - pub beta: Option, -} - -impl CublasLTBatchMatmul { - pub fn fwd_f16( - &self, - a: &candle::CudaStorage, - a_l: &Layout, - b: &candle::CudaStorage, - b_l: &Layout, - bias: Option<&candle::CudaStorage>, - bias_l: Option<&Layout>, - ) -> Result<(candle::CudaStorage, Shape)> { - let dev = a.device(); - - // Assume TN - let (batch_size, m, k) = a_l.shape().dims3()?; - let (b_0, n, b_2) = b_l.shape().dims3()?; - - if b_2 != k { - candle::bail!("This layer only supports TN layout"); - } - - if b_0 != batch_size { - candle::bail!("`b` must have the same batch size as `a`") - } - - let lda = k; - let ldb = k; - let ldc = m; - - let out_shape = Shape::from((batch_size, n, m)); - - let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); - let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); - - let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { - if bias_l.shape().dims1()? != m { - candle::bail!("Bias does not have the correct shape"); - } - - Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) - } else { - None - }; - - let (mut out, stride_c) = if let Some(c) = &self.c { - let (c, c_l) = c.storage_and_layout(); - let c = match &*c { - Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle::bail!("`c` must be a cuda tensor"), - }; - match c_l.contiguous_offsets() { - Some((o1, o2)) => { - if o1 != 0 { - candle::bail!("`c` start offset must be 0"); - } - if o2 != out_shape.elem_count() { - candle::bail!("`c` end offset must be {}", out_shape.elem_count()) - } - } - None => candle::bail!("`c` has to be contiguous"), - }; - - if c_l.shape().dims3()? != (batch_size, n, m) { - candle::bail!("`c` does not have the correct shape"); - } - - // Set beta to 0.0 if it is not set - (c.clone(), c_l.stride()[0]) - } else { - // Allocate out tensor - ( - unsafe { dev.alloc::(out_shape.elem_count()).w()? }, - (n * m), - ) - }; - - let config = MatmulConfig { - transa: true, - transb: false, - m: m as u64, - n: n as u64, - k: k as u64, - alpha: self.alpha.unwrap_or(1.0), - lda: lda as i64, - ldb: ldb as i64, - beta: self.beta.unwrap_or(0.0), - ldc: ldc as i64, - stride_a: Some(a_l.stride()[0] as i64), - stride_b: Some(b_l.stride()[0] as i64), - stride_c: Some(stride_c as i64), - stride_bias: None, - batch_size: Some(c_int::try_from(batch_size)?), - }; - - unsafe { - self.cublaslt - .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) - .map_err(|e| candle::Error::Cuda(Box::new(e)))?; - } - - let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); - - Ok((out, out_shape)) - } - - pub fn fwd_bf16( - &self, - a: &candle::CudaStorage, - a_l: &Layout, - b: &candle::CudaStorage, - b_l: &Layout, - bias: Option<&candle::CudaStorage>, - bias_l: Option<&Layout>, - ) -> Result<(candle::CudaStorage, Shape)> { - let dev = a.device(); - - // Assume TN - let (batch_size, m, k) = a_l.shape().dims3()?; - let (b_0, n, b_2) = b_l.shape().dims3()?; - - if b_2 != k { - candle::bail!("This layer only supports TN layout"); - } - - if b_0 != batch_size { - candle::bail!("`b` must have the same batch size as `a`") - } - - let lda = k; - let ldb = k; - let ldc = m; - - let out_shape = Shape::from((batch_size, n, m)); - - let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); - let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); - - let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { - if bias_l.shape().dims1()? != m { - candle::bail!("Bias does not have the correct shape"); - } - - Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) - } else { - None - }; - - let (mut out, stride_c) = if let Some(c) = &self.c { - let (c, c_l) = c.storage_and_layout(); - let c = match &*c { - Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle::bail!("`c` must be a cuda tensor"), - }; - match c_l.contiguous_offsets() { - Some((o1, o2)) => { - if o1 != 0 { - candle::bail!("`c` start offset must be 0"); - } - if o2 != out_shape.elem_count() { - candle::bail!("`c` end offset must be {}", out_shape.elem_count()) - } - } - None => candle::bail!("`c` has to be contiguous"), - }; - - if c_l.shape().dims3()? != (batch_size, n, m) { - candle::bail!("`c` does not have the correct shape"); - } - - // Set beta to 0.0 if it is not set - (c.clone(), c_l.stride()[0]) - } else { - // Allocate out tensor - ( - unsafe { dev.alloc::(out_shape.elem_count()).w()? }, - (n * m), - ) - }; - - let config = MatmulConfig { - transa: true, - transb: false, - m: m as u64, - n: n as u64, - k: k as u64, - alpha: self.alpha.unwrap_or(1.0), - lda: lda as i64, - ldb: ldb as i64, - beta: self.beta.unwrap_or(0.0), - ldc: ldc as i64, - stride_a: Some(a_l.stride()[0] as i64), - stride_b: Some(b_l.stride()[0] as i64), - stride_c: Some(stride_c as i64), - stride_bias: None, - batch_size: Some(c_int::try_from(batch_size)?), - }; - - unsafe { - self.cublaslt - .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) - .map_err(|e| candle::Error::Cuda(Box::new(e)))?; - } - - let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); - - Ok((out, out_shape)) - } - - pub fn fwd_f32( - &self, - a: &candle::CudaStorage, - a_l: &Layout, - b: &candle::CudaStorage, - b_l: &Layout, - bias: Option<&candle::CudaStorage>, - bias_l: Option<&Layout>, - ) -> Result<(candle::CudaStorage, Shape)> { - let dev = a.device(); - - // Assume TN - let (batch_size, m, k) = a_l.shape().dims3()?; - let (b_0, n, b_2) = b_l.shape().dims3()?; - - if b_2 != k { - candle::bail!("This layer only supports TN layout"); - } - - if b_0 != batch_size { - candle::bail!("`b` must have the same batch size as `a`") - } - - let lda = k; - let ldb = k; - let ldc = m; - - let out_shape = Shape::from((batch_size, n, m)); - - let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); - let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); - - let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { - if bias_l.shape().dims1()? != m { - candle::bail!("Bias does not have the correct shape"); - } - - Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) - } else { - None - }; - - let (mut out, stride_c) = if let Some(c) = &self.c { - let (c, c_l) = c.storage_and_layout(); - let c = match &*c { - Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle::bail!("`c` must be a cuda tensor"), - }; - match c_l.contiguous_offsets() { - Some((o1, o2)) => { - if o1 != 0 { - candle::bail!("`c` start offset must be 0"); - } - if o2 != out_shape.elem_count() { - candle::bail!("`c` end offset must be {}", out_shape.elem_count()) - } - } - None => candle::bail!("`c` has to be contiguous"), - }; - - if c_l.shape().dims3()? != (batch_size, n, m) { - candle::bail!("`c` does not have the correct shape"); - } - - // Set beta to 0.0 if it is not set - (c.clone(), c_l.stride()[0]) - } else { - // Allocate out tensor - ( - unsafe { dev.alloc::(out_shape.elem_count()).w()? }, - (n * m), - ) - }; - - let config = MatmulConfig { - transa: true, - transb: false, - m: m as u64, - n: n as u64, - k: k as u64, - alpha: self.alpha.unwrap_or(1.0), - lda: lda as i64, - ldb: ldb as i64, - beta: self.beta.unwrap_or(0.0), - ldc: ldc as i64, - stride_a: Some(a_l.stride()[0] as i64), - stride_b: Some(b_l.stride()[0] as i64), - stride_c: Some(stride_c as i64), - stride_bias: None, - batch_size: Some(c_int::try_from(batch_size)?), - }; - - unsafe { - self.cublaslt - .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) - .map_err(|e| candle::Error::Cuda(Box::new(e)))?; - } - - let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); - - Ok((out, out_shape)) - } -} - -impl candle::CustomOp2 for CublasLTBatchMatmul { - fn name(&self) -> &'static str { - "cublaslt-batch-matmul" - } - - fn cpu_fwd( - &self, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - ) -> Result<(CpuStorage, Shape)> { - candle::bail!("no cpu support for cublaslt-batch-matmul") - } - - fn cuda_fwd( - &self, - a: &candle::CudaStorage, - a_l: &Layout, - b: &candle::CudaStorage, - b_l: &Layout, - ) -> Result<(candle::CudaStorage, Shape)> { - match a.dtype() { - candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None), - candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None), - candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None), - dt => { - candle::bail!("cublaslt-batch-matmul is only supported for f16/bf16/f32 ({dt:?})") - } - } - } -} - -impl candle::CustomOp3 for CublasLTBatchMatmul { - fn name(&self) -> &'static str { - "cublaslt-batch-matmul-add" - } - - fn cpu_fwd( - &self, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - ) -> Result<(CpuStorage, Shape)> { - candle::bail!("no cpu support for cublaslt-batch-matmul-add") - } - - fn cuda_fwd( - &self, - a: &candle::CudaStorage, - a_l: &Layout, - b: &candle::CudaStorage, - b_l: &Layout, - bias: &candle::CudaStorage, - bias_l: &Layout, - ) -> Result<(candle::CudaStorage, Shape)> { - match a.dtype() { - candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)), - candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)), - candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)), - dt => candle::bail!( - "cublaslt-batch-matmul-add is only supported for f16/bf16/f32 ({dt:?})" - ), - } - } -} - -/// Fused batch matmul + add + Relu/Gelu activation using CublasLt -/// -/// # Arguments -/// -/// * `a` - Input tensor of size BxMxK -/// * `b` - Input tensor of size BxNxK -/// * `out` - Optional Output tensor of size BxNxK. -/// If set and beta != 0, will be added to the end result of A*B before `act` -/// * `alpha` - Optional scaling factor for A*B -/// * `beta` - Optional scaling factor for C -/// * `bias` - Optional bias tensor of size M -/// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result -/// * `cublaslt` - CublasLt handle -/// -/// The resulting tensor is of shape NxM -#[allow(clippy::too_many_arguments)] -pub fn fused_batch_matmul( - a: &Tensor, - b: &Tensor, - out: Option<&Tensor>, - alpha: Option, - beta: Option, - bias: Option<&Tensor>, - act: Option, - cublaslt: CublasLt, -) -> Result { - let op = CublasLTBatchMatmul { - act, - cublaslt: cublaslt.0, - c: out.cloned(), - alpha, - beta, - }; - - if let Some(bias) = bias { - a.apply_op3(b, bias, op) - } else { - a.apply_op2(b, op) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use candle::{DType, Device}; - - fn to_vec2_round(t: Tensor, digits: i32) -> Result>> { - let b = 10f32.powi(digits); - let t = t.to_vec2::()?; - let t = t - .iter() - .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) - .collect(); - Ok(t) - } - - fn to_vec3_round(t: Tensor, digits: i32) -> Result>>> { - let b = 10f32.powi(digits); - let t = t.to_vec3::()?; - let t = t - .iter() - .map(|t| { - t.iter() - .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) - .collect() - }) - .collect(); - Ok(t) - } - - #[test] - fn test_fused_matmul() -> Result<()> { - let device = Device::new_cuda(0)?; - - let a = Tensor::randn(0., 1., (8, 4), &device)?.to_dtype(DType::F32)?; - let b = Tensor::randn(0., 1., (2, 4), &device)?.to_dtype(DType::F32)?; - let bias = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; - - let cublaslt = CublasLt::new(&device)?; - - let res = fused_matmul(&a, &b, None, None, None, Some(&bias), None, cublaslt)?; - let expected = (b.matmul(&a.t()?)? + bias.broadcast_left(2)?)?; - - assert_eq!( - to_vec2_round(res.to_dtype(DType::F32)?, 4)?, - to_vec2_round(expected.to_dtype(DType::F32)?, 4)? - ); - Ok(()) - } - - #[test] - fn test_fused_batch_matmul() -> Result<()> { - let device = Device::new_cuda(0)?; - - let a = Tensor::randn(0., 1., (3, 8, 4), &device)?.to_dtype(DType::F32)?; - let b = Tensor::randn(0., 1., (3, 2, 4), &device)?.to_dtype(DType::F32)?; - let c = Tensor::randn(0., 1., (3, 2, 8), &device)?.to_dtype(DType::F32)?; - let bias = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; - - let cublaslt = CublasLt::new(&device)?; - - let res = fused_batch_matmul( - &a, - &b, - Some(&c), - None, - Some(1.0), - Some(&bias), - None, - cublaslt, - )?; - let expected = (b.matmul(&a.t()?)?.add(&c)? + bias.broadcast_left((3, 2))?)?; - - assert_eq!( - to_vec3_round(res.to_dtype(DType::F32)?, 4)?, - to_vec3_round(expected.to_dtype(DType::F32)?, 4)? - ); - Ok(()) - } -} diff --git a/candle-nn/src/cublaslt/mod.rs b/candle-nn/src/cublaslt/mod.rs deleted file mode 100644 index e46aaabf67..0000000000 --- a/candle-nn/src/cublaslt/mod.rs +++ /dev/null @@ -1,100 +0,0 @@ -//! This module inspired from: -//! -//! https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/candle/src/layers/cublaslt.rs - -#![allow(unused_variables, unused_imports, dead_code)] - -use crate::Activation as CandleActivation; -use candle::{Device, Result, Tensor}; -use std::sync::{LazyLock, Mutex, Once}; - -#[cfg(feature = "cuda")] -mod api; - -#[cfg(feature = "cuda")] -use api::{fused_batch_matmul, fused_matmul, Activation, CublasLt}; - -static INIT: Once = Once::new(); -static mut CUBLASLT: Option = None; -pub(crate) static CUBLASLT_HANDLE: LazyLock>> = - LazyLock::new(|| Mutex::new(None)); - -/// Internal function to initialize the cublaslt handle wrapper, behind a LazyLock so initialization occurs -/// only once. -pub(crate) fn setup_cublas_lt_wrapper() { - unsafe { - INIT.call_once(|| { - #[cfg(not(feature = "cuda"))] - { - CUBLASLT = None; - } - - #[cfg(feature = "cuda")] - { - // Check if we can call the driver - // Then check if we can create a device - // Then check that the device is CUDA - use candle::cuda_backend::cudarc::driver; - CUBLASLT = driver::result::init() - .ok() - .and_then(|_| Device::cuda_if_available(0).ok()) - .and_then(|device| match device { - Device::Cuda(_) => Some(CublasLtWrapper { - cublaslt: CublasLt::new(&device).unwrap(), - }), - _ => None, - }); - } - }); - let cublaslt: Option<&'static CublasLtWrapper> = CUBLASLT.as_ref(); - *CUBLASLT_HANDLE.lock().unwrap() = cublaslt; - } -} - -#[derive(Debug, Clone)] -pub struct CublasLtWrapper { - #[cfg(feature = "cuda")] - pub cublaslt: CublasLt, -} - -impl CublasLtWrapper { - #[allow(clippy::too_many_arguments)] - pub fn batch_matmul( - &self, - a: &Tensor, - b: &Tensor, - out: Option<&Tensor>, - alpha: Option, - beta: Option, - bias: Option<&Tensor>, - act: Option, - ) -> Result { - #[cfg(feature = "cuda")] - { - let inner_act = act.map(|a| match a { - CandleActivation::Relu => Activation::Relu, - CandleActivation::Gelu => Activation::Gelu, - _ => unreachable!("Unsupported activation in cublaslt matmul"), - }); - let mut result = fused_batch_matmul( - a, - b, - out, - alpha, - beta, - bias, - inner_act, - self.cublaslt.clone(), - )?; - - if Some(CandleActivation::Swiglu) == act { - result = crate::ops::swiglu(&result)?; - } - Ok(result) - } - #[cfg(not(feature = "cuda"))] - { - candle::bail!("`cuda` feature is not enabled") - } - } -} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 34793cd8af..a4f5943ee3 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -2,7 +2,6 @@ pub mod activation; pub mod attention; pub mod batch_norm; pub mod conv; -pub mod cublaslt; pub mod embedding; pub mod encoding; pub mod func; From 9f144d63af4ec441b8a39a6ba16a555a1379f1c4 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 4 Aug 2024 19:31:50 -0400 Subject: [PATCH 24/75] Add matmul_alpha --- candle-core/src/backend.rs | 10 + candle-core/src/cpu_backend/mod.rs | 360 ++++++++++++++++++++++++- candle-core/src/cpu_backend/utils.rs | 34 +++ candle-core/src/cuda_backend/mod.rs | 69 +++++ candle-core/src/dummy_cuda_backend.rs | 11 + candle-core/src/dummy_metal_backend.rs | 11 + candle-core/src/error.rs | 8 + candle-core/src/metal_backend/mod.rs | 46 ++++ candle-core/src/storage.rs | 32 +++ candle-core/src/tensor.rs | 95 ++++++- candle-core/tests/matmul_tests.rs | 37 ++- 11 files changed, 700 insertions(+), 13 deletions(-) diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index b1e67921a9..6465bcdd27 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -109,6 +109,16 @@ pub trait BackendStorage: Sized { _: &Layout, ) -> Result<()>; + #[allow(clippy::too_many_arguments)] + fn matmul_with_alpha( + &self, + _: &Self, + _: Option, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + ) -> Result; + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>; #[allow(clippy::too_many_arguments)] diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 466ab5195b..95adbd2a59 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -8,7 +8,8 @@ use rayon::prelude::*; mod utils; pub use utils::{ - binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8, Map3, + binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2Alpha, Map2U8, + Map3, }; const USE_IM2COL_CONV1D: bool = true; @@ -1542,7 +1543,7 @@ impl Deref for MatMulWithBias { } impl Map3 for MatMulWithBias { - const OP: &'static str = "mat_mul_c"; + const OP: &'static str = "mat_mul_ac"; #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))] fn f( @@ -1644,7 +1645,7 @@ impl Map3 for MatMulWithBias { c: &mut [T], c_l: &Layout, s: Option, - ) -> Result> { + ) -> Result<()> { let (b, m, n, k) = self.0; let lhs = &lhs[lhs_l.start_offset()..]; let rhs = &rhs[rhs_l.start_offset()..]; @@ -1754,7 +1755,7 @@ impl Map3 for MatMulWithBias { } dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, } - Ok(dst) + Ok(()) } #[cfg(feature = "mkl")] @@ -1767,7 +1768,7 @@ impl Map3 for MatMulWithBias { c: &mut [T], c_l: &Layout, s: Option, - ) -> Result> { + ) -> Result<()> { let (b, m, n, k) = self.0; let lhs = &lhs[lhs_l.start_offset()..]; let rhs = &rhs[rhs_l.start_offset()..]; @@ -1811,6 +1812,344 @@ impl Map3 for MatMulWithBias { None => crate::bail!("`c` has to be contiguous"), }; + match T::DTYPE { + DType::F16 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f16; + let b = lhs_p.as_ptr() as *const f16; + let c = dst_p.as_mut_ptr() as *mut f16; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::hgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ f16::from_f64(s.unwrap_or(1.)), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ f16::ONE, + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::sgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.) as f32, + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 0., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::dgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 0., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, + } + Ok(()) + } +} + +struct MatMulWithAlpha(MatMul); + +impl Deref for MatMulWithAlpha { + type Target = MatMul; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Map2Alpha for MatMulWithAlpha { + const OP: &'static str = "mat_mul_a"; + + #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + s: Option, + ) -> Result> { + use gemm::{gemm, Parallelism}; + + match T::DTYPE { + DType::F16 | DType::F32 | DType::F64 => {} + _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?, + } + + let (b, m, n, k) = self.0 .0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rank = lhs_stride.len(); + let lhs_cs = lhs_stride[rank - 1]; + let lhs_rs = lhs_stride[rank - 2]; + + let rhs_cs = rhs_stride[rank - 1]; + let rhs_rs = rhs_stride[rank - 2]; + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let dst_shape: Shape = (m, n).into(); + let dst_strides = dst_shape.stride_contiguous(); + let dst_rs = dst_strides[0]; + let dst_cs = dst_strides[1]; + + let mut dst = vec![T::zero(); b * m * n]; + let num_threads = crate::utils::get_num_threads(); + let parallelism = if num_threads > 1 { + Parallelism::Rayon(num_threads) + } else { + Parallelism::None + }; + + let alpha = T::from_f64(s.unwrap_or(1.0)); + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + gemm( + /* m: usize = */ m, + /* n: usize = */ n, + /* k: usize = */ k, + /* dst: *mut T = */ dst_p.as_mut_ptr(), + /* dst_cs: isize = */ dst_cs as isize, + /* dst_rs: isize = */ dst_rs as isize, + /* read_dst: bool = */ true, + /* lhs: *const T = */ lhs_p.as_ptr(), + /* lhs_cs: isize = */ lhs_cs as isize, + /* lhs_rs: isize = */ lhs_rs as isize, + /* rhs: *const T = */ rhs_p.as_ptr(), + /* rhs_cs: isize = */ rhs_cs as isize, + /* rhs_rs: isize = */ rhs_rs as isize, + /* alpha: T = */ T::one(), + /* beta: T = */ alpha, + /* conj_dst: bool = */ false, + /* conj_lhs: bool = */ false, + /* conj_rhs: bool = */ false, + parallelism, + ) + } + } + Ok(dst) + } + + #[cfg(feature = "accelerate")] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + s: Option, + ) -> Result> { + let (b, m, n, k) = self.0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? + }; + + let mut dst = vec![T::zero(); b * m * n]; + match T::DTYPE { + DType::F16 => { + crate::bail!("the accelerate backend does not support f16 matmul") + } + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::sgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 1., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::dgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 1., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, + } + Ok(dst) + } + + #[cfg(feature = "mkl")] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + s: Option, + ) -> Result> { + let (b, m, n, k) = self.0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? + }; + + let mut dst = vec![T::zero(); b * m * n]; match T::DTYPE { DType::F16 => { for step in 0..b { @@ -2825,6 +3164,17 @@ impl BackendStorage for CpuStorage { MatMulWithBias(MatMul(bmnk)).map(self, lhs_l, rhs, rhs_l, c, c_l, s) } + fn matmul_with_alpha( + &self, + rhs: &Self, + s: Option, + bmnk: (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + MatMulWithAlpha(MatMul(bmnk)).map(self, lhs_l, rhs, rhs_l, s) + } + fn device(&self) -> &Self::Device { &CpuDevice } diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index 1566f6ebfd..9d38729145 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -83,6 +83,7 @@ pub trait Map3 { l3: &Layout, s: Option, ) -> Result<()> { + let v3d = v3.dtype(); match (v1, v2, v3) { (C::U8(v1), C::U8(v2), C::U8(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), (C::U32(v1), C::U32(v2), C::U32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), @@ -91,6 +92,39 @@ pub trait Map3 { (C::F16(v1), C::F16(v2), C::F16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), (C::F32(v1), C::F32(v2), C::F32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), (C::F64(v1), C::F64(v2), C::F64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + _ => Err(Error::DTypeMismatchBinaryOp3 { + lhs: v1.dtype(), + rhs: v2.dtype(), + c: v3d, + op: Self::OP, + } + .bt()), + } + } +} + +pub trait Map2Alpha { + const OP: &'static str; + #[allow(clippy::too_many_arguments)] + fn f( + &self, + v1: &[T], + l1: &Layout, + v2: &[T], + l2: &Layout, + s: Option, + ) -> Result>; + + #[allow(clippy::too_many_arguments)] + fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout, s: Option) -> Result { + match (v1, v2) { + (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2, s)?)), + (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2, s)?)), + (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2, s)?)), + (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2, s)?)), + (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2, s)?)), + (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2, s)?)), + (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2, s)?)), _ => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 856972adae..954cedc041 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1779,6 +1779,75 @@ impl BackendStorage for CudaStorage { Ok(()) } + fn matmul_with_alpha( + &self, + rhs: &Self, + scale: Option, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + let elem_count = b * m * n; + let dev = &self.device; + let slice = match (&self.slice, &rhs.slice) { + (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config( + bf16::from_f64(scale.unwrap_or(1.)), + bf16::ZERO, + (b, m, n, k), + lhs_l, + rhs_l, + )?; + let mut out = unsafe { dev.alloc::(elem_count) }.w()?; + unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) } + .w()?; + CudaStorageSlice::BF16(out) + } + (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config( + f16::from_f64(scale.unwrap_or(1.)), + f16::ZERO, + (b, m, n, k), + lhs_l, + rhs_l, + )?; + let mut out = unsafe { dev.alloc::(elem_count) }.w()?; + unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) } + .w()?; + CudaStorageSlice::F16(out) + } + (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(scale.unwrap_or(1.) as f32, 0., (b, m, n, k), lhs_l, rhs_l)?; + let mut out = unsafe { dev.alloc::(elem_count) }.w()?; + unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) } + .w()?; + CudaStorageSlice::F32(out) + } + (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(scale.unwrap_or(1.), 0., (b, m, n, k), lhs_l, rhs_l)?; + let mut out = unsafe { dev.alloc::(elem_count) }.w()?; + unsafe { + self.device + .blas + .gemm_strided_batched(cfg, rhs, lhs, &mut out) + } + .w()?; + CudaStorageSlice::F64(out) + } + _ => Err(CudaError::InternalError("dtype mismatch in matmul op"))?, + }; + let device = dev.clone(); + Ok(Self { slice, device }) + } + fn copy2d( &self, dst: &mut Self, diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 3035a02976..353498b68c 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -163,6 +163,17 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + fn matmul_with_alpha( + &self, + _: &Self, + _: Option, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index dd19ec5cc7..07a63d135e 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -175,6 +175,17 @@ impl crate::backend::BackendStorage for MetalStorage { Err(Error::NotCompiledWithMetalSupport) } + fn matmul_with_alpha( + &self, + _: &Self, + _: Option, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { Err(Error::NotCompiledWithMetalSupport) } diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 2aacfd54ef..67b391451b 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -26,6 +26,14 @@ pub enum Error { op: &'static str, }, + #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}, c: {rhs:?}")] + DTypeMismatchBinaryOp3 { + lhs: DType, + rhs: DType, + c: DType, + op: &'static str, + }, + #[error("unsupported dtype {0:?} for op {1}")] UnsupportedDTypeForOp(DType, &'static str), diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index b02edac825..039bb62d95 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1398,6 +1398,7 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; Ok(acc) } + fn matmul( &self, rhs: &Self, @@ -1497,6 +1498,51 @@ impl BackendStorage for MetalStorage { Ok(()) } + fn matmul_with_alpha( + &self, + rhs: &Self, + s: Option, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?; + let name = match self.dtype { + DType::F32 => "sgemm", + DType::F16 => "hgemm", + DType::BF16 => "bgemm", + dtype => { + return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into()) + } + }; + + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("matmul"); + candle_metal_kernels::call_gemm( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + (b, m, n, k), + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &self.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &rhs.buffer, + &buffer, + s.unwrap_or(1.) as f32, + 0., + ) + .map_err(MetalError::from)?; + Ok(Self::new( + buffer, + self.device.clone(), + b * m * n, + self.dtype(), + )) + } + fn copy2d( &self, dst: &mut Self, diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index fe635c40f9..c9d9b5c3ec 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -771,6 +771,38 @@ impl Storage { } } + pub(crate) fn matmul_with_alpha( + &self, + rhs: &Self, + s: Option, + bmnk: (usize, usize, usize, usize), + lhs_layout: &Layout, + rhs_layout: &Layout, + ) -> Result { + self.same_device(rhs, "matmul_with_alpha")?; + self.same_dtype(rhs, "matmul_with_alpha")?; + match (self, rhs) { + (Self::Cpu(lhs), Self::Cpu(rhs)) => { + let storage = lhs.matmul_with_alpha(rhs, s, bmnk, lhs_layout, rhs_layout)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(lhs), Self::Cuda(rhs)) => { + let storage = lhs.matmul_with_alpha(rhs, s, bmnk, lhs_layout, rhs_layout)?; + Ok(Self::Cuda(storage)) + } + (Self::Metal(lhs), Self::Metal(rhs)) => { + let storage = lhs.matmul_with_alpha(rhs, s, bmnk, lhs_layout, rhs_layout)?; + Ok(Self::Metal(storage)) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "matmul", + } + .bt()), + } + } + // self, the source can be strided whereas dst is contiguous. pub(crate) fn copy_strided_src( &self, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index c65f1f5846..49f396c6f2 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1235,8 +1235,9 @@ impl Tensor { /// Else: /// `c := c + scale * (axb)` /// - /// - /// This is incompatible with gradient tracking. No gradients will be tracked on this operation. + /// This function is faster than a matmul followed by some scaling multiply because the scaling is fused in the GEMM kernel. + /// This is incompatible with gradient tracking. No gradients will be tracked on this operation. However, this also means + /// there is an allocation saved as the output is in `c`. /// /// # Arguments /// @@ -1287,7 +1288,7 @@ impl Tensor { Err(Error::ShapeMismatchBinaryOp { lhs: self.shape().clone(), rhs: rhs.shape().clone(), - op: "matmul", + op: "matmul_with_alpha_beta", } .bt())? } @@ -1303,6 +1304,94 @@ impl Tensor { ) } + /// Returns the matrix-multiplication of the input tensor with the other provided tensor. The result is scaled. + /// + /// This function is faster than a matmul followed by some scaling multiply because the scaling is fused in the GEMM kernel. + /// + /// The output is as follows: + /// `scale * (axb)` + /// + /// + /// This is incompatible with gradient tracking. No gradients will be tracked on this operation. + /// + /// # Arguments + /// + /// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`. + /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`. + /// * `scale` - Factor to multiply `self` x `rhs` by. + pub fn matmul_with_alpha(&self, rhs: &Self, scale: Option) -> Result { + let a_dims = self.shape().dims(); + let b_dims = rhs.shape().dims(); + + let dim = a_dims.len(); + + if dim < 2 || b_dims.len() != dim { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul", + } + .bt())? + } + + let m = a_dims[dim - 2]; + let k = a_dims[dim - 1]; + let k2 = b_dims[dim - 2]; + let n = b_dims[dim - 1]; + + let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]); + if c_shape.elem_count() == 0 || k == 0 { + return Tensor::zeros(c_shape, self.dtype(), self.device()); + } + let batching: usize = a_dims[..dim - 2].iter().product(); + let batching_b: usize = b_dims[..dim - 2].iter().product(); + if k != k2 || batching != batching_b { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul_with_alpha", + } + .bt())? + } + + let storage = self.storage().matmul_with_alpha( + &rhs.storage(), + scale, + (batching, m, n, k), + self.layout(), + rhs.layout(), + )?; + let op = BackpropOp::new2(self, rhs, Op::Matmul); + Ok(from_storage(storage, c_shape, op, false)) + } + + /// Matrix-multiplication with broadcasting support and fused scaling. + /// + /// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as + /// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has + /// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`. + pub fn broadcast_matmul_with_alpha(&self, rhs: &Self, scale: Option) -> Result { + let lhs = self; + let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?; + let l_broadcast = l_shape != *lhs.shape(); + let r_broadcast = r_shape != *rhs.shape(); + // TODO: Avoid concretising the broadcasted matrixes via contiguous. + match (l_broadcast, r_broadcast) { + (true, true) => lhs + .broadcast_as(&l_shape)? + .contiguous()? + .matmul_with_alpha(&rhs.broadcast_as(&r_shape)?.contiguous()?, scale), + (false, true) => { + lhs.matmul_with_alpha(&rhs.broadcast_as(&r_shape)?.contiguous()?, scale) + } + (true, false) => lhs + .broadcast_as(&l_shape)? + .contiguous()? + .matmul_with_alpha(rhs, scale), + (false, false) => lhs.matmul_with_alpha(rhs, scale), + } + } + /// Returns a tensor with the same shape as the input tensor, the values are taken from /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the /// input tensor is equal to zero. diff --git a/candle-core/tests/matmul_tests.rs b/candle-core/tests/matmul_tests.rs index c2af6012be..edca8e1561 100644 --- a/candle-core/tests/matmul_tests.rs +++ b/candle-core/tests/matmul_tests.rs @@ -109,7 +109,7 @@ fn mm_layout(device: &Device) -> Result<()> { Ok(()) } -fn matmul_beta(device: &Device) -> Result<()> { +fn matmul_alpha_beta(device: &Device) -> Result<()> { let data = vec![1.0f32, 2.0, 3.0, 4.0]; let a = Tensor::from_slice(&data, (2, 2), device)?; let data = vec![1.0f32, 2.0, 3.0, 4.0]; @@ -119,15 +119,42 @@ fn matmul_beta(device: &Device) -> Result<()> { a.matmul_with_alpha_beta(&b, &mut c, None)?; assert_eq!(c.to_vec2::()?, &[[8.0f32, 11.0], [16.0, 23.0]]); + + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 1.0, 1.0, 1.0]; + let mut c = Tensor::from_slice(&data, (2, 2), device)?; + + a.matmul_with_alpha_beta(&b, &mut c, Some(2.))?; + assert_eq!(c.to_vec2::()?, &[[15.0f32, 21.0], [31.0, 45.0]]); + Ok(()) +} + +fn matmul_alpha(device: &Device) -> Result<()> { + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), device)?; + + let c = a.matmul_with_alpha(&b, Some(2.))?; + assert_eq!(c.to_vec2::()?, &[[14.0f32, 20.0], [30.0, 44.0]]); Ok(()) } test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal); test_device!( - matmul_beta, - matmul_beta_cpu, - matmul_beta_gpu, - matmul_beta_metal + matmul_alpha_beta, + matmul_alpha_beta_cpu, + matmul_alpha_beta_gpu, + matmul_alpha_beta_metal +); +test_device!( + matmul_alpha, + matmul_alpha_cpu, + matmul_alpha_gpu, + matmul_alpha_metal ); test_device!( matmul_bf16, From c830f2632b1aba9fb950ee5c378281f48cfd3d1c Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 4 Aug 2024 19:34:36 -0400 Subject: [PATCH 25/75] Use matmul_with_alpha in sdpa --- candle-nn/src/attention.rs | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/candle-nn/src/attention.rs b/candle-nn/src/attention.rs index 525f3d9eb2..c594faffe8 100644 --- a/candle-nn/src/attention.rs +++ b/candle-nn/src/attention.rs @@ -28,11 +28,9 @@ pub fn scaled_dot_product_attention( q: &Tensor, k: &Tensor, v: &Tensor, - n_attn_heads: usize, - head_dim: usize, + scale: f64, mask: Option, use_flash_attn: bool, - b_sz: usize, seq_len: usize, ) -> Result { if use_flash_attn { @@ -40,20 +38,18 @@ pub fn scaled_dot_product_attention( let q = q.transpose(1, 2)?; let k = k.transpose(1, 2)?; let v = v.transpose(1, 2)?; - let softmax_scale = 1f32 / (head_dim as f32).sqrt(); - return flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2); + return flash_attn(&q, &k, &v, scale as f32, seq_len > 1)?.transpose(1, 2); } let att = match mask { Some(mut m) => { - q.contiguous()?.matmul_with_alpha_beta( - &k.t()?.contiguous()?, - &mut m, - Some(1. / (head_dim as f64).sqrt()), - )?; + q.contiguous()? + .matmul_with_alpha_beta(&k.t()?.contiguous()?, &mut m, Some(scale))?; m } - None => (&q.contiguous()?.matmul(&k.t()?.contiguous()?)? / (head_dim as f64).sqrt())?, + None => q + .contiguous()? + .matmul_with_alpha(&k.t()?.contiguous()?, Some(scale))?, }; let att = crate::ops::softmax_last_dim(&att)?; From 86d0876bf6bc6fd1ba826ae25c17ed7434180727 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 4 Aug 2024 20:13:45 -0400 Subject: [PATCH 26/75] Add it to mistral --- candle-nn/src/attention.rs | 13 +++++-- candle-transformers/src/models/mistral.rs | 45 ++++++----------------- 2 files changed, 20 insertions(+), 38 deletions(-) diff --git a/candle-nn/src/attention.rs b/candle-nn/src/attention.rs index c594faffe8..4a5c733710 100644 --- a/candle-nn/src/attention.rs +++ b/candle-nn/src/attention.rs @@ -42,10 +42,15 @@ pub fn scaled_dot_product_attention( } let att = match mask { - Some(mut m) => { - q.contiguous()? - .matmul_with_alpha_beta(&k.t()?.contiguous()?, &mut m, Some(scale))?; - m + Some(mask) => { + let (b, n, s, _h) = q.dims4()?; + let mut mask_and_output = mask.broadcast_as((b, n, s, s))?.contiguous()?; + q.contiguous()?.matmul_with_alpha_beta( + &k.t()?.contiguous()?, + &mut mask_and_output, + Some(scale), + )?; + mask_and_output } None => q .contiguous()? diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index 1cb55f9e61..e04cf72f24 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -1,7 +1,7 @@ use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; /// Mistral LLM, https://github.com/mistralai/mistral-src use candle::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{Activation, VarBuilder}; +use candle_nn::{scaled_dot_product_attention, Activation, VarBuilder}; use std::sync::Arc; fn default_use_flash_attn() -> bool { @@ -157,22 +157,6 @@ impl Module for MLP { } } -#[cfg(feature = "flash-attn")] -fn flash_attn( - q: &Tensor, - k: &Tensor, - v: &Tensor, - softmax_scale: f32, - causal: bool, -) -> Result { - candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) -} - -#[cfg(not(feature = "flash-attn"))] -fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { - unimplemented!("compile with '--features flash-attn'") -} - #[derive(Debug, Clone)] struct Attention { q_proj: Linear, @@ -257,24 +241,17 @@ impl Attention { let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?; let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?; - let attn_output = if self.use_flash_attn { - // flash-attn expects (b_sz, seq_len, nheads, head_dim) - let q = query_states.transpose(1, 2)?; - let k = key_states.transpose(1, 2)?; - let v = value_states.transpose(1, 2)?; - let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); - flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)? - } else { - let scale = 1f64 / f64::sqrt(self.head_dim as f64); - let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + let scale = 1. / (self.head_dim as f64).sqrt(); + let attn_output = scaled_dot_product_attention( + &query_states, + &key_states, + &value_states, + scale, + attention_mask.cloned(), + self.use_flash_attn, + q_len, + )?; - let attn_weights = match attention_mask { - None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, - }; - let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; - attn_weights.matmul(&value_states)? - }; attn_output .transpose(1, 2)? .reshape((b_sz, q_len, self.hidden_size))? From 8d8889c7771408f85e226b5a7bd6f3a2dd6cad5a Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 4 Aug 2024 20:31:48 -0400 Subject: [PATCH 27/75] Add it to q llama --- candle-nn/src/attention.rs | 2 +- candle-transformers/src/models/mistral.rs | 2 +- .../src/models/quantized_llama.rs | 34 +++++-------------- 3 files changed, 11 insertions(+), 27 deletions(-) diff --git a/candle-nn/src/attention.rs b/candle-nn/src/attention.rs index 4a5c733710..5b8de4388d 100644 --- a/candle-nn/src/attention.rs +++ b/candle-nn/src/attention.rs @@ -29,7 +29,7 @@ pub fn scaled_dot_product_attention( k: &Tensor, v: &Tensor, scale: f64, - mask: Option, + mask: Option<&Tensor>, use_flash_attn: bool, seq_len: usize, ) -> Result { diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index e04cf72f24..f7b70e6cbf 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -247,7 +247,7 @@ impl Attention { &key_states, &value_states, scale, - attention_mask.cloned(), + attention_mask, self.use_flash_attn, q_len, )?; diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 6b326fbe92..544bf8a456 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -4,7 +4,7 @@ use crate::quantized_nn::RmsNorm; use candle::quantized::QTensor; use candle::quantized::{ggml_file, gguf_file}; use candle::{DType, Device, IndexOp, Result, Tensor}; -use candle_nn::{Embedding, Module}; +use candle_nn::{scaled_dot_product_attention, Embedding, Module}; pub const MAX_SEQ_LEN: usize = 4096; @@ -138,19 +138,12 @@ struct LayerWeights { head_dim: usize, cos: Tensor, sin: Tensor, - neg_inf: Tensor, kv_cache: Option<(Tensor, Tensor)>, span_attn: tracing::Span, span_rot: tracing::Span, span_mlp: tracing::Span, } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result { - let shape = mask.shape(); - let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; - Ok(m) -} - impl LayerWeights { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { let _enter = self.span_rot.enter(); @@ -209,17 +202,10 @@ impl LayerWeights { let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?; let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?; - let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let att = match mask { - None => att, - Some(mask) => { - let mask = mask.broadcast_as(att.shape())?; - masked_fill(&att, &mask, &self.neg_inf)? - } - }; - let att = candle_nn::ops::softmax_last_dim(&att)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; + let scale = 1. / (self.head_dim as f64).sqrt(); + + let y = scaled_dot_product_attention(&q, &k, &v, scale, mask, false, seq_len)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; let y = self.attention_wo.forward(&y)?; Ok(y) @@ -260,7 +246,6 @@ impl ModelWeights { pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result { let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?; - let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?; let tok_embeddings = ct.remove("tok_embeddings.weight")?; let tok_embeddings = tok_embeddings.dequantize(&ct.device)?; let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?; @@ -300,7 +285,6 @@ impl ModelWeights { head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, cos: cos.clone(), sin: sin.clone(), - neg_inf: neg_inf.clone(), kv_cache: None, span_attn, span_rot, @@ -349,7 +333,6 @@ impl ModelWeights { .and_then(|m| m.to_f32()) .unwrap_or(10000f32); let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?; - let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; @@ -420,7 +403,6 @@ impl ModelWeights { head_dim: embedding_length / head_count, cos: cos.clone(), sin: sin.clone(), - neg_inf: neg_inf.clone(), kv_cache: None, span_attn, span_rot, @@ -445,9 +427,11 @@ impl ModelWeights { Ok(mask.clone()) } else { let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .flat_map(|i| (0..t).map(move |j| if i < j { f32::NEG_INFINITY } else { 0.0 })) .collect(); - let mask = Tensor::from_slice(&mask, (t, t), device)?; + let mask = Tensor::from_slice(&mask, (t, t), device)? + .expand((1, 1, t, t))? + .to_dtype(DType::F32)?; self.masks.insert(t, mask.clone()); Ok(mask) } From d18eb13d55b4521e2d2c76be1230ecd10733c78b Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 5 Aug 2024 10:28:20 -0400 Subject: [PATCH 28/75] Add attention benches --- candle-core/benches/benchmarks/mod.rs | 4 +- candle-nn/benches/bench_main.rs | 7 +- candle-nn/benches/benchmarks/attention.rs | 111 ++++++++++++++++++++++ candle-nn/benches/benchmarks/mod.rs | 5 +- 4 files changed, 122 insertions(+), 5 deletions(-) create mode 100644 candle-nn/benches/benchmarks/attention.rs diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index 579c5f3f0b..d52659045c 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -6,7 +6,7 @@ pub(crate) mod random; pub(crate) mod unary; pub(crate) mod where_cond; -use candle_core::{Device, Result}; +use candle_core::{cuda::WrapErr, Device, Result}; pub(crate) trait BenchDevice { fn sync(&self) -> Result<()>; @@ -20,7 +20,7 @@ impl BenchDevice for Device { Device::Cpu => Ok(()), Device::Cuda(device) => { #[cfg(feature = "cuda")] - return Ok(device.synchronize()?); + return Ok(device.synchronize().w()?); #[cfg(not(feature = "cuda"))] panic!("Cuda device without cuda feature enabled: {:?}", device) } diff --git a/candle-nn/benches/bench_main.rs b/candle-nn/benches/bench_main.rs index 4db1d35c0a..727479b5c9 100644 --- a/candle-nn/benches/bench_main.rs +++ b/candle-nn/benches/bench_main.rs @@ -1,4 +1,9 @@ mod benchmarks; use criterion::criterion_main; -criterion_main!(benchmarks::layer_norm::benches, benchmarks::conv::benches); +criterion_main!( + benchmarks::layer_norm::benches, + benchmarks::conv::benches, + benchmarks::attention::benches_fast, + benchmarks::attention::benches_naive +); diff --git a/candle-nn/benches/benchmarks/attention.rs b/candle-nn/benches/benchmarks/attention.rs new file mode 100644 index 0000000000..8aa479d319 --- /dev/null +++ b/candle-nn/benches/benchmarks/attention.rs @@ -0,0 +1,111 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle::{DType, Device, Tensor}; +use candle_nn::scaled_dot_product_attention; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; + +fn run_attention(q: &Tensor, k: &Tensor, v: &Tensor, m: &Tensor, s: f64) { + let att = (q + .contiguous() + .unwrap() + .matmul(&k.t().unwrap().contiguous().unwrap()) + .unwrap() + / s) + .unwrap(); + + let att = att.broadcast_add(m).unwrap(); + + let att = candle_nn::ops::softmax_last_dim(&att).unwrap(); + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous().unwrap()).unwrap(); +} + +fn run_bench_naive(c: &mut Criterion, device: &Device) { + let b = 4; + let seq = 1024; + let heads = 32; + let hd = 128; + + let dtype = DType::F32; + let q = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let k = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let v = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let m = Tensor::zeros((b, heads, seq, seq), dtype, device).unwrap(); + + let flops = b * seq * heads * hd; + + let mut group = c.benchmark_group(device.bench_name("attention_naive")); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run_attention( + black_box(&q), + black_box(&k), + black_box(&v), + black_box(&m), + 0.3, + ); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark_naive(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_bench_naive(c, &device); + } +} + +fn run_bench_fast(c: &mut Criterion, device: &Device) { + let b = 4; + let seq = 1024; + let heads = 32; + let hd = 128; + + let dtype = DType::F32; + let q = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let k = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let v = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let m = Tensor::zeros((b, heads, seq, seq), dtype, device).unwrap(); + + let flops = b * seq * heads * hd; + + let mut group = c.benchmark_group(device.bench_name("attention_fast")); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + let _ = scaled_dot_product_attention( + black_box(&q), + black_box(&k), + black_box(&v), + 0.3, + Some(black_box(&m)), + false, + seq, + ) + .unwrap(); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark_fast(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_bench_fast(c, &device); + } +} + +criterion_group!(benches_naive, criterion_benchmark_naive); +criterion_group!(benches_fast, criterion_benchmark_fast); diff --git a/candle-nn/benches/benchmarks/mod.rs b/candle-nn/benches/benchmarks/mod.rs index 30a6ab6a2b..2ef61c70a0 100644 --- a/candle-nn/benches/benchmarks/mod.rs +++ b/candle-nn/benches/benchmarks/mod.rs @@ -1,7 +1,8 @@ +pub(crate) mod attention; pub(crate) mod conv; pub(crate) mod layer_norm; -use candle::{Device, Result}; +use candle::{cuda::WrapErr, Device, Result}; pub(crate) trait BenchDevice { fn sync(&self) -> Result<()>; @@ -15,7 +16,7 @@ impl BenchDevice for Device { Device::Cpu => Ok(()), Device::Cuda(device) => { #[cfg(feature = "cuda")] - return Ok(device.synchronize()?); + return Ok(device.synchronize().w()?); #[cfg(not(feature = "cuda"))] panic!("Cuda device without cuda feature enabled: {:?}", device) } From d71b7d78396a944817876c56f1677bd17633234d Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 5 Aug 2024 11:55:02 -0400 Subject: [PATCH 29/75] Fixes --- candle-core/src/cpu_backend/mod.rs | 10 +++++----- candle-nn/benches/benchmarks/mod.rs | 7 +++++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 95adbd2a59..3c4c47305a 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -1646,7 +1646,7 @@ impl Map3 for MatMulWithBias { c_l: &Layout, s: Option, ) -> Result<()> { - let (b, m, n, k) = self.0; + let (b, m, n, k) = self.0 .0; let lhs = &lhs[lhs_l.start_offset()..]; let rhs = &rhs[rhs_l.start_offset()..]; @@ -1769,7 +1769,7 @@ impl Map3 for MatMulWithBias { c_l: &Layout, s: Option, ) -> Result<()> { - let (b, m, n, k) = self.0; + let (b, m, n, k) = self.0 .0; let lhs = &lhs[lhs_l.start_offset()..]; let rhs = &rhs[rhs_l.start_offset()..]; @@ -1877,7 +1877,7 @@ impl Map3 for MatMulWithBias { for step in 0..b { let lhs_p = &lhs[step * a_skip..]; let rhs_p = &rhs[step * b_skip..]; - let dst_p = &mut dst[step * c_skip..]; + let dst_p = &mut c[step * c_skip..]; unsafe { let a = rhs_p.as_ptr() as *const f64; let b = lhs_p.as_ptr() as *const f64; @@ -2008,7 +2008,7 @@ impl Map2Alpha for MatMulWithAlpha { rhs_l: &Layout, s: Option, ) -> Result> { - let (b, m, n, k) = self.0; + let (b, m, n, k) = self.0 .0; let lhs = &lhs[lhs_l.start_offset()..]; let rhs = &rhs[rhs_l.start_offset()..]; @@ -2118,7 +2118,7 @@ impl Map2Alpha for MatMulWithAlpha { rhs_l: &Layout, s: Option, ) -> Result> { - let (b, m, n, k) = self.0; + let (b, m, n, k) = self.0 .0; let lhs = &lhs[lhs_l.start_offset()..]; let rhs = &rhs[rhs_l.start_offset()..]; diff --git a/candle-nn/benches/benchmarks/mod.rs b/candle-nn/benches/benchmarks/mod.rs index 2ef61c70a0..8c60df2ee5 100644 --- a/candle-nn/benches/benchmarks/mod.rs +++ b/candle-nn/benches/benchmarks/mod.rs @@ -2,7 +2,7 @@ pub(crate) mod attention; pub(crate) mod conv; pub(crate) mod layer_norm; -use candle::{cuda::WrapErr, Device, Result}; +use candle::{Device, Result}; pub(crate) trait BenchDevice { fn sync(&self) -> Result<()>; @@ -16,7 +16,10 @@ impl BenchDevice for Device { Device::Cpu => Ok(()), Device::Cuda(device) => { #[cfg(feature = "cuda")] - return Ok(device.synchronize().w()?); + { + use candle::cuda::WrapErr; + return Ok(device.synchronize().w()?); + } #[cfg(not(feature = "cuda"))] panic!("Cuda device without cuda feature enabled: {:?}", device) } From 27ca77e51c14fa9cc6c740119df2ec9f7e4da4bb Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 7 Aug 2024 17:06:13 -0400 Subject: [PATCH 30/75] Simplify things a bit --- candle-core/src/backend.rs | 8 --- candle-core/src/cpu_backend/mod.rs | 21 +++----- candle-core/src/cuda_backend/mod.rs | 67 +++----------------------- candle-core/src/dummy_cuda_backend.rs | 10 ---- candle-core/src/dummy_metal_backend.rs | 10 ---- candle-core/src/metal_backend/mod.rs | 55 +++------------------ candle-core/src/quantized/cuda.rs | 2 +- candle-core/src/storage.rs | 31 ------------ candle-core/src/tensor.rs | 3 +- 9 files changed, 21 insertions(+), 186 deletions(-) diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 6465bcdd27..40fd63f1ec 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -89,14 +89,6 @@ pub trait BackendStorage: Sized { _: usize, ) -> Result; - fn matmul( - &self, - _: &Self, - _: (usize, usize, usize, usize), - _: &Layout, - _: &Layout, - ) -> Result; - #[allow(clippy::too_many_arguments)] fn matmul_with_alpha_beta( &self, diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 3c4c47305a..65afc477f5 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -2938,7 +2938,7 @@ impl BackendStorage for CpuStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = unsafe { @@ -2949,7 +2949,7 @@ impl BackendStorage for CpuStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?; let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; @@ -2990,8 +2990,9 @@ impl BackendStorage for CpuStorage { vec![0, k_size * c_out, 1], kernel_l.start_offset(), ); - self.matmul( + self.matmul_with_alpha( kernel, + None, ( b_size, /* m */ l_in, @@ -3040,7 +3041,7 @@ impl BackendStorage for CpuStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = unsafe { @@ -3051,7 +3052,7 @@ impl BackendStorage for CpuStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, h_out, w_out, params.c_out)) .transpose(1, 2)? @@ -3141,16 +3142,6 @@ impl BackendStorage for CpuStorage { } } - fn matmul( - &self, - rhs: &Self, - bmnk: (usize, usize, usize, usize), - lhs_l: &Layout, - rhs_l: &Layout, - ) -> Result { - MatMul(bmnk).map(self, lhs_l, rhs, rhs_l) - } - fn matmul_with_alpha_beta( &self, rhs: &Self, diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 954cedc041..69ce0d2533 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1365,7 +1365,7 @@ impl BackendStorage for CudaStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = unsafe { @@ -1376,7 +1376,7 @@ impl BackendStorage for CudaStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; @@ -1420,8 +1420,9 @@ impl BackendStorage for CudaStorage { vec![0, k_size * c_out, 1], kernel_l.start_offset(), ); - self.matmul( + self.matmul_with_alpha( kernel, + None, ( b_size, /* m */ l_in, @@ -1479,7 +1480,7 @@ impl BackendStorage for CudaStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = unsafe { @@ -1490,7 +1491,7 @@ impl BackendStorage for CudaStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, h_out, w_out, n)) .transpose(1, 2)? @@ -1653,62 +1654,6 @@ impl BackendStorage for CudaStorage { Ok(acc) } - fn matmul( - &self, - rhs: &Self, - (b, m, n, k): (usize, usize, usize, usize), - lhs_l: &Layout, - rhs_l: &Layout, - ) -> Result { - let elem_count = b * m * n; - let dev = &self.device; - let slice = match (&self.slice, &rhs.slice) { - (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; - let mut out = unsafe { dev.alloc::(elem_count) }.w()?; - unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) } - .w()?; - CudaStorageSlice::BF16(out) - } - (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; - let mut out = unsafe { dev.alloc::(elem_count) }.w()?; - unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) } - .w()?; - CudaStorageSlice::F16(out) - } - (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; - let mut out = unsafe { dev.alloc::(elem_count) }.w()?; - unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) } - .w()?; - CudaStorageSlice::F32(out) - } - (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; - let mut out = unsafe { dev.alloc::(elem_count) }.w()?; - unsafe { - self.device - .blas - .gemm_strided_batched(cfg, rhs, lhs, &mut out) - } - .w()?; - CudaStorageSlice::F64(out) - } - _ => Err(CudaError::InternalError("dtype mismatch in matmul op"))?, - }; - let device = dev.clone(); - Ok(Self { slice, device }) - } - fn matmul_with_alpha_beta( &self, rhs: &Self, diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 353498b68c..26d9b2f629 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -140,16 +140,6 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - fn matmul( - &self, - _: &Self, - _: (usize, usize, usize, usize), - _: &Layout, - _: &Layout, - ) -> Result { - Err(Error::NotCompiledWithCudaSupport) - } - fn matmul_with_alpha_beta( &self, _: &Self, diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index 07a63d135e..2ec89f97a5 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -152,16 +152,6 @@ impl crate::backend::BackendStorage for MetalStorage { Err(Error::NotCompiledWithMetalSupport) } - fn matmul( - &self, - _: &Self, - _: (usize, usize, usize, usize), - _: &Layout, - _: &Layout, - ) -> Result { - Err(Error::NotCompiledWithMetalSupport) - } - fn matmul_with_alpha_beta( &self, _: &Self, diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 039bb62d95..22a363a77e 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -804,7 +804,7 @@ impl BackendStorage for MetalStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; @@ -812,7 +812,7 @@ impl BackendStorage for MetalStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; @@ -863,8 +863,9 @@ impl BackendStorage for MetalStorage { vec![0, k_size * c_out, 1], k_layout.start_offset(), ); - self.matmul( + self.matmul_with_alpha( k, + None, (b_size, l_in, c_out * k_size, c_in), &layout.transpose(1, 2)?, &kernel_l_mm, @@ -995,7 +996,7 @@ impl BackendStorage for MetalStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; @@ -1003,7 +1004,7 @@ impl BackendStorage for MetalStorage { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, h_out, w_out, n)) .transpose(1, 2)? @@ -1399,50 +1400,6 @@ impl BackendStorage for MetalStorage { Ok(acc) } - fn matmul( - &self, - rhs: &Self, - (b, m, n, k): (usize, usize, usize, usize), - lhs_l: &Layout, - rhs_l: &Layout, - ) -> Result { - let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?; - let name = match self.dtype { - DType::F32 => "sgemm", - DType::F16 => "hgemm", - DType::BF16 => "bgemm", - dtype => { - return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into()) - } - }; - - let command_buffer = self.device.command_buffer()?; - command_buffer.set_label("matmul"); - candle_metal_kernels::call_gemm( - &self.device.device, - &command_buffer, - &self.device.kernels, - name, - (b, m, n, k), - lhs_l.stride(), - lhs_l.start_offset() * self.dtype.size_in_bytes(), - &self.buffer, - rhs_l.stride(), - rhs_l.start_offset() * rhs.dtype.size_in_bytes(), - &rhs.buffer, - &buffer, - 1., - 0., - ) - .map_err(MetalError::from)?; - Ok(Self::new( - buffer, - self.device.clone(), - b * m * n, - self.dtype(), - )) - } - fn matmul_with_alpha_beta( &self, rhs: &Self, diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 8e4884b28d..d7229fbda6 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -538,7 +538,7 @@ impl QCudaStorage { let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) { let data_f32 = self.dequantize(n * k)?; let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?; - storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)? + storage.matmul_with_alpha(&data_f32, None, (b, m, n, k), layout, &rhs_l)? } else { let storage = storage.as_cuda_slice::()?; let storage = match layout.contiguous_offsets() { diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index c9d9b5c3ec..8ff1cbf82a 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -705,37 +705,6 @@ impl Storage { } } - pub(crate) fn matmul( - &self, - rhs: &Self, - bmnk: (usize, usize, usize, usize), - lhs_layout: &Layout, - rhs_layout: &Layout, - ) -> Result { - self.same_device(rhs, "matmul")?; - self.same_dtype(rhs, "matmul")?; - match (self, rhs) { - (Self::Cpu(lhs), Self::Cpu(rhs)) => { - let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; - Ok(Self::Cpu(storage)) - } - (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; - Ok(Self::Cuda(storage)) - } - (Self::Metal(lhs), Self::Metal(rhs)) => { - let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; - Ok(Self::Metal(storage)) - } - (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { - lhs: lhs.device().location(), - rhs: rhs.device().location(), - op: "matmul", - } - .bt()), - } - } - #[allow(clippy::too_many_arguments)] pub(crate) fn matmul_with_alpha_beta( &self, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 49f396c6f2..2624c5aa5e 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1194,8 +1194,9 @@ impl Tensor { .bt())? } - let storage = self.storage().matmul( + let storage = self.storage().matmul_with_alpha( &rhs.storage(), + None, (batching, m, n, k), self.layout(), rhs.layout(), From 7ad6494dc04024532dfdacdb2f9a09908c272a6b Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Fri, 9 Aug 2024 13:54:32 -0400 Subject: [PATCH 31/75] Mistral.rs GPTQ dev PR (#14) * Add i32 dtype for cpu and cuda, with kernels * Fix cuda i32 * Fix cpu i32 * Add cuda map impls for i32 * Start to add to metal * Add the kernels * Oops * Fix dtype cast in safetensors * Oops * Oops * Add bf16 to i32 and vice versa casts --- .vscode/settings.json | 4 +- candle-core/src/convert.rs | 5 ++ candle-core/src/cpu/kernels.rs | 11 +++ candle-core/src/cpu_backend/mod.rs | 114 +++++++++++++++++++++++- candle-core/src/cpu_backend/utils.rs | 2 + candle-core/src/cuda_backend/device.rs | 32 ++++++- candle-core/src/cuda_backend/mod.rs | 53 ++++++++++- candle-core/src/cuda_backend/utils.rs | 2 + candle-core/src/display.rs | 7 ++ candle-core/src/dtype.rs | 19 +++- candle-core/src/metal_backend/mod.rs | 69 ++++++++++++++ candle-core/src/npy.rs | 6 ++ candle-core/src/op.rs | 56 ++++++++++++ candle-core/src/safetensors.rs | 8 +- candle-core/src/sort.rs | 1 + candle-core/tests/tensor_tests.rs | 6 +- candle-kernels/src/affine.cu | 1 + candle-kernels/src/binary.cu | 12 +++ candle-kernels/src/cast.cu | 18 ++++ candle-kernels/src/cuda_utils.cuh | 2 + candle-kernels/src/fill.cu | 2 + candle-kernels/src/indexing.cu | 46 ++++++++++ candle-kernels/src/reduce.cu | 1 + candle-kernels/src/sort.cu | 1 + candle-kernels/src/ternary.cu | 10 +++ candle-kernels/src/unary.cu | 1 + candle-metal-kernels/src/binary.metal | 6 +- candle-metal-kernels/src/cast.metal | 16 ++++ candle-metal-kernels/src/indexing.metal | 22 +++++ candle-metal-kernels/src/lib.rs | 7 ++ candle-metal-kernels/src/reduce.metal | 6 ++ candle-metal-kernels/src/sort.metal | 1 + candle-metal-kernels/src/ternary.metal | 14 +++ candle-metal-kernels/src/unary.metal | 3 + candle-pyo3/src/lib.rs | 2 + 35 files changed, 548 insertions(+), 18 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 646783a968..e510b688c4 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,5 +8,7 @@ ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, - "rust-analyzer.cargo.features": ["cuda"] + "rust-analyzer.cargo.features": [ + "cuda" + ], } \ No newline at end of file diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs index 5ea5612a7c..b29ff346f6 100644 --- a/candle-core/src/convert.rs +++ b/candle-core/src/convert.rs @@ -130,6 +130,11 @@ impl Tensor { f.write_u32::(v)? } } + DType::I32 => { + for v in vs.to_vec1::()? { + f.write_i32::(v)? + } + } DType::I64 => { for v in vs.to_vec1::()? { f.write_i64::(v)? diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs index 527646d62b..fe0e241622 100644 --- a/candle-core/src/cpu/kernels.rs +++ b/candle-core/src/cpu/kernels.rs @@ -144,6 +144,17 @@ impl VecOps for u32 { ::max(self, other) } } +impl VecOps for i32 { + #[inline(always)] + fn min(self, other: Self) -> Self { + ::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + ::max(self, other) + } +} impl VecOps for i64 { #[inline(always)] fn min(self, other: Self) -> Self { diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 3c4c47305a..2bb8991c0c 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -22,6 +22,7 @@ const USE_IM2COL_CONV2D: bool = true; pub enum CpuStorage { U8(Vec), U32(Vec), + I32(Vec), I64(Vec), BF16(Vec), F16(Vec), @@ -33,6 +34,7 @@ pub enum CpuStorage { pub enum CpuStorageRef<'a> { U8(&'a [u8]), U32(&'a [u32]), + I32(&'a [i32]), I64(&'a [i64]), BF16(&'a [bf16]), F16(&'a [f16]), @@ -2285,6 +2287,17 @@ impl CpuStorage { .concat(); Self::U32(storages) } + Self::I32(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::I32(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::I32(storages) + } Self::I64(_) => { let storages = storages .iter() @@ -2352,6 +2365,7 @@ impl BackendStorage for CpuStorage { match self { Self::U8(_) => DType::U8, Self::U32(_) => DType::U32, + Self::I32(_) => DType::I32, Self::I64(_) => DType::I64, Self::BF16(_) => DType::BF16, Self::F16(_) => DType::F16, @@ -2371,6 +2385,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); Ok(Self::BF16(data)) } + (Self::I32(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } (Self::I64(storage), DType::BF16) => { let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); Ok(Self::BF16(data)) @@ -2399,6 +2417,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) } + (Self::I32(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } (Self::I64(storage), DType::F16) => { let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) @@ -2427,6 +2449,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } + (Self::I32(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } (Self::I64(storage), DType::F32) => { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) @@ -2471,6 +2497,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) } + (Self::I32(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } (Self::I64(storage), DType::U8) => { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) @@ -2483,6 +2513,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v); Ok(Self::U32(data)) } + (Self::I32(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } (Self::I64(storage), DType::U32) => { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) @@ -2503,6 +2537,38 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } + (Self::U8(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::U32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::I32(data)) + } + (Self::I64(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::BF16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } + (Self::F16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } + (Self::F32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::F64(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } (Self::U8(storage), DType::I64) => { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) @@ -2511,6 +2577,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) } + (Self::I32(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } (Self::I64(storage), DType::I64) => { let data = unary_map(storage, layout, |v| v); Ok(Self::I64(data)) @@ -2539,6 +2609,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) } + (Self::I32(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } (Self::I64(storage), DType::F64) => { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) @@ -2674,6 +2748,7 @@ impl BackendStorage for CpuStorage { } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), + Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), } } @@ -2699,6 +2774,7 @@ impl BackendStorage for CpuStorage { } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), + Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), } } @@ -2749,6 +2825,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, B::u32); Ok(Self::U32(data)) } + Self::I32(storage) => { + let data = unary_map(storage, layout, B::i32); + Ok(Self::I32(data)) + } Self::I64(storage) => { let data = unary_map(storage, layout, B::i64); Ok(Self::I64(data)) @@ -2803,6 +2883,14 @@ impl BackendStorage for CpuStorage { }; Ok(Self::U32(data)) } + (Self::I32(lhs), Self::I32(rhs)) => { + let data = if B::I32_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i32, B::i32_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::i32) + }; + Ok(Self::I32(data)) + } (Self::I64(lhs), Self::I64(rhs)) => { let data = if B::I64_VEC { binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec) @@ -2846,6 +2934,9 @@ impl BackendStorage for CpuStorage { (Self::U32(src), Self::U32(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } + (Self::I32(src), Self::I32(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } (Self::I64(src), Self::I64(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } @@ -2877,6 +2968,7 @@ impl BackendStorage for CpuStorage { match (self, dst) { (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), @@ -2906,6 +2998,7 @@ impl BackendStorage for CpuStorage { match self { Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")), } @@ -3075,6 +3168,7 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + Self::I32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()), } @@ -3084,6 +3178,7 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l), + Self::I32(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()), } @@ -3101,6 +3196,7 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + Self::I32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()), } @@ -3130,6 +3226,13 @@ impl BackendStorage for CpuStorage { }; IndexAdd { ids, dim }.map(self, l, src, src_l) } + Self::I32(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } Self::I64(ids) => { let ids = match ids_l.contiguous_offsets() { Some((a, b)) => &ids[a..b], @@ -3225,7 +3328,7 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::I64 => { + DType::U8 | DType::U32 | DType::I32 | DType::I64 => { Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()) } DType::BF16 => { @@ -3271,7 +3374,7 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::I64 => { + DType::U8 | DType::U32 | DType::I32 | DType::I64 => { Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) } DType::BF16 => { @@ -3330,6 +3433,11 @@ impl BackendDevice for CpuDevice { v.set_len(elem_count); CpuStorage::U32(v) } + DType::I32 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::I32(v) + } DType::I64 => { let mut v = Vec::with_capacity(elem_count); v.set_len(elem_count); @@ -3364,6 +3472,7 @@ impl BackendDevice for CpuDevice { let storage = match dtype { DType::U8 => CpuStorage::U8(vec![1u8; elem_count]), DType::U32 => CpuStorage::U32(vec![1u32; elem_count]), + DType::I32 => CpuStorage::I32(vec![1i32; elem_count]), DType::I64 => CpuStorage::I64(vec![1i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]), @@ -3378,6 +3487,7 @@ impl BackendDevice for CpuDevice { let storage = match dtype { DType::U8 => CpuStorage::U8(vec![0u8; elem_count]), DType::U32 => CpuStorage::U32(vec![0u32; elem_count]), + DType::I32 => CpuStorage::I32(vec![0i32; elem_count]), DType::I64 => CpuStorage::I64(vec![0i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]), diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index 9d38729145..297ccd3de6 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -10,6 +10,7 @@ pub trait Map1 { match vs { C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)), C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)), + C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)), C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)), C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)), C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)), @@ -26,6 +27,7 @@ pub trait Map1Any { match vs { C::U8(vs) => Ok(self.f(vs, layout, C::U8)?), C::U32(vs) => Ok(self.f(vs, layout, C::U32)?), + C::I32(vs) => Ok(self.f(vs, layout, C::I32)?), C::I64(vs) => Ok(self.f(vs, layout, C::I64)?), C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?), C::F16(vs) => Ok(self.f(vs, layout, C::F16)?), diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 352bae9442..9e0b64067b 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -79,6 +79,14 @@ impl CudaDevice { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(data) } + DType::I32 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_i32", kernels::FILL)?; + let params = (&data, v as i32, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I32(data) + } DType::I64 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; @@ -192,6 +200,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::U32(data) } + DType::I32 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::I32(data) + } DType::I64 => { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::I64(data) @@ -225,7 +237,7 @@ impl BackendDevice for CudaDevice { let slice = match dtype { // TODO: Add support for F16 and BF16 though this is likely to require some upstream // cudarc changes. - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { + DType::U8 | DType::U32 | DType::I64 | DType::I32 | DType::F16 | DType::BF16 => { Err(CudaError::UnsupportedDtype { dtype, op: "rand_uniform", @@ -269,7 +281,7 @@ impl BackendDevice for CudaDevice { elem_count }; let slice = match dtype { - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { + DType::U8 | DType::U32 | DType::I32 | DType::I64 | DType::F16 | DType::BF16 => { Err(CudaError::UnsupportedDtype { dtype, op: "rand_normal", @@ -311,6 +323,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc::(elem_count).w()?; CudaStorageSlice::U32(data) } + DType::I32 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::I32(data) + } DType::I64 => { let data = self.alloc::(elem_count).w()?; CudaStorageSlice::I64(data) @@ -348,6 +364,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorageRef::I32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I32(data) + } CpuStorageRef::I64(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::I64(data) @@ -385,6 +405,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorage::I32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I32(data) + } CpuStorage::I64(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::I64(data) @@ -422,6 +446,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorage::I32(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::I32(data) + } CpuStorage::I64(storage) => { let data = self.htod_copy(storage).w()?; CudaStorageSlice::I64(data) diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 827e22e797..aedcbdd7cb 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -47,6 +47,7 @@ impl SlicePtrOrNull { pub enum CudaStorageSlice { U8(CudaSlice), U32(CudaSlice), + I32(CudaSlice), I64(CudaSlice), BF16(CudaSlice), F16(CudaSlice), @@ -361,11 +362,14 @@ impl<'a> Map1 for IndexSelect<'a> { CudaStorageSlice::U8(slice) => { ("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr()) } + CudaStorageSlice::I32(slice) => { + ("is_i32", *slice.slice(ids_l.start_offset()..).device_ptr()) + } CudaStorageSlice::I64(slice) => { ("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr()) } _ => Err(CudaError::UnexpectedDType { - msg: "index_select ids should be u8 or u32", + msg: "index_select ids should be u8/u32/i32/i64", expected: DType::U32, got: self.0.dtype(), }) @@ -425,11 +429,14 @@ impl<'a> Map1 for Gather<'a> { ("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr()) } CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I32(slice) => { + ("gather_i32", *slice.slice(ids_o1..ids_o2).device_ptr()) + } CudaStorageSlice::I64(slice) => { ("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr()) } _ => Err(CudaError::UnexpectedDType { - msg: "gather ids should be u8/u32/i64", + msg: "gather ids should be u8/u32/i32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -475,10 +482,11 @@ impl<'a> Map2InPlace for IndexAdd<'a> { }; let (name, ids) = match &ids.slice { CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I32(slice) => ("ia_i32", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), _ => Err(CudaError::UnexpectedDType { - msg: "index-add ids should be u8/u32/i64", + msg: "index-add ids should be u8/u32/i32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -523,10 +531,11 @@ impl<'a> Map2InPlace for ScatterAdd<'a> { }; let (name, ids) = match &ids.slice { CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I32(slice) => ("sa_i32", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), _ => Err(CudaError::UnexpectedDType { - msg: "scatter-add ids should be u8/u32/i64", + msg: "scatter-add ids should be u8/u32/i32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -865,6 +874,10 @@ impl<'a> Map2 for WhereCond<'a> { let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); (ptr, "where_u32") } + CudaStorageSlice::I32(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_i32") + } CudaStorageSlice::I64(slice) => { let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); (ptr, "where_i64") @@ -1024,6 +1037,7 @@ macro_rules! cuda_dtype { } cuda_dtype!(u8, U8); cuda_dtype!(u32, U32); +cuda_dtype!(i32, I32); cuda_dtype!(i64, I64); cuda_dtype!(f16, F16); cuda_dtype!(bf16, BF16); @@ -1146,6 +1160,7 @@ impl BackendStorage for CudaStorage { match self.slice { CudaStorageSlice::U8(_) => DType::U8, CudaStorageSlice::U32(_) => DType::U32, + CudaStorageSlice::I32(_) => DType::I32, CudaStorageSlice::I64(_) => DType::I64, CudaStorageSlice::BF16(_) => DType::BF16, CudaStorageSlice::F16(_) => DType::F16, @@ -1172,6 +1187,7 @@ impl BackendStorage for CudaStorage { let inp = match &self.slice { CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::I32(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), @@ -1195,6 +1211,12 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(out) } + DType::I32 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I32(out) + } DType::I64 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); @@ -1291,6 +1313,11 @@ impl BackendStorage for CudaStorage { let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::U32(cpu_storage)) } + CudaStorageSlice::I32(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::I32(cpu_storage)) + } CudaStorageSlice::I64(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; @@ -1557,6 +1584,7 @@ impl BackendStorage for CudaStorage { S::F64(out) } (S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?, + (S::I32(_), S::I32(_)) => Err(CudaError::InternalError("conv2d does not support i32"))?, (S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?, _ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?, }; @@ -1879,6 +1907,11 @@ impl BackendStorage for CudaStorage { *d.slice(dst_o..).device_ptr(), "copy2d_u32", ), + (S::I32(s), S::I32(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_i32", + ), (S::I64(s), S::I64(d)) => ( *s.slice(src_o..).device_ptr(), *d.slice(dst_o..).device_ptr(), @@ -1985,6 +2018,18 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()? } } + (CudaStorageSlice::I32(src), CudaStorageSlice::I32(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_i32", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } (CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index c1210727ad..ae009b26ab 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -19,6 +19,7 @@ pub trait Map1 { let out = match s { S::U8(s) => S::U8(self.f(s, d, l)?), S::U32(s) => S::U32(self.f(s, d, l)?), + S::I32(s) => S::I32(self.f(s, d, l)?), S::I64(s) => S::I64(self.f(s, d, l)?), S::BF16(s) => S::BF16(self.f(s, d, l)?), S::F16(s) => S::F16(self.f(s, d, l)?), @@ -136,6 +137,7 @@ pub trait Map1Any { let out = match s { S::U8(s) => self.f(s, d, l, S::U8)?, S::U32(s) => self.f(s, d, l, S::U32)?, + S::I32(s) => self.f(s, d, l, S::I32)?, S::I64(s) => self.f(s, d, l, S::I64)?, S::BF16(s) => self.f(s, d, l, S::BF16)?, S::F16(s) => self.f(s, d, l, S::F16)?, diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 7e6e3cf8f1..5fb370b696 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -55,6 +55,7 @@ impl std::fmt::Debug for Tensor { match self.dtype() { DType::U8 => self.fmt_dt::(f), DType::U32 => self.fmt_dt::(f), + DType::I32 => self.fmt_dt::(f), DType::I64 => self.fmt_dt::(f), DType::BF16 => self.fmt_dt::(f), DType::F16 => self.fmt_dt::(f), @@ -463,6 +464,12 @@ impl std::fmt::Display for Tensor { tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; writeln!(f)?; } + DType::I32 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } DType::I64 => { let tf: IntFormatter = IntFormatter::new(); let max_w = tf.max_width(&to_display); diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index de6cddc3a3..c6a0800b24 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -10,6 +10,8 @@ pub enum DType { U8, // Unsigned 32 bits integer. U32, + // Signed 32 bits integer. + I32, // Signed 64 bits integer. I64, // Brain floating-point using half precision (16 bits). @@ -39,6 +41,7 @@ impl std::str::FromStr for DType { match s { "u8" => Ok(Self::U8), "u32" => Ok(Self::U32), + "i32" => Ok(Self::I32), "i64" => Ok(Self::I64), "bf16" => Ok(Self::BF16), "f16" => Ok(Self::F16), @@ -55,6 +58,7 @@ impl DType { match self { Self::U8 => "u8", Self::U32 => "u32", + Self::I32 => "i32", Self::I64 => "i64", Self::BF16 => "bf16", Self::F16 => "f16", @@ -68,6 +72,7 @@ impl DType { match self { Self::U8 => 1, Self::U32 => 4, + Self::I32 => 4, Self::I64 => 8, Self::BF16 => 2, Self::F16 => 2, @@ -78,14 +83,14 @@ impl DType { pub fn is_int(&self) -> bool { match self { - Self::U8 | Self::U32 | Self::I64 => true, + Self::U8 | Self::U32 | Self::I32 | Self::I64 => true, Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false, } } pub fn is_float(&self) -> bool { match self { - Self::U8 | Self::U32 | Self::I64 => false, + Self::U8 | Self::U32 | Self::I32 | Self::I64 => false, Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true, } } @@ -169,6 +174,7 @@ use half::{bf16, f16}; with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64); with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64); +with_dtype!(i32, I32, |v: f64| v as i32, |v: i32| v as f64); with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64); with_dtype!(f16, F16, f16::from_f64, f16::to_f64); with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); @@ -180,6 +186,15 @@ pub trait IntDType: WithDType { fn as_usize(&self) -> usize; } +impl IntDType for i32 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + impl IntDType for i64 { fn is_true(&self) -> bool { *self != 0 diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 039bb62d95..cdcec45f75 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -96,6 +96,7 @@ impl BackendStorage for MetalStorage { match self.dtype { DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)), DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)), + DType::I32 => Ok(CpuStorage::I32(self.to_cpu()?)), DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)), DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)), DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)), @@ -304,6 +305,11 @@ impl BackendStorage for MetalStorage { (ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false), (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true), (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true), + (ReduceOp::Sum, DType::I32) => ("fast_sum_i32_strided", false, false), + (ReduceOp::Min, DType::I32) => ("fast_min_i32_strided", true, false), + (ReduceOp::Max, DType::I32) => ("fast_max_i32_strided", true, false), + (ReduceOp::ArgMin, DType::I32) => ("fast_argmin_i32_strided", true, true), + (ReduceOp::ArgMax, DType::I32) => ("fast_argmax_i32_strided", true, true), (ReduceOp::Sum, DType::I64) => ("fast_sum_i64_strided", false, false), (ReduceOp::Min, DType::I64) => ("fast_min_i64_strided", true, false), (ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false), @@ -363,21 +369,30 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::BF16) => "cast_u32_bf16", (DType::U32, DType::F16) => "cast_u32_f16", (DType::U32, DType::F32) => "cast_u32_f32", + (DType::U32, DType::I32) => "cast_u32_i32", (DType::U32, DType::I64) => "cast_u32_i64", (DType::U32, DType::U8) => "cast_u32_u8", (DType::U8, DType::BF16) => "cast_u8_bf16", (DType::U8, DType::F16) => "cast_u8_f16", (DType::U8, DType::F32) => "cast_u8_f32", + (DType::U8, DType::I32) => "cast_u8_i32", (DType::U8, DType::I64) => "cast_u8_i64", (DType::U8, DType::U32) => "cast_u8_u32", (DType::F32, DType::BF16) => "cast_f32_bf16", (DType::F32, DType::F16) => "cast_f32_f16", + (DType::F32, DType::I32) => "cast_f32_i32", (DType::F32, DType::I64) => "cast_f32_i64", (DType::F32, DType::U32) => "cast_f32_u32", (DType::F32, DType::U8) => "cast_f32_u8", + (DType::I32, DType::BF16) => "cast_i32_bf16", + (DType::I32, DType::F16) => "cast_i32_f16", + (DType::I32, DType::F32) => "cast_i32_f32", + (DType::I32, DType::U32) => "cast_i32_u32", + (DType::I32, DType::U8) => "cast_i32_u8", + (DType::I64, DType::BF16) => "cast_i64_bf16", (DType::I64, DType::F16) => "cast_i64_f16", (DType::I64, DType::F32) => "cast_i64_f32", @@ -386,12 +401,14 @@ impl BackendStorage for MetalStorage { (DType::F16, DType::BF16) => "cast_f16_bf16", (DType::F16, DType::F32) => "cast_f16_f32", + (DType::F16, DType::I32) => "cast_f16_i32", (DType::F16, DType::I64) => "cast_f16_i64", (DType::F16, DType::U32) => "cast_f16_u32", (DType::F16, DType::U8) => "cast_f16_u8", (DType::BF16, DType::F16) => "cast_bf16_f16", (DType::BF16, DType::F32) => "cast_bf16_f32", + (DType::BF16, DType::I32) => "cast_bf16_i32", (DType::BF16, DType::I64) => "cast_bf16_i64", (DType::BF16, DType::U32) => "cast_bf16_u32", (DType::BF16, DType::U8) => "cast_bf16_u8", @@ -414,12 +431,15 @@ impl BackendStorage for MetalStorage { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32_strided", (DType::U32, DType::U8) => "cast_u32_u8_strided", + (DType::U32, DType::I32) => "cast_u32_i32_strided", (DType::U32, DType::I64) => "cast_u32_i64_strided", (DType::U8, DType::U32) => "cast_u8_u32_strided", (DType::U8, DType::F32) => "cast_u8_f32_strided", + (DType::U8, DType::I32) => "cast_u8_i32_strided", (DType::U8, DType::I64) => "cast_u8_i64_strided", (DType::F32, DType::F16) => "cast_f32_f16_strided", (DType::F16, DType::F32) => "cast_f16_f32_strided", + (DType::I32, DType::F32) => "cast_i32_f32_strided", (DType::I64, DType::F32) => "cast_i64_f32_strided", (DType::F32, DType::BF16) => "cast_f32_bf16_strided", (DType::BF16, DType::F32) => "cast_bf16_f32_strided", @@ -514,6 +534,7 @@ impl BackendStorage for MetalStorage { ("usign", DType::F16) => contiguous_tiled::sign::HALF, ("usign", DType::F32) => contiguous_tiled::sign::FLOAT, ("usign", DType::BF16) => contiguous_tiled::sign::BFLOAT, + ("usign", DType::I32) => contiguous_tiled::sign::I32, ("usign", DType::I64) => contiguous_tiled::sign::I64, (name, dtype) => { crate::bail!( @@ -592,6 +613,7 @@ impl BackendStorage for MetalStorage { ("usign", DType::F16) => contiguous::sign::HALF, ("usign", DType::F32) => contiguous::sign::FLOAT, ("usign", DType::BF16) => contiguous::sign::BFLOAT, + ("usign", DType::I32) => contiguous::sign::I32, ("usign", DType::I64) => contiguous::sign::I64, (name, dtype) => { crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") @@ -723,6 +745,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "where_u32_f32", (DType::U8, DType::BF16) => "where_u8_bf16", (DType::U8, DType::F16) => "where_u8_f16", + (DType::U8, DType::I32) => "where_u8_i32", (DType::U8, DType::I64) => "where_u8_i64", (DType::U8, DType::U32) => "where_u8_u32", (DType::U8, DType::U8) => "where_u8_u8", @@ -1259,6 +1282,9 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "sa_u32_f32", (DType::U32, DType::F16) => "sa_u32_f16", (DType::U32, DType::BF16) => "sa_u32_bf16", + (DType::I32, DType::F32) => "sa_i32_f32", + (DType::I32, DType::F16) => "sa_i32_f16", + (DType::I32, DType::BF16) => "sa_i32_bf16", (DType::I64, DType::F32) => "sa_i64_f32", (DType::I64, DType::F16) => "sa_i64_f16", (DType::I64, DType::BF16) => "sa_i64_bf16", @@ -1307,6 +1333,10 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F16) => "is_u32_f16", (DType::U32, DType::BF16) => "is_u32_bf16", + (DType::I32, DType::F32) => "is_i32_f32", + (DType::I32, DType::F16) => "is_i32_f16", + (DType::I32, DType::BF16) => "is_i32_bf16", + (DType::I64, DType::F32) => "is_i64_f32", (DType::I64, DType::F16) => "is_i64_f16", (DType::I64, DType::BF16) => "is_i64_bf16", @@ -1352,9 +1382,18 @@ impl BackendStorage for MetalStorage { return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt()); }; let name = match (ids.dtype, self.dtype) { + (DType::I32, DType::BF16) => "ia_i32_bf16", + (DType::I32, DType::F16) => "ia_i32_f16", + (DType::I32, DType::F32) => "ia_i32_f32", + (DType::I32, DType::I32) => "ia_i32_i32", + (DType::I32, DType::I64) => "ia_i32_i64", + (DType::I32, DType::U32) => "ia_i32_u32", + (DType::I32, DType::U8) => "ia_i32_u8", + (DType::I64, DType::BF16) => "ia_i64_bf16", (DType::I64, DType::F16) => "ia_i64_f16", (DType::I64, DType::F32) => "ia_i64_f32", + (DType::I64, DType::I32) => "ia_i64_i32", (DType::I64, DType::I64) => "ia_i64_i64", (DType::I64, DType::U32) => "ia_i64_u32", (DType::I64, DType::U8) => "ia_i64_u8", @@ -1362,6 +1401,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::BF16) => "ia_u32_bf16", (DType::U32, DType::F16) => "ia_u32_f16", (DType::U32, DType::F32) => "ia_u32_f32", + (DType::U32, DType::I32) => "ia_u32_i32", (DType::U32, DType::I64) => "ia_u32_i64", (DType::U32, DType::U32) => "ia_u32_u32", (DType::U32, DType::U8) => "ia_u32_u8", @@ -1369,6 +1409,7 @@ impl BackendStorage for MetalStorage { (DType::U8, DType::BF16) => "ia_u8_bf16", (DType::U8, DType::F16) => "ia_u8_f16", (DType::U8, DType::F32) => "ia_u8_f32", + (DType::U8, DType::I32) => "ia_u8_i32", (DType::U8, DType::I64) => "ia_u8_i64", (DType::U8, DType::U32) => "ia_u8_u32", (DType::U8, DType::U8) => "ia_u8_u8", @@ -1579,6 +1620,7 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::copy2d::FLOAT, DType::F16 => candle_metal_kernels::copy2d::HALF, DType::BF16 => candle_metal_kernels::copy2d::BFLOAT, + DType::I32 => candle_metal_kernels::copy2d::I32, DType::I64 => candle_metal_kernels::copy2d::I64, DType::U32 => candle_metal_kernels::copy2d::U32, DType::U8 => candle_metal_kernels::copy2d::U8, @@ -1625,6 +1667,7 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + DType::I32 => candle_metal_kernels::unary::strided::copy::I32, DType::I64 => candle_metal_kernels::unary::strided::copy::I64, DType::U32 => candle_metal_kernels::unary::strided::copy::U32, DType::U8 => candle_metal_kernels::unary::strided::copy::U8, @@ -1716,6 +1759,17 @@ impl MetalStorage { ("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8), ("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8), + ("add", DType::I32) => (contiguous::add::I32, self.dtype), + ("sub", DType::I32) => (contiguous::sub::I32, self.dtype), + ("mul", DType::I32) => (contiguous::mul::I32, self.dtype), + ("div", DType::I32) => (contiguous::div::I32, self.dtype), + ("eq", DType::I32) => (contiguous::eq::I32, DType::U8), + ("ne", DType::I32) => (contiguous::ne::I32, DType::U8), + ("le", DType::I32) => (contiguous::le::I32, DType::U8), + ("lt", DType::I32) => (contiguous::lt::I32, DType::U8), + ("ge", DType::I32) => (contiguous::ge::I32, DType::U8), + ("gt", DType::I32) => (contiguous::gt::I32, DType::U8), + ("add", DType::I64) => (contiguous::add::I64, self.dtype), ("sub", DType::I64) => (contiguous::sub::I64, self.dtype), ("mul", DType::I64) => (contiguous::mul::I64, self.dtype), @@ -1809,6 +1863,19 @@ impl MetalStorage { ("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8), ("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8), + ("badd", DType::I32) => (strided::add::I32, self.dtype), + ("bsub", DType::I32) => (strided::sub::I32, self.dtype), + ("bmul", DType::I32) => (strided::mul::I32, self.dtype), + ("bdiv", DType::I32) => (strided::div::I32, self.dtype), + ("bminimum", DType::I32) => (strided::min::I32, self.dtype), + ("bmaximum", DType::I32) => (strided::max::I32, self.dtype), + ("eq", DType::I32) => (strided::eq::I32, DType::U8), + ("ne", DType::I32) => (strided::ne::I32, DType::U8), + ("le", DType::I32) => (strided::le::I32, DType::U8), + ("lt", DType::I32) => (strided::lt::I32, DType::U8), + ("ge", DType::I32) => (strided::ge::I32, DType::U8), + ("gt", DType::I32) => (strided::gt::I32, DType::U8), + ("badd", DType::I64) => (strided::add::I64, self.dtype), ("bsub", DType::I64) => (strided::sub::I64, self.dtype), ("bmul", DType::I64) => (strided::mul::I64, self.dtype), @@ -1964,6 +2031,7 @@ impl BackendDevice for MetalDevice { let (count, buffer) = match T::cpu_storage_ref(s) { CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), @@ -1977,6 +2045,7 @@ impl BackendDevice for MetalDevice { let (count, buffer) = match storage { CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 83e4f6527f..b321a619f8 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -85,6 +85,7 @@ impl Header { DType::F16 => "f2", DType::F32 => "f4", DType::F64 => "f8", + DType::I32 => "i4", DType::I64 => "i8", DType::U32 => "u4", DType::U8 => "u1", @@ -234,6 +235,11 @@ impl Tensor { reader.read_u32_into::(&mut data_t)?; Tensor::from_vec(data_t, shape, &Device::Cpu) } + DType::I32 => { + let mut data_t = vec![0i32; elem_count]; + reader.read_i32_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } DType::I64 => { let mut data_t = vec![0i64; elem_count]; reader.read_i64_into::(&mut data_t)?; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 49ba44be89..75931ee2fe 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -189,6 +189,7 @@ pub trait UnaryOpT { fn f64(v1: f64) -> f64; fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; + fn i32(v1: i32) -> i32; fn i64(v1: i64) -> i64; // There is no very good way to represent optional function in traits so we go for an explicit @@ -213,6 +214,7 @@ pub trait BinaryOpT { fn f64(v1: f64, v2: f64) -> f64; fn u8(v1: u8, v2: u8) -> u8; fn u32(v1: u32, v2: u32) -> u32; + fn i32(v1: i32, v2: i32) -> i32; fn i64(v1: i64, v2: i64) -> i64; const BF16_VEC: bool = false; @@ -229,6 +231,8 @@ pub trait BinaryOpT { fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {} const I64_VEC: bool = false; fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {} + const I32_VEC: bool = false; + fn i32_vec(_xs1: &[i32], _xs2: &[i32], _ys: &mut [i32]) {} } pub(crate) struct Add; @@ -288,6 +292,10 @@ macro_rules! bin_op { $e(v1, v2) } #[inline(always)] + fn i32(v1: i32, v2: i32) -> i32 { + $e(v1, v2) + } + #[inline(always)] fn i64(v1: i64, v2: i64) -> i64 { $e(v1, v2) } @@ -379,6 +387,10 @@ macro_rules! unary_op { fn i64(_: i64) -> i64 { todo!("no unary function for i64") } + #[inline(always)] + fn i32(_: i32) -> i32 { + todo!("no unary function for i32") + } } }; @@ -415,6 +427,10 @@ macro_rules! unary_op { fn i64(_: i64) -> i64 { todo!("no unary function for i64") } + #[inline(always)] + fn i32(_: i32) -> i32 { + todo!("no unary function for i32") + } #[cfg(feature = "mkl")] const F32_VEC: bool = true; @@ -514,6 +530,10 @@ impl UnaryOpT for Gelu { fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } const KERNEL: &'static str = "ugelu"; #[cfg(feature = "mkl")] @@ -587,6 +607,10 @@ impl UnaryOpT for Erf { fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } } /// Silu operation @@ -621,6 +645,10 @@ impl UnaryOpT for Silu { fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } const KERNEL: &'static str = "usilu"; #[cfg(feature = "mkl")] @@ -692,6 +720,10 @@ impl UnaryOpT for Abs { fn i64(v: i64) -> i64 { v.abs() } + #[inline(always)] + fn i32(v: i32) -> i32 { + v.abs() + } } impl UnaryOpT for Ceil { @@ -726,6 +758,10 @@ impl UnaryOpT for Ceil { fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } } impl UnaryOpT for Floor { @@ -760,6 +796,10 @@ impl UnaryOpT for Floor { fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } } impl UnaryOpT for Round { @@ -794,6 +834,10 @@ impl UnaryOpT for Round { fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } } impl UnaryOpT for GeluErf { @@ -828,6 +872,10 @@ impl UnaryOpT for GeluErf { fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } } impl UnaryOpT for Relu { @@ -862,6 +910,10 @@ impl UnaryOpT for Relu { fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } } /// `BackpropOp` is a wrapper around `Option`. The main goal is to ensure that dependencies are @@ -960,4 +1012,8 @@ impl UnaryOpT for Sign { fn i64(v: i64) -> i64 { (v > 0) as i64 - (v < 0) as i64 } + #[inline(always)] + fn i32(v: i32) -> i32 { + (v > 0) as i32 - (v < 0) as i32 + } } diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 5ea1f192b3..162928ec7d 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -11,6 +11,7 @@ impl From for st::Dtype { DType::U8 => st::Dtype::U8, DType::U32 => st::Dtype::U32, DType::I64 => st::Dtype::I64, + DType::I32 => st::Dtype::I32, DType::BF16 => st::Dtype::BF16, DType::F16 => st::Dtype::F16, DType::F32 => st::Dtype::F32, @@ -187,6 +188,7 @@ impl Tensor { match dtype { DType::U8 => convert_slice::(data, shape, device), DType::U32 => convert_slice::(data, shape, device), + DType::I32 => convert_slice::(data, shape, device), DType::I64 => convert_slice::(data, shape, device), DType::BF16 => convert_slice::(data, shape, device), DType::F16 => convert_slice::(data, shape, device), @@ -204,10 +206,7 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { convert_with_cast_::(view, device, conv) } st::Dtype::U32 => convert_::(view, device), - st::Dtype::I32 => { - let conv = |x| Ok(i64::from(x)); - convert_with_cast_::(view, device, conv) - } + st::Dtype::I32 => convert_::(view, device), st::Dtype::I64 => convert_::(view, device), st::Dtype::BF16 => convert_::(view, device), st::Dtype::F16 => convert_::(view, device), @@ -223,6 +222,7 @@ fn convert_back(tensor: &Tensor) -> Result> { match tensor.dtype() { DType::U8 => Ok(convert_back_::(tensor.to_vec1()?)), DType::U32 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::I32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::I64 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::BF16 => Ok(convert_back_::(tensor.to_vec1()?)), diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 614a37fe65..92ad1d5adc 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -65,6 +65,7 @@ impl crate::CustomOp1 for ArgSort { let sort_indexes = match storage { crate::CpuStorage::U8(vs) => self.asort(vs, layout), crate::CpuStorage::U32(vs) => self.asort(vs, layout), + crate::CpuStorage::I32(vs) => self.asort(vs, layout), crate::CpuStorage::I64(vs) => self.asort(vs, layout), crate::CpuStorage::BF16(vs) => self.asort(vs, layout), crate::CpuStorage::F16(vs) => self.asort(vs, layout), diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 6dd21cc1ab..bff8f36042 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -17,6 +17,10 @@ fn ones(device: &Device) -> Result<()> { Tensor::ones((2, 3), DType::U32, device)?.to_vec2::()?, [[1, 1, 1], [1, 1, 1]], ); + assert_eq!( + Tensor::ones((2, 3), DType::I32, device)?.to_vec2::()?, + [[1, 1, 1], [1, 1, 1]], + ); assert_eq!( Tensor::ones((2, 3), DType::I64, device)?.to_vec2::()?, [[1, 1, 1], [1, 1, 1]], @@ -805,7 +809,7 @@ fn index_select(device: &Device) -> Result<()> { [9.0, 10.0, 11.0] ] ); - for dtype in [DType::U8, DType::U32, DType::I64] { + for dtype in [DType::U8, DType::U32, DType::I32, DType::I64] { let ids = ids.to_dtype(dtype)?; let hs = t.index_select(&ids, 1)?; assert_eq!( diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index 540d0819f5..c3ff5b8753 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -40,4 +40,5 @@ AFFINE_OP(float, affine_f32) AFFINE_OP(double, affine_f64) AFFINE_OP(uint8_t, affine_u8) AFFINE_OP(uint32_t, affine_u32) +AFFINE_OP(int32_t, affine_i32) AFFINE_OP(int64_t, affine_i64) diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index d44e3b20ee..f534fc76ad 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -35,65 +35,77 @@ BINARY_OP(float, badd_f32, x + y) BINARY_OP(double, badd_f64, x + y); BINARY_OP(uint8_t, badd_u8, x + y); BINARY_OP(uint32_t, badd_u32, x + y); +BINARY_OP(int32_t, badd_i32, x + y); BINARY_OP(int64_t, badd_i64, x + y); BINARY_OP(float, bdiv_f32, x / y) BINARY_OP(double, bdiv_f64, x / y); BINARY_OP(uint8_t, bdiv_u8, x / y); BINARY_OP(uint32_t, bdiv_u32, x / y); +BINARY_OP(int32_t, bdiv_i32, x / y); BINARY_OP(int64_t, bdiv_i64, x / y); BINARY_OP(float, bmul_f32, x * y) BINARY_OP(double, bmul_f64, x * y); BINARY_OP(uint8_t, bmul_u8, x * y); BINARY_OP(uint32_t, bmul_u32, x * y); +BINARY_OP(int32_t, bmul_i32, x * y); BINARY_OP(int64_t, bmul_i64, x * y); BINARY_OP(float, bsub_f32, x - y) BINARY_OP(double, bsub_f64, x - y); BINARY_OP(uint8_t, bsub_u8, x - y); BINARY_OP(uint32_t, bsub_u32, x - y); +BINARY_OP(int32_t, bsub_i32, x - y); BINARY_OP(int64_t, bsub_i64, x - y); BINARY_OP(float, bminimum_f32, ming(x, y)); BINARY_OP(double, bminimum_f64, ming(x, y)); BINARY_OP(uint8_t, bminimum_u8, ming(x, y)); BINARY_OP(uint32_t, bminimum_u32, ming(x, y)); +BINARY_OP(int32_t, bminimum_i32, ming(x, y)); BINARY_OP(int64_t, bminimum_i64, ming(x, y)); BINARY_OP(float, bmaximum_f32, maxg(x, y)); BINARY_OP(double, bmaximum_f64, maxg(x, y)); BINARY_OP(uint8_t, bmaximum_u8, maxg(x, y)); BINARY_OP(uint32_t, bmaximum_u32, maxg(x, y)); +BINARY_OP(int32_t, bmaximum_i32, maxg(x, y)); BINARY_OP(int64_t, bmaximum_i64, maxg(x, y)); BINARY_OP_OUT(float, uint8_t, eq_f32, x == y) BINARY_OP_OUT(double, uint8_t, eq_f64, x == y) BINARY_OP_OUT(uint8_t, uint8_t, eq_u8, x == y) BINARY_OP_OUT(uint32_t, uint8_t, eq_u32, x == y) +BINARY_OP_OUT(int32_t, uint8_t, eq_i32, x == y) BINARY_OP_OUT(int64_t, uint8_t, eq_i64, x == y) BINARY_OP_OUT(float, uint8_t, ne_f32, x != y) BINARY_OP_OUT(double, uint8_t, ne_f64, x != y) BINARY_OP_OUT(uint8_t, uint8_t, ne_u8, x != y) BINARY_OP_OUT(uint32_t, uint8_t, ne_u32, x != y) +BINARY_OP_OUT(int32_t, uint8_t, ne_i32, x != y) BINARY_OP_OUT(int64_t, uint8_t, ne_i64, x != y) BINARY_OP_OUT(float, uint8_t, lt_f32, x < y) BINARY_OP_OUT(double, uint8_t, lt_f64, x < y) BINARY_OP_OUT(uint8_t, uint8_t, lt_u8, x < y) BINARY_OP_OUT(uint32_t, uint8_t, lt_u32, x < y) +BINARY_OP_OUT(int32_t, uint8_t, lt_i32, x < y) BINARY_OP_OUT(int64_t, uint8_t, lt_i64, x < y) BINARY_OP_OUT(float, uint8_t, le_f32, x <= y) BINARY_OP_OUT(double, uint8_t, le_f64, x <= y) BINARY_OP_OUT(uint8_t, uint8_t, le_u8, x <= y) BINARY_OP_OUT(uint32_t, uint8_t, le_u32, x <= y) +BINARY_OP_OUT(int32_t, uint8_t, le_i32, x <= y) BINARY_OP_OUT(int64_t, uint8_t, le_i64, x <= y) BINARY_OP_OUT(float, uint8_t, gt_f32, x > y) BINARY_OP_OUT(double, uint8_t, gt_f64, x > y) BINARY_OP_OUT(uint8_t, uint8_t, gt_u8, x > y) BINARY_OP_OUT(uint32_t, uint8_t, gt_u32, x > y) +BINARY_OP_OUT(int32_t, uint8_t, gt_i32, x > y) BINARY_OP_OUT(int64_t, uint8_t, gt_i64, x > y) BINARY_OP_OUT(float, uint8_t, ge_f32, x >= y) BINARY_OP_OUT(double, uint8_t, ge_f64, x >= y) BINARY_OP_OUT(uint8_t, uint8_t, ge_u8, x >= y) BINARY_OP_OUT(uint32_t, uint8_t, ge_u32, x >= y) +BINARY_OP_OUT(int32_t, uint8_t, ge_i32, x >= y) BINARY_OP_OUT(int64_t, uint8_t, ge_i64, x >= y) diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index 90f5e7ba48..f92ac0cbf9 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -83,6 +83,8 @@ CAST_OP(double, __nv_bfloat16, cast_f64_bf16) CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8) CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16) CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) +CAST_THROUGH_OP(int32_t, __nv_bfloat16, float, cast_i32_bf16) +CAST_THROUGH_OP(__nv_bfloat16, int32_t, float, cast_bf16_i32) #else #include #if CUDA_VERSION >= 11000 @@ -94,6 +96,8 @@ CAST_THROUGH_OP(__nv_bfloat16, double, float, cast_bf16_f64) CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) CAST_THROUGH_OP(double, __nv_bfloat16, float, cast_f64_bf16) CAST_THROUGH_OP(uint8_t, __nv_bfloat16, float, cast_u8_bf16) +CAST_THROUGH_OP(int32_t, __nv_bfloat16, float, cast_i32_bf16) +CAST_THROUGH_OP(__nv_bfloat16, int32_t, float, cast_bf16_i32) #endif #endif @@ -108,34 +112,48 @@ CAST_OP(uint8_t, __half, cast_u8_f16 ) CAST_OP(uint32_t, __half, cast_u32_f16) CAST_OP(float, __half, cast_f32_f16) CAST_OP(double, __half, cast_f64_f16) +CAST_OP(int32_t, __half, cast_i32_f16 ) +CAST_THROUGH_OP(__half, int32_t, float, cast_f16_i32) #endif CAST_OP(uint32_t, uint32_t, cast_u32_u32) CAST_OP(uint32_t, uint8_t, cast_u32_u8 ) CAST_OP(uint32_t, int64_t, cast_u32_i64 ) +CAST_OP(uint32_t, int32_t, cast_u32_i32 ) CAST_OP(uint32_t, float, cast_u32_f32) CAST_OP(uint32_t, double, cast_u32_f64) CAST_OP(uint8_t, uint32_t, cast_u8_u32) CAST_OP(uint8_t, uint8_t, cast_u8_u8 ) +CAST_OP(uint8_t, int32_t, cast_u8_i32 ) CAST_OP(uint8_t, int64_t, cast_u8_i64 ) CAST_OP(uint8_t, float, cast_u8_f32) CAST_OP(uint8_t, double, cast_u8_f64) CAST_OP(int64_t, uint32_t, cast_i64_u32) CAST_OP(int64_t, uint8_t, cast_i64_u8 ) +CAST_OP(int64_t, int32_t, cast_i64_i32 ) CAST_OP(int64_t, int64_t, cast_i64_i64 ) CAST_OP(int64_t, float, cast_i64_f32) CAST_OP(int64_t, double, cast_i64_f64) +CAST_OP(int32_t, uint32_t, cast_i32_u32) +CAST_OP(int32_t, uint8_t, cast_i32_u8 ) +CAST_OP(int32_t, int64_t, cast_i32_i64 ) +CAST_OP(int32_t, int32_t, cast_i32_i32 ) +CAST_OP(int32_t, float, cast_i32_f32) +CAST_OP(int32_t, double, cast_i32_f64) + CAST_OP(float, uint8_t, cast_f32_u8 ) CAST_OP(float, uint32_t, cast_f32_u32) +CAST_OP(float, int32_t, cast_f32_i32 ) CAST_OP(float, int64_t, cast_f32_i64 ) CAST_OP(float, float, cast_f32_f32) CAST_OP(float, double, cast_f32_f64) CAST_OP(double, uint8_t, cast_f64_u8 ) CAST_OP(double, uint32_t, cast_f64_u32) +CAST_OP(double, int32_t, cast_f64_i32 ) CAST_OP(double, int64_t, cast_f64_i64 ) CAST_OP(double, float, cast_f64_f32) CAST_OP(double, double, cast_f64_f64) diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index f7a2506d0e..08aa2b089a 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -181,6 +181,8 @@ __device__ __forceinline__ double absg(double a) { return fabs(a); } __device__ __forceinline__ float copysigng(float a, float b) { return copysignf(a, b); } __device__ __forceinline__ double copysigng(double a, double b) { return copysign(a, b); } +__device__ __forceinline__ int32_t ming(int32_t a, int32_t b) { return min(a, b); } +__device__ __forceinline__ int32_t maxg(int32_t a, int32_t b) { return max(a, b); } __device__ __forceinline__ int64_t ming(int64_t a, int64_t b) { return min(a, b); } __device__ __forceinline__ int64_t maxg(int64_t a, int64_t b) { return max(a, b); } __device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a, b); } diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index ca448d989f..42bfddfd9f 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -9,6 +9,7 @@ __device__ void fill_with(T *buf, T value, const size_t numel) { } extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_i32(int32_t *buf, int32_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); } @@ -34,6 +35,7 @@ COPY2D_OP(float, copy2d_f32) COPY2D_OP(double, copy2d_f64) COPY2D_OP(uint8_t, copy2d_u8) COPY2D_OP(uint32_t, copy2d_u32) +COPY2D_OP(int32_t, copy2d_i32) COPY2D_OP(int64_t, copy2d_i64) #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 8af2954d13..2f3df4de1b 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -147,44 +147,61 @@ extern "C" __global__ void FN_NAME( \ #if __CUDA_ARCH__ >= 800 +IS_OP(__nv_bfloat16, int32_t, is_i32_bf16) IS_OP(__nv_bfloat16, int64_t, is_i64_bf16) IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16) IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16) +GATHER_OP(__nv_bfloat16, int32_t, gather_i32_bf16) GATHER_OP(__nv_bfloat16, int64_t, gather_i64_bf16) GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16) GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16) +IA_OP(__nv_bfloat16, int32_t, ia_i32_bf16) IA_OP(__nv_bfloat16, int64_t, ia_i64_bf16) IA_OP(__nv_bfloat16, uint32_t, ia_u32_bf16) IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16) +SA_OP(__nv_bfloat16, int32_t, sa_i32_bf16) SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16) SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16) SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 +IS_OP(__half, int32_t, is_i32_f16) IS_OP(__half, int64_t, is_i64_f16) IS_OP(__half, uint32_t, is_u32_f16) IS_OP(__half, uint8_t, is_u8_f16) +GATHER_OP(__half, int32_t, gather_i32_f16) GATHER_OP(__half, int64_t, gather_i64_f16) GATHER_OP(__half, uint32_t, gather_u32_f16) GATHER_OP(__half, uint8_t, gather_u8_f16) +IA_OP(__half, int32_t, ia_i32_f16) IA_OP(__half, int64_t, ia_i64_f16) IA_OP(__half, uint32_t, ia_u32_f16) IA_OP(__half, uint8_t, ia_u8_f16) +SA_OP(__half, int32_t, sa_i32_f16) SA_OP(__half, int64_t, sa_i64_f16) SA_OP(__half, uint32_t, sa_u32_f16) SA_OP(__half, uint8_t, sa_u8_f16) #endif +IS_OP(float, int32_t, is_i32_f32) +IS_OP(double, int32_t, is_i32_f64) +IS_OP(uint8_t, int32_t, is_i32_u8) +IS_OP(uint32_t, int32_t, is_i32_u32) +IS_OP(int32_t, int32_t, is_i32_i32) +IS_OP(int64_t, int32_t, is_i32_i64) + IS_OP(float, int64_t, is_i64_f32) IS_OP(double, int64_t, is_i64_f64) IS_OP(uint8_t, int64_t, is_i64_u8) IS_OP(uint32_t, int64_t, is_i64_u32) IS_OP(int64_t, int64_t, is_i64_i64) +IS_OP(int32_t, int64_t, is_i64_i32) IS_OP(float, uint32_t, is_u32_f32) IS_OP(double, uint32_t, is_u32_f64) IS_OP(uint8_t, uint32_t, is_u32_u8) +IS_OP(int32_t, uint32_t, is_u32_i32) IS_OP(int64_t, uint32_t, is_u32_i64) IS_OP(uint32_t, uint32_t, is_u32_u32) @@ -192,17 +209,27 @@ IS_OP(float, uint8_t, is_u8_f32) IS_OP(double, uint8_t, is_u8_f64) IS_OP(uint8_t, uint8_t, is_u8_u8) IS_OP(uint32_t, uint8_t, is_u8_u32) +IS_OP(int32_t, uint8_t, is_u8_i32) IS_OP(int64_t, uint8_t, is_u8_i64) +GATHER_OP(float, int32_t, gather_i32_f32) +GATHER_OP(double, int32_t, gather_i32_f64) +GATHER_OP(uint8_t, int32_t, gather_i32_u8) +GATHER_OP(uint32_t, int32_t, gather_i32_u32) +GATHER_OP(int32_t, int32_t, gather_i32_i32) +GATHER_OP(int64_t, int32_t, gather_i32_i64) + GATHER_OP(float, int64_t, gather_i64_f32) GATHER_OP(double, int64_t, gather_i64_f64) GATHER_OP(uint8_t, int64_t, gather_i64_u8) GATHER_OP(uint32_t, int64_t, gather_i64_u32) GATHER_OP(int64_t, int64_t, gather_i64_i64) +GATHER_OP(int32_t, int64_t, gather_i64_i32) GATHER_OP(float, uint32_t, gather_u32_f32) GATHER_OP(double, uint32_t, gather_u32_f64) GATHER_OP(uint8_t, uint32_t, gather_u32_u8) +GATHER_OP(int32_t, uint32_t, gather_u32_i32) GATHER_OP(int64_t, uint32_t, gather_u32_i64) GATHER_OP(uint32_t, uint32_t, gather_u32_u32) @@ -210,17 +237,26 @@ GATHER_OP(float, uint8_t, gather_u8_f32) GATHER_OP(double, uint8_t, gather_u8_f64) GATHER_OP(uint8_t, uint8_t, gather_u8_u8) GATHER_OP(uint32_t, uint8_t, gather_u8_u32) +GATHER_OP(int32_t, uint8_t, gather_u8_i32) GATHER_OP(int64_t, uint8_t, gather_u8_i64) +IA_OP(float, int32_t, ia_i32_f32) +IA_OP(double, int32_t, ia_i32_f64) +IA_OP(uint8_t, int32_t, ia_i32_u8) +IA_OP(int32_t, int32_t, ia_i32_i32) +IA_OP(uint32_t, int32_t, ia_i32_u32) + IA_OP(float, int64_t, ia_i64_f32) IA_OP(double, int64_t, ia_i64_f64) IA_OP(uint8_t, int64_t, ia_i64_u8) IA_OP(int64_t, int64_t, ia_i64_i64) IA_OP(uint32_t, int64_t, ia_i64_u32) +IA_OP(int32_t, int64_t, ia_i64_i32) IA_OP(float, uint32_t, ia_u32_f32) IA_OP(double, uint32_t, ia_u32_f64) IA_OP(uint8_t, uint32_t, ia_u32_u8) +IA_OP(int32_t, uint32_t, ia_u32_i32) IA_OP(int64_t, uint32_t, ia_u32_i64) IA_OP(uint32_t, uint32_t, ia_u32_u32) @@ -228,17 +264,26 @@ IA_OP(float, uint8_t, ia_u8_f32) IA_OP(double, uint8_t, ia_u8_f64) IA_OP(uint8_t, uint8_t, ia_u8_u8) IA_OP(uint32_t, uint8_t, ia_u8_u32) +IA_OP(int32_t, uint8_t, ia_u8_i32) IA_OP(int64_t, uint8_t, ia_u8_i64) +SA_OP(float, int32_t, sa_i32_f32) +SA_OP(double, int32_t, sa_i32_f64) +SA_OP(uint8_t, int32_t, sa_i32_u8) +SA_OP(int32_t, int32_t, sa_i32_i32) +SA_OP(uint32_t, int32_t, sa_i32_u32) + SA_OP(float, int64_t, sa_i64_f32) SA_OP(double, int64_t, sa_i64_f64) SA_OP(uint8_t, int64_t, sa_i64_u8) +SA_OP(int32_t, int64_t, sa_i64_i32) SA_OP(int64_t, int64_t, sa_i64_i64) SA_OP(uint32_t, int64_t, sa_i64_u32) SA_OP(float, uint32_t, sa_u32_f32) SA_OP(double, uint32_t, sa_u32_f64) SA_OP(uint8_t, uint32_t, sa_u32_u8) +SA_OP(int32_t, uint32_t, sa_u32_i32) SA_OP(int64_t, uint32_t, sa_u32_i64) SA_OP(uint32_t, uint32_t, sa_u32_u32) @@ -246,4 +291,5 @@ SA_OP(float, uint8_t, sa_u8_f32) SA_OP(double, uint8_t, sa_u8_f64) SA_OP(uint8_t, uint8_t, sa_u8_u8) SA_OP(uint32_t, uint8_t, sa_u8_u32) +SA_OP(int32_t, uint8_t, sa_u8_i32) SA_OP(int64_t, uint8_t, sa_u8_i64) diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index aaac24a146..9a1354a8dc 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -606,5 +606,6 @@ ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64) FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32) FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64) FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_argmin_u32, fast_argmax_u32, fast_sum_u32) +FAST_OP(int32_t, fast_min_i32, fast_max_i32, fast_argmin_i32, fast_argmax_i32, fast_sum_i32) FAST_OP(int64_t, fast_min_i64, fast_max_i64, fast_argmin_i64, fast_argmax_i64, fast_sum_i64) FAST_OP(uint8_t, fast_min_u8, fast_max_u8, fast_argmin_u8, fast_argmax_u8, fast_sum_u8) diff --git a/candle-kernels/src/sort.cu b/candle-kernels/src/sort.cu index 08f1f9fc29..7fecf8413e 100644 --- a/candle-kernels/src/sort.cu +++ b/candle-kernels/src/sort.cu @@ -85,4 +85,5 @@ ASORT_OP(float, f32) ASORT_OP(double, f64) ASORT_OP(uint8_t, u8) ASORT_OP(uint32_t, u32) +ASORT_OP(int32_t, i32) ASORT_OP(int64_t, i64) diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index aaa8a881fb..4617c08fbe 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -33,17 +33,25 @@ extern "C" __global__ void FN_NAME( \ } \ #if __CUDA_ARCH__ >= 800 +WHERE_OP(__nv_bfloat16, int32_t, where_i32_bf16) WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16) WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16) WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 +WHERE_OP(__half, int32_t, where_i32_f16) WHERE_OP(__half, int64_t, where_i64_f16) WHERE_OP(__half, uint32_t, where_u32_f16) WHERE_OP(__half, uint8_t, where_u8_f16) #endif +WHERE_OP(float, int32_t, where_i32_f32) +WHERE_OP(double, int32_t, where_i32_f64) +WHERE_OP(uint8_t, int32_t, where_i32_u8) +WHERE_OP(uint32_t, int32_t, where_i32_u32) +WHERE_OP(int32_t, int32_t, where_i32_i64) + WHERE_OP(float, int64_t, where_i64_f32) WHERE_OP(double, int64_t, where_i64_f64) WHERE_OP(uint8_t, int64_t, where_i64_u8) @@ -54,10 +62,12 @@ WHERE_OP(float, uint32_t, where_u32_f32) WHERE_OP(double, uint32_t, where_u32_f64) WHERE_OP(uint8_t, uint32_t, where_u32_u8) WHERE_OP(uint32_t, uint32_t, where_u32_u32) +WHERE_OP(int32_t, uint32_t, where_u32_i32) WHERE_OP(int64_t, uint32_t, where_u32_i64) WHERE_OP(float, uint8_t, where_u8_f32) WHERE_OP(double, uint8_t, where_u8_f64) WHERE_OP(uint8_t, uint8_t, where_u8_u8) WHERE_OP(uint32_t, uint8_t, where_u8_u32) +WHERE_OP(int32_t, uint8_t, where_u8_i32) WHERE_OP(int64_t, uint8_t, where_u8_i64) diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index c82a88375d..21d3d995c0 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -153,6 +153,7 @@ UNARY_OP(__half, usigmoid_f16, sigmoid_fwd(x)) UNARY_OP(uint8_t, ucopy_u8, x) UNARY_OP(uint32_t, ucopy_u32, x) +UNARY_OP(int32_t, ucopy_i32, x) UNARY_OP(int64_t, ucopy_i64, x) UNARY_OP(float, ucopy_f32, x) UNARY_OP(double, ucopy_f64, x) diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index e83498e40d..a9b8129c3a 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -58,13 +58,15 @@ kernel void FN_NAME_STRIDED( \ BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \ BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \ -BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); +BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \ +BINARY(FN, int32_t, int32_t, NAME##_i32, NAME##_i32_strided); #define BINARY_OP_OUT(NAME, FN) \ BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \ BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \ BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \ -BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); +BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \ +BINARY(FN, int32_t, uint8_t, NAME##_i32, NAME##_i32_strided); #define INT64_BINARY_OP(NAME, FN) \ BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided); diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 2af3fdceb0..c8122ccf0a 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -76,6 +76,7 @@ kernel void FN_NAME_STRIDED( \ CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half) +CAST(cast_u32_i32, cast_u32_i32_strided, uint32_t, int32_t) #if __METAL_VERSION__ >= 220 CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) #endif @@ -87,6 +88,7 @@ CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat) CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half) +CAST(cast_u8_i32, cast_u8_i32_strided, uint8_t, int64_t) #if __METAL_VERSION__ >= 220 CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t) #endif @@ -98,6 +100,7 @@ CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) CAST(cast_f16_f32, cast_f16_f32_strided, half, float) CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t) CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t) +CAST(cast_f16_i32, cast_f16_i32_strided, half, int64_t) CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t) #if defined(__HAVE_BFLOAT__) CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) @@ -107,15 +110,27 @@ CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t) CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t) +CAST(cast_i64_i32, cast_i64_i32_strided, int64_t, int32_t) CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half) #if defined(__HAVE_BFLOAT__) CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float) #endif +// i32 +CAST(cast_i32_f32, cast_i32_f32_strided, int32_t, float) +CAST(cast_i32_u8, cast_i32_u8_strided, int32_t, uint8_t) +CAST(cast_i32_u32, cast_i32_u32_strided, int32_t, uint32_t) +CAST(cast_i32_i64, cast_i32_i64_strided, int32_t, int64_t) +CAST(cast_i32_f16, cast_i32_f16_strided, int32_t, half) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_i32_bf16, cast_i32_bf16_strided, int64_t, bfloat, float) +#endif + // f32 CAST(cast_f32_f16, cast_f32_f16_strided, float, half) CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t) CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t) +CAST(cast_f32_i32, cast_f32_i32_strided, float, int32_t) CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t) #if defined(__HAVE_BFLOAT__) CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) @@ -124,6 +139,7 @@ CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) // bf16 #if defined(__HAVE_BFLOAT__) CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) +CAST(cast_bf16_i32, cast_bf16_i32_strided, bfloat, int32_t) CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 9eee97ca0a..eaa78d7b73 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -193,6 +193,12 @@ INDEX_OP(is_i64_f16, int64_t, half) INDEX_OP(is_i64_bf16, int64_t, bfloat) #endif +INDEX_OP(is_i32_f32, int32_t, float) +INDEX_OP(is_i32_f16, int32_t, half) +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_i32_bf16, int32_t, bfloat) +#endif + INDEX_OP(is_u32_f32, uint32_t, float) INDEX_OP(is_u32_f16, uint32_t, half) #if defined(__HAVE_BFLOAT__) @@ -213,9 +219,11 @@ GATHER_OP(gather_u32_bf16, uint, bfloat) SCATTER_ADD_OP(sa_u32_f32, uint32_t, float) SCATTER_ADD_OP(sa_u8_f32, uint8_t, float) +SCATTER_ADD_OP(sa_i32_f32, int32_t, float) SCATTER_ADD_OP(sa_i64_f32, int64_t, float) SCATTER_ADD_OP(sa_u32_f16, uint32_t, half) SCATTER_ADD_OP(sa_u8_f16, uint8_t, half) +SCATTER_ADD_OP(sa_i32_f16, int32_t, half) SCATTER_ADD_OP(sa_i64_f16, int64_t, half) #if defined(__HAVE_BFLOAT__) SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat) @@ -226,6 +234,7 @@ SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat) // i64 INDEX_ADD_OP(ia_i64_f16, int64_t, half) INDEX_ADD_OP(ia_i64_f32, int64_t, float) +INDEX_ADD_OP(ia_i64_i32, int64_t, int32_t) INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t) INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t) INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) @@ -233,9 +242,21 @@ INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) #endif +// i64 +INDEX_ADD_OP(ia_i32_f16, int32_t, half) +INDEX_ADD_OP(ia_i32_f32, int32_t, float) +INDEX_ADD_OP(ia_i32_i64, int32_t, int64_t) +INDEX_ADD_OP(ia_i32_i32, int32_t, int32_t) +INDEX_ADD_OP(ia_i32_u32, int32_t, uint32_t) +INDEX_ADD_OP(ia_i32_u8, int32_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_i32_bf16, int32_t, bfloat) +#endif + // u32 INDEX_ADD_OP(ia_u32_f16, uint32_t, half) INDEX_ADD_OP(ia_u32_f32, uint32_t, float) +INDEX_ADD_OP(ia_u32_i32, uint32_t, int32_t) INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t) INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t) INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t) @@ -246,6 +267,7 @@ INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) // u8 INDEX_ADD_OP(ia_u8_f16, uint8_t, half) INDEX_ADD_OP(ia_u8_f32, uint8_t, float) +INDEX_ADD_OP(ia_u8_i32, uint8_t, int32_t) INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t) INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t) INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a97d327468..d6e6dd69b8 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -46,6 +46,7 @@ pub mod copy2d { pub const HALF: Kernel = Kernel("copy2d_f16"); pub const BFLOAT: Kernel = Kernel("copy2d_bf16"); pub const I64: Kernel = Kernel("copy2d_i64"); + pub const I32: Kernel = Kernel("copy2d_i32"); pub const U32: Kernel = Kernel("copy2d_u32"); pub const U8: Kernel = Kernel("copy2d_u8"); } @@ -62,6 +63,7 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64")); + pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8")); } @@ -72,6 +74,7 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel("copy_f16"); pub const BFLOAT: Kernel = Kernel("copy_bf16"); pub const I64: Kernel = Kernel("copy_i64"); + pub const I32: Kernel = Kernel("copy_i32"); pub const U32: Kernel = Kernel("copy_u32"); pub const U8: Kernel = Kernel("copy_u8"); } @@ -86,6 +89,7 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_tiled")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled")); + pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_tiled")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled")); } @@ -96,6 +100,7 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel("copy_f16_tiled"); pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled"); pub const I64: Kernel = Kernel("copy_i64_tiled"); + pub const I32: Kernel = Kernel("copy_i32_tiled"); pub const U32: Kernel = Kernel("copy_u32_tiled"); pub const U8: Kernel = Kernel("copy_u8_tiled"); } @@ -110,6 +115,7 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided")); + pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_strided")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided")); } @@ -120,6 +126,7 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel("copy_f16_strided"); pub const BFLOAT: Kernel = Kernel("copy_bf16_strided"); pub const I64: Kernel = Kernel("copy_i64_strided"); + pub const I32: Kernel = Kernel("copy_i32_strided"); pub const U32: Kernel = Kernel("copy_u32_strided"); pub const U8: Kernel = Kernel("copy_u8_strided"); } diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index e009ca1d6a..484fa0a1b1 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -602,6 +602,12 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) #endif +REDUCE(x + y, fast_sum_i32_strided, int32_t, 0) +REDUCE(MIN(x, y), fast_min_i32_strided, int32_t, INT_MAX) +REDUCE(MAX(x, y), fast_max_i32_strided, int32_t, INT_MIN) +ARGMIN(fast_argmin_i32_strided, int32_t, INT_MAX) +ARGMAX(fast_argmax_i32_strided, int32_t, INT_MIN) + #if defined(__HAVE_BFLOAT__) REDUCE(x + y, fast_sum_bf16, bfloat, 0) REDUCE(x + y, fast_sum_bf16_strided, half, 0) diff --git a/candle-metal-kernels/src/sort.metal b/candle-metal-kernels/src/sort.metal index d71ab82234..b7cf71bb58 100644 --- a/candle-metal-kernels/src/sort.metal +++ b/candle-metal-kernels/src/sort.metal @@ -88,6 +88,7 @@ ARGSORT(float, f32) ARGSORT(half, f16) ARGSORT(uint8_t, u8) ARGSORT(uint32_t, u32) +ARGSORT(int32_t, i32) #if __METAL_VERSION__ >= 220 ARGSORT(int64_t, i64) diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index fe04f2378f..0e043332fe 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -75,11 +75,25 @@ WHERE_OP(float, int64_t, where_i64_f32) WHERE_OP(uint8_t, int64_t, where_i64_u8) WHERE_OP(uint32_t, int64_t, where_i64_u32) WHERE_OP(int64_t, int64_t, where_i64_i64) +WHERE_OP(int64_t, int32_t, where_i64_i32) #if defined(__HAVE_BFLOAT__) WHERE_OP(bfloat, int64_t, where_i64_bf16) #endif #endif +WHERE_OP(int64_t, uint8_t, where_u8_i32) +WHERE_OP(int64_t, uint32_t, where_u32_i32) + +WHERE_OP(half, int32_t, where_i32_f16) +WHERE_OP(float, int32_t, where_i32_f32) +WHERE_OP(uint8_t, int32_t, where_i32_u8) +WHERE_OP(uint32_t, int32_t, where_i32_u32) +WHERE_OP(int64_t, int32_t, where_i32_i64) +WHERE_OP(int32_t, int32_t, where_i32_i32) +#if defined(__HAVE_BFLOAT__) +WHERE_OP(bfloat, int32_t, where_i32_bf16) +#endif + #if defined(__HAVE_BFLOAT__) WHERE_OP(bfloat, uint8_t, where_u8_bf16) WHERE_OP(bfloat, uint32_t, where_u32_bf16) diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index a82bfdbdd6..0c5a2736ee 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -169,6 +169,9 @@ UNARY(id, int64_t, copy_i64, copy_i64_strided) COPY2D(copy2d_i64, int64_t) #endif +UNARY(id, int32_t, copy_i32, copy_i32_strided) +COPY2D(copy2d_i32, int32_t) + #if defined(__HAVE_BFLOAT__) BFLOAT_UNARY_OP(cos) BFLOAT_UNARY_OP(sin) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 0da2c70028..55b5542ed8 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -151,6 +151,7 @@ macro_rules! pydtype { }; } +pydtype!(i32, |v| v); pydtype!(i64, |v| v); pydtype!(u8, |v| v); pydtype!(u32, |v| v); @@ -200,6 +201,7 @@ trait MapDType { match t.dtype() { DType::U8 => self.f::(t), DType::U32 => self.f::(t), + DType::I32 => self.f::(t), DType::I64 => self.f::(t), DType::BF16 => self.f::(t), DType::F16 => self.f::(t), From 6f0e1908754bfc5c11e53534f2934b0a94e1b0f1 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 14 Aug 2024 07:11:36 -0400 Subject: [PATCH 32/75] Fix on metal --- candle-core/src/sort.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 92ad1d5adc..9d9fd59634 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -150,6 +150,7 @@ impl crate::CustomOp1 for ArgSort { DType::U8 => "asort_asc_u8", DType::U32 => "asort_asc_u32", DType::I64 => "asort_asc_i64", + DType::I32 => "asort_asc_i32", } } else { match storage.dtype() { @@ -160,6 +161,7 @@ impl crate::CustomOp1 for ArgSort { DType::U8 => "asort_desc_u8", DType::U32 => "asort_desc_u32", DType::I64 => "asort_desc_i64", + DType::I32 => "asort_desc_i32", } } }; From ec55f5805b0e57a8a76ba398dc38cc1e806794ad Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 4 Aug 2024 07:14:33 +0100 Subject: [PATCH 33/75] Add the flux model for image generation. (#2390) * Add the flux autoencoder. * Add the encoder down-blocks. * Upsampling in the decoder. * Sketch the flow matching model. * More flux model. * Add some of the positional embeddings. * Add the rope embeddings. * Add the sampling functions. * Add the flux example. * Fix the T5 bits. * Proper T5 tokenizer. * Clip encoder path fix. * Get the clip embeddings. * No configurable weights in layer norm. * More weights related fixes. * Yet another shape fix. * DType fix. * Fix a couple more shape issues. * DType fixes. * Fix the latent dims. * Fix more shape issues. * Autoencoder fixes. * Get some generations out. * Bugfix. * T5 padding. * Clippy fix. * Add the decode only mode. * Fix. * More fixes. * Finally get some generations to work. * Add readme. --- candle-examples/examples/flux/README.md | 19 + .../examples/flux/assets/flux-robot.jpg | Bin 0 -> 92230 bytes candle-examples/examples/flux/main.rs | 182 ++++++ .../src/models/flux/autoencoder.rs | 440 +++++++++++++ candle-transformers/src/models/flux/mod.rs | 3 + candle-transformers/src/models/flux/model.rs | 582 ++++++++++++++++++ .../src/models/flux/sampling.rs | 119 ++++ candle-transformers/src/models/mod.rs | 1 + 8 files changed, 1346 insertions(+) create mode 100644 candle-examples/examples/flux/README.md create mode 100644 candle-examples/examples/flux/assets/flux-robot.jpg create mode 100644 candle-examples/examples/flux/main.rs create mode 100644 candle-transformers/src/models/flux/autoencoder.rs create mode 100644 candle-transformers/src/models/flux/mod.rs create mode 100644 candle-transformers/src/models/flux/model.rs create mode 100644 candle-transformers/src/models/flux/sampling.rs diff --git a/candle-examples/examples/flux/README.md b/candle-examples/examples/flux/README.md new file mode 100644 index 0000000000..528f058e38 --- /dev/null +++ b/candle-examples/examples/flux/README.md @@ -0,0 +1,19 @@ +# candle-flux: image generation with latent rectified flow transformers + +![rusty robot holding a candle](./assets/flux-robot.jpg) + +Flux is a 12B rectified flow transformer capable of generating images from text +descriptions, +[huggingface](https://huggingface.co/black-forest-labs/FLUX.1-schnell), +[github](https://github.com/black-forest-labs/flux), +[blog post](https://blackforestlabs.ai/announcing-black-forest-labs/). + + +## Running the model + +```bash +cargo run --features cuda --example flux -r -- \ + --height 1024 --width 1024 + --prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k" +``` + diff --git a/candle-examples/examples/flux/assets/flux-robot.jpg b/candle-examples/examples/flux/assets/flux-robot.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f715743346d1a8dd6ab45bea8136921890df0f3f GIT binary patch literal 92230 zcmbTe1zZ&E7e6|?OLvEKDM(0nNvAX-0@BhFf`Xz82uer^3J8caf=GvnbP7@ef`XJF zScIa0_nFzn_kDl=dq1ChFAKx$?9A-UIp1^6_nhZB^Jn7EEJS0Vqpt&@&=7XE>4a^ z?Cjk9Vn?`nMflj+1?7Z9#3iJqrMZqMsK`qyi%CgK!Y4tIl9E!8Q?O7{vPkl<^GN=G z{`u1mae?b3L?J&A{s)D|5D*d(laP{;gB|K=AT$btK@(sI2?+?m-r?YLh=7)mjz{7+ z5xtocF>e5aWJJn+5TF)_2S9uhcwL{R9cw2Z8ryn>>pmbQ+r zp1y&FrIqzbn^U$fu5RvUJv_aF&IgBFxELB1buBvP`ij|9{=t-y8eqzNR1w3<|&qLkp=vA>Mf(n58MC z)XB35So55huoOc)LEsH!H-v`YfCaNO)SO2+0Eu~^aGwmLCX2y#hC}d9QF_Gg@T$Kw zl=ttwSqm`{@KJ)&&?=E&A=sat64I-SF(^*%!st0KVOfmPd@h!}lAbsGJh4-dl=`(i zG?ok!gz|_n>eNKZJt+@tKA%c7h>17KRuN|OKc*eI;4?Z*Yd=*+(`_@BDv2d?6cd7c z^9bK+x5y5m*Yk)8ivlLnUCMks=|!yMjpBe&J1Y`hLT4yC6eBZ;H#yORFg* zqilolMz8aOg$YmF3yd-8c_N3WAR|*A2_rO?4$B1okn2EW8B8(QLP<-z(FSu7{!^ts zh4Dhf^_nRm0!;g3%595W+vmI#uJ`gZz$e$)yEk=DyjlJ9Z}TN8bM40MqJC9#?FR5q ziJIx%X!G^nW0TLyf`|oQxP#p<=kMK}N_1Rg7u_ohBD_w1)yuA}>Kk#raANBY=ct{? zv6OY~)>amkvS|@cZDw*3ry~EV7-cj6s!m&Ho}jqV+leDGD4mtt3rYZA-i&F z+To_PV|RcwLeJZYsJIkP*yQk)`IIx6AC2kL%i$T&bel*!ZoZ&uX8er(Wv|B9OIAXp z?G<0giZ!(;8O>cMZgf}$*mx%9bT}0&iV2grl=xSTa4<8*Wb6UqN*5``wfacgSyRTu zM@Kp3&pF9vi7`L^%aIDH&6D%cPdotC!(LWeX6S9S3hjZGk zQjU%Y4bZOlr+u zvFC^1g>6%F?cFFaK=9^{c;NL(lOk3BMf$ap;Y9xhc2Scev4HQXUEb^FOEPrR)<4%z zH6IeZlOI_dD9SsmX%<*6OwSuhWa^pDg@zn%{L&M|TOl)3Pm1 zvFa{znyzPfw?)a68K_V$dz1bMz2!uCZh(D>AZ##pLnw7}b@(e7pgO3I#1`?qU=x`j z$^*PcfyjZbU~L0y6$O`p^VgGjWlE`25k-(uVsLt#{l}35hyEDL91O9QLNn?_tB}zU z8de_oN}xASY;z$h%aIig?)juG6EpB&>|?c!6`jD%DJkj_#DW1);kjCGK8tc%bMQz43r1}9OiiYGa) z-zc*SVdvfkPW-~;Ug;7!Y*b6+6yg?3jMu&kE|F6)?TjKmR>tH<=a}bJojZn-4swl(f6HbZMN0-8WyO@t{E%TbQ3BZ z@JhMs#%f!qC0U+TcbrzbT`pLhGbzrnm9HNRte&bzb)kquA_gakB1S3~F_ zbR1+$Aecv(M~LVU)+4O0Jm3@OfNFtE|EdMTdt96`;shRq_uZ?IoGA+E3}G1`X#noX<3w32CZ@)cs9R;yFd4l{OL zLuXVfVwo}`fuk}Yl5P+pyfOJvFt#wimCw_2LQKaE!~xH&dxCb#s|p%(%VxdIGy^7S zC;2@0R=&s0EC0HDKR|1=S!4d%D7f+sk19KFL%5^u8MN zqyr1R+b03icwCG$rPQwr2kE(83%hN5dBMAao&OyL^M*lC1*itz5fL|{0-PRzVzE+2h#tTK?-G%a``5l8 z_Y5*hZlFs1v8W*2{)7YEARL>*4wea}`Pz}4Gd*%X>b2AK#V~7*9;;O!bJrCCLv6cs zx3a3wYOj6(=izPq1qg}P)3Z=GCG%*7amOhkX{B3$4lraoao*4Wd?Lww2rNur>))mY zhMG6b(#(MM^m3H*&x7RiTl6Z0C{k?9(dwQg!4f!6BLERZsMFRCv(opRXcK-~uTNUT zl=Yx}I%4^O9{8eXtIhgNgGzvU_-{P>SY>xc>0P4Pw5IGO01KzooX8KMX=}o9T<_Aa z2h6SdNX&g?84rV^-<9#6vehjIO1bcWKjK{(o%MwWybDxg8ocfgXa+_hqC$c)_|9@% z%RDY9*iI+sA(|0qO8)Eofct~j;TM1rplB#{D%heB_1|9;)e%vn1eKz@sVe~U`MF6Z!%vZ@}fx)`;7 z2C|n+ap+c7ZA?|0j1r3V?z?yto3SyEUFF85vC>ufo`7 zdSk0z<4)GH2{F!KcG!$6n5kvs2|G3ft3IO*M3~w60g{sS0WE`p)(I%4UGzdRF!;jM zhrFkSpQh{6OxiD$>)lP30_-#M!8ILYNDjs}gA63EP$)%KfoJf9!1FOB!#F~G9f()9 zL_lG9n_>rE8D~-O?tpb7uLA`mw!jJ}4F53}SY7zQ2J~{^>|pQt*R*id#11yegHtl# zoH0&mdO(Nlha!@w4O*08+8LT=BgT^q{Qljt8*-*-n4W<9PjLlCWl%F(2aHU{Q2X$S z+EMwmLcJqDih5O@jJuw(7c8;7;oKb6?$>%F|HZ1;M1u8kNw=7j_2*mIop{=rB-`ug zv_7HWFTjBo8Z|m>DHjFNg_pm(BN#;Dcl{wa{E#1TW#kzNGrva`n8n+IznCQ3wfhEt zvCI>5W^Nh`E20a%!Fe^Tp6)acFtZBd?kwBF=)l8SIk8P`#|~oEoB$I?7paiV1?K)D z5V;IQl#=L%t3E3mf*bpy2(0dq(^`+}BO%NAkOwLT zOV0qM0Y0wI%BV?U4t0ow_plia!`gv010t6xPXLf*N1g@cXmASIMA3e zdqltV^gY|l`AaN*Dl6Zwi?_+Bn%Ozs202%;(MVOy=lqt#tB!)wSP)%XhFNAkpc!qf z{iu!8)8Q&R6LA2q@#J5?Dc;O|(4y6u5ks#YlLh$+LT}Yc8$gx;dD_Ci66&n-#S#+g zq;X)CNC;~V72^~^*a3+NqF^l$;r>PE5yP^FUv{RbICz1mfdN{DlV_aj04;>s2NTZa z2?E=}1v;=NE=0lZ3o%xNKjMrF4q}Luan2DIdy4_W9MnZ&eu8&>2od}zo9r0i2I>Dr zFDT@fHOsmDRvI8_atjo>DdvJH+Aj|S$mK`qy*kk^$aoVvmq!4vilEPxIMO^r(+6Yd zgVB_u5ZDh+H82r{60@L+klZE>-7Q7QSpc4S?JyRgwa^09l!x(!#Cjpqv`SP>e1b-r z9Z_1ICfNw;i3f&@QG>!1?Hmycix0>Uk&LMz4LF8@!p;=^A%YzK1T0{w0V);@hX?S& zFhKA=2MZkRffy0l;U&wGN6-PWAe-Ro0H_TB=>j{z3IT%5orjKySI0Xv*f{cG&xa6Z zAZ#G@17ZPguo-6$0O1H-$2%@iegV({!uni@cF!Pv3^FJZJP!wC97(5ugzlDM@0P)G z3Wg4Q0LkEk6l^Yl&ja9ng%#A2MIZ;C9UKL|7$=B^i1;{lLHHJ(Y1JutBWTl$d4R~+ z7@g^TF*I?Y@PT&5Qdj_tMFClv(*S}43lRMQ*{FvI2cQr+1LC&vh(z#6BnZK>Igp?; zmYDcU}&$x*m9@eVWvj8gz>urKUn;GHN0X<+%8!-8{CkIr)*f-@YDP9RA0fDk^0 zkIcXY!03f79yk(PI5ok6@qZ#X5)0_1TZ~{b41|R2CuKwczX4l7ZuWP9C<8*ORv@dL z^`x*Hf{gzjgEI#ZurOrm_-i1_f}D!crhIGwXcj1du?8%?FU01Fec}n1pg_E7z=HIZ z%>xxOglYf+z***jAYo|&$*h^8qv4&Y-s%A?gQ62oD8RyjJ0Nd}0KYL)kn9jQ@(x)A zCxgGpgT^`Lb02@F#?DG(-1@K*hQJ49zz>Qo3`ZItm zuyT-M0)>!f5ghiw94eewt< zjR+=T?Sb>&^`ObfbBZHi$#Zs&AQEJOaf|b%up|MWgx>(!;T+`w*kHE^3kzh&U=0fk z@IN2HTk$e0Mw00RBZj4e9E-y-{y$+-xsLhBDX>~fPuXt5Gul`5ab}%`;WH+Z{dA`4)C}_;{b?YOA43G z{v$2;ME@f|{3d#7ltBi<9}oe7>=Pd;>87xO#_s#XD&R2FL>W zPGG;_tta+xaQ%;e`@gx!zv4ykh@4W)16>KufV(wte~6c3MO+;Ythi7N5(0HNA|MPD z9+5MHr3eqpAqjDoBaA5Sf8+zh04HWRq8#uABb*^1A0XfUZ_ydRJDLJ@)4>o&qPYNY zKWJdIK-CDy&@c%=Y?3q!$Qm-t(x%jP5jd zy;&?<)_GwqawN`I;7)H0sk@`3k&ifWr7`Ew8EX;k59A7`Q3Agz*>(;oyqwDttt|1l z*#G{=V%bK@8oxq4xeTjkCh>1G@%^-&*Q%Yz_X?&T` z!fbMgtJVcEcb|$2s=B-&h#@N}DUs1n{a_!nlkZX=#r~t4Gdsw#&~?_^_F~MEEVqFT zRSW0d^^&mij9%*d|T9=LadhWK`t%hCk!G;r84MbwH zZ!l4koopB9C`nc)-3g>#F@NLSJoP&1w4s&xC-1AQ9BE(nP2#LKINn$aS{Z%#um^p5 zCIDUG=+0X?mWtcQ*M-X`jgy-PhNjRe~Ya-7`PdU+!NlSvR}ew(S}p z?lrOZt8DY?ot?3=!>%j(^yR`7<-*i1Sr@bd>yAlxI0@2Z0?VO590_n2uu%gow!=zC zQV0!r3+`hiuo%ImRy>a2AVvlWk&xyQjj)*0q_rRnk`xZeNa#92-z9f8_Ecw@~Y!H0YVN73+UfSbmH3e=jwD0dW0#R1;_13qMw6r~OhK+6A! zp8*SGI>h%c?C)jX#QM1q?1aoYn&r;A?iaoh(EY+WPM2l*IVj~hdx#y@TdFMI(5L5; zkjNFYgyj;o{l}HpRkmpTs6Oaw?v@3nJxCRpv-<-wafF{WkA@y+2{k{wsu_wAyx z3$c=C7=2=tgQDJMiM8DpX(lQ@ZaSY>ST+BiA*lOGugFaqg~Qp^=>-W3E*ECZRuXsT z)oc~+@_K4}_}Z@eJmGNtC>wF_vDAKe*roB8=)On+w^iyv=>o`57wH%+ng@~l%6``NQ+Z7Jo7suA*k~VLutE%^&J}uF+n`BUA z=kPMMt74^l^3d3@9J`^<&L61DmG3*lUe&YsHTxvXs$l!6t#Y&pRc+U;_c`x}vmOT8 zALF4aaQ*{bIN5l4`}d_sA73AOld9+$+}IBU~2r$Q=#^< z3McpOmS5iY<}SB#)x3}TRO;|x|F`;B?I#Lh-hBdv0o%%7nNN>iqVNqQdFwM3(un#4 z6=(T0?V-x8-nhtJyQ70?SCrs6W0uII*5?0^7F;8?+{(HC9jEvGFd?U=r2fWSp#rro z`WUW^a&r@xBSektUa@RQXodYLLN*34h%F#qmS8{}Y{h+qpb|g~VH(G?z_1v@HJ)O` z1v6lUu%>W;1>P9Eki-Isx5$S7;`;x|1OSr%&4V)G{L5&F8^$?FEjSMn7~uu3G>c%L z1sV5emL{=-tE|yPFwZ)OjP8irg_%ep_W)q6z`XH%0rib?xO9v6)*)p>i>BWWWfv3_ zHjaeslCz&=$SrwoZXKy5`!#b}FD===irT@8t=hsmR+_9jyvBigu#P?W_l(M8>+7sW z)JEZ6FYoVeE@Xcdf{OGUm42Ix^2pJ?OELJE-x%b*1!2qx89a4N~Lo$%b|Nj<+?#{UHYb2s@D-C-@Acj7=P=_ zj31j;*3?>S#1@1+94r*b+G&m-PjIFFt|BDcXVmxj=2UP?P}8mCAblyz^muO{*)}%G z?DexdKS^6YeM%-GE;5&Np|w>|9lB^~y!#buqoBUbskl|lTR&P`RL5&vIxI+0o{*9l zw34XwIz+xh!pFBTE99&>1E!9L+ckvm?Tbo}rSO-9v3AYL8T}EV7;hTV-Ibu{b3L5x zAsg@bBn&ex>2piCa=n@-n)xQO!o98Zwu=^;{y-jIFkgK}EF4$ziyGtoM8s--FqH|@ zewTS5$TyS@C5}B65ctBE^-x|dIQQn+i+8ro>^j9O73`KhKMN0E+esfhRG}c?>@${$ z2QKi|NMe9XAYgX;uPgx}h5*Yj;=pVWJjTTVrv|Pu6L@ii2Y~1RfCs#O9s#@+2@L;t z3&frNWey<9<1r5QMNo@305~oHR*r``nqr9HJaEIvaR_MfliG+7A^}on9E)ZEypR}9 z&JKXx{Ff`@2N*!Afqdb~7frv#GD8EJ9;IU%bWKCDuhivsymZj}W_BwvP_L3GTtFsY z?6jLeLM7?6MDE3vk&RNBGPV|1Ns+LNn<8B*C}pax8r`fl_2lmM*~2OqsK0wf4}>V> zouBE{QCOC)G{2!OZfZI2N7B8jQX3b`Hd0l4TH3Vhn9cNV$&1DPFiWACiE+!dN^aBb zN4Hw~Q1`a={y=2mT5p;({R~}Ge$-sGEe`UF{yymjU12;q)bnw`TU%s^dvH+j%!HT`ui~z2{ID zXEe8^bjoe4FqSs6(5Wgc;bgFa0C{$>^Yq}!ukvnRm&pClA*Lq&=%f{Un+q4SSvovz zKGxb3<3J8TillHbO5tP&$&oX0$pOF}ehpv^VJTp?37&@Funs&toM^#yM37a&IXXgd z0RDlS2QOp=@12ovLJ*Ar5f5M~^)?vYe}RwGd%!921cs&;9`%9C3h*2xF%OnH;IiS; zWe1+bOyY;GksJumY``5MrrCe>q(cKpWkk^EmepbkpF1%*44Y+kV~bx!BH2 zT>H*dU!5?C~lM@qFD?+I=3hH2ld*Gy)<^-8|vNZHMb*xL%W$D156-4Ar=+&eSw z*PnJg0fl!dU%j_q?{sEPHvf@uPI!e``OKg&^=7P!)bA=yq4^2g#NW|@wjbY}=wDLS z{Gdy(@9U~#y6)NNC%oS~pQz{>7Vyb_)Z>hGfb}C^+UuVAtBNK4alKN5y*7LLuTZ-| z_SJ>&VmXa3hbx*`OKu#0TFY~)QDT9wOykP6Z(3y&g7fnh?&mHlwIAEAAKt<|U3L|f zvpjOMFt@At68lNUFf~uk2F{w)>bKIe4Rk1(&JL)iM`$XbWnr5jWhDESo%MyGwWey3 zTQdfgWH+2A-s(-4T!qem%;V_dUMjtlmzDEWhiHfIg5Ur}gnoMkgwib=wCE&O@w|3g zVoSyL&GAO9bW%4+ge|fm;D-KTvp81$g_MU-{j#_Zyys2}N%*ZF?E zVKnfg_*^0*YG8i4l;@5-t9)nZr#HovZbA87zH_C!_o$`WCO>T|Q3y8NhGHy3QbGJ?W7wBY-F_)O z+H=OY6+Syt9-f$Nkn&y0z0NtgADnKV7VMk`{wVy>|j>b%^#I| zmQR$5IH*qWMRtdFeXoqA)!!6w)J}i-vJ<_$+O{in?8f7Jr2MQ@*WXI#F7#X~PG&LH zU#dIJek}0wO_p$5BT-93Jl3LSd3!|uS#L--jVZf7ij$iHP4CGsSd0AzG!<& z-7Xp`WmlsbNB7iNMWWwKJjW=|Q26%ESGzyt{bp0wsKe(OMb4R12AM?})LtqHp?F{* zD8C`@PuxH~QdUB;@q|{ck1f~f?77en9IiAmr(=_P3-TXrnmH6HhSSbFS4$KzD@Zk| z5;WF!<#H}LTx%qwuBx0E(%X4F`Aq+buExzj(6xsNqBWGMLD3(ne};RV)u)Q_l<%?s z81O5+@1t!se`E5p&x5N^vadM)f%rZ-1l!ni*~e67nd`cVqgZDBo4Z%FL$jYz+Z!6@ zJgZQ1uW5{8`JQjtHH30l6!73%_BPn?xg&Ew8I)(>@)#iRf6WNw13aMEMk2_PiGw?S zd;s|Wzp(=$S@?(qp90}TNYDWb$R7Ei+8MP}9K5}K>wY4sax=Y{P2ht`f1Tl=NJ;ta4rh-+rEhz?w^qd` znB{Nk{DB_L?Ywc|RJ3cO7Z&C$`dZK$F!?heSpWVXNdKuogo?Gq(3Q;H^5G~z1p1x1 zze$FtjDE7Ul2Sn2D92CV>Kja$gMB`79+^FP^^ME61TrbWP|PnVH;szjx{G zgz=icH@h5lhtnF^FUUFU?H17XwsT}!@JFls+?vl?@ITr7uH^nr|KDFwT=UNln5z#pR(((2S8UZ37W84i!nOQsSjYFyP|#xFx{F6x&1dgjxiO=Z;vg&VPIqFCFU)64iiV(^{)djL9ZkT^ zcU>tqkzu()<-Kp7izHVQpD!sHVf;8`%Qx87`}*v7mOzcI$8B-JmP7&v)%4+4hiJ@- zuf6DO%f$u_*eh+Y%m0?I3w`y(-%&$cQu3MNn@d+#zU|)Z?Ml8_<0SBdJ%y9{joSW$ z*+<^@oYIcB#y**CN}7~T6V9gl1F1`T_u7uD=KDJEwVzp6d|LM0$mta$(Ye-tqRrjam^h6A)IC?5d?gWLd@cOWMFaQ7{ zZmt8e3Ue;SNDXACVsWrHuD5}K!XG(O2p&uUYrt^okX3ld93DS`T|Rsw32koqb2rSF zewbWW3tJBYGCTS?UMA*0{Vr@@QvIO)=KTlYq$O)~(T{+UgqDfQH|N zM;fJ7;!7&xGc80e@|WtYN{rp8_{C;u?b-g~@tI;N3wqYiCL{ubSJTFmoSBwQ=By|Z zCJ1Ju-Z%b?n=Sjs5$ov^6?0~7kdc^!@A1qZ2qn5t^epb!jM3MC&7J##_Y78;7f!zo zYSRK&N%qB4?`N;f;t6Xx{%3Eg=2OE0%+|J4SU)^_n|p5l?#P0;gZ|D9t{INO`O}7a zHX9$ctx9JG`NsH#=-a8kRG#iL%T9haxVSr56kwH;RQl5V$)?g5GLIK7aWhMC+>Y^% z({kecUtE&r`1syGyg0@dsG54$;ZWuO+47t^Rp>+aCvRsc;Hm303&Loj;=zg`Mnw zu>3#o5GN0yjv6HqBqu~U!lQ@)+hDF$2Vg0Hntl5)m-$YvfGgMcUXnK}omU?Qj9>Y-?l|vtQsL*F%W+ZV>?jMm)hD8Q$FdED z>LcQ*Ff;{?4wp9W^)~9}&h)-_K0ys;yJI<%j|Me<8-OAg+XE^bot}*QOnSr#3_fJy z=<)3mc;M=5*nFA7PJV&5KJ1Fjt%-MvS*2IcOEz`CQW4T>y_BlsU!&$Pt~NMlxh>!q ztV(}aut-Ut7~*ON7S!707zc}iW67X^Z zVjSyFVjOyOiJkg}vI$CG_0{SuyGmcegL~U2f8HJ0m@i!}Kjo8j8g;vOcSPX}xu?1J zulH}=Z$@Y9`jK|^t)z;Fo0O?;ExsvDs<^M4?h$q~d#KQ8d`(t$<#OlFh3whe1ZYLn ztv5tX{DYuLs&CYZXng^?RiEFM^Fr!-;c~iw{>n?1irnjVGS@^tsjef><_4nO;rMPbQD_L>|8gxjHK2^@y zdFI-k_JGHXCEE>EZz}G6Y$z0QQD*uyy0vgy_2(mDdh6cq-xM6=UW4jG$yYtz$uH9$ zU0&7SPiGM*cXOxvVJ=rbtehgq0!G2Xi~>@SKf>Gr_X2eLANLH;K|*M`5SS}~RMX+z z;i5m1$|Bh>5d+{dz)-;44nRm)Oi@PYWI}N|@CU}skiY@@SHY<7!Q>$jNZ=EQ;eukzTvjIzI-1lJEW5JbI5t=dUD}QE#K@IJ&Icd+rNFe=OEM?5=J-Q zUlF?Z2Z|0BR>)j?GoY<3=+AN0;VRjeM4n*zl9+SOckWi6b!rT84!F}`t(M6?{%-Pu z!t;_xtv)Nc3{ReC9v)lrSBqg^tj-RZ)43E}>D5$JJ*i9T`}*w^KhwtBi4M9p8iKHUPtYS)7biv$nbK8gi?b?WFghBbfIOI!7J2^ z^{j#2*+za90e6_D?nKvXvOk$kr()7l7T@k#RebY^tkG8L<-_B0p1d0ud*|lg^1q=l zG3A$6`!OD{WA*#g5BYJToNW%bg7-h82ilt(>W(aD#tQs_==->aPO8{xB<7fC#L%AC zwO|NXx!vP?j&tHxv}Wi0n@2&@zux2sH7nJQU1P9HX`XALtnbhB|BlLbk$1TG>z2c9 z(n}9>bsYB=ZbsR6WeY&kL+C@m4&c-k&LRQ%Imkm10)&G;13V&j9Do>}kBO&7;1djj zS7eF~nBxQtN*zEPH+>V#0YiGYffcj}fo>gO7O)W@CV}Vcf8T@6C?psl)0crY;hMm} zraW{6_>clb0dNWrcvc`9Om%}z1UJhTWZYnuO$4=Ol%B`;_q4|sFP2ywOMV_?ATiC( z^~ugW>FK5Ot5cB*AC}EVU+Fk$4ypv~%M^s@idguqHHZG-24lK?VgPZ56XRq*glf== zLE3!*%^0nqA&X|!eh+M;MTc-?@uV}0>_SQlhnj7#xsu>87uo)fj{eL0N^t@8eb(m| z-e%q)JTYJQW2vUBqQXkx^R1~{F+DxgC%pC60>^*eWYMv@v>2h5V>DqQ5u8~Z?#XT4 zJ?CP@`f2^sEp5Td)9F^vU6VAbd>h97Us>y~f8f5YbbO!lc>!kX=_&{VUE%Vo_WJip zy`pnv)3bDXSaxDtI@K)Vl05phJ=Y`!uQ8hziu{3|^54*(Y5F|aZ=%jztQ4ta5%^f> zIOC&pV_yUkXH3#ut<1|4AXP`yXOoA9K2r>DL?>S`Hs0|+=0~QHq8+KnHujcFyo-I| zm@3x#^~A=i;rXC_`ES0T6S%AEJjOq3yC_^7%QXGkyv7z8SpU&1vVL-T>~_BN+4e)0 z%pVgPRc&ihOMXO=$u4BO28M@%G$ZO-mTpmWWpl7x-7_KSx1o$_Typ1qg@bFsWkfOEYV{c z4Oy~L#&JY<1Pqd`uFH3>Z6x}u6;ahqO{eQfjJMa+BKY+@?=)x4cPo!ot9u_@sxzS%Dwz z>k?&e_7xjM{(Aa7DYiw+upE_ zq!P8khhST4U9aN8*Zh)-RcwEtkP<(sWN5i4NOnpo)%vg}#jpB|>wX_^uw`7nU#0&C zQXM;4!#f`D_Nm#izPY(ctmt6|fyHQy8CEsZPsK#K-Z}Ddbp-uUdx7@16GZR&vQwq9 zgFfl};-GU_Pnn*4C)!SCs@iJRBb*K;>kl4J%m69JyX;5z7ssNLj*Qk_rTLUqw4Jw2 z&)(N~{vJ1b@`ESUxfvgxS;-eBNwgZ*)_%SD^xoIFj{-7Q5KnclgbI*!@3OEt~9^GefGqiN}BJ^n>JSl=GGNC0~=Pdw-(O+dQ)yV zryFhmd#%n}leQSlT5Sg~WJ4GW?CcSDk0bgpUV(ciKzc&JDo9EJnS!YlAT%@p@qfV$ zc0?u(JL4KK@USPucTD`9Q*0{U0SSo1ll@8cK!pY-QV4HyhAqO80SjjcwjDg6C7=l$ z1LnqevUx&J;2s0eU;42zV@55 zU~5+}onf0p*&=aS;KH2mm9eX##V6+UckZ8&2>|)(fa^;4%FT-oD;XVGhbOsRoyKe& z_jT@f6a<^87TmTkQAsdE_#9}GSMjgrY z&t4rO&Zt2pb;er`-zZn|?nZGv^l4DYU24d+oY4_Kt-Nm5!#Or!|1sD!w&;?7p|XF@ zTAQERz3-K5XM%=G*u2b7QTRp+d@m3&fzH1Qr=*foG-RT9;eCT|Ekuy8XQw2izbDgh_IJUXEygznQ91qI9~hPR z9fmI(GX%Zu;`dsS?r2nwwj${|vMa26-P2)Nx?IQTP2D$XZeupm*o~=~HEYS?v$xNF z$_~g5S$b%8M>E&_=C;UN?=XjZ?4M=Vqk7)^7j-B1OSSMXM7*Lt;rZCyRe{&uYi4QO z_rv9y&TDg5pCq{c`c6cACdh7hM6)Kojy04x-H+BLk-;@ser1Qu_oItmzTvGs#Ue%>IcnZlKd3^KwG6t0-gO`8I|41xgOL?t{h z$J&1!9&CM_ke58boK_gt{B}3C>~_b<$+LW4J4eFz3&M1NHQg3aT?@RbwddvNY4tHo z)vKS7ATc{!=tklP7QYdwK$CZNk8*H!e(QK%Ot~C+S<{1@+_Kqh^ldN8AL#uN;?0Vg z!#4|Jjg~Wd2#OUL=Y!``Mk7`8M&<7wy(5qq)X2JneYuoxb|`1_uwgNAltIguOr7CN z&V1G?&2Sa!pKnxoEvlG7zUek@moz_X9?#)a1p_*3oL@IUEP;h_b!CMWrM7Q~&vqv_G9#_2% zEKTW(yCW+5YUkAw|BpunwtS1HM-zrHpvZ}#0W;%ZFMp8jBJN$oRJ@o6q3*4)qtA#; z1|v&oT$1Yy_tDjXY&Qb5$-+I2A8>tl0HJMe%+Mr2qM1KsnjOZC872GLA8B2(rc9%1 zGZ^9_8$#2<{W^yQY2YR$3LGOs!!3ysF|gDBYXLY)g?~%QQP0fs1a$Vnr?UfXYQPl! zHmJp85yXR}I*#Blemgxw3%(Dkodpg$c&j{U*NuY~_jwJRzZQC_Efx(Y_1KIWjo)72 zqU4>t?_l@`QW2~E16|Jh10CY2TKyd2UeRtq0G|Wgw~I7t{)#{r>C@9?9D382Q|mnI zHtH{$b+05XpVH?zzpXHye!10UnpxRzDZlgS%yrViE0nrC1_c6%D+P`%*+~X>PApZG zpAE9+2nbnEPkMe*rcXr9<#DcpXdii*w$f3Z)&6Nzcd3mhoAP-PAKk*AG0g8CYfqeW zd@*nRHPHQ-1opbgH=++$tG|Frnl5VNC96hEE@-COH0iuy%!JO-mIB4D9>L1&Jwg^xo2)qxRrIT^bhpYhy$7- zh)?V%K^LsvqCQ9Uftp3TGmzvDxo7S6pz5=guOG9*JE>2G ztHZK3H-4}`t~;SoY@ab_Ns$_k{i4y@Q>|>ow6?aZP+;P@fA8`Wu^hcb`ah6J+SZZ1 zf|ZF&n_K=(RwgH#=jN+(M`jW`cPi)mK$meHpdBMH1pz)6&~QKnK&L0nG=Kt-Mg}5L zG>lldMWr=@2!;&Y7YrB#uF*4@j1U9VG+PILW5i&G2`^>MIimmKl3-H1olR&Ofu~tO zOs^QzA@c|uCj;lIVVGFlz&lhN#?ZcbB5WSyxOT;Z7CYowXP_6PON6<-%OB9f7 zK|-DVLi>1UbV~Z5(G*8&0%T4ZXk} z&SsFjxLMAmy4|zS(CmD&{$8M=uWD~cyzQyer>pk`^R3h}7z7KCn4b$V^15+I8w!Iy)8X`a+AZbH9);E8a4Q>jI$3_b_Q3##*aL7D9`r71)IVV}O z02}Z(VIl?|QP!6~=lK0&la2CYF(&%HeDL%ud+_wD8vG}d+O02f>+@Cc>groCJo z2D`uT5@Vu1HWd$^FLwnzUyk!kJu^8|DL+yM2UHa|o(cLDVW0?r>LADrX<*N6giHtk zh>whj1}+#tUK|eV^H&a+?RSpYT&3*$1ASE6^LnK-8h@4Q542#xH@te{g?0&@mtU}q zyAwiM>ad9`g^8~&#P2z`_h+9?rc$#-`Gpddf1swbwb;c`&}%#e7`TYEu~ zE>4_t$vlrPU?QYUnosUX6Hx}`#>TGNLvDe_C*mhkZ?rKuwn)g!D=c%AHp!e133)+F z=2f#A>?rxT-xCBIiEB!;ZY?0Wqd z6%AZLyCUpYog*eB^h8*|li{4xTU%+>Z9Hwom}uf0rQ^gf2B5hWG#(}&^b7-=2mD9E z2s~H-cKWd1!BYs~zD_tZfRhA#c!0^1><|+Yfq)g`vIqDPXABgkLU1d1gdQaiI4D7b;;}mT1Q|_1 zk{(zXSP3JI55N;5?f#(KT|q6D_T1NY>ZGj;Y6EP$NBO;t&b5WgT?|{4-|_32Z!o*s zdLvr%5$s2h$^&cA1;*&YH$hCa>e25#dRjK~m@4|taQ+PK6nK_GyeHFe?yGe3O_jUK z%MPPHk9KuS?KK0~j`K(L@%WQPW8~*Kgrz)Y_{iL&i}*t?5=W&B@OI~z=CWQjARq^$ zt~3@qhls+?My*{IJ{Gh7#{9b8?Qici-ucYg>#RYcjQQ0AV%{HvWe+*TJAOAToI0)H zigFd#f6DoD)hpGWB=B6|o}_jmc#1>}-lio$^C?zBokSawV}PPPv9Ry`!x6F(4jn`2=t2w@6zMSxWa{!{ zMDn7-2^VeeraOI*GR!DqQr7UYi(7E%XqFLkJL=N$(6H3JkZ3B&Lr%}Lv1FZe`}p=% zFNL&gY9Jw7XG;Npf1N_S* zh%HFpcN*jaypmuO8C`m3^4E0H+sW(c0w@D?8|%?q0{5ejRvKRe;S4;J1K(zkp-hWx zUC(nnI=zUf6&aI8y8OM$QVBjE1paduI6vx<3?-jgkY{N(} zq@A6_sudi{6r~jU90Z25ju6Ty`AI4_LtI}4)P~axAJi#*7uG~oSTQ_#oKIz_M2Pt& zjXpUsaJSu-qnX#E%! z=#PD|9qtX~bd279Dw2s|?YvTNmgN)sI*oC%Or`E*$oIGXPd=m@DE^8R_v%lQ{PyL# z&v~VI9uIyYW$g)tS0mYu=jT72GSPexy?7}*Y^oreko`AVPwfNgodz+@f(~AyBiyM{ zjbJ9tDa8gwCPtaSqP=U@!=*Pa7oOdE$Rk{kTftx-_s)dFtKU}9g#4n`#XIYo>wOq? zhPCfcTU(ZM3m_Ub#isJiR|ZFlex7afoVzRZFm?9TaNm;D+T>fvhTQy>|OCideyGCyc<+W-OE@_x-c=ZhANEFa{ zTn_pm^aom8W#s5?u)S_}w~G?<=xc8Mr3dzdlnZ@gyj0OHGGD7V{jPf~-2+Qmk9?ytutBrc(r&oLJj>Z5p~&2|?r`q~#_I%y z6KnDUhHiq1-#c_Aym(nUWVi#gj9JRfB?EXoI=l+q1XCoPA4(mejrOY4)L<+Y1{8}W z?>HPF$Ow6_N0|3&R-2nTV-IPY3ncJn3Eb?E0E7W#^?+UyF;Kj!j7kROEPO2%85jf( zmK02@0L9b{LT+3x9e{f#0rId4ByNGy*Gv6%JtF1tZ_j(gxOMkG>E-R8Dpr#Cxff>2 zOa9yJQqenG8k2jNPSyX5u(uAYqHEj6Hzf@s-O`P8mq>RvD&0t@NH@|T-JOz>N_V#i zNJ)1J0{6Ez-p~6z@Av!TcO08Fd-lw%J;!FQYh8Jsg)-!o6usZG%uW_J``0;9c{w{N zH>#;@{(-zwCmk)b^(w{99W(mf$FYJw>o{wyb(iU^PLo^tBJ2!-&z5~tn`T^Fb zycsT<*VOs0OrJZvbhSD8XP+BF3gh5cO4*|!2T>ud#OyX_<24UMvTs+BT)&Q4X1s?G zTn)+IP`8r#o+{4m9m{*GZ}&AGCr56>Z5>A^B}qlM2@xs33XnPCY-ZW+yoj1(KKwX0 zFw$$vY?A-Ao;Um;ogie@wJ=7X%Q>BEVG>~7jrdd;XFt3vU8l4qUc z8T-%cj)#DT?eMW(nS-z7)_AYU1tRx&5fpZb`abtfxlm`svI%O%rAcAuGq?yoISmQS zzPuLwjW+vwiq=;DQzipojj3e!LN~Nd*|#TYehXv z+plnQ8`72^>yNU%xDqf#VG;0fM-BzGzfLru$0HlXVGv-2yy;4_gOkl_U;Fm7yJ~SY z37?*$!^Lrram-V^?>kNq)aiu-E-}aufF&NlIXp&Jp>~KLni>JHlnm5RpU{&A+QEJP z>zzFMAdwMWcEEf}$UUg-XBwuk!lra+K(Wt77VsS)-Scdwm=me<*u?OqXgduQ9!vNXwD>aMmm1jYc zL?R?csW%_c(JC2|U}lmqq2yqkOqpT`|JchRQS=;mV=yu%ouy_J7ZE68Y_k;XAtiDD zlOq8!GdO`33>9$Jzh(#6z<>@rV8;Ou1BmM1`jF>9wwj=x`{VHFdV|4_Lk8j`dniU` zv+heK&ocbo%D(HV+fhtN!cXB)pxnz6%qIi%Gb-Gxj{i*j`ExPw(7?-3@NVZTW!tnF zN-bx1b{f8I;KjvlL~HD7bl-SiGR=&^D(;&)a!@J)c#0@ z$M|j1CrYGOCr=%oH_h_I6str<0GB$u`nRfMNpZ_W*2>~?%$;}N=co;S>IZ3DZPOt( zeM+}Kf}HTuANf0$+@di5foL~SV*iqe+j(NTa#j~gs7Xk$adWXAiY5IlG0b#=%HY6_ zip0EmYcq9k%4M^P&C)0JyR?u!V)psnE@M2SMk$B#0e9bVy&a9hA%fo* zbXn5k&Ti%HrjpMUGLg(kVEM){9egZZS!&X{_Z0DBkTp%8-7P=S@bT<5-WxayAxhV- z;nlbj09*DKXQ>+P@sP$R? zy4>!3;90OW9Q1rH|C4M@lmjLIZW0}B(s!*;3E_%L0>3Vve;_uN1fp5K)rDM(G`;e< zMBHEJKMkr<3X})H*gn&CydLPhT)i&eG8Y-7m7B-Q32_`+jcdKwpNn?Wu8}CML`v1b zH7Dp6<-6NU{GM@gYB|M#8BR&2hU%k4b3ZfGf%gx@EcEonH&sKfJ!i{tj&=-b7bm^B zRI-w48(+*gZ)B#-CR5S1Lrrf2y={hk!xc+`1^dm})W9oSueb1BeFI}Q>MamwIciUp$i}4pzw2pabY~;* zB<=8VgL?6Sl!Fa`$2Lx>fG%WEwDp`v{}tGyu(7?p&v8(m2;zhggFPfOr|4w;|&;q|KBl zo(RC4)6q!=45EY@FY;3{<3f{H5}POP)A5Ji$rMoHHhh_OX*yt8h>RR>kIoeq#qM9VwC{^GRA!H4 z%V8hF|X*KYw9K+%yqLaqe`EB8A$z89F=_@k@(RrRV-6$SSZEwcb zuRf(a$BmH>;n0zI$15 z-KsACry1KTIDMcE^Tfl?%skvh1F*ImaP1T$i1M3AyskGXRG24 z;_HMN`af-}rtX*Jnpg?>Ob1f(@yO1f(GmH?SKvg=Eo1qCaCSv8S&D{QA`JfR{p*WRPZ1S!7-$#qm%8YK8r_!MwIYZ z9fYE{SOue0H38$B{C8vYGW=XxX}b*AM4tbKy>@8G|rJp!qEYFIXIn;Z2KR`?ClJv zMt^RXEXA0+Np>+enbem(F2x7$DI%5y<=S6Lj@jxDJ|$18KK_U-it2mCa}bLxDn~q# z=Ge?QWGd#||GHtEu5@2D;+_)navITVA-v{RFe;t^a-p$PGc5)wO z^+;MeKc-K+alm7h1D>qa|2$mKhe#V0xO4INb`%RL^%0*}>!IS9H1zqfzDhF~Rc-?s zhyUDa*`%XbZ=m(G5-gusN}CJz3}6|}9Q>S8#PqnTvg-hVO@A3&$nh4Dk(@dkM zT{RCP%e=ErzOmy^aKg&d(OF%QXfg+B%O+YX)BHmv{I_s-;}c#JEQ6YDtyTs+QYSD} z)o|HU{+DE*(Bkrr9>}+KIbCberp+*gHlk-(Iu9&l3~o2}(pUde(?M1%bqq~+OiNKn)=y0a4LZT?_l1wX z*GLPE8>=yXSap5u$wDF+r-IwpGs+yZ9pjM}ee(sg~nesdy>CxR`7k^ zN*^Peq~uph2%$Y6q9xff%pKD}@0+at8MEk%$f>|K2YKA>i3~IhpYo4nJ-QehvA)@& zg}Rj;0b(0j3&ISLx>C!q=i7)ls|O?HHcuRC6A+f?Hto_2blDf4iS}r>qZ?Per?F$K zE2@oi%Xg6JQ2F*k(CdTfR5K%;^J3SiQ;}S5WV2u(4>ICh?v88Mu@gT&qDWE&H3$k_ z(k0rTY~{uxO9eL*?yUOi{c!HDsu5)1!2+0G%ARQH|ZGum7YF&JxSmoPLTk? zBEc&t0CMz;TF?&R@i_ExBW?u#C2p~kEmtMGjyOYtQBxUhmvJf&nmVXAco_k*bM=p9 zK>(5v1waO|TY>$&SkM8VV+14%u#kXSXSSS|&@i=wUh=Sv+oCpOhgMR*QK5B}GQAzh zqU16d6k>HU8bqkl=Q5+rUNPhj|v1D1^PE1_n0H1W$SI^^qPjgAR z@WRqijbeT7AAT-m^7E^(*DSJs;?RnLS>gC_bKS(1>h1x*$V`f@w8x5C;n$aD)Q$YZ zZ=J%1TD<&STdojn&j@bn)@ z0jqEOm`Gv-ml1LOClXV(hoq>lR2~PC&#;<9VWga8>%_ecf0)9xWZYf=lUC)~!kp%4 z^Ac#ee~KtB{lR-H70c@rGlo&W*H+f{mse}7IfNvqyAh+AGVwo;P3{v^xn17Ly|2YH z!|06&aA;NXgv*E(fIFU#P?=z2_0b)9E#9x4ga}`jaUvS_@H+dWyX2qZG405!irxCB z)x}jifr{<5V;RxT(>q??Y3JRf4xqnV(_^}b;u zZ+I2r{1wLj8q-Yq1i>mgatcf^(EBdf7_fkiPp!+trQd zU30Y)LkZPJ*;nhRr8=<*#%8;!vL!TDs(ClbkI0GuTT|>Jw2#WbmR!w3%_atJJd; zJZ-hA4s%*P@G=K6cOlBlqL3o4_gn$_u_R5clx4h{mB}aUbq{mLSV3Q@h|YCa!Pfu7s!I3B+ru}gMRlH> z>Fo(eQJnxQo6I4HX{lmmtjx9_ce8D(KukNW(&A^ZDv z`+v^UoD$U*!nz6yEofzKW#7IsT&Y0}+|JOFtKEs4WG1OBH+{N;ccw6OrnGq>VzHk# zr4tf2sthTlGL@jXUjcQ{?_RUdw>zKzRG$gPMHpg_Gj%=3S@YtEM~VAMb$70GoRg-w z_sY{{Zrk!^n86g15o3QtD7?3ZnyOhjm=n7sQsd1^gZgk?xb zoQAx(aVjOS5OG+T5@SyqSd#Mu?G%ST11VLalAEQKrU%{=DD!aC5wAaDLhn`O0S`N9 zOn6-Q{#uV>P#k&$I217!irlWk#Ly`g1d3pGKx)HHL(wsl7z(F0X%)S!YCIGCQ#p}8 z?E7A#tP{cO>cr6j#hC-eo@L2^_0ZhhQWWY^M+OClVe>Y)00F- zy6NKtwlRM0r{P8v!P0xAip7~Z%Fkzq-BH`0>F%yd^^<$F><`#ul)h!MF;n-Tc}8*0 zy!&xiBefARGd*3PHKtG@_~G(fUvn8*73prtmyhp=`<`Qx1RYsC<3$m?EHB2pH0Mi; zV4+V-mfQ??cXD%3ChQnB6DA9_Lmd)GRwqR!*9#eIqGpbJIMn~e|H1D`F5L=T*9Vi~8_ zTF59r{oul;wKdDY-i3ZSpI!Y6Y<=AyOmaf$oxd1}XFvCoMss*nSt+R|1hT8B%gK`Y zmoda>qMsc!g$ZoTdWhz#8>Vt@+`=q?4cL_1ocF#t4*|0Kf*)ClS$UIkOe<`pJ~ zqDWHnfz>6B5*S~Wd765C&NXe%t#<|9EN-3`k~gb&Je-2M`SqgWuu}W5er)|Xr^?~Q z`>%EzA4IfztzrrfW=x|_INH+qry435mPs3)OLVoU^B05-luNOoQC2?fi&F@pQH52Z z=t^M0#HGo#0yk2WMlF{o0ienTO%yqFQi^0A*PX|V9O%df0O9~x3}BOWL($+-(3vnc z0;g(ZLvFQ5yIt{6L-v^Qak!aq9o?rs-vs)Am>mkqA9YfYypPjsI0@$cALBGja>y8sU@-SW9DO3QuE3AX?bQ^A84+^lxh* z6#4kCZ2c{z_}hp5JjWRHv_4gbW)yM7=;CqQ={{fPSe~-j=s!%IvfOW{xDC=SOx#Zs zta)DcpI59D>rL2o^yo`z49?RqI!y2_AmUsA9l?U38uK=McE3>(b zgvl`)(YWzSS7A^%WAWfP|Bu70OKsZCNJ7IsGq`suI9~VkjLehM*drFYG)XNR7-IKK z$XQuryn4xtW3>DpbK)ZVUuZlpOSjx6mM5b!7g>L>&UK6(SnMYN8M@qO;H*ky$z_D? zp&xVXfIBI1s^tl?KhU8HBK`~qm$U(2(tmbYWhJj(#iWl{B_5Xj{4JiE;{=K0?+=>e z_@!-ER1UgWyzC6uwLZT^hjS}&Dvr!&FQ_b^S;#!tEf}wK3z3gZUwvc>@8mTeVbl^S z=6mk+L3Q_L4*L@)Z-AR&Zc}2PQ{$9l74)mp1OkwYb(3`0wB_rxM z;(WXxV={e$-sJYPs>#9jDk)EuoRAM}z9v>hqf9dK-EPdasBRBYC)W^Cd<`WyXEE38 zZ?lUi-%IU(EwWF1a;9GP+g)v9@aEd{t#`|v3a5pI2fO!d#$juP zbxD|4jDF;nU@s~Op|zw|v1DLho#B``n4%ZAgh4I}GqDQeRCpy@L`wnA0chn=tW1E{ z|1bK?V|y$pM{)+RO^iy!(78m|gIzAKC)FNkgu15Pq%6f+&99METUScywM z2Q4c=*;4?6p6c62LW;kfhl(>ufEvI)2SvZ!IpZ%$fGZvpm1++P#Q|!eB>lhz|9&?B zHEf1*usr^7ttEiA1)O0x1DM<%p+$^5rZpc1Pb!xpewa6rNxVPfypCDryqT$!vpzWW zM*1*Y_KMv2NBRPWqU>ZV>z+1x)7vD(MsF>_x?)p{itzybt?kn?%U~w#yqYLMUwkTI zN12j-`Cu;EaP*col#Ng{DWB3(!XX62CBhJCvRQWyN5+ig!9B(}v^tbcOu2C4S0-9> zTxY|a^s6U6)&hwmHymb*zT2f`#!~U`qbuhbF?V%2CCgYFB@ZzB2tTv$M<>P!BOOV9 zYKkH$S^k+N+Er2JRl5@PGB&!$OA8nC4Vn$YP=cRh=hZ(D{tEq?WgLaF;d-*oD6kK= zSZlwnU_8mpSm>UxJGgr0IDc8e$Xhch$K=!#9wPceJ;Uy(eE0Xf@DHg<0Ahf=`GqQ0 z=bmr(!CTAG{+HYwF}zWn8!VD2mCMDwqr?vd%6i1`(Xlmb?E1~yGrsG4dbMlGt197K zvvM)N5UpI-Q?wtf&W1EO6@nAcG(jJ{kk%Eb2sLS1Yu33Se7H=;+qeD z-&-eEs4={jdKw^w;lqBb?M9}}e~LAcx9szfv=k#CQ&+I@M{e0$@Zb)Ue6Es2bG~Iv zx>NhG=y$L-^*eP!$VT>g+7E0D4_Z&aF|9;n=tRhMucficDC2-g;!1eOBxl8krP&~? z>^O0Wc}R*kwePm&(YJ9VFjwP6<|ApvDebD7VxMCKlzF!+6U08RgE$zEz&@|j(H%UK z_Jzi=t-CswDi;-cBxD5zLfrCP5)rd#i_{UbL1SqaFV_a#hpBs9ZX`W*JX%s(uQ>XY zh7)(0cI|Sv&{ALZ4DT^9OK9wJ4HYxVl1OATM+qb2uC88LURQhMwKi3&(Qysv_S#0f zNxiP5qV2;kHsR?5oPhEs?}ywd*S5e75{>YLPKu2J%AO?;J@B5<70C8HD#IMp@yb{wFo*5ltAr8s${6cIfD)>8T3f{e;olI zkamycbvgi&#ws#9urqp{*VUBuKkq#2b>4S@!Ju<%5={2>LWh+u{TMpKpZ~voz)UB! zsD&E-cwzorX237S%Z6y1kXvSabTG}M@Pj$tL4WYSTZf-u;;-P};RBZB|L3$;sIK_m zr~l`&$1U`@1Ah*yr$x``n?r#DJr{d4x5#9;2QM<^O*SUFa)HuviS2rihjAn4_e-1) zdsH#wcYib#GjZ~L6#f+dq=$=AlWDH8W1*gTacFDw4@6felIUJ%Lv*f`P&Sh;lEt`o z{-`oXrAAy62U9FkoyJ7Zsf7))RnUruXBLen^LtQqDr-ym6^|4;6PpmxB;h61S*MWm zZ}nyj>XNGbu%$nTT&a!|pCcV(k`d|n$Yw&m6ZcaZtabd%>_?;w{Y<59iQE*|g$Ad~ zDZ#FrD4xBP0FO}s@(R_mIFnBI7>650OWx6-SG7`0lFnSdfDNCl?U@ZqXmzL?FP0}u z^3Yij8KUsG^>*CO^b^~`%*tpEN6-N9QL>)fc@|$2^HU+Yl-Z9dxF45gV|3Tv??`NMz-5fCFs-bUuF~} z%OUT}v#U9$Fp=WKdzWa}PZxEx5fRQkY|85DB>iP29rj~!8QBKOTMa33C-JhQ*cU@1 z-uCc+VywT)9vl1A3$P^<%Jexb9ojdIS2xIc?Pv>mM}<6%Pzix|9yjpSYnD=6rugH* zT1b4s7*#Cfgmh$6j8TwzfXg_YXy^Pqu(il$N==JziP*G}x+ot8E;vbfI!;eA!z0Xc zT(rr=sM-K$wbSMI9m)Lrr1kepZWbim*tiJ>7eNZ`qick2zLIpVe zpkl+IQ_u^B@@PzT;{?C~8mZtBUzjg3h`Z18(ReU!KPzuyui#Q(9Q8Agg7kIFc@@dC z1~=!84^-aV$0U@zC?ArzKRm%&I=|U^qKj$Ve={E*9GoCLcZd@aL*s=d&Acf-rBzBG z3;05(IBbpLE2((;2pJYa`ckbc6noBnb{Eotd>J_-W=hllJEThnrYQ47BC-LgX|206bpV4{$@4?4C zr%<*mSCYJcPXkn<|G7d0y!&_N$LYZkbY23!Hxt|n^eehNXTZh(&yeMchw=9BX~)o8 z764!JU!Q@GVtlMji2tu104l-A;in8M@$SO^Rvv)Ej~6H#MW9pw)iqEq75Ml*g8!vO z5PE-#8SqY_Mt)5(PyuK0aWf(G%Q}-*)_>uAN)?V;BjRVCFo};aa^OL8H~r>~q9d#O zIK}&HXB~U@`niyXt6>|NN|Wamj!}OS2Jb!Ym>8_bQdo6T6g*Az2pZkKq1TVs+niZr zG=s(}^&cblh?~4bQ&=-wm4Z^;ZSDlhI7*J!Q@j9I>vw;WJj9i)x=XHcBPs5Lqer^?D$HsNUq_1zD7 zO#6QzV<~Y$ay&fxj$<5{#h`A5aNZAH$MqNDC9Z-a9V8XY0YmU!P)n-ZXd>*aiWg@E zY5cN)x5i7nGB%WhF-sKfJnmLARZ=HX;Vm7B+jcY)QD}9D`RyupMqo6i*oUvP{y!&% zzs*#!ynvNZ`REr_Jukdpltsq*@%u01nppgy7`@_SZ<4RSVZDZs)dWSNeMb2i>Td0xBGa!Zhy4ZV55%p)8&;f~bp& z%AKKu^*;4_-ENPx7j(UPc^KJNZsg+5_crq}UGIGqX~**AQ(cu2hHW=ol2Q0=csGyygkoiY)WKMgD%hA^xc8nclU`s&H<>&Bke(bYX4O zBW2v3p>OrYN0*$%m@Cu{faBCk?bpZXy(A>v{X&VY#DUaui6}El%2`>$W~@);HR&d( z(=wqHg?kmRV!JL5InBGmt$2dW>f{%0_Pep3o`qQ@Z( z`tEUn-Vl^LbPMx;57;AI&>I9p&^bEvCjK)R33Wqn_dO_Vy^ zVN(`WLom}eDMz(889iK!md&GisnH3a;j(?JEqvf5$IG~)U>>`jbyF%iY=eu2_&ijN zGa;HuKI1HW{{dn8Yv-B=NtCQhpW7`N{snoZ%S&A{SF6|3z&vNK(jUD4qk0o!q^l*! z7QM51m?|YtChcsntTbq=_>NLqXepkc3bLJLJj`&tg6~`PGQ=ABaGO z0d~O=Gg?|TXkZDSp$oH@yKwk8{wJ@dLv)GieXm|?&Af-kyOYZ-+5QoMt0kN5(?#QS z;StoY#OhmKJ7!vY`J0iw_l~Hj!~D(c^^(#>F9#ot+wN** zouwgQG4Av%dU_2l)g8+`JXy$B=coIVrN>7;TDTLH^2-ZbLMC>&?W6hK1Y)Pp9CDnS zZoP@FILf9n^b&-a7lW)u`ErGLtv?Ca{z~FS?pV^mj>6znjj5#9(6Y#lpC}fvYp5Kx zN-WzvbBuBYwIJ5+7BfG@^Bc`JPtaXeCo^=)`8eZG3-pCFoxEk@M7IL(m=;63V!|pa zc??b@7b4YdEIOlIyJ0;p9va5O43#RVR;cH0pHXkEGksfO!bMUE$dtnT%BcAvBnGvb zeUHd>ux&agUoyrxhi52fwXX3=y*{3sAL-gLid=86SPg5E zU?{TvKq{f28JUX%cde~dv*R|-u)d(UreBlK9)50@q)NiPl!})kq^yTW%DTdQp^IT5 zxerT262vo;d2xDZAMszH?G;d}M<|0H6baywhzB6-0e%EY2pco230Trqm{nN#_P?(N-h}vKD;Z!KfA2bRHGT3>JlKsXnpL3yf z_o1`@h0J(tXaemW4uHI}$HmV*NenUJP^{<75^`1oUQaz=&p<*1AV$!ZI?y(O02_;m zDu#)yfQ-uklf9lo>5I6h$SX5x$lpIkfIuGYs5_%AbwOI-sfLAHL+{k#D@O2t@_#T!B+HA+h6jiWor6zRUw+H3ow36 z6{kO1n>+6$l`lp?9alSc^=1c1dv&!4qv7=<#$+`sosBk~Sx_|5$czb2OF1pXRXIA{ z>xu@FB^t=#OV*r^RJ8_hn9!A!OTDNgqvp_uqt#5or7Et;!MnwI(9DwLTbwe*XfIRb zrXR>rHs?pQfyFGLt$g;q*oKz&!*>-YFALa!4+Fch_^N1k#Z@0bti93X$p2{-M^9!B zMibY9`tybx8l{&T+(B2u2JDlc1#vch(aQ+4`IKjJPS=E2K+OExU-=4LuvJTW@aNS} z$_Pcmh#gR}3~dM4nPR%BZqmZK+a_=BwDbHYVCk`|H@HewLCx-V9!0_Gsdh{;BpT+P8c1 z@iyu`y>B(SN0BWhS9h1W>mZ`cH9Q{UOar3h=Y{Dnn}w_{#|sobBG;XV&4@cM6QBV(|Dep!2xfr4+7^LgJdmTD@8{kWHmg^JaxHvB$in8o^ z<=lrM{@L*Y1Oeu%O_1q^O}6S}3b`g*Y%0zqdjB>XgXW8rzpw$IASM9^Mbf~dl!d)R ztwx&+0oYKNBr%l|AB``CzNqodz3?gay+@jl(R{6N`(CILUb;W3im`^gLV!q3Ou(CCgC_1#SR$1PI3zB}5)niN9U@aM zHkd>Y`;rtQIFti7lmpMM0*9sn8E3v#f%OCd<{>U5cv?R)imM!x9xli!Iz)6Z=Shks zEM15ng{37dCsvL)I1Q72f{mRBVkJ)hbgIZdJW%WjI)s1=V#ft(RfMokixHH29;-Ur zt9cyV?=C6PG*!JkR@$sLh<|EU(#H_VI~KOl@`M|YcC~!22`!Qx?MU9fvK(BpI0%=D(jF z+?Eqvc#VHa{eYV>x>ITISvgM7Da!qJ@+YCCwr1Qim2&1%uqGUJLI-)KGRu%Fk+yJg zcNSjb$*-T^D{N7f z;#)1;2kePXVO!4&X>8K~#~E$HUExS|*4>nkeC6!asw$g<8gp7ZT^UxIiq7Y6Y{ zISi%dxX=dQHg43oTFACiKaWb2QByuGVv*FvK~)OX!%_ctH`cV=Y9wGj*;$~ zdASD99C^JJehT~`Hgo>{&z7H~)@91c74OFcxIA|TuL$$@UMbE@>AoT$-b|7oP06xGBkx6mh9US?`{O6gz_O*R;)ax|w1PyfexmDVLgJiEi3YKS zV-w2nU2u;|Iz>!lvhCkVxnyB~wxr`d_sgBPeH$&8Fzjtn=t&q&A?rl7QyWE+cV7_s zfi+scooCi4S{5mtLe9YOswq2UBqJKL^n1L zZy%4;e%GQ>fnNfQrRhdZpXlkT?{V3c)g99B5vim!KGFi`U8POkN~JG~O|3H*3+rwl z-!=?|^Vr)JY?GY&`8dwEI|*QIw+QA(I2M=w{H@Cp9JHcfSsQvYKoWOHB@(z*Egm4N zOGd_{mRNPgGc>o21f?p5_J{qeqChE-zgRD9Ea+n|==Ya#CqbZ%loq9nSl4g1Tjxp7 zk4vzPk5wWj+%@24XJwrewkDwA${bVsBmK3tEp45PBVfU$yPbd`fN z!qxFH7rGUhK{5(l^rParfK0y=g9P}IeV%|X$AzQ8Qi1Ey@DnP7yuxK8FwS{$qXF^M zfPB@2yhswmcsvO6JaRD>my{ zd`e)1vaKO+=obC(2d{Dvzy}cUN4Om1xe6V&=D8fi zFtLM24FagqYg5qJC*bW!A#D9(A>hvQ1Rx%A@Ium^bP$Y!3>XDo`lbCAmc%%kv2rWI z#5G^fCNC!`u{$e+h9~2Ur(`89{b5H{X9|oD8m|X3FvsCRV@f$cXA#O|&5-bLUGvfh zlZkVc$uc+I_k@n-S&W}~YZPA`Ylzp-i?-9D#_eXyIm)tY*!5gLA&Jw^@<_>%raQ%B z{%yaad}J6iOMJ&f^bZ92N5F^Et+|qle;~YK!Zu--G&QD&ev<$ozJk1y!CirZ)xk-#QVxxwr-I?(%l+K!fPyS4mN)CC?_9l z^n`1w9_tPq=Zz%`7p<^vWPPVvY4*yhe;w>uk>7?wP3`BDQnu}JZ@R8I;FQv~#`fn}sgdwy zi1Nd}0S$%G6|u>xq2JW)yYW;C+o)Wm{js)5IfR|(&#PG*_T*EKyd1i_*fw`G05OS$ zWZOI5gLujK4?4fhNclH1w1&yUikj^`XF1Za-D|2%%&dx}8|SFQy|Plt4zm;;XR8Iq zR|G8AtR}T0$f@41AsZJW2_inva`N`JSV`0b%>x=Esm#n%vUGTj2Sc4n0f8Cku}!Nn zUi-mTt*j=ZWva`)X!_kmU-4~tMBd-^XUs(>U2}a*4y&&fltSh+&ZkE6QT}4RZAEA5 zQnFhL-=uW}L-nyI@Cfsw@@>$1903a3!X^gh{QQqB;k^dziNR!VhhMK<2Hv$PDQ_4$ z1PU!StW^~o=IKdRd1#&WA2SC?s5!CquwsgYBA!&M1zg*{FA}l&;Y>B4;m{d^v>NfL zInSzRYT`J4!hkAaC?2xszz;Dsoz>S*o!Rqxt2{C$|HNe|K_yy@**zuvgC^G!0Iw0vp=%@`p>rdG_(7ZIXS;v`ChrRS5 z-m->(YwGFB*xXvX!dPMESfGv?CdnbS?hm@u>2V;JdXOZoSf~Z3DfdLwG6!Li9tH&k zVig6fI2TAaE-ZoZRS}$mhMzKR3i2;t7nZ=R3y2BPz>(Ja;hRAyUYWg;z<4Hu!ZgDR4zjy$}f}`5l4FbqXFe2e+d_-+-E{Y zv0F95h4-9a8b9?;wjIIsqPDkWuO|@PbtQ`i^Bq}e*4GOsR@#G(3%JbFgVz>;_5kuK z*|`TGVqN{bPpd0G$>fu^4o%24wHI)&7P>mjeUT6do$Q|FeQ;$THvIY89jCaF?c>xH6JcZt%q z4Q6~{c#j#m;KEf03HIO};}TahR>jSYt(DZ;D+ZWJ2e+hsz#(^+wvpDnH~mA^eO(2H zanI|L+~C$LBGS^z$A#LnKYsEnPFw)y@&F0>`NE91u>5u{i{3K5B$~rETwE4swD$+i zQMClyf}==fs?36I8;2)OFNNNE%`sBR;u!b1pSAEz3}>cnSM!r17XQ{(s93YXwvmZf zXlS+g9qY19moN{Iu2e~>BX>-*Q!s#ub^bbn|3#@M-I|7fTokM*D23H@63|;Z)WnU5 z7kNX7;C!}S0_xu9_CRMxX{G;G%wWnh_NP=w86m>>BHxD6g41YS)g$>i2<;2#;&m=} zY^LO2ZSB40w7v9&Gf3t7v~SlF5`GhG%ZML zna|#qvGJ2^oIWX2)#ePAS~t=Ab}QtR_i%+BfJhwrts2BBLXBRJ_*{OxY?dDkWcKxx ziLq4Ag2`p`%F0~L#_2BUMhH4=6i4aMt(7G;p__5L#I4XB)Z&>rM43$q2i>r01OUu5 z-SX*Q6dEf4=4ftTS6&nc4`cjC^4Clrv5n6k_b*P$W`u+y>$Rn-FAXzs=mlzPIp%xR zZi-gq!}J{ewwC$Sv&6>vNIAT2QBEB_(#~R<^F#ts{d1AM#|~!XQ#xL-E@7bZQ_-DP ze@d_BwQ2N@$fRd^CdFD^vBS|bu`G!`F@_RQ8Mq(oo8)HX(A2)CNj$yT`WlekDF4$URn&y;D6e&w5-H2fIT#_tb5(24Tq*gJglrV7l#lhu zlrUdPi`T~9Ddj%IdCI62pIG*?kEsm6o^3tL>wTCt@{xJkU<0p0S?4b4Ds&( zvv$A~vR!V>DcRLMStG-&6u%xB9Nx2}-sBYJ$+;34(Sz1M=U2nbDUb!K_!+tFs#6^k zxIEzi4#=_yX0;%SFgt5z(O$?05g;7 z6h5MW;g0#b*sy9y09Ml8b6w@hoA5)uCbh+O8$Ipp%lSMm?P zDfjQW_Kf{0>WuLrAIE8dtSC|DblAS1_6=bwNj_f^SJ$;m@B3+`_o%>Q}9y>UG{aW@%@041_Q)p0jd|{3K}xr$;oQCAa8y zZ)BCfmG2H;wkV$#?Db;q%?e_XKec3ZaBUYoNq0XhagBSUQMzFAJH307!D5c3=Ag>= zN>8|!X)j5T@Y4fZutSRhxiuQ+_5jNBI}GG|rsPHe_9?o`z=8im)muPCwY_n}XBebA zq&uV~hOVJox-saImJmb+k(4fJM!FSgERY5XX$h5X6akTt`|aW0|NFk*T4%vHafY?r z&)!e`qTJjnhePhqOLKXOVTF3iq>JD-#YU&#ooJUpJ5nll+?wCtPF4Wm#o&-Ws&M-^ z%bV*gWIN`8ydG3z*^CYky-8@!yVE*O$^||aC5g=rthQQ(G#IT|>9H^5p07Qm_5O8J z<(nH1-G!d5j&S$X^Y%vRZO2&`{)U+8M>;-LD%4}K9PNz)RgY?%q;Vt@Z;`^M-|-Iy z>mEvT=(Xr>T;F=qERuOXE3es!yVGNW&QkYqv&c1TYgeVT>0g2oj_T&+^wPR{2gWVT zS;P)lzW$Iy>dWs==SPyb+0of@@39FYES}X$D$^>CnSQC`ucW8-(OWgE7AXbAKRlbCHDWV_4v&hv_Kr zoC=pP!#e0Mrb7je5VHs7x|xl{y}Iz%4PgKGHAe_Bfc)>NSRU9%E9kBSic|tcIgEx6 zqp=TpphF%QxT8Mi*<}3n3Fy!T-b*`^7lJL|Lbf=IfQ~c6Q)9sxWDAu&jKz7Yk*I~i zt~C7_-k%>vhBCKyK3BB1>3DAwy`{8LSN9En%%oE;UX}e6?b7v_GX##~fOTe29BJbwR0=MM{c zx^GN&afTc=HoT9n!F`kN{{toRFz_OB7qgg4#zcssr%@7bne65Y;}QKxZmBXFiV(g> z3zw{ZPUS-ThEJbScaIW?uMie4fN@1P^+Q(p+s?Me%OQ5PQXlt}oujZrM<}5tS?|4LZy-Dh& z-Mju@INwe9$p$5Tp}dcPEk5U#XsOa}br?T%{~eg6ut39k52$A!E z=49~-#!y`9h=rr&zvG{DP4XC-``+m$xTa~>3TVC=yq8MwcQ+zgy+3p(w_#K=2cP@`TTpa0SzdTu^mB8<0I(v+0G2^!w zT>ac;j!lw~79NpzHfjrWTekI9mTK8 z6O)qGgLR`Z)mk=dpUs~nnnPY8y~q9ekK<>Hs+VULgAU0tI5hL}K7Xc-J@OHD#q}Bs z6&KZFWoHEOwIAbGzLj35$O$1;a9;XTxTco8P@+XAJj&2Y9Fa1j@iI7DeY?NP&nL%M zu(|2+AMUU&7mx}J*rNU4_Evo+rmtNB?hH?jFPA6I8fD22LC@wBamBYP9I-X)gbo6Q zTK!ia8u?EU4@k*|OKtRg2R2Sq_U?@+t(1!=a^Ta~8_re*Rr0q~k`z}h+_ljP^zO_) zjD$zh?#egxxT-%XYqPVpYRIe^wH%$Ji$XfUA{Y-6^VZuvUi5`(Pv9i9u&s<&l5@?5 z@8`a{*sh86mHO%)D^b}OOq=GBD%e=p;W=S$f^Empwn7G zHCMWVX{$)zB-q@f>Rdk?2j>gF9DMhABIwDvv~}eSb1J!(t?O@w9{tuGeR6UyAL>Id zb=pO?#V&8twh@sL^;;V6ZbUYwqgkxpT4isZ5$Lc)KbS03aGU4e{VT<;mcz23LR58JH{-b_Pwd@1F?Pb*meG&(eH@{g zInI^2ft$yi4Lm&3*>4m{bbV*WUv8Az^DSMCObKVTm4qkeUsKI)LMZC%HCMl(<1Ba^ zR(MdA*llW_Vs0ussV+}?RCJa|s3J3qit`nfVd*VXub*vVy8}E`yAQ_EpUG%O?`pf) zMM;fl4U6>5ewoeDi~8%}*hD31 z#zS4Db&V!1X^#b8K7swaa`$f(;<08^kyb1h^3ycN(ar>%u0I%^G#u=lPVyzJ-tlS} zRPDbt(c$IR+Ox1xZANm_h;!1TIDtx9%7D*?^Gk^c)udT5izzBmOu4L2z(zK`Y%;=` z?(C@6UN29)UL=HRFDgY&k?`1}M4d+Rm3XJ|jNA7L#V2WhWj1UdY)bW4OdZo|000=P@5H=#%-1pG_&ONE&qzm4r;Xrm> z!^(`8C?<{aFW;T@lwgvSS3fssyBs9-d-DvbX&3+G2m2P6k5|9Q2~JGgcaAs0an@9l zGff`2zi5~K{JaDpbRP(wZk9(gnmqS)d>@Y`H&RShV=hYC+$$sT_{lR|snOI-6FSya%b^C;gi%QY+|v>DsLxPoHHzeTQJ3Q8&g_Z;%{dO2$_2c*<*cVT*_?5cFguB1 zZbL_PH+2rS5<^M{BP{t<1Cbpr$TDgv`KRkbX2LUh!s<&ImsfSN$zfe88;p zmha{C1$J?leUKDR8kM_v1P04@a0C-TtQ`m7bmm_!>I1;i|c7`o^51_*Lji;|G8 z^D&aYt{Ypax#cs@{N(v-^=fLiXp_7{JtO6yN^+qW-q7jDZl>H>-h{WhNH+rYZJ~a^ zZ)Z9Zo0M84YTar3Tdg_!jUCER^=wWzX@hr-=ds!HdW;pfN}?{lHtxo5*jvj!VO%{v zGp-l%*i>|t-?DfUW6lziR9V#Bwi$Sb?5Ere4Y`|4Q(N4W@B#84DA;&58CV-^ZrLBUmcM)dac<9pi4f+;T(Yu-N&{W5j=0;&R~PyZ5b9z0+Xoe^bFxHXuO-8F z8_V8JSdmPTFB4=N?j-WZ{QlY{l9nuU*#DO|JLyjG!KDe61k16f;+IOtE++}~)9?4O zq#2JCHxSEex1LC|Eh!55B+B?#iCteB!Ehe=t_k=|=Da3+WmVm3Mt5B5n~D(JE56r* zH16-^Bh-Rg65ju+R9JqN*zeuga=E-QYv}!&+hQr{6&ySo+kC_;eV4+b#LZU+XXTFs z<0TWnm33Zg66(lzYVu*RA0W2DuXts;V^=n(#1H}_P>1!ebBZVrtYi$Xg~3&QcY%FG zVa`sB2{lmQ#!kS4sbNI)t-02Pp?b!VHmQG80%0aW1MR1B{IwO?bo5omep= zBQ7BHQ1mv%2&=k|2TMfLp~400(|rvDtIJr+dtQ{DD;@0!PP$bxOOG`##Fyt)E}4{P)U#`2 zZ2%38&uL8E%{s;wkJpKd{QyB^G=C~5yz}FST2!GoF4LApb&+bAau`7_$!c}Pz+XLn z(KO{B@&5iS39`Z)xBIT8oF?Mm4xvhjQH-IsiGOa!U+MHsbIPl&W#XoT;on(?(>zl? zU$Km!6NK&0u$WS8&``UTy|XTJ^HBWGK+4|t_H3)#cO$Wnr_Fy-wVImo)zt{AYp|Dg zCYb4#|Ni-S7x%#HvPnq!-OJ>nZD8@9sg>mGvZk^Tns})6ljZlwGuKLms${}Q!L{0Y zyH-yB1+np_1pd4Xkwl?sahjz_wBp<8{iO5E%@3b)UfRVMNT}t8fkcL{bH9m>uDw2m z=2XzUK9-+OyxvYR`aM(pTG89*f6vqYLdCP)&!1HcTa%J(wmUr0kr()vtgl$ow)%57RweZxx5H zNq!}pIr^flU#8C&(cQf1sLqV8_)9|UusqUZMp~(mza;2oz~T7xXI~yow%jYlA2p9r z<7>8juWsuxg><~$q9=6Id$^%P$&)I6H^U^4vfvMA09BQ8|66%~UW+xT&uT-SeXVpj z-JDz}#F=X}vm;5!OW6$ zFLdf#Ul-6e&UgIk-XxSEa?~7soL$N{qr^y^6DYUm8~$(=l25o7ED$vNg@ZufKuz>_1hT4(B*- ze^UL1>{&#@+#uIf6s;rZnS^fsTO@$u1B(WYLxaW^{a-JJEg14VgWzNEt}%GT7}gI= zGaM(|h4$N3(}j@>jf>Ai_7hW+VBbXcW?!{lK=%dQncm^QZVYI^z-k2{u0tY-lM}yW2!{WMXp2E0~zGVKs9D4 zDgr1&50Lmz>~bjS$AvfDCd5*?8$P5;Bkir#)hWZLw@z|V>&?1rsWJ?fzpep7{J54V zUySr<3`Ktd1siQ->N?_4Agz7Uu*RC@)Sp}NqHjm|yYODS2G8{~3?CD*4&B}$LK@BA*gjX_o;t#H7~iMhsl&}?r=`BM?K6_U_TY1FQ=;5$>S4c}T;sN> zrl|MtpAQ#2glnHGudXoR!o$sPyhxBL8JAhKqhnN(J>-ZA;O|oRhS-uA^l=Y%-BSgA zq<$_)vXw5>e3g~=yg4EEWKQf(x`Z4-+_*#Jb_{d`{XBv_2efj{>8y9hXfb={2OjnR zKol~U(j`mR-3i8{dc`iJ$A%~xwh4E+R4%5Azo<^^7_oarwewRvo~eLeUW8q!F%OXL zh;k<_%|uW9Xu2r0?6}nC`K@`^GG)MBfpBuhrhu45zcqx+eDEQYPdw2@cj9TqgW(sq zt<3~KbcK-(M<6oIb}kn>lUC6aCQlD|CZd-dqM0ARp*<$IN;^yCdzZ?#{DwZ@ABYg) zz%laSL++-bVlv_{c(%md`jNuVp0;m)a|J&i6$ILzV)6b`y=##s=hq$3+t(eOfWFIX z+xozwb4fi2K2KXSXD(whukIvW>OfwWC+MLI&$ z94-U!T34nf=^eF}7vuUXVT4b#rfUcOft*9a*vuc;4~Rzm`3L&gA2jLcb++NQaYi7W z@_0}qKWNn8)eMl!Y2KO;Q?*s~l2GK?XE7eOm%kHz`=-BT%H~&+$>YYr`b2zk26r)= zB0#xmt2wnSuEyK<>ipSfu`=BDDkkL9NSIpHZ{6r;PiBwgN%ccHRGTKgnq;7fB+;yl z_41cZpQ#}+xc{Cy z9wJAND%Ywwu5c+ocTn$rxtvJaM+DnXd9jq7&d(YTL|*HdD1ReN|GVg&bW$t!57aRJ zo2$lsSM~ArS6g*#|3ED&|3E#)zZGncf3SV{S@MkfK-kaXjOeD@yU{8mTN_$YSoQ;i~WRHk=Og%1glA1wfrK7f zI)~rRw`r6>pOmO{zigsDRh-I+YLFwH!pHE1U_$SQ2@ijx^9h+B82rM+oloRTp^%r= z6Xi~ENL|%@#pCGZHUx~2!UdLvfY9Vf3M?0xP7V~i zih$i-V^8ka?L|@Px?rglgfmi_lkX5MxW;(@lA^p57$na7gl&u-A-O0B=08*&Os*R}p*rImWai3?`RvD1Gse4E$R`z6Ec zr|n)}Zd>~tulm=v)%mGEH?A13qcQ$W6ev{ff(I)cGQj_pb4)Rn!cav5#3d)TPcYRLCp0}qD z2RQorIx0#W?xa4fbgH(?<-_Ul=#mFw6!4#DTn&c8HmdlrFDhyudm-Dy2AuKlvfJ_}cjfJk8b#whtGxRjBA}XseH}`y50wLivAjw|yKzlAc{`qcd- z6iJ)jR9bo(@pIZZesv68h!1a(Zx~o^_^rDYl$9>7h`SSPEuA{BFW4!S^X(*pZiUrX zk=$sj(T2Zg|6b!w2By-Kb~25lnvrMRby3-#ls}-ZnNoS|{e&=5*NHz7M(0$wSbgI! z9pW?@)I&_wC7s0ie(~r3e*Qg|Sk$Co-TZy4O2*eW9L75>Hxy(0n!M;;8BIkfGW(EU zxErNo`q4O;a?9b?B<%L=81@bBWC#bvWix?=>_suPj$whIYXl`%9AtxANOPtnB7ML=_qunTM`2g(NTh$_RPLBJTKORyaiASeg&9)^TKl^21; zABGsmunv(OaoC7bIOPxkmc?P*et>tFH1$soNCFl`stss<#*xhoeedu+^upOq@ zEsUxA2!IUH=X-6G#+3Gcx~i{q?Y-O_n*0I=&uP~yH3z!CHowOvaR=R6E&te_zNllK zU((I?r?wwTDQ3ORQk^N6Kzo+*VFz}KGe40~Fce1m(zB`4@-}l1on^Ay{PMK-KhRA9 zf{5e~A}fc(dc<~Kw=IjB07mHz_TB)zT+Z**te5u-%_A-zoqVkr-=bN`{)gDYy<&ZL zHp0Iay;L?g9+%_%8W`nqdApP=n_z>LoP(Kk<@V%b%D(Pu-D#2OKH-Aftdx!r#%yQ; zCNU12C~$aD0Jrw*2W3#--G&347q;%fuT-!<#jbsP)0|i+f<5h<_Aar)_*ufpt^Lb$ zLywu6=QKAIf6XcXIfHjpzCjpjRK(Jk2825sKn7*F< zDtX#T%uZaCJM{~;wviE_<QSSqkVS4O#>*@F;{1S+SIym4LH6J#G7uSKQks%Q>_Pr`99`6<)|-%=)D|H|>x ztSBd=XC&kPj6{QUGUuaq3SXr=CC+m3GxFM=CmO{nWNi;ENB)6q-K~E!(HIhM{dNX( zn!(Zf&!fy6A>Tg`FSrcIuU**5R;H_`ICCRLv~Eo1`OK0AZvszWZ#%w2m&E+%+|APJ zv|G>w6P&CEkS4(SgLDC%{4qEY74#4R74ZS^FoqunmhTmeNJ?D~n0@C!8TU%Chj5$V zoMa&e@Tbg^5UqR#){Pc`o@Sn(L$SJhxnZk&yaO=k1wb!J<$-?nm%$px;MzG*d@TgS z-_yi5!U{3K-tpQ)_JtM#thEadme@5chBdn^+~-J~K|p4Mo*IaXA&C(vTLds}B;rSs zr=xLbD~aT)u$gZ|-kdYAUGg+tUE3yl@e^clmnxb82-~+0UT8jt-{Su! zfp32!I^p>tPSpuV=BLh|1ry8e7FiCu^pt76mO-N{aT`@HDc|QdNIZB+thD^HNpR|- z`NW`z+NIbC=JaSyzjaq*$27XEnS<_jwdhe8zLMhm0}d+r)KT2So42*&r0$6P7D zpTN>YT?Y^TeCfpYTjX->DwxfM7l~(tqBlZ;0Ag@~A$y>pZ%KPSAL4&l5aDM^hlv9` z4&Pu%etKcMpWD#3qRcaT(vD-6d|~?A(D}~k1gBR%%|n{Q5@AWQq*Isj12jU=VmqLD zdxxWb02_Kb*=XaCA~NL>mmkZf@OkBIyn#QTyt!!eO;28@Nb(wz_IQ}_#fZf6uk9HY zm&O9tf1p5bh^3rGTUsS+?-9q6_`&t$nU=uu`BT5FX(>}pR=&K{ki)ofvbJHf2L1G- zbvp63-rhZ6l*USo@f)^_1cqp!8yOW@uPB~o{F}dD*(|h~+#@9EV z51XZi@2e|txW3lJ8QA?oTWG+@GC@5MyPu?iTUGDgA@&x$aeSz)IH*6#Q|?6rLl}|$ zCEQf0_P&5`Q{op3`YpMoc|Rn~3i6zTlpCSNMko>(FakaH0D)wHX~#t*ob6Zm_c83xb8L(`P!5AjH;T+# zk1bWRh`*R3DPsl-`CTHc75mvK2JEXup@a`EI*U5T-Y5O7EVdf(N%u)uq1e?I&C68-(jm=&|zmIzxOuWQ}Z$70?ZQx}H^sm>>gVBQ@YdD0~$^Ab76)C;{9V*sp3-5 z%L4;w@d;`Oe2z$Xf~*S=TZFriL@q9{0yRi8*#LWsw23XPMj1t34A%bx=o=r39uQyf z=1ASJmB&JivEIy|8?ZNKf()!SL6~lp9*G4<60jmV?sAU7Rsk{aDfqKg0|0Yym{ks) zvn+4eh~kYgEyJ!>LnYIjG4(lRMEjJXIM~DS)I^ATARYwhZV<^0+HqP|#-!^ZNe(L@ z3MUlj9T(U+P%5rQQq*ep(c>5t^82K8QO>qmQ$r<1%wP246 zt;RokS3jbGuG!T2N@OD6d!OsuE%$d~11?<$cv{E#IO+q&vK_NKO&!J;!}fPA^0VYQ z6;(4>52$wp7oIx>8ELeJwx+cBBJbp_0v%8B0%ogRd;9$||IdpNU*8Bl^yI48R}-Q!!Quabn@gnd&Z-H9E}Fs!%YMdpp6j zJ-=Yb)&6S@ml^C8=Z!Ua((&@_oDjrsL}MXEzT$n(0r6I;U-_>G6~Ej#Wmd@WsjMp# z7d|o9@b{9F%sg1IZcXLIF+^=EmXEi*8=+xv=}Q65fglsI*97@G43mm0mB3pJL;}aK zywFM*r~}Y`hO=kk|6kMtLj{oO1d4qa**@$=-X-9spht;!2SphwOeGGI5BG#iAw($z z*blfh4`8BoyhM+AU^QY$qIqP53;Ldk0h~?^Kr47-B8#bDglf8~PK;x0Q47HaSfP2~ z$KS%aEZmemByI|l7?8=juFiZ%aZ--Gp|Rh~9NOT1|OmHT6&yDEq%OMIXe-axevua85rp&bd|vjpmbQj<^c zg->uk?mSnfpv=~4a*Uxp@F zZYx7^A2;ikuFXl)t)BF)7ZsJ&i@zxO^hV2U$kcV3j_j3#-9sA(`#k^em;MR7mie<4 ze#oub<9o={>RQc%1!7u%vXH4mvgNyRN6g1sX?CmTmVmZ9+ngFzro|(TCVCfkr)f+e z(^4bXM@UsTB4Ku0zbNUlNhd&v%-8w3+vSZZe^*-lhXUS(r}pGL5kHy2DBQW{WmazS zY&eH=%4P?$7>R~iW!|qb7yka;ejZLuBW}pZo=nl1#b_^ou$4o}5IJ2Urn&iDcwi?L zDD-8FxOQd46Er-k!Wx(#o1#lmKWFT{;9v}Ecfkr$(js6kLUC=b>}cA1iM9w!daM8n zhnL|$P?FYPu>6hGH1#Aem0mkHu&nZg#6IGx4&^_0;$=1e!bjF063BS|pt=p;&95%i zTKDH*>1}*7Rr-61L!25>;yYUA z)XG!D2xvpp1I26#wEw&-ktc{C=&=Yh>{0`L&tf4y~A0H4HM0L$oA z0??Bk7~rEQ3ds9Wek%DFGzs?1f3c7#QknlQ$`R0+2Xw?U=&k~aW)x6;Ca*x6R})2m z%p~AM<3^+LsViZu!O-I~NODa0kQehc1@}M^OaVqT} z4M#@NfqCVKybCPKM;H4LhF0nXZpz`9{5~DZRTHMUhraej6FvlX{R$Z1Q3kP5_Xev3 zjMIMSfekixWSC{en({&iOxz;$^lD^pRe`ssvKjD;u*>h}=8aM2jj^?|1IHWyI3U+i z0$r0v^a$HZ!Rs@$XuEYFRt*{`Cha!&CN@n15j7>2_lsJ5mohbAXV;+Bb7s6Va+?Gz zbB*$lkDsRN$C;P^e8jF;iRRk`N?du94gfu06LUW4-)~d!|9*yo543v@37*I3kNzqQ zHe(tx;1f%i*H`&rsmDa+T$?~!lb)z_aqHT>)sZQx{$nPl)kY_ak+roTrcgzepH=0u znj`OoTsD2T=kLCU)irz)+P=PQZSoGO;FW2h_-+SDKHrj6|KWf%eS(~`ke@);;ZV83 zV?8)9gkpLg?XlO$qo%%lC(>@|ATzdYo%)oURaNga>skNaW_j@B=g<=aMl z#;a5XKmDN4V408aAFucI9cMENl}YkiGo~nsvvTaRc9Zy3$yyz887AXvMUvZ?*uU_@ zX4hB$j^YS^`Vl$hCoD};X&dswFzt5Faj|NK&U~+~P{WdzcjJpEgaH&i1%rzk8y7o3 zwv)N!e!yB)y!t-c?m8L>6VU5Hl3;$0z6Q>T5)m<7vkH)zU(xPEjW=1dwqrWF7 zQi5qC0dQU#l7K!umEQpS#TF(-A!>vs1X(eFaXD>)`h`P{*7UvV*UlKTBME57JN^507<}w#=owO zl#{|RR&Y{z;2rfg&Ls@4MBds^a4Do+t&0!nevXOBeY91Uno*iPT56Iu^hk9zBZvAD zQ+A6k6U{DNB@7+`76;{s4TDKp(IX??>)3flv~(^IMTD+OmlKMkb@k@-lV$G0hrfAC zw7wA8p0N4lM@#o_Bo?ZCZAfRB-AYh%mv~?VW+fU?cxY>_jR)}c_=^6OHMwLQ-=I-X zltE>GXX`@s!vlieUo7l8ycweShx0z8%`}M$g?A=pjP`B#lm$jo*S~I0ZM1Z$1{uED zcitI)p;qbk;8u+zF`a=Vb=w9QOPXc?3{q^d`i!yTJ=!TCvX84*w#`<5X!befDOzhq z&%!#$M)_q@vnF}J;;-SyPJaWEdMoRdWeV;cD5zU@${n%VrPK}uVBgXu>hoTC5~g_g z*ERL?^%qX}7~b2J&optz=_}XP^vMP{I>$Y)FJa1_nfpXfJIj(06&n71hAOkX68D7! z-h3<57lxodg}lkxmiq3U%pLcdkvQR}(T1C(QkVYoxDOtRo0bg81U{FK?55awF(0wl z$70nWCffP4C!`tgEDm}nwE!FOL&n)i2yBBUkFpB z0>11+P#7@5>Fq$dS^@hg1A$c-Fo?|nXr=}tS6o^WpdkRqSO?gG5i>W6CQk~d1rx7nvx?Ol8En1~Vn9Dd0s%L;}I{fLcPr&xhe8L1oz_k}S)&ZDcH;o|Yww_Uo%IJoI`w2pw(A zHMWU_k7~%ib~3YX{oX+F@EG^ZNZ(jJO}R@kEfgDdke`saa`6wO{aK9R`VY_UxbG`D zw5jz8aPk1fIjUzGdbYgRRo44#0kliVJu{>r;l_Si>44dd(fi~PQNX*Q;P|?@^k4ss zA1kaIA|(2gGkQ#kMA|1Ydk()mO6|mYktZPajCAoUN+_qiF`f0SX0~Gw8A%fALd(w> zxTD?mhJ2n+<0Te~Ojtkbh)Sd={^h&aULIxpT66!&yFcQXdsH4gxCaYQv8g#COKqOD z99@KdZ|MnY<-V3@S$ZFJb|;$a@5jB?Nj!h1X||IvxwKRB`{b+U8HZ3Ze$<qr60eDB>jiSe&cOemh~^P30wC+h6_Tx;Pz`#5ZpTCkA5P)=wT+ zi(LO4(0n*+OwKkW`D*2$1S{_dOaL}{Xs3pT8cC=FTUu<9o=O6d?@d>O$LYwFT^9tb z8g%cdT>h*~8*9B8wBS5f)BRIg+Uw&hS)qh)gs#7H>OQ!7DKIZaXm(Z*nCXcnGCQr5r_^0S7=EW7X}b`T#%Fo5*GIiRbVPi%9o61 zk7-?kWd_zA@UjQ{D`MwSJ;2fPjT|2~(C z`4njOTs43GcSDe3M-5_wC9HR~$O8=iPbmER1X%U7n=nA;zw!Lv%^s7)gOu9^EKWoY zEZ5kNn`2$V2g^JNYve`0z`59!MMMnDC&RM&ba!Ne)H;uR(5QH^ZoazSZP+Kod| zzWY`Vo_^Y0ZB}Z-BGd&kE}6&yGaYt5HBmLH4u8yDb;nNGqswlI996VSu8naj@m~)h_y_uCB>fwZNhc^1dTZ?0^-4roB1nf`PQOlVkSuqq z8KjZ1Y7lvb_cd=%amkeKM5H?UeZ3gh6U|$k$}lP1`U`iRGM_<;#1#=G7_vL4yZ$v9 ziTA@Fdwa)wr=MTZz4WZ?8jbfi!}UHpi2E!mfE;IcU>j}QF=2lra{stTgyPG~AA!sf zrQCyYG7NWfsA&5aG-~~Z{9Y^Kiy$Mg1CJ{OhCaN?Us$kykz$Et_3dJ0mYx0?Y_U>( zKsVqlt3YHF_6?_|f41@Geq}&$|HQubZv{)bi>)uIt0|Q<;m0Q(dg3#S*&2t4Q%mO0 z*Ngx7zmz_G(X4+`M|0*>$P;4XHv8zNK%iSrTvcG)m}AlfaTM-%-iqX%0GZi~n?*_` zdSyzsR0g_{w-4F|6uT{cY3`UT*Em%vxEJtNY#DFK#Q&DmiM0(aA3G!8RVTKsU&u*q zv-ZS6(w!N#DSu&H__&B4%=K#NfW|AG@Z&9#U6K%=i)w_2;7D<&W6^<7MxOJ!N@H-b z>G+@cdcqUus)iZMHT?G|VoZ|`$?OP8DFE=oI*Qx0)SFwiXr*4JPkW`FHSt*_s+5Kh z|23%U?t|GFfiWSN^Qb!V#(2)%18Y5C3~LwEh?phL(EP94Gnn{~V8+SMR$r0loSGd@ZjQJ;3P#SUpH9z)h~QArH*g z5;kMv&=!pcB+b@=mLez%4g+7~Hu)5!Sm_AM8eHS8#g#Kw<{5Cqb+hAdde1{P%8g=V zXGd`-SQv;7id5EXrTUlJ)gkJ@ItRi#s-Htl*)7SR%~gcI*Au}_1Ik07_T_=j;WVmpJpg#q(*C3zs#w*SG)j^jLkODtku~+e5zL(#PuSreVMo zA=_%IM6>3c=IBvn{Y&vd-;17CqIjFZF;L=KaVas<;}lsS(W=eX%f=# z{gC=K!%TXgC;T0-8*rElYBQj_rJ>HGCjNlG%u&a#UY5f-XhaFCYGsX$47E^kX6W-P z5owTj+=mlBJKdDr=l+3%f5c}ZXP~+v=#8qOz{CqLX=_C(Rmbln@0ar<16;RQ7cBWx zK4)cJM7atm(n% z9-+hVUeQGdiDOcoLUF1hx#ftY49{dk{c7=iPWP?~dU;!diMp+a)V2>u)Y|%z0RcA0 zL64Edp^hB7w+Z5X+2pLP4w|OYP6X0|1%;Bj0kvCQ)_GZRbH4cl%k3op)`9prNQMe5 zH3UqCUPa(ViDxS~AJ%6sT*|91Wx zq4CC!e}^KumzWYRtLdgh(c3Ys z8k4M?t?#6#1FmN#AJUg^mUF_29ce`bqKE~#68sa&Mq zcA$WeGfL7u z$J{L+g4UF{<`(9Kc^nUQc1IsScI>`Kxt!p~P_Poa`DB zJNT%_V|O=Rq49@%I%3i8tHZB-+8_3j>B+b&n_-1J7PoR2nbJiFUFq-3JvULE;;8)#-g%DTo(az{u^> zH}as+ub=K(SBfet2BZD;?{=Mr^4aL|PD*_gb{9k2*Wv8{{ZLUx?2o8LhTL~vb>xsb zIjRWBEQnZ2G(2^@$6KwPbit@x@>KZ&my?LSed5(nBa1+|AMgfQ(d&D6ox1mLQT|h`O)rCM8?Z&m*fFe~8p9O+-NeWHzVoEo z0n`KI3nY^klyc|2%%mZ3?(*`UT#irmH!ZYw!vBt3e$zD*OGq3vz4@ZIs9NUqY1MUx zm|d~y*TEcRGH*l=J3Ka0zp9=`^B4N$7-SqXx%_%VJ}B|Ao5x7L+TF#kyt3tcb$WBE zui`X64j)bCEjZDcUtv|;dS0HRU$d|&>vFJKgHq^yA;D*~L)2_hbeg+e^}VFa{IA7~ zteEwzjLP{%KdQtvPz-`D^q(?T&@~&p8NCl~L zWB(m{SX1|)L}A_gHR2wHhM^wW=Y>bI2gFS!#HGWGW!BYbd!K~r40wUxWI*-Py%z~q zw{)_OjI9|17y!Z%Llrc@+Mt47^aF%iS$K4w3yi2YoLCer#q=@Q20au!;(}$`ipIYN zT8grJ03NxFB5gE)NA^H0E+^<~G*m=)8x62dkD#JpC@3cxH-{G$Wr%!w2)1Vzpx5;T zMe~V@;N!!hf2RhR*!%B#6A1|cHewGJ5vFfwV+=%dFfAgCZ|1DT1Q|G_u97kkQMUgT zE1-wOk0jzn5`RGwTt_kfdtJlGQD5a;@$Y$HKnX%}3@d6Fk|M&SdN=^&vx_Db6tw9J z!W>cKwCNx&sbEM_EEF>@5im{%9oiBQd}VAU6J*(CteEK^Xz3E*p&jBrhEwtrfToWz zSEk`s>aAKN?R(*bZz`1Rt35Y9g;Ga}ora%IHoD^4M?9Gc`$2m=!aG|wEK7~U)(2p7 z%SNdvt)nv|tmCGS5%J$wIL;>M1zg2STVi(?!q;>!wMJzLAE7t4)d+o^>%Z;sc?ZWe zWM!zkz;h1bnXR{;8+^!+PFXp`{W=HTQ^eBWu$b2Fdx1*ZSok=Rqc$bLmGZWug?x37 zK;x6hM6K#>}%%PH{QZPT$MM-J@A>}lp| z>Uy#Tl(nu2#cn4#A*rgOwzL4nPoAL|mvS~Ercrv!++6$~hw(UC;2-F%%39{m6JhD! z9MprQ17e@kJiGj@GFXDkNF5H2zv+JDExsW+&Bl(a}HqKmq$<`5QRDfxS}&i z4zMHdF#}qB0gCa)3D}jrFhPT2+91HbnFq)%OGx~UI7na40JbqEvcv5P15?jy>{&4b z;A8;^i6-z<RadvFrq9qh?K zI3w_WaR~Y`t>`wQ^hLvC*9jm)&`OYpxH9+l?@`1P65ugZ?g}5(n~47y1CH+fE+^W&Z9h8L`hX zsrz-(^?mo{;T02|TbDAON@U<92_7>fUhFPa;Uh2boy5!Rm-{BBxk}zo9DYgwwXb#z zJW>CMr||!A^&aq4Ki>cN<=P=-@4YgzDtm@&WhRL-%8cw?UD+1df{Qlnuy}0W|?(=oV^E}Tv1(`IfA`DJ!z~IddGJrjvKSj_OhS$f;KCJ$CM#iZaJnVdjy?q4bmbcFoV9 z;of21Q#LzUu&}TwbNGFx&3dYsY^o~}LM0oZ@KtBZzIAi<+>wojM>lUAdt(;Lyr&d6 zZ=-jG*G`DNxW|<5AC!3FK9A)IX7hWA&t9uM(0{f+G|bZ~|DEE4q>}t+C?pTYuarw( z@qN{4K9{TAItX!;zB=rLcfcyUcW?ZRA?$R<3Iy7SBlRKS&!i_bRV* zsy1f+P;}kHjJ~@nSg>oj_)5t)qTpi|gPS0`%!Yih!W83``jZrq1`9oW9Gs8!3ynwb z7iPQN;Zu3oZ>(lipSvHT!G6O{>>)ds#1&VWt7loG!8{}5d=HWXbr!rm#3_I?bGk^R zy_7DV1YDN1;lIohA(6;L2K_?`i_GY6tqU0|@%{gP!b%e&LsESOThX@3hC{?&?b*>MX{3)x;mN>c|sHqVYEwBI?7>8BWOw@dGcFZh8G_Rd#EYhlWF}%QR z)Ke_`EC0!qbxa_eM+pOIcWy8#k=N#B2GWk#sR4LNd9maX+S&bU`dW0ix8$Ow>O)M> zfo!N#I0x9b(g`z#qK=x=O$$nK2ta}R3$aBA$(ba>L#8-N@w;VwXWdXi_83~@ZWH5$ z4u9_5at-Rp0G@O;<^4=%`V{Hxa$UzyWZ}5%^LEFAXE%T1scxGEpKkWC+a* zi`?{b*Y7a3b{cHHQjYtkmPmb9?QAckKJl;IE5|g?S<0Uf#9R?L(t9E@tH0MEYl>5I zz4>AEj^{YA_?TYvjI@Grfy3P`cO}pJ6eOL#riI?|c)MEQgOhjUnLv}c zYG-=Q0LA7jJ>SwSVAy>nm#wkhH_xWB1(6TLpG-{a3yzw0e|iJb@UXW!HtF+D@~>_T zrB@_>84I3C!fTqb!1me0!aH%a15A%+xjdd257!V~`g6NwIZU@WW|W*KGS0N6SE1@- z&uLDa^VzUzCNj_T4)(;r*{vFiwE@eJ&Vgrl{PTFy^8Ofk(7bW>G=I0Ja6ET9N&0bu zzxj^$AF;Qk-wIXYEGzmH2VYpLgSDp?fql2(QM^$oUIq$Rh8KixYJm3v1^TzU8dReW z*|5_B;Yy`}+d(Ly{~aJ)`OBoD8K@0Sc)HLu4HSkZQf-L%njS|YM1!V>+HizAii8bS zM}Gu!sLmi>OPuoQ7dlqgSl8RxG~Y$&C`nP6hma0H5vth)Z9wI~FMxa@BdR1iHpCgH zxsZ_D{!1P9c{^%~G+WiYb;z!g)9sU{7jQ^@Im)X)3u|cvOucHmQy3`XMuv&Wv<$VB zAHpF142(0+sOr1^0P2(Ip;Y9MO`{mJg(u5bo4~U9{eyBt8P^`xR~H^vsbJ|I{!_u( z&8aq#YLmG1aio2gHvg--js5V`3dU|u`CSuce#KDGi~ay&x6m;)po}=dnX}>u-;3Es z)R{E3Qfa2*q!>4n^V+JJofrlfDdGO4#=8$$Ju z7-CT}1OLr@jX9#9XH+QmbsBW#-sQC3w{SANuO63aLh?#9Hp>#ri7;*qjY$$Gs@Yb~ zgdJ8;)lqAo%~FG1V?%E9nl@G6a*Bg42hC}wcYm-oKFfbgjn+M<+E_<&`ODO|H?@-z z{W+?uz0`H7JN*j&Tm8)ZQ`$ab3C1~G(QKt-{9>PM1@mXyWb+3L4)1M|XQ{=u?YR}N zi{X9eJ!H9D3(ayI;?B>dVJ$S$^;EN}VEgYoBQ|S>%?~gA%imBLTod(({9KHV?imc< z*LFTD_baW8&CT1(xxw|(%01WAi_xU{-a=$jv(YTO``I;L403EX+=snjy_BWhDP(^w znG3rAnE>9_-5@iWjre&+;YLl!hy0yOk<$WpV>JCbvk^O^?v4*Ct7~4$u>~agG5_?n zUZp9hs?y-b@_+wP6Jqk%S=&atDlJm7H0Q0&-s32bmyV3*Uu(S&V?32RZE(@YG#_RZ z7OvPbG%4>n{OsEs?O{Q7lY%7$6&$dYaA*yWM&<%h5N3k>zXvJ2CQJ(WwaA$w#9-ymP45o+{r>U_YDIx-OPwJ=ZzWe!1bu?zglI)LBP9$ic>@^MjA9?f+dG{0 z#)zNSWr!D|6-AA<-8h4~+k0u4A;B`OnT1Hs7m~x4h)XB%m zOHtf*T$wMt>U!KrFiqKC4Nu=1o4|7ZfW`%GryqC{re`BB+BSqEo;MCmte)q4)-6Em zm4T9fxjID_I%t-*xSO0&v2>L5dHwD2t#kHS*NSNmFY0#*#fOh<@N|a`n%sNXRVeyM zC-}n>=8)yN%?}tZ%K2MYW~J*aVNAD}P)K0-_I0guO;3LNMPH7}GIKIvNvE|)8{^>a zt`YApF;m1h>S(>dpN1>tHkR8EsB8wJ`AbI{0v?XinH^!$-dGXF*O)l~V8(17AjO7XM8$`3*5y&f{P z$+Kv55C{GKHxlmhr&lc5ID_1pm&kU}PmRA@TimXC^|&FeBl1`DO}4kIB5NVfJ6fLfkuC5ACtvUR0If+YXHA!# z6EDbCooIQ%KrISC#2h4{LWp<=GD@fLXhf6)(Kd1*Jn+Ap|9(dZ(ErbY44J$I1742xHuYRcRmOkkt?C~Tbtz9VnJDFCLmft^SMXSLy|4!c4pJhD1G23`? zk!Ij#y+lb9Lpr}OAoU5@S&jTKC%=S3v**jR0I42Lzq%RQyP{6 zGL8&3gzY<$6-4^VSv591s(W-H6vKau2rf0CiI7;H@PR-dO?5C9N*1cq1m6T#10sDv z5rl7q;M?;6-%zzc5pZsMFk22L(Ge4BU*J99bRC2ZA>^Z2h|J-$Sq>G+Mk4%16g$mN zg$m3A?dSsbu<|aTexMq4eBy955B-U_=7Tib8xeWttOo-u}hj4bV=bext zFJvFWsk4?5;Yg6*jWnu{$lmwNOs!KwR8$}HA+lvjCT2MsOPY~2CT4%;LsV=D37NKA zYaB)EMKklv*Wnqvfs7H?#iCNgH-03FpQsBxI`i?D!8iHy`87YVngLuxp%QOCo=$WSSL{9$^n+yIu+b_6 z=TlGZ&#vm{hZ;M-$aQUlUFY`b?m%Y%g@J6WBx?VMs@aX=V>cp3>+UnU=V!YihS5cefF7!Pu9Z&YAy|#BD>wX-}z!)WyJnL z(Z~+;h?I}tJl5n)-?HO=vd_oGSWm+0aF~?Z$WZx4R_tin(!6`gW``Z#z81PZ{E)>Z z1UUk1O)4Q~sl-CR23;2h+3jTLhbUtnyAZ-=mhAtO2P{an3CS1_Lh1ux0p0*tNVNd4 z){evrbS6^&Z{T2!Lcoe34Fs4%2bPOqzW#?k0<;KI4GFj=c|n-See$ooN1%!piwr;u zCk5ERS>WWzxFqDfu2=$kkOzHq9;$U#`s)QASV~dbN*JX@+3#9O|$3J^@>l+ z=qpo`!y*b47^=gA-sfL=UdYe<_ARZ={M)TG=86KiV1rZ7G)i;Ui#<9orD)#a$x>2=~$R3qkyfzpLK69s`9 z*a3N^RKm1WLLlX+uusH_6<6Hb9sYAh^ya@EyN7h5B8vR4hz%Dg#^+9)6Yap|>ezPQ z;mP`VDzd#lfNpPAn3c1+%N!@0bjs29#pm}sH7_Y=w9`}+js^V@)%e}<;E-9}3dz{{ zxpAWr8*#~H;Gf(ycb}LA;CbvEIY*B!6W~= zaY@9XgExe-J>gl%qC+qsV4w~>+}T_Jra)mlXYXh;ydV@~7>s}0S;4*iEEN<2v) z(oi%_VCD|Mf*7j*AXEk>3Jfa(BB%hR`X~;gUp?_I`3L3X{>;xV!M)MdJdyP3{Ma`6 zmv@5<{Dj$CE@JITkVBe_ChF@g2{_tMH1mL1a@q{Efc6Vd;mTq*wq zDbFtH4Plss7Rz>2T|L82F?Pn%_ubYpVOMtMyn!Fn7Q7ssT`8Tdb@|^#CTBTWj53q7ZwOh zjEEWN4av#mC``J;X{lYYAnW9!A?Bcg>r-x+sDuh-PmE>rR;QN7J3JQ*9Z3p{u;PD99fH&)}-#V}y7_{6Abq2F-HW!*iOJ-Z+oks$0$lo9hXo$x z^+*<8M`yeU4KRlWgUWCL%hrO(Z?;p`12okgGWKm6JS@FECOR4n4R$%JH@I4pJ(x(6 zB7%c?V>bAYo_(IJI{WdS(1{a)a?7!7{b|K|{P%R?&rc;2hmYViTIOy@Y#0{sJU!CR zD))tQ(`4KC&ySxfG4?y79+NBOTo%5`8YWV*xLmfAaScn_&x+|EP0~bj1}&Kj`q?U+ zzf=C0%Te;}mop)_EHam6-U*Knk`FFV1x$8_+-y2F%Py;!b(|?BAk_(q_d&1_4Ey_OdiW-?5z#Ba-uCDR5S(SLk6BI9@;$W*O%VH?!b45ye_?G;ZirUnyU>eg%_mg*xoB^cK zUk?Eb_@0SeQ|5wlfp>pU4?oJ(&n!XlIQjCq)~j)(oyX@Yc>@E@--Q-XP|I&~IFdu?G$j2|5Jf zJ%L!7$nft4|Mv@lHHQPjq!ZpimNGyA;_1j(N_db!33R>GCZw&s zFiBbMO8Iip!l9CsuT+G-yQD3>;e z;zUn=xvKMnOM$6M6+SmSnbD-GNn_vkElT}rY>sBsHAf$vj>+p$j^Qs;!6POpR(O%doz<2j%l?)7T9@scDv=OL zCN6<;ui9w|X2uvwmu4%k^o011h__y~o}fVV)oyRoZig1L_I5_NSt`<8C0EMK(!t40 zglsWR&ZdzbA}Uivxh97v5t^%wkhujA8UyyATe}3_MM8B_5IhS)p|K;JT6h%rKkzjr zQf0(2-Q*yMrUQVOBPtb<{~i7n_X;5fgb;BP2x|mN1S}9Vi-ZG^+?mw%R=K27|AxaZ zNM!Q@L7Qu=2wek$D6mO}2Vf&b935VBnD7JWnu}hrl-n^gSe^TOUig)us*3z7@GCQz zvA*1S(4G3_?&@lQaYZsqT4uz`^2bm4LVdi4)`GA83}hbj@D3DBU(k7(?xZIDHAO$d z|4Qxb@x8icOKNh~WtFydOv&%ttDc5NN$=YRUWJClv&^h+Ddy_2ImNTl1NTo8IPU!s=UI#+8+m1wU?=Is2_TejA25Vc+EtX(sBLj_(c; zhw4Sri@t9r>py%W@5b4@cW)aN1P1hc-RHW3sYU8Acbg(P%3!^F%Eklfk*w7L&wTl( zruF3?1)tpKO_m|*u;a$h1rCrWixrr8vhL0fK^C;olE<~l!moO*|swc+(lC};8tO$`dp^Q>OWJp4= zKZs@d&-1(>hKp?LRs*D#wB!ozSNw_)=j4Q#|!4Cqnn3s|mcHvJZyWy}6r*^bps$Ua1 z%;HlkA1$r7(%da@q|Yv-I@BAy)cnWz`L7c~rL{BD$}Mkwf<6q(k;Ym}%uN{F=G91g zJX@=#wYQ`~H963j-Q_*cHJS1aPp@?IP~V}+U|*Ixmcy*zWWFa)sU7>ikrptx?RDam z?%B{l#&>zQn8f(8Y3){?7lT>8=nas}Jn7Z$hQ`q1+~z1e^*}2R2Ti(|-!?f%{95AL z6=6p+?F$!9yop__p1(`#s3}kSFywREh;O>P{K>iAiq&M9{a3sdwx`=Hs$y9gi;gH3~$jTr$8|HphOCO6wWQaY?2aFLa4^)kk(A=3kx#;n=F5;;4OU@h5o;40@ z`_!i#5qLYfr_$oO$In*xGCr8@x(xd+Zi~L<`!-X#p^GeHh!BT{7%6pfhYyK+Bt`bSd);_yJ430L6M*T2DrzF@9kXaOk(})_9Wip_K!8Q?c zJ!;ovESn<$?+X}3GTwt5nI&j_6Zk|B=#UELN}mi;e}uUt$0I~?OH$wcjisXaT!|({SVAvqpAgC@D z!U>cq!ZP7Ckb8hTQi^YD1x{%StUXRr><~_pi5f1+mZE7pgk}P#086{z5rmc@d@8Z1 zN{;nMl#0QqMDBCqj*+vGFXo?aq=hfLU63n3RR1!p^V~pQy?$GMt)9};0pCwg^`1qN zX~&QG9%+fbc~P}E>KQ^uz~wRd=$fi0}T_9xP2XSPX?6U7m8YPBA}HzWXaX z>i4~X&$FB%-$IED-alS7xIehCCOK8J=^x9*8U(vW{^dj73l^J5XlE=bM1 zW`I&ehG+w33?aA&(1{S00Z+>2!a~qVh0i0#NA2usGQ1{1V@wMehi+XG4d@W+HFSHa2EMLEFmAE@J_Vn4a^t@>v z;0(Q~obn1Mm}Vub8+<6a9|m9N9-VwDY>VBp(+c(5KVKo9V8vc=N=5&{X893JLBG>2 zOqrR6R^DdMD(1v7+zJ0nJ_#3p{xViM)E-p>9 z6fB0FO)X7UApZ`164pb07YdKKy>DIAwDDC@!xNN z(_(^G0HsbOhnggKMFf9fVE=l$|M*a_J)j2yV+ei{1*s(KSzf+cMM-OIXdf~cK=>?i zBE&`!kOIFLyyY!$Lz>JL)dGw-!IWWKS^mH)KkVIyA z^y7ZNyFv5b>)<=?W!`DahaG4sc1_*0Sgkhxgkn@VMz7SK67@YoGaW@^-wDs-8p*|!x$6(Cr6VP zH9V#zoaV@k6#dDcWGCL7?X=iDdEo)$$l#-$>-Y6ne#BLK6c4-!8<58SK;OE_)~L5F zQy_RjtMzR|t_s!b)lt*W(li1|Kj*Jms`dIYw(4=n=4L#W3ww1xXB>opb$o9`7~@FV zv0MUVSS|`VEbstFAORo|8ioSbNn9XZt`Rd4h1?5Qg06%T#zP`09n6UQ&kg^ZVj?+c zCPIP|YVimwnvHC*Wz$AO3lgmE0VocM03VN-t1jL+B-w=AlTfUO*d0WW6G&bQ(S8?i zv?w+ct8WJpA4?_$$a^6H0oMzWJwV-rB%8SI6k9(jiYJw8tG%NQuQ|wra2OcUGEfA> z6PR^)GLkkRtSLlyBogSqG2XuqBGXn6(B05^^8DR!bf26zuRIMuDao#{`p`A|4Tk(v z9enzQACo!`4~5G&*zc5gR!)B~mONTr@uvKxIV^&>9<#|gP+RIV&0b(2kfNL-o;)Cy z{$|xjZ_P+c(ewxptvdhU&>^a{ZWt-|dZl}I7vy!@OOuv@8Z2HG(wJY9i|i>oR7UaL z{v!*^bB{c4WqV#-(qkVFF5Xca4Uk{o>v9zPZnaYW`lV0lTCUt9-`8};5_wr9d8x)G z(jQy->$*&9?EQWT4Q(Iho2I_dbL2-u<1 z2kBNo#v>s!k?`bXc$iq~c@T|3T}BdUuG9l}2VxL)GSOJTmI)8ufe@t&qE3IGsXHm! z=jS_`O~NdgVol~EfRER#<7;+vsNHq=U;ase4Vs}7WPPGQrxL}B6VIfAa0Ft17SQK{ zeTL8zpoWM^2<00dfSBxwuK_f~Z!n+r{CQLj1)vrO3&$R+SsRtOQuxZg-p;q)-T$@z zz@ODonp(30TF`olbfgZ4Sq8QY_ZHU`y50=D(&QHqZ;2`vh_d2X6bRw&eMk9t4MyJj zD-9?=kQ+L=7^oobN-0Jia3neO$AD0pwo_(P5GurknNMcd=lU0-8?F?SzR4dmYnTKs z`0SYej9c0+3%?_$r=KIip7i;_9t9W8$=0}&cSPTQX`PDBb$fm_IBaM0UOeTiZSJ(* z6>EO7Mk)WV8jaPK3!5zYrgk>b3Ou6Rrykkdkdr#?tN+@Wy4OnCvy_D^rf|7YP{lZ~ zspY~7duZd5vflAOTSWVfm9GVzIZhf`Nc59QpHl1u%5#*Gu6~ zzgECgTG|ZqslP8lR{*^TromH$F4CN35-|q!44`Hse0`x0RtCnTL7ak+4e;JlAY-+F zrr-qyC}N{%CZH{52nmajzA`yft`Q*70XYXK>Im&-Lqt@A(p2!^-%e%F0I5@Up*RqV z9)~gk%wXw|XdI3BK1wLqmJt&sSiyV;0$GZ1tmlICs*GZ3?h2PmaosKrH6zxn5@pak z!e1@T6eVZCX{cccs9=CCx&R9R2BEFM8$;S94G2LrN9qK6UJ?fU9MJRxuAXqv1d13T zEnvU4lI>`9!y#_ft9nXvTC~Bw%hyqrp53Vt8OjCUO`1m@Tz>pG^`%_?A9DN_V}kQl zzO=%MyY4~dzuQlacYLZ-!~PL%i!VCHpCjN}qQIZA4qfPRZ`vGbl8rSmmx<2T3&d3O zH8Ql$y0%@?bYbkFa1del6;)wY=tH0I>k}Smi9e~6HJ((9pWAibJ7)W_%U0t2quEO| zNlY}VKltDLPPZ%PT=4TYYTrvAye0PI8W^0POovq5!GEH++1Ka%T z(}Io94aZ-@K8co9IPbiIo-CNxA`?q}CR*kUhsR=*Z=(MncUAwvr5`e@A&+AGKNv`m zz3Sd5rB^SNNn2wN3@cao*)x6SY$nBS-ojdP@}Y##-qTIP;tBa6Fp?)H|CDCk?WX@_ zu2SD(Do5y8VrLhc8bKI=HUhSj&{hjIVlub~jg_OwAw`|n{!zkb*FYU0t$qvK6T2aa z9C$YHWq?LeY&Kq#fU&>6O$uRQkA}yBi~tFt2`DD~0D}7?z!3`s8z7SEZ`?)Fg@#UG z8p3^$UV4-OJ|PxEP})UI7f%Mv*#~twkb`1$zp^&vI;x94ingO(DLtBGdX zE+hWF;ghq%CunlDsaDCelTz|ee)dz%F3B;|E9TQSO?&kdMXR>^?xC}Iv7@i%x4qT| zU)#z4+pQjtj@tCf5B#*aNgsMvWU}9UUUq%Pov7J-5VG?7J-Id% zly6q9pJN$s_x0hN@lcgG*Adt>tk_dW{gU5LR~<7jmbUXw#%^^$u=bHb_l02HSF5bS zb0^g*W?J0(x@@@$xNv#tX+s}^TU9zw(fx{h;xmOtw{ZmYp6L2|MY6nd+e+*1`}U?W zOD9JWEW>2;Sx)&E!mP_`WIt*G=BDBwCcNdy|5F*S)0uFLPU^$E0mE6I!W)I9*Uo4N z;Eql7deo$!7N%0aveX<%@_Y25&0E=cz5Ls8-W-YvZR|H@rUUtIliVjN@cvpbIrD(h zw9=?%Co2vbe%edD;8q0 zPC-Z@p~vVTA#g%~ip+Ed^GD2v@{l5+7U>=rgc2qRp>_$pS{K|DM8ngdoB$s*hy_N> zIYevH2$c_Zm@EzHT|Ws|43HxCdND=G=i3ziMKxauH|4T?3b1I6k zXmtwcLo^{ChWzj#F#9(slU0ubAC6>_zO>r;A8;v)2FQ&Oobd9#$7z;E?e%r0Od>mi(q2(P#X`Su_{yj<_ zHW~BHSQLb%7N@GXU_7FpV5|&$e)E#!e-Esxs5ZL$vaZ^UG`@V`7(c~mUyt>OCQ;@A zp5#VZQa_PD`ir=prOk6A(GZ$b!c8I&R^E!Ai1V?dFsIVFn*6ulX-~34FM^y6TFwSZ z_=SLaI%N*>D=Q;tcKxJYbvy?n&J`evE&3c!+%dWcs@tidBOC~wG&A^i~m5)h35{1gsK|3nH%_4M)}F~pwh5KJr)IiAR=K(bf3 z9^_4c^Qd*g*0vyo22g;61A($5?Sm=$q!Dto?eL5iyd|?-`@wey%mgOno*kYLQgHarR2S7(g6S1i4PCxF6Auza=dg~Bz)$?e)-7A zNx$xsI>GEgoVwpDoa!_!WG&+B`wFvevboYE&@`=G!m$hT1vDh}tP`nb&!lH)c^O8I zt3Mr>J=_2ChpE=rs~49Sl-xzw&d({Ck_+G@$8qs0zU})=kG@Sc-!`guJEInS+?4yb z*uKh&RZmNgS?QL}zOSVN1I2f>*pR^Yjm-~I?T_=nzBzXKg<2B28nU^;%w#TX|Xky!fDJA4qlAJyl@*0Hj4m8 zKqVB`1_6@)1`Xh>tqH@*Tv#UAP_`{Fr+^LRqJ=}%K`9JgT?Fy+0TBW|-V};oVEGWz z7xCLRL*URGkPL(nY!V2Sc$8((tj$Y4o^2=eueVIw-{>Tf0!|OxEnRqrEgwH#dMIn(+2jOuqp~{|cI^K7 z0!}yD<61{)j_vB$jt+O0e3sSmF)n&wsqDv{On$<%#IjQ}vrwyL==+tsJquSSD{oZ` zag=@jfcoN_kuh>pP4ODexn9qjV(?q_OiiUYDvfi<)1+tTj@s~vchOn-$sFDJvu~1{ z;704el70tZW+eJET#0%#5@J4sSv1`)j0t#~_S9=ukkyE^UWXgKW=(XZ;X zN5Jh9l2UzyghLg57Gy?sF|!PbG)DsJzyoXkrH@Y{aj6Ilgd{-*L82fMAecs^TLQ_9 zk<=cXJy=L+Qbte=nKu;Zybwz`)Q)9;VnzfRey4y5v9la-%7~0f3tmL=FSt#l5bKJn zGHM5Qh6Sma5rR$tG1P7;MBpCKGoo5bq!E)zu$G#H02Zlgcr7j&Quk?fwqDH1O22s$ zX4s0)%S5c4pw1TKC=jBS{Ht62A(jihOhUe|$CfJ$*6V z3tw-qzWgzhVdAuC#RbDQ?Gcg!EFF!ssF2+v*n024NN%nlxHHxgVyu)dV!x&+)t4uniIhS#=$L>477O1#c$#^fWPc&DtCl)^< zVKy{i^r$}8uGjHb0(3sPqh==xkMXUxIZ?@k{HT=ET9x~qf98i$5WlNX>a^+20pjqd zC!8PdICoe+<7k+@dgtNck9|w~uYTWDZ0_!nTlQr?{f-snI3VStPJ5NaaKH7GS+a^n zBKP2+g>BCC&!%hk6_f7`etu3XQr(g7bNH!b-pWevaV*d->-N;A@Ap3vpHLZpU4V0(PVM0FmlqpJrg3k{oNym6GU@X2^@!};lGmrDwiaD$8?ULNO0WIGjV&l%2 zo*s(fl+p?29{66 z56l^n06=n=xIIoZzB56B$l-D-S6Mjy{o~ar$?a`Cnr8@Pw?GccLqJGH#sHKtn+s7? z1(44W!43n~DC6zGq+$^Ve~^F?f>B7?JkWI22d*6n76?EF%S$anR1QhlAjmqY<>gS> z#6YB@9En!cDd^bLC5Z>6rN9*sV`2BOZLN0(A<^~x`_963#%XYu*= z)-Kh#laY#fN|O!g_GF=X0g?Z~3OBE8MMXV|32Us3Y_3 z($}fp;v~!U^`VRlbgb;Mx#}pldUWEor__dY53AVEn^Z$z7U!>AlD*baG>fS(Te@T2 zxEk=|EvumK3+{m>lRl>S)gQ-ccRW9IjAq;D1@rt_9sfh#Xmd6>xUYA--z_ofzyId< zYq0^`&SQK}&;B~jM^gA>cwQ~}B(C4Uye8qv8ND-J7?KubDZ| zNOdghAk}Jv+xnR8Z$3}1zV_`BI5|X@!cQWj$#2wsQT@gGV4-)zP}#47=&XrEPp9j} z-Sm0qU9E9lx730u*SCAFQmyPZTwKf%_h^z?dty3sfwiD}*D}TV)rPWLDD)Hx6~4PW zXxcsC>!dLI_@Ull*h9(nyK9tx)TkD{=L1wQ0|m-GzSH;X{fkbs%zqXVT2!6glHU_+ zO?*J#);&B|d1K+}Rc}aO2q_M{^MC0|3gosOSY)zrG8!pb7b2hba4LZi63heMZM=4G zAXh`D5zvC_!VDk;MWtcL22rlqb|O4k@Qp1LBXKqq=@658U{u-un$5qa7!9*~4-79s zGzx!VxJn%MxC3FWkXKQGJJGc!@(g;~x$f)jE9%iFkk}a?QGL}6iue8;yKI*~nf3cd zs|4@48Hrm)+um_h_99hs(GMAa*P`L?>3Hpl0$H9_S#p@(ykS6!gP4w8eoisR-9gNg zTVl+2$7po=R4uSxDk`pgxpfyp1J*}|8Q=RTeTY<#f5Q+^DM8^Wbc0Q2@5^Jd)YqZH zPg;`?1sT?3#w{)`RGxfOQk3HOZ6L<(+cwq&{SiZiyVwkp_RzzljfrOzMxMpKU$ML_ zIiZO(WKRSfik=WB{#x{c-Q~zY&?lqcH=~98S<~M}uB5DO+5h6$XOyvKf16bjEN=62 zed3{_^~l)U!I$k^jLesS?3&qjNM}6olI$+Llzh-N{i zF(V@9Ga@YBNYYQZEFfKiZhJ72|1kle0fTD>_nHAM6MtPhcfX890b+Unr_7}64~2QHxL(j>9^J*U zu30!lEu~YzJRuC6KGXXg5ID*t@&N{fu>$a?@$CJ z*P<3_YzoYNo~5?eHCXsJg)%t&oTDzjf0kQmdtML5S;S^^dj+Z8K0Iqx-zvEqsf#Ws z9vdpcjONz2yv==ijl|dBc=oYUd`?+cEX+o&X$ zUQ^_5p3QTHyQ&qxc=3aO45U%N?cl~$YvQ&a zjqr`DHL3l_5FE-lD+!S`4`GK;Fd2&jvNa;$R%d`Apaw`o!5t|)f%HNFk6OU8_o2~* zKm?*+eHi;A$a6HC!WESXv;%ToghvxD3glXWYXR?IMsOiW0!c^+;U#}Bnq;JhgZ(Am zaifcv3T39wDjbi(Uevt&zS${!{84RI@0=m4~~ngX_lndTq3VPN9?0v=sZ#lk|p-mK?L2bH0`y>+sn&Gj?lqip}s& zJEmXX$r;15dF-UqqlGwB#Lt0l>AbT!f|dTX{M`C)mM&4e*Ee7ppe6t8mKK?o;lKIH z!n?~_=CX9uT$7h*odB=nvP*4^o_oV&H^-GgYGmZGZ_BsN$6?y#q@O8krJ;Kw?a4m< z@c4I2A2}XCsU)ID1w9~iw=ojbNv$A6T`QWYeO%E@fOhcnP6@T%QM>0Fo3=eM{J)oa zmLK1=j`kZGpFdI?GE(dVJ6u3?yR=NH@Dy{fg4X4i`+Pi-YR*-Fk|wxCTu`I4RU z@#F2)o89;5y`6$^RQKwY@?~uwj)ZlvVZxp;8=}LzeRuM_3nr;YSLj*`Y`&dQIYYbN z_QRTUHGHMYG6# zc2n9g?+pZ4&~6k#XzG(>qCaTvBeVe0`-??IemZmTyieL4nMTk8c=s9f0-k#7HBYl; zBhGb2DIxQc47~hJIRHJE!NyX^AXRkIrlft|h?y~tu_LfdjOfJzmW>t>&1$*;;Q;nh zWeJrl{Siy=89GP=#q%g#2$DaAAbbN-69bb?ie!7)2O}7f!7dP0^O&H$5KE3mqGiyN zpNK^ZmlE?`$j(cx&hsykVNcRLbo0sNpc2Cco5XJov>iGX7vz*IBF9YRG@{~n;-_Fm*9Ift*TB(6t8F2hLVZZI3+_|VbgbO&C z8JtE|$a9oIh32SG?w5=FBoU+G;~6{0jBa-d_*y1jcD~5#o~o3^i0L*ktz}GnsQ1&# z4#zSx^fv8HMRlYCBPPZ(!d^b>@NdnGaqr31Fy#-*8s(2lR?c(E>=;|ye`hOB^*~C| z4Jb5*9P%JCSKzW?ACFzo(Ce=sS(@!+1@DWmY;m00SJc}RxiTE!oE7mXZ}ua8>A2PF z`2kk>Hs$c8k9GIii|c8WFyqr_&Id}a>5sV_XD;geyg6VR!*pZCs3U$MUQY+{4q*te%0ZTC(*7PY|1p!kOcZeHqdb$NY( zD?pKJU3%21)mTJd`@g0W60pz*M%r+YiDtw`oEq8mW`rU_im= zWfQ5HRHSy6XHF!GgMfyniw9ySWax^a17iu_DWg}G*{|YY2ud={gU+n+%gTSvv5-L8 zJ2P06xQ51wE8S)`GJ#Q!&0E87Z8)46s>JRPe|E?|6mQ2{LTm<%BVym$^PgJc%Fo{UYRFYzV5%$9*~C+X^BB@85FtUT6PZuZpg z%Ox?SRhDaA=WJh)VSdVYLH81is?~sXgF|vQ?(_b+$(B(j)2a~5N4Arydhxk_+1AmE z8<(ZLQ#BUSw8rjM$|ZL?SKimT+CLJMZ+!1vXM1QClt|IS#8{3z%yVm21K+A4?$__C z3Wxmm{Sy=~wTIGw>#msRFM(CFNO{p$9VJJQL<8qk$tUL{`DK#PE;OQzq0KX>`cLuF zL{wj1gb*d0Y{6%)!!uGh*>3xFO&zIPx)#A79?^4i=$pvh-NHUM`loRXNR# z!wen|2X=mrO_ztuM|_^X9p6!?UfX`GCHH7(+2PvMoeb$Hw6My{?(3D>0$ZG}r)bVE zpLGV_;cs69RKIu-JpdwjBMbVk>55klg}`Yc`+(R0o(q^*AS7T>GboQ4ULay~LyduI zl!%YgFk?*B=J5?SNl6RCI}d4p-D{DgsJ-*z&vc)vLons&!bxPOR8Y`>mZ2r*<{%6_ zfNfqx5DkQ5NI)$E)i^R?sC-bNTus#}EKG>74XDB>Fmez|1GbHjZM*~_-w+Z?kRj$9 zxMiT;kU|DgIL3h^WRTc<1vx3<6J{xpm!eLu>PX32I%uy-BfGaVVLpxDllYB{xJd1=%mMMVLTpAhltzVMstEf@bgfvE)5{q5JB! zh^rOelTzFj0wC7BX8zitpPRFrXx(sxC*?J{us}hkG+%r2bNvjLRh^J`bZ@QK-j|;h zOIjHW88)lSoA$}BYZ3lq%6;+Jue3Y$X-=Fn(|Y2>>GJD(eM}X-VT$HecS}5D0T?&5Cu&41MVG?iw&V(P0WaoEW*qwu0zKW_udjHL+DRcjPN98nLVAoC4hd`>4< z*|ldKRbBu3N7RvCQv(n07rk7?+-dkB83nPg#J|tfMQ{GTd|b(%$1z7b@iu?$7q#|V zID3wJIET*V6^SPW4pLXI7|ZYVn3|22#eLDZ?ln>ReezgK0NP9Wd-6>x=2qIxE5ht7 zmCLj=CWZEcmeNSY%u1$CscY)qi?@MxK z&|{5sZ|cD0{tMT3W2jWiS;89tCC0Mi+_|{=Y}F3A%GqVo+1r35n2Uq8ul(Bghg^9^iz5-zwwjM&^!-#WQ8w z*hYujcqKhHudTt<>5cCTQ4NWBwQfX3GYHFK@00bDisZiWRUFocGt+6g43ZrH$)jP+cc$> zM6%Mm2QqZNJ&A~jASPzs5D0NCvJVIC|sJyV=5MqBKR++MEU)URgsaOj( zv{GBc(e#L@7&~f_w(up68!Q4+Y=eke`zg{d25exOo{R3P#tD z0a32TqhKapYNs59hm`1YNPbz?)$bS2)Y2_|@cJd5$ zI>+ZlDz#SmFN)OWCr-Owt$VihbiCv45&B<)XYc6v1?Y@!O0(O%7wkTH=eA1hr?g&` zkVM}uM}alJcJDgtm9}#~&lwAZMCXWH<>(N#-#yNqWR~^+6?G-xP_|L~nZ?)#84Si2 zF-jsTyP=e%3|ghK6r-;NQz|7ROC?#NqAXKnsg#N$30a1+O!(5Gl2J$`TVd!wukXM9 z|EeZ4X5N{3-{;)tJm=i!K1u19-?_PnFS;IG`R0fBfxqaELx9GJr_{fC!72~6Sf+Qp zm=#|K`psqYL`#+R#wb4;pRPRk)3k79n`AwIX!0j6JY=usS*t6yPW7X~584|y+JCF4 zyfCrUHD^Zdaa_UIMUlj?)B8KGy{FwvIkrY=x;6I8YJ&qhuK#*P?UkuB7mBY4S+h%9 z^LRG1VpRC*L)s#HyYqI8cSF(II_&L@Cb!cI-cMTmMdW`LY-$b&euOr8U0HJg&b_$K z(eb|L*80^evX|exIdT5A%A?CzodTJ|YKo9-Cyk z0vZ+FQw?})s0#nzyum-Rl*S6K(!z`n!FBZ zz%N47nw69ms~hHM)NShdx@hM+>i~$T&rFra8XCH7Dc?F+ofmv!JrN5o43?21q?f*p z(TGOM(96U{e5e^?rpN*iu7xU|=257UytF;%hRAu@1duq=#8N5%%OXZ9Abmr)Piv!| zO$fgQ1(mr0A>_1wgfB+UGb+dyE2Ln%N27&cbSj~IQt)Ht@T0mD>XF|RUa=q727mTk z`?>s4QpeGTdU|Pagn#A56r~G~3|Wdbg$HsTJuzOVul@KtzUReNdvWd3u->f8$!A1x zHxAMka%xA6lbu)B#-0RS9&PIYPDQ924o9dQWbphr_%5cF>~os~{vf|L62?M8mTdV0 zM%4Xn zfy}hmM|}mgR=tFaReg(DE45F&ecVt*{myQ?cHVI7!qMKe=WZ^Y!f%CZTDFj0`CH7) zQMwMSv(#;Sck|gh&H69=xbXpO!rQ)roaUJx_rZ22b{kxy zzeN62?9<6|amtbLb+384GN~Zo=hA7%-UX_!4lIwa-THc)xxdZU&|OdOZhfuee`J5` z?N3}Oi7RUbuY60^+oiYR?9_81jZc&cqa$Ge$^l_4V574g$xG^Y)63fr)nE74yHsy9 z>YTc)B`|yDnN;6>?L2qK*V@sS@6oJln%a%OzDvDI=t!P3t#_b(IY$%R`1z+@;T0{S z)wNFROu4>}Mn$zi*b|ZGy?fZP$7;A$Sts4~pM;DYMjCTxbA*F*9<7 zWK-KtrE61e?q!>e4n`RXY||RKt=yK%0aS7=&OgxQ`Yc?X=fh8voFeKq$6chkekkQAB zMp8vw0$T#h<`lsi8w?)of~bG}m8s7Aro$V~(<@qC9qL|MdEOX?J#ThYemD%;ID6=~ zT3EojRsL}ag@~5r%80y$XfsS>u~x5e38IEQoyr;eVC%$C2FCShF5$%7Ba_8dn9S@15IRtE_Fs$AUQM1-m4T42RPbJQ!vwB$0;}Ziu)Hq^o02 zj!m`89?HS@T|;IQ5bCM^)W~JwEmIYT+x`i$HF#a@|BK4~hh5(%eK7b`U2Rg9p@Vj* zU@F$H^ZW8Ic#7twiw&()(_5<3I(DtA`m$Yp;+T=&OR8R!zr?Khz_!v@lQl`L@}y0w zYGMCKoOMuZ+;YV{_~W|AYpufuRjr91yrQ0Ue8ak5Ha-naE=S!;)jGz; zC-QD&ydzCfPEQ{r2GwRq4z7&|v3vf`@1xF{O)^Wz=D-_f26TyMKP!pprVcj+6y1|O zXf-kL;e#n>Sg3#MiO0^Ht$in!54E_UxQJGPBb3{Y9VN!}$n0suv}O;DV>(l+6=huS za*Ei-KSK^xtHSiVh#!BxxxgtFiGrvSS}s+H8XgFj3qi8Xn+WkLF5wr(2rTk{6S3X? zvj^l313ELar62w7jUU{7?bEojxWh{_F_Y>M;p;{_*Kq zbJ*X1qQhQ3)=&{Dszui$K=6$ii+=%-fbX3gfLE4Aem3Dbm@ z5YZtw01jAvS*kkg8`-YskJMfnk5eyt)t70T{a}@ri7svOr^OyqqJJ^i>#yDA8nh!p z#8Pg8@}kBi|8P5W&^R$DwR&K?HsgpLKR1rj)d6=B_9L=rE|h@pJIp#D>U0oTgFp!2 zgw-=}g{3l=k7@h{k4NajjhoM%;o!0TlyV(GYwND6UuS}iJ@Z%>qWW)}bRmud85q?qYFK3zdqJ5LxUh<>=olFhk7lOGGCGBt9qH-)!)py0%9 zXP3rLnW7#ZmB?;4=aQDJT<>poi19JeCXMD+)VOL?CVwDn3`k=qzO(c+8!mBE8IPT3aEpB z+^DCx(@eq>awt=?%bHGAOJAV$3kC&LRS}sev@bpbqAu&PA()w{N<8YzSFbU}37-C& zbM4R*cEG~i?xu!Uzjti^>S(+A4=w)R+sfCP)(yY%$~qwE6wb8Et2Q^@qX#g`euU4%0mC90|n2sbwaAUWzF^O!&c#%{Ca6Z zlkLji)04TxGp=&4WJPcE+mef;R{X5eDez%urHJs_E=Q)GzkiFJnfLG8YawL6RYTj< zGPd>KynN<|i^VDbW-}qZXxFv(+hR(>w%_5YI^b*vzrDQQZSmxb1K)GJLVo>|*9T6w ziZ55R?7jT7tllJoPkxcxTd~pltxf4INT2a+GhoP6`8b$di)GMH2Ph1MCqsi)8E94b zKpXJV__q|1$RH~4FON^dSP`qFSxCd}hJP+&&;*PRpB}ICmT89{m4xu%rtAX9Dm%kn znL)@}%V*7%;H7%UoDKAB#{Xk+hA#%=X9MQ8s)31bs_cPVlZm__@bQu)mp72-YdM|r ze}qrZED!t(_9spedu}5(3zLaUb~aEJ_A5j+!=`D%12IIs696L0`YLeVCV$cQTzoGd zAD)?l%4~|oU{GyW7)ufzP_NMM+N!=t`!DK*wBWD6nzG;DASZt8{q^VK1J?i!VDb64 zWz@a29a_>Munv746dZ4&rIvtU38m>dG(USE`_*D=44jkjOlvx7oDhb$zDX5Nj<4(! z_HYOj0{&kF6KM}a;=s)I4P+2p8u<7l-vJpbN4Ziv@?yG38Sf~U87}5N#}O_wPoWXQ z@hbmBDxtc82Zofo{h!*0_v(EjN`1_9j(oQ`%9VQXWl-Dpz16Fj#O|pxZsybRe?f*z zL{lTvaX8c+x1B_!$#;rog^xf#_qszT1tBclJMHJ|%><@6)kEHM zC08@6q!Ko<2Zj06Zu^m)d3xH@C;3{TGgkyqLqmtOi@-#FwC+5>AIoj$UCh4N&R)Luwx_Uu)Arz*fAG}|sRl*YhHj(pw)J_N1BPD( z#Ynoh2}bDbbtX!W&IXQZ?wL9uBP3mZ=b1E2EpdI-m)Wu3Hi#_0gZys%+vOZuN8Qfy zeFttYE2?LGO`OdU?|!=IZ1SO1AiU&|ys63aSAT;2$jo~G-27DIwD_@10DpeO5Zc-V0tci zNy$r@r_mMfZ`ZP={z=>Or^vcA&zmhp;?PJed|gIH-T0?1h0l^JgH#lGo$bGJ-nQJV zsXF#ZZwsi?=C1GcxOo zukwXB;-Vw+S|_S~|H;%KBl=^+B{qPiUO%W=dyf4so7*gA*O{jWbZbjZT&wM#0fy`= zHbwQA*{dSSxV^ppaT9{97Z)xriI3l$M998V(C-=}twf5Z^3-H1J&4O_levm8&_QMj z)bU%x$L2y$9})A+;t8RlF&}3x46|+h!Vo>YA7ACj^YHdI?^Um;Q&NZ$&FY&Ahgd-> z&L0#nRkPgxxI<}BXbE*rVSnj%v$h5q!MPm3UpC$>vyx z;)IVFVor(TnukOS`tr0t7b;1T&Czu5anHHPLY%j`)ak8)U#oyAoB2TlUdLh(rx+d(gt_^ zEKmEthmtRq_a)u`{e^^-fT$ybLmN`f*G_4b@kBohXI?kavQ@f*#fG}tzmZ*IUb`IT zY1{Ud&}w?M!$8-j0p^;qjYn^$zn5B2u*3Id)=#g$XgdEJd9z!%*g)`T;sy=cPzv~E z^VIFETMf>}kkGuojTMl>auvnL_ z><9aiQ&}}D4A)a`w!GRnqL7g^wnBXCQ1uMwi_&*oqtNBcjf}S);Cs(xIy}J?MGaY<%FdtUqFFCEYuR#jF-2WrK8s8 zEY!#guTK^K12rQcAcY>~62MF22H$XsPFFg4dz(Phpww^3|G_(SmKaIRV6`P@n9Y+k z;B5DlqZESiOJK_d%of;s2^?Nj?F3YbS_)gXXwDPNg0Q-&6li9sKCjHcdC36QNInBN zeC|D%36Czj7Pe7Ii7uW{Xc--|=%KFD+S@K$najgf`X00g)rCJ>&odQsQwaL-gpgG; z1}(FV%Dhl??22`{zwO*@bW!j=I-9Zc&*ngxKcVAXaic0>?cQJMJsjD1a~=V92DVPf zy(Kv4^L+^D zbJ;oZMYG@I7VNjX>F1#OB-nf^!8QV-+JAJM>;{r&e?S0ycU9w_+_3Q05+z9#%}%E! zwN6dK$sOIMW?gFACxm^ge50GpIDG3qQ%`fI>QUu}M(oZ=UX$Tw67B;$=bCwBip!&X zD-P^aU2{J&7R5$~T}>vsQU*YI-H zN~N})uVmk^zG3gwr%ZG`LMx=z{MotLq33u4KS@}pwEB1K=yBi8ja?0I=LY>U_*urD ze@_3Qd7VBoeEyO0?AZ{DgqIh%kK|_-dUt5Nybv9evE1%0EFWnLSE1hM%`ZNzyL(;I zR{3|(+@7kX?O*it+}C|6zNNl(Dl+c&pMYb2j;`dcpoX)o%-EQ;{^7*)l7jr|z*n$m zS+ek=+0ns=#*2q}A^6+|$p?cAWN{nwa!*QZRo=gJ&65M{b}Nb1dONeO<{DXO_zr;d zzi(IX#5A86(Ks~Sr*RRP+0w=`gDnKYZxNE^nQ40Qblhq|Wo|GFx7zmxuRRu-Bz5NH zO~E)1Lc5iJ))?#qxb}0t@#9r4&ahW_)*T3D4Co@Bg^BdpLhW_nGmPSt9SM^e!4_1(%p))B_I z`%1N*#c7l{Lrb}emrMtr_$lip$GAi-*iv5@T^&}z-c-fb9o|sROHIDi<=>X(nECZ> zvpn5wccsUrvrBtl?>?`sbI_Wu+)?*+cQq~EX?=Q(sFci~j`Noq^QxD2Y$KLkGIA<9 zw050s#4ugXOV>jFCnUBk0FMdYznhWaC)mevq|>O{+7$pZ2_8ryC=#>VTS$8i&8si# z)X_nPoa7R(F8jz*p6GavWin9$foQ`JWHAhsP-%bAptJpfN#1FItY(7hBdP#i4Qb)e z%ZPoQx%%h+d4c+6pyDp3_7^R4La`h6X(!rbRJOwOnY}KL zZHyk9xzQiw@OFok#g?*$?Z3u{u%G96XK!3QF1V%Sy;nMi*ZGRd3pGTh#~HTt?rGN$ zJP?*^#WcNs|7477%wDBZ7n8!unZ%pbN{?ms)F<9^d`veUog8vDbX>~eZqwNwf`!+T z3EM|3f*AK~58uD93#ys*;hsz-DRNdWaK6uqipJukVU>%KD-WOpHsR*E zc(2el+$Ir7`}?c)D_=c!uPz|4D9XQ2lyFOAd&Qz&+?L=m) zqVlLe84^m5!Sh1uqO_dd0VO-${oO2_Q^GGUv1$;YMe+$3&Pmu>Oh!Wb!wv_Gh)y!i5!ToFzlKhtxBg#i^&b6 zE7VpICz^3VUExU#(NYR}W4SgcU1GKQg~6Hgc?;z%yaZI9dhtcKtt*S)r46|^hqc9* zsTHKX3LYUWJhK`w@<=xy|0_-aYsK~@M`;YIq&WmwZ0cpRblI$UtWP1Ra~cbv#H+u< zfd;1I7U)z7;vLWA8Z#_&h}HXiPki)?sm5w9B$#MUpu3P+R&vzaUeAq#mw^?ak9_cJEJg8^GwktCOq9!}BEk4O8VP2+V!50?bxTzMV2IumLSdD>}sxP#2`a-ZLTKv_FfXm=p+*lQo{0PL)Xw+rhvp*oSS^$ z-rf&a^eVK*6rVRmhscVlCBVg9lb2(`8Em zqcyHNAzKx6c+*?1h~Q1CG!Q@EwFEiUlv7|ct_7qj&^iI5uLBAP(R_d*+6Q7Qbx8L( z*9UMrQsXirpivgm1U;6XRB>QMwZKaT7^ z^-EsTVXGITUy>NOyqz_PMZwwe<58A;wbHJr;6TqHHY6{#2F zu6_WW6+jq=-cu$XCOnk`{}vVj@C!tpNQTCa09;agq34o`1H=Uc2KSB0ECE4Rip_wz zB~oD7B?!prAiq1eHZhV48r=R=Aq&-$>7aZA{X9t6kH)$7$KYHag22fghY_WK;R8G~ zp9pLb*C9`NXtQP}_L10VC3{kUY zM6kVxs9?2?*ubJrdah(DDY2lH1>_7_xjk#GI zq_~MXVzHC{ISVGf8>-NdA$|eqQe{n-{qc)IUh5gGe{PdW9YJ;sNX)W|T!NnZI4WyK zYW>vg0`MoWjGm9fH-k0i#=;RoxFTRbW)wtV;H;QPG!%(*MbwKq`Wzs7z*}G%>|iCn zV-TNpAUR;=?H-3S!-fqa#@y0_kYf2pyi6ZM{1Sv7K(o?K74<+^pv%<4oS;?%AQVy0 z6am2vRyh=_-*9l9g>WJ8bOe3E&~gAH3re{Rh8rphMCKBS}Lqqe&72 z2S@SgNhCp&q!9r)k5Zt^hfj^1H-Pybo*?X6zzuK%bbW75!+4bjJNyxapX?q;!`Lna}R~XN)U}?G~_&9ppojA zTbJkkdZ$>6!}6U5A-P&7nOdS0I0e!loJD~LDIpCX6OJx9uVZi1!SMjJrNJi_~o(dT!5-oy(xQr7R zu%@6~4#=5kV{>>=93M}{lmK8FZcqm8)Ig^I#@D%U`bp*u-aGqG8 zkM-9AVF{xSGl_?n*PEN=5^kn!0GaXynW8XHRKXV<5?HHG#D(_}ApC7v0?IBF&`tw0 zmV>kDlLUYOB16spX%k7}NMPBT*ca|qLAZuoJo`UJz7I#+&#MD{X_)`^0Isk=6X~d) z1$diM1#%wR0JPAc2yf+v%Z<1YOQcjCf?hK>y!BadV?ELclx*^1z0^Sb`2iO4VFECS zYoN`vlvW1BWN>~ls$k+Jbm4;akf@*-g5!es<^y)g&rf{CD1*diK(l#?-T5IIFaizC z5FWtAiK_Vsyak+h@J!US98FjhEaltJwyv#x$>6R&Pf9(fD1U}0SQqMKVN4HSsBJ{yobq|Z6#I;J%nX&-5ILTiXeO=HaBuKB2E#}X|mwW>id zZ2<(q_TDdYBgPJ}%N{tEE7~^J_u0peFo%d4w&VOxHJa zIb3x*jIMUUIOC=Ca%xUbYD{0CnH67T?G9qP-KQ>iic4?88JjBF4)6VW<{k8t!eyk8KO;rKCfb7GtM%u=WI z_vZWrR_^*gLS(z>(d(^xE>J;GxkH$<$i<0eu=(Men#a_{fMOf>Lz1 zyt?$lRN2&yYa}Z=1yQ6ONk8QH#)=zx)EvtN(YczYB2ioeEeyNBReu2EijliUGOfEt zl2oRMxfP=*cLWkVTN6pujM@(nFu(UvI*9Wwg~75=SEh(WEnxuYfj0v$RLsXm&mRK= z7LZ!pSd6&;J7flo8}NoiiGo?)2=qQvF@%Ju4gv=((E9(I`5+p*fi>atA4rO0Gt|IZ zq3g{LGfd3ez<=1x8{AzL?wt$Jq3+C7!a4dzLaH5;r@YPIpzg*g{g|UdN*(HQQOsy9J95Pt3*(ZD)9`0rXPLT_8g!Vuqu#KaRlT43 zze!uLR=9lFLu!Zgmxaq2xe>I?8^$6Lg2Anvsw=1Dn~lZQUirtOO6fs5ZOqxBxTC?W zE8K#}`mkU1T;nXONXcVlsQb0I!nN6mL5H)!Uo&q!<7`}gKH$1>QBv68N@hOy6%bcH21}`2`&I>l!e06;QuSuF#?L*>AjAmzfn8&V}e8aA)Nv<>W3Q%Pp zMo?|bbgQCxUf~p4#@4P~35}G;SI4J_M2kBTk2KT~s1HOygb{2Pgf6w5`Ghk)5d<$L zabq(<`#qbnfS#)r1R)B9ZxBg7=BR9gh`JaZ_)kNT2u$x?XHvyEm^TIc$8K=a(3OQP zD=hR_4R*Z<{6!30J26i_2F(z6{MRA_cjT^^bJ5EGOnH9u1-=z?FKOUy=aJfo1fElh zsRhjP|NRw&ihv$iC=%eWuqa_`GhPEj2|A9kFd|`V9exp+t`KASi_xa#w_?$a>+vk* z#TgyA#lIXy7OQYnWS2SZeM}D2Klk#fPI9cT{mz=wF+0bK$rVQns5L%H>F#SDs|s%n zZT73T{$xjMFNzP}68Qa7aee&Oll}qXyG2ZK6sc4gew>+7%h(0`HpPR(X^!Ju=gu>| zPbYGSbcY?om08`Tak0Cjq&)gbl_64X$4 z0eHqy(2HU;uYYQciH<`XEyTkZf, + + /// The width in pixels of the generated image. + #[arg(long)] + width: Option, + + #[arg(long)] + decode_only: Option, +} + +fn run(args: Args) -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let Args { + prompt, + cpu, + height, + width, + tracing, + decode_only, + } = args; + let width = width.unwrap_or(1360); + let height = height.unwrap_or(768); + + let _guard = if tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + let api = hf_hub::api::sync::Api::new()?; + let bf_repo = api.repo(hf_hub::Repo::model( + "black-forest-labs/FLUX.1-schnell".to_string(), + )); + let device = candle_examples::device(cpu)?; + let dtype = device.bf16_default_to_f32(); + let img = match decode_only { + None => { + let t5_emb = { + let repo = api.repo(hf_hub::Repo::with_revision( + "google/t5-v1_1-xxl".to_string(), + hf_hub::RepoType::Model, + "refs/pr/2".to_string(), + )); + let model_file = repo.get("model.safetensors")?; + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; + let config_filename = repo.get("config.json")?; + let config = std::fs::read_to_string(config_filename)?; + let config: t5::Config = serde_json::from_str(&config)?; + let mut model = t5::T5EncoderModel::load(vb, &config)?; + let tokenizer_filename = api + .model("lmz/mt5-tokenizers".to_string()) + .get("t5-v1_1-xxl.tokenizer.json")?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let mut tokens = tokenizer + .encode(prompt.as_str(), true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + tokens.resize(256, 0); + let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?; + println!("{input_token_ids}"); + model.forward(&input_token_ids)? + }; + println!("T5\n{t5_emb}"); + let clip_emb = { + let repo = api.repo(hf_hub::Repo::model( + "openai/clip-vit-large-patch14".to_string(), + )); + let model_file = repo.get("model.safetensors")?; + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; + // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json + let config = clip::text_model::ClipTextConfig { + vocab_size: 49408, + projection_dim: 768, + activation: clip::text_model::Activation::QuickGelu, + intermediate_size: 3072, + embed_dim: 768, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 12, + num_attention_heads: 12, + }; + let model = + clip::text_model::ClipTextTransformer::new(vb.pp("text_model"), &config)?; + let tokenizer_filename = repo.get("tokenizer.json")?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let tokens = tokenizer + .encode(prompt.as_str(), true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?; + println!("{input_token_ids}"); + model.forward(&input_token_ids)? + }; + println!("CLIP\n{clip_emb}"); + let img = { + let model_file = bf_repo.get("flux1-schnell.sft")?; + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; + let cfg = flux::model::Config::schnell(); + let model = flux::model::Flux::new(&cfg, vb)?; + + let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?; + let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?; + println!("{state:?}"); + let timesteps = flux::sampling::get_schedule(4, None); // no shift for flux-schnell + println!("{timesteps:?}"); + flux::sampling::denoise( + &model, + &state.img, + &state.img_ids, + &state.txt, + &state.txt_ids, + &state.vec, + ×teps, + 4., + )? + }; + flux::sampling::unpack(&img, height, width)? + } + Some(file) => { + let mut st = candle::safetensors::load(file, &device)?; + st.remove("img").unwrap().to_dtype(dtype)? + } + }; + println!("latent img\n{img}"); + + let img = { + let model_file = bf_repo.get("ae.sft")?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; + let cfg = flux::autoencoder::Config::schnell(); + let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?; + model.decode(&img)? + }; + println!("img\n{img}"); + let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?; + candle_examples::save_image(&img.i(0)?, "out.jpg")?; + Ok(()) +} + +fn main() -> Result<()> { + let args = Args::parse(); + run(args) +} diff --git a/candle-transformers/src/models/flux/autoencoder.rs b/candle-transformers/src/models/flux/autoencoder.rs new file mode 100644 index 0000000000..8c2aebbdc4 --- /dev/null +++ b/candle-transformers/src/models/flux/autoencoder.rs @@ -0,0 +1,440 @@ +use candle::{Result, Tensor, D}; +use candle_nn::{conv2d, group_norm, Conv2d, GroupNorm, VarBuilder}; + +// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/modules/autoencoder.py#L9 +#[derive(Debug, Clone)] +pub struct Config { + pub resolution: usize, + pub in_channels: usize, + pub ch: usize, + pub out_ch: usize, + pub ch_mult: Vec, + pub num_res_blocks: usize, + pub z_channels: usize, + pub scale_factor: f64, + pub shift_factor: f64, +} + +impl Config { + // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L47 + pub fn dev() -> Self { + Self { + resolution: 256, + in_channels: 3, + ch: 128, + out_ch: 3, + ch_mult: vec![1, 2, 4, 4], + num_res_blocks: 2, + z_channels: 16, + scale_factor: 0.3611, + shift_factor: 0.1159, + } + } + + // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L79 + pub fn schnell() -> Self { + Self { + resolution: 256, + in_channels: 3, + ch: 128, + out_ch: 3, + ch_mult: vec![1, 2, 4, 4], + num_res_blocks: 2, + z_channels: 16, + scale_factor: 0.3611, + shift_factor: 0.1159, + } + } +} + +fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result { + let dim = q.dim(D::Minus1)?; + let scale_factor = 1.0 / (dim as f64).sqrt(); + let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?; + candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v) +} + +#[derive(Debug, Clone)] +struct AttnBlock { + q: Conv2d, + k: Conv2d, + v: Conv2d, + proj_out: Conv2d, + norm: GroupNorm, +} + +impl AttnBlock { + fn new(in_c: usize, vb: VarBuilder) -> Result { + let q = conv2d(in_c, in_c, 1, Default::default(), vb.pp("q"))?; + let k = conv2d(in_c, in_c, 1, Default::default(), vb.pp("k"))?; + let v = conv2d(in_c, in_c, 1, Default::default(), vb.pp("v"))?; + let proj_out = conv2d(in_c, in_c, 1, Default::default(), vb.pp("proj_out"))?; + let norm = group_norm(32, in_c, 1e-6, vb.pp("norm"))?; + Ok(Self { + q, + k, + v, + proj_out, + norm, + }) + } +} + +impl candle::Module for AttnBlock { + fn forward(&self, xs: &Tensor) -> Result { + let init_xs = xs; + let xs = xs.apply(&self.norm)?; + let q = xs.apply(&self.q)?; + let k = xs.apply(&self.k)?; + let v = xs.apply(&self.v)?; + let (b, c, h, w) = q.dims4()?; + let q = q.flatten_from(2)?.t()?.unsqueeze(1)?; + let k = k.flatten_from(2)?.t()?.unsqueeze(1)?; + let v = v.flatten_from(2)?.t()?.unsqueeze(1)?; + let xs = scaled_dot_product_attention(&q, &k, &v)?; + let xs = xs.squeeze(1)?.t()?.reshape((b, c, h, w))?; + xs.apply(&self.proj_out)? + init_xs + } +} + +#[derive(Debug, Clone)] +struct ResnetBlock { + norm1: GroupNorm, + conv1: Conv2d, + norm2: GroupNorm, + conv2: Conv2d, + nin_shortcut: Option, +} + +impl ResnetBlock { + fn new(in_c: usize, out_c: usize, vb: VarBuilder) -> Result { + let conv_cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let norm1 = group_norm(32, in_c, 1e-6, vb.pp("norm1"))?; + let conv1 = conv2d(in_c, out_c, 3, conv_cfg, vb.pp("conv1"))?; + let norm2 = group_norm(32, out_c, 1e-6, vb.pp("norm2"))?; + let conv2 = conv2d(out_c, out_c, 3, conv_cfg, vb.pp("conv2"))?; + let nin_shortcut = if in_c == out_c { + None + } else { + Some(conv2d( + in_c, + out_c, + 1, + Default::default(), + vb.pp("nin_shortcut"), + )?) + }; + Ok(Self { + norm1, + conv1, + norm2, + conv2, + nin_shortcut, + }) + } +} + +impl candle::Module for ResnetBlock { + fn forward(&self, xs: &Tensor) -> Result { + let h = xs + .apply(&self.norm1)? + .apply(&candle_nn::Activation::Swish)? + .apply(&self.conv1)? + .apply(&self.norm2)? + .apply(&candle_nn::Activation::Swish)? + .apply(&self.conv2)?; + match self.nin_shortcut.as_ref() { + None => xs + h, + Some(c) => xs.apply(c)? + h, + } + } +} + +#[derive(Debug, Clone)] +struct Downsample { + conv: Conv2d, +} + +impl Downsample { + fn new(in_c: usize, vb: VarBuilder) -> Result { + let conv_cfg = candle_nn::Conv2dConfig { + stride: 2, + ..Default::default() + }; + let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?; + Ok(Self { conv }) + } +} + +impl candle::Module for Downsample { + fn forward(&self, xs: &Tensor) -> Result { + let xs = xs.pad_with_zeros(D::Minus1, 0, 1)?; + let xs = xs.pad_with_zeros(D::Minus2, 0, 1)?; + xs.apply(&self.conv) + } +} + +#[derive(Debug, Clone)] +struct Upsample { + conv: Conv2d, +} + +impl Upsample { + fn new(in_c: usize, vb: VarBuilder) -> Result { + let conv_cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?; + Ok(Self { conv }) + } +} + +impl candle::Module for Upsample { + fn forward(&self, xs: &Tensor) -> Result { + let (_, _, h, w) = xs.dims4()?; + xs.upsample_nearest2d(h * 2, w * 2)?.apply(&self.conv) + } +} + +#[derive(Debug, Clone)] +struct DownBlock { + block: Vec, + downsample: Option, +} + +#[derive(Debug, Clone)] +pub struct Encoder { + conv_in: Conv2d, + mid_block_1: ResnetBlock, + mid_attn_1: AttnBlock, + mid_block_2: ResnetBlock, + norm_out: GroupNorm, + conv_out: Conv2d, + down: Vec, +} + +impl Encoder { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let conv_cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let mut block_in = cfg.ch; + let conv_in = conv2d(cfg.in_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?; + + let mut down = Vec::with_capacity(cfg.ch_mult.len()); + let vb_d = vb.pp("down"); + for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate() { + let mut block = Vec::with_capacity(cfg.num_res_blocks); + let vb_d = vb_d.pp(i_level); + let vb_b = vb_d.pp("block"); + let in_ch_mult = if i_level == 0 { + 1 + } else { + cfg.ch_mult[i_level - 1] + }; + block_in = cfg.ch * in_ch_mult; + let block_out = cfg.ch * ch_mult; + for i_block in 0..cfg.num_res_blocks { + let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?; + block.push(b); + block_in = block_out; + } + let downsample = if i_level != cfg.ch_mult.len() - 1 { + Some(Downsample::new(block_in, vb_d.pp("downsample"))?) + } else { + None + }; + let block = DownBlock { block, downsample }; + down.push(block) + } + + let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?; + let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?; + let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?; + let conv_out = conv2d(block_in, 2 * cfg.z_channels, 3, conv_cfg, vb.pp("conv_out"))?; + let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?; + Ok(Self { + conv_in, + mid_block_1, + mid_attn_1, + mid_block_2, + norm_out, + conv_out, + down, + }) + } +} + +impl candle_nn::Module for Encoder { + fn forward(&self, xs: &Tensor) -> Result { + let mut h = xs.apply(&self.conv_in)?; + for block in self.down.iter() { + for b in block.block.iter() { + h = h.apply(b)? + } + if let Some(ds) = block.downsample.as_ref() { + h = h.apply(ds)? + } + } + h.apply(&self.mid_block_1)? + .apply(&self.mid_attn_1)? + .apply(&self.mid_block_2)? + .apply(&self.norm_out)? + .apply(&candle_nn::Activation::Swish)? + .apply(&self.conv_out) + } +} + +#[derive(Debug, Clone)] +struct UpBlock { + block: Vec, + upsample: Option, +} + +#[derive(Debug, Clone)] +pub struct Decoder { + conv_in: Conv2d, + mid_block_1: ResnetBlock, + mid_attn_1: AttnBlock, + mid_block_2: ResnetBlock, + norm_out: GroupNorm, + conv_out: Conv2d, + up: Vec, +} + +impl Decoder { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let conv_cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let mut block_in = cfg.ch * cfg.ch_mult.last().unwrap_or(&1); + let conv_in = conv2d(cfg.z_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?; + let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?; + let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?; + let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?; + + let mut up = Vec::with_capacity(cfg.ch_mult.len()); + let vb_u = vb.pp("up"); + for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate().rev() { + let block_out = cfg.ch * ch_mult; + let vb_u = vb_u.pp(i_level); + let vb_b = vb_u.pp("block"); + let mut block = Vec::with_capacity(cfg.num_res_blocks + 1); + for i_block in 0..=cfg.num_res_blocks { + let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?; + block.push(b); + block_in = block_out; + } + let upsample = if i_level != 0 { + Some(Upsample::new(block_in, vb_u.pp("upsample"))?) + } else { + None + }; + let block = UpBlock { block, upsample }; + up.push(block) + } + up.reverse(); + + let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?; + let conv_out = conv2d(block_in, cfg.out_ch, 3, conv_cfg, vb.pp("conv_out"))?; + Ok(Self { + conv_in, + mid_block_1, + mid_attn_1, + mid_block_2, + norm_out, + conv_out, + up, + }) + } +} + +impl candle_nn::Module for Decoder { + fn forward(&self, xs: &Tensor) -> Result { + let h = xs.apply(&self.conv_in)?; + let mut h = h + .apply(&self.mid_block_1)? + .apply(&self.mid_attn_1)? + .apply(&self.mid_block_2)?; + for block in self.up.iter().rev() { + for b in block.block.iter() { + h = h.apply(b)? + } + if let Some(us) = block.upsample.as_ref() { + h = h.apply(us)? + } + } + h.apply(&self.norm_out)? + .apply(&candle_nn::Activation::Swish)? + .apply(&self.conv_out) + } +} + +#[derive(Debug, Clone)] +pub struct DiagonalGaussian { + sample: bool, + chunk_dim: usize, +} + +impl DiagonalGaussian { + pub fn new(sample: bool, chunk_dim: usize) -> Result { + Ok(Self { sample, chunk_dim }) + } +} + +impl candle_nn::Module for DiagonalGaussian { + fn forward(&self, xs: &Tensor) -> Result { + let chunks = xs.chunk(2, self.chunk_dim)?; + if self.sample { + let std = (&chunks[1] * 0.5)?.exp()?; + &chunks[0] + (std * chunks[0].randn_like(0., 1.))? + } else { + Ok(chunks[0].clone()) + } + } +} + +#[derive(Debug, Clone)] +pub struct AutoEncoder { + encoder: Encoder, + decoder: Decoder, + reg: DiagonalGaussian, + shift_factor: f64, + scale_factor: f64, +} + +impl AutoEncoder { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let encoder = Encoder::new(cfg, vb.pp("encoder"))?; + let decoder = Decoder::new(cfg, vb.pp("decoder"))?; + let reg = DiagonalGaussian::new(true, 1)?; + Ok(Self { + encoder, + decoder, + reg, + scale_factor: cfg.scale_factor, + shift_factor: cfg.shift_factor, + }) + } + + pub fn encode(&self, xs: &Tensor) -> Result { + let z = xs.apply(&self.encoder)?.apply(&self.reg)?; + (z - self.shift_factor)? * self.scale_factor + } + pub fn decode(&self, xs: &Tensor) -> Result { + let xs = ((xs / self.scale_factor)? + self.shift_factor)?; + xs.apply(&self.decoder) + } +} + +impl candle::Module for AutoEncoder { + fn forward(&self, xs: &Tensor) -> Result { + self.decode(&self.encode(xs)?) + } +} diff --git a/candle-transformers/src/models/flux/mod.rs b/candle-transformers/src/models/flux/mod.rs new file mode 100644 index 0000000000..763fa90da1 --- /dev/null +++ b/candle-transformers/src/models/flux/mod.rs @@ -0,0 +1,3 @@ +pub mod autoencoder; +pub mod model; +pub mod sampling; diff --git a/candle-transformers/src/models/flux/model.rs b/candle-transformers/src/models/flux/model.rs new file mode 100644 index 0000000000..aa00077e66 --- /dev/null +++ b/candle-transformers/src/models/flux/model.rs @@ -0,0 +1,582 @@ +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{LayerNorm, Linear, RmsNorm, VarBuilder}; + +// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/model.py#L12 +#[derive(Debug, Clone)] +pub struct Config { + pub in_channels: usize, + pub vec_in_dim: usize, + pub context_in_dim: usize, + pub hidden_size: usize, + pub mlp_ratio: f64, + pub num_heads: usize, + pub depth: usize, + pub depth_single_blocks: usize, + pub axes_dim: Vec, + pub theta: usize, + pub qkv_bias: bool, + pub guidance_embed: bool, +} + +impl Config { + // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L32 + pub fn dev() -> Self { + Self { + in_channels: 64, + vec_in_dim: 768, + context_in_dim: 4096, + hidden_size: 3072, + mlp_ratio: 4.0, + num_heads: 24, + depth: 19, + depth_single_blocks: 38, + axes_dim: vec![16, 56, 56], + theta: 10_000, + qkv_bias: true, + guidance_embed: true, + } + } + + // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L64 + pub fn schnell() -> Self { + Self { + in_channels: 64, + vec_in_dim: 768, + context_in_dim: 4096, + hidden_size: 3072, + mlp_ratio: 4.0, + num_heads: 24, + depth: 19, + depth_single_blocks: 38, + axes_dim: vec![16, 56, 56], + theta: 10_000, + qkv_bias: true, + guidance_embed: false, + } + } +} + +fn layer_norm(dim: usize, vb: VarBuilder) -> Result { + let ws = Tensor::ones(dim, vb.dtype(), vb.device())?; + Ok(LayerNorm::new_no_bias(ws, 1e-6)) +} + +fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result { + let dim = q.dim(D::Minus1)?; + let scale_factor = 1.0 / (dim as f64).sqrt(); + let mut batch_dims = q.dims().to_vec(); + batch_dims.pop(); + batch_dims.pop(); + let q = q.flatten_to(batch_dims.len() - 1)?; + let k = k.flatten_to(batch_dims.len() - 1)?; + let v = v.flatten_to(batch_dims.len() - 1)?; + let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?; + let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?; + batch_dims.push(attn_scores.dim(D::Minus2)?); + batch_dims.push(attn_scores.dim(D::Minus1)?); + attn_scores.reshape(batch_dims) +} + +fn rope(pos: &Tensor, dim: usize, theta: usize) -> Result { + if dim % 2 == 1 { + candle::bail!("dim {dim} is odd") + } + let dev = pos.device(); + let theta = theta as f64; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, 1, inv_freq_len), dev)?; + let inv_freq = inv_freq.to_dtype(pos.dtype())?; + let freqs = pos.unsqueeze(2)?.broadcast_mul(&inv_freq)?; + let cos = freqs.cos()?; + let sin = freqs.sin()?; + let out = Tensor::stack(&[&cos, &sin.neg()?, &sin, &cos], 3)?; + let (b, n, d, _ij) = out.dims4()?; + out.reshape((b, n, d, 2, 2)) +} + +fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result { + let dims = x.dims(); + let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; + let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; + let x0 = x.narrow(D::Minus1, 0, 1)?; + let x1 = x.narrow(D::Minus1, 1, 1)?; + let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?; + let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?; + (fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec()) +} + +fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result { + let q = apply_rope(q, pe)?.contiguous()?; + let k = apply_rope(k, pe)?.contiguous()?; + let x = scaled_dot_product_attention(&q, &k, v)?; + x.transpose(1, 2)?.flatten_from(2) +} + +fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result { + const TIME_FACTOR: f64 = 1000.; + const MAX_PERIOD: f64 = 10000.; + if dim % 2 == 1 { + candle::bail!("{dim} is odd") + } + let dev = t.device(); + let half = dim / 2; + let t = (t * TIME_FACTOR)?; + let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle::DType::F32)?; + let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?; + let args = t + .unsqueeze(1)? + .to_dtype(candle::DType::F32)? + .broadcast_mul(&freqs.unsqueeze(0)?)?; + let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?; + Ok(emb) +} + +#[derive(Debug, Clone)] +pub struct EmbedNd { + #[allow(unused)] + dim: usize, + theta: usize, + axes_dim: Vec, +} + +impl EmbedNd { + fn new(dim: usize, theta: usize, axes_dim: Vec) -> Self { + Self { + dim, + theta, + axes_dim, + } + } +} + +impl candle::Module for EmbedNd { + fn forward(&self, ids: &Tensor) -> Result { + let n_axes = ids.dim(D::Minus1)?; + let mut emb = Vec::with_capacity(n_axes); + for idx in 0..n_axes { + let r = rope( + &ids.get_on_dim(D::Minus1, idx)?, + self.axes_dim[idx], + self.theta, + )?; + emb.push(r) + } + let emb = Tensor::cat(&emb, 2)?; + emb.unsqueeze(1) + } +} + +#[derive(Debug, Clone)] +pub struct MlpEmbedder { + in_layer: Linear, + out_layer: Linear, +} + +impl MlpEmbedder { + fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result { + let in_layer = candle_nn::linear(in_sz, h_sz, vb.pp("in_layer"))?; + let out_layer = candle_nn::linear(h_sz, h_sz, vb.pp("out_layer"))?; + Ok(Self { + in_layer, + out_layer, + }) + } +} + +impl candle::Module for MlpEmbedder { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer) + } +} + +#[derive(Debug, Clone)] +pub struct QkNorm { + query_norm: RmsNorm, + key_norm: RmsNorm, +} + +impl QkNorm { + fn new(dim: usize, vb: VarBuilder) -> Result { + let query_norm = vb.get(dim, "query_norm.scale")?; + let query_norm = RmsNorm::new(query_norm, 1e-6); + let key_norm = vb.get(dim, "key_norm.scale")?; + let key_norm = RmsNorm::new(key_norm, 1e-6); + Ok(Self { + query_norm, + key_norm, + }) + } +} + +#[derive(Debug, Clone)] +pub struct Modulation { + lin: Linear, + multiplier: usize, +} + +impl Modulation { + fn new(dim: usize, double: bool, vb: VarBuilder) -> Result { + let multiplier = if double { 6 } else { 3 }; + let lin = candle_nn::linear(dim, multiplier * dim, vb.pp("lin"))?; + Ok(Self { lin, multiplier }) + } + + fn forward(&self, vec_: &Tensor) -> Result> { + vec_.silu()? + .apply(&self.lin)? + .unsqueeze(1)? + .chunk(self.multiplier, D::Minus1) + } +} + +#[derive(Debug, Clone)] +pub struct SelfAttention { + qkv: Linear, + norm: QkNorm, + proj: Linear, + num_heads: usize, +} + +impl SelfAttention { + fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result { + let head_dim = dim / num_heads; + let qkv = candle_nn::linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?; + let norm = QkNorm::new(head_dim, vb.pp("norm"))?; + let proj = candle_nn::linear(dim, dim, vb.pp("proj"))?; + Ok(Self { + qkv, + norm, + proj, + num_heads, + }) + } + + fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { + let qkv = xs.apply(&self.qkv)?; + let (b, l, _khd) = qkv.dims3()?; + let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?; + let q = qkv.i((.., .., 0))?.transpose(1, 2)?; + let k = qkv.i((.., .., 1))?.transpose(1, 2)?; + let v = qkv.i((.., .., 2))?.transpose(1, 2)?; + let q = q.apply(&self.norm.query_norm)?; + let k = k.apply(&self.norm.key_norm)?; + Ok((q, k, v)) + } + + #[allow(unused)] + fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result { + let (q, k, v) = self.qkv(xs)?; + attention(&q, &k, &v, pe)?.apply(&self.proj) + } +} + +#[derive(Debug, Clone)] +struct Mlp { + lin1: Linear, + lin2: Linear, +} + +impl Mlp { + fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result { + let lin1 = candle_nn::linear(in_sz, mlp_sz, vb.pp("0"))?; + let lin2 = candle_nn::linear(mlp_sz, in_sz, vb.pp("2"))?; + Ok(Self { lin1, lin2 }) + } +} + +impl candle::Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2) + } +} + +#[derive(Debug, Clone)] +pub struct DoubleStreamBlock { + img_mod: Modulation, + img_norm1: LayerNorm, + img_attn: SelfAttention, + img_norm2: LayerNorm, + img_mlp: Mlp, + txt_mod: Modulation, + txt_norm1: LayerNorm, + txt_attn: SelfAttention, + txt_norm2: LayerNorm, + txt_mlp: Mlp, +} + +impl DoubleStreamBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let h_sz = cfg.hidden_size; + let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize; + let img_mod = Modulation::new(h_sz, true, vb.pp("img_mod"))?; + let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?; + let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?; + let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?; + let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?; + let txt_mod = Modulation::new(h_sz, true, vb.pp("txt_mod"))?; + let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?; + let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?; + let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?; + let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?; + Ok(Self { + img_mod, + img_norm1, + img_attn, + img_norm2, + img_mlp, + txt_mod, + txt_norm1, + txt_attn, + txt_norm2, + txt_mlp, + }) + } + + fn forward( + &self, + img: &Tensor, + txt: &Tensor, + vec_: &Tensor, + pe: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let img_mod = self.img_mod.forward(vec_)?; // shift, scale, gate + let txt_mod = self.txt_mod.forward(vec_)?; // shift, scale, gate + let img_modulated = img.apply(&self.img_norm1)?; + let img_modulated = img_modulated + .broadcast_mul(&(&img_mod[1] + 1.)?)? + .broadcast_add(&img_mod[0])?; + let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?; + + let txt_modulated = txt.apply(&self.txt_norm1)?; + let txt_modulated = txt_modulated + .broadcast_mul(&(&txt_mod[1] + 1.)?)? + .broadcast_add(&txt_mod[0])?; + let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?; + + let q = Tensor::cat(&[txt_q, img_q], 2)?; + let k = Tensor::cat(&[txt_k, img_k], 2)?; + let v = Tensor::cat(&[txt_v, img_v], 2)?; + + let attn = attention(&q, &k, &v, pe)?; + let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?; + let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?; + + let img = (img + + img_attn + .apply(&self.img_attn.proj)? + .broadcast_mul(&img_mod[2]))?; + let img = (&img + + &img_mod[5].broadcast_mul( + &img.apply(&self.img_norm2)? + .broadcast_mul(&(&img_mod[4] + 1.0)?)? + .broadcast_add(&img_mod[3])? + .apply(&self.img_mlp)?, + )?)?; + + let txt = (txt + + txt_attn + .apply(&self.txt_attn.proj)? + .broadcast_mul(&txt_mod[2]))?; + let txt = (&txt + + &txt_mod[5].broadcast_mul( + &txt.apply(&self.txt_norm2)? + .broadcast_mul(&(&txt_mod[4] + 1.0)?)? + .broadcast_add(&txt_mod[3])? + .apply(&self.txt_mlp)?, + )?)?; + + Ok((img, txt)) + } +} + +#[derive(Debug, Clone)] +pub struct SingleStreamBlock { + linear1: Linear, + linear2: Linear, + norm: QkNorm, + pre_norm: LayerNorm, + modulation: Modulation, + h_sz: usize, + mlp_sz: usize, + num_heads: usize, +} + +impl SingleStreamBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let h_sz = cfg.hidden_size; + let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize; + let head_dim = h_sz / cfg.num_heads; + let linear1 = candle_nn::linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?; + let linear2 = candle_nn::linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?; + let norm = QkNorm::new(head_dim, vb.pp("norm"))?; + let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?; + let modulation = Modulation::new(h_sz, false, vb.pp("modulation"))?; + Ok(Self { + linear1, + linear2, + norm, + pre_norm, + modulation, + h_sz, + mlp_sz, + num_heads: cfg.num_heads, + }) + } + + fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result { + let mod_ = self.modulation.forward(vec_)?; + let (shift, scale, gate) = (&mod_[0], &mod_[1], &mod_[2]); + let x_mod = xs + .apply(&self.pre_norm)? + .broadcast_mul(&(scale + 1.0)?)? + .broadcast_add(shift)?; + let x_mod = x_mod.apply(&self.linear1)?; + let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?; + let (b, l, _khd) = qkv.dims3()?; + let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?; + let q = qkv.i((.., .., 0))?.transpose(1, 2)?; + let k = qkv.i((.., .., 1))?.transpose(1, 2)?; + let v = qkv.i((.., .., 2))?.transpose(1, 2)?; + let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?; + let q = q.apply(&self.norm.query_norm)?; + let k = k.apply(&self.norm.key_norm)?; + let attn = attention(&q, &k, &v, pe)?; + let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?; + xs + gate.broadcast_mul(&output) + } +} + +#[derive(Debug, Clone)] +pub struct LastLayer { + norm_final: LayerNorm, + linear: Linear, + ada_ln_modulation: Linear, +} + +impl LastLayer { + fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result { + let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?; + let linear = candle_nn::linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?; + let ada_ln_modulation = candle_nn::linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?; + Ok(Self { + norm_final, + linear, + ada_ln_modulation, + }) + } + + fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result { + let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?; + let (shift, scale) = (&chunks[0], &chunks[1]); + let xs = xs + .apply(&self.norm_final)? + .broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)? + .broadcast_add(&shift.unsqueeze(1)?)?; + xs.apply(&self.linear) + } +} + +#[derive(Debug, Clone)] +pub struct Flux { + img_in: Linear, + txt_in: Linear, + time_in: MlpEmbedder, + vector_in: MlpEmbedder, + guidance_in: Option, + pe_embedder: EmbedNd, + double_blocks: Vec, + single_blocks: Vec, + final_layer: LastLayer, +} + +impl Flux { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let img_in = candle_nn::linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?; + let txt_in = candle_nn::linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?; + let mut double_blocks = Vec::with_capacity(cfg.depth); + let vb_d = vb.pp("double_blocks"); + for idx in 0..cfg.depth { + let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?; + double_blocks.push(db) + } + let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks); + let vb_s = vb.pp("single_blocks"); + for idx in 0..cfg.depth_single_blocks { + let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?; + single_blocks.push(sb) + } + let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?; + let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?; + let guidance_in = if cfg.guidance_embed { + let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?; + Some(mlp) + } else { + None + }; + let final_layer = + LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?; + let pe_dim = cfg.hidden_size / cfg.num_heads; + let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec()); + Ok(Self { + img_in, + txt_in, + time_in, + vector_in, + guidance_in, + pe_embedder, + double_blocks, + single_blocks, + final_layer, + }) + } + + #[allow(clippy::too_many_arguments)] + pub fn forward( + &self, + img: &Tensor, + img_ids: &Tensor, + txt: &Tensor, + txt_ids: &Tensor, + timesteps: &Tensor, + y: &Tensor, + guidance: Option<&Tensor>, + ) -> Result { + if txt.rank() != 3 { + candle::bail!("unexpected shape for txt {:?}", txt.shape()) + } + if img.rank() != 3 { + candle::bail!("unexpected shape for img {:?}", img.shape()) + } + let dtype = img.dtype(); + let pe = { + let ids = Tensor::cat(&[txt_ids, img_ids], 1)?; + ids.apply(&self.pe_embedder)? + }; + let mut txt = txt.apply(&self.txt_in)?; + let mut img = img.apply(&self.img_in)?; + let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?; + let vec_ = match (self.guidance_in.as_ref(), guidance) { + (Some(g_in), Some(guidance)) => { + (vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))? + } + _ => vec_, + }; + let vec_ = (vec_ + y.apply(&self.vector_in))?; + + // Double blocks + for block in self.double_blocks.iter() { + (img, txt) = block.forward(&img, &txt, &vec_, &pe)? + } + // Single blocks + let mut img = Tensor::cat(&[&txt, &img], 1)?; + for block in self.single_blocks.iter() { + img = block.forward(&img, &vec_, &pe)?; + } + let img = img.i((.., txt.dim(1)?..))?; + self.final_layer.forward(&img, &vec_) + } +} diff --git a/candle-transformers/src/models/flux/sampling.rs b/candle-transformers/src/models/flux/sampling.rs new file mode 100644 index 0000000000..89b9a95382 --- /dev/null +++ b/candle-transformers/src/models/flux/sampling.rs @@ -0,0 +1,119 @@ +use candle::{Device, Result, Tensor}; + +pub fn get_noise( + num_samples: usize, + height: usize, + width: usize, + device: &Device, +) -> Result { + let height = (height + 15) / 16 * 2; + let width = (width + 15) / 16 * 2; + Tensor::randn(0f32, 1., (num_samples, 16, height, width), device) +} + +#[derive(Debug, Clone)] +pub struct State { + pub img: Tensor, + pub img_ids: Tensor, + pub txt: Tensor, + pub txt_ids: Tensor, + pub vec: Tensor, +} + +impl State { + pub fn new(t5_emb: &Tensor, clip_emb: &Tensor, img: &Tensor) -> Result { + let dtype = img.dtype(); + let (bs, c, h, w) = img.dims4()?; + let dev = img.device(); + let img = img.reshape((bs, c, h / 2, 2, w / 2, 2))?; // (b, c, h, ph, w, pw) + let img = img.permute((0, 2, 4, 1, 3, 5))?; // (b, h, w, c, ph, pw) + let img = img.reshape((bs, h / 2 * w / 2, c * 4))?; + let img_ids = Tensor::stack( + &[ + Tensor::full(0u32, (h / 2, w / 2), dev)?, + Tensor::arange(0u32, h as u32 / 2, dev)? + .reshape(((), 1))? + .broadcast_as((h / 2, w / 2))?, + Tensor::arange(0u32, w as u32 / 2, dev)? + .reshape((1, ()))? + .broadcast_as((h / 2, w / 2))?, + ], + 2, + )? + .to_dtype(dtype)?; + let img_ids = img_ids.reshape((1, h / 2 * w / 2, 3))?; + let img_ids = img_ids.repeat((bs, 1, 1))?; + let txt = t5_emb.repeat(bs)?; + let txt_ids = Tensor::zeros((bs, txt.dim(1)?, 3), dtype, dev)?; + let vec = clip_emb.repeat(bs)?; + Ok(Self { + img, + img_ids, + txt, + txt_ids, + vec, + }) + } +} + +fn time_shift(mu: f64, sigma: f64, t: f64) -> f64 { + let e = mu.exp(); + e / (e + (1. / t - 1.).powf(sigma)) +} + +/// `shift` is a triple `(image_seq_len, base_shift, max_shift)`. +pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec { + let timesteps: Vec = (0..=num_steps) + .map(|v| v as f64 / num_steps as f64) + .rev() + .collect(); + match shift { + None => timesteps, + Some((image_seq_len, y1, y2)) => { + let (x1, x2) = (256., 4096.); + let m = (y2 - y1) / (x2 - x1); + let b = y1 - m * x1; + let mu = m * image_seq_len as f64 + b; + timesteps + .into_iter() + .map(|v| time_shift(mu, 1., v)) + .collect() + } + } +} + +pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result { + let (b, _h_w, c_ph_pw) = xs.dims3()?; + let height = (height + 15) / 16; + let width = (width + 15) / 16; + xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw) + .permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw) + .reshape((b, c_ph_pw / 4, height * 2, width * 2)) +} + +#[allow(clippy::too_many_arguments)] +pub fn denoise( + model: &super::model::Flux, + img: &Tensor, + img_ids: &Tensor, + txt: &Tensor, + txt_ids: &Tensor, + vec_: &Tensor, + timesteps: &[f64], + guidance: f64, +) -> Result { + let b_sz = img.dim(0)?; + let dev = img.device(); + let guidance = Tensor::full(guidance as f32, b_sz, dev)?; + let mut img = img.clone(); + for window in timesteps.windows(2) { + let (t_curr, t_prev) = match window { + [a, b] => (a, b), + _ => continue, + }; + let t_vec = Tensor::full(*t_curr as f32, b_sz, dev)?; + let pred = model.forward(&img, img_ids, txt, txt_ids, &t_vec, vec_, Some(&guidance))?; + img = (img + pred * (t_prev - t_curr))? + } + Ok(img) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 4c9c523ede..4daebec299 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -17,6 +17,7 @@ pub mod efficientvit; pub mod encodec; pub mod eva2; pub mod falcon; +pub mod flux; pub mod gemma; pub mod hiera; pub mod jina_bert; From 0a146d77eff0423ca67eb5de2f40c4eef033812d Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 4 Aug 2024 10:09:54 +0100 Subject: [PATCH 34/75] Simplify handling of flux modulations. (#2394) --- candle-transformers/src/models/flux/model.rs | 134 ++++++++++++------- 1 file changed, 88 insertions(+), 46 deletions(-) diff --git a/candle-transformers/src/models/flux/model.rs b/candle-transformers/src/models/flux/model.rs index aa00077e66..4e47873fe0 100644 --- a/candle-transformers/src/models/flux/model.rs +++ b/candle-transformers/src/models/flux/model.rs @@ -212,24 +212,82 @@ impl QkNorm { } } +struct ModulationOut { + shift: Tensor, + scale: Tensor, + gate: Tensor, +} + +impl ModulationOut { + fn scale_shift(&self, xs: &Tensor) -> Result { + xs.broadcast_mul(&(&self.scale + 1.)?)? + .broadcast_add(&self.shift) + } + + fn gate(&self, xs: &Tensor) -> Result { + self.gate.broadcast_mul(xs) + } +} + #[derive(Debug, Clone)] -pub struct Modulation { +struct Modulation1 { lin: Linear, - multiplier: usize, } -impl Modulation { - fn new(dim: usize, double: bool, vb: VarBuilder) -> Result { - let multiplier = if double { 6 } else { 3 }; - let lin = candle_nn::linear(dim, multiplier * dim, vb.pp("lin"))?; - Ok(Self { lin, multiplier }) +impl Modulation1 { + fn new(dim: usize, vb: VarBuilder) -> Result { + let lin = candle_nn::linear(dim, 3 * dim, vb.pp("lin"))?; + Ok(Self { lin }) } - fn forward(&self, vec_: &Tensor) -> Result> { - vec_.silu()? + fn forward(&self, vec_: &Tensor) -> Result { + let ys = vec_ + .silu()? .apply(&self.lin)? .unsqueeze(1)? - .chunk(self.multiplier, D::Minus1) + .chunk(3, D::Minus1)?; + if ys.len() != 3 { + candle::bail!("unexpected len from chunk {ys:?}") + } + Ok(ModulationOut { + shift: ys[0].clone(), + scale: ys[1].clone(), + gate: ys[2].clone(), + }) + } +} + +#[derive(Debug, Clone)] +struct Modulation2 { + lin: Linear, +} + +impl Modulation2 { + fn new(dim: usize, vb: VarBuilder) -> Result { + let lin = candle_nn::linear(dim, 6 * dim, vb.pp("lin"))?; + Ok(Self { lin }) + } + + fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> { + let ys = vec_ + .silu()? + .apply(&self.lin)? + .unsqueeze(1)? + .chunk(6, D::Minus1)?; + if ys.len() != 6 { + candle::bail!("unexpected len from chunk {ys:?}") + } + let mod1 = ModulationOut { + shift: ys[0].clone(), + scale: ys[1].clone(), + gate: ys[2].clone(), + }; + let mod2 = ModulationOut { + shift: ys[3].clone(), + scale: ys[4].clone(), + gate: ys[5].clone(), + }; + Ok((mod1, mod2)) } } @@ -296,12 +354,12 @@ impl candle::Module for Mlp { #[derive(Debug, Clone)] pub struct DoubleStreamBlock { - img_mod: Modulation, + img_mod: Modulation2, img_norm1: LayerNorm, img_attn: SelfAttention, img_norm2: LayerNorm, img_mlp: Mlp, - txt_mod: Modulation, + txt_mod: Modulation2, txt_norm1: LayerNorm, txt_attn: SelfAttention, txt_norm2: LayerNorm, @@ -312,12 +370,12 @@ impl DoubleStreamBlock { fn new(cfg: &Config, vb: VarBuilder) -> Result { let h_sz = cfg.hidden_size; let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize; - let img_mod = Modulation::new(h_sz, true, vb.pp("img_mod"))?; + let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?; let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?; let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?; let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?; let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?; - let txt_mod = Modulation::new(h_sz, true, vb.pp("txt_mod"))?; + let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?; let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?; let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?; let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?; @@ -343,18 +401,14 @@ impl DoubleStreamBlock { vec_: &Tensor, pe: &Tensor, ) -> Result<(Tensor, Tensor)> { - let img_mod = self.img_mod.forward(vec_)?; // shift, scale, gate - let txt_mod = self.txt_mod.forward(vec_)?; // shift, scale, gate + let (img_mod1, img_mod2) = self.img_mod.forward(vec_)?; // shift, scale, gate + let (txt_mod1, txt_mod2) = self.txt_mod.forward(vec_)?; // shift, scale, gate let img_modulated = img.apply(&self.img_norm1)?; - let img_modulated = img_modulated - .broadcast_mul(&(&img_mod[1] + 1.)?)? - .broadcast_add(&img_mod[0])?; + let img_modulated = img_mod1.scale_shift(&img_modulated)?; let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?; let txt_modulated = txt.apply(&self.txt_norm1)?; - let txt_modulated = txt_modulated - .broadcast_mul(&(&txt_mod[1] + 1.)?)? - .broadcast_add(&txt_mod[0])?; + let txt_modulated = txt_mod1.scale_shift(&txt_modulated)?; let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?; let q = Tensor::cat(&[txt_q, img_q], 2)?; @@ -365,27 +419,19 @@ impl DoubleStreamBlock { let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?; let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?; - let img = (img - + img_attn - .apply(&self.img_attn.proj)? - .broadcast_mul(&img_mod[2]))?; + let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?; let img = (&img - + &img_mod[5].broadcast_mul( - &img.apply(&self.img_norm2)? - .broadcast_mul(&(&img_mod[4] + 1.0)?)? - .broadcast_add(&img_mod[3])? + + img_mod2.gate( + &img_mod2 + .scale_shift(&img.apply(&self.img_norm2)?)? .apply(&self.img_mlp)?, )?)?; - let txt = (txt - + txt_attn - .apply(&self.txt_attn.proj)? - .broadcast_mul(&txt_mod[2]))?; + let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?; let txt = (&txt - + &txt_mod[5].broadcast_mul( - &txt.apply(&self.txt_norm2)? - .broadcast_mul(&(&txt_mod[4] + 1.0)?)? - .broadcast_add(&txt_mod[3])? + + txt_mod2.gate( + &txt_mod2 + .scale_shift(&txt.apply(&self.txt_norm2)?)? .apply(&self.txt_mlp)?, )?)?; @@ -399,7 +445,7 @@ pub struct SingleStreamBlock { linear2: Linear, norm: QkNorm, pre_norm: LayerNorm, - modulation: Modulation, + modulation: Modulation1, h_sz: usize, mlp_sz: usize, num_heads: usize, @@ -414,7 +460,7 @@ impl SingleStreamBlock { let linear2 = candle_nn::linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?; let norm = QkNorm::new(head_dim, vb.pp("norm"))?; let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?; - let modulation = Modulation::new(h_sz, false, vb.pp("modulation"))?; + let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?; Ok(Self { linear1, linear2, @@ -429,11 +475,7 @@ impl SingleStreamBlock { fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result { let mod_ = self.modulation.forward(vec_)?; - let (shift, scale, gate) = (&mod_[0], &mod_[1], &mod_[2]); - let x_mod = xs - .apply(&self.pre_norm)? - .broadcast_mul(&(scale + 1.0)?)? - .broadcast_add(shift)?; + let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?; let x_mod = x_mod.apply(&self.linear1)?; let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?; let (b, l, _khd) = qkv.dims3()?; @@ -446,7 +488,7 @@ impl SingleStreamBlock { let k = k.apply(&self.norm.key_norm)?; let attn = attention(&q, &k, &v, pe)?; let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?; - xs + gate.broadcast_mul(&output) + xs + mod_.gate(&output) } } From 0f55c37c37b6e447726e37e57ef3c75dabf658aa Mon Sep 17 00:00:00 2001 From: MilkFather <31627231+MilkFather@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:24:17 +0800 Subject: [PATCH 35/75] optimize gradient for silu a bit (#2393) --- candle-core/src/backprop.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 3ea03b0b40..a556677478 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -623,9 +623,9 @@ impl Tensor { } Op::Unary(arg, UnaryOp::Silu) => { let sum_grad = grads.or_insert(arg)?; - // d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) + // d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?; - let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?; + let silu_grad = &sigmoid_arg * (1. - *node) + *node; *sum_grad = sum_grad.add(&(&grad * silu_grad)?)? } Op::Elu(arg, alpha) => { From aef4ebac550ebdf66786fc0c4bb3c0d4ed5a8d9f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 4 Aug 2024 11:16:24 +0100 Subject: [PATCH 36/75] Support the flux-dev model too. (#2395) --- candle-examples/examples/flux/main.rs | 46 +++++++++++++++++++++------ 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/candle-examples/examples/flux/main.rs b/candle-examples/examples/flux/main.rs index 826174bc69..a9278d013d 100644 --- a/candle-examples/examples/flux/main.rs +++ b/candle-examples/examples/flux/main.rs @@ -37,6 +37,15 @@ struct Args { #[arg(long)] decode_only: Option, + + #[arg(long, value_enum, default_value = "schnell")] + model: Model, +} + +#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)] +enum Model { + Schnell, + Dev, } fn run(args: Args) -> Result<()> { @@ -50,6 +59,7 @@ fn run(args: Args) -> Result<()> { width, tracing, decode_only, + model, } = args; let width = width.unwrap_or(1360); let height = height.unwrap_or(768); @@ -63,9 +73,13 @@ fn run(args: Args) -> Result<()> { }; let api = hf_hub::api::sync::Api::new()?; - let bf_repo = api.repo(hf_hub::Repo::model( - "black-forest-labs/FLUX.1-schnell".to_string(), - )); + let bf_repo = { + let name = match model { + Model::Dev => "black-forest-labs/FLUX.1-dev", + Model::Schnell => "black-forest-labs/FLUX.1-schnell", + }; + api.repo(hf_hub::Repo::model(name.to_string())) + }; let device = candle_examples::device(cpu)?; let dtype = device.bf16_default_to_f32(); let img = match decode_only { @@ -132,16 +146,27 @@ fn run(args: Args) -> Result<()> { }; println!("CLIP\n{clip_emb}"); let img = { - let model_file = bf_repo.get("flux1-schnell.sft")?; + let model_file = match model { + Model::Schnell => bf_repo.get("flux1-schnell.sft")?, + Model::Dev => bf_repo.get("flux1-dev.sft")?, + }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; - let cfg = flux::model::Config::schnell(); - let model = flux::model::Flux::new(&cfg, vb)?; - + let cfg = match model { + Model::Dev => flux::model::Config::dev(), + Model::Schnell => flux::model::Config::schnell(), + }; let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?; let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?; + let timesteps = match model { + Model::Dev => { + flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15))) + } + Model::Schnell => flux::sampling::get_schedule(4, None), + }; + let model = flux::model::Flux::new(&cfg, vb)?; + println!("{state:?}"); - let timesteps = flux::sampling::get_schedule(4, None); // no shift for flux-schnell println!("{timesteps:?}"); flux::sampling::denoise( &model, @@ -166,7 +191,10 @@ fn run(args: Args) -> Result<()> { let img = { let model_file = bf_repo.get("ae.sft")?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; - let cfg = flux::autoencoder::Config::schnell(); + let cfg = match model { + Model::Dev => flux::autoencoder::Config::dev(), + Model::Schnell => flux::autoencoder::Config::schnell(), + }; let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?; model.decode(&img)? }; From c301efa449e1cc386d56440ba8d6ca1ca72cd72e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 4 Aug 2024 18:52:40 +0100 Subject: [PATCH 37/75] Support for mistral-nemo. (#2396) --- candle-examples/examples/mistral/main.rs | 21 ++++++++++++++------- candle-transformers/src/models/mistral.rs | 17 ++++++++++++----- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index 39cf61422d..66265488a0 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -149,6 +149,10 @@ enum Which { Mistral7bInstructV02, #[value(name = "7b-maths-v0.1")] Mathstral7bV01, + #[value(name = "nemo-2407")] + MistralNemo2407, + #[value(name = "nemo-instruct-2407")] + MistralNemoInstruct2407, } #[derive(Parser, Debug)] @@ -263,13 +267,16 @@ fn main() -> Result<()> { } "lmz/candle-mistral".to_string() } else { - match args.which { - Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(), - Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(), - Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(), - Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(), - Which::Mathstral7bV01 => "mistralai/mathstral-7B-v0.1".to_string(), - } + let name = match args.which { + Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1", + Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2", + Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1", + Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2", + Which::Mathstral7bV01 => "mistralai/mathstral-7B-v0.1", + Which::MistralNemo2407 => "mistralai/Mistral-Nemo-Base-2407", + Which::MistralNemoInstruct2407 => "mistralai/Mistral-Nemo-Instruct-2407", + }; + name.to_string() } } }; diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index f7b70e6cbf..0d590399e7 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -15,6 +15,7 @@ pub struct Config { pub intermediate_size: usize, pub num_hidden_layers: usize, pub num_attention_heads: usize, + pub head_dim: Option, pub num_key_value_heads: usize, pub hidden_act: Activation, pub max_position_embeddings: usize, @@ -34,6 +35,7 @@ impl Config { intermediate_size: 14336, num_hidden_layers: 32, num_attention_heads: 32, + head_dim: None, num_key_value_heads: 8, hidden_act: Activation::Silu, max_position_embeddings: 32768, @@ -53,6 +55,7 @@ impl Config { intermediate_size: 14336, num_hidden_layers: 32, num_attention_heads: 32, + head_dim: None, num_key_value_heads: 8, hidden_act: Activation::Silu, max_position_embeddings: 32768, @@ -71,6 +74,7 @@ impl Config { intermediate_size: 14336, num_hidden_layers: 32, num_attention_heads: 32, + head_dim: None, num_key_value_heads: 8, hidden_act: Activation::Silu, max_position_embeddings: 32768, @@ -80,6 +84,11 @@ impl Config { use_flash_attn, } } + + fn head_dim(&self) -> usize { + self.head_dim + .unwrap_or(self.hidden_size / self.num_attention_heads) + } } #[derive(Debug, Clone)] @@ -91,7 +100,7 @@ struct RotaryEmbedding { impl RotaryEmbedding { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { let rope_theta = cfg.rope_theta as f32; - let dim = cfg.hidden_size / cfg.num_attention_heads; + let dim = cfg.head_dim(); let max_seq_len = cfg.max_position_embeddings; let inv_freq: Vec<_> = (0..dim) .step_by(2) @@ -167,7 +176,6 @@ struct Attention { num_kv_heads: usize, num_kv_groups: usize, head_dim: usize, - hidden_size: usize, rotary_emb: Arc, kv_cache: Option<(Tensor, Tensor)>, use_flash_attn: bool, @@ -179,7 +187,7 @@ impl Attention { let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; let num_kv_groups = num_heads / num_kv_heads; - let head_dim = hidden_sz / num_heads; + let head_dim = cfg.head_dim(); let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; @@ -193,7 +201,6 @@ impl Attention { num_kv_heads, num_kv_groups, head_dim, - hidden_size: hidden_sz, rotary_emb, kv_cache: None, use_flash_attn: cfg.use_flash_attn, @@ -254,7 +261,7 @@ impl Attention { attn_output .transpose(1, 2)? - .reshape((b_sz, q_len, self.hidden_size))? + .reshape((b_sz, q_len, self.num_heads * self.head_dim))? .apply(&self.o_proj) } From fd0e93391c2c8d2f87494232a2c4ddafbc0386f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=94=90=E7=92=9C?= <113148619+donjuanplatinum@users.noreply.github.com> Date: Mon, 5 Aug 2024 11:48:09 -0400 Subject: [PATCH 38/75] add models support and example for THUDM/glm-4 (#2362) * add models support and example for THUDM/glm-4 * fix the ci report * fmt * fix * Update README.org * Update README.org * fmt * Update README.org * README.md add codegeex4 * README.md add glm4 * Typo. * change expect into ? --------- Co-authored-by: Laurent Mazare --- README.md | 2 + candle-examples/examples/glm4/README.org | 77 +++ candle-examples/examples/glm4/main.rs | 255 ++++++++++ candle-transformers/src/models/glm4.rs | 595 +++++++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + 5 files changed, 930 insertions(+) create mode 100644 candle-examples/examples/glm4/README.org create mode 100644 candle-examples/examples/glm4/main.rs create mode 100644 candle-transformers/src/models/glm4.rs diff --git a/README.md b/README.md index 765625bd70..adea9942ef 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,8 @@ We also provide a some command line based examples using state of the art models - [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes the SOLAR-10.7B variant. - [Falcon](./candle-examples/examples/falcon/): general LLM. +- [Codegeex4](./candle-examples/examples/codegeex4-9b/): Code completion,code interpreter,web search,fuction calling,repository-level +- [GLM4](./candle-examples/examples/glm4/): Open Multilingual Multimodal Chat LMs by THUDM - [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind. - [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b Griffin based models from Google that mix attention with a RNN like state. diff --git a/candle-examples/examples/glm4/README.org b/candle-examples/examples/glm4/README.org new file mode 100644 index 0000000000..364f61e8eb --- /dev/null +++ b/candle-examples/examples/glm4/README.org @@ -0,0 +1,77 @@ +* GLM4 +GLM-4-9B is the open-source version of the latest generation of pre-trained models in the GLM-4 series launched by Zhipu AI. + +- [[https://github.com/THUDM/GLM4][Github]] +- [[https://huggingface.co/THUDM/glm-4-9b][huggingface]] + +** Running with ~cuda~ + +#+begin_src shell + cargo run --example glm4 --release --features cuda +#+end_src + +** Running with ~cpu~ +#+begin_src shell + cargo run --example glm4 --release -- --cpu +#+end_src + +** Output Example +#+begin_src shell +cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache . + Finished release [optimized] target(s) in 0.24s + Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .` +avx: true, neon: false, simd128: false, f16c: true +temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64 +cache path . +retrieved the files in 6.88963ms +loaded the model in 6.113752297s +starting the inference loop +[欢迎使用GLM-4,请输入prompt] +请你告诉我什么是FFT +266 tokens generated (34.50 token/s) +Result: +。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换(DFT)的方法,它广泛应用于信号处理、图像处理和数据分析等领域。 + +具体来说,FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中,我们通常需要知道信号的频率成分,这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高,而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。 + +以下是使用 Python 中的 numpy 进行 FFT 的简单示例: + +```python +import numpy as np + +# 创建一个时域信号 +t = np.linspace(0, 1, num=100) +f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t) + +# 对该信号做FFT变换,并计算其幅值谱 +fft_result = np.fft.fftshift(np.abs(np.fft.fft(f))) + +``` + +在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。 +#+end_src + +This example will read prompt from stdin + +* Citation +#+begin_src + @misc{glm2024chatglm, + title={ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools}, + author={Team GLM and Aohan Zeng and Bin Xu and Bowen Wang and Chenhui Zhang and Da Yin and Diego Rojas and Guanyu Feng and Hanlin Zhao and Hanyu Lai and Hao Yu and Hongning Wang and Jiadai Sun and Jiajie Zhang and Jiale Cheng and Jiayi Gui and Jie Tang and Jing Zhang and Juanzi Li and Lei Zhao and Lindong Wu and Lucen Zhong and Mingdao Liu and Minlie Huang and Peng Zhang and Qinkai Zheng and Rui Lu and Shuaiqi Duan and Shudan Zhang and Shulin Cao and Shuxun Yang and Weng Lam Tam and Wenyi Zhao and Xiao Liu and Xiao Xia and Xiaohan Zhang and Xiaotao Gu and Xin Lv and Xinghan Liu and Xinyi Liu and Xinyue Yang and Xixuan Song and Xunkai Zhang and Yifan An and Yifan Xu and Yilin Niu and Yuantao Yang and Yueyan Li and Yushi Bai and Yuxiao Dong and Zehan Qi and Zhaoyu Wang and Zhen Yang and Zhengxiao Du and Zhenyu Hou and Zihan Wang}, + year={2024}, + eprint={2406.12793}, + archivePrefix={arXiv}, + primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'} +} +#+end_src + +#+begin_src + @misc{wang2023cogvlm, + title={CogVLM: Visual Expert for Pretrained Language Models}, + author={Weihan Wang and Qingsong Lv and Wenmeng Yu and Wenyi Hong and Ji Qi and Yan Wang and Junhui Ji and Zhuoyi Yang and Lei Zhao and Xixuan Song and Jiazheng Xu and Bin Xu and Juanzi Li and Yuxiao Dong and Ming Ding and Jie Tang}, + year={2023}, + eprint={2311.03079}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +#+end_src diff --git a/candle-examples/examples/glm4/main.rs b/candle-examples/examples/glm4/main.rs new file mode 100644 index 0000000000..55a27f349e --- /dev/null +++ b/candle-examples/examples/glm4/main.rs @@ -0,0 +1,255 @@ +use candle_transformers::models::glm4::*; +use clap::Parser; + +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + device: Device, + tokenizer: Tokenizer, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, + verbose_prompt: bool, + dtype: DType, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + tokenizer: Tokenizer, + seed: u64, + temp: Option, + top_p: Option, + repeat_penalty: f32, + repeat_last_n: usize, + verbose_prompt: bool, + device: &Device, + dtype: DType, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + tokenizer, + logits_processor, + repeat_penalty, + repeat_last_n, + verbose_prompt, + device: device.clone(), + dtype, + } + } + + fn run(&mut self, sample_len: usize) -> anyhow::Result<()> { + use std::io::BufRead; + use std::io::BufReader; + use std::io::Write; + println!("starting the inference loop"); + println!("[欢迎使用GLM-4,请输入prompt]"); + let stdin = std::io::stdin(); + let reader = BufReader::new(stdin); + for line in reader.lines() { + let line = line.expect("Failed to read line"); + + let tokens = self.tokenizer.encode(line, true).expect("tokens error"); + if tokens.is_empty() { + panic!("Empty prompts are not supported in the chatglm model.") + } + if self.verbose_prompt { + for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { + let token = token.replace('▁', " ").replace("<0x0A>", "\n"); + println!("{id:7} -> '{token}'"); + } + } + let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { + Some(token) => *token, + None => panic!("cannot find the endoftext token"), + }; + let mut tokens = tokens.get_ids().to_vec(); + let mut generated_tokens = 0usize; + + std::io::stdout().flush().expect("output flush error"); + let start_gen = std::time::Instant::now(); + + let mut count = 0; + let mut result = vec![]; + for index in 0..sample_len { + count += 1; + let context_size = if index > 0 { 1 } else { tokens.len() }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input)?; + let logits = logits.squeeze(0)?.to_dtype(self.dtype)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + break; + } + let token = self + .tokenizer + .decode(&[next_token], true) + .expect("Token error"); + if self.verbose_prompt { + println!( + "[Count: {}] [Raw Token: {}] [Decode Token: {}]", + count, next_token, token + ); + } + result.push(token); + std::io::stdout().flush()?; + } + let dt = start_gen.elapsed(); + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + println!("Result:"); + for tokens in result { + print!("{tokens}"); + } + self.model.reset_kv_cache(); // clean the cache + } + Ok(()) + } +} +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(name = "cache", short, long, default_value = ".")] + cache_path: String, + + #[arg(long)] + cpu: bool, + + /// Display the token for the specified prompt. + #[arg(long)] + verbose_prompt: bool, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 8192)] + sample_len: usize, + + #[arg(long)] + model_id: Option, + + #[arg(long)] + revision: Option, + + #[arg(long)] + weight_file: Option, + + #[arg(long)] + tokenizer: Option, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.2)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, +} + +fn main() -> anyhow::Result<()> { + let args = Args::parse(); + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.6), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + println!("cache path {}", args.cache_path); + let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into())) + .build() + .map_err(anyhow::Error::msg)?; + + let model_id = match args.model_id { + Some(model_id) => model_id.to_string(), + None => "THUDM/glm-4-9b".to_string(), + }; + let revision = match args.revision { + Some(rev) => rev.to_string(), + None => "main".to_string(), + }; + let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + let tokenizer_filename = match args.tokenizer { + Some(file) => std::path::PathBuf::from(file), + None => api + .model("THUDM/codegeex4-all-9b".to_string()) + .get("tokenizer.json") + .map_err(anyhow::Error::msg)?, + }; + let filenames = match args.weight_file { + Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error"); + + let start = std::time::Instant::now(); + let config = Config::glm4(); + let device = candle_examples::device(args.cpu)?; + let dtype = if device.is_cuda() { + DType::BF16 + } else { + DType::F32 + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + let model = Model::new(&config, vb)?; + + println!("loaded the model in {:?}", start.elapsed()); + + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + args.verbose_prompt, + &device, + dtype, + ); + pipeline.run(args.sample_len)?; + Ok(()) +} diff --git a/candle-transformers/src/models/glm4.rs b/candle-transformers/src/models/glm4.rs new file mode 100644 index 0000000000..3b436eaa6d --- /dev/null +++ b/candle-transformers/src/models/glm4.rs @@ -0,0 +1,595 @@ +use crate::models::with_tracing::{linear_b as linear, Linear}; +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug, Clone)] +pub struct Config { + pub num_layers: usize, + pub padded_vocab_size: usize, + pub hidden_size: usize, + pub ffn_hidden_size: usize, + pub kv_channels: usize, + pub num_attention_heads: usize, + pub seq_length: usize, + pub layernorm_epsilon: f64, + pub rmsnorm: bool, + pub apply_residual_connection_post_layernorm: bool, + pub post_layer_norm: bool, + pub add_bias_linear: bool, + pub add_qkv_bias: bool, + pub bias_dropout_fusion: bool, + pub multi_query_attention: bool, + pub multi_query_group_num: usize, + pub apply_query_key_layer_scaling: bool, + pub attention_softmax_in_fp32: bool, + pub fp32_residual_connection: bool, +} + +impl Config { + pub fn glm4() -> Self { + Self { + num_layers: 40, + padded_vocab_size: 151552, + hidden_size: 4096, + ffn_hidden_size: 13696, + kv_channels: 128, + num_attention_heads: 32, + seq_length: 8192, + layernorm_epsilon: 1e-5, + rmsnorm: true, + apply_residual_connection_post_layernorm: false, + post_layer_norm: true, + add_bias_linear: false, + add_qkv_bias: true, + bias_dropout_fusion: true, + multi_query_attention: true, + multi_query_group_num: 2, + apply_query_key_layer_scaling: true, + attention_softmax_in_fp32: true, + fp32_residual_connection: false, + } + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + cache: Tensor, +} + +impl RotaryEmbedding { + fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result { + let rotary_dim = cfg.kv_channels; + let n_elem = rotary_dim / 2; + let inv_freq: Vec<_> = (0..n_elem) + .step_by(2) + .map(|i| 1f32 / 10_000f64.powf(i as f64 / n_elem as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)? + .to_dtype(dtype)? + .reshape((cfg.seq_length, 1))?; + let freqs = t.matmul(&inv_freq)?; + let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?; + Ok(Self { cache }) + } + + fn apply(&self, xs: &Tensor, seqlen_offset: usize) -> Result { + let (seqlen, _b, np, _hn) = xs.dims4()?; + let cache = self.cache.narrow(0, seqlen_offset, seqlen)?; + let rot_dim = cache.dim(D::Minus2)? * 2; + let (xs, xs_pass) = ( + xs.narrow(D::Minus1, 0, rot_dim)?, + xs.narrow(D::Minus1, rot_dim, rot_dim)?, + ); + let xshaped = xs.reshape((seqlen, (), np, rot_dim / 2, 2))?; + let cache = cache.reshape((seqlen, (), 1, rot_dim / 2, 2))?; + let (xshaped0, xshaped1) = ( + xshaped.i((.., .., .., .., 0))?, + xshaped.i((.., .., .., .., 1))?, + ); + let (cache0, cache1) = (cache.i((.., .., .., .., 0))?, cache.i((.., .., .., .., 1))?); + let xs_out = Tensor::stack( + &[ + (xshaped0.broadcast_mul(&cache0)? - xshaped1.broadcast_mul(&cache1)?)?, + (xshaped1.broadcast_mul(&cache0)? + xshaped0.broadcast_mul(&cache1)?)?, + ], + D::Minus1, + )?; + let xs_out = xs_out.flatten_from(3)?; + Tensor::cat(&[xs_out, xs_pass], D::Minus1) + } +} + +#[derive(Debug, Clone)] +struct CoreAttention { + coeff: Option, + norm_factor: f64, + dtype: DType, +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32, dtype: DType) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true.to_dtype(dtype)?, on_false)?; + Ok(m) +} + +impl CoreAttention { + fn new(layer_number: usize, cfg: &Config, dtype: DType) -> Result { + let norm_factor = (cfg.kv_channels as f64).sqrt(); + let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling { + let coeff = f64::max(1.0, layer_number as f64); + (norm_factor * coeff, Some(coeff)) + } else { + (norm_factor, None) + }; + Ok(Self { + coeff, + norm_factor, + dtype, + }) + } + + fn forward( + &self, + query_layer: &Tensor, + key_layer: &Tensor, + value_layer: &Tensor, + attention_mask: &Option, + ) -> Result { + let output_size = ( + query_layer.dim(1)?, // b + query_layer.dim(2)?, // np + query_layer.dim(0)?, // sq + key_layer.dim(0)?, // sk + ); + let query_layer = + query_layer.reshape((output_size.2, output_size.0 * output_size.1, ()))?; + let key_layer = key_layer.reshape((output_size.3, output_size.0 * output_size.1, ()))?; + let matmul_result = Tensor::matmul( + &query_layer.transpose(0, 1)?.contiguous()?, + &key_layer.transpose(0, 1)?.transpose(1, 2)?.contiguous()?, + )?; + let matmul_result = (matmul_result / self.norm_factor)?.reshape(output_size)?; + let matmul_result = match self.coeff { + None => matmul_result, + Some(coeff) => (matmul_result * coeff)?, + }; + let attention_scores = match attention_mask { + Some(mask) => masked_fill( + &matmul_result, + &mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?, + f32::NEG_INFINITY, + self.dtype, + )?, + None => matmul_result, + }; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + + let output_size = ( + value_layer.dim(1)?, + value_layer.dim(2)?, + query_layer.dim(0)?, + value_layer.dim(3)?, + ); + let value_layer = + value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?; + let attention_probs = + attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?; + let context_layer = Tensor::matmul( + &attention_probs.contiguous()?, + &value_layer.transpose(0, 1)?.contiguous()?, + )?; + let context_layer = context_layer.reshape(output_size)?; + let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?; + context_layer.flatten_from(D::Minus2) + } +} + +#[derive(Debug, Clone)] +struct SelfAttention { + query_key_value: Linear, + core_attention: CoreAttention, + dense: Linear, + multi_query_attention: bool, + num_attention_heads_per_partition: usize, + num_multi_query_groups_per_partition: usize, + hidden_size_per_attention_head: usize, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl SelfAttention { + fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result { + let projection_size = cfg.kv_channels * cfg.num_attention_heads; + let hidden_size_per_attention_head = projection_size / cfg.num_attention_heads; + let qkv_hidden_size = if cfg.multi_query_attention { + projection_size + 2 * hidden_size_per_attention_head * cfg.multi_query_group_num + } else { + 3 * projection_size + }; + let query_key_value = linear( + cfg.hidden_size, + qkv_hidden_size, + cfg.add_bias_linear || cfg.add_qkv_bias, + vb.pp("query_key_value"), + )?; + let core_attention = CoreAttention::new(layer_number, cfg, vb.dtype())?; + let dense = linear( + cfg.hidden_size, + cfg.hidden_size, + cfg.add_bias_linear, + vb.pp("dense"), + )?; + Ok(Self { + query_key_value, + core_attention, + dense, + multi_query_attention: cfg.multi_query_attention, + num_attention_heads_per_partition: cfg.num_attention_heads, + num_multi_query_groups_per_partition: cfg.multi_query_group_num, + hidden_size_per_attention_head: cfg.kv_channels, + kv_cache: None, + }) + } + + fn reset_kv_cache(&mut self) { + self.kv_cache = None + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: &Option, + rotary_emb: &RotaryEmbedding, + ) -> Result { + let mixed_x_layer = xs.apply(&self.query_key_value)?; + if !self.multi_query_attention { + candle::bail!("only multi_query_attention=true is supported") + } + let hpa = self.hidden_size_per_attention_head; + let query_layer = + mixed_x_layer.narrow(D::Minus1, 0, self.num_attention_heads_per_partition * hpa)?; + let key_layer = mixed_x_layer.narrow( + D::Minus1, + self.num_attention_heads_per_partition * hpa, + self.num_multi_query_groups_per_partition * hpa, + )?; + let value_layer = mixed_x_layer.narrow( + D::Minus1, + self.num_attention_heads_per_partition * hpa + + self.num_multi_query_groups_per_partition * hpa, + self.num_multi_query_groups_per_partition * hpa, + )?; + let query_layer = query_layer.reshape(( + query_layer.dim(0)?, + query_layer.dim(1)?, + self.num_attention_heads_per_partition, + hpa, + ))?; + let key_layer = key_layer.reshape(( + key_layer.dim(0)?, + key_layer.dim(1)?, + self.num_multi_query_groups_per_partition, + hpa, + ))?; + let value_layer = value_layer.reshape(( + value_layer.dim(0)?, + value_layer.dim(1)?, + self.num_multi_query_groups_per_partition, + hpa, + ))?; + + // Rotary embeddings. + let seqlen_offset = match &self.kv_cache { + None => 0, + Some((prev_k, _)) => prev_k.dim(0)?, + }; + let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?; + let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?; + + // KV cache. + let (key_layer, value_layer) = match &self.kv_cache { + None => (key_layer, value_layer), + Some((prev_k, prev_v)) => { + let k = Tensor::cat(&[prev_k, &key_layer], 0)?; + let v = Tensor::cat(&[prev_v, &value_layer], 0)?; + (k, v) + } + }; + self.kv_cache = Some((key_layer.clone(), value_layer.clone())); + + // Repeat KV. + let ratio = + self.num_attention_heads_per_partition / self.num_multi_query_groups_per_partition; + let key_layer = { + let (d0, d1, d2, d3) = key_layer.dims4()?; + key_layer + .unsqueeze(D::Minus2)? + .expand((d0, d1, d2, ratio, d3))? + .reshape(( + d0, + d1, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ))? + }; + let value_layer = { + let (d0, d1, d2, d3) = value_layer.dims4()?; + value_layer + .unsqueeze(D::Minus2)? + .expand((d0, d1, d2, ratio, d3))? + .reshape(( + d0, + d1, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ))? + }; + + let context_layer = + self.core_attention + .forward(&query_layer, &key_layer, &value_layer, attention_mask)?; + let output = context_layer.apply(&self.dense)?; + Ok(output) + } +} + +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone)] +struct MLP { + dense_h_to_4h: Linear, + dense_4h_to_h: Linear, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dense_h_to_4h = linear( + cfg.hidden_size, + cfg.ffn_hidden_size * 2, + cfg.add_bias_linear, + vb.pp("dense_h_to_4h"), + )?; + let dense_4h_to_h = linear( + cfg.ffn_hidden_size, + cfg.hidden_size, + cfg.add_bias_linear, + vb.pp("dense_4h_to_h"), + )?; + Ok(Self { + dense_4h_to_h, + dense_h_to_4h, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.dense_h_to_4h)? + .apply(&candle_nn::Activation::Swiglu)? + .apply(&self.dense_4h_to_h) + } +} + +#[derive(Debug, Clone)] +struct Block { + input_layernorm: candle_nn::LayerNorm, + self_attention: SelfAttention, + post_attention_layernorm: candle_nn::LayerNorm, + mlp: MLP, + apply_residual_connection_post_layernorm: bool, +} + +impl Block { + fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result { + let input_layernorm = if cfg.rmsnorm { + candle_nn::rms_norm( + cfg.hidden_size, + cfg.layernorm_epsilon, + vb.pp("input_layernorm"), + )? + .into_inner() + } else { + candle_nn::layer_norm( + cfg.hidden_size, + cfg.layernorm_epsilon, + vb.pp("input_layernorm"), + )? + }; + let post_attention_layernorm = if cfg.rmsnorm { + candle_nn::rms_norm( + cfg.hidden_size, + cfg.layernorm_epsilon, + vb.pp("post_attention_layernorm"), + )? + .into_inner() + } else { + candle_nn::layer_norm( + cfg.hidden_size, + cfg.layernorm_epsilon, + vb.pp("post_attention_layernorm"), + )? + }; + let self_attention = SelfAttention::new(layer_number, cfg, vb.pp("self_attention"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + Ok(Self { + input_layernorm, + self_attention, + post_attention_layernorm, + mlp, + apply_residual_connection_post_layernorm: cfg.apply_residual_connection_post_layernorm, + }) + } + + fn reset_kv_cache(&mut self) { + self.self_attention.reset_kv_cache() + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: &Option, + rotary_emb: &RotaryEmbedding, + ) -> Result { + let layernorm_output = xs.apply(&self.input_layernorm)?; + let attention_output = + self.self_attention + .forward(&layernorm_output, attention_mask, rotary_emb)?; + let residual = if self.apply_residual_connection_post_layernorm { + &layernorm_output + } else { + xs + }; + let layernorm_input = (residual + attention_output)?; + let layernorm_output = layernorm_input.apply(&self.post_attention_layernorm)?; + let mlp_output = layernorm_output.apply(&self.mlp)?; + let residual = if self.apply_residual_connection_post_layernorm { + &layernorm_output + } else { + &layernorm_input + }; + mlp_output + residual + } +} + +#[derive(Debug, Clone)] +struct Transformer { + layers: Vec, + final_layernorm: Option, + rotary_emb: RotaryEmbedding, +} + +impl Transformer { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_l = vb.pp("layers"); + let mut layers = Vec::with_capacity(cfg.num_layers); + for layer_index in 0..cfg.num_layers { + let block = Block::new(layer_index + 1, cfg, vb_l.pp(layer_index))?; + layers.push(block) + } + let final_layernorm = if cfg.post_layer_norm { + let ln = if cfg.rmsnorm { + candle_nn::rms_norm( + cfg.hidden_size, + cfg.layernorm_epsilon, + vb.pp("final_layernorm"), + )? + .into_inner() + } else { + candle_nn::layer_norm( + cfg.hidden_size, + cfg.layernorm_epsilon, + vb.pp("final_layernorm"), + )? + }; + Some(ln) + } else { + None + }; + let rotary_emb = RotaryEmbedding::new(cfg, vb.dtype(), vb.device())?; + Ok(Self { + layers, + final_layernorm, + rotary_emb, + }) + } + + fn reset_kv_cache(&mut self) { + for block in self.layers.iter_mut() { + block.reset_kv_cache() + } + } + + fn forward(&mut self, xs: &Tensor, attention_mask: &Option) -> Result { + let mut xs = xs.clone(); + for block in self.layers.iter_mut() { + xs = block.forward(&xs, attention_mask, &self.rotary_emb)? + } + match self.final_layernorm.as_ref() { + None => Ok(xs), + Some(ln) => xs.apply(ln), + } + } +} + +#[derive(Debug, Clone)] +struct Embedding { + word_embeddings: candle_nn::Embedding, + fp32_residual_connection: bool, +} + +impl Embedding { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let word_embeddings = candle_nn::embedding( + cfg.padded_vocab_size, + cfg.hidden_size, + vb.pp("word_embeddings"), + )?; + Ok(Self { + word_embeddings, + fp32_residual_connection: cfg.fp32_residual_connection, + }) + } +} + +impl Module for Embedding { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.word_embeddings.forward(xs)?.transpose(0, 1)?; // b,s,h -> s,b,h + if self.fp32_residual_connection { + xs.to_dtype(candle::DType::F32) + } else { + xs.contiguous() + } + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embedding: Embedding, + encoder: Transformer, + output_layer: Linear, +} + +fn get_mask(size: usize, device: &Device) -> Result { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + Tensor::from_slice(&mask, (size, size), device) +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb = vb.pp("transformer"); + let embedding = Embedding::new(cfg, vb.pp("embedding"))?; + let encoder = Transformer::new(cfg, vb.pp("encoder"))?; + let output_layer = linear( + cfg.hidden_size, + cfg.padded_vocab_size, + false, + vb.pp("output_layer"), + )?; + + Ok(Self { + embedding, + encoder, + output_layer, + }) + } + + pub fn reset_kv_cache(&mut self) { + self.encoder.reset_kv_cache() + } + + pub fn forward(&mut self, xs: &Tensor) -> Result { + let (_b_size, seq_len) = xs.dims2()?; + let input_embeds = xs.apply(&self.embedding)?; + let attention_mask = if seq_len <= 1 { + None + } else { + Some(get_mask(seq_len, xs.device())?) + }; + let xs = self.encoder.forward(&input_embeds, &attention_mask)?; + let lm_logits = xs.i(seq_len - 1)?.apply(&self.output_layer)?; + Ok(lm_logits) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 4daebec299..cf5a240231 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -19,6 +19,7 @@ pub mod eva2; pub mod falcon; pub mod flux; pub mod gemma; +pub mod glm4; pub mod hiera; pub mod jina_bert; pub mod llama; From f8e2b36b252d02be818730bd0d7464142fb04574 Mon Sep 17 00:00:00 2001 From: Czxck001 <10724409+Czxck001@users.noreply.github.com> Date: Mon, 5 Aug 2024 10:26:15 -0700 Subject: [PATCH 39/75] Add the MMDiT model of Stable Diffusion 3 (#2397) * add mmdit of stable diffusion 3 lint add comments * correct a misplaced comment * fix cargo fmt * fix clippy error * use bail! instead of assert! * use get_on_dim in splitting qkv --- .../src/models/mmdit/blocks.rs | 294 ++++++++++++++++++ .../src/models/mmdit/embedding.rs | 197 ++++++++++++ candle-transformers/src/models/mmdit/mod.rs | 4 + candle-transformers/src/models/mmdit/model.rs | 173 +++++++++++ .../src/models/mmdit/projections.rs | 94 ++++++ candle-transformers/src/models/mod.rs | 1 + 6 files changed, 763 insertions(+) create mode 100644 candle-transformers/src/models/mmdit/blocks.rs create mode 100644 candle-transformers/src/models/mmdit/embedding.rs create mode 100644 candle-transformers/src/models/mmdit/mod.rs create mode 100644 candle-transformers/src/models/mmdit/model.rs create mode 100644 candle-transformers/src/models/mmdit/projections.rs diff --git a/candle-transformers/src/models/mmdit/blocks.rs b/candle-transformers/src/models/mmdit/blocks.rs new file mode 100644 index 0000000000..e2b924a013 --- /dev/null +++ b/candle-transformers/src/models/mmdit/blocks.rs @@ -0,0 +1,294 @@ +use candle::{Module, Result, Tensor, D}; +use candle_nn as nn; + +use super::projections::{AttnProjections, Mlp, Qkv, QkvOnlyAttnProjections}; + +pub struct ModulateIntermediates { + gate_msa: Tensor, + shift_mlp: Tensor, + scale_mlp: Tensor, + gate_mlp: Tensor, +} + +pub struct DiTBlock { + norm1: LayerNormNoAffine, + attn: AttnProjections, + norm2: LayerNormNoAffine, + mlp: Mlp, + ada_ln_modulation: nn::Sequential, +} + +pub struct LayerNormNoAffine { + eps: f64, +} + +impl LayerNormNoAffine { + pub fn new(eps: f64) -> Self { + Self { eps } + } +} + +impl Module for LayerNormNoAffine { + fn forward(&self, x: &Tensor) -> Result { + nn::LayerNorm::new_no_bias(Tensor::ones_like(x)?, self.eps).forward(x) + } +} + +impl DiTBlock { + pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + // {'hidden_size': 1536, 'num_heads': 24} + let norm1 = LayerNormNoAffine::new(1e-6); + let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?; + let norm2 = LayerNormNoAffine::new(1e-6); + let mlp_ratio = 4; + let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?; + let n_mods = 6; + let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear( + hidden_size, + n_mods * hidden_size, + vb.pp("adaLN_modulation.1"), + )?); + + Ok(Self { + norm1, + attn, + norm2, + mlp, + ada_ln_modulation, + }) + } + + pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result<(Qkv, ModulateIntermediates)> { + let modulation = self.ada_ln_modulation.forward(c)?; + let chunks = modulation.chunk(6, D::Minus1)?; + let (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = ( + chunks[0].clone(), + chunks[1].clone(), + chunks[2].clone(), + chunks[3].clone(), + chunks[4].clone(), + chunks[5].clone(), + ); + + let norm_x = self.norm1.forward(x)?; + let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?; + let qkv = self.attn.pre_attention(&modulated_x)?; + + Ok(( + qkv, + ModulateIntermediates { + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + }, + )) + } + + pub fn post_attention( + &self, + attn: &Tensor, + x: &Tensor, + mod_interm: &ModulateIntermediates, + ) -> Result { + let attn_out = self.attn.post_attention(attn)?; + let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?; + + let norm_x = self.norm2.forward(&x)?; + let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?; + let mlp_out = self.mlp.forward(&modulated_x)?; + let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?; + + Ok(x) + } +} + +pub struct QkvOnlyDiTBlock { + norm1: LayerNormNoAffine, + attn: QkvOnlyAttnProjections, + ada_ln_modulation: nn::Sequential, +} + +impl QkvOnlyDiTBlock { + pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + let norm1 = LayerNormNoAffine::new(1e-6); + let attn = QkvOnlyAttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?; + let n_mods = 2; + let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear( + hidden_size, + n_mods * hidden_size, + vb.pp("adaLN_modulation.1"), + )?); + + Ok(Self { + norm1, + attn, + ada_ln_modulation, + }) + } + + pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result { + let modulation = self.ada_ln_modulation.forward(c)?; + let chunks = modulation.chunk(2, D::Minus1)?; + let (shift_msa, scale_msa) = (chunks[0].clone(), chunks[1].clone()); + + let norm_x = self.norm1.forward(x)?; + let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?; + self.attn.pre_attention(&modulated_x) + } +} + +pub struct FinalLayer { + norm_final: LayerNormNoAffine, + linear: nn::Linear, + ada_ln_modulation: nn::Sequential, +} + +impl FinalLayer { + pub fn new( + hidden_size: usize, + patch_size: usize, + out_channels: usize, + vb: nn::VarBuilder, + ) -> Result { + let norm_final = LayerNormNoAffine::new(1e-6); + let linear = nn::linear( + hidden_size, + patch_size * patch_size * out_channels, + vb.pp("linear"), + )?; + let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear( + hidden_size, + 2 * hidden_size, + vb.pp("adaLN_modulation.1"), + )?); + + Ok(Self { + norm_final, + linear, + ada_ln_modulation, + }) + } + + pub fn forward(&self, x: &Tensor, c: &Tensor) -> Result { + let modulation = self.ada_ln_modulation.forward(c)?; + let chunks = modulation.chunk(2, D::Minus1)?; + let (shift, scale) = (chunks[0].clone(), chunks[1].clone()); + + let norm_x = self.norm_final.forward(x)?; + let modulated_x = modulate(&norm_x, &shift, &scale)?; + let output = self.linear.forward(&modulated_x)?; + + Ok(output) + } +} + +fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> Result { + let shift = shift.unsqueeze(1)?; + let scale = scale.unsqueeze(1)?; + let scale_plus_one = scale.add(&Tensor::ones_like(&scale)?)?; + shift.broadcast_add(&x.broadcast_mul(&scale_plus_one)?) +} + +pub struct JointBlock { + x_block: DiTBlock, + context_block: DiTBlock, + num_heads: usize, +} + +impl JointBlock { + pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?; + let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?; + + Ok(Self { + x_block, + context_block, + num_heads, + }) + } + + pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> { + let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?; + let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?; + let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?; + let context_out = + self.context_block + .post_attention(&context_attn, context, &context_interm)?; + let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?; + Ok((context_out, x_out)) + } +} + +pub struct ContextQkvOnlyJointBlock { + x_block: DiTBlock, + context_block: QkvOnlyDiTBlock, + num_heads: usize, +} + +impl ContextQkvOnlyJointBlock { + pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?; + let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?; + Ok(Self { + x_block, + context_block, + num_heads, + }) + } + + pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result { + let context_qkv = self.context_block.pre_attention(context, c)?; + let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?; + + let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?; + + let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?; + Ok(x_out) + } +} + +// A QKV-attention that is compatible with the interface of candle_flash_attn::flash_attn +// Flash attention regards q, k, v dimensions as (batch_size, seqlen, nheads, headdim) +fn flash_compatible_attention( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, +) -> Result { + let q_dims_for_matmul = q.transpose(1, 2)?.dims().to_vec(); + let rank = q_dims_for_matmul.len(); + let q = q.transpose(1, 2)?.flatten_to(rank - 3)?; + let k = k.transpose(1, 2)?.flatten_to(rank - 3)?; + let v = v.transpose(1, 2)?.flatten_to(rank - 3)?; + let attn_weights = (q.matmul(&k.t()?)? * softmax_scale as f64)?; + let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?; + attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2) +} + +fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> { + let qkv = Qkv { + q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?, + k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?, + v: Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?, + }; + + let (batch_size, seqlen, _) = qkv.q.dims3()?; + let qkv = Qkv { + q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?, + k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?, + v: qkv.v, + }; + + let headdim = qkv.q.dim(D::Minus1)?; + let softmax_scale = 1.0 / (headdim as f64).sqrt(); + // let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?; + let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?; + + let attn = attn.reshape((batch_size, seqlen, ()))?; + let context_qkv_seqlen = context_qkv.q.dim(1)?; + let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?; + let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?; + + Ok((context_attn, x_attn)) +} diff --git a/candle-transformers/src/models/mmdit/embedding.rs b/candle-transformers/src/models/mmdit/embedding.rs new file mode 100644 index 0000000000..6e200b18bd --- /dev/null +++ b/candle-transformers/src/models/mmdit/embedding.rs @@ -0,0 +1,197 @@ +use candle::{bail, DType, Module, Result, Tensor}; +use candle_nn as nn; + +pub struct PatchEmbedder { + proj: nn::Conv2d, +} + +impl PatchEmbedder { + pub fn new( + patch_size: usize, + in_channels: usize, + embed_dim: usize, + vb: nn::VarBuilder, + ) -> Result { + let proj = nn::conv2d( + in_channels, + embed_dim, + patch_size, + nn::Conv2dConfig { + stride: patch_size, + ..Default::default() + }, + vb.pp("proj"), + )?; + + Ok(Self { proj }) + } +} + +impl Module for PatchEmbedder { + fn forward(&self, x: &Tensor) -> Result { + let x = self.proj.forward(x)?; + + // flatten spatial dim and transpose to channels last + let (b, c, h, w) = x.dims4()?; + x.reshape((b, c, h * w))?.transpose(1, 2) + } +} + +pub struct Unpatchifier { + patch_size: usize, + out_channels: usize, +} + +impl Unpatchifier { + pub fn new(patch_size: usize, out_channels: usize) -> Result { + Ok(Self { + patch_size, + out_channels, + }) + } + + pub fn unpatchify(&self, x: &Tensor, h: usize, w: usize) -> Result { + let h = (h + 1) / self.patch_size; + let w = (w + 1) / self.patch_size; + + let x = x.reshape(( + x.dim(0)?, + h, + w, + self.patch_size, + self.patch_size, + self.out_channels, + ))?; + let x = x.permute((0, 5, 1, 3, 2, 4))?; // "nhwpqc->nchpwq" + x.reshape(( + x.dim(0)?, + self.out_channels, + self.patch_size * h, + self.patch_size * w, + )) + } +} + +pub struct PositionEmbedder { + pos_embed: Tensor, + patch_size: usize, + pos_embed_max_size: usize, +} + +impl PositionEmbedder { + pub fn new( + hidden_size: usize, + patch_size: usize, + pos_embed_max_size: usize, + vb: nn::VarBuilder, + ) -> Result { + let pos_embed = vb.get( + (1, pos_embed_max_size * pos_embed_max_size, hidden_size), + "pos_embed", + )?; + Ok(Self { + pos_embed, + patch_size, + pos_embed_max_size, + }) + } + pub fn get_cropped_pos_embed(&self, h: usize, w: usize) -> Result { + let h = (h + 1) / self.patch_size; + let w = (w + 1) / self.patch_size; + + if h > self.pos_embed_max_size || w > self.pos_embed_max_size { + bail!("Input size is too large for the position embedding") + } + + let top = (self.pos_embed_max_size - h) / 2; + let left = (self.pos_embed_max_size - w) / 2; + + let pos_embed = + self.pos_embed + .reshape((1, self.pos_embed_max_size, self.pos_embed_max_size, ()))?; + let pos_embed = pos_embed.narrow(1, top, h)?.narrow(2, left, w)?; + pos_embed.reshape((1, h * w, ())) + } +} + +pub struct TimestepEmbedder { + mlp: nn::Sequential, + frequency_embedding_size: usize, +} + +impl TimestepEmbedder { + pub fn new( + hidden_size: usize, + frequency_embedding_size: usize, + vb: nn::VarBuilder, + ) -> Result { + let mlp = nn::seq() + .add(nn::linear( + frequency_embedding_size, + hidden_size, + vb.pp("mlp.0"), + )?) + .add(nn::Activation::Silu) + .add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?); + + Ok(Self { + mlp, + frequency_embedding_size, + }) + } + + fn timestep_embedding(t: &Tensor, dim: usize, max_period: f64) -> Result { + if dim % 2 != 0 { + bail!("Embedding dimension must be even") + } + + if t.dtype() != DType::F32 && t.dtype() != DType::F64 { + bail!("Input tensor must be floating point") + } + + let half = dim / 2; + let freqs = Tensor::arange(0f32, half as f32, t.device())? + .to_dtype(candle::DType::F32)? + .mul(&Tensor::full( + (-f64::ln(max_period) / half as f64) as f32, + half, + t.device(), + )?)? + .exp()?; + + let args = t + .unsqueeze(1)? + .to_dtype(candle::DType::F32)? + .matmul(&freqs.unsqueeze(0)?)?; + let embedding = Tensor::cat(&[args.cos()?, args.sin()?], 1)?; + embedding.to_dtype(candle::DType::F16) + } +} + +impl Module for TimestepEmbedder { + fn forward(&self, t: &Tensor) -> Result { + let t_freq = Self::timestep_embedding(t, self.frequency_embedding_size, 10000.0)?; + self.mlp.forward(&t_freq) + } +} + +pub struct VectorEmbedder { + mlp: nn::Sequential, +} + +impl VectorEmbedder { + pub fn new(input_dim: usize, hidden_size: usize, vb: nn::VarBuilder) -> Result { + let mlp = nn::seq() + .add(nn::linear(input_dim, hidden_size, vb.pp("mlp.0"))?) + .add(nn::Activation::Silu) + .add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?); + + Ok(Self { mlp }) + } +} + +impl Module for VectorEmbedder { + fn forward(&self, x: &Tensor) -> Result { + self.mlp.forward(x) + } +} diff --git a/candle-transformers/src/models/mmdit/mod.rs b/candle-transformers/src/models/mmdit/mod.rs new file mode 100644 index 0000000000..9c4db6e085 --- /dev/null +++ b/candle-transformers/src/models/mmdit/mod.rs @@ -0,0 +1,4 @@ +pub mod blocks; +pub mod embedding; +pub mod model; +pub mod projections; diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs new file mode 100644 index 0000000000..1523836c7f --- /dev/null +++ b/candle-transformers/src/models/mmdit/model.rs @@ -0,0 +1,173 @@ +// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206). +// This follows the implementation of the MMDiT model in the ComfyUI repository. +// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1 +use candle::{Module, Result, Tensor, D}; +use candle_nn as nn; + +use super::blocks::{ContextQkvOnlyJointBlock, FinalLayer, JointBlock}; +use super::embedding::{ + PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder, +}; + +#[derive(Debug, Clone)] +pub struct Config { + pub patch_size: usize, + pub in_channels: usize, + pub out_channels: usize, + pub depth: usize, + pub head_size: usize, + pub adm_in_channels: usize, + pub pos_embed_max_size: usize, + pub context_embed_size: usize, + pub frequency_embedding_size: usize, +} + +impl Config { + pub fn sd3() -> Self { + Self { + patch_size: 2, + in_channels: 16, + out_channels: 16, + depth: 24, + head_size: 64, + adm_in_channels: 2048, + pos_embed_max_size: 192, + context_embed_size: 4096, + frequency_embedding_size: 256, + } + } +} + +pub struct MMDiT { + core: MMDiTCore, + patch_embedder: PatchEmbedder, + pos_embedder: PositionEmbedder, + timestep_embedder: TimestepEmbedder, + vector_embedder: VectorEmbedder, + context_embedder: nn::Linear, + unpatchifier: Unpatchifier, +} + +impl MMDiT { + pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result { + let hidden_size = cfg.head_size * cfg.depth; + let core = MMDiTCore::new( + cfg.depth, + hidden_size, + cfg.depth, + cfg.patch_size, + cfg.out_channels, + vb.clone(), + )?; + let patch_embedder = PatchEmbedder::new( + cfg.patch_size, + cfg.in_channels, + hidden_size, + vb.pp("x_embedder"), + )?; + let pos_embedder = PositionEmbedder::new( + hidden_size, + cfg.patch_size, + cfg.pos_embed_max_size, + vb.clone(), + )?; + let timestep_embedder = TimestepEmbedder::new( + hidden_size, + cfg.frequency_embedding_size, + vb.pp("t_embedder"), + )?; + let vector_embedder = + VectorEmbedder::new(cfg.adm_in_channels, hidden_size, vb.pp("y_embedder"))?; + let context_embedder = nn::linear( + cfg.context_embed_size, + hidden_size, + vb.pp("context_embedder"), + )?; + let unpatchifier = Unpatchifier::new(cfg.patch_size, cfg.out_channels)?; + + Ok(Self { + core, + patch_embedder, + pos_embedder, + timestep_embedder, + vector_embedder, + context_embedder, + unpatchifier, + }) + } + + pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result { + // Following the convention of the ComfyUI implementation. + // https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919 + // + // Forward pass of DiT. + // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + // t: (N,) tensor of diffusion timesteps + // y: (N,) tensor of class labels + let h = x.dim(D::Minus2)?; + let w = x.dim(D::Minus1)?; + let cropped_pos_embed = self.pos_embedder.get_cropped_pos_embed(h, w)?; + let x = self + .patch_embedder + .forward(x)? + .broadcast_add(&cropped_pos_embed)?; + let c = self.timestep_embedder.forward(t)?; + let y = self.vector_embedder.forward(y)?; + let c = (c + y)?; + let context = self.context_embedder.forward(context)?; + + let x = self.core.forward(&context, &x, &c)?; + let x = self.unpatchifier.unpatchify(&x, h, w)?; + x.narrow(2, 0, h)?.narrow(3, 0, w) + } +} + +pub struct MMDiTCore { + joint_blocks: Vec, + context_qkv_only_joint_block: ContextQkvOnlyJointBlock, + final_layer: FinalLayer, +} + +impl MMDiTCore { + pub fn new( + depth: usize, + hidden_size: usize, + num_heads: usize, + patch_size: usize, + out_channels: usize, + vb: nn::VarBuilder, + ) -> Result { + let mut joint_blocks = Vec::with_capacity(depth - 1); + for i in 0..depth - 1 { + joint_blocks.push(JointBlock::new( + hidden_size, + num_heads, + vb.pp(format!("joint_blocks.{}", i)), + )?); + } + + Ok(Self { + joint_blocks, + context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new( + hidden_size, + num_heads, + vb.pp(format!("joint_blocks.{}", depth - 1)), + )?, + final_layer: FinalLayer::new( + hidden_size, + patch_size, + out_channels, + vb.pp("final_layer"), + )?, + }) + } + + pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result { + let (mut context, mut x) = (context.clone(), x.clone()); + for joint_block in &self.joint_blocks { + (context, x) = joint_block.forward(&context, &x, c)?; + } + let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?; + self.final_layer.forward(&x, c) + } +} diff --git a/candle-transformers/src/models/mmdit/projections.rs b/candle-transformers/src/models/mmdit/projections.rs new file mode 100644 index 0000000000..1077398f5c --- /dev/null +++ b/candle-transformers/src/models/mmdit/projections.rs @@ -0,0 +1,94 @@ +use candle::{Module, Result, Tensor}; +use candle_nn as nn; + +pub struct Qkv { + pub q: Tensor, + pub k: Tensor, + pub v: Tensor, +} + +pub struct Mlp { + fc1: nn::Linear, + act: nn::Activation, + fc2: nn::Linear, +} + +impl Mlp { + pub fn new( + in_features: usize, + hidden_features: usize, + vb: candle_nn::VarBuilder, + ) -> Result { + let fc1 = nn::linear(in_features, hidden_features, vb.pp("fc1"))?; + let act = nn::Activation::GeluPytorchTanh; + let fc2 = nn::linear(hidden_features, in_features, vb.pp("fc2"))?; + + Ok(Self { fc1, act, fc2 }) + } +} + +impl Module for Mlp { + fn forward(&self, x: &Tensor) -> Result { + let x = self.fc1.forward(x)?; + let x = self.act.forward(&x)?; + self.fc2.forward(&x) + } +} + +pub struct QkvOnlyAttnProjections { + qkv: nn::Linear, + head_dim: usize, +} + +impl QkvOnlyAttnProjections { + pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + // {'dim': 1536, 'num_heads': 24} + let head_dim = dim / num_heads; + let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?; + Ok(Self { qkv, head_dim }) + } + + pub fn pre_attention(&self, x: &Tensor) -> Result { + let qkv = self.qkv.forward(x)?; + split_qkv(&qkv, self.head_dim) + } +} + +pub struct AttnProjections { + head_dim: usize, + qkv: nn::Linear, + proj: nn::Linear, +} + +impl AttnProjections { + pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + let head_dim = dim / num_heads; + let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?; + let proj = nn::linear(dim, dim, vb.pp("proj"))?; + Ok(Self { + head_dim, + qkv, + proj, + }) + } + + pub fn pre_attention(&self, x: &Tensor) -> Result { + let qkv = self.qkv.forward(x)?; + split_qkv(&qkv, self.head_dim) + } + + pub fn post_attention(&self, x: &Tensor) -> Result { + self.proj.forward(x) + } +} + +fn split_qkv(qkv: &Tensor, head_dim: usize) -> Result { + let (batch_size, seq_len, _) = qkv.dims3()?; + let qkv = qkv.reshape((batch_size, seq_len, 3, (), head_dim))?; + let q = qkv.get_on_dim(2, 0)?; + let q = q.reshape((batch_size, seq_len, ()))?; + let k = qkv.get_on_dim(2, 1)?; + let k = k.reshape((batch_size, seq_len, ()))?; + let v = qkv.get_on_dim(2, 2)?; + Ok(Qkv { q, k, v }) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index cf5a240231..8bc37475e4 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -32,6 +32,7 @@ pub mod metavoice; pub mod mistral; pub mod mixformer; pub mod mixtral; +pub mod mmdit; pub mod mobilenetv4; pub mod mobileone; pub mod moondream; From 0e78d29b7b61c5c878517464fbb9034ecffa9791 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 5 Aug 2024 20:03:31 +0100 Subject: [PATCH 40/75] Add the import script for the T5 tokenizer. (#2399) --- candle-examples/examples/flux/t5_tokenizer.py | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 candle-examples/examples/flux/t5_tokenizer.py diff --git a/candle-examples/examples/flux/t5_tokenizer.py b/candle-examples/examples/flux/t5_tokenizer.py new file mode 100644 index 0000000000..cebd0624bf --- /dev/null +++ b/candle-examples/examples/flux/t5_tokenizer.py @@ -0,0 +1,6 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + +BASE_MODEL = "google/t5-v1_1-xxl" +tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) +# The tokenizer will be saved in /tmp/tokenizer/tokenizer.json +tokenizer.save_pretrained("/tmp/tokenizer/") From 1b796b9dac53dba5b703890c092ee6a8abc362a1 Mon Sep 17 00:00:00 2001 From: Hamir Mahal Date: Tue, 6 Aug 2024 01:59:34 -0700 Subject: [PATCH 41/75] fix: usage of `actions/checkout@v2` (#2403) * chore: changes from formatting on save * fix: usage of `actions/checkout@v2` --- .github/workflows/python.yml | 6 +++--- .github/workflows/rust-ci.yml | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index be9b917ec2..68e2eee31e 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -18,9 +18,9 @@ jobs: strategy: matrix: os: [ubuntu-latest] # For now, only test on Linux - steps: + steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Install Rust uses: actions-rs/toolchain@v1 @@ -65,4 +65,4 @@ jobs: working-directory: ./candle-pyo3 run: | source .env/bin/activate - python -m pytest -s -v tests \ No newline at end of file + python -m pytest -s -v tests diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index 2ca53b2348..ee480c474c 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -1,6 +1,6 @@ -on: +on: push: - branches: + branches: - main pull_request: @@ -15,7 +15,7 @@ jobs: os: [ubuntu-latest, windows-latest, macOS-latest] rust: [stable] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - uses: actions-rs/toolchain@v1 with: profile: minimal @@ -34,7 +34,7 @@ jobs: os: [ubuntu-latest, windows-latest, macOS-latest] rust: [stable] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - uses: actions-rs/toolchain@v1 with: profile: minimal @@ -49,7 +49,7 @@ jobs: name: Rustfmt runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - uses: actions-rs/toolchain@v1 with: profile: minimal @@ -65,7 +65,7 @@ jobs: name: Clippy runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - uses: actions-rs/toolchain@v1 with: profile: minimal From c9cdd54eeb4a2bb93a07e20c98b8b9b1a00cfbc3 Mon Sep 17 00:00:00 2001 From: Joel Nises Date: Sat, 10 Aug 2024 07:49:05 +0200 Subject: [PATCH 42/75] Fix issues in the encodec example README.md (#2407) Also squeeze the first dimension of the codes tensor in the example file to get the expected three dimensions. --- candle-examples/examples/encodec/README.md | 2 +- .../examples/encodec/jfk-codes.safetensors | Bin 13328 -> 13328 bytes 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-examples/examples/encodec/README.md b/candle-examples/examples/encodec/README.md index 9de0d4adba..a339b67f50 100644 --- a/candle-examples/examples/encodec/README.md +++ b/candle-examples/examples/encodec/README.md @@ -7,7 +7,7 @@ quantization. ## Running one example ```bash -cargo run --example encodec --features symphonia --release -- code-to-audio \ +cargo run --example encodec --features encodec --release -- code-to-audio \ candle-examples/examples/encodec/jfk-codes.safetensors \ jfk.wav ``` diff --git a/candle-examples/examples/encodec/jfk-codes.safetensors b/candle-examples/examples/encodec/jfk-codes.safetensors index b8eb202618b3196ef37a9b04ca15f3445ec50311..675196120562d85a1efad3d6677a13fa4ea9f181 100644 GIT binary patch delta 13 UcmbP`F(G4u_C#IfjeduW0Vn Date: Fri, 9 Aug 2024 22:57:52 -0700 Subject: [PATCH 43/75] Soft Non-Maximum Suppression (#2400) * Soft NMS with thresholds * NMS Test * Soft nms w/ boxes removed below threshold * Soft nms test * No longer removing bounding boxes to fit Soft-NMS focus * Initialize confidence * Added comments * Refactored out updating based on IOU/sigma * Score_threshold -> confidence_threshold for clarity * Remove bboxes below confidence threshold * Softnms basic functionality test * Softnms confidence decay test * Softnms confidence threshold test * Softnms no overlapping bbox test * Testing confidence after no overlap test * Single bbox and no bbox tests * Signify test completion * Handling result of test functions * Checking all pairs of bboxes instead of a forward pass * Equal confidence overlap test * Clarified tests for implementation * No longer dropping boxes, just setting to 0.0 * Formatted w/ cargo --- candle-transformers/src/object_detection.rs | 58 +++++ candle-transformers/tests/nms_tests.rs | 222 ++++++++++++++++++++ 2 files changed, 280 insertions(+) create mode 100644 candle-transformers/tests/nms_tests.rs diff --git a/candle-transformers/src/object_detection.rs b/candle-transformers/src/object_detection.rs index ce5793165d..e922075fcc 100644 --- a/candle-transformers/src/object_detection.rs +++ b/candle-transformers/src/object_detection.rs @@ -50,3 +50,61 @@ pub fn non_maximum_suppression(bboxes: &mut [Vec>], threshold: f32) { bboxes_for_class.truncate(current_index); } } + +// Updates confidences starting at highest and comparing subsequent boxes. +fn update_confidences( + bboxes_for_class: &[Bbox], + updated_confidences: &mut [f32], + iou_threshold: f32, + sigma: f32, +) { + let len = bboxes_for_class.len(); + for current_index in 0..len { + let current_bbox = &bboxes_for_class[current_index]; + for index in (current_index + 1)..len { + let iou_val = iou(current_bbox, &bboxes_for_class[index]); + if iou_val > iou_threshold { + // Decay calculation from page 4 of: https://arxiv.org/pdf/1704.04503 + let decay = (-iou_val * iou_val / sigma).exp(); + let updated_confidence = bboxes_for_class[index].confidence * decay; + updated_confidences[index] = updated_confidence; + } + } + } +} + +// Sorts the bounding boxes by confidence and applies soft non-maximum suppression. +// This function is based on the algorithm described in https://arxiv.org/pdf/1704.04503 +pub fn soft_non_maximum_suppression( + bboxes: &mut [Vec>], + iou_threshold: Option, + confidence_threshold: Option, + sigma: Option, +) { + let iou_threshold = iou_threshold.unwrap_or(0.5); + let confidence_threshold = confidence_threshold.unwrap_or(0.1); + let sigma = sigma.unwrap_or(0.5); + + for bboxes_for_class in bboxes.iter_mut() { + // Sort boxes by confidence in descending order + bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap()); + let mut updated_confidences = bboxes_for_class + .iter() + .map(|bbox| bbox.confidence) + .collect::>(); + update_confidences( + bboxes_for_class, + &mut updated_confidences, + iou_threshold, + sigma, + ); + // Update confidences, set to 0.0 if below threshold + for (i, &confidence) in updated_confidences.iter().enumerate() { + bboxes_for_class[i].confidence = if confidence < confidence_threshold { + 0.0 + } else { + confidence + }; + } + } +} diff --git a/candle-transformers/tests/nms_tests.rs b/candle-transformers/tests/nms_tests.rs new file mode 100644 index 0000000000..d70f6fdf32 --- /dev/null +++ b/candle-transformers/tests/nms_tests.rs @@ -0,0 +1,222 @@ +use candle::Result; +use candle_transformers::object_detection::{ + non_maximum_suppression, soft_non_maximum_suppression, Bbox, +}; + +#[test] +fn nms_basic() -> Result<()> { + // Boxes based upon https://thepythoncode.com/article/non-maximum-suppression-using-opencv-in-python + let mut bboxes = vec![vec![ + Bbox { + xmin: 245.0, + ymin: 305.0, + xmax: 575.0, + ymax: 490.0, + confidence: 0.9, + data: (), + }, // Box 1 + Bbox { + xmin: 235.0, + ymin: 300.0, + xmax: 485.0, + ymax: 515.0, + confidence: 0.8, + data: (), + }, // Box 2 + Bbox { + xmin: 305.0, + ymin: 270.0, + xmax: 540.0, + ymax: 500.0, + confidence: 0.6, + data: (), + }, // Box 3 + ]]; + + non_maximum_suppression(&mut bboxes, 0.5); + let bboxes = bboxes.into_iter().next().unwrap(); + assert_eq!(bboxes.len(), 1); + assert_eq!(bboxes[0].confidence, 0.9); + + Ok(()) +} + +#[test] +fn softnms_basic_functionality() -> Result<()> { + let mut bboxes = vec![vec![ + Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.5, + data: (), + }, + Bbox { + xmin: 0.1, + ymin: 0.1, + xmax: 1.1, + ymax: 1.1, + confidence: 0.9, + data: (), + }, + Bbox { + xmin: 0.2, + ymin: 0.2, + xmax: 1.2, + ymax: 1.2, + confidence: 0.6, + data: (), + }, + ]]; + + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + + // Should decay boxes following highest confidence box + assert!(bboxes[0][0].confidence == 0.9); + assert!(bboxes[0][1].confidence < 0.5); + assert!(bboxes[0][2].confidence < 0.6); + Ok(()) +} + +#[test] +fn softnms_confidence_decay() -> Result<()> { + let mut bboxes = vec![vec![ + Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.9, + data: (), + }, // Reference box + Bbox { + xmin: 0.1, + ymin: 0.1, + xmax: 1.1, + ymax: 1.1, + confidence: 0.8, + data: (), + }, // Overlapping box + ]]; + + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + + // Check that confidence of the overlapping box is decayed + assert!(bboxes[0][0].confidence == 0.9); + assert!(bboxes[0][1].confidence < 0.8); + Ok(()) +} + +#[test] +fn softnms_confidence_threshold() -> Result<()> { + let mut bboxes = vec![vec![ + Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.9, + data: (), + }, + Bbox { + xmin: 0.1, + ymin: 0.1, + xmax: 1.1, + ymax: 1.1, + confidence: 0.05, + data: (), + }, + ]]; + + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + + // Box with confidence below the threshold should be removed + assert_eq!(bboxes[0].len(), 2); + assert_eq!(bboxes[0][0].confidence, 0.9); + assert_eq!(bboxes[0][1].confidence, 0.00); + Ok(()) +} + +#[test] +fn softnms_no_overlap() -> Result<()> { + let mut bboxes = vec![vec![ + Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.9, + data: (), + }, + Bbox { + xmin: 2.0, + ymin: 2.0, + xmax: 3.0, + ymax: 3.0, + confidence: 0.8, + data: (), + }, + ]]; + + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + + // Both boxes should remain as they do not significantly overlap + assert_eq!(bboxes[0].len(), 2); + assert_eq!(bboxes[0][0].confidence, 0.9); + assert_eq!(bboxes[0][1].confidence, 0.8); + Ok(()) +} +#[test] +fn softnms_no_bbox() -> Result<()> { + let mut bboxes: Vec>> = vec![]; + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + assert!(bboxes.is_empty()); + Ok(()) +} + +#[test] +fn softnms_single_bbox() -> Result<()> { + let mut bboxes = vec![vec![Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.9, + data: (), + }]]; + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + assert_eq!(bboxes[0].len(), 1); + Ok(()) +} + +#[test] +fn softnms_equal_confidence_overlap() -> Result<()> { + let mut bboxes = vec![vec![ + Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.5, + data: (), + }, + Bbox { + xmin: 0.1, + ymin: 0.1, + xmax: 1.1, + ymax: 1.1, + confidence: 0.5, + data: (), + }, + ]]; + + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + + // First box will be reference box, second box should be decayed + // Implementation must change to have both be decayed + assert_eq!(bboxes[0].len(), 2); + assert!(bboxes[0][0].confidence == 0.5); + assert!(bboxes[0][1].confidence < 0.5); + Ok(()) +} From de719a253c853b904fb0c074b97866bfb0d9096f Mon Sep 17 00:00:00 2001 From: Carsten Csiky Date: Sat, 10 Aug 2024 08:11:09 +0200 Subject: [PATCH 44/75] Add documentation examples for `Tensor::i` and `Tensor::narrow` methods (#2308) * Add documentation examples for `Tensor` methods * Apply fmt. * Cosmetic tweaks. --------- Co-authored-by: Laurent --- candle-core/src/indexer.rs | 102 ++++++++++++++++++++++++++++++++++--- candle-core/src/tensor.rs | 75 +++++++++++++++++++++++++++ 2 files changed, 169 insertions(+), 8 deletions(-) diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs index e3ed41e52b..a645d8b195 100644 --- a/candle-core/src/indexer.rs +++ b/candle-core/src/indexer.rs @@ -141,28 +141,114 @@ impl IndexOp for Tensor where T: Into, { + ///```rust + /// use candle_core::{Tensor, DType, Device, IndexOp}; + /// let a = Tensor::new(&[ + /// [0., 1.], + /// [2., 3.], + /// [4., 5.] + /// ], &Device::Cpu)?; + /// + /// let b = a.i(0)?; + /// assert_eq!(b.shape().dims(), &[2]); + /// assert_eq!(b.to_vec1::()?, &[0., 1.]); + /// + /// let c = a.i(..2)?; + /// assert_eq!(c.shape().dims(), &[2, 2]); + /// assert_eq!(c.to_vec2::()?, &[ + /// [0., 1.], + /// [2., 3.] + /// ]); + /// + /// let d = a.i(1..)?; + /// assert_eq!(d.shape().dims(), &[2, 2]); + /// assert_eq!(d.to_vec2::()?, &[ + /// [2., 3.], + /// [4., 5.] + /// ]); + /// # Ok::<(), candle_core::Error>(()) fn i(&self, index: T) -> Result { self.index(&[index.into()]) } } +impl IndexOp<(A,)> for Tensor +where + A: Into, +{ + ///```rust + /// use candle_core::{Tensor, DType, Device, IndexOp}; + /// let a = Tensor::new(&[ + /// [0f32, 1.], + /// [2. , 3.], + /// [4. , 5.] + /// ], &Device::Cpu)?; + /// + /// let b = a.i((0,))?; + /// assert_eq!(b.shape().dims(), &[2]); + /// assert_eq!(b.to_vec1::()?, &[0., 1.]); + /// + /// let c = a.i((..2,))?; + /// assert_eq!(c.shape().dims(), &[2, 2]); + /// assert_eq!(c.to_vec2::()?, &[ + /// [0., 1.], + /// [2., 3.] + /// ]); + /// + /// let d = a.i((1..,))?; + /// assert_eq!(d.shape().dims(), &[2, 2]); + /// assert_eq!(d.to_vec2::()?, &[ + /// [2., 3.], + /// [4., 5.] + /// ]); + /// # Ok::<(), candle_core::Error>(()) + fn i(&self, (a,): (A,)) -> Result { + self.index(&[a.into()]) + } +} +#[allow(non_snake_case)] +impl IndexOp<(A, B)> for Tensor +where + A: Into, + B: Into, +{ + ///```rust + /// use candle_core::{Tensor, DType, Device, IndexOp}; + /// let a = Tensor::new(&[[0f32, 1., 2.], [3., 4., 5.], [6., 7., 8.]], &Device::Cpu)?; + /// + /// let b = a.i((1, 0))?; + /// assert_eq!(b.to_vec0::()?, 3.); + /// + /// let c = a.i((..2, 1))?; + /// assert_eq!(c.shape().dims(), &[2]); + /// assert_eq!(c.to_vec1::()?, &[1., 4.]); + /// + /// let d = a.i((2.., ..))?; + /// assert_eq!(c.shape().dims(), &[2]); + /// assert_eq!(c.to_vec1::()?, &[1., 4.]); + /// # Ok::<(), candle_core::Error>(()) + fn i(&self, (a, b): (A, B)) -> Result { + self.index(&[a.into(), b.into()]) + } +} + macro_rules! index_op_tuple { - ($($t:ident),+) => { + ($doc:tt, $($t:ident),+) => { #[allow(non_snake_case)] impl<$($t),*> IndexOp<($($t,)*)> for Tensor where $($t: Into,)* { + #[doc=$doc] fn i(&self, ($($t,)*): ($($t,)*)) -> Result { self.index(&[$($t.into(),)*]) } } }; } -index_op_tuple!(A); -index_op_tuple!(A, B); -index_op_tuple!(A, B, C); -index_op_tuple!(A, B, C, D); -index_op_tuple!(A, B, C, D, E); -index_op_tuple!(A, B, C, D, E, F); -index_op_tuple!(A, B, C, D, E, F, G); + +index_op_tuple!("see [TensorIndex#method.i]", A, B, C); +index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D); +index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E); +index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F); +index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F, G); diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e66fcc629e..256c388810 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -431,6 +431,15 @@ impl Tensor { /// Returns a new tensor with all the elements having the same specified value. Note that /// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed. + ///```rust + /// use candle_core::{Tensor, Device}; + /// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec2::()?, &[ + /// [3.5, 3.5, 3.5, 3.5], + /// [3.5, 3.5, 3.5, 3.5], + /// ]); + /// # Ok::<(), candle_core::Error>(()) pub fn full>( value: D, shape: S, @@ -440,6 +449,13 @@ impl Tensor { } /// Creates a new 1D tensor from an iterator. + ///```rust + /// use candle_core::{Tensor, Device}; + /// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec1::()?, &[1.0, 2.0, 3.0, 4.0]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` pub fn from_iter( iter: impl IntoIterator, device: &Device, @@ -451,12 +467,26 @@ impl Tensor { /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common /// difference `1` from `start`. + ///```rust + /// use candle_core::{Tensor, Device}; + /// let a = Tensor::arange(2., 5., &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec1::()?, &[2., 3., 4.]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` pub fn arange(start: D, end: D, device: &Device) -> Result { Self::arange_step(start, end, D::one(), device) } /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common /// difference `step` from `start`. + ///```rust + /// use candle_core::{Tensor, Device}; + /// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec1::()?, &[2.0, 2.5, 3.0, 3.5]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` pub fn arange_step( start: D, end: D, @@ -502,6 +532,16 @@ impl Tensor { /// Creates a new tensor initialized with values from the input vector. The number of elements /// in this vector must be the same as the number of elements defined by the shape. /// If the device is cpu, no data copy is made. + ///```rust + /// use candle_core::{Tensor, Device}; + /// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec2::()?, &[ + /// [1., 2., 3.], + /// [4., 5., 6.] + /// ]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` pub fn from_vec, D: crate::WithDType>( data: Vec, shape: S, @@ -512,6 +552,17 @@ impl Tensor { /// Creates a new tensor initialized with values from the input slice. The number of elements /// in this vector must be the same as the number of elements defined by the shape. + ///```rust + /// use candle_core::{Tensor, Device}; + /// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.]; + /// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec2::()?, &[ + /// [2., 3., 4.], + /// [5., 6., 7.] + /// ]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` pub fn from_slice, D: crate::WithDType>( array: &[D], shape: S, @@ -793,6 +844,30 @@ impl Tensor { /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` /// ranges from `start` to `start + len`. + /// ``` + /// use candle_core::{Tensor, Device}; + /// let a = Tensor::new(&[ + /// [0f32, 1., 2.], + /// [3. , 4., 5.], + /// [6. , 7., 8.] + /// ], &Device::Cpu)?; + /// + /// let b = a.narrow(0, 1, 2)?; + /// assert_eq!(b.shape().dims(), &[2, 3]); + /// assert_eq!(b.to_vec2::()?, &[ + /// [3., 4., 5.], + /// [6., 7., 8.] + /// ]); + /// + /// let c = a.narrow(1, 1, 1)?; + /// assert_eq!(c.shape().dims(), &[3, 1]); + /// assert_eq!(c.to_vec2::()?, &[ + /// [1.], + /// [4.], + /// [7.] + /// ]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` pub fn narrow(&self, dim: D, start: usize, len: usize) -> Result { let dims = self.dims(); let dim = dim.to_index(self.shape(), "narrow")?; From 2e72a3d0597a4e0f5602f7f40a59bd3659f34bd7 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Mon, 12 Aug 2024 22:21:19 +0300 Subject: [PATCH 45/75] Add Based LLM from Hazy Research. (#2411) --- candle-examples/examples/based/README.md | 20 + candle-examples/examples/based/main.rs | 275 +++++++++++ candle-transformers/src/models/based.rs | 589 +++++++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + 4 files changed, 885 insertions(+) create mode 100644 candle-examples/examples/based/README.md create mode 100644 candle-examples/examples/based/main.rs create mode 100644 candle-transformers/src/models/based.rs diff --git a/candle-examples/examples/based/README.md b/candle-examples/examples/based/README.md new file mode 100644 index 0000000000..16bfddb6c7 --- /dev/null +++ b/candle-examples/examples/based/README.md @@ -0,0 +1,20 @@ +# candle-based + +Experimental, not instruction-tuned small LLM from the Hazy Research group, combining local and linear attention layers. + +[Blogpost](https://hazyresearch.stanford.edu/blog/2024-03-03-based) + +[Simple linear attention language models balance the recall-throughput tradeoff](https://arxiv.org/abs/2402.18668) + +## Running an example + +```bash +$ cargo run --example based --release -- --prompt "Flying monkeys are" --which 1b-50b --sample-len 100 + +Flying monkeys are a common sight in the wild, but they are also a threat to humans. + +The new study, published today (July 31) in the journal Science Advances, shows that the monkeys are using their brains to solve the problem of how to get around the problem. + +"We found that the monkeys were using a strategy called 'cognitive mapping' - they would use their brains to map out the route ahead," says lead author Dr. David J. Smith from the University of California + +``` diff --git a/candle-examples/examples/based/main.rs b/candle-examples/examples/based/main.rs new file mode 100644 index 0000000000..a8bff15ba5 --- /dev/null +++ b/candle-examples/examples/based/main.rs @@ -0,0 +1,275 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::{Parser, ValueEnum}; + +use candle_transformers::models::based::Model; + +use candle::{DType, Device, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + device: Device, + tokenizer: TokenOutputStream, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + tokenizer: Tokenizer, + seed: u64, + temp: Option, + top_p: Option, + repeat_penalty: f32, + repeat_last_n: usize, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + tokenizer: TokenOutputStream::new(tokenizer), + logits_processor, + repeat_penalty, + repeat_last_n, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + self.tokenizer.clear(); + let mut tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + for &t in tokens.iter() { + if let Some(t) = self.tokenizer.next_token(t)? { + print!("{t}") + } + } + std::io::stdout().flush()?; + + let mut generated_tokens = 0usize; + let eos_token = match self.tokenizer.get_token("<|endoftext|>") { + Some(token) => token, + None => anyhow::bail!("cannot find the <|endoftext|> token"), + }; + let start_gen = std::time::Instant::now(); + for index in 0..sample_len { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let start_pos = tokens.len().saturating_sub(context_size); + let ctxt = &tokens[start_pos..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input, start_pos)?; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + break; + } + if let Some(t) = self.tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + } + let dt = start_gen.elapsed(); + if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Which { + #[value(name = "360m")] + W360m, + #[value(name = "1b")] + W1b, + #[value(name = "1b-50b")] + W1b50b, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 10000)] + sample_len: usize, + + #[arg(long)] + model_id: Option, + + #[arg(long, default_value = "refs/pr/1")] + revision: String, + + #[arg(long)] + config_file: Option, + + #[arg(long)] + tokenizer_file: Option, + + #[arg(long)] + weight_files: Option, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + #[arg(long, default_value = "360m")] + which: Which, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let model_id = match args.model_id { + Some(model_id) => model_id, + None => match args.which { + Which::W360m => "hazyresearch/based-360m".to_string(), + Which::W1b => "hazyresearch/based-1b".to_string(), + Which::W1b50b => "hazyresearch/based-1b-50b".to_string(), + }, + }; + let repo = api.repo(Repo::with_revision( + model_id, + RepoType::Model, + args.revision, + )); + let config_file = match args.config_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("config.json")?, + }; + let filenames = match args.weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), + None => vec![repo.get("model.safetensors")?], + }; + + let repo = api.model("openai-community/gpt2".to_string()); + let tokenizer_file = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?; + + let start = std::time::Instant::now(); + let config = serde_json::from_reader(std::fs::File::open(config_file)?)?; + let device = candle_examples::device(args.cpu)?; + let dtype = if device.is_cuda() { + DType::BF16 + } else { + DType::F32 + }; + + let mut vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + if args.which == Which::W1b50b { + vb = vb.pp("model"); + }; + + let model = Model::new(&config, vb)?; + + println!("loaded the model in {:?}", start.elapsed()); + + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + &device, + ); + pipeline.run(&args.prompt, args.sample_len)?; + Ok(()) +} diff --git a/candle-transformers/src/models/based.rs b/candle-transformers/src/models/based.rs new file mode 100644 index 0000000000..aa28f52333 --- /dev/null +++ b/candle-transformers/src/models/based.rs @@ -0,0 +1,589 @@ +//! Based from the Stanford Hazy Research group. +//! +//! See "Simple linear attention language models balance the recall-throughput tradeoff", Arora et al. 2024 +//! + +//! Original code: +//! https://github.com/HazyResearch/based + +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_nn::{ + conv1d_no_bias, linear, linear_no_bias, ops::softmax_last_dim, rms_norm, Conv1d, Conv1dConfig, + Func, Linear, RmsNorm, VarBuilder, +}; +use std::sync::Arc; + +#[derive(Debug, Clone, serde::Deserialize)] +pub struct LinearAttentionFeatureMapConfig { + input_dim: usize, +} + +#[derive(Debug, Clone, serde::Deserialize)] +pub struct LinearAttentionConfig { + num_heads: usize, + feature_dim: usize, + feature_map: LinearAttentionFeatureMapConfig, +} + +#[derive(Debug, Clone, serde::Deserialize)] +pub struct SlidingWindowAttentionConfig { + num_heads: usize, + window_size: usize, +} + +#[derive(Debug, Clone, serde::Deserialize)] +pub struct Config { + vocab_size: usize, + #[serde(rename = "n_embd")] + hidden_size: usize, + #[serde(rename = "n_inner")] + intermediate_size: usize, + #[serde(rename = "n_layer")] + num_hidden_layers: usize, + #[serde(rename = "n_head")] + num_attention_heads: usize, + + layer_norm_epsilon: f64, + #[serde(default = "default_rope", rename = "rotary_emb_base")] + rope_theta: f64, + + alt_mixer_layers: Vec, + alt_mixer_2_layers: Vec, + #[serde(rename = "alt_mixer")] + la: LinearAttentionConfig, + #[serde(rename = "alt_mixer_2")] + swa: SlidingWindowAttentionConfig, +} + +fn default_rope() -> f64 { + 10_000.0 +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + fc1: Linear, + fc2: Linear, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let fc1 = linear_no_bias(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("fc1"))?; + let fc2 = linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?; + Ok(Self { fc1, fc2 }) + } +} + +// Swiglu implementation. +// Not using Activation::Swiglu because this has the gate and y arguments switched compared to the version in candle-nn/src/ops.rs +fn swiglu(xs: &Tensor) -> Result { + let xs = xs.chunk(2, D::Minus1)?; + &xs[1].silu()? * &xs[0] +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + let xs = xs.apply(&self.fc1)?; + let xs = swiglu(&xs)?; + let xs = xs.apply(&self.fc2)?; + Ok(xs) + } +} + +// A gated convolutional block. +#[derive(Debug, Clone)] +struct BasedConv { + in_proj: Linear, + out_proj: Linear, + conv: Conv1d, + state: Tensor, +} + +impl BasedConv { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dim = cfg.hidden_size * 2; + + let conv1d_cfg = Conv1dConfig { + groups: dim, + padding: 2, + ..Default::default() + }; + + let in_proj = linear(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("in_proj"))?; + let out_proj = linear(dim, cfg.hidden_size, vb.pp("out_proj"))?; + let conv = conv1d_no_bias(dim, dim, 3, conv1d_cfg, vb.pp("conv.conv"))?; + let state = Tensor::zeros((1, dim, 3), vb.dtype(), vb.device())?; + Ok(Self { + in_proj, + out_proj, + conv, + state, + }) + } + + fn step(&mut self, xs: &Tensor) -> Result { + self.state = self.state.roll(-1, D::Minus1)?; + let (_, _, l) = self.state.dims3()?; + self.state = self.state.narrow(D::Minus1, 0, l - 1)?; + self.state = Tensor::cat(&[&self.state, &xs.transpose(1, 2)?], 2)?; + + let xs = (&self.state * self.conv.weight().permute((1, 0, 2))?)? + .sum_keepdim(0)? + .sum(D::Minus1)?; + + let xs = xs.unsqueeze(1)?; + + Ok(xs) + } + + fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result { + let xs = xs.apply(&self.in_proj)?; + let us = xs.chunk(2, D::Minus1)?; + let (_b, l, _d) = us[0].dims3()?; + let u_conv = if seqlen_offset > 0 { + self.step(&us[0])? + } else { + let k = std::cmp::min(3, l); + self.state = self.state.narrow(D::Minus1, 0, 3 - k)?; + let xs = us[0].narrow(1, l - k, k)?.transpose(1, 2)?; + self.state = Tensor::cat(&[&self.state, &xs], 2)?; + + us[0] + .transpose(1, 2)? + .apply(&self.conv)? + .narrow(D::Minus1, 0, l)? + .transpose(1, 2)? + }; + + let u_conv = u_conv.silu()?; + let v = u_conv.broadcast_mul(&us[1])?; + let xs = v.apply(&self.out_proj)?; + + Ok(xs) + } +} + +// Linear attention approximating softmax using second order Taylor polynomials. +#[derive(Debug, Clone)] +struct LinearAttention { + proj_q: Linear, + proj_k: Linear, + proj_v: Linear, + out_proj: Linear, + feature_dim: usize, + num_heads: usize, + input_dim: usize, + k_state: Tensor, + kv_state: Tensor, +} + +impl LinearAttention { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let input_dim = cfg.la.feature_map.input_dim; + let out_proj = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("out_proj"))?; + let proj_k = linear_no_bias( + cfg.hidden_size, + cfg.la.num_heads * cfg.la.feature_dim, + vb.pp("proj_k"), + )?; + let proj_q = linear_no_bias( + cfg.hidden_size, + cfg.la.num_heads * cfg.la.feature_dim, + vb.pp("proj_q"), + )?; + + let proj_v = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("proj_v"))?; + let expanded_size = cfg.la.feature_dim.pow(2) + cfg.la.feature_dim + 1; + let k_state = Tensor::zeros( + (1, cfg.la.num_heads, 1, 1, expanded_size), + vb.dtype(), + vb.device(), + )?; + let kv_state = Tensor::zeros( + (1, cfg.la.num_heads, cfg.la.feature_dim, expanded_size), + vb.dtype(), + vb.device(), + )?; + + Ok(Self { + proj_q, + proj_k, + proj_v, + out_proj, + feature_dim: cfg.la.feature_dim, + num_heads: cfg.la.num_heads, + input_dim, + k_state, + kv_state, + }) + } + + fn taylor_expansion(&self) -> Result> { + let r2 = std::f64::consts::SQRT_2; + let rd = (self.input_dim as f64).sqrt(); + let rrd = rd.sqrt(); + + Ok(Func::new(move |xs| { + let dims = xs.dims(); + let mut d = dims.to_vec(); + if let Some(last) = d.last_mut() { + *last = 1; + }; + + let x = xs + .unsqueeze(D::Minus1)? + .broadcast_mul(&xs.unsqueeze(D::Minus2)?)?; + let x = (x.flatten_from(D::Minus2)? / r2)?; + let o = Tensor::ones(d, xs.dtype(), xs.device())?; + let x = Tensor::cat(&[o, (xs / rrd)?, (&x / rd)?], D::Minus1)?; + + Ok(x) + })) + } + + fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result { + let eps = 1e-12; + + let feature_map = self.taylor_expansion()?; + + let (b, l, d) = xs.dims3()?; + let q = xs.apply(&self.proj_q)?; + let k = xs.apply(&self.proj_k)?; + let v = xs.apply(&self.proj_v)?; + + let q = q + .reshape((b, l, self.num_heads, self.feature_dim))? + .transpose(1, 2)? + .contiguous()?; + let k = k + .reshape((b, l, self.num_heads, self.feature_dim))? + .transpose(1, 2)? + .contiguous()?; + let v = v + .reshape((b, l, self.num_heads, d / self.num_heads))? + .transpose(1, 2)? + .contiguous()?; + + let q = feature_map.forward(&q)?; + let k = feature_map.forward(&k)?; + + let y = if seqlen_offset > 0 { + let (_b, _h, l, _d) = k.dims4()?; + let q = q.unsqueeze(D::Minus2)?; + let k = k.unsqueeze(D::Minus2)?; + let v = v.unsqueeze(D::Minus1)?; + let kn = k.narrow(D::Minus1, l - 1, 1)?; + let vn = v.narrow(D::Minus1, l - 1, 1)?; + + self.k_state = self.k_state.broadcast_add(&kn)?; + self.kv_state = self.kv_state.broadcast_add(&kn.broadcast_mul(&vn)?)?; + + let num = q.broadcast_mul(&self.kv_state)?.sum(D::Minus1)?; + let den = (q.broadcast_mul(&self.k_state)?.sum(D::Minus1)? + eps)?; + num.broadcast_div(&den)? + } else { + self.k_state = k.sum(2)?.unsqueeze(2)?.unsqueeze(3)?; + self.kv_state = k + .transpose(2, 3)? + .matmul(&v)? + .transpose(2, 3)? + .unsqueeze(2)?; + let aqk = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?; + let tril = Tensor::tril2(l, aqk.dtype(), aqk.device())?; + let aqk = aqk.broadcast_mul(&tril)?.matmul(&v)?; + + let z = (1f64 / (q.mul(&k.cumsum(2)?)?.sum(D::Minus1)? + eps)?)?; + aqk.broadcast_mul(&z.unsqueeze(D::Minus1)?)? + }; + + let (b, h, l, d) = y.dims4()?; + let y = y.permute((0, 2, 1, 3))?.reshape((b, l, h * d))?; + let y = self.out_proj.forward(&y)?; + + Ok(y) + } +} + +// Rotary embeddings used in local attention. +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let dim = cfg.hidden_size / cfg.num_attention_heads; + let max_seq_len = 2048; // Hardcoded, missing from config. + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +// Local attention using a small sliding window. +#[derive(Debug, Clone)] +struct SlidingWindowAttention { + wqkv: Linear, + out_proj: Linear, + num_heads: usize, + head_dim: usize, + hidden_size: usize, + rotary_emb: Arc, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl SlidingWindowAttention { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden_size = cfg.hidden_size; + let num_heads = cfg.swa.num_heads; + let head_dim = hidden_size / num_heads; + let out_proj = linear_no_bias(hidden_size, hidden_size, vb.pp("out_proj"))?; + let wqkv = linear_no_bias(hidden_size, hidden_size * 3, vb.pp("Wqkv"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?); + Ok(Self { + wqkv, + out_proj, + hidden_size, + num_heads, + head_dim, + rotary_emb, + kv_cache: None, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let (b_sz, q_len, _) = xs.dims3()?; + + let qkv = xs.apply(&self.wqkv)?; + let qkv = qkv.reshape((b_sz, q_len, 3, (), self.head_dim))?; + + let q = qkv.i((.., .., 0))?; + let k = qkv.i((.., .., 1))?; + let v = qkv.i((.., .., 2))?; + + let q = q + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + + let (q, k) = self + .rotary_emb + .apply_rotary_emb_qkv(&q, &k, seqlen_offset)?; + + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((prev_k, prev_v)) => { + let k = Tensor::cat(&[prev_k, &k], 2)?; + let v = Tensor::cat(&[prev_v, &v], 2)?; + (k, v) + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = softmax_last_dim(&attn_weights)?; + let attn_output = attn_weights.matmul(&v)?; + let out = attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.hidden_size))? + .apply(&self.out_proj)?; + + Ok(out) + } +} + +// The model layers use three types of mixers. +#[derive(Debug, Clone)] +enum SequenceMixer { + Based(BasedConv), + Linear(LinearAttention), + Sliding(SlidingWindowAttention), +} + +impl SequenceMixer { + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + pos: usize, + ) -> Result { + match self { + Self::Based(b) => b.forward(xs, pos), + Self::Linear(b) => b.forward(xs, pos), + Self::Sliding(b) => b.forward(xs, attention_mask, pos), + } + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + mlp: MLP, + norm1: RmsNorm, + norm2: RmsNorm, + mixer: SequenceMixer, +} + +impl DecoderLayer { + fn new(layer_idx: usize, cfg: &Config, vb: VarBuilder) -> Result { + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + let norm1 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm1"))?; + let norm2 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm2"))?; + + let l_attn = cfg.alt_mixer_layers.contains(&layer_idx); + let sw_attn = cfg.alt_mixer_2_layers.contains(&layer_idx); + + let mixer = if l_attn { + SequenceMixer::Linear(LinearAttention::new(cfg, vb.pp("mixer"))?) + } else if sw_attn { + SequenceMixer::Sliding(SlidingWindowAttention::new(cfg, vb.pp("mixer"))?) + } else { + SequenceMixer::Based(BasedConv::new(cfg, vb.pp("mixer"))?) + }; + + Ok(Self { + mlp, + norm1, + norm2, + mixer, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let residual = xs; + let xs = self.norm1.forward(xs)?; + let xs = self.mixer.forward(&xs, attention_mask, seqlen_offset)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.norm2)?.apply(&self.mlp)?; + residual + xs + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: super::with_tracing::Embedding, + layers: Vec, + norm: RmsNorm, + lm_head: Linear, + sliding_window: usize, + device: Device, + dtype: DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vocab_size = cfg.vocab_size + (8 - cfg.vocab_size % 8) % 8; + let lm_head = linear_no_bias(cfg.hidden_size, vocab_size, vb.pp("lm_head"))?; + let embed_tokens = super::with_tracing::Embedding::from_weights(lm_head.weight().clone())?; + let vb_m = vb.pp("transformer"); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(layer_idx, cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb_m.pp("ln_f"))?; + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + sliding_window: cfg.swa.window_size, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + fn prepare_decoder_attention_mask( + &self, + b_size: usize, + tgt_len: usize, + seqlen_offset: usize, + ) -> Result { + let sliding_window = self.sliding_window / 2; + let mask: Vec<_> = (0..tgt_len) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) + .collect(); + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(self.dtype) + } + + pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { + let (b_size, seq_len) = input_ids.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + Some(mask) + }; + let mut xs = self.embed_tokens.forward(input_ids)?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? + } + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 8bc37475e4..3ae56e839a 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -1,3 +1,4 @@ +pub mod based; pub mod beit; pub mod bert; pub mod bigcode; From d7a9bd00aba07582666ff7539e40b42fb3cc2378 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 14 Aug 2024 09:01:12 +0100 Subject: [PATCH 46/75] Fix the device for the bert attention mask. (#2414) --- candle-transformers/src/models/bert.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 42486a2da2..2262aa1a8c 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -501,5 +501,6 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result< }; let attention_mask = attention_mask.to_dtype(dtype)?; // torch.finfo(dtype).min - (attention_mask.ones_like()? - attention_mask)?.broadcast_mul(&Tensor::try_from(f32::MIN)?) + (attention_mask.ones_like()? - &attention_mask)? + .broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?) } From 3d40ffcec809e8742c56955600b0ddf27f7eb0f7 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 14 Aug 2024 09:13:53 +0100 Subject: [PATCH 47/75] Clippy fixes. (#2415) * Clippy fixes. * Bump the web_sys required version. --- candle-wasm-examples/llama2-c/Cargo.toml | 2 +- candle-wasm-examples/llama2-c/src/app.rs | 12 +++++------- candle-wasm-examples/whisper/Cargo.toml | 2 +- candle-wasm-examples/whisper/src/app.rs | 12 +++++------- candle-wasm-examples/yolo/Cargo.toml | 2 +- candle-wasm-examples/yolo/src/app.rs | 11 +++++------ 6 files changed, 18 insertions(+), 23 deletions(-) diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml index d46cdafa54..af4737656b 100644 --- a/candle-wasm-examples/llama2-c/Cargo.toml +++ b/candle-wasm-examples/llama2-c/Cargo.toml @@ -35,7 +35,7 @@ yew-agent = "0.2.0" yew = { version = "0.20.0", features = ["csr"] } [dependencies.web-sys] -version = "0.3.64" +version = "0.3.70" features = [ 'Blob', 'Document', diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs index 1e40b77e54..7456a5bdaa 100644 --- a/candle-wasm-examples/llama2-c/src/app.rs +++ b/candle-wasm-examples/llama2-c/src/app.rs @@ -9,13 +9,11 @@ use yew_agent::{Bridge, Bridged}; async fn fetch_url(url: &str) -> Result, JsValue> { use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response}; let window = web_sys::window().ok_or("window")?; - let mut opts = RequestInit::new(); - let opts = opts - .method("GET") - .mode(RequestMode::Cors) - .cache(RequestCache::NoCache); - - let request = Request::new_with_str_and_init(url, opts)?; + let opts = RequestInit::new(); + opts.set_method("GET"); + opts.set_mode(RequestMode::Cors); + opts.set_cache(RequestCache::NoCache); + let request = Request::new_with_str_and_init(url, &opts)?; let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?; diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml index 745b7ae782..526a64425a 100644 --- a/candle-wasm-examples/whisper/Cargo.toml +++ b/candle-wasm-examples/whisper/Cargo.toml @@ -35,7 +35,7 @@ yew-agent = "0.2.0" yew = { version = "0.20.0", features = ["csr"] } [dependencies.web-sys] -version = "0.3.64" +version = "0.3.70" features = [ 'Blob', 'Document', diff --git a/candle-wasm-examples/whisper/src/app.rs b/candle-wasm-examples/whisper/src/app.rs index e344096c59..a2c0ddabcb 100644 --- a/candle-wasm-examples/whisper/src/app.rs +++ b/candle-wasm-examples/whisper/src/app.rs @@ -18,13 +18,11 @@ const SAMPLE_NAMES: [&str; 6] = [ async fn fetch_url(url: &str) -> Result, JsValue> { use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response}; let window = web_sys::window().ok_or("window")?; - let mut opts = RequestInit::new(); - let opts = opts - .method("GET") - .mode(RequestMode::Cors) - .cache(RequestCache::NoCache); - - let request = Request::new_with_str_and_init(url, opts)?; + let opts = RequestInit::new(); + opts.set_method("GET"); + opts.set_mode(RequestMode::Cors); + opts.set_cache(RequestCache::NoCache); + let request = Request::new_with_str_and_init(url, &opts)?; let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?; diff --git a/candle-wasm-examples/yolo/Cargo.toml b/candle-wasm-examples/yolo/Cargo.toml index ac76f9a7f2..e03319a043 100644 --- a/candle-wasm-examples/yolo/Cargo.toml +++ b/candle-wasm-examples/yolo/Cargo.toml @@ -35,7 +35,7 @@ yew-agent = "0.2.0" yew = { version = "0.20.0", features = ["csr"] } [dependencies.web-sys] -version = "0.3.64" +version = "0.3.70" features = [ 'Blob', 'CanvasRenderingContext2d', diff --git a/candle-wasm-examples/yolo/src/app.rs b/candle-wasm-examples/yolo/src/app.rs index a68284fa7b..61253fb5a8 100644 --- a/candle-wasm-examples/yolo/src/app.rs +++ b/candle-wasm-examples/yolo/src/app.rs @@ -8,13 +8,12 @@ use yew_agent::{Bridge, Bridged}; async fn fetch_url(url: &str) -> Result, JsValue> { use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response}; let window = web_sys::window().ok_or("window")?; - let mut opts = RequestInit::new(); - let opts = opts - .method("GET") - .mode(RequestMode::Cors) - .cache(RequestCache::NoCache); + let opts = RequestInit::new(); + opts.set_method("GET"); + opts.set_mode(RequestMode::Cors); + opts.set_cache(RequestCache::NoCache); - let request = Request::new_with_str_and_init(url, opts)?; + let request = Request::new_with_str_and_init(url, &opts)?; let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?; From c5c5d498aaec5591dea50a6b954824cd3fe1cb9d Mon Sep 17 00:00:00 2001 From: joshpopelka20 <107133849+joshpopelka20@users.noreply.github.com> Date: Wed, 14 Aug 2024 07:18:53 -0400 Subject: [PATCH 48/75] Update flash_fwd_launch_template.h with fix for kernels (#16) --- .../kernels/flash_fwd_launch_template.h | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h index 1f78041e16..29918c87c9 100644 --- a/candle-flash-attn/kernels/flash_fwd_launch_template.h +++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h @@ -171,14 +171,12 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if constexpr(!Is_dropout) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } }); } From 2386e4e4683f31530ab6d02e7516dfffe683331a Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 14 Aug 2024 07:25:04 -0400 Subject: [PATCH 49/75] Build fixes --- candle-nn/src/layer_norm.rs | 10 +--------- candle-nn/src/rope.rs | 4 +--- candle-transformers/src/models/based.rs | 16 ++++++++-------- candle-transformers/src/models/codegeex4_9b.rs | 6 +++--- candle-transformers/src/models/flux/model.rs | 10 +++++----- candle-transformers/src/models/glm4.rs | 6 +++--- 6 files changed, 21 insertions(+), 31 deletions(-) diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 38fdbaaa5e..d38e64e582 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -30,15 +30,10 @@ //! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450 use std::marker::PhantomData; -#[cfg(feature = "cuda")] -use std::{ - mem, - sync::{Arc, Mutex}, -}; #[cfg(feature = "cuda")] use candle::cuda_backend::{ - cudarc::driver::{sys, DeviceRepr, LaunchAsync, LaunchConfig}, + cudarc::driver::{DeviceRepr, LaunchAsync, LaunchConfig}, kernel_name, kernels, CudaDType, WrapErr, }; @@ -50,9 +45,6 @@ use candle::{ use candle::{DType, Module, Result, Tensor, D}; -#[cfg(feature = "cuda")] -static MAX_GRID_Y: Mutex> = Mutex::new(None); - #[derive(Debug, Clone, Copy, PartialEq)] pub struct LayerNormConfig { pub eps: f64, diff --git a/candle-nn/src/rope.rs b/candle-nn/src/rope.rs index f405ec02a4..2f3af072af 100644 --- a/candle-nn/src/rope.rs +++ b/candle-nn/src/rope.rs @@ -8,9 +8,7 @@ use candle::{ #[cfg(feature = "cuda")] use candle::cuda_backend::{ - cudarc::driver::{ - CudaFunction, CudaStream, DeviceRepr, DriverError, LaunchAsync, LaunchConfig, - }, + cudarc::driver::{DeviceRepr, LaunchAsync, LaunchConfig}, kernel_name, kernels, CudaDType, }; diff --git a/candle-transformers/src/models/based.rs b/candle-transformers/src/models/based.rs index aa28f52333..534ed3c964 100644 --- a/candle-transformers/src/models/based.rs +++ b/candle-transformers/src/models/based.rs @@ -8,8 +8,8 @@ use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{ - conv1d_no_bias, linear, linear_no_bias, ops::softmax_last_dim, rms_norm, Conv1d, Conv1dConfig, - Func, Linear, RmsNorm, VarBuilder, + conv1d_no_bias, layer_norm::RmsNormNonQuantized, linear, linear_no_bias, ops::softmax_last_dim, + rms_norm_non_quant, Conv1d, Conv1dConfig, Func, Linear, RmsNorm, VarBuilder, }; use std::sync::Arc; @@ -460,16 +460,16 @@ impl SequenceMixer { #[derive(Debug, Clone)] struct DecoderLayer { mlp: MLP, - norm1: RmsNorm, - norm2: RmsNorm, + norm1: RmsNorm, + norm2: RmsNorm, mixer: SequenceMixer, } impl DecoderLayer { fn new(layer_idx: usize, cfg: &Config, vb: VarBuilder) -> Result { let mlp = MLP::new(cfg, vb.pp("mlp"))?; - let norm1 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm1"))?; - let norm2 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm2"))?; + let norm1 = rms_norm_non_quant(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm1"))?; + let norm2 = rms_norm_non_quant(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm2"))?; let l_attn = cfg.alt_mixer_layers.contains(&layer_idx); let sw_attn = cfg.alt_mixer_2_layers.contains(&layer_idx); @@ -510,7 +510,7 @@ impl DecoderLayer { pub struct Model { embed_tokens: super::with_tracing::Embedding, layers: Vec, - norm: RmsNorm, + norm: RmsNorm, lm_head: Linear, sliding_window: usize, device: Device, @@ -529,7 +529,7 @@ impl Model { let layer = DecoderLayer::new(layer_idx, cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb_m.pp("ln_f"))?; + let norm = rms_norm_non_quant(cfg.hidden_size, cfg.layer_norm_epsilon, vb_m.pp("ln_f"))?; Ok(Self { embed_tokens, layers, diff --git a/candle-transformers/src/models/codegeex4_9b.rs b/candle-transformers/src/models/codegeex4_9b.rs index aaa99fd96d..ae4b629601 100644 --- a/candle-transformers/src/models/codegeex4_9b.rs +++ b/candle-transformers/src/models/codegeex4_9b.rs @@ -384,7 +384,7 @@ struct Block { impl Block { fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result { let input_layernorm = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("input_layernorm"), @@ -398,7 +398,7 @@ impl Block { )? }; let post_attention_layernorm = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("post_attention_layernorm"), @@ -470,7 +470,7 @@ impl Transformer { } let final_layernorm = if cfg.post_layer_norm { let ln = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("final_layernorm"), diff --git a/candle-transformers/src/models/flux/model.rs b/candle-transformers/src/models/flux/model.rs index 4e47873fe0..4a12e915d8 100644 --- a/candle-transformers/src/models/flux/model.rs +++ b/candle-transformers/src/models/flux/model.rs @@ -1,5 +1,5 @@ use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{LayerNorm, Linear, RmsNorm, VarBuilder}; +use candle_nn::{layer_norm::RmsNormNonQuantized, LayerNorm, Linear, RmsNorm, VarBuilder}; // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/model.py#L12 #[derive(Debug, Clone)] @@ -195,16 +195,16 @@ impl candle::Module for MlpEmbedder { #[derive(Debug, Clone)] pub struct QkNorm { - query_norm: RmsNorm, - key_norm: RmsNorm, + query_norm: RmsNorm, + key_norm: RmsNorm, } impl QkNorm { fn new(dim: usize, vb: VarBuilder) -> Result { let query_norm = vb.get(dim, "query_norm.scale")?; - let query_norm = RmsNorm::new(query_norm, 1e-6); + let query_norm = RmsNorm::::new(query_norm, 1e-6); let key_norm = vb.get(dim, "key_norm.scale")?; - let key_norm = RmsNorm::new(key_norm, 1e-6); + let key_norm = RmsNorm::::new(key_norm, 1e-6); Ok(Self { query_norm, key_norm, diff --git a/candle-transformers/src/models/glm4.rs b/candle-transformers/src/models/glm4.rs index 3b436eaa6d..00ead338d0 100644 --- a/candle-transformers/src/models/glm4.rs +++ b/candle-transformers/src/models/glm4.rs @@ -383,7 +383,7 @@ struct Block { impl Block { fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result { let input_layernorm = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("input_layernorm"), @@ -397,7 +397,7 @@ impl Block { )? }; let post_attention_layernorm = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("post_attention_layernorm"), @@ -469,7 +469,7 @@ impl Transformer { } let final_layernorm = if cfg.post_layer_norm { let ln = if cfg.rmsnorm { - candle_nn::rms_norm( + candle_nn::rms_norm_non_quant( cfg.hidden_size, cfg.layernorm_epsilon, vb.pp("final_layernorm"), From 1b1974e0e7f8a89f3d00c01eb695edf9158d55b4 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Wed, 21 Aug 2024 09:16:15 -0400 Subject: [PATCH 50/75] Add GGUF BF16 support (#17) * Add GGUF bf16 type support * Add non avx impl for vec_dot_bf16 * Fix from_u32 * Fix loading * Fix dequant of bf16 --- candle-core/src/cpu/avx.rs | 83 +++++++++++++++++++++++++- candle-core/src/cpu/kernels.rs | 7 +++ candle-core/src/cpu/mod.rs | 62 ++++++++++++++++++- candle-core/src/quantized/cuda.rs | 1 + candle-core/src/quantized/ggml_file.rs | 1 + candle-core/src/quantized/k_quants.rs | 46 +++++++++++++- candle-core/src/quantized/mod.rs | 14 +++-- 7 files changed, 205 insertions(+), 9 deletions(-) diff --git a/candle-core/src/cpu/avx.rs b/candle-core/src/cpu/avx.rs index 9398a3460a..113fc14ced 100644 --- a/candle-core/src/cpu/avx.rs +++ b/candle-core/src/cpu/avx.rs @@ -1,10 +1,10 @@ -use super::{Cpu, CpuF16}; +use super::{Cpu, CpuBF16, CpuF16}; #[cfg(target_arch = "x86")] use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; -use half::f16; +use half::{bf16, f16}; pub struct CurrentCpu {} @@ -146,3 +146,82 @@ impl CpuF16 for CurrentCpuF16 { *y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); } } + +pub struct CurrentCpuBF16 {} +impl CpuBF16 for CurrentCpuBF16 { + type Unit = __m256; + type Array = [__m256; ARR]; + + const STEP: usize = STEP; + const EPR: usize = EPR; + + fn n() -> usize { + ARR + } + + unsafe fn zero() -> Self::Unit { + _mm256_setzero_ps() + } + + unsafe fn zero_array() -> Self::Array { + [Self::zero(); ARR] + } + + unsafe fn from_f32(v: f32) -> Self::Unit { + _mm256_set1_ps(v) + } + + #[cfg(target_feature = "f16c")] + unsafe fn load(mem_addr: *const bf16) -> Self::Unit { + _mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i)) + } + + #[cfg(not(target_feature = "f16c"))] + unsafe fn load(mem_addr: *const bf16) -> Self::Unit { + let mut tmp = [0.0f32; 8]; + for i in 0..8 { + tmp[i] = (*mem_addr.add(i)).to_f32(); + } + _mm256_loadu_ps(tmp.as_ptr()) + } + + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit { + _mm256_add_ps(a, b) + } + + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit { + _mm256_add_ps(_mm256_mul_ps(b, c), a) + } + + #[cfg(target_feature = "f16c")] + unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) { + _mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0)) + } + + #[cfg(not(target_feature = "f16c"))] + unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) { + let mut tmp = [0.0f32; 8]; + _mm256_storeu_ps(tmp.as_mut_ptr(), a); + for i in 0..8 { + *mem_addr.add(i) = bf16::from_f32(tmp[i]); + } + } + + unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) { + let mut offset = ARR >> 1; + for i in 0..offset { + x[i] = _mm256_add_ps(x[i], x[offset + i]); + } + offset >>= 1; + for i in 0..offset { + x[i] = _mm256_add_ps(x[i], x[offset + i]); + } + offset >>= 1; + for i in 0..offset { + x[i] = _mm256_add_ps(x[i], x[offset + i]); + } + let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1)); + let t1 = _mm_hadd_ps(t0, t0); + *y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); + } +} diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs index fe0e241622..fd6da1f1ff 100644 --- a/candle-core/src/cpu/kernels.rs +++ b/candle-core/src/cpu/kernels.rs @@ -121,6 +121,13 @@ impl VecOps for half::bf16 { fn max(self, other: Self) -> Self { Self::max(self, other) } + + #[inline(always)] + unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { + let mut res_f32 = 0f32; + super::vec_dot_bf16(lhs, rhs, &mut res_f32, len); + *res = half::bf16::from_f32(res_f32); + } } impl VecOps for u8 { #[inline(always)] diff --git a/candle-core/src/cpu/mod.rs b/candle-core/src/cpu/mod.rs index e7d8b6906f..0b77e6ecb7 100644 --- a/candle-core/src/cpu/mod.rs +++ b/candle-core/src/cpu/mod.rs @@ -36,14 +36,33 @@ trait CpuF16 { unsafe fn from_f32(v: f32) -> Self::Unit; unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit); } -use half::f16; + +#[allow(unused)] +trait CpuBF16 { + type Unit; + type Array; + const STEP: usize; + const EPR: usize; + + fn n() -> usize; + unsafe fn zero() -> Self::Unit; + unsafe fn zero_array() -> Self::Array; + unsafe fn load(mem_addr: *const bf16) -> Self::Unit; + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit; + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit; + unsafe fn vec_reduce(x: Self::Array, y: *mut f32); + unsafe fn from_f32(v: f32) -> Self::Unit; + unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit); +} + +use half::{bf16, f16}; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[cfg(target_feature = "avx")] pub mod avx; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[cfg(target_feature = "avx")] -pub use avx::{CurrentCpu, CurrentCpuF16}; +pub use avx::{CurrentCpu, CurrentCpuBF16, CurrentCpuF16}; #[cfg(target_arch = "wasm32")] #[cfg(target_feature = "simd128")] @@ -170,6 +189,34 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f *c = sumf; } +#[cfg(target_feature = "avx")] +#[inline(always)] +pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) { + let mut sumf = 0.0f32; + let np = k & !(CurrentCpuBF16::STEP - 1); + + let mut sum = CurrentCpuBF16::zero_array(); + let mut ax = CurrentCpuBF16::zero_array(); + let mut ay = CurrentCpuBF16::zero_array(); + + for i in (0..np).step_by(CurrentCpuBF16::STEP) { + for j in 0..CurrentCpuBF16::n() { + ax[j] = CurrentCpuBF16::load(a_row.add(i + j * CurrentCpuBF16::EPR)); + ay[j] = CurrentCpuBF16::load(b_row.add(i + j * CurrentCpuBF16::EPR)); + + sum[j] = CurrentCpuBF16::vec_fma(sum[j], ax[j], ay[j]); + } + } + + CurrentCpuBF16::vec_reduce(sum, &mut sumf); + + // leftovers + for i in np..k { + sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32(); + } + *c = sumf; +} + #[cfg(not(target_feature = "avx"))] #[inline(always)] pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) { @@ -180,3 +227,14 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f } *c = sum; } + +#[cfg(not(target_feature = "avx"))] +#[inline(always)] +pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) { + // leftovers + let mut sum = 0.0; + for i in 0..k { + sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32(); + } + *c = sum; +} diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 30ca7b98e3..426b818c1d 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -409,6 +409,7 @@ impl QCudaStorage { match self.dtype { GgmlDType::F32 => deq::(&buffer, block_len, &mut out)?, GgmlDType::F16 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::BF16 => deq::(&buffer, block_len, &mut out)?, GgmlDType::Q4_0 => deq::(&buffer, block_len, &mut out)?, GgmlDType::Q4_1 => deq::(&buffer, block_len, &mut out)?, GgmlDType::Q5_0 => deq::(&buffer, block_len, &mut out)?, diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 99200bbd06..ea5ec02578 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -153,6 +153,7 @@ pub fn qtensor_from_ggml( match ggml_dtype { GgmlDType::F32 => from_raw_data::(raw_data, size_in_bytes, dims, device), GgmlDType::F16 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::BF16 => from_raw_data::(raw_data, size_in_bytes, dims, device), GgmlDType::Q4_0 => { from_raw_data::(raw_data, size_in_bytes, dims, device) } diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 6210ac1e9f..2e92921954 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -5,7 +5,7 @@ use super::utils::{ use super::GgmlDType; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; -use half::f16; +use half::{bf16, f16}; use rayon::prelude::*; // Default to QK_K 256 rather than 64. @@ -1963,3 +1963,47 @@ impl GgmlType for f16 { Ok(()) } } + +impl GgmlType for bf16 { + const DTYPE: GgmlDType = GgmlDType::BF16; + const BLCK_SIZE: usize = 1; + type VecDotType = bf16; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if xs.len() < n { + crate::bail!("size mismatch {} < {n}", xs.len()) + } + if ys.len() < n { + crate::bail!("size mismatch {} < {n}", ys.len()) + } + let mut res = 0f32; + unsafe { crate::cpu::vec_dot_bf16(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; + Ok(res) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = bf16::from_f32(*x) + } + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = x.to_f32() + } + Ok(()) + } +} diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index ff00f36389..39a30d12c5 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -27,7 +27,7 @@ pub mod neon; #[cfg(target_feature = "simd128")] pub mod simd128; pub mod utils; -use half::f16; +use half::{bf16, f16}; pub use k_quants::GgmlType; @@ -145,6 +145,7 @@ impl QStorage { pub enum GgmlDType { F32, F16, + BF16, Q4_0, Q4_1, Q5_0, @@ -176,6 +177,8 @@ impl GgmlDType { 13 => Self::Q5K, 14 => Self::Q6K, 15 => Self::Q8K, + // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 + 30 => Self::BF16, _ => crate::bail!("unknown dtype for tensor {u}"), }; Ok(dtype) @@ -197,6 +200,8 @@ impl GgmlDType { Self::Q5K => 13, Self::Q6K => 14, Self::Q8K => 15, + // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 + Self::BF16 => 30, } } @@ -217,6 +222,7 @@ impl GgmlDType { Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]), Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]), Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), + Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]), } } /// The type size for blocks in bytes. @@ -224,7 +230,7 @@ impl GgmlDType { use k_quants::*; match self { Self::F32 => 4, - Self::F16 => 2, + Self::F16 | Self::BF16 => 2, Self::Q4_0 => std::mem::size_of::(), Self::Q4_1 => std::mem::size_of::(), Self::Q5_0 => std::mem::size_of::(), @@ -245,7 +251,7 @@ impl GgmlDType { pub fn block_size(&self) -> usize { match self { Self::F32 => 1, - Self::F16 => 1, + Self::F16 | Self::BF16 => 1, Self::Q4_0 => k_quants::QK4_0, Self::Q4_1 => k_quants::QK4_1, Self::Q5_0 => k_quants::QK5_0, @@ -461,7 +467,7 @@ thread_local! { impl QMatMul { pub fn from_arc(qtensor: std::sync::Arc) -> Result { let dequantize = match qtensor.dtype() { - GgmlDType::F32 | GgmlDType::F16 => true, + GgmlDType::F32 | GgmlDType::F16 | GgmlDType::BF16 => true, _ => DEQUANTIZE_ALL.with(|b| *b), }; let t = if dequantize { From 6fbddd6d4b27ce81dd677f6c81a922c7900402b3 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 21 Aug 2024 21:38:32 -0400 Subject: [PATCH 51/75] Complete merge --- candle-transformers/src/models/based.rs | 590 ------------------------ candle-transformers/src/models/mod.rs | 3 +- 2 files changed, 1 insertion(+), 592 deletions(-) diff --git a/candle-transformers/src/models/based.rs b/candle-transformers/src/models/based.rs index c976d8d240..534ed3c964 100644 --- a/candle-transformers/src/models/based.rs +++ b/candle-transformers/src/models/based.rs @@ -6,596 +6,6 @@ //! Original code: //! https://github.com/HazyResearch/based -use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::{ - conv1d_no_bias, linear, linear_no_bias, ops::softmax_last_dim, rms_norm, Conv1d, Conv1dConfig, - Func, Linear, RmsNorm, VarBuilder, -}; -use std::sync::Arc; - -#[derive(Debug, Clone, serde::Deserialize)] -pub struct LinearAttentionFeatureMapConfig { - input_dim: usize, -} - -#[derive(Debug, Clone, serde::Deserialize)] -pub struct LinearAttentionConfig { - num_heads: usize, - feature_dim: usize, - feature_map: LinearAttentionFeatureMapConfig, -} - -#[derive(Debug, Clone, serde::Deserialize)] -pub struct SlidingWindowAttentionConfig { - num_heads: usize, - window_size: usize, -} - -#[derive(Debug, Clone, serde::Deserialize)] -pub struct Config { - vocab_size: usize, - #[serde(rename = "n_embd")] - hidden_size: usize, - #[serde(rename = "n_inner")] - intermediate_size: usize, - #[serde(rename = "n_layer")] - num_hidden_layers: usize, - #[serde(rename = "n_head")] - num_attention_heads: usize, - - layer_norm_epsilon: f64, - #[serde(default = "default_rope", rename = "rotary_emb_base")] - rope_theta: f64, - - alt_mixer_layers: Vec, - alt_mixer_2_layers: Vec, - #[serde(rename = "alt_mixer")] - la: LinearAttentionConfig, - #[serde(rename = "alt_mixer_2")] - swa: SlidingWindowAttentionConfig, -} - -fn default_rope() -> f64 { - 10_000.0 -} - -#[derive(Debug, Clone)] -#[allow(clippy::upper_case_acronyms)] -struct MLP { - fc1: Linear, - fc2: Linear, -} - -impl MLP { - fn new(cfg: &Config, vb: VarBuilder) -> Result { - let fc1 = linear_no_bias(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("fc1"))?; - let fc2 = linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?; - Ok(Self { fc1, fc2 }) - } -} - -// Swiglu implementation. -// Not using Activation::Swiglu because this has the gate and y arguments switched compared to the version in candle-nn/src/ops.rs -fn swiglu(xs: &Tensor) -> Result { - let xs = xs.chunk(2, D::Minus1)?; - &xs[1].silu()? * &xs[0] -} - -impl Module for MLP { - fn forward(&self, xs: &Tensor) -> Result { - let xs = xs.apply(&self.fc1)?; - let xs = swiglu(&xs)?; - let xs = xs.apply(&self.fc2)?; - Ok(xs) - } -} - -// A gated convolutional block. -#[derive(Debug, Clone)] -struct BasedConv { - in_proj: Linear, - out_proj: Linear, - conv: Conv1d, - state: Tensor, -} - -impl BasedConv { - fn new(cfg: &Config, vb: VarBuilder) -> Result { - let dim = cfg.hidden_size * 2; - - let conv1d_cfg = Conv1dConfig { - groups: dim, - padding: 2, - ..Default::default() - }; - - let in_proj = linear(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("in_proj"))?; - let out_proj = linear(dim, cfg.hidden_size, vb.pp("out_proj"))?; - let conv = conv1d_no_bias(dim, dim, 3, conv1d_cfg, vb.pp("conv.conv"))?; - let state = Tensor::zeros((1, dim, 3), vb.dtype(), vb.device())?; - Ok(Self { - in_proj, - out_proj, - conv, - state, - }) - } - - fn step(&mut self, xs: &Tensor) -> Result { - self.state = self.state.roll(-1, D::Minus1)?; - let (_, _, l) = self.state.dims3()?; - self.state = self.state.narrow(D::Minus1, 0, l - 1)?; - self.state = Tensor::cat(&[&self.state, &xs.transpose(1, 2)?], 2)?; - - let xs = (&self.state * self.conv.weight().permute((1, 0, 2))?)? - .sum_keepdim(0)? - .sum(D::Minus1)?; - - let xs = xs.unsqueeze(1)?; - - Ok(xs) - } - - fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result { - let xs = xs.apply(&self.in_proj)?; - let us = xs.chunk(2, D::Minus1)?; - let (_b, l, _d) = us[0].dims3()?; - let u_conv = if seqlen_offset > 0 { - self.step(&us[0])? - } else { - let k = std::cmp::min(3, l); - self.state = self.state.narrow(D::Minus1, 0, 3 - k)?; - let xs = us[0].narrow(1, l - k, k)?.transpose(1, 2)?; - self.state = Tensor::cat(&[&self.state, &xs], 2)?; - - us[0] - .transpose(1, 2)? - .apply(&self.conv)? - .narrow(D::Minus1, 0, l)? - .transpose(1, 2)? - }; - - let u_conv = u_conv.silu()?; - let v = u_conv.broadcast_mul(&us[1])?; - let xs = v.apply(&self.out_proj)?; - - Ok(xs) - } -} - -// Linear attention approximating softmax using second order Taylor polynomials. -#[derive(Debug, Clone)] -struct LinearAttention { - proj_q: Linear, - proj_k: Linear, - proj_v: Linear, - out_proj: Linear, - feature_dim: usize, - num_heads: usize, - input_dim: usize, - k_state: Tensor, - kv_state: Tensor, -} - -impl LinearAttention { - fn new(cfg: &Config, vb: VarBuilder) -> Result { - let input_dim = cfg.la.feature_map.input_dim; - let out_proj = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("out_proj"))?; - let proj_k = linear_no_bias( - cfg.hidden_size, - cfg.la.num_heads * cfg.la.feature_dim, - vb.pp("proj_k"), - )?; - let proj_q = linear_no_bias( - cfg.hidden_size, - cfg.la.num_heads * cfg.la.feature_dim, - vb.pp("proj_q"), - )?; - - let proj_v = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("proj_v"))?; - let expanded_size = cfg.la.feature_dim.pow(2) + cfg.la.feature_dim + 1; - let k_state = Tensor::zeros( - (1, cfg.la.num_heads, 1, 1, expanded_size), - vb.dtype(), - vb.device(), - )?; - let kv_state = Tensor::zeros( - (1, cfg.la.num_heads, cfg.la.feature_dim, expanded_size), - vb.dtype(), - vb.device(), - )?; - - Ok(Self { - proj_q, - proj_k, - proj_v, - out_proj, - feature_dim: cfg.la.feature_dim, - num_heads: cfg.la.num_heads, - input_dim, - k_state, - kv_state, - }) - } - - fn taylor_expansion(&self) -> Result> { - let r2 = std::f64::consts::SQRT_2; - let rd = (self.input_dim as f64).sqrt(); - let rrd = rd.sqrt(); - - Ok(Func::new(move |xs| { - let dims = xs.dims(); - let mut d = dims.to_vec(); - if let Some(last) = d.last_mut() { - *last = 1; - }; - - let x = xs - .unsqueeze(D::Minus1)? - .broadcast_mul(&xs.unsqueeze(D::Minus2)?)?; - let x = (x.flatten_from(D::Minus2)? / r2)?; - let o = Tensor::ones(d, xs.dtype(), xs.device())?; - let x = Tensor::cat(&[o, (xs / rrd)?, (&x / rd)?], D::Minus1)?; - - Ok(x) - })) - } - - fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result { - let eps = 1e-12; - - let feature_map = self.taylor_expansion()?; - - let (b, l, d) = xs.dims3()?; - let q = xs.apply(&self.proj_q)?; - let k = xs.apply(&self.proj_k)?; - let v = xs.apply(&self.proj_v)?; - - let q = q - .reshape((b, l, self.num_heads, self.feature_dim))? - .transpose(1, 2)? - .contiguous()?; - let k = k - .reshape((b, l, self.num_heads, self.feature_dim))? - .transpose(1, 2)? - .contiguous()?; - let v = v - .reshape((b, l, self.num_heads, d / self.num_heads))? - .transpose(1, 2)? - .contiguous()?; - - let q = feature_map.forward(&q)?; - let k = feature_map.forward(&k)?; - - let y = if seqlen_offset > 0 { - let (_b, _h, l, _d) = k.dims4()?; - let q = q.unsqueeze(D::Minus2)?; - let k = k.unsqueeze(D::Minus2)?; - let v = v.unsqueeze(D::Minus1)?; - let kn = k.narrow(D::Minus1, l - 1, 1)?; - let vn = v.narrow(D::Minus1, l - 1, 1)?; - - self.k_state = self.k_state.broadcast_add(&kn)?; - self.kv_state = self.kv_state.broadcast_add(&kn.broadcast_mul(&vn)?)?; - - let num = q.broadcast_mul(&self.kv_state)?.sum(D::Minus1)?; - let den = (q.broadcast_mul(&self.k_state)?.sum(D::Minus1)? + eps)?; - num.broadcast_div(&den)? - } else { - self.k_state = k.sum(2)?.unsqueeze(2)?.unsqueeze(3)?; - self.kv_state = k - .transpose(2, 3)? - .matmul(&v)? - .transpose(2, 3)? - .unsqueeze(2)?; - let aqk = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?; - let tril = Tensor::tril2(l, aqk.dtype(), aqk.device())?; - let aqk = aqk.broadcast_mul(&tril)?.matmul(&v)?; - - let z = (1f64 / (q.mul(&k.cumsum(2)?)?.sum(D::Minus1)? + eps)?)?; - aqk.broadcast_mul(&z.unsqueeze(D::Minus1)?)? - }; - - let (b, h, l, d) = y.dims4()?; - let y = y.permute((0, 2, 1, 3))?.reshape((b, l, h * d))?; - let y = self.out_proj.forward(&y)?; - - Ok(y) - } -} - -// Rotary embeddings used in local attention. -#[derive(Debug, Clone)] -struct RotaryEmbedding { - sin: Tensor, - cos: Tensor, -} - -impl RotaryEmbedding { - fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { - let dim = cfg.hidden_size / cfg.num_attention_heads; - let max_seq_len = 2048; // Hardcoded, missing from config. - let inv_freq: Vec<_> = (0..dim) - .step_by(2) - .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) - .collect(); - let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; - let t = Tensor::arange(0u32, max_seq_len as u32, dev)? - .to_dtype(dtype)? - .reshape((max_seq_len, 1))?; - let freqs = t.matmul(&inv_freq)?; - Ok(Self { - sin: freqs.sin()?, - cos: freqs.cos()?, - }) - } - - fn apply_rotary_emb_qkv( - &self, - q: &Tensor, - k: &Tensor, - seqlen_offset: usize, - ) -> Result<(Tensor, Tensor)> { - let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; - let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; - let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; - let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; - let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; - Ok((q_embed, k_embed)) - } -} - -// Local attention using a small sliding window. -#[derive(Debug, Clone)] -struct SlidingWindowAttention { - wqkv: Linear, - out_proj: Linear, - num_heads: usize, - head_dim: usize, - hidden_size: usize, - rotary_emb: Arc, - kv_cache: Option<(Tensor, Tensor)>, -} - -impl SlidingWindowAttention { - fn new(cfg: &Config, vb: VarBuilder) -> Result { - let hidden_size = cfg.hidden_size; - let num_heads = cfg.swa.num_heads; - let head_dim = hidden_size / num_heads; - let out_proj = linear_no_bias(hidden_size, hidden_size, vb.pp("out_proj"))?; - let wqkv = linear_no_bias(hidden_size, hidden_size * 3, vb.pp("Wqkv"))?; - let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?); - Ok(Self { - wqkv, - out_proj, - hidden_size, - num_heads, - head_dim, - rotary_emb, - kv_cache: None, - }) - } - - fn forward( - &mut self, - xs: &Tensor, - attention_mask: Option<&Tensor>, - seqlen_offset: usize, - ) -> Result { - let (b_sz, q_len, _) = xs.dims3()?; - - let qkv = xs.apply(&self.wqkv)?; - let qkv = qkv.reshape((b_sz, q_len, 3, (), self.head_dim))?; - - let q = qkv.i((.., .., 0))?; - let k = qkv.i((.., .., 1))?; - let v = qkv.i((.., .., 2))?; - - let q = q - .reshape((b_sz, q_len, self.num_heads, self.head_dim))? - .transpose(1, 2)?; - let k = k - .reshape((b_sz, q_len, self.num_heads, self.head_dim))? - .transpose(1, 2)?; - let v = v - .reshape((b_sz, q_len, self.num_heads, self.head_dim))? - .transpose(1, 2)?; - - let (q, k) = self - .rotary_emb - .apply_rotary_emb_qkv(&q, &k, seqlen_offset)?; - - let (k, v) = match &self.kv_cache { - None => (k, v), - Some((prev_k, prev_v)) => { - let k = Tensor::cat(&[prev_k, &k], 2)?; - let v = Tensor::cat(&[prev_v, &v], 2)?; - (k, v) - } - }; - self.kv_cache = Some((k.clone(), v.clone())); - - let scale = 1f64 / f64::sqrt(self.head_dim as f64); - let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; - - let attn_weights = match attention_mask { - None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, - }; - let attn_weights = softmax_last_dim(&attn_weights)?; - let attn_output = attn_weights.matmul(&v)?; - let out = attn_output - .transpose(1, 2)? - .reshape((b_sz, q_len, self.hidden_size))? - .apply(&self.out_proj)?; - - Ok(out) - } -} - -// The model layers use three types of mixers. -#[derive(Debug, Clone)] -enum SequenceMixer { - Based(BasedConv), - Linear(LinearAttention), - Sliding(SlidingWindowAttention), -} - -impl SequenceMixer { - fn forward( - &mut self, - xs: &Tensor, - attention_mask: Option<&Tensor>, - pos: usize, - ) -> Result { - match self { - Self::Based(b) => b.forward(xs, pos), - Self::Linear(b) => b.forward(xs, pos), - Self::Sliding(b) => b.forward(xs, attention_mask, pos), - } - } -} - -#[derive(Debug, Clone)] -struct DecoderLayer { - mlp: MLP, - norm1: RmsNorm, - norm2: RmsNorm, - mixer: SequenceMixer, -} - -impl DecoderLayer { - fn new(layer_idx: usize, cfg: &Config, vb: VarBuilder) -> Result { - let mlp = MLP::new(cfg, vb.pp("mlp"))?; - let norm1 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm1"))?; - let norm2 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm2"))?; - - let l_attn = cfg.alt_mixer_layers.contains(&layer_idx); - let sw_attn = cfg.alt_mixer_2_layers.contains(&layer_idx); - - let mixer = if l_attn { - SequenceMixer::Linear(LinearAttention::new(cfg, vb.pp("mixer"))?) - } else if sw_attn { - SequenceMixer::Sliding(SlidingWindowAttention::new(cfg, vb.pp("mixer"))?) - } else { - SequenceMixer::Based(BasedConv::new(cfg, vb.pp("mixer"))?) - }; - - Ok(Self { - mlp, - norm1, - norm2, - mixer, - }) - } - - fn forward( - &mut self, - xs: &Tensor, - attention_mask: Option<&Tensor>, - seqlen_offset: usize, - ) -> Result { - let residual = xs; - let xs = self.norm1.forward(xs)?; - let xs = self.mixer.forward(&xs, attention_mask, seqlen_offset)?; - let xs = (xs + residual)?; - let residual = &xs; - let xs = xs.apply(&self.norm2)?.apply(&self.mlp)?; - residual + xs - } -} - -#[derive(Debug, Clone)] -pub struct Model { - embed_tokens: super::with_tracing::Embedding, - layers: Vec, - norm: RmsNorm, - lm_head: Linear, - sliding_window: usize, - device: Device, - dtype: DType, -} - -impl Model { - pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let vocab_size = cfg.vocab_size + (8 - cfg.vocab_size % 8) % 8; - let lm_head = linear_no_bias(cfg.hidden_size, vocab_size, vb.pp("lm_head"))?; - let embed_tokens = super::with_tracing::Embedding::from_weights(lm_head.weight().clone())?; - let vb_m = vb.pp("transformer"); - let mut layers = Vec::with_capacity(cfg.num_hidden_layers); - let vb_l = vb_m.pp("layers"); - for layer_idx in 0..cfg.num_hidden_layers { - let layer = DecoderLayer::new(layer_idx, cfg, vb_l.pp(layer_idx))?; - layers.push(layer) - } - let norm = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb_m.pp("ln_f"))?; - Ok(Self { - embed_tokens, - layers, - norm, - lm_head, - sliding_window: cfg.swa.window_size, - device: vb.device().clone(), - dtype: vb.dtype(), - }) - } - - fn prepare_decoder_attention_mask( - &self, - b_size: usize, - tgt_len: usize, - seqlen_offset: usize, - ) -> Result { - let sliding_window = self.sliding_window / 2; - let mask: Vec<_> = (0..tgt_len) - .flat_map(|i| { - (0..tgt_len).map(move |j| { - if i < j || j + sliding_window < i { - f32::NEG_INFINITY - } else { - 0. - } - }) - }) - .collect(); - let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; - let mask = if seqlen_offset > 0 { - let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?; - Tensor::cat(&[&mask0, &mask], D::Minus1)? - } else { - mask - }; - mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? - .to_dtype(self.dtype) - } - - pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { - let (b_size, seq_len) = input_ids.dims2()?; - let attention_mask = if seq_len <= 1 { - None - } else { - let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; - Some(mask) - }; - let mut xs = self.embed_tokens.forward(input_ids)?; - for layer in self.layers.iter_mut() { - xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? - } - xs.narrow(1, seq_len - 1, 1)? - .apply(&self.norm)? - .apply(&self.lm_head) - } -} - -//! Based from the Stanford Hazy Research group. -//! -//! See "Simple linear attention language models balance the recall-throughput tradeoff", Arora et al. 2024 -//! - -//! Original code: -//! https://github.com/HazyResearch/based - use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{ conv1d_no_bias, layer_norm::RmsNormNonQuantized, linear, linear_no_bias, ops::softmax_last_dim, diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 5cada11a35..fc797a771b 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -1,5 +1,4 @@ pub mod based; -pub mod based; pub mod beit; pub mod bert; pub mod bigcode; @@ -21,8 +20,8 @@ pub mod encodec; pub mod eva2; pub mod falcon; pub mod flux; -pub mod flux; pub mod gemma; +pub mod gemma2; pub mod glm4; pub mod hiera; pub mod jina_bert; From f706ef27666e77c5b5a6cdbd28f47c3205171d39 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:42:27 -0400 Subject: [PATCH 52/75] Add softcapping support to flash attention (#18) * Expose the softcap methods * Add some tests * Fix generics --- .vscode/settings.json | 2 +- candle-flash-attn/kernels/flash_api.cu | 13 +- candle-flash-attn/src/ffi.rs | 1 + candle-flash-attn/src/lib.rs | 182 +++++++++++++++++++- candle-flash-attn/tests/flash_attn_tests.rs | 98 ++++++++++- 5 files changed, 286 insertions(+), 10 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index e510b688c4..f9b6ef02f9 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -9,6 +9,6 @@ "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, "rust-analyzer.cargo.features": [ - "cuda" + "cuda", "flash-attn", ], } \ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index 4ca41b0a16..ca5f2b255d 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -45,6 +45,7 @@ extern "C" void run_mha( uint32_t d, uint32_t d_rounded, float softmax_scale, + float softcap, uint32_t seqlen_q, uint32_t seqlen_k, @@ -99,8 +100,16 @@ extern "C" void run_mha( params.d_rounded = d_rounded; // Set the different scale values. - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = softmax_scale * M_LOG2E; + if (softcap > 0.0) { + params.softcap = softmax_scale / softcap; + params.scale_softmax = softcap; + params.scale_softmax_log2 = softcap * M_LOG2E; + }else{ + // Remove potential NaN + params.softcap = 0.0; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + } params.p_dropout = 1.; // probability to keep params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs index ca65520be5..fe565beae6 100644 --- a/candle-flash-attn/src/ffi.rs +++ b/candle-flash-attn/src/ffi.rs @@ -34,6 +34,7 @@ extern "C" { d: u32, d_rounded: u32, softmax_scale: f32, + softcap: f32, seqlen_q: u32, seqlen_k: u32, diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index f171a9868f..5d991f0075 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -8,6 +8,7 @@ use half::{bf16, f16}; pub struct FlashAttn { pub softmax_scale: f32, + pub softcap: Option, pub alibi_slopes: Option, pub window_size_left: Option, pub window_size_right: Option, @@ -193,6 +194,7 @@ impl FlashAttn { /* d */ head_size as u32, /* d_rounded */ head_size_rounded as u32, /* softmax_scale*/ self.softmax_scale, + /* softcap */ self.softcap.unwrap_or(0.0), /* seqlen_q */ seqlen_q as u32, /* seqlen_k */ seqlen_k as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32, @@ -262,12 +264,25 @@ pub fn flash_attn( v: &Tensor, softmax_scale: f32, causal: bool, +) -> Result { + flash_attn_softcap(q, k, v, softmax_scale, None, causal) +} + +/// Equivalent to [`flash_attn`], but with softcap support +pub fn flash_attn_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + softcap: Option, + causal: bool, ) -> Result { let window_size_left = None; let window_size_right = if causal { Some(0) } else { None }; let op = FlashAttn { softmax_scale, + softcap, alibi_slopes: None, window_size_left, window_size_right, @@ -302,9 +317,31 @@ pub fn flash_attn_windowed( softmax_scale: f32, window_size_left: Option, window_size_right: Option, +) -> Result { + flash_attn_windowed_softcap( + q, + k, + v, + softmax_scale, + None, + window_size_left, + window_size_right, + ) +} + +/// Equivalent to [`flash_attn_windowed`], but with softcap support. +pub fn flash_attn_windowed_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + softcap: Option, + window_size_left: Option, + window_size_right: Option, ) -> Result { let op = FlashAttn { softmax_scale, + softcap, alibi_slopes: None, window_size_left, window_size_right, @@ -333,12 +370,26 @@ pub fn flash_attn_alibi( alibi_slopes: &Tensor, softmax_scale: f32, causal: bool, +) -> Result { + flash_attn_alibi_softcap(q, k, v, alibi_slopes, softmax_scale, None, causal) +} + +/// Equivalent to [`flash_attn_alibi`], but with softcap support. +pub fn flash_attn_alibi_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + softmax_scale: f32, + softcap: Option, + causal: bool, ) -> Result { let window_size_left = None; let window_size_right = if causal { Some(0) } else { None }; let op = FlashAttn { softmax_scale, + softcap, alibi_slopes: Some(alibi_slopes.clone()), window_size_left, window_size_right, @@ -378,6 +429,7 @@ pub fn flash_attn_alibi_windowed( ) -> Result { let op = FlashAttn { softmax_scale, + softcap: None, alibi_slopes: Some(alibi_slopes.clone()), window_size_left, window_size_right, @@ -387,6 +439,7 @@ pub fn flash_attn_alibi_windowed( struct FlashAttnVarLen { pub softmax_scale: f32, + pub softcap: Option, pub max_seqlen_q: usize, pub max_seqlen_k: usize, pub seqlens_q: Tensor, @@ -434,9 +487,9 @@ impl FlashAttnVarLen { None => candle::bail!("seqlens_k has to be contiguous"), }; - let q = q.as_cuda_slice::()?; - let k = k.as_cuda_slice::()?; - let v = v.as_cuda_slice::()?; + let q = q.as_cuda_slice::()?; + let k = k.as_cuda_slice::()?; + let v = v.as_cuda_slice::()?; let q = q.slice(q_l.start_offset()..); let k = k.slice(k_l.start_offset()..); let v = v.slice(v_l.start_offset()..); @@ -548,7 +601,7 @@ impl FlashAttnVarLen { let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128); let elem_count = out_shape.elem_count(); - let dst = unsafe { dev.alloc::(elem_count) }.w()?; + let dst = unsafe { dev.alloc::(elem_count) }.w()?; let softmax_lse = dev .alloc_zeros::(batch_size * num_heads * self.max_seqlen_q) .w()?; @@ -605,6 +658,7 @@ impl FlashAttnVarLen { /* d */ head_size as u32, /* d_rounded */ head_size_rounded as u32, /* softmax_scale*/ self.softmax_scale, + /* softcap */ self.softcap.unwrap_or(0.0), /* seqlen_q */ self.max_seqlen_q as u32, /* seqlen_k */ self.max_seqlen_k as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32, @@ -686,12 +740,40 @@ pub fn flash_attn_varlen( max_seqlen_k: usize, softmax_scale: f32, causal: bool, +) -> Result { + flash_attn_varlen_softcap( + q, + k, + v, + seqlens_q, + seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + None, + causal, + ) +} + +/// Equivalent to [`flash_attn_varlen`], but with softcap support. +pub fn flash_attn_varlen_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + softcap: Option, + causal: bool, ) -> Result { let window_size_left = None; let window_size_right = if causal { Some(0) } else { None }; let op = FlashAttnVarLen { softmax_scale, + softcap, max_seqlen_q, max_seqlen_k, seqlens_q: seqlens_q.clone(), @@ -742,9 +824,39 @@ pub fn flash_attn_varlen_windowed( softmax_scale: f32, window_size_left: Option, window_size_right: Option, +) -> Result { + flash_attn_varlen_windowed_softcap( + q, + k, + v, + seqlens_q, + seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + None, + window_size_left, + window_size_right, + ) +} + +/// Equivalent to [`flash_attn_varlen_windowed`], but with softcap support. +pub fn flash_attn_varlen_windowed_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + softcap: Option, + window_size_left: Option, + window_size_right: Option, ) -> Result { let op = FlashAttnVarLen { softmax_scale, + softcap, max_seqlen_q, max_seqlen_k, seqlens_q: seqlens_q.clone(), @@ -789,12 +901,42 @@ pub fn flash_attn_varlen_alibi( max_seqlen_k: usize, softmax_scale: f32, causal: bool, +) -> Result { + flash_attn_varlen_alibi_softcap( + q, + k, + v, + alibi_slopes, + seqlens_q, + seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + None, + causal, + ) +} + +/// Equivalent to [`flash_attn_varlen_alibi`], but with softcap support +pub fn flash_attn_varlen_alibi_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + softcap: Option, + causal: bool, ) -> Result { let window_size_left = None; let window_size_right = if causal { Some(0) } else { None }; let op = FlashAttnVarLen { softmax_scale, + softcap, max_seqlen_q, max_seqlen_k, seqlens_q: seqlens_q.clone(), @@ -847,9 +989,41 @@ pub fn flash_attn_varlen_alibi_windowed( softmax_scale: f32, window_size_left: Option, window_size_right: Option, +) -> Result { + flash_attn_varlen_alibi_windowed_softcap( + q, + k, + v, + alibi_slopes, + seqlens_q, + seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + None, + window_size_left, + window_size_right, + ) +} + +/// Equivalent to [`flash_attn_varlen_alibi_windowed`], but with softcap support. +pub fn flash_attn_varlen_alibi_windowed_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + softcap: Option, + window_size_left: Option, + window_size_right: Option, ) -> Result { let op = FlashAttnVarLen { softmax_scale, + softcap, max_seqlen_q, max_seqlen_k, seqlens_q: seqlens_q.clone(), diff --git a/candle-flash-attn/tests/flash_attn_tests.rs b/candle-flash-attn/tests/flash_attn_tests.rs index 250added04..fd51152ee6 100644 --- a/candle-flash-attn/tests/flash_attn_tests.rs +++ b/candle-flash-attn/tests/flash_attn_tests.rs @@ -15,12 +15,23 @@ fn to_vec3_round(t: Tensor, digits: i32) -> Result>>> { Ok(t) } -fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result { +fn fa_acausal( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + softcap: Option, +) -> Result { let in_dtype = q.dtype(); let q = q.to_dtype(DType::F32)?; let k = k.to_dtype(DType::F32)?; let v = v.to_dtype(DType::F32)?; - let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?; + let mut att = (q.matmul(&k.t()?)? * softmax_scale as f64)?; + if let Some(softcap) = softcap { + att = (att / softcap as f64)?; + att = att.tanh()?; + att = (att * softcap as f64)?; + } let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?; @@ -37,7 +48,7 @@ fn flash_attn_acausal() -> Result<()> { let v = (&q / 50.)?; let q = (&q / 30.)?; - let ys1 = fa_acausal(&q, &k, &v, 0.5)?; + let ys1 = fa_acausal(&q, &k, &v, 0.5, None)?; let ys1 = ys1.i(0)?.to_dtype(DType::F32)?; let ys2 = { let q = q.transpose(1, 2)?; @@ -133,3 +144,84 @@ fn flash_attn_varlen() -> Result<()> { ); Ok(()) } + +#[test] +fn flash_attn_acausal_softcap() -> Result<()> { + let device = Device::new_cuda(0)?; + let q = Tensor::arange(0u32, 48, &device)? + .to_dtype(DType::F16)? + .reshape((1, 3, 2, 8))?; + let k = (&q / 40.)?; + let v = (&q / 50.)?; + let q = (&q / 30.)?; + + let ys1 = fa_acausal(&q, &k, &v, 0.5, Some(30.))?; + let ys1 = ys1.i(0)?.to_dtype(DType::F32)?; + let ys2 = { + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + candle_flash_attn::flash_attn_softcap(&q, &k, &v, 0.5, Some(30.), false)?.transpose(1, 2)? + }; + let ys2 = ys2.i(0)?.to_dtype(DType::F32)?; + let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?; + + assert_eq!(ys1.dims(), &[3, 2, 8]); + assert_eq!(ys2.dims(), &[3, 2, 8]); + assert!(diff.to_vec0::()?.abs() < 1e-5); + Ok(()) +} + +#[test] +fn flash_attn_varlen_softcap() -> Result<()> { + let device = Device::new_cuda(0)?; + let q = Tensor::arange(0u32, 48, &device)? + .to_dtype(DType::F16)? + .reshape((3, 2, 8))?; + let k = (&q / 40.)?; + let v = (&q / 50.)?; + let q = (&q / 30.)?; + + let seqlens_q = Tensor::new(&[0u32, 2u32], &device)?; + let seqlens_k = Tensor::new(&[0u32, 2u32], &device)?; + + let ys = { + let q = q.transpose(0, 1)?; + let k = k.transpose(0, 1)?; + let v = v.transpose(0, 1)?; + candle_flash_attn::flash_attn_varlen_softcap( + &q, + &k, + &v, + &seqlens_q, + &seqlens_k, + 32, + 32, + 0.5, + Some(30.), + false, + )? + .transpose(0, 1)? + }; + let ys = ys.to_dtype(DType::F32)?; + + assert_eq!(ys.dims(), &[3, 2, 8]); + assert_eq!( + to_vec3_round(ys, 4)?, + &[ + [ + [0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238], + [0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322] + ], + [ + [0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605], + [0.428, 0.448, 0.468, 0.488, 0.5078, 0.5278, 0.5479, 0.5679] + ], + [ + [0.7549, 0.7749, 0.7949, 0.8149, 0.835, 0.855, 0.875, 0.895], + [0.7607, 0.7808, 0.8008, 0.8208, 0.8408, 0.8608, 0.8809, 0.9009] + ] + ] + ); + Ok(()) +} From 3c8e120e8a6ae88b41ee251ec6255035864858d6 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Mon, 2 Sep 2024 10:42:38 -0400 Subject: [PATCH 53/75] Update kernels for metal bf16 (#19) * Update kernels for metal bf16 * Fix typo * Check if have bfloat --- candle-core/src/quantized/metal.rs | 5 + candle-metal-kernels/src/lib.rs | 4 +- candle-metal-kernels/src/quantized.metal | 200 ++++++++++++++++++++++- 3 files changed, 206 insertions(+), 3 deletions(-) diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index f23d6e15df..bc43658d13 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -55,6 +55,10 @@ impl QMetalStorage { let vec: Vec = read_to_vec(&buffer, block_len); half::f16::to_float(&vec, &mut out)?; } + GgmlDType::BF16 => { + let vec: Vec = read_to_vec(&buffer, block_len); + half::bf16::to_float(&vec, &mut out)?; + } GgmlDType::Q4_0 => { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?; @@ -241,6 +245,7 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, + GgmlDType::BF16 => candle_metal_kernels::GgmlDType::F16, } } } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index d6e6dd69b8..d5e5b8eb66 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1844,6 +1844,7 @@ pub enum GgmlDType { Q8K, F16, F32, + BF16, } #[allow(clippy::too_many_arguments)] @@ -1921,7 +1922,7 @@ pub fn call_quantized_matmul_mv_t( let align = 2; (nth0, nth1, align) } - GgmlDType::F16 | GgmlDType::Q8K => { + GgmlDType::F16 | GgmlDType::BF16 | GgmlDType::Q8K => { // Original implem uses rows let nth0 = 32; let nth1 = 1; @@ -1959,6 +1960,7 @@ pub fn call_quantized_matmul_mv_t( GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32", GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32", GgmlDType::F16 => "kernel_mul_mv_f16_f32", + GgmlDType::BF16 => "kernel_mul_mv_bf16_f32", GgmlDType::F32 => "kernel_mul_mv_f32_f32", }; diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal index fef6ac54f8..162b7a2d19 100644 --- a/candle-metal-kernels/src/quantized.metal +++ b/candle-metal-kernels/src/quantized.metal @@ -1495,8 +1495,203 @@ kernel void kernel_mul_mv_f16_f32( kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); } +#if defined(__HAVE_BFLOAT__) +void kernel_mul_mv_bf16_f32_1row_impl( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const bfloat* x = (device const bfloat*) (src0 + offset0); + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + if (ne00 < 128) { + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + device const bfloat4* x4 = (device const bfloat4*) x; + device const float4 * y4 = (device const float4 *) y; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_bf16_f32_1row")]] +kernel void kernel_mul_mv_bf16_f32_1row( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_bf16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} +#endif + +#define N_BF16_F32 4 + +#if defined(__HAVE_BFLOAT__) +void kernel_mul_mv_bf16_f32_impl( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_BF16_F32; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const bfloat * x = (device const bfloat *) (src0 + offset0); + + if (ne00 < 128) { + for (int row = 0; row < N_BF16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const bfloat4 * x4 = (device const bfloat4 *)x; + for (int row = 0; row < N_BF16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float4 * y4 = (device const float4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +[[host_name("kernel_mul_mv_bf16_f32")]] +kernel void kernel_mul_mv_bf16_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_bf16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} +#endif + +#if defined(__HAVE_BFLOAT__) // Assumes row size (ne00) is a multiple of 4 -kernel void kernel_mul_mv_f16_f32_l4( +kernel void kernel_mul_mv_bf16_f32_l4( device const char * src0, device const char * src1, device float * dst, @@ -1528,7 +1723,7 @@ kernel void kernel_mul_mv_f16_f32_l4( const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - device const half4 * x4 = (device const half4 *) (src0 + offset0); + device const bfloat4 * x4 = (device const bfloat4 *) (src0 + offset0); for (int r1 = 0; r1 < nrows; ++r1) { device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); @@ -1544,6 +1739,7 @@ kernel void kernel_mul_mv_f16_f32_l4( } } } +#endif kernel void kernel_alibi_f32( device const float * src0, From 014f1400b54d9345d72c3da81c959fe29da8c9cb Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 5 Sep 2024 11:26:32 +1000 Subject: [PATCH 54/75] fix(metal/accelerate): f64-f32 type mismatch (#20) --- candle-core/src/cpu_backend/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 3635001bbd..75c6e7bd38 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -1713,7 +1713,7 @@ impl Map3 for MatMulWithBias { /* m= */ n as i32, /* n= */ m as i32, /* k= */ k as i32, - /* alpha= */ s.unwrap_or(1.), + /* alpha= */ s.unwrap_or(1.) as f32, /* a= */ a, /* lda= */ lda, /* b= */ b, @@ -1743,7 +1743,7 @@ impl Map3 for MatMulWithBias { /* m= */ n as i32, /* n= */ m as i32, /* k= */ k as i32, - /* alpha= */ s.unwrap_or(1.), + /* alpha= */ s.unwrap_or(1.) as f64, /* a= */ a, /* lda= */ lda, /* b= */ b, @@ -2064,7 +2064,7 @@ impl Map2Alpha for MatMulWithAlpha { /* m= */ n as i32, /* n= */ m as i32, /* k= */ k as i32, - /* alpha= */ s.unwrap_or(1.), + /* alpha= */ s.unwrap_or(1.) as f32, /* a= */ a, /* lda= */ lda, /* b= */ b, From f317df825622e2a8a79a7ac2b4ccc84bccc05535 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 22 Aug 2024 08:23:52 +0100 Subject: [PATCH 55/75] Bump the version to 0.6.1. (#2438) --- Cargo.toml | 18 +++++++++--------- candle-flash-attn/Cargo.toml | 4 ++-- candle-kernels/Cargo.toml | 2 +- candle-metal-kernels/Cargo.toml | 2 +- candle-onnx/Cargo.toml | 6 +++--- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2dfc41f7cc..efd39165ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ exclude = [ resolver = "2" [workspace.package] -version = "0.6.0" +version = "0.6.1" edition = "2021" description = "Minimalist ML framework." repository = "https://github.com/huggingface/candle" @@ -33,14 +33,14 @@ ab_glyph = "0.2.23" accelerate-src = { version = "0.3.2" } anyhow = { version = "1", features = ["backtrace"] } byteorder = "1.4.3" -candle = { path = "./candle-core", package = "candle-core", version = "0.6.0" } -candle-datasets = { path = "./candle-datasets", version = "0.6.0" } -candle-flash-attn = { path = "./candle-flash-attn", version = "0.6.0" } -candle-kernels = { path = "./candle-kernels", version = "0.6.0" } -candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.6.0" } -candle-nn = { path = "./candle-nn", version = "0.6.0" } -candle-onnx = { path = "./candle-onnx", version = "0.6.0" } -candle-transformers = { path = "./candle-transformers", version = "0.6.0" } +candle = { path = "./candle-core", package = "candle-core", version = "0.6.1" } +candle-datasets = { path = "./candle-datasets", version = "0.6.1" } +candle-flash-attn = { path = "./candle-flash-attn", version = "0.6.1" } +candle-kernels = { path = "./candle-kernels", version = "0.6.1" } +candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.6.1" } +candle-nn = { path = "./candle-nn", version = "0.6.1" } +candle-onnx = { path = "./candle-onnx", version = "0.6.1" } +candle-transformers = { path = "./candle-transformers", version = "0.6.1" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features=false } cudarc = { version = "=0.11.6", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index a4c412ffa7..3be0e4d6c2 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-flash-attn" -version = "0.6.0" +version = "0.6.1" edition = "2021" description = "Flash attention layer for the candle ML framework." @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0" readme = "README.md" [dependencies] -candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.6.0" } +candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.6.1" } half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml index 35d269fa87..84cacb89b1 100644 --- a/candle-kernels/Cargo.toml +++ b/candle-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-kernels" -version = "0.6.0" +version = "0.6.1" edition = "2021" description = "CUDA kernels for Candle" diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index e7a85f1f6d..a93d872902 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-metal-kernels" -version = "0.6.0" +version = "0.6.1" edition = "2021" description = "Metal kernels for Candle" diff --git a/candle-onnx/Cargo.toml b/candle-onnx/Cargo.toml index ad9530e9b1..dc9bbcbbad 100644 --- a/candle-onnx/Cargo.toml +++ b/candle-onnx/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-onnx" -version = "0.6.0" +version = "0.6.1" edition = "2021" description = "ONNX support for Candle" @@ -10,8 +10,8 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -candle = { path = "../candle-core", package = "candle-core", version = "0.6.0" } -candle-nn = { path = "../candle-nn", version = "0.6.0" } +candle = { path = "../candle-core", package = "candle-core", version = "0.6.1" } +candle-nn = { path = "../candle-nn", version = "0.6.1" } prost = "0.12.1" [build-dependencies] From 8a9d2be9fd5073a4db7e8f105d23b9a06be242d4 Mon Sep 17 00:00:00 2001 From: shua Date: Thu, 22 Aug 2024 13:34:53 +0200 Subject: [PATCH 56/75] onnx: workaround pow with negative base (#2439) * onnx: workaround pow with negative base rather than fully defining pow in the cpu backend (as in #2318), this implements a much smaller change which is sufficient to evaluate silero-vad onnx models. Specifically, checking if pow is run with 2.0 exponent, and if so evaluate as simply `x*x` instead of the cpu backend of `e^(2.0 * ln(x))`. * PR: use Tensor::powf insead powf correctly handles a negative base. --- candle-onnx/src/eval.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 036f583873..fca51ef7ef 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -352,8 +352,15 @@ fn simple_eval_( "Pow" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; - let output = input0.broadcast_pow(input1)?; - values.insert(node.output[0].clone(), output); + // HACK: current implementation of broadcast_pow cannot handle negative base, + // so we use powf where we can, which *does* correctly handle negative base. + if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::())() { + let output = input0.powf(exp as f64)?; + values.insert(node.output[0].clone(), output); + } else { + let output = input0.broadcast_pow(input1)?; + values.insert(node.output[0].clone(), output); + } } "Exp" => { let xs = get(&node.input[0])?; From a7142d344df87bf5d4fbef1d51dd9f12a48e4ef3 Mon Sep 17 00:00:00 2001 From: shua Date: Thu, 22 Aug 2024 15:28:25 +0200 Subject: [PATCH 57/75] onnx: support negative index in Gather (#2440) index_select does not support negative indexing, but this change adds just enough workarounds in onnx to allow evaluating silero-vad models (which make use of negative indices). --- candle-onnx/src/eval.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index fca51ef7ef..5b66a743c3 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -629,6 +629,18 @@ fn simple_eval_( let axis = get_attr_opt::(node, "axis")?.copied().unwrap_or(0); let axis = xs.normalize_axis(axis)?; + // index_select does not support negative indices, so normalize them + // to positive indices. + let indices = &{ + let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?; + let max = Tensor::new(xs.dims()[axis] as i64, indices.device())? + .to_dtype(indices.dtype())?; + let mask = indices.lt(&zeros)?; + mask.to_dtype(indices.dtype())? + .broadcast_mul(&max)? + .add(&indices)? + }; + // In Pytorch or Numpy this can be done by indexing the xs tensor using the indices // tensor directly, but candle does not support tensor indexing at the moment, so // some workarounds must be done. From f62d7e851fdd7728cadf41a535d08ca8d268419d Mon Sep 17 00:00:00 2001 From: shua Date: Thu, 22 Aug 2024 22:50:42 +0200 Subject: [PATCH 58/75] silero-vad v5 example (#2321) * silero-vad v5 example This change adds an example of how to run silero-vad v5 * PR: rename 'vad' to 'silero-vad' * Update README.md --------- Co-authored-by: Laurent Mazare --- candle-examples/Cargo.toml | 4 + candle-examples/examples/silero-vad/README.md | 12 ++ candle-examples/examples/silero-vad/main.rs | 200 ++++++++++++++++++ 3 files changed, 216 insertions(+) create mode 100644 candle-examples/examples/silero-vad/README.md create mode 100644 candle-examples/examples/silero-vad/main.rs diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 56e3d535de..6879c48b28 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -108,3 +108,7 @@ required-features = ["encodec"] [[example]] name = "depth_anything_v2" required-features = ["depth_anything_v2"] + +[[example]] +name = "silero-vad" +required-features = ["onnx"] diff --git a/candle-examples/examples/silero-vad/README.md b/candle-examples/examples/silero-vad/README.md new file mode 100644 index 0000000000..14dd8a82b1 --- /dev/null +++ b/candle-examples/examples/silero-vad/README.md @@ -0,0 +1,12 @@ +# silero-vad: Voice Activity Detection + +[Silero VAD (v5)](https://github.com/snakers4/silero-vad) detects voice activity in streaming audio. + +This example uses the models available in the hugging face [onnx-community/silero-vad](https://huggingface.co/onnx-community/silero-vad). + +## Running the example + +```bash +$ arecord -t raw -f S16_LE -r 16000 -c 1 -d 5 - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000 +``` + diff --git a/candle-examples/examples/silero-vad/main.rs b/candle-examples/examples/silero-vad/main.rs new file mode 100644 index 0000000000..4618ad80df --- /dev/null +++ b/candle-examples/examples/silero-vad/main.rs @@ -0,0 +1,200 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Result; +use clap::Parser; + +use candle::{DType, Tensor}; +use candle_onnx; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Which { + #[value(name = "silero")] + Silero, +} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum SampleRate { + #[value(name = "8000")] + Sr8k, + #[value(name = "16000")] + Sr16k, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + input: Option, + + #[arg(long)] + sample_rate: SampleRate, + + #[arg(long)] + model_id: Option, + + #[arg(long)] + config_file: Option, + + /// The model to use. + #[arg(long, default_value = "silero")] + which: Which, +} + +/// an iterator which reads consecutive frames of le i16 values from a reader +struct I16Frames { + rdr: R, + buf: Box<[u8]>, + len: usize, + eof: bool, +} +impl I16Frames { + fn new(rdr: R, frame_size: usize) -> Self { + I16Frames { + rdr, + buf: vec![0; frame_size * std::mem::size_of::()].into_boxed_slice(), + len: 0, + eof: false, + } + } +} +impl Iterator for I16Frames { + type Item = std::io::Result>; + + fn next(&mut self) -> Option { + if self.eof { + return None; + } + self.len += match self.rdr.read(&mut self.buf[self.len..]) { + Ok(0) => { + self.eof = true; + 0 + } + Ok(n) => n, + Err(e) => return Some(Err(e)), + }; + if self.eof || self.len == self.buf.len() { + let buf = self.buf[..self.len] + .chunks(2) + .map(|bs| match bs { + [a, b] => i16::from_le_bytes([*a, *b]), + _ => unreachable!(), + }) + .map(|i| i as f32 / i16::MAX as f32) + .collect(); + self.len = 0; + Some(Ok(buf)) + } else { + self.next() + } + } +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + + let start = std::time::Instant::now(); + let model_id = match &args.model_id { + Some(model_id) => std::path::PathBuf::from(model_id), + None => match args.which { + Which::Silero => hf_hub::api::sync::Api::new()? + .model("onnx-community/silero-vad".into()) + .get("onnx/model.onnx")?, + // TODO: candle-onnx doesn't support Int8 dtype + // Which::SileroQuantized => hf_hub::api::sync::Api::new()? + // .model("onnx-community/silero-vad".into()) + // .get("onnx/model_quantized.onnx")?, + }, + }; + let (sample_rate, frame_size, context_size): (i64, usize, usize) = match args.sample_rate { + SampleRate::Sr8k => (8000, 256, 32), + SampleRate::Sr16k => (16000, 512, 64), + }; + println!("retrieved the files in {:?}", start.elapsed()); + + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + let model = candle_onnx::read_file(model_id)?; + + println!("loaded the model in {:?}", start.elapsed()); + + let start = std::time::Instant::now(); + struct State { + frame_size: usize, + sample_rate: Tensor, + state: Tensor, + context: Tensor, + } + + let mut state = State { + frame_size, + sample_rate: Tensor::new(sample_rate, &device)?, + state: Tensor::zeros((2, 1, 128), DType::F32, &device)?, + context: Tensor::zeros((1, context_size), DType::F32, &device)?, + }; + let mut res = vec![]; + for chunk in I16Frames::new(std::io::stdin().lock(), state.frame_size) { + let chunk = chunk.unwrap(); + if chunk.len() < state.frame_size { + continue; + } + let next_context = Tensor::from_slice( + &chunk[state.frame_size - context_size..], + (1, context_size), + &device, + )?; + let chunk = Tensor::from_vec(chunk, (1, state.frame_size), &device)?; + let chunk = Tensor::cat(&[&state.context, &chunk], 1)?; + let inputs = std::collections::HashMap::from_iter([ + ("input".to_string(), chunk), + ("sr".to_string(), state.sample_rate.clone()), + ("state".to_string(), state.state.clone()), + ]); + let out = candle_onnx::simple_eval(&model, inputs).unwrap(); + let out_names = &model.graph.as_ref().unwrap().output; + let output = out.get(&out_names[0].name).unwrap().clone(); + state.state = out.get(&out_names[1].name).unwrap().clone(); + assert_eq!(state.state.dims(), &[2, 1, 128]); + state.context = next_context; + + let output = output.flatten_all()?.to_vec1::()?; + assert_eq!(output.len(), 1); + let output = output[0]; + println!("vad chunk prediction: {output}"); + res.push(output); + } + println!("calculated prediction in {:?}", start.elapsed()); + + let res_len = res.len() as f32; + let prediction = res.iter().sum::() / res_len; + println!("vad average prediction: {prediction}"); + Ok(()) +} From ceab78ebaf72ca1a1f0c63163b8f702c512d0a60 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 22 Aug 2024 22:22:03 +0100 Subject: [PATCH 59/75] Fix for parler-tts, do not add the last slice of padding tokens. (#2442) * Fix for parler-tts, do not add the last slice of padding tokens. * Support for the mini model. --- candle-examples/examples/parler-tts/main.rs | 23 ++++++++++++++++++-- candle-transformers/src/models/parler_tts.rs | 1 - 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/candle-examples/examples/parler-tts/main.rs b/candle-examples/examples/parler-tts/main.rs index 4e3730e2cb..88e0ef8b30 100644 --- a/candle-examples/examples/parler-tts/main.rs +++ b/candle-examples/examples/parler-tts/main.rs @@ -86,6 +86,17 @@ struct Args { /// The output wav file. #[arg(long, default_value = "out.wav")] out_file: String, + + #[arg(long, default_value = "large-v1")] + which: Which, +} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Which { + #[value(name = "large-v1")] + LargeV1, + #[value(name = "mini-v1")] + MiniV1, } fn main() -> anyhow::Result<()> { @@ -117,7 +128,10 @@ fn main() -> anyhow::Result<()> { let api = hf_hub::api::sync::Api::new()?; let model_id = match args.model_id { Some(model_id) => model_id.to_string(), - None => "parler-tts/parler-tts-large-v1".to_string(), + None => match args.which { + Which::LargeV1 => "parler-tts/parler-tts-large-v1".to_string(), + Which::MiniV1 => "parler-tts/parler-tts-mini-v1".to_string(), + }, }; let revision = match args.revision { Some(r) => r, @@ -130,7 +144,12 @@ fn main() -> anyhow::Result<()> { )); let model_files = match args.model_file { Some(m) => vec![m.into()], - None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + None => match args.which { + Which::MiniV1 => vec![repo.get("model.safetensors")?], + Which::LargeV1 => { + candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? + } + }, }; let config = match args.config_file { Some(m) => m.into(), diff --git a/candle-transformers/src/models/parler_tts.rs b/candle-transformers/src/models/parler_tts.rs index 16023a7c6f..da40124741 100644 --- a/candle-transformers/src/models/parler_tts.rs +++ b/candle-transformers/src/models/parler_tts.rs @@ -429,7 +429,6 @@ impl Model { let min_len = all_audio_tokens.iter().map(|v| v.len()).min().unwrap_or(0); all_audio_tokens.iter_mut().for_each(|v| { v.resize(min_len, 0); - v.push(self.pad_token_id) }); let all_audio_tokens = Tensor::new(all_audio_tokens, &candle::Device::Cpu)?; Ok(all_audio_tokens) From 5b4c593933057e12bd55ef0ae9c20d0f01e4d81e Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Fri, 23 Aug 2024 17:06:54 +0300 Subject: [PATCH 60/75] Add FastViT model. (#2444) --- candle-examples/examples/fastvit/README.md | 20 + candle-examples/examples/fastvit/main.rs | 102 ++++ candle-transformers/src/models/fastvit.rs | 512 +++++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + 4 files changed, 635 insertions(+) create mode 100644 candle-examples/examples/fastvit/README.md create mode 100644 candle-examples/examples/fastvit/main.rs create mode 100644 candle-transformers/src/models/fastvit.rs diff --git a/candle-examples/examples/fastvit/README.md b/candle-examples/examples/fastvit/README.md new file mode 100644 index 0000000000..499685bd3c --- /dev/null +++ b/candle-examples/examples/fastvit/README.md @@ -0,0 +1,20 @@ +# candle-fastvit + +[FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization](https://arxiv.org/abs/2303.14189). +This candle implementation uses a pre-trained FastViT network for inference. The +classification head has been trained on the ImageNet dataset and returns the +probabilities for the top-5 classes. + +## Running an example + +``` +$ cargo run --example fastvit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which sa12 + +loaded image Tensor[dims 3, 256, 256; f32] +model built +mountain bike, all-terrain bike, off-roader: 43.45% +bicycle-built-for-two, tandem bicycle, tandem: 14.16% +unicycle, monocycle : 4.12% +crash helmet : 2.26% +alp : 1.40% +``` diff --git a/candle-examples/examples/fastvit/main.rs b/candle-examples/examples/fastvit/main.rs new file mode 100644 index 0000000000..520fd0aed3 --- /dev/null +++ b/candle-examples/examples/fastvit/main.rs @@ -0,0 +1,102 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::fastvit; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + T8, + T12, + S12, + SA12, + SA24, + SA36, + MA36, +} + +impl Which { + fn model_filename(&self) -> String { + let name = match self { + Self::T8 => "t8", + Self::T12 => "t12", + Self::S12 => "s12", + Self::SA12 => "sa12", + Self::SA24 => "sa24", + Self::SA36 => "sa36", + Self::MA36 => "ma36", + }; + format!("timm/fastvit_{}.apple_in1k", name) + } + + fn config(&self) -> fastvit::Config { + match self { + Self::T8 => fastvit::Config::t8(), + Self::T12 => fastvit::Config::t12(), + Self::S12 => fastvit::Config::s12(), + Self::SA12 => fastvit::Config::sa12(), + Self::SA24 => fastvit::Config::sa24(), + Self::SA36 => fastvit::Config::sa36(), + Self::MA36 => fastvit::Config::ma36(), + } + } +} + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(value_enum, long, default_value_t=Which::S12)] + which: Which, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image(args.image, 256)?.to_device(&device)?; + println!("loaded image {image:?}"); + + let model_file = match args.model { + None => { + let model_name = args.which.model_filename(); + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(model_name); + api.get("model.safetensors")? + } + Some(model) => model.into(), + }; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = fastvit::fastvit(&args.which.config(), 1000, vb)?; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::()?; + let mut prs = prs.iter().enumerate().collect::>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!( + "{:24}: {:.2}%", + candle_examples::imagenet::CLASSES[category_idx], + 100. * pr + ); + } + Ok(()) +} diff --git a/candle-transformers/src/models/fastvit.rs b/candle-transformers/src/models/fastvit.rs new file mode 100644 index 0000000000..a0b3cc3e57 --- /dev/null +++ b/candle-transformers/src/models/fastvit.rs @@ -0,0 +1,512 @@ +//! FastViT inference implementation based on timm +//! +//! See "FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization" +//! https://arxiv.org/pdf/2303.14189 +//! +//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py + +use candle::{DType, Result, Tensor, D}; +use candle_nn::{ + batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax, + BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder, +}; + +#[derive(Clone, Debug)] +pub struct Config { + exp_ratio: usize, + in_channels: usize, + blocks: [usize; 4], + attn: bool, + lkc_use_act: bool, +} + +impl Config { + pub fn t8() -> Self { + Self { + exp_ratio: 3, + in_channels: 48, + blocks: [2, 2, 4, 2], + attn: false, + lkc_use_act: false, + } + } + + pub fn t12() -> Self { + Self { + exp_ratio: 3, + in_channels: 64, + blocks: [2, 2, 6, 2], + attn: false, + lkc_use_act: false, + } + } + pub fn s12() -> Self { + Self { + exp_ratio: 4, + in_channels: 64, + blocks: [2, 2, 6, 2], + attn: false, + lkc_use_act: false, + } + } + pub fn sa12() -> Self { + Self { + exp_ratio: 4, + in_channels: 64, + blocks: [2, 2, 6, 2], + attn: true, + lkc_use_act: false, + } + } + pub fn sa24() -> Self { + Self { + exp_ratio: 4, + in_channels: 64, + blocks: [4, 4, 12, 4], + attn: true, + lkc_use_act: false, + } + } + pub fn sa36() -> Self { + Self { + exp_ratio: 4, + in_channels: 64, + blocks: [6, 6, 18, 6], + attn: true, + lkc_use_act: false, + } + } + pub fn ma36() -> Self { + Self { + exp_ratio: 4, + in_channels: 76, + blocks: [6, 6, 18, 6], + attn: true, + lkc_use_act: false, + } + } + + // configs used by MobileCLIP's image encoder + pub fn mci0() -> Self { + Self { + exp_ratio: 3, + in_channels: 64, + blocks: [2, 6, 10, 2], + attn: true, + lkc_use_act: true, + } + } + pub fn mci1() -> Self { + Self { + exp_ratio: 3, + in_channels: 64, + blocks: [4, 12, 20, 4], + attn: true, + lkc_use_act: true, + } + } + pub fn mci2() -> Self { + Self { + exp_ratio: 3, + in_channels: 80, + blocks: [4, 12, 24, 4], + attn: true, + lkc_use_act: true, + } + } +} + +fn conv_norm( + in_channels: usize, + out_channels: usize, + kernel: usize, + stride: usize, + vb: VarBuilder, +) -> Result> { + let conv2d_cfg = Conv2dConfig { + stride, + padding: kernel / 2, + groups: in_channels, + ..Default::default() + }; + + let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?; + let conv = conv2d_no_bias(in_channels, out_channels, kernel, conv2d_cfg, vb.pp("conv"))?; + let conv = conv.absorb_bn(&bn)?; + Ok(Func::new(move |xs| { + let xs = xs.apply(&conv)?; + Ok(xs) + })) +} + +fn conv_mlp(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result> { + let conv2d_cfg = Conv2dConfig { + ..Default::default() + }; + + let conv = conv_norm(dim, dim, 7, 1, vb.pp("conv"))?; + let fc1 = conv2d(dim, dim * exp_ratio, 1, conv2d_cfg, vb.pp("fc1"))?; + let fc2 = conv2d(dim * exp_ratio, dim, 1, conv2d_cfg, vb.pp("fc2"))?; + + Ok(Func::new(move |xs| { + let xs = xs.apply(&conv)?.apply(&fc1)?.gelu_erf()?.apply(&fc2)?; + Ok(xs) + })) +} + +fn squeeze_and_excitation( + in_channels: usize, + squeeze_channels: usize, + vb: VarBuilder, +) -> Result> { + let conv2d_cfg = Conv2dConfig { + ..Default::default() + }; + let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?; + let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?; + + Ok(Func::new(move |xs| { + let residual = xs; + let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?; + let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?; + + residual.broadcast_mul(&xs) + })) +} + +// fuses a convolutional kernel and a batchnorm layer into a convolutional layer +// based on the _fuse_bn_tensor method in timm +// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602 +fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> { + let (gamma, beta) = bn.weight_and_bias().unwrap(); + let mu = bn.running_mean(); + let sigma = (bn.running_var() + bn.eps())?.sqrt(); + let gps = (gamma / sigma)?; + let bias = (beta - mu * &gps)?; + let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?; + + Ok((weights, bias)) +} + +fn mobileone_block( + in_channels: usize, + out_channels: usize, + kernel: usize, + stride: usize, + group_size: usize, + use_act: bool, + vb: VarBuilder, +) -> Result> { + let groups = if group_size == 0 { + 1 + } else { + in_channels / group_size + }; + + let padding = kernel / 2; + let conv2d_cfg = Conv2dConfig { + stride, + groups, + padding, + ..Default::default() + }; + + let mut w = Tensor::zeros( + (out_channels, in_channels / groups, kernel, kernel), + DType::F32, + vb.device(), + )?; + let dim = out_channels; + + let mut b = Tensor::zeros(dim, DType::F32, vb.device())?; + + let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.0.bn")); + let conv_kxk = conv2d_no_bias( + in_channels, + out_channels, + kernel, + conv2d_cfg, + vb.pp("conv_kxk.0.conv"), + ); + + if let (Ok(conv), Ok(bn)) = (conv_kxk, conv_kxk_bn) { + let (wk, bk) = fuse_conv_bn(conv.weight(), bn)?; + w = (w + wk)?; + b = (b + bk)?; + }; + + let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn")); + let conv_scale = conv2d_no_bias( + in_channels, + out_channels, + 1, + conv2d_cfg, + vb.pp("conv_scale.conv"), + ); + + if let (Ok(conv), Ok(bn)) = (conv_scale, conv_scale_bn) { + let (ws, bs) = fuse_conv_bn(conv.weight(), bn)?; + // pad to 3x3 + let ws = ws + .pad_with_zeros(D::Minus1, 1, 1)? + .pad_with_zeros(D::Minus2, 1, 1)?; + + w = (w + ws)?; + b = (b + bs)?; + }; + + let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("se")); + + // read and reparameterize the identity bn into wi and bi + let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity")); + + if let Ok(id_bn) = identity_bn { + let mut weights: Vec = vec![0.0; w.elem_count()]; + let id = in_channels / groups; + // See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809 + for i in 0..in_channels { + if kernel > 1 { + weights[i * kernel * kernel + 4] = 1.0; + } else { + weights[i * (id + 1)] = 1.0; + } + } + + let weights = &Tensor::from_vec(weights, w.shape(), w.device())?; + let (wi, bi) = fuse_conv_bn(weights, id_bn)?; + + w = (w + wi)?; + b = (b + bi)?; + }; + let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg); + + Ok(Func::new(move |xs| { + let mut xs = xs.apply(&reparam_conv)?; + if let Ok(f) = &se { + xs = xs.apply(f)?; + } + if use_act { + xs = xs.gelu_erf()?; + }; + Ok(xs) + })) +} + +fn repmixer(dim: usize, kernel: usize, vb: VarBuilder) -> Result> { + let gamma = vb.get((dim, 1, 1), "layer_scale.gamma")?; + let norm = mobileone_block(dim, dim, kernel, 1, 1, false, vb.pp("norm"))?; + let mixer = mobileone_block(dim, dim, kernel, 1, 1, false, vb.pp("mixer"))?; + + Ok(Func::new(move |xs| { + let residual = xs.clone(); + let xs = (xs.apply(&mixer)? - xs.apply(&norm)?)?; + let xs = xs.broadcast_mul(&gamma.reshape((1, (), 1, 1))?)?; + let xs = (xs + residual)?; + Ok(xs) + })) +} + +fn repmixer_block(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result> { + let gamma = vb.get((dim, 1, 1), "layer_scale.gamma")?; + let token_mixer = repmixer(dim, 3, vb.pp("token_mixer"))?; + let mlp = conv_mlp(dim, exp_ratio, vb.pp("mlp"))?; + + Ok(Func::new(move |xs| { + let residual = xs.apply(&token_mixer)?; + let mut xs = residual.apply(&mlp)?; + xs = xs.broadcast_mul(&gamma.reshape((1, (), 1, 1))?)?; + let xs = (xs + residual)?; + Ok(xs) + })) +} + +fn positional_encoding(dim: usize, vb: VarBuilder) -> Result> { + let conv2d_cfg = Conv2dConfig { + stride: 1, + padding: 3, + groups: dim, + ..Default::default() + }; + + let conv = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("pos_enc"))?; + + Ok(Func::new(move |xs| { + let xs = (xs + xs.apply(&conv)?)?; + Ok(xs) + })) +} + +fn attention(dim: usize, vb: VarBuilder) -> Result> { + let qkv = linear_no_bias(dim, dim * 3, vb.pp("qkv"))?; + let proj = linear(dim, dim, vb.pp("proj"))?; + let num_heads = 32; + let head_dim = dim / num_heads; + let scale = (head_dim as f64).powf(-0.5); + + Ok(Func::new(move |xs| { + let xs = xs.clone(); + let (b, c, h, w) = xs.dims4()?; + let n = h * w; + let xs = xs.flatten_from(2)?.transpose(D::Minus1, D::Minus2)?; + let qkv = xs + .apply(&qkv)? + .reshape((b, n, 3, num_heads, head_dim))? + .permute((2, 0, 3, 1, 4))?; + + let q = qkv.get(0)?; + let k = qkv.get(1)?; + let v = qkv.get(2)?; + + let q = (q * scale)?; + + let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?; + let att = softmax(&att, D::Minus1)?; + let xs = att.matmul(&v)?; + + let xs = xs.transpose(1, 2)?.reshape((b, n, c))?; + let xs = xs.apply(&proj)?; + let xs = xs.transpose(D::Minus1, D::Minus2)?.reshape((b, c, h, w))?; + + Ok(xs) + })) +} + +fn attention_block(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result> { + let gamma1 = vb.get((dim, 1, 1), "layer_scale_1.gamma")?; + let gamma2 = vb.get((dim, 1, 1), "layer_scale_2.gamma")?; + let norm = batch_norm(dim, 1e-5, vb.pp("norm"))?; + let token_mixer = attention(dim, vb.pp("token_mixer"))?; + let mlp = conv_mlp(dim, exp_ratio, vb.pp("mlp"))?; + + Ok(Func::new(move |xs| { + let xs = xs.clone(); + let xs = (&xs + + &xs + .apply_t(&norm, false)? + .apply(&token_mixer)? + .broadcast_mul(&gamma1.reshape((1, (), 1, 1))?)?)?; + + let xs = (&xs + + &xs + .apply(&mlp)? + .broadcast_mul(&gamma2.reshape((1, (), 1, 1))?)?)?; + + Ok(xs) + })) +} + +fn fastvit_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result> { + let nblocks = cfg.blocks[idx]; + let mut blocks = Vec::with_capacity(nblocks); + + let dim = cfg.in_channels << idx; + let downsample = fastvit_patch_embed(dim / 2, dim, cfg.lkc_use_act, vb.pp("downsample")); + for block_idx in 0..nblocks { + let block = if cfg.attn && idx == 3 { + attention_block(dim, cfg.exp_ratio, vb.pp(format!("blocks.{block_idx}")))? + } else { + repmixer_block(dim, cfg.exp_ratio, vb.pp(format!("blocks.{block_idx}")))? + }; + blocks.push(block); + } + let pos_emb = positional_encoding(dim, vb.pp("pos_emb")); + + Ok(Func::new(move |xs| { + let mut xs = xs.clone(); + if let Ok(ds) = &downsample { + xs = xs.apply(ds)?; + } + if let Ok(pos) = &pos_emb { + xs = xs.apply(pos)?; + } + for block in blocks.iter() { + xs = xs.apply(block)?; + } + Ok(xs) + })) +} + +fn fastvit_patch_embed( + in_channels: usize, + out_channels: usize, + use_act: bool, + vb: VarBuilder, +) -> Result> { + let lk = conv_norm(in_channels, out_channels, 7, 2, vb.pp("proj.0.large_conv"))?; + let sk = conv_norm(in_channels, out_channels, 3, 2, vb.pp("proj.0.small_conv"))?; + let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("proj.0.se")); + let mb = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp("proj.1"))?; + + Ok(Func::new(move |xs| { + let mut xs = (xs.apply(&lk)? + xs.apply(&sk)?)?; + if let Ok(f) = &se { + xs = xs.apply(f)?; + } + if use_act { + xs = xs.gelu_erf()?; + }; + let xs = xs.apply(&mb)?; + Ok(xs) + })) +} + +fn fastvit_stem(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result> { + let mb0 = mobileone_block(in_channels, out_channels, 3, 2, 0, true, vb.pp(0))?; + let mb1 = mobileone_block(out_channels, out_channels, 3, 2, 1, true, vb.pp(1))?; + let mb2 = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp(2))?; + Ok(Func::new(move |xs| { + let xs = xs.apply(&mb0)?.apply(&mb1)?.apply(&mb2)?; + Ok(xs) + })) +} + +// Build a fastvit model for a given configuration. +fn fastvit_model(cfg: &Config, nclasses: Option, vb: VarBuilder) -> Result> { + let cls = match nclasses { + None => None, + Some(nclasses) => { + let linear = linear(cfg.in_channels * 16, nclasses, vb.pp("head.fc"))?; + Some(linear) + } + }; + + let stem = fastvit_stem(3, cfg.in_channels, vb.pp("stem"))?; + let final_conv = mobileone_block( + cfg.in_channels * 8, + cfg.in_channels * 16, + 3, + 1, + 1, + true, + vb.pp("final_conv"), + )?; + + let vb = vb.pp("stages"); + let stage1 = fastvit_stage(cfg, 0, vb.pp(0))?; + let stage2 = fastvit_stage(cfg, 1, vb.pp(1))?; + let stage3 = fastvit_stage(cfg, 2, vb.pp(2))?; + let stage4 = fastvit_stage(cfg, 3, vb.pp(3))?; + + Ok(Func::new(move |xs| { + let xs = xs + .apply(&stem)? + .apply(&stage1)? + .apply(&stage2)? + .apply(&stage3)? + .apply(&stage4)? + .apply(&final_conv)?; + + match &cls { + None => Ok(xs), + Some(cls) => xs.mean(D::Minus2)?.mean(D::Minus1)?.apply(cls), + } + })) +} + +pub fn fastvit(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result> { + fastvit_model(cfg, Some(nclasses), vb) +} + +pub fn fastvit_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result> { + fastvit_model(cfg, None, vb) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index fc797a771b..39f6ba2fd7 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -19,6 +19,7 @@ pub mod efficientvit; pub mod encodec; pub mod eva2; pub mod falcon; +pub mod fastvit; pub mod flux; pub mod gemma; pub mod gemma2; From ef9649c2cb2ce469466bba608268dd5709a998a9 Mon Sep 17 00:00:00 2001 From: ilookee Date: Fri, 23 Aug 2024 22:50:02 +0800 Subject: [PATCH 61/75] fix: qwen2 lm_head loading #2443 (#2445) Co-authored-by: Yi Xu --- candle-transformers/src/models/qwen2.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs index 3dce5c6a6a..187ea98a10 100644 --- a/candle-transformers/src/models/qwen2.rs +++ b/candle-transformers/src/models/qwen2.rs @@ -361,7 +361,7 @@ pub struct ModelForCausalLM { impl ModelForCausalLM { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let base_model = Model::new(cfg, vb.clone())?; - let lm_head = if vb.contains_tensor("lm_head") { + let lm_head = if vb.contains_tensor("lm_head.weight") { linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? } else { Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None) From 7412bd0bba5f73107619e73a374c5709c03c96f1 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 27 Aug 2024 09:10:30 +0100 Subject: [PATCH 62/75] Update cudarc to 0.12. (#2451) * Update cudarc to 0.12. * Some cudnn tweaks. --- Cargo.toml | 2 +- candle-core/src/cuda_backend/cudnn.rs | 4 ++-- candle-core/src/cuda_backend/mod.rs | 2 ++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index efd39165ab..54ae266d1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,7 @@ candle-onnx = { path = "./candle-onnx", version = "0.6.1" } candle-transformers = { path = "./candle-transformers", version = "0.6.1" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features=false } -cudarc = { version = "=0.11.6", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } +cudarc = { version = "0.12.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } fancy-regex = "0.13.0" gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } hf-hub = "0.3.0" diff --git a/candle-core/src/cuda_backend/cudnn.rs b/candle-core/src/cuda_backend/cudnn.rs index 0c149cd0d0..d604863d35 100644 --- a/candle-core/src/cuda_backend/cudnn.rs +++ b/candle-core/src/cuda_backend/cudnn.rs @@ -1,6 +1,6 @@ use crate::WithDType; use cudarc; -use cudarc::cudnn::safe::{Conv2dForward, Cudnn}; +use cudarc::cudnn::safe::{ConvForward, Cudnn}; use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits}; use std::cell::RefCell; use std::collections::HashMap; @@ -87,7 +87,7 @@ pub(crate) fn launch_conv2d< cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, [params.b_size as i32, params.c_out as i32, h_out, w_out], )?; - let conv2d = Conv2dForward { + let conv2d = ConvForward { conv: &conv, x: &x, w: &w, diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index e68bd3477b..231e24715c 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -175,6 +175,7 @@ impl Map1 for Im2Col1D { } } +#[allow(unused)] struct Im2Col { h_k: usize, w_k: usize, @@ -184,6 +185,7 @@ struct Im2Col { } impl Im2Col { + #[allow(unused)] fn hw_out(&self, h: usize, w: usize) -> (usize, usize) { let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1; let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1; From 8e390866e8720b7de43ffa2e530f85c03cbe3224 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Wed, 28 Aug 2024 12:20:09 +0300 Subject: [PATCH 63/75] FastViT fixes. (#2452) * correct optional SE layer dimensions. * head_dim instead of num_heads is 32. * update test example output. --- README.md | 2 +- candle-examples/examples/fastvit/README.md | 10 +++++----- candle-transformers/src/models/fastvit.rs | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 55c0923ef0..173f907d6f 100644 --- a/README.md +++ b/README.md @@ -243,7 +243,7 @@ If you have an addition to this list, please submit a pull request. - Parler-TTS, text-to-speech model. - Computer Vision Models. - DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT, - ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera. + ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera, FastViT. - yolo-v3, yolo-v8. - Segment-Anything Model (SAM). - SegFormer. diff --git a/candle-examples/examples/fastvit/README.md b/candle-examples/examples/fastvit/README.md index 499685bd3c..467e1032b1 100644 --- a/candle-examples/examples/fastvit/README.md +++ b/candle-examples/examples/fastvit/README.md @@ -12,9 +12,9 @@ $ cargo run --example fastvit --release -- --image candle-examples/examples/yolo loaded image Tensor[dims 3, 256, 256; f32] model built -mountain bike, all-terrain bike, off-roader: 43.45% -bicycle-built-for-two, tandem bicycle, tandem: 14.16% -unicycle, monocycle : 4.12% -crash helmet : 2.26% -alp : 1.40% +mountain bike, all-terrain bike, off-roader: 52.67% +bicycle-built-for-two, tandem bicycle, tandem: 7.93% +unicycle, monocycle : 3.46% +maillot : 1.32% +crash helmet : 1.28% ``` diff --git a/candle-transformers/src/models/fastvit.rs b/candle-transformers/src/models/fastvit.rs index a0b3cc3e57..8199874276 100644 --- a/candle-transformers/src/models/fastvit.rs +++ b/candle-transformers/src/models/fastvit.rs @@ -339,8 +339,8 @@ fn positional_encoding(dim: usize, vb: VarBuilder) -> Result> { fn attention(dim: usize, vb: VarBuilder) -> Result> { let qkv = linear_no_bias(dim, dim * 3, vb.pp("qkv"))?; let proj = linear(dim, dim, vb.pp("proj"))?; - let num_heads = 32; - let head_dim = dim / num_heads; + let head_dim = 32; + let num_heads = dim / head_dim; let scale = (head_dim as f64).powf(-0.5); Ok(Func::new(move |xs| { @@ -434,7 +434,7 @@ fn fastvit_patch_embed( ) -> Result> { let lk = conv_norm(in_channels, out_channels, 7, 2, vb.pp("proj.0.large_conv"))?; let sk = conv_norm(in_channels, out_channels, 3, 2, vb.pp("proj.0.small_conv"))?; - let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("proj.0.se")); + let se = squeeze_and_excitation(out_channels, out_channels / 4, vb.pp("proj.0.se")); let mb = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp("proj.1"))?; Ok(Func::new(move |xs| { From 8632a2f66082b9595ba3eae0221c60d177bdc2a5 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Thu, 29 Aug 2024 16:38:58 +0300 Subject: [PATCH 64/75] MobileCLIP models S1 and S2 (#2454) * Allow loading images with given std and mean * OpenCLIP text encoder component * Two MobileCLIP models * Clippy fixes. --------- Co-authored-by: Laurent --- candle-examples/examples/mobileclip/README.md | 28 ++ candle-examples/examples/mobileclip/main.rs | 192 +++++++++++++ candle-examples/examples/mobilenetv4/main.rs | 5 +- candle-examples/src/imagenet.rs | 35 ++- candle-transformers/src/models/mobileclip.rs | 89 ++++++ candle-transformers/src/models/mod.rs | 2 + .../src/models/openclip/mod.rs | 1 + .../src/models/openclip/text_model.rs | 266 ++++++++++++++++++ 8 files changed, 608 insertions(+), 10 deletions(-) create mode 100644 candle-examples/examples/mobileclip/README.md create mode 100644 candle-examples/examples/mobileclip/main.rs create mode 100644 candle-transformers/src/models/mobileclip.rs create mode 100644 candle-transformers/src/models/openclip/mod.rs create mode 100644 candle-transformers/src/models/openclip/text_model.rs diff --git a/candle-examples/examples/mobileclip/README.md b/candle-examples/examples/mobileclip/README.md new file mode 100644 index 0000000000..a3869b2571 --- /dev/null +++ b/candle-examples/examples/mobileclip/README.md @@ -0,0 +1,28 @@ +# candle-mobileclip + +MobileCLIP is family of efficient CLIP-like models using FastViT-based image encoders. + +See [MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training](https://arxiv.org/abs/2311.17049) + + +## Running on an example on cpu + +``` +$ cargo run --example mobileclip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle" + +softmax_image_vec: [2.4819004e-5, 3.81081e-6, 0.9999714, 0.9999738, 2.382714e-5, 2.3317718e-6] + + +Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg + +Probability: 0.0025% Text: a cycling race +Probability: 0.0004% Text: a photo of two cats +Probability: 99.9971% Text: a robot holding a candle + + +Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg + +Probability: 99.9974% Text: a cycling race +Probability: 0.0024% Text: a photo of two cats +Probability: 0.0002% Text: a robot holding a candle +``` diff --git a/candle-examples/examples/mobileclip/main.rs b/candle-examples/examples/mobileclip/main.rs new file mode 100644 index 0000000000..d505fc7c48 --- /dev/null +++ b/candle-examples/examples/mobileclip/main.rs @@ -0,0 +1,192 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Error as E; +use clap::{Parser, ValueEnum}; + +use candle::{DType, Device, Tensor}; +use candle_nn::{ops::softmax, VarBuilder}; +use candle_transformers::models::mobileclip; + +use tokenizers::Tokenizer; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + S1, + S2, +} + +impl Which { + fn model_name(&self) -> String { + let name = match self { + Self::S1 => "S1", + Self::S2 => "S2", + }; + format!("apple/MobileCLIP-{}-OpenCLIP", name) + } + + fn config(&self) -> mobileclip::MobileClipConfig { + match self { + Self::S1 => mobileclip::MobileClipConfig::s1(), + Self::S2 => mobileclip::MobileClipConfig::s2(), + } + } +} + +#[derive(Parser)] +struct Args { + #[arg(long, use_value_delimiter = true)] + images: Option>, + + #[arg(long)] + cpu: bool, + + /// Use the pytorch weights rather than the safetensors ones + #[arg(long)] + use_pth: bool, + + #[arg(long, use_value_delimiter = true)] + sequences: Option>, + + #[arg(value_enum, long, default_value_t=Which::S1)] + which: Which, +} + +fn load_images>( + paths: &Vec, + image_size: usize, +) -> anyhow::Result { + let mut images = vec![]; + + for path in paths { + let tensor = candle_examples::imagenet::load_image_with_std_mean( + path, + image_size, + &[0.0, 0.0, 0.0], + &[1.0, 1.0, 1.0], + )?; + images.push(tensor); + } + + let images = Tensor::stack(&images, 0)?; + + Ok(images) +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let model_name = args.which.model_name(); + + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(model_name); + + let model_file = if args.use_pth { + api.get("open_clip_pytorch_model.bin")? + } else { + api.get("open_clip_model.safetensors")? + }; + + let tokenizer = api.get("tokenizer.json")?; + + let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; + + let config = &args.which.config(); + + let device = candle_examples::device(args.cpu)?; + + let vec_imgs = match args.images { + Some(imgs) => imgs, + None => vec![ + "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(), + "candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(), + ], + }; + + let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?; + + let vb = if args.use_pth { + VarBuilder::from_pth(&model_file, DType::F32, &device)? + } else { + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? } + }; + + let model = mobileclip::MobileClipModel::new(vb, config)?; + + let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?; + + let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?; + + let softmax_image = softmax(&logits_per_image, 1)?; + + let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::()?; + + println!("softmax_image_vec: {:?}", softmax_image_vec); + + let probability_vec = softmax_image_vec + .iter() + .map(|v| v * 100.0) + .collect::>(); + + let probability_per_image = probability_vec.len() / vec_imgs.len(); + + for (i, img) in vec_imgs.iter().enumerate() { + let start = i * probability_per_image; + let end = start + probability_per_image; + let prob = &probability_vec[start..end]; + println!("\n\nResults for image: {}\n", img); + + for (i, p) in prob.iter().enumerate() { + println!("Probability: {:.4}% Text: {}", p, vec_seq[i]); + } + } + + Ok(()) +} + +pub fn tokenize_sequences( + sequences: Option>, + tokenizer: &Tokenizer, + device: &Device, +) -> anyhow::Result<(Tensor, Vec)> { + // let pad_id = *tokenizer + // .get_vocab(true) + // .get("<|endoftext|>") + // .ok_or(E::msg("No pad token"))?; + + // The model does not work well if the text is padded using the <|endoftext|> token, using 0 + // as the original OpenCLIP code. + let pad_id = 0; + + let vec_seq = match sequences { + Some(seq) => seq, + None => vec![ + "a cycling race".to_string(), + "a photo of two cats".to_string(), + "a robot holding a candle".to_string(), + ], + }; + + let mut tokens = vec![]; + + for seq in vec_seq.clone() { + let encoding = tokenizer.encode(seq, true).map_err(E::msg)?; + tokens.push(encoding.get_ids().to_vec()); + } + + let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0); + // Pad the sequences to have the same length + for token_vec in tokens.iter_mut() { + let len_diff = max_len - token_vec.len(); + if len_diff > 0 { + token_vec.extend(vec![pad_id; len_diff]); + } + } + + let input_ids = Tensor::new(tokens, device)?; + + Ok((input_ids, vec_seq)) +} diff --git a/candle-examples/examples/mobilenetv4/main.rs b/candle-examples/examples/mobilenetv4/main.rs index 26c0dad95c..c31b91e6e4 100644 --- a/candle-examples/examples/mobilenetv4/main.rs +++ b/candle-examples/examples/mobilenetv4/main.rs @@ -72,8 +72,9 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image(args.image, args.which.resolution())? - .to_device(&device)?; + let image = + candle_examples::imagenet::load_image(args.image, args.which.resolution() as usize)? + .to_device(&device)?; println!("loaded image {image:?}"); let model_file = match args.model { diff --git a/candle-examples/src/imagenet.rs b/candle-examples/src/imagenet.rs index 6fcda4243b..a3b1242387 100644 --- a/candle-examples/src/imagenet.rs +++ b/candle-examples/src/imagenet.rs @@ -1,23 +1,42 @@ use candle::{Device, Result, Tensor}; -/// Loads an image from disk using the image crate at the requested resolution. -// This returns a tensor with shape (3, res, res). imagenet normalization is applied. -pub fn load_image>(p: P, res: u32) -> Result { +pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406]; +pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225]; + +/// Loads an image from disk using the image crate at the requested resolution, +/// using the given std and mean parameters. +/// This returns a tensor with shape (3, res, res). imagenet normalization is applied. + +pub fn load_image_with_std_mean>( + p: P, + res: usize, + mean: &[f32; 3], + std: &[f32; 3], +) -> Result { let img = image::ImageReader::open(p)? .decode() .map_err(candle::Error::wrap)? - .resize_to_fill(res, res, image::imageops::FilterType::Triangle); + .resize_to_fill( + res as u32, + res as u32, + image::imageops::FilterType::Triangle, + ); let img = img.to_rgb8(); let data = img.into_raw(); - let data = Tensor::from_vec(data, (res as usize, res as usize, 3), &Device::Cpu)? - .permute((2, 0, 1))?; - let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?; - let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?; + let data = Tensor::from_vec(data, (res, res, 3), &Device::Cpu)?.permute((2, 0, 1))?; + let mean = Tensor::new(mean, &Device::Cpu)?.reshape((3, 1, 1))?; + let std = Tensor::new(std, &Device::Cpu)?.reshape((3, 1, 1))?; (data.to_dtype(candle::DType::F32)? / 255.)? .broadcast_sub(&mean)? .broadcast_div(&std) } +/// Loads an image from disk using the image crate at the requested resolution. +/// This returns a tensor with shape (3, res, res). imagenet normalization is applied. +pub fn load_image>(p: P, res: usize) -> Result { + load_image_with_std_mean(p, res, &IMAGENET_MEAN, &IMAGENET_STD) +} + /// Loads an image from disk using the image crate, this returns a tensor with shape /// (3, 224, 224). imagenet normalization is applied. pub fn load_image224>(p: P) -> Result { diff --git a/candle-transformers/src/models/mobileclip.rs b/candle-transformers/src/models/mobileclip.rs new file mode 100644 index 0000000000..4953d835b5 --- /dev/null +++ b/candle-transformers/src/models/mobileclip.rs @@ -0,0 +1,89 @@ +use super::fastvit; +use super::openclip::text_model; +use candle::{Result, Tensor, D}; +use candle_nn::{Func, VarBuilder}; + +#[derive(Clone, Debug)] +pub struct MobileClipModel { + text_model: text_model::OpenClipTextTransformer, + vision_model: Func<'static>, + text_projection: Tensor, + logit_scale: Tensor, +} + +#[derive(Clone, Debug)] +pub struct MobileClipConfig { + pub text_config: text_model::Config, + pub vision_config: fastvit::Config, + pub image_size: usize, +} + +impl MobileClipConfig { + pub fn s1() -> Self { + let text_config = text_model::Config::vit_base_patch32(); + let vision_config = fastvit::Config::mci1(); + + Self { + text_config, + vision_config, + image_size: 256, + } + } + pub fn s2() -> Self { + let text_config = text_model::Config::vit_base_patch32(); + let vision_config = fastvit::Config::mci2(); + + Self { + text_config, + vision_config, + image_size: 256, + } + } +} + +impl MobileClipModel { + pub fn new(vs: VarBuilder, c: &MobileClipConfig) -> Result { + let vision_model = fastvit::fastvit(&c.vision_config, 512, vs.pp("visual.trunk"))?; + let text_model = text_model::OpenClipTextTransformer::new(vs.pp("text"), &c.text_config)?; + + let text_projection = vs.get( + (c.text_config.embed_dim, c.text_config.projection_dim), + "text.text_projection", + )?; + + let logit_scale = vs.get(&[], "logit_scale")?; + Ok(Self { + text_model, + vision_model, + text_projection, + logit_scale, + }) + } + + pub fn get_text_features(&self, input_ids: &Tensor) -> Result { + input_ids + .apply(&self.text_model)? + .matmul(&self.text_projection) + } + + pub fn get_image_features(&self, pixel_values: &Tensor) -> Result { + pixel_values.apply(&self.vision_model) + } + + pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> { + let image_features = self.get_image_features(pixel_values)?; + let text_features = self.get_text_features(input_ids)?; + let image_features_normalized = div_l2_norm(&image_features)?; + let text_features_normalized = div_l2_norm(&text_features)?; + let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?; + let logit_scale = self.logit_scale.exp()?; + let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?; + let logits_per_image = logits_per_text.t()?; + Ok((logits_per_text, logits_per_image)) + } +} + +pub fn div_l2_norm(v: &Tensor) -> Result { + let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?; + v.broadcast_div(&l2_norm) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 39f6ba2fd7..38a01595e0 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -37,11 +37,13 @@ pub mod mistral; pub mod mixformer; pub mod mixtral; pub mod mmdit; +pub mod mobileclip; pub mod mobilenetv4; pub mod mobileone; pub mod moondream; pub mod mpt; pub mod olmo; +pub mod openclip; pub mod parler_tts; pub mod persimmon; pub mod phi; diff --git a/candle-transformers/src/models/openclip/mod.rs b/candle-transformers/src/models/openclip/mod.rs new file mode 100644 index 0000000000..ee2a501d6a --- /dev/null +++ b/candle-transformers/src/models/openclip/mod.rs @@ -0,0 +1 @@ +pub mod text_model; diff --git a/candle-transformers/src/models/openclip/text_model.rs b/candle-transformers/src/models/openclip/text_model.rs new file mode 100644 index 0000000000..7b444e797e --- /dev/null +++ b/candle-transformers/src/models/openclip/text_model.rs @@ -0,0 +1,266 @@ +//! Text encoder as used in most OpenCLIP pretrained models +//! https://github.com/mlfoundations/open_clip + +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{ + embedding, layer_norm, linear, ops::softmax_last_dim, Embedding, LayerNorm, Linear, Module, + VarBuilder, +}; + +#[derive(Debug, Clone)] +pub struct Config { + pub vocab_size: usize, + pub embed_dim: usize, + pub intermediate_size: usize, + pub max_position_embeddings: usize, + pub pad_with: Option, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub projection_dim: usize, +} + +impl Config { + pub fn vit_base_patch32() -> Self { + Self { + vocab_size: 49408, + embed_dim: 512, + intermediate_size: 2048, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 12, + num_attention_heads: 8, + projection_dim: 512, + } + } +} + +#[derive(Clone, Debug)] +struct TextEmbeddings { + token_embedding: Embedding, + position_embedding: Tensor, +} + +impl TextEmbeddings { + fn new(vs: VarBuilder, c: &Config) -> Result { + let token_embedding = embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?; + let position_embedding = vs.get( + (c.max_position_embeddings, c.embed_dim), + "positional_embedding", + )?; + Ok(TextEmbeddings { + token_embedding, + position_embedding, + }) + } +} + +impl Module for TextEmbeddings { + fn forward(&self, input_ids: &Tensor) -> Result { + let seq_length = input_ids.dim(D::Minus1)?; + let inputs_embeds = self.token_embedding.forward(input_ids)?; + + let position_embedding = self.position_embedding.narrow(0, 0, seq_length)?; + + inputs_embeds.broadcast_add(&position_embedding) + } +} + +#[derive(Clone, Debug)] +struct Attention { + k_proj: candle_nn::Linear, + v_proj: candle_nn::Linear, + q_proj: candle_nn::Linear, + out_proj: Linear, + head_dim: usize, + scale: f64, + num_attention_heads: usize, +} + +impl Attention { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result { + let embed_dim = c.embed_dim; + let num_attention_heads = c.num_attention_heads; + + let in_proj_weights = vs + .get((embed_dim * 3, embed_dim), "in_proj_weight")? + .chunk(3, 0)?; + let (q_w, k_w, v_w) = ( + &in_proj_weights[0], + &in_proj_weights[1], + &in_proj_weights[2], + ); + let in_proj_biases = vs.get(embed_dim * 3, "in_proj_bias")?.chunk(3, 0)?; + let (q_b, k_b, v_b) = (&in_proj_biases[0], &in_proj_biases[1], &in_proj_biases[2]); + + let q_proj = Linear::new(q_w.clone(), Some(q_b.clone())); + let k_proj = Linear::new(k_w.clone(), Some(k_b.clone())); + let v_proj = Linear::new(v_w.clone(), Some(v_b.clone())); + let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?; + let head_dim = embed_dim / num_attention_heads; + let scale = (head_dim as f64).powf(-0.5); + + Ok(Attention { + k_proj, + v_proj, + q_proj, + out_proj, + head_dim, + scale, + num_attention_heads, + }) + } + + fn shape_multihead(&self, xs: &Tensor, bsz: usize, seq_len: usize) -> Result { + xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()? + .to_dtype(DType::F32) + } + + fn forward(&self, xs: &Tensor) -> Result { + let in_dtype = xs.dtype(); + let (bsz, seq_len, embed_dim) = xs.dims3()?; + + let q = self.shape_multihead(&self.q_proj.forward(xs)?, bsz, seq_len)?; + let k = self.shape_multihead(&self.k_proj.forward(xs)?, bsz, seq_len)?; + let v = self.shape_multihead(&self.v_proj.forward(xs)?, bsz, seq_len)?; + let q = (q * self.scale)?; + + let attn_weights = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?; + + let attn_weights = softmax_last_dim(&attn_weights)?; + + let attn_output = attn_weights.matmul(&v)?.to_dtype(in_dtype)?; + let attn_output = attn_output + .transpose(1, 2)? + .contiguous()? + .reshape((bsz, seq_len, embed_dim))?; + let out = self.out_proj.forward(&attn_output)?; + Ok(out) + } +} + +#[derive(Clone, Debug)] +struct Mlp { + fc1: Linear, + fc2: Linear, +} + +impl Mlp { + fn new(vs: VarBuilder, c: &Config) -> Result { + let fc1 = linear(c.embed_dim, c.intermediate_size, vs.pp("c_fc"))?; + let fc2 = linear(c.intermediate_size, c.embed_dim, vs.pp("c_proj"))?; + + Ok(Mlp { fc1, fc2 }) + } +} + +impl Mlp { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.fc1.forward(xs)?; + self.fc2.forward(&xs.gelu_erf()?) + } +} + +#[derive(Clone, Debug)] +struct EncoderLayer { + self_attn: Attention, + layer_norm1: LayerNorm, + mlp: Mlp, + layer_norm2: LayerNorm, +} + +impl EncoderLayer { + fn new(vs: VarBuilder, c: &Config) -> Result { + let self_attn = Attention::new(vs.pp("attn"), c)?; + let layer_norm1 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_1"))?; + let mlp = Mlp::new(vs.pp("mlp"), c)?; + let layer_norm2 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_2"))?; + + Ok(EncoderLayer { + self_attn, + layer_norm1, + mlp, + layer_norm2, + }) + } + + fn forward(&self, xs: &Tensor) -> Result { + let residual = xs; + let xs = self.layer_norm1.forward(xs)?; + let xs = self.self_attn.forward(&xs)?; + let xs = (xs + residual)?; + + let residual = &xs; + let xs = self.layer_norm2.forward(&xs)?; + let xs = self.mlp.forward(&xs)?; + let out = (xs + residual)?; + Ok(out) + } +} + +#[derive(Clone, Debug)] +pub struct Encoder { + layers: Vec, +} + +impl Encoder { + pub fn new(vs: VarBuilder, c: &Config) -> Result { + let vs = vs.pp("resblocks"); + let mut layers: Vec = Vec::new(); + for index in 0..c.num_hidden_layers { + let layer = EncoderLayer::new(vs.pp(index.to_string()), c)?; + layers.push(layer) + } + Ok(Encoder { layers }) + } + + pub fn forward(&self, xs: &Tensor) -> Result { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs)?; + } + Ok(xs) + } +} + +/// A text transformer as used in CLIP variants. +#[derive(Clone, Debug)] +pub struct OpenClipTextTransformer { + embeddings: TextEmbeddings, + encoder: Encoder, + final_layer_norm: LayerNorm, +} + +impl OpenClipTextTransformer { + pub fn new(vs: VarBuilder, c: &Config) -> Result { + let embeddings = TextEmbeddings::new(vs.clone(), c)?; + let final_layer_norm = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_final"))?; + let encoder = Encoder::new(vs.pp("transformer"), c)?; + Ok(OpenClipTextTransformer { + embeddings, + encoder, + final_layer_norm, + }) + } + + pub fn forward(&self, input_ids: &Tensor) -> Result { + let input_ids = self.embeddings.forward(input_ids)?; + let input_ids = self.encoder.forward(&input_ids)?; + self.final_layer_norm.forward(&input_ids) + } +} + +impl Module for OpenClipTextTransformer { + fn forward(&self, input_ids: &Tensor) -> Result { + let output = self.forward(input_ids)?; + let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?; + + let mut indices = Vec::new(); + for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::()?.iter().enumerate() { + let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?; + indices.push(index); + } + Tensor::cat(&indices, 0) + } +} From f492c04973ef8669e45472be7f8967a47f9d591a Mon Sep 17 00:00:00 2001 From: Eugene Hauptmann Date: Thu, 29 Aug 2024 17:10:28 +0200 Subject: [PATCH 65/75] Fix FLUX.1 weights (#2457) * fix FLUX.1 weights * added flux1-dev.safetensors --- candle-examples/examples/flux/main.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/candle-examples/examples/flux/main.rs b/candle-examples/examples/flux/main.rs index a9278d013d..539ae6f260 100644 --- a/candle-examples/examples/flux/main.rs +++ b/candle-examples/examples/flux/main.rs @@ -147,8 +147,8 @@ fn run(args: Args) -> Result<()> { println!("CLIP\n{clip_emb}"); let img = { let model_file = match model { - Model::Schnell => bf_repo.get("flux1-schnell.sft")?, - Model::Dev => bf_repo.get("flux1-dev.sft")?, + Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?, + Model::Dev => bf_repo.get("flux1-dev.safetensors")?, }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; @@ -189,7 +189,7 @@ fn run(args: Args) -> Result<()> { println!("latent img\n{img}"); let img = { - let model_file = bf_repo.get("ae.sft")?; + let model_file = bf_repo.get("ae.safetensors")?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; let cfg = match model { Model::Dev => flux::autoencoder::Config::dev(), From 91e0c6ed3376e46a050f3493123852b88d3b56cd Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 5 Sep 2024 22:46:55 +0100 Subject: [PATCH 66/75] Clippy fixes for 1.81.0. (#2461) * Clippy fixes for 1.81.0. * Another fix. --- candle-examples/examples/musicgen/musicgen_model.rs | 6 +++--- candle-examples/examples/silero-vad/main.rs | 1 - candle-examples/examples/yolo-v3/darknet.rs | 6 +++--- candle-examples/examples/yolo-v8/model.rs | 2 +- candle-transformers/src/models/bert.rs | 6 +++--- candle-transformers/src/models/bigcode.rs | 2 +- candle-transformers/src/models/distilbert.rs | 6 +++--- candle-transformers/src/models/falcon.rs | 2 +- candle-transformers/src/models/jina_bert.rs | 2 +- candle-transformers/src/models/llama.rs | 2 +- candle-transformers/src/models/llama2_c.rs | 2 +- candle-transformers/src/models/moondream.rs | 2 +- candle-transformers/src/models/segformer.rs | 8 ++++---- candle-transformers/src/models/t5.rs | 2 +- candle-transformers/src/models/whisper/model.rs | 4 ++-- candle-wasm-examples/llama2-c/src/model.rs | 2 +- candle-wasm-examples/yolo/src/model.rs | 2 +- 17 files changed, 28 insertions(+), 29 deletions(-) diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index 03e9661418..7fbe8b5306 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -284,11 +284,11 @@ impl MusicgenDecoder { }; let embed_dim = cfg.vocab_size + 1; let embed_tokens = (0..cfg.num_codebooks) - .map(|i| embedding(embed_dim, h, vb.pp(&format!("embed_tokens.{i}")))) + .map(|i| embedding(embed_dim, h, vb.pp(format!("embed_tokens.{i}")))) .collect::>>()?; let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb.clone(), cfg)?; let layers = (0..cfg.num_hidden_layers) - .map(|i| MusicgenDecoderLayer::load(vb.pp(&format!("layers.{i}")), cfg)) + .map(|i| MusicgenDecoderLayer::load(vb.pp(format!("layers.{i}")), cfg)) .collect::>>()?; let layer_norm = layer_norm(h, 1e-5, vb.pp("layer_norm"))?; Ok(Self { @@ -341,7 +341,7 @@ impl MusicgenForCausalLM { let h = cfg.hidden_size; let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?; let lm_heads = (0..cfg.num_codebooks) - .map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(&format!("lm_heads.{i}")))) + .map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(format!("lm_heads.{i}")))) .collect::>>()?; Ok(Self { decoder, diff --git a/candle-examples/examples/silero-vad/main.rs b/candle-examples/examples/silero-vad/main.rs index 4618ad80df..20a0b7a238 100644 --- a/candle-examples/examples/silero-vad/main.rs +++ b/candle-examples/examples/silero-vad/main.rs @@ -8,7 +8,6 @@ use anyhow::Result; use clap::Parser; use candle::{DType, Tensor}; -use candle_onnx; #[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] enum Which { diff --git a/candle-examples/examples/yolo-v3/darknet.rs b/candle-examples/examples/yolo-v3/darknet.rs index 1892acdd68..944f4dcb59 100644 --- a/candle-examples/examples/yolo-v3/darknet.rs +++ b/candle-examples/examples/yolo-v3/darknet.rs @@ -123,7 +123,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl) let padding = if pad != 0 { (size - 1) / 2 } else { 0 }; let (bn, bias) = match b.parameters.get("batch_normalize") { Some(p) if p.parse::()? != 0 => { - let bn = batch_norm(filters, 1e-5, vb.pp(&format!("batch_norm_{index}")))?; + let bn = batch_norm(filters, 1e-5, vb.pp(format!("batch_norm_{index}")))?; (Some(bn), false) } Some(_) | None => (None, true), @@ -135,9 +135,9 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl) dilation: 1, }; let conv = if bias { - conv2d(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))? + conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))? } else { - conv2d_no_bias(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))? + conv2d_no_bias(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))? }; let leaky = match activation { "leaky" => true, diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs index cecd4ce6c4..e1be1f3c80 100644 --- a/candle-examples/examples/yolo-v8/model.rs +++ b/candle-examples/examples/yolo-v8/model.rs @@ -161,7 +161,7 @@ impl C2f { let cv2 = ConvBlock::load(vb.pp("cv2"), (2 + n) * c, c2, 1, 1, None)?; let mut bottleneck = Vec::with_capacity(n); for idx in 0..n { - let b = Bottleneck::load(vb.pp(&format!("bottleneck.{idx}")), c, c, shortcut)?; + let b = Bottleneck::load(vb.pp(format!("bottleneck.{idx}")), c, c, shortcut)?; bottleneck.push(b) } Ok(Self { diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 2262aa1a8c..354048de97 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -419,7 +419,7 @@ struct BertEncoder { impl BertEncoder { fn load(vb: VarBuilder, config: &Config) -> Result { let layers = (0..config.num_hidden_layers) - .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config)) + .map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config)) .collect::>>()?; let span = tracing::span!(tracing::Level::TRACE, "encoder"); Ok(BertEncoder { layers, span }) @@ -454,8 +454,8 @@ impl BertModel { (Err(err), _) | (_, Err(err)) => { if let Some(model_type) = &config.model_type { if let (Ok(embeddings), Ok(encoder)) = ( - BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), - BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config), + BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config), + BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config), ) { (embeddings, encoder) } else { diff --git a/candle-transformers/src/models/bigcode.rs b/candle-transformers/src/models/bigcode.rs index 2e1bbd37af..f6b4a4efdc 100644 --- a/candle-transformers/src/models/bigcode.rs +++ b/candle-transformers/src/models/bigcode.rs @@ -298,7 +298,7 @@ impl GPTBigCode { let wte = embedding(cfg.vocab_size, hidden_size, vb_t.pp("wte"))?; let wpe = embedding(cfg.max_position_embeddings, hidden_size, vb_t.pp("wpe"))?; let blocks = (0..cfg.num_hidden_layers) - .map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg)) + .map(|i| Block::load(vb_t.pp(format!("h.{i}")), &cfg)) .collect::>>()?; let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?; let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?; diff --git a/candle-transformers/src/models/distilbert.rs b/candle-transformers/src/models/distilbert.rs index ea074c9782..f899d772a2 100644 --- a/candle-transformers/src/models/distilbert.rs +++ b/candle-transformers/src/models/distilbert.rs @@ -275,7 +275,7 @@ struct Transformer { impl Transformer { fn load(vb: VarBuilder, config: &Config) -> Result { let layers = (0..config.n_layers) - .map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config)) + .map(|index| TransformerBlock::load(vb.pp(format!("layer.{index}")), config)) .collect::>>()?; let span = tracing::span!(tracing::Level::TRACE, "encoder"); Ok(Transformer { layers, span }) @@ -311,8 +311,8 @@ impl DistilBertModel { (Err(err), _) | (_, Err(err)) => { if let Some(model_type) = &config.model_type { if let (Ok(embeddings), Ok(encoder)) = ( - Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), - Transformer::load(vb.pp(&format!("{model_type}.transformer")), config), + Embeddings::load(vb.pp(format!("{model_type}.embeddings")), config), + Transformer::load(vb.pp(format!("{model_type}.transformer")), config), ) { (embeddings, encoder) } else { diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs index 3a3575aac2..50ec66f316 100644 --- a/candle-transformers/src/models/falcon.rs +++ b/candle-transformers/src/models/falcon.rs @@ -448,7 +448,7 @@ impl Falcon { vb.pp("transformer.word_embeddings"), )?; let blocks = (0..cfg.num_hidden_layers) - .map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg)) + .map(|i| FalconDecoderLayer::load(vb.pp(format!("transformer.h.{i}")), &cfg)) .collect::>>()?; let ln_f = layer_norm( cfg.hidden_size, diff --git a/candle-transformers/src/models/jina_bert.rs b/candle-transformers/src/models/jina_bert.rs index a9ae37e914..1f0fae1ee4 100644 --- a/candle-transformers/src/models/jina_bert.rs +++ b/candle-transformers/src/models/jina_bert.rs @@ -344,7 +344,7 @@ impl BertEncoder { candle::bail!("only alibi is supported as a position-embedding-type") } let layers = (0..cfg.num_hidden_layers) - .map(|index| BertLayer::new(vb.pp(&format!("layer.{index}")), cfg)) + .map(|index| BertLayer::new(vb.pp(format!("layer.{index}")), cfg)) .collect::>>()?; let span = tracing::span!(tracing::Level::TRACE, "encoder"); let alibi = build_alibi_bias(cfg)?.to_device(vb.device())?; diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 3681472be8..e96bb855b2 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -507,7 +507,7 @@ impl Llama { let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.num_hidden_layers) - .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap()) + .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap()) .collect(); Ok(Self { diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index ac9bed2850..91d298e97f 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -360,7 +360,7 @@ impl Llama { let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; let ln_f = rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) - .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), &cfg).unwrap()) + .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), &cfg).unwrap()) .collect(); Ok(Self { wte, diff --git a/candle-transformers/src/models/moondream.rs b/candle-transformers/src/models/moondream.rs index e63fcf6e74..cde59d43d6 100644 --- a/candle-transformers/src/models/moondream.rs +++ b/candle-transformers/src/models/moondream.rs @@ -167,7 +167,7 @@ impl VisionTransformer { let blocks = (0..cfg.num_blocks) .map(|i| { VitBlock::new( - vb.pp(&format!("blocks.{}", i)), + vb.pp(format!("blocks.{}", i)), cfg.embed_dim, cfg.num_heads, cfg, diff --git a/candle-transformers/src/models/segformer.rs b/candle-transformers/src/models/segformer.rs index 3727e00427..260ceb3a84 100644 --- a/candle-transformers/src/models/segformer.rs +++ b/candle-transformers/src/models/segformer.rs @@ -404,7 +404,7 @@ impl SegformerEncoder { stride, num_channels, hidden_size, - vb.pp(&format!("patch_embeddings.{}", i)), + vb.pp(format!("patch_embeddings.{}", i)), )?); let mut layers = Vec::with_capacity(config.depths[i]); for j in 0..config.depths[i] { @@ -417,14 +417,14 @@ impl SegformerEncoder { num_attention_heads, sequence_reduction_ratio, mlp_ratio, - vb.pp(&format!("block.{}.{}", i, j)), + vb.pp(format!("block.{}.{}", i, j)), )?); } blocks.push(layers); layer_norms.push(layer_norm( hidden_size, config.layer_norm_eps, - vb.pp(&format!("layer_norm.{}", i)), + vb.pp(format!("layer_norm.{}", i)), )?); } Ok(Self { @@ -507,7 +507,7 @@ impl SegformerDecodeHead { linear_c.push(SegformerMLP::new( config, hidden_size, - vb.pp(&format!("linear_c.{}", i)), + vb.pp(format!("linear_c.{}", i)), )?); } let linear_fuse = conv2d_no_bias( diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 21517d64b5..84e072a294 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -659,7 +659,7 @@ struct T5Stack { impl T5Stack { fn load(decoder: bool, vb: VarBuilder, shared: &Arc, cfg: &Config) -> Result { let block = (0..cfg.num_layers) - .map(|i| T5Block::load(i == 0, decoder, vb.pp(&format!("block.{i}")), cfg)) + .map(|i| T5Block::load(i == 0, decoder, vb.pp(format!("block.{i}")), cfg)) .collect::>>()?; let final_layer_norm = T5LayerNorm::load( cfg.d_model, diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index 593ed373da..dc50e0dbc3 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -260,7 +260,7 @@ impl AudioEncoder { let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?; let blocks = (0..cfg.encoder_layers) .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}"))) + ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}"))) }) .collect::>>()?; let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; @@ -321,7 +321,7 @@ impl TextDecoder { let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?; let blocks = (0..cfg.decoder_layers) .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}"))) + ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}"))) }) .collect::>>()?; let ln = layer_norm(n_state, vb.pp("layer_norm"))?; diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index f640bbd08e..dae6cb4fdc 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -290,7 +290,7 @@ impl Llama { let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; let norm = candle_nn::rms_norm_non_quant(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) - .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap()) + .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cache, cfg).unwrap()) .collect(); Ok(Self::new(wte, blocks, norm, lm_head)) } diff --git a/candle-wasm-examples/yolo/src/model.rs b/candle-wasm-examples/yolo/src/model.rs index f1d7ea2056..ee98c1256f 100644 --- a/candle-wasm-examples/yolo/src/model.rs +++ b/candle-wasm-examples/yolo/src/model.rs @@ -155,7 +155,7 @@ impl C2f { let cv2 = ConvBlock::load(vb.pp("cv2"), (2 + n) * c, c2, 1, 1, None)?; let mut bottleneck = Vec::with_capacity(n); for idx in 0..n { - let b = Bottleneck::load(vb.pp(&format!("bottleneck.{idx}")), c, c, shortcut)?; + let b = Bottleneck::load(vb.pp(format!("bottleneck.{idx}")), c, c, shortcut)?; bottleneck.push(b) } Ok(Self { From ad84486241dc646487b562704f1aed40c5ef5f36 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Tue, 10 Sep 2024 20:45:52 -0400 Subject: [PATCH 67/75] Improve candle_core::Error to make it more ergonomic (#21) * Bump the version to 0.6.1. (#2438) * onnx: workaround pow with negative base (#2439) * onnx: workaround pow with negative base rather than fully defining pow in the cpu backend (as in #2318), this implements a much smaller change which is sufficient to evaluate silero-vad onnx models. Specifically, checking if pow is run with 2.0 exponent, and if so evaluate as simply `x*x` instead of the cpu backend of `e^(2.0 * ln(x))`. * PR: use Tensor::powf insead powf correctly handles a negative base. * onnx: support negative index in Gather (#2440) index_select does not support negative indexing, but this change adds just enough workarounds in onnx to allow evaluating silero-vad models (which make use of negative indices). * silero-vad v5 example (#2321) * silero-vad v5 example This change adds an example of how to run silero-vad v5 * PR: rename 'vad' to 'silero-vad' * Update README.md --------- Co-authored-by: Laurent Mazare * Fix for parler-tts, do not add the last slice of padding tokens. (#2442) * Fix for parler-tts, do not add the last slice of padding tokens. * Support for the mini model. * Add FastViT model. (#2444) * fix: qwen2 lm_head loading #2443 (#2445) Co-authored-by: Yi Xu * Update cudarc to 0.12. (#2451) * Update cudarc to 0.12. * Some cudnn tweaks. * FastViT fixes. (#2452) * correct optional SE layer dimensions. * head_dim instead of num_heads is 32. * update test example output. * MobileCLIP models S1 and S2 (#2454) * Allow loading images with given std and mean * OpenCLIP text encoder component * Two MobileCLIP models * Clippy fixes. --------- Co-authored-by: Laurent * Fix FLUX.1 weights (#2457) * fix FLUX.1 weights * added flux1-dev.safetensors * Clippy fixes for 1.81.0. (#2461) * Clippy fixes for 1.81.0. * Another fix. * Make Error::msg more in line with anyhow::Error::msg * Add context trait * Even more flexible * Format --------- Co-authored-by: Laurent Mazare Co-authored-by: shua Co-authored-by: Jani Monoses Co-authored-by: ilookee Co-authored-by: Yi Xu Co-authored-by: Eugene Hauptmann --- candle-core/src/error.rs | 106 ++++++++++++++++++++++++++++++++++++++- candle-core/src/lib.rs | 2 +- 2 files changed, 105 insertions(+), 3 deletions(-) diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 67b391451b..66f9fd4175 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -1,3 +1,8 @@ +use std::{ + convert::Infallible, + fmt::{Debug, Display}, +}; + use crate::{DType, DeviceLocation, Layout, MetalError, Shape}; #[derive(Debug, Clone)] @@ -210,6 +215,13 @@ pub enum Error { #[error(transparent)] Wrapped(Box), + /// Arbitrary errors wrapping with context. + #[error("{wrapped:?}\n{context:?}")] + WrappedContext { + wrapped: Box, + context: String, + }, + /// Adding path information to an error. #[error("path: {path:?} {inner}")] WithPath { @@ -231,14 +243,21 @@ pub enum Error { pub type Result = std::result::Result; impl Error { + /// Create a new error by wrapping another. pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self { Self::Wrapped(Box::new(err)).bt() } - pub fn msg(err: impl std::error::Error) -> Self { - Self::Msg(err.to_string()).bt() + /// Create a new error based on a printable error message. + /// + /// If the message implements `std::error::Error`, prefer using [`Error::wrap`] instead. + pub fn msg(msg: M) -> Self { + Self::Msg(msg.to_string()).bt() } + /// Create a new error based on a debuggable error message. + /// + /// If the message implements `std::error::Error`, prefer using [`Error::wrap`] instead. pub fn debug(err: impl std::fmt::Debug) -> Self { Self::Msg(format!("{err:?}")).bt() } @@ -283,3 +302,86 @@ pub fn zip(r1: Result, r2: Result) -> Result<(T, U)> { (_, Err(e)) => Err(e), } } + +pub(crate) mod private { + pub trait Sealed {} + + impl Sealed for std::result::Result where E: std::error::Error {} + impl Sealed for Option {} +} + +/// Attach more context to an error. +/// +/// Inspired by [`anyhow::Context`]. +pub trait Context: private::Sealed { + /// Wrap the error value with additional context. + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static; + + /// Wrap the error value with additional context that is evaluated lazily + /// only once an error does occur. + fn with_context(self, f: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C; +} + +impl Context for std::result::Result +where + E: std::error::Error + Send + Sync + 'static, +{ + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static, + { + // Not using map_err to save 2 useless frames off the captured backtrace + // in ext_context. + match self { + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context.to_string(), + }), + } + } + + fn with_context(self, context: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context().to_string(), + }), + } + } +} + +impl Context for Option { + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static, + { + // Not using ok_or_else to save 2 useless frames off the captured + // backtrace. + match self { + Some(ok) => Ok(ok), + None => Err(Error::msg(context)), + } + } + + fn with_context(self, context: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Some(ok) => Ok(ok), + None => Err(Error::msg(context())), + } + } +} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 4861294ac4..6bd48c81c2 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -81,7 +81,7 @@ pub use cpu_backend::{CpuStorage, CpuStorageRef}; pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3}; pub use device::{Device, DeviceLocation, NdArray}; pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType}; -pub use error::{Error, Result}; +pub use error::{Context, Error, Result}; pub use indexer::IndexOp; pub use layout::Layout; pub use shape::{Shape, D}; From 7f5a47040e798f0076014c9d9e82cc6cb25708a0 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Wed, 11 Sep 2024 06:04:36 -0400 Subject: [PATCH 68/75] Add API to get current device seed (#22) * Add api to get current seed * Remove cell for rwlock --- .vscode/settings.json | 2 +- candle-core/src/backend.rs | 1 + candle-core/src/cpu_backend/mod.rs | 4 ++++ candle-core/src/cuda_backend/device.rs | 9 ++++++++- candle-core/src/device.rs | 9 +++++++++ candle-core/src/dummy_cuda_backend.rs | 4 ++++ candle-core/src/dummy_metal_backend.rs | 4 ++++ candle-core/src/metal_backend/device.rs | 2 ++ candle-core/src/metal_backend/mod.rs | 7 +++++++ 9 files changed, 40 insertions(+), 2 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index f9b6ef02f9..6abf0d3d6d 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -9,6 +9,6 @@ "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, "rust-analyzer.cargo.features": [ - "cuda", "flash-attn", + "cuda", ], } \ No newline at end of file diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 40fd63f1ec..655c7894d8 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -158,6 +158,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone { fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result; fn set_seed(&self, _: u64) -> Result<()>; + fn get_current_seed(&self) -> Result; /// Synchronize should block until all the operations on the device are completed. fn synchronize(&self) -> Result<()>; diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 75c6e7bd38..54d2da7d12 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -3313,6 +3313,10 @@ impl BackendDevice for CpuDevice { crate::bail!("cannot seed the CPU rng with set_seed") } + fn get_current_seed(&self) -> Result { + crate::bail!("cannot get the CPU rng seed with get_current_seed") + } + fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result { use rand::prelude::*; diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 9e0b64067b..9dd4477639 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -4,7 +4,7 @@ pub use candle_kernels as kernels; pub use cudarc; use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig}; use half::{bf16, f16}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr}; @@ -30,6 +30,7 @@ pub struct CudaDevice { device: Arc, pub(crate) blas: Arc, curand: Arc>, + seed_value: Arc>, } impl std::fmt::Debug for CudaDevice { @@ -168,6 +169,7 @@ impl BackendDevice for CudaDevice { device, blas: Arc::new(blas), curand: Arc::new(Mutex::new(CudaRng(curand))), + seed_value: Arc::new(RwLock::new(299792458)), }) } @@ -176,9 +178,14 @@ impl BackendDevice for CudaDevice { // state will be identical and the same random numbers will be generated. let mut curand = self.curand.lock().unwrap(); curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?; + *self.seed_value.write().unwrap() = seed; Ok(()) } + fn get_current_seed(&self) -> Result { + Ok(*self.seed_value.read().unwrap()) + } + fn location(&self) -> crate::DeviceLocation { crate::DeviceLocation::Cuda { gpu_id: self.device.ordinal(), diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 52e3e22812..55384d0b7b 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -142,6 +142,15 @@ impl Device { } } + /// Get the current seed for the device RNG. + pub fn get_current_seed(&self) -> Result { + match self { + Self::Cpu => CpuDevice.get_current_seed(), + Self::Cuda(c) => c.get_current_seed(), + Self::Metal(m) => m.get_current_seed(), + } + } + pub fn same_device(&self, rhs: &Self) -> bool { match (self, rhs) { (Self::Cpu, Self::Cpu) => true, diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 26d9b2f629..9fa1970b00 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -208,6 +208,10 @@ impl crate::backend::BackendDevice for CudaDevice { Err(Error::NotCompiledWithCudaSupport) } + fn get_current_seed(&self) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn location(&self) -> crate::DeviceLocation { fail!() } diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index 2ec89f97a5..2a3ea93c03 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -220,6 +220,10 @@ impl crate::backend::BackendDevice for MetalDevice { Err(Error::NotCompiledWithMetalSupport) } + fn get_current_seed(&self) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + fn location(&self) -> crate::DeviceLocation { fail!() } diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 07210c68c1..1c4e3f0171 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -70,6 +70,8 @@ pub struct MetalDevice { pub(crate) buffers: AllocatedBuffers, /// Seed for random number generation. pub(crate) seed: Arc>, + /// Value of the current seed + pub(crate) seed_value: Arc>, } impl std::fmt::Debug for MetalDevice { diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 194c5a6254..58aacdfe38 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1944,6 +1944,7 @@ impl BackendDevice for MetalDevice { buffers, kernels, seed, + seed_value: Arc::new(RwLock::new(299792458)), }) } @@ -2105,9 +2106,15 @@ impl BackendDevice for MetalDevice { } seed_buffer.did_modify_range(metal::NSRange::new(0, 4)); + *self.seed_value.write().unwrap() = seed; + Ok(()) } + fn get_current_seed(&self) -> Result { + Ok(*self.seed_value.read().unwrap()) + } + fn synchronize(&self) -> Result<()> { self.wait_until_completed() } From 9240d03a8bd322eb38794f849a020bf7f89e20e0 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Thu, 12 Sep 2024 22:00:03 -0400 Subject: [PATCH 69/75] Add QStorage::data for cuda and metal (#23) --- candle-core/src/quantized/cuda.rs | 4 ++++ candle-core/src/quantized/dummy_cuda.rs | 4 ++++ candle-core/src/quantized/dummy_metal.rs | 4 ++++ candle-core/src/quantized/metal.rs | 16 ++++++++++++++++ candle-core/src/quantized/mod.rs | 5 ++--- 5 files changed, 30 insertions(+), 3 deletions(-) diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 426b818c1d..09f44d570e 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -493,6 +493,10 @@ impl QCudaStorage { self.dequantize_matmul(self_shape, storage, layout) } } + + pub fn data(&self) -> Result> { + self.device.dtoh_sync_copy(&self.data.slice(..)).w() + } } impl QCudaStorage { diff --git a/candle-core/src/quantized/dummy_cuda.rs b/candle-core/src/quantized/dummy_cuda.rs index 69daad3cc4..23a9e05bc2 100644 --- a/candle-core/src/quantized/dummy_cuda.rs +++ b/candle-core/src/quantized/dummy_cuda.rs @@ -48,6 +48,10 @@ impl QCudaStorage { ) -> Result<(CudaStorage, crate::Shape)> { Err(Error::NotCompiledWithCudaSupport) } + + pub fn data(&self) -> Result> { + Err(Error::NotCompiledWithCudaSupport) + } } pub fn load_quantized( diff --git a/candle-core/src/quantized/dummy_metal.rs b/candle-core/src/quantized/dummy_metal.rs index fc51214c19..c5c8db9282 100644 --- a/candle-core/src/quantized/dummy_metal.rs +++ b/candle-core/src/quantized/dummy_metal.rs @@ -44,6 +44,10 @@ impl QMetalStorage { ) -> Result<(MetalStorage, crate::Shape)> { Err(Error::NotCompiledWithMetalSupport) } + + pub fn data(&self) -> Result> { + Err(Error::NotCompiledWithMetalSupport) + } } pub fn load_quantized( diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index bc43658d13..8eb38c849c 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -206,6 +206,22 @@ impl QMetalStorage { let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32); Ok((dst_storage, dst_shape)) } + + pub fn data(&self) -> Result> { + let size = (self.count * self.dtype.size_in_bytes()) as NSUInteger; + + let buffer = self.device.new_buffer_managed(size)?; + { + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("to_cpu"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("blit_to_cpu"); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, size); + blit.end_encoding(); + } + self.device.wait_until_completed()?; + Ok(read_to_vec::(&buffer, self.count)) + } } pub fn load_quantized( diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 39a30d12c5..7f8dbfcf2a 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -134,9 +134,8 @@ impl QStorage { let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; Ok(Cow::from(data)) } - QStorage::Metal(_) | QStorage::Cuda(_) => { - crate::bail!("not implemented"); - } + QStorage::Cuda(storage) => Ok(Cow::from(storage.data()?)), + QStorage::Metal(storage) => Ok(Cow::from(storage.data()?)), } } } From 8a99f7cf31a1d8f175281492eaa7026730067d08 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Fri, 13 Sep 2024 11:50:03 -0400 Subject: [PATCH 70/75] Fix build error with seed (#25) --- candle-core/src/metal_backend/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 58aacdfe38..19bac09e15 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -2106,7 +2106,7 @@ impl BackendDevice for MetalDevice { } seed_buffer.did_modify_range(metal::NSRange::new(0, 4)); - *self.seed_value.write().unwrap() = seed; + *self.seed_value.write().unwrap() = seed as u64; Ok(()) } From 9e31a192642b4048e3df75173efddebaf663fef2 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sat, 14 Sep 2024 21:08:54 -0300 Subject: [PATCH 71/75] Add the i16 dtype (2) (#26) * Add the i16 dtype * Added I16 and I32 to fix the missing arms issue (candle-onnx/eval) * Update rust-ci.yml * Update ci_cuda.yaml * fmt adjustment * Revert "Update rust-ci.yml" This reverts commit f659d36aed9e6e7ab7377c408a3859b8c8b94908. * Revert "Update ci_cuda.yaml" This reverts commit 62a4b3977e24bc7ac60195a6ae4363df36127125. --- candle-core/src/convert.rs | 5 + candle-core/src/cpu/kernels.rs | 11 +++ candle-core/src/cpu_backend/mod.rs | 124 +++++++++++++++++++++++- candle-core/src/cpu_backend/utils.rs | 2 + candle-core/src/cuda_backend/device.rs | 64 +++++++++--- candle-core/src/cuda_backend/mod.rs | 55 ++++++++++- candle-core/src/cuda_backend/utils.rs | 2 + candle-core/src/display.rs | 7 ++ candle-core/src/dtype.rs | 19 +++- candle-core/src/metal_backend/mod.rs | 69 +++++++++++++ candle-core/src/npy.rs | 6 ++ candle-core/src/op.rs | 56 +++++++++++ candle-core/src/safetensors.rs | 4 + candle-core/src/sort.rs | 1 + candle-core/tests/tensor_tests.rs | 6 +- candle-kernels/src/affine.cu | 1 + candle-kernels/src/binary.cu | 12 +++ candle-kernels/src/cast.cu | 14 +++ candle-kernels/src/cuda_utils.cuh | 2 + candle-kernels/src/fill.cu | 2 + candle-kernels/src/indexing.cu | 52 ++++++++++ candle-kernels/src/reduce.cu | 1 + candle-kernels/src/sort.cu | 1 + candle-kernels/src/ternary.cu | 12 +++ candle-kernels/src/unary.cu | 1 + candle-metal-kernels/src/binary.metal | 2 + candle-metal-kernels/src/cast.metal | 18 ++++ candle-metal-kernels/src/indexing.metal | 26 ++++- candle-metal-kernels/src/lib.rs | 7 ++ candle-metal-kernels/src/reduce.metal | 6 ++ candle-metal-kernels/src/sort.metal | 1 + candle-metal-kernels/src/ternary.metal | 15 +++ candle-metal-kernels/src/unary.metal | 3 + candle-onnx/src/eval.rs | 4 +- candle-pyo3/src/lib.rs | 2 + 35 files changed, 586 insertions(+), 27 deletions(-) diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs index b29ff346f6..3e19d970c3 100644 --- a/candle-core/src/convert.rs +++ b/candle-core/src/convert.rs @@ -130,6 +130,11 @@ impl Tensor { f.write_u32::(v)? } } + DType::I16 => { + for v in vs.to_vec1::()? { + f.write_i16::(v)? + } + } DType::I32 => { for v in vs.to_vec1::()? { f.write_i32::(v)? diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs index fd6da1f1ff..f81ad625d3 100644 --- a/candle-core/src/cpu/kernels.rs +++ b/candle-core/src/cpu/kernels.rs @@ -151,6 +151,17 @@ impl VecOps for u32 { ::max(self, other) } } +impl VecOps for i16 { + #[inline(always)] + fn min(self, other: Self) -> Self { + ::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + ::max(self, other) + } +} impl VecOps for i32 { #[inline(always)] fn min(self, other: Self) -> Self { diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 54d2da7d12..24ce83581c 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -22,6 +22,7 @@ const USE_IM2COL_CONV2D: bool = true; pub enum CpuStorage { U8(Vec), U32(Vec), + I16(Vec), I32(Vec), I64(Vec), BF16(Vec), @@ -34,6 +35,7 @@ pub enum CpuStorage { pub enum CpuStorageRef<'a> { U8(&'a [u8]), U32(&'a [u32]), + I16(&'a [i16]), I32(&'a [i32]), I64(&'a [i64]), BF16(&'a [bf16]), @@ -2287,6 +2289,17 @@ impl CpuStorage { .concat(); Self::U32(storages) } + Self::I16(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::I16(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::I16(storages) + } Self::I32(_) => { let storages = storages .iter() @@ -2365,6 +2378,7 @@ impl BackendStorage for CpuStorage { match self { Self::U8(_) => DType::U8, Self::U32(_) => DType::U32, + Self::I16(_) => DType::I16, Self::I32(_) => DType::I32, Self::I64(_) => DType::I64, Self::BF16(_) => DType::BF16, @@ -2385,6 +2399,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); Ok(Self::BF16(data)) } + (Self::I16(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } (Self::I32(storage), DType::BF16) => { let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); Ok(Self::BF16(data)) @@ -2417,6 +2435,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) } + (Self::I16(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } (Self::I32(storage), DType::F16) => { let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) @@ -2449,6 +2471,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } + (Self::I16(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } (Self::I32(storage), DType::F32) => { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) @@ -2497,6 +2523,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) } + (Self::I16(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } (Self::I32(storage), DType::U8) => { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) @@ -2513,6 +2543,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v); Ok(Self::U32(data)) } + (Self::I16(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } (Self::I32(storage), DType::U32) => { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) @@ -2537,6 +2571,42 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } + (Self::U8(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::U32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::I16(data)) + } + (Self::I32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::I64(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::BF16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + (Self::F16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + (Self::F32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::F64(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } (Self::U8(storage), DType::I32) => { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) @@ -2545,6 +2615,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) } + (Self::I16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } (Self::I32(storage), DType::I32) => { let data = unary_map(storage, layout, |v| v); Ok(Self::I32(data)) @@ -2577,6 +2651,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) } + (Self::I16(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } (Self::I32(storage), DType::I64) => { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) @@ -2609,6 +2687,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) } + (Self::I16(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } (Self::I32(storage), DType::F64) => { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) @@ -2748,6 +2830,7 @@ impl BackendStorage for CpuStorage { } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), + Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()), Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), } @@ -2774,7 +2857,8 @@ impl BackendStorage for CpuStorage { } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), - Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), + Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()), + Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), } } @@ -2825,6 +2909,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, B::u32); Ok(Self::U32(data)) } + Self::I16(storage) => { + let data = unary_map(storage, layout, B::i16); + Ok(Self::I16(data)) + } Self::I32(storage) => { let data = unary_map(storage, layout, B::i32); Ok(Self::I32(data)) @@ -2883,6 +2971,14 @@ impl BackendStorage for CpuStorage { }; Ok(Self::U32(data)) } + (Self::I16(lhs), Self::I16(rhs)) => { + let data = if B::I16_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i16, B::i16_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::i16) + }; + Ok(Self::I16(data)) + } (Self::I32(lhs), Self::I32(rhs)) => { let data = if B::I32_VEC { binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i32, B::i32_vec) @@ -2934,6 +3030,9 @@ impl BackendStorage for CpuStorage { (Self::U32(src), Self::U32(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } + (Self::I16(src), Self::I16(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } (Self::I32(src), Self::I32(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } @@ -2968,6 +3067,7 @@ impl BackendStorage for CpuStorage { match (self, dst) { (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::I16(src), Self::I16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), @@ -2998,6 +3098,7 @@ impl BackendStorage for CpuStorage { match self { Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::I16(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")), @@ -3169,6 +3270,7 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + Self::I16(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::I32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()), @@ -3179,6 +3281,7 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l), + Self::I16(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::I32(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()), @@ -3197,6 +3300,7 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + Self::I16(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::I32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()), @@ -3227,6 +3331,13 @@ impl BackendStorage for CpuStorage { }; IndexAdd { ids, dim }.map(self, l, src, src_l) } + Self::I16(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } Self::I32(ids) => { let ids = match ids_l.contiguous_offsets() { Some((a, b)) => &ids[a..b], @@ -3323,7 +3434,7 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::I32 | DType::I64 => { + DType::U8 | DType::U32 | DType::I16 | DType::I32 | DType::I64 => { Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()) } DType::BF16 => { @@ -3369,7 +3480,7 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::I32 | DType::I64 => { + DType::U8 | DType::U32 | DType::I16 | DType::I32 | DType::I64 => { Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) } DType::BF16 => { @@ -3428,6 +3539,11 @@ impl BackendDevice for CpuDevice { v.set_len(elem_count); CpuStorage::U32(v) } + DType::I16 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::I16(v) + } DType::I32 => { let mut v = Vec::with_capacity(elem_count); v.set_len(elem_count); @@ -3467,6 +3583,7 @@ impl BackendDevice for CpuDevice { let storage = match dtype { DType::U8 => CpuStorage::U8(vec![1u8; elem_count]), DType::U32 => CpuStorage::U32(vec![1u32; elem_count]), + DType::I16 => CpuStorage::I16(vec![1i16; elem_count]), DType::I32 => CpuStorage::I32(vec![1i32; elem_count]), DType::I64 => CpuStorage::I64(vec![1i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]), @@ -3482,6 +3599,7 @@ impl BackendDevice for CpuDevice { let storage = match dtype { DType::U8 => CpuStorage::U8(vec![0u8; elem_count]), DType::U32 => CpuStorage::U32(vec![0u32; elem_count]), + DType::I16 => CpuStorage::I16(vec![0i16; elem_count]), DType::I32 => CpuStorage::I32(vec![0i32; elem_count]), DType::I64 => CpuStorage::I64(vec![0i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]), diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index 297ccd3de6..20f362e8c4 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -10,6 +10,7 @@ pub trait Map1 { match vs { C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)), C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)), + C::I16(vs) => Ok(C::I16(self.f(vs, layout)?)), C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)), C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)), C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)), @@ -27,6 +28,7 @@ pub trait Map1Any { match vs { C::U8(vs) => Ok(self.f(vs, layout, C::U8)?), C::U32(vs) => Ok(self.f(vs, layout, C::U32)?), + C::I16(vs) => Ok(self.f(vs, layout, C::I16)?), C::I32(vs) => Ok(self.f(vs, layout, C::I32)?), C::I64(vs) => Ok(self.f(vs, layout, C::I64)?), C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?), diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 9dd4477639..ccca8c039c 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -80,6 +80,14 @@ impl CudaDevice { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(data) } + DType::I16 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_i16", kernels::FILL)?; + let params = (&data, v as i16, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I16(data) + } DType::I32 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; @@ -207,6 +215,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::U32(data) } + DType::I16 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::I16(data) + } DType::I32 => { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::I32(data) @@ -244,13 +256,17 @@ impl BackendDevice for CudaDevice { let slice = match dtype { // TODO: Add support for F16 and BF16 though this is likely to require some upstream // cudarc changes. - DType::U8 | DType::U32 | DType::I64 | DType::I32 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_uniform", - }) - .w()? - } + DType::U8 + | DType::U32 + | DType::I64 + | DType::I32 + | DType::I16 + | DType::F16 + | DType::BF16 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_uniform", + }) + .w()?, DType::F32 => { let mut data = unsafe { self.alloc::(elem_count) }.w()?; curand.0.fill_with_uniform(&mut data).w()?; @@ -288,13 +304,17 @@ impl BackendDevice for CudaDevice { elem_count }; let slice = match dtype { - DType::U8 | DType::U32 | DType::I32 | DType::I64 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_normal", - }) - .w()? - } + DType::U8 + | DType::U32 + | DType::I16 + | DType::I32 + | DType::I64 + | DType::F16 + | DType::BF16 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + }) + .w()?, DType::F32 => { let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; curand @@ -330,6 +350,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc::(elem_count).w()?; CudaStorageSlice::U32(data) } + DType::I16 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::I16(data) + } DType::I32 => { let data = self.alloc::(elem_count).w()?; CudaStorageSlice::I32(data) @@ -371,6 +395,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorageRef::I16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I16(data) + } CpuStorageRef::I32(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::I32(data) @@ -412,6 +440,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorage::I16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I16(data) + } CpuStorage::I32(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::I32(data) @@ -453,6 +485,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorage::I16(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::I16(data) + } CpuStorage::I32(storage) => { let data = self.htod_copy(storage).w()?; CudaStorageSlice::I32(data) diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 231e24715c..1a394d4b58 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -47,6 +47,7 @@ impl SlicePtrOrNull { pub enum CudaStorageSlice { U8(CudaSlice), U32(CudaSlice), + I16(CudaSlice), I32(CudaSlice), I64(CudaSlice), BF16(CudaSlice), @@ -364,6 +365,9 @@ impl<'a> Map1 for IndexSelect<'a> { CudaStorageSlice::U8(slice) => { ("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr()) } + CudaStorageSlice::I16(slice) => { + ("is_i16", *slice.slice(ids_l.start_offset()..).device_ptr()) + } CudaStorageSlice::I32(slice) => { ("is_i32", *slice.slice(ids_l.start_offset()..).device_ptr()) } @@ -371,7 +375,7 @@ impl<'a> Map1 for IndexSelect<'a> { ("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr()) } _ => Err(CudaError::UnexpectedDType { - msg: "index_select ids should be u8/u32/i32/i64", + msg: "index_select ids should be u8/u32/i16/i32/i64", expected: DType::U32, got: self.0.dtype(), }) @@ -431,6 +435,9 @@ impl<'a> Map1 for Gather<'a> { ("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr()) } CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I16(slice) => { + ("gather_i16", *slice.slice(ids_o1..ids_o2).device_ptr()) + } CudaStorageSlice::I32(slice) => { ("gather_i32", *slice.slice(ids_o1..ids_o2).device_ptr()) } @@ -438,7 +445,7 @@ impl<'a> Map1 for Gather<'a> { ("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr()) } _ => Err(CudaError::UnexpectedDType { - msg: "gather ids should be u8/u32/i32/i64", + msg: "gather ids should be u8/u32/i16/i32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -484,11 +491,12 @@ impl<'a> Map2InPlace for IndexAdd<'a> { }; let (name, ids) = match &ids.slice { CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I16(slice) => ("ia_i16", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I32(slice) => ("ia_i32", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), _ => Err(CudaError::UnexpectedDType { - msg: "index-add ids should be u8/u32/i32/i64", + msg: "index-add ids should be u8/u32/i16/i32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -533,11 +541,12 @@ impl<'a> Map2InPlace for ScatterAdd<'a> { }; let (name, ids) = match &ids.slice { CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I16(slice) => ("sa_i16", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I32(slice) => ("sa_i32", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), _ => Err(CudaError::UnexpectedDType { - msg: "scatter-add ids should be u8/u32/i32/i64", + msg: "scatter-add ids should be u8/u32/i16/i32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -876,6 +885,10 @@ impl<'a> Map2 for WhereCond<'a> { let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); (ptr, "where_u32") } + CudaStorageSlice::I16(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_i16") + } CudaStorageSlice::I32(slice) => { let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); (ptr, "where_i32") @@ -885,7 +898,7 @@ impl<'a> Map2 for WhereCond<'a> { (ptr, "where_i64") } _ => Err(CudaError::UnexpectedDType { - msg: "where conditions should be u8/u32/i64", + msg: "where conditions should be u8/u32/i16/i32/i64", expected: DType::U32, got: self.0.dtype(), }) @@ -1039,6 +1052,7 @@ macro_rules! cuda_dtype { } cuda_dtype!(u8, U8); cuda_dtype!(u32, U32); +cuda_dtype!(i16, I16); cuda_dtype!(i32, I32); cuda_dtype!(i64, I64); cuda_dtype!(f16, F16); @@ -1162,6 +1176,7 @@ impl BackendStorage for CudaStorage { match self.slice { CudaStorageSlice::U8(_) => DType::U8, CudaStorageSlice::U32(_) => DType::U32, + CudaStorageSlice::I16(_) => DType::I16, CudaStorageSlice::I32(_) => DType::I32, CudaStorageSlice::I64(_) => DType::I64, CudaStorageSlice::BF16(_) => DType::BF16, @@ -1189,6 +1204,7 @@ impl BackendStorage for CudaStorage { let inp = match &self.slice { CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::I16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::I32(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(), @@ -1213,6 +1229,12 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(out) } + DType::I16 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I16(out) + } DType::I32 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); @@ -1315,6 +1337,11 @@ impl BackendStorage for CudaStorage { let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::U32(cpu_storage)) } + CudaStorageSlice::I16(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::I16(cpu_storage)) + } CudaStorageSlice::I32(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; @@ -1587,6 +1614,7 @@ impl BackendStorage for CudaStorage { S::F64(out) } (S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?, + (S::I16(_), S::I16(_)) => Err(CudaError::InternalError("conv2d does not support i16"))?, (S::I32(_), S::I32(_)) => Err(CudaError::InternalError("conv2d does not support i32"))?, (S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?, _ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?, @@ -1854,6 +1882,11 @@ impl BackendStorage for CudaStorage { *d.slice(dst_o..).device_ptr(), "copy2d_u32", ), + (S::I16(s), S::I16(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_i16", + ), (S::I32(s), S::I32(d)) => ( *s.slice(src_o..).device_ptr(), *d.slice(dst_o..).device_ptr(), @@ -1965,6 +1998,18 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()? } } + (CudaStorageSlice::I16(src), CudaStorageSlice::I16(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_i16", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } (CudaStorageSlice::I32(src), CudaStorageSlice::I32(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index ae009b26ab..df06756d78 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -19,6 +19,7 @@ pub trait Map1 { let out = match s { S::U8(s) => S::U8(self.f(s, d, l)?), S::U32(s) => S::U32(self.f(s, d, l)?), + S::I16(s) => S::I16(self.f(s, d, l)?), S::I32(s) => S::I32(self.f(s, d, l)?), S::I64(s) => S::I64(self.f(s, d, l)?), S::BF16(s) => S::BF16(self.f(s, d, l)?), @@ -137,6 +138,7 @@ pub trait Map1Any { let out = match s { S::U8(s) => self.f(s, d, l, S::U8)?, S::U32(s) => self.f(s, d, l, S::U32)?, + S::I16(s) => self.f(s, d, l, S::I16)?, S::I32(s) => self.f(s, d, l, S::I32)?, S::I64(s) => self.f(s, d, l, S::I64)?, S::BF16(s) => self.f(s, d, l, S::BF16)?, diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 5fb370b696..50e0129aeb 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -55,6 +55,7 @@ impl std::fmt::Debug for Tensor { match self.dtype() { DType::U8 => self.fmt_dt::(f), DType::U32 => self.fmt_dt::(f), + DType::I16 => self.fmt_dt::(f), DType::I32 => self.fmt_dt::(f), DType::I64 => self.fmt_dt::(f), DType::BF16 => self.fmt_dt::(f), @@ -464,6 +465,12 @@ impl std::fmt::Display for Tensor { tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; writeln!(f)?; } + DType::I16 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } DType::I32 => { let tf: IntFormatter = IntFormatter::new(); let max_w = tf.max_width(&to_display); diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index c6a0800b24..42d3b1eef9 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -10,6 +10,8 @@ pub enum DType { U8, // Unsigned 32 bits integer. U32, + // Signed 16 bits integer. + I16, // Signed 32 bits integer. I32, // Signed 64 bits integer. @@ -41,6 +43,7 @@ impl std::str::FromStr for DType { match s { "u8" => Ok(Self::U8), "u32" => Ok(Self::U32), + "i16" => Ok(Self::I16), "i32" => Ok(Self::I32), "i64" => Ok(Self::I64), "bf16" => Ok(Self::BF16), @@ -58,6 +61,7 @@ impl DType { match self { Self::U8 => "u8", Self::U32 => "u32", + Self::I16 => "i16", Self::I32 => "i32", Self::I64 => "i64", Self::BF16 => "bf16", @@ -72,6 +76,7 @@ impl DType { match self { Self::U8 => 1, Self::U32 => 4, + Self::I16 => 2, Self::I32 => 4, Self::I64 => 8, Self::BF16 => 2, @@ -83,14 +88,14 @@ impl DType { pub fn is_int(&self) -> bool { match self { - Self::U8 | Self::U32 | Self::I32 | Self::I64 => true, + Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => true, Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false, } } pub fn is_float(&self) -> bool { match self { - Self::U8 | Self::U32 | Self::I32 | Self::I64 => false, + Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => false, Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true, } } @@ -174,6 +179,7 @@ use half::{bf16, f16}; with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64); with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64); +with_dtype!(i16, I16, |v: f64| v as i16, |v: i16| v as f64); with_dtype!(i32, I32, |v: f64| v as i32, |v: i32| v as f64); with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64); with_dtype!(f16, F16, f16::from_f64, f16::to_f64); @@ -186,6 +192,15 @@ pub trait IntDType: WithDType { fn as_usize(&self) -> usize; } +impl IntDType for i16 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + impl IntDType for i32 { fn is_true(&self) -> bool { *self != 0 diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 19bac09e15..bf641eb86a 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -96,6 +96,7 @@ impl BackendStorage for MetalStorage { match self.dtype { DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)), DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)), + DType::I16 => Ok(CpuStorage::I16(self.to_cpu()?)), DType::I32 => Ok(CpuStorage::I32(self.to_cpu()?)), DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)), DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)), @@ -305,6 +306,11 @@ impl BackendStorage for MetalStorage { (ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false), (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true), (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true), + (ReduceOp::Sum, DType::I16) => ("fast_sum_i16_strided", false, false), + (ReduceOp::Min, DType::I16) => ("fast_min_i16_strided", true, false), + (ReduceOp::Max, DType::I16) => ("fast_max_i16_strided", true, false), + (ReduceOp::ArgMin, DType::I16) => ("fast_argmin_i16_strided", true, true), + (ReduceOp::ArgMax, DType::I16) => ("fast_argmax_i16_strided", true, true), (ReduceOp::Sum, DType::I32) => ("fast_sum_i32_strided", false, false), (ReduceOp::Min, DType::I32) => ("fast_min_i32_strided", true, false), (ReduceOp::Max, DType::I32) => ("fast_max_i32_strided", true, false), @@ -369,6 +375,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::BF16) => "cast_u32_bf16", (DType::U32, DType::F16) => "cast_u32_f16", (DType::U32, DType::F32) => "cast_u32_f32", + (DType::U32, DType::I16) => "cast_u32_i16", (DType::U32, DType::I32) => "cast_u32_i32", (DType::U32, DType::I64) => "cast_u32_i64", (DType::U32, DType::U8) => "cast_u32_u8", @@ -376,17 +383,25 @@ impl BackendStorage for MetalStorage { (DType::U8, DType::BF16) => "cast_u8_bf16", (DType::U8, DType::F16) => "cast_u8_f16", (DType::U8, DType::F32) => "cast_u8_f32", + (DType::U8, DType::I16) => "cast_u8_i16", (DType::U8, DType::I32) => "cast_u8_i32", (DType::U8, DType::I64) => "cast_u8_i64", (DType::U8, DType::U32) => "cast_u8_u32", (DType::F32, DType::BF16) => "cast_f32_bf16", (DType::F32, DType::F16) => "cast_f32_f16", + (DType::F32, DType::I16) => "cast_f32_i16", (DType::F32, DType::I32) => "cast_f32_i32", (DType::F32, DType::I64) => "cast_f32_i64", (DType::F32, DType::U32) => "cast_f32_u32", (DType::F32, DType::U8) => "cast_f32_u8", + (DType::I16, DType::BF16) => "cast_i16_bf16", + (DType::I16, DType::F16) => "cast_i16_f16", + (DType::I16, DType::F32) => "cast_i16_f32", + (DType::I16, DType::U32) => "cast_i16_u32", + (DType::I16, DType::U8) => "cast_i16_u8", + (DType::I32, DType::BF16) => "cast_i32_bf16", (DType::I32, DType::F16) => "cast_i32_f16", (DType::I32, DType::F32) => "cast_i32_f32", @@ -401,6 +416,7 @@ impl BackendStorage for MetalStorage { (DType::F16, DType::BF16) => "cast_f16_bf16", (DType::F16, DType::F32) => "cast_f16_f32", + (DType::F16, DType::I16) => "cast_f16_i16", (DType::F16, DType::I32) => "cast_f16_i32", (DType::F16, DType::I64) => "cast_f16_i64", (DType::F16, DType::U32) => "cast_f16_u32", @@ -408,6 +424,7 @@ impl BackendStorage for MetalStorage { (DType::BF16, DType::F16) => "cast_bf16_f16", (DType::BF16, DType::F32) => "cast_bf16_f32", + (DType::BF16, DType::I16) => "cast_bf16_i16", (DType::BF16, DType::I32) => "cast_bf16_i32", (DType::BF16, DType::I64) => "cast_bf16_i64", (DType::BF16, DType::U32) => "cast_bf16_u32", @@ -431,14 +448,17 @@ impl BackendStorage for MetalStorage { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32_strided", (DType::U32, DType::U8) => "cast_u32_u8_strided", + (DType::U32, DType::I16) => "cast_u32_i16_strided", (DType::U32, DType::I32) => "cast_u32_i32_strided", (DType::U32, DType::I64) => "cast_u32_i64_strided", (DType::U8, DType::U32) => "cast_u8_u32_strided", (DType::U8, DType::F32) => "cast_u8_f32_strided", + (DType::U8, DType::I16) => "cast_u8_i16_strided", (DType::U8, DType::I32) => "cast_u8_i32_strided", (DType::U8, DType::I64) => "cast_u8_i64_strided", (DType::F32, DType::F16) => "cast_f32_f16_strided", (DType::F16, DType::F32) => "cast_f16_f32_strided", + (DType::I16, DType::F32) => "cast_i16_f32_strided", (DType::I32, DType::F32) => "cast_i32_f32_strided", (DType::I64, DType::F32) => "cast_i64_f32_strided", (DType::F32, DType::BF16) => "cast_f32_bf16_strided", @@ -534,6 +554,7 @@ impl BackendStorage for MetalStorage { ("usign", DType::F16) => contiguous_tiled::sign::HALF, ("usign", DType::F32) => contiguous_tiled::sign::FLOAT, ("usign", DType::BF16) => contiguous_tiled::sign::BFLOAT, + ("usign", DType::I16) => contiguous_tiled::sign::I16, ("usign", DType::I32) => contiguous_tiled::sign::I32, ("usign", DType::I64) => contiguous_tiled::sign::I64, (name, dtype) => { @@ -613,6 +634,7 @@ impl BackendStorage for MetalStorage { ("usign", DType::F16) => contiguous::sign::HALF, ("usign", DType::F32) => contiguous::sign::FLOAT, ("usign", DType::BF16) => contiguous::sign::BFLOAT, + ("usign", DType::I16) => contiguous::sign::I16, ("usign", DType::I32) => contiguous::sign::I32, ("usign", DType::I64) => contiguous::sign::I64, (name, dtype) => { @@ -745,6 +767,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "where_u32_f32", (DType::U8, DType::BF16) => "where_u8_bf16", (DType::U8, DType::F16) => "where_u8_f16", + (DType::U8, DType::I16) => "where_u8_i16", (DType::U8, DType::I32) => "where_u8_i32", (DType::U8, DType::I64) => "where_u8_i64", (DType::U8, DType::U32) => "where_u8_u32", @@ -1283,6 +1306,9 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "sa_u32_f32", (DType::U32, DType::F16) => "sa_u32_f16", (DType::U32, DType::BF16) => "sa_u32_bf16", + (DType::I16, DType::F32) => "sa_i16_f32", + (DType::I16, DType::F16) => "sa_i16_f16", + (DType::I16, DType::BF16) => "sa_i16_bf16", (DType::I32, DType::F32) => "sa_i32_f32", (DType::I32, DType::F16) => "sa_i32_f16", (DType::I32, DType::BF16) => "sa_i32_bf16", @@ -1334,6 +1360,10 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F16) => "is_u32_f16", (DType::U32, DType::BF16) => "is_u32_bf16", + (DType::I16, DType::F32) => "is_i16_f32", + (DType::I16, DType::F16) => "is_i16_f16", + (DType::I16, DType::BF16) => "is_i16_bf16", + (DType::I32, DType::F32) => "is_i32_f32", (DType::I32, DType::F16) => "is_i32_f16", (DType::I32, DType::BF16) => "is_i32_bf16", @@ -1383,6 +1413,14 @@ impl BackendStorage for MetalStorage { return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt()); }; let name = match (ids.dtype, self.dtype) { + (DType::I16, DType::BF16) => "ia_i16_bf16", + (DType::I16, DType::F16) => "ia_i16_f16", + (DType::I16, DType::F32) => "ia_i16_f32", + (DType::I16, DType::I32) => "ia_i16_i32", + (DType::I16, DType::I64) => "ia_i16_i64", + (DType::I16, DType::U32) => "ia_i16_u32", + (DType::I16, DType::U8) => "ia_i16_u8", + (DType::I32, DType::BF16) => "ia_i32_bf16", (DType::I32, DType::F16) => "ia_i32_f16", (DType::I32, DType::F32) => "ia_i32_f32", @@ -1394,6 +1432,7 @@ impl BackendStorage for MetalStorage { (DType::I64, DType::BF16) => "ia_i64_bf16", (DType::I64, DType::F16) => "ia_i64_f16", (DType::I64, DType::F32) => "ia_i64_f32", + (DType::I64, DType::I16) => "ia_i64_i16", (DType::I64, DType::I32) => "ia_i64_i32", (DType::I64, DType::I64) => "ia_i64_i64", (DType::I64, DType::U32) => "ia_i64_u32", @@ -1402,6 +1441,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::BF16) => "ia_u32_bf16", (DType::U32, DType::F16) => "ia_u32_f16", (DType::U32, DType::F32) => "ia_u32_f32", + (DType::U32, DType::I16) => "ia_u32_i16", (DType::U32, DType::I32) => "ia_u32_i32", (DType::U32, DType::I64) => "ia_u32_i64", (DType::U32, DType::U32) => "ia_u32_u32", @@ -1410,6 +1450,7 @@ impl BackendStorage for MetalStorage { (DType::U8, DType::BF16) => "ia_u8_bf16", (DType::U8, DType::F16) => "ia_u8_f16", (DType::U8, DType::F32) => "ia_u8_f32", + (DType::U8, DType::I16) => "ia_u8_i16", (DType::U8, DType::I32) => "ia_u8_i32", (DType::U8, DType::I64) => "ia_u8_i64", (DType::U8, DType::U32) => "ia_u8_u32", @@ -1577,6 +1618,7 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::copy2d::FLOAT, DType::F16 => candle_metal_kernels::copy2d::HALF, DType::BF16 => candle_metal_kernels::copy2d::BFLOAT, + DType::I16 => candle_metal_kernels::copy2d::I16, DType::I32 => candle_metal_kernels::copy2d::I32, DType::I64 => candle_metal_kernels::copy2d::I64, DType::U32 => candle_metal_kernels::copy2d::U32, @@ -1624,6 +1666,7 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + DType::I16 => candle_metal_kernels::unary::strided::copy::I16, DType::I32 => candle_metal_kernels::unary::strided::copy::I32, DType::I64 => candle_metal_kernels::unary::strided::copy::I64, DType::U32 => candle_metal_kernels::unary::strided::copy::U32, @@ -1716,6 +1759,17 @@ impl MetalStorage { ("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8), ("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8), + ("add", DType::I16) => (contiguous::add::I16, self.dtype), + ("sub", DType::I16) => (contiguous::sub::I16, self.dtype), + ("mul", DType::I16) => (contiguous::mul::I16, self.dtype), + ("div", DType::I16) => (contiguous::div::I16, self.dtype), + ("eq", DType::I16) => (contiguous::eq::I16, DType::U8), + ("ne", DType::I16) => (contiguous::ne::I16, DType::U8), + ("le", DType::I16) => (contiguous::le::I16, DType::U8), + ("lt", DType::I16) => (contiguous::lt::I16, DType::U8), + ("ge", DType::I16) => (contiguous::ge::I16, DType::U8), + ("gt", DType::I16) => (contiguous::gt::I16, DType::U8), + ("add", DType::I32) => (contiguous::add::I32, self.dtype), ("sub", DType::I32) => (contiguous::sub::I32, self.dtype), ("mul", DType::I32) => (contiguous::mul::I32, self.dtype), @@ -1820,6 +1874,19 @@ impl MetalStorage { ("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8), ("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8), + ("badd", DType::I16) => (strided::add::I16, self.dtype), + ("bsub", DType::I16) => (strided::sub::I16, self.dtype), + ("bmul", DType::I16) => (strided::mul::I16, self.dtype), + ("bdiv", DType::I16) => (strided::div::I16, self.dtype), + ("bminimum", DType::I16) => (strided::min::I16, self.dtype), + ("bmaximum", DType::I16) => (strided::max::I16, self.dtype), + ("eq", DType::I16) => (strided::eq::I16, DType::U8), + ("ne", DType::I16) => (strided::ne::I16, DType::U8), + ("le", DType::I16) => (strided::le::I16, DType::U8), + ("lt", DType::I16) => (strided::lt::I16, DType::U8), + ("ge", DType::I16) => (strided::ge::I16, DType::U8), + ("gt", DType::I16) => (strided::gt::I16, DType::U8), + ("badd", DType::I32) => (strided::add::I32, self.dtype), ("bsub", DType::I32) => (strided::sub::I32, self.dtype), ("bmul", DType::I32) => (strided::mul::I32, self.dtype), @@ -1989,6 +2056,7 @@ impl BackendDevice for MetalDevice { let (count, buffer) = match T::cpu_storage_ref(s) { CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), @@ -2003,6 +2071,7 @@ impl BackendDevice for MetalDevice { let (count, buffer) = match storage { CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::I16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index b321a619f8..33a4f4c728 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -85,6 +85,7 @@ impl Header { DType::F16 => "f2", DType::F32 => "f4", DType::F64 => "f8", + DType::I16 => "i2", DType::I32 => "i4", DType::I64 => "i8", DType::U32 => "u4", @@ -235,6 +236,11 @@ impl Tensor { reader.read_u32_into::(&mut data_t)?; Tensor::from_vec(data_t, shape, &Device::Cpu) } + DType::I16 => { + let mut data_t = vec![0i16; elem_count]; + reader.read_i16_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } DType::I32 => { let mut data_t = vec![0i32; elem_count]; reader.read_i32_into::(&mut data_t)?; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 75931ee2fe..3786a82aaf 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -189,6 +189,7 @@ pub trait UnaryOpT { fn f64(v1: f64) -> f64; fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; + fn i16(v1: i16) -> i16; fn i32(v1: i32) -> i32; fn i64(v1: i64) -> i64; @@ -214,6 +215,7 @@ pub trait BinaryOpT { fn f64(v1: f64, v2: f64) -> f64; fn u8(v1: u8, v2: u8) -> u8; fn u32(v1: u32, v2: u32) -> u32; + fn i16(v1: i16, v2: i16) -> i16; fn i32(v1: i32, v2: i32) -> i32; fn i64(v1: i64, v2: i64) -> i64; @@ -233,6 +235,8 @@ pub trait BinaryOpT { fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {} const I32_VEC: bool = false; fn i32_vec(_xs1: &[i32], _xs2: &[i32], _ys: &mut [i32]) {} + const I16_VEC: bool = false; + fn i16_vec(_xs1: &[i16], _xs2: &[i16], _ys: &mut [i16]) {} } pub(crate) struct Add; @@ -292,6 +296,10 @@ macro_rules! bin_op { $e(v1, v2) } #[inline(always)] + fn i16(v1: i16, v2: i16) -> i16 { + $e(v1, v2) + } + #[inline(always)] fn i32(v1: i32, v2: i32) -> i32 { $e(v1, v2) } @@ -391,6 +399,10 @@ macro_rules! unary_op { fn i32(_: i32) -> i32 { todo!("no unary function for i32") } + #[inline(always)] + fn i16(_: i16) -> i16 { + todo!("no unary function for i16") + } } }; @@ -431,6 +443,10 @@ macro_rules! unary_op { fn i32(_: i32) -> i32 { todo!("no unary function for i32") } + #[inline(always)] + fn i16(_: i16) -> i16 { + todo!("no unary function for i16") + } #[cfg(feature = "mkl")] const F32_VEC: bool = true; @@ -534,6 +550,10 @@ impl UnaryOpT for Gelu { fn i32(_: i32) -> i32 { 0 } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } const KERNEL: &'static str = "ugelu"; #[cfg(feature = "mkl")] @@ -611,6 +631,10 @@ impl UnaryOpT for Erf { fn i32(_: i32) -> i32 { 0 } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } } /// Silu operation @@ -649,6 +673,10 @@ impl UnaryOpT for Silu { fn i32(_: i32) -> i32 { 0 } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } const KERNEL: &'static str = "usilu"; #[cfg(feature = "mkl")] @@ -724,6 +752,10 @@ impl UnaryOpT for Abs { fn i32(v: i32) -> i32 { v.abs() } + #[inline(always)] + fn i16(v: i16) -> i16 { + v.abs() + } } impl UnaryOpT for Ceil { @@ -762,6 +794,10 @@ impl UnaryOpT for Ceil { fn i32(v: i32) -> i32 { v } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } } impl UnaryOpT for Floor { @@ -800,6 +836,10 @@ impl UnaryOpT for Floor { fn i32(v: i32) -> i32 { v } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } } impl UnaryOpT for Round { @@ -838,6 +878,10 @@ impl UnaryOpT for Round { fn i32(v: i32) -> i32 { v } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } } impl UnaryOpT for GeluErf { @@ -876,6 +920,10 @@ impl UnaryOpT for GeluErf { fn i32(_: i32) -> i32 { 0 } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } } impl UnaryOpT for Relu { @@ -914,6 +962,10 @@ impl UnaryOpT for Relu { fn i32(v: i32) -> i32 { v } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } } /// `BackpropOp` is a wrapper around `Option`. The main goal is to ensure that dependencies are @@ -1016,4 +1068,8 @@ impl UnaryOpT for Sign { fn i32(v: i32) -> i32 { (v > 0) as i32 - (v < 0) as i32 } + #[inline(always)] + fn i16(v: i16) -> i16 { + (v > 0) as i16 - (v < 0) as i16 + } } diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 162928ec7d..12436a0903 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -11,6 +11,7 @@ impl From for st::Dtype { DType::U8 => st::Dtype::U8, DType::U32 => st::Dtype::U32, DType::I64 => st::Dtype::I64, + DType::I16 => st::Dtype::I16, DType::I32 => st::Dtype::I32, DType::BF16 => st::Dtype::BF16, DType::F16 => st::Dtype::F16, @@ -188,6 +189,7 @@ impl Tensor { match dtype { DType::U8 => convert_slice::(data, shape, device), DType::U32 => convert_slice::(data, shape, device), + DType::I16 => convert_slice::(data, shape, device), DType::I32 => convert_slice::(data, shape, device), DType::I64 => convert_slice::(data, shape, device), DType::BF16 => convert_slice::(data, shape, device), @@ -206,6 +208,7 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { convert_with_cast_::(view, device, conv) } st::Dtype::U32 => convert_::(view, device), + st::Dtype::I16 => convert_::(view, device), st::Dtype::I32 => convert_::(view, device), st::Dtype::I64 => convert_::(view, device), st::Dtype::BF16 => convert_::(view, device), @@ -222,6 +225,7 @@ fn convert_back(tensor: &Tensor) -> Result> { match tensor.dtype() { DType::U8 => Ok(convert_back_::(tensor.to_vec1()?)), DType::U32 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::I16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::I32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::I64 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F16 => Ok(convert_back_::(tensor.to_vec1()?)), diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 9d9fd59634..14e3417138 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -65,6 +65,7 @@ impl crate::CustomOp1 for ArgSort { let sort_indexes = match storage { crate::CpuStorage::U8(vs) => self.asort(vs, layout), crate::CpuStorage::U32(vs) => self.asort(vs, layout), + crate::CpuStorage::I16(vs) => self.asort(vs, layout), crate::CpuStorage::I32(vs) => self.asort(vs, layout), crate::CpuStorage::I64(vs) => self.asort(vs, layout), crate::CpuStorage::BF16(vs) => self.asort(vs, layout), diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index bff8f36042..ede9f3c708 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -17,6 +17,10 @@ fn ones(device: &Device) -> Result<()> { Tensor::ones((2, 3), DType::U32, device)?.to_vec2::()?, [[1, 1, 1], [1, 1, 1]], ); + assert_eq!( + Tensor::ones((2, 3), DType::I16, device)?.to_vec2::()?, + [[1, 1, 1], [1, 1, 1]], + ); assert_eq!( Tensor::ones((2, 3), DType::I32, device)?.to_vec2::()?, [[1, 1, 1], [1, 1, 1]], @@ -809,7 +813,7 @@ fn index_select(device: &Device) -> Result<()> { [9.0, 10.0, 11.0] ] ); - for dtype in [DType::U8, DType::U32, DType::I32, DType::I64] { + for dtype in [DType::U8, DType::U32, DType::I16, DType::I32, DType::I64] { let ids = ids.to_dtype(dtype)?; let hs = t.index_select(&ids, 1)?; assert_eq!( diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index c3ff5b8753..301bcd5a64 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -40,5 +40,6 @@ AFFINE_OP(float, affine_f32) AFFINE_OP(double, affine_f64) AFFINE_OP(uint8_t, affine_u8) AFFINE_OP(uint32_t, affine_u32) +AFFINE_OP(int16_t, affine_i16) AFFINE_OP(int32_t, affine_i32) AFFINE_OP(int64_t, affine_i64) diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index f534fc76ad..99ab23b875 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -35,36 +35,42 @@ BINARY_OP(float, badd_f32, x + y) BINARY_OP(double, badd_f64, x + y); BINARY_OP(uint8_t, badd_u8, x + y); BINARY_OP(uint32_t, badd_u32, x + y); +BINARY_OP(int16_t, badd_i16, x + y); BINARY_OP(int32_t, badd_i32, x + y); BINARY_OP(int64_t, badd_i64, x + y); BINARY_OP(float, bdiv_f32, x / y) BINARY_OP(double, bdiv_f64, x / y); BINARY_OP(uint8_t, bdiv_u8, x / y); BINARY_OP(uint32_t, bdiv_u32, x / y); +BINARY_OP(int16_t, bdiv_i16, x / y); BINARY_OP(int32_t, bdiv_i32, x / y); BINARY_OP(int64_t, bdiv_i64, x / y); BINARY_OP(float, bmul_f32, x * y) BINARY_OP(double, bmul_f64, x * y); BINARY_OP(uint8_t, bmul_u8, x * y); BINARY_OP(uint32_t, bmul_u32, x * y); +BINARY_OP(int16_t, bmul_i16, x * y); BINARY_OP(int32_t, bmul_i32, x * y); BINARY_OP(int64_t, bmul_i64, x * y); BINARY_OP(float, bsub_f32, x - y) BINARY_OP(double, bsub_f64, x - y); BINARY_OP(uint8_t, bsub_u8, x - y); BINARY_OP(uint32_t, bsub_u32, x - y); +BINARY_OP(int16_t, bsub_i16, x - y); BINARY_OP(int32_t, bsub_i32, x - y); BINARY_OP(int64_t, bsub_i64, x - y); BINARY_OP(float, bminimum_f32, ming(x, y)); BINARY_OP(double, bminimum_f64, ming(x, y)); BINARY_OP(uint8_t, bminimum_u8, ming(x, y)); BINARY_OP(uint32_t, bminimum_u32, ming(x, y)); +BINARY_OP(int16_t, bminimum_i16, ming(x, y)); BINARY_OP(int32_t, bminimum_i32, ming(x, y)); BINARY_OP(int64_t, bminimum_i64, ming(x, y)); BINARY_OP(float, bmaximum_f32, maxg(x, y)); BINARY_OP(double, bmaximum_f64, maxg(x, y)); BINARY_OP(uint8_t, bmaximum_u8, maxg(x, y)); BINARY_OP(uint32_t, bmaximum_u32, maxg(x, y)); +BINARY_OP(int16_t, bmaximum_i16, maxg(x, y)); BINARY_OP(int32_t, bmaximum_i32, maxg(x, y)); BINARY_OP(int64_t, bmaximum_i64, maxg(x, y)); @@ -72,6 +78,7 @@ BINARY_OP_OUT(float, uint8_t, eq_f32, x == y) BINARY_OP_OUT(double, uint8_t, eq_f64, x == y) BINARY_OP_OUT(uint8_t, uint8_t, eq_u8, x == y) BINARY_OP_OUT(uint32_t, uint8_t, eq_u32, x == y) +BINARY_OP_OUT(int16_t, uint8_t, eq_i16, x == y) BINARY_OP_OUT(int32_t, uint8_t, eq_i32, x == y) BINARY_OP_OUT(int64_t, uint8_t, eq_i64, x == y) @@ -79,6 +86,7 @@ BINARY_OP_OUT(float, uint8_t, ne_f32, x != y) BINARY_OP_OUT(double, uint8_t, ne_f64, x != y) BINARY_OP_OUT(uint8_t, uint8_t, ne_u8, x != y) BINARY_OP_OUT(uint32_t, uint8_t, ne_u32, x != y) +BINARY_OP_OUT(int16_t, uint8_t, ne_i16, x != y) BINARY_OP_OUT(int32_t, uint8_t, ne_i32, x != y) BINARY_OP_OUT(int64_t, uint8_t, ne_i64, x != y) @@ -86,6 +94,7 @@ BINARY_OP_OUT(float, uint8_t, lt_f32, x < y) BINARY_OP_OUT(double, uint8_t, lt_f64, x < y) BINARY_OP_OUT(uint8_t, uint8_t, lt_u8, x < y) BINARY_OP_OUT(uint32_t, uint8_t, lt_u32, x < y) +BINARY_OP_OUT(int16_t, uint8_t, lt_i16, x < y) BINARY_OP_OUT(int32_t, uint8_t, lt_i32, x < y) BINARY_OP_OUT(int64_t, uint8_t, lt_i64, x < y) @@ -93,6 +102,7 @@ BINARY_OP_OUT(float, uint8_t, le_f32, x <= y) BINARY_OP_OUT(double, uint8_t, le_f64, x <= y) BINARY_OP_OUT(uint8_t, uint8_t, le_u8, x <= y) BINARY_OP_OUT(uint32_t, uint8_t, le_u32, x <= y) +BINARY_OP_OUT(int16_t, uint8_t, le_i16, x <= y) BINARY_OP_OUT(int32_t, uint8_t, le_i32, x <= y) BINARY_OP_OUT(int64_t, uint8_t, le_i64, x <= y) @@ -100,6 +110,7 @@ BINARY_OP_OUT(float, uint8_t, gt_f32, x > y) BINARY_OP_OUT(double, uint8_t, gt_f64, x > y) BINARY_OP_OUT(uint8_t, uint8_t, gt_u8, x > y) BINARY_OP_OUT(uint32_t, uint8_t, gt_u32, x > y) +BINARY_OP_OUT(int16_t, uint8_t, gt_i16, x > y) BINARY_OP_OUT(int32_t, uint8_t, gt_i32, x > y) BINARY_OP_OUT(int64_t, uint8_t, gt_i64, x > y) @@ -107,5 +118,6 @@ BINARY_OP_OUT(float, uint8_t, ge_f32, x >= y) BINARY_OP_OUT(double, uint8_t, ge_f64, x >= y) BINARY_OP_OUT(uint8_t, uint8_t, ge_u8, x >= y) BINARY_OP_OUT(uint32_t, uint8_t, ge_u32, x >= y) +BINARY_OP_OUT(int16_t, uint8_t, ge_i16, x >= y) BINARY_OP_OUT(int32_t, uint8_t, ge_i32, x >= y) BINARY_OP_OUT(int64_t, uint8_t, ge_i64, x >= y) diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index f92ac0cbf9..e288bf1812 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -120,11 +120,13 @@ CAST_OP(uint32_t, uint32_t, cast_u32_u32) CAST_OP(uint32_t, uint8_t, cast_u32_u8 ) CAST_OP(uint32_t, int64_t, cast_u32_i64 ) CAST_OP(uint32_t, int32_t, cast_u32_i32 ) +CAST_OP(uint32_t, int16_t, cast_u32_i16 ) CAST_OP(uint32_t, float, cast_u32_f32) CAST_OP(uint32_t, double, cast_u32_f64) CAST_OP(uint8_t, uint32_t, cast_u8_u32) CAST_OP(uint8_t, uint8_t, cast_u8_u8 ) +CAST_OP(uint8_t, int16_t, cast_u8_i16 ) CAST_OP(uint8_t, int32_t, cast_u8_i32 ) CAST_OP(uint8_t, int64_t, cast_u8_i64 ) CAST_OP(uint8_t, float, cast_u8_f32) @@ -132,6 +134,7 @@ CAST_OP(uint8_t, double, cast_u8_f64) CAST_OP(int64_t, uint32_t, cast_i64_u32) CAST_OP(int64_t, uint8_t, cast_i64_u8 ) +CAST_OP(int64_t, int16_t, cast_i64_i16 ) CAST_OP(int64_t, int32_t, cast_i64_i32 ) CAST_OP(int64_t, int64_t, cast_i64_i64 ) CAST_OP(int64_t, float, cast_i64_f32) @@ -141,11 +144,21 @@ CAST_OP(int32_t, uint32_t, cast_i32_u32) CAST_OP(int32_t, uint8_t, cast_i32_u8 ) CAST_OP(int32_t, int64_t, cast_i32_i64 ) CAST_OP(int32_t, int32_t, cast_i32_i32 ) +CAST_OP(int32_t, int16_t, cast_i32_i16 ) CAST_OP(int32_t, float, cast_i32_f32) CAST_OP(int32_t, double, cast_i32_f64) +CAST_OP(int16_t, uint32_t, cast_i16_u32) +CAST_OP(int16_t, uint8_t, cast_i16_u8 ) +CAST_OP(int16_t, int64_t, cast_i16_i64 ) +CAST_OP(int16_t, int32_t, cast_i16_i32 ) +CAST_OP(int16_t, int16_t, cast_i16_i16 ) +CAST_OP(int16_t, float, cast_i16_f32) +CAST_OP(int16_t, double, cast_i16_f64) + CAST_OP(float, uint8_t, cast_f32_u8 ) CAST_OP(float, uint32_t, cast_f32_u32) +CAST_OP(float, int16_t, cast_f32_i16 ) CAST_OP(float, int32_t, cast_f32_i32 ) CAST_OP(float, int64_t, cast_f32_i64 ) CAST_OP(float, float, cast_f32_f32) @@ -153,6 +166,7 @@ CAST_OP(float, double, cast_f32_f64) CAST_OP(double, uint8_t, cast_f64_u8 ) CAST_OP(double, uint32_t, cast_f64_u32) +CAST_OP(double, int16_t, cast_f64_i16 ) CAST_OP(double, int32_t, cast_f64_i32 ) CAST_OP(double, int64_t, cast_f64_i64 ) CAST_OP(double, float, cast_f64_f32) diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index 08aa2b089a..df1497f672 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -181,6 +181,8 @@ __device__ __forceinline__ double absg(double a) { return fabs(a); } __device__ __forceinline__ float copysigng(float a, float b) { return copysignf(a, b); } __device__ __forceinline__ double copysigng(double a, double b) { return copysign(a, b); } +__device__ __forceinline__ int16_t ming(int16_t a, int16_t b) { return min(a, b); } +__device__ __forceinline__ int16_t maxg(int16_t a, int16_t b) { return max(a, b); } __device__ __forceinline__ int32_t ming(int32_t a, int32_t b) { return min(a, b); } __device__ __forceinline__ int32_t maxg(int32_t a, int32_t b) { return max(a, b); } __device__ __forceinline__ int64_t ming(int64_t a, int64_t b) { return min(a, b); } diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index 42bfddfd9f..0654c2631b 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -9,6 +9,7 @@ __device__ void fill_with(T *buf, T value, const size_t numel) { } extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_i16(int16_t *buf, int16_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_i32(int32_t *buf, int32_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); } @@ -35,6 +36,7 @@ COPY2D_OP(float, copy2d_f32) COPY2D_OP(double, copy2d_f64) COPY2D_OP(uint8_t, copy2d_u8) COPY2D_OP(uint32_t, copy2d_u32) +COPY2D_OP(int16_t, copy2d_i16) COPY2D_OP(int32_t, copy2d_i32) COPY2D_OP(int64_t, copy2d_i64) diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 2f3df4de1b..df0e3a071d 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -147,18 +147,22 @@ extern "C" __global__ void FN_NAME( \ #if __CUDA_ARCH__ >= 800 +IS_OP(__nv_bfloat16, int16_t, is_i16_bf16) IS_OP(__nv_bfloat16, int32_t, is_i32_bf16) IS_OP(__nv_bfloat16, int64_t, is_i64_bf16) IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16) IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16) +GATHER_OP(__nv_bfloat16, int16_t, gather_i16_bf16) GATHER_OP(__nv_bfloat16, int32_t, gather_i32_bf16) GATHER_OP(__nv_bfloat16, int64_t, gather_i64_bf16) GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16) GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16) +IA_OP(__nv_bfloat16, int16_t, ia_i16_bf16) IA_OP(__nv_bfloat16, int32_t, ia_i32_bf16) IA_OP(__nv_bfloat16, int64_t, ia_i64_bf16) IA_OP(__nv_bfloat16, uint32_t, ia_u32_bf16) IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16) +SA_OP(__nv_bfloat16, int16_t, sa_i16_bf16) SA_OP(__nv_bfloat16, int32_t, sa_i32_bf16) SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16) SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16) @@ -166,28 +170,41 @@ SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 +IS_OP(__half, int16_t, is_i16_f16) IS_OP(__half, int32_t, is_i32_f16) IS_OP(__half, int64_t, is_i64_f16) IS_OP(__half, uint32_t, is_u32_f16) IS_OP(__half, uint8_t, is_u8_f16) +GATHER_OP(__half, int16_t, gather_i16_f16) GATHER_OP(__half, int32_t, gather_i32_f16) GATHER_OP(__half, int64_t, gather_i64_f16) GATHER_OP(__half, uint32_t, gather_u32_f16) GATHER_OP(__half, uint8_t, gather_u8_f16) +IA_OP(__half, int16_t, ia_i16_f16) IA_OP(__half, int32_t, ia_i32_f16) IA_OP(__half, int64_t, ia_i64_f16) IA_OP(__half, uint32_t, ia_u32_f16) IA_OP(__half, uint8_t, ia_u8_f16) +SA_OP(__half, int16_t, sa_i16_f16) SA_OP(__half, int32_t, sa_i32_f16) SA_OP(__half, int64_t, sa_i64_f16) SA_OP(__half, uint32_t, sa_u32_f16) SA_OP(__half, uint8_t, sa_u8_f16) #endif +IS_OP(float, int16_t, is_i16_f32) +IS_OP(double, int16_t, is_i16_f64) +IS_OP(uint8_t, int16_t, is_i16_u8) +IS_OP(uint32_t, int16_t, is_i16_u32) +IS_OP(int16_t, int16_t, is_i16_i16) +IS_OP(int32_t, int16_t, is_i16_i32) +IS_OP(int64_t, int16_t, is_i16_i64) + IS_OP(float, int32_t, is_i32_f32) IS_OP(double, int32_t, is_i32_f64) IS_OP(uint8_t, int32_t, is_i32_u8) IS_OP(uint32_t, int32_t, is_i32_u32) +IS_OP(int16_t, int32_t, is_i32_i16) IS_OP(int32_t, int32_t, is_i32_i32) IS_OP(int64_t, int32_t, is_i32_i64) @@ -197,10 +214,12 @@ IS_OP(uint8_t, int64_t, is_i64_u8) IS_OP(uint32_t, int64_t, is_i64_u32) IS_OP(int64_t, int64_t, is_i64_i64) IS_OP(int32_t, int64_t, is_i64_i32) +IS_OP(int16_t, int64_t, is_i64_i16) IS_OP(float, uint32_t, is_u32_f32) IS_OP(double, uint32_t, is_u32_f64) IS_OP(uint8_t, uint32_t, is_u32_u8) +IS_OP(int16_t, uint32_t, is_u32_i16) IS_OP(int32_t, uint32_t, is_u32_i32) IS_OP(int64_t, uint32_t, is_u32_i64) IS_OP(uint32_t, uint32_t, is_u32_u32) @@ -209,13 +228,23 @@ IS_OP(float, uint8_t, is_u8_f32) IS_OP(double, uint8_t, is_u8_f64) IS_OP(uint8_t, uint8_t, is_u8_u8) IS_OP(uint32_t, uint8_t, is_u8_u32) +IS_OP(int16_t, uint8_t, is_u8_i16) IS_OP(int32_t, uint8_t, is_u8_i32) IS_OP(int64_t, uint8_t, is_u8_i64) +GATHER_OP(float, int16_t, gather_i16_f32) +GATHER_OP(double, int16_t, gather_i16_f64) +GATHER_OP(uint8_t, int16_t, gather_i16_u8) +GATHER_OP(uint32_t, int16_t, gather_i16_u32) +GATHER_OP(int16_t, int16_t, gather_i16_i16) +GATHER_OP(int32_t, int16_t, gather_i16_i32) +GATHER_OP(int64_t, int16_t, gather_i16_i64) + GATHER_OP(float, int32_t, gather_i32_f32) GATHER_OP(double, int32_t, gather_i32_f64) GATHER_OP(uint8_t, int32_t, gather_i32_u8) GATHER_OP(uint32_t, int32_t, gather_i32_u32) +GATHER_OP(int16_t, int32_t, gather_i32_i16) GATHER_OP(int32_t, int32_t, gather_i32_i32) GATHER_OP(int64_t, int32_t, gather_i32_i64) @@ -225,10 +254,12 @@ GATHER_OP(uint8_t, int64_t, gather_i64_u8) GATHER_OP(uint32_t, int64_t, gather_i64_u32) GATHER_OP(int64_t, int64_t, gather_i64_i64) GATHER_OP(int32_t, int64_t, gather_i64_i32) +GATHER_OP(int16_t, int64_t, gather_i64_i16) GATHER_OP(float, uint32_t, gather_u32_f32) GATHER_OP(double, uint32_t, gather_u32_f64) GATHER_OP(uint8_t, uint32_t, gather_u32_u8) +GATHER_OP(int16_t, uint32_t, gather_u32_i16) GATHER_OP(int32_t, uint32_t, gather_u32_i32) GATHER_OP(int64_t, uint32_t, gather_u32_i64) GATHER_OP(uint32_t, uint32_t, gather_u32_u32) @@ -237,9 +268,16 @@ GATHER_OP(float, uint8_t, gather_u8_f32) GATHER_OP(double, uint8_t, gather_u8_f64) GATHER_OP(uint8_t, uint8_t, gather_u8_u8) GATHER_OP(uint32_t, uint8_t, gather_u8_u32) +GATHER_OP(int16_t, uint8_t, gather_u8_i16) GATHER_OP(int32_t, uint8_t, gather_u8_i32) GATHER_OP(int64_t, uint8_t, gather_u8_i64) +IA_OP(float, int16_t, ia_i16_f32) +IA_OP(double, int16_t, ia_i16_f64) +IA_OP(uint8_t, int16_t, ia_i16_u8) +IA_OP(int16_t, int16_t, ia_i16_i16) +IA_OP(uint16_t, int16_t, ia_i16_u16) + IA_OP(float, int32_t, ia_i32_f32) IA_OP(double, int32_t, ia_i32_f64) IA_OP(uint8_t, int32_t, ia_i32_u8) @@ -252,10 +290,12 @@ IA_OP(uint8_t, int64_t, ia_i64_u8) IA_OP(int64_t, int64_t, ia_i64_i64) IA_OP(uint32_t, int64_t, ia_i64_u32) IA_OP(int32_t, int64_t, ia_i64_i32) +IA_OP(int16_t, int64_t, ia_i64_i16) IA_OP(float, uint32_t, ia_u32_f32) IA_OP(double, uint32_t, ia_u32_f64) IA_OP(uint8_t, uint32_t, ia_u32_u8) +IA_OP(int16_t, uint32_t, ia_u32_i16) IA_OP(int32_t, uint32_t, ia_u32_i32) IA_OP(int64_t, uint32_t, ia_u32_i64) IA_OP(uint32_t, uint32_t, ia_u32_u32) @@ -264,18 +304,28 @@ IA_OP(float, uint8_t, ia_u8_f32) IA_OP(double, uint8_t, ia_u8_f64) IA_OP(uint8_t, uint8_t, ia_u8_u8) IA_OP(uint32_t, uint8_t, ia_u8_u32) +IA_OP(int16_t, uint8_t, ia_u8_i16) IA_OP(int32_t, uint8_t, ia_u8_i32) IA_OP(int64_t, uint8_t, ia_u8_i64) +SA_OP(float, int16_t, sa_i16_f32) +SA_OP(double, int16_t, sa_i16_f64) +SA_OP(uint8_t, int16_t, sa_i16_u8) +SA_OP(int16_t, int16_t, sa_i16_i16) +SA_OP(int32_t, int16_t, sa_i16_i32) +SA_OP(uint32_t, int16_t, sa_i16_u32) + SA_OP(float, int32_t, sa_i32_f32) SA_OP(double, int32_t, sa_i32_f64) SA_OP(uint8_t, int32_t, sa_i32_u8) +SA_OP(int16_t, int32_t, sa_i32_i16) SA_OP(int32_t, int32_t, sa_i32_i32) SA_OP(uint32_t, int32_t, sa_i32_u32) SA_OP(float, int64_t, sa_i64_f32) SA_OP(double, int64_t, sa_i64_f64) SA_OP(uint8_t, int64_t, sa_i64_u8) +SA_OP(int16_t, int64_t, sa_i64_i16) SA_OP(int32_t, int64_t, sa_i64_i32) SA_OP(int64_t, int64_t, sa_i64_i64) SA_OP(uint32_t, int64_t, sa_i64_u32) @@ -283,6 +333,7 @@ SA_OP(uint32_t, int64_t, sa_i64_u32) SA_OP(float, uint32_t, sa_u32_f32) SA_OP(double, uint32_t, sa_u32_f64) SA_OP(uint8_t, uint32_t, sa_u32_u8) +SA_OP(int16_t, uint32_t, sa_u32_i16) SA_OP(int32_t, uint32_t, sa_u32_i32) SA_OP(int64_t, uint32_t, sa_u32_i64) SA_OP(uint32_t, uint32_t, sa_u32_u32) @@ -291,5 +342,6 @@ SA_OP(float, uint8_t, sa_u8_f32) SA_OP(double, uint8_t, sa_u8_f64) SA_OP(uint8_t, uint8_t, sa_u8_u8) SA_OP(uint32_t, uint8_t, sa_u8_u32) +SA_OP(int16_t, uint8_t, sa_u8_i16) SA_OP(int32_t, uint8_t, sa_u8_i32) SA_OP(int64_t, uint8_t, sa_u8_i64) diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 9a1354a8dc..fe2e30160a 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -606,6 +606,7 @@ ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64) FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32) FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64) FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_argmin_u32, fast_argmax_u32, fast_sum_u32) +FAST_OP(int16_t, fast_min_i16, fast_max_i16, fast_argmin_i16, fast_argmax_i16, fast_sum_i16) FAST_OP(int32_t, fast_min_i32, fast_max_i32, fast_argmin_i32, fast_argmax_i32, fast_sum_i32) FAST_OP(int64_t, fast_min_i64, fast_max_i64, fast_argmin_i64, fast_argmax_i64, fast_sum_i64) FAST_OP(uint8_t, fast_min_u8, fast_max_u8, fast_argmin_u8, fast_argmax_u8, fast_sum_u8) diff --git a/candle-kernels/src/sort.cu b/candle-kernels/src/sort.cu index 7fecf8413e..f2b2e9d458 100644 --- a/candle-kernels/src/sort.cu +++ b/candle-kernels/src/sort.cu @@ -85,5 +85,6 @@ ASORT_OP(float, f32) ASORT_OP(double, f64) ASORT_OP(uint8_t, u8) ASORT_OP(uint32_t, u32) +ASORT_OP(int16_t, i16) ASORT_OP(int32_t, i32) ASORT_OP(int64_t, i64) diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index 4617c08fbe..18beede021 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -33,6 +33,7 @@ extern "C" __global__ void FN_NAME( \ } \ #if __CUDA_ARCH__ >= 800 +WHERE_OP(__nv_bfloat16, int16_t, where_i16_bf16) WHERE_OP(__nv_bfloat16, int32_t, where_i32_bf16) WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16) WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16) @@ -40,12 +41,21 @@ WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 +WHERE_OP(__half, int16_t, where_i16_f16) WHERE_OP(__half, int32_t, where_i32_f16) WHERE_OP(__half, int64_t, where_i64_f16) WHERE_OP(__half, uint32_t, where_u32_f16) WHERE_OP(__half, uint8_t, where_u8_f16) #endif +WHERE_OP(float, int16_t, where_i16_f32) +WHERE_OP(double, int16_t, where_i16_f64) +WHERE_OP(uint8_t, int16_t, where_i16_u8) +WHERE_OP(uint32_t, int16_t, where_i16_u32) +WHERE_OP(int16_t, int16_t, where_i16_i16) +WHERE_OP(int32_t, int16_t, where_i16_i32) +WHERE_OP(int64_t, int16_t, where_i16_i64) + WHERE_OP(float, int32_t, where_i32_f32) WHERE_OP(double, int32_t, where_i32_f64) WHERE_OP(uint8_t, int32_t, where_i32_u8) @@ -62,6 +72,7 @@ WHERE_OP(float, uint32_t, where_u32_f32) WHERE_OP(double, uint32_t, where_u32_f64) WHERE_OP(uint8_t, uint32_t, where_u32_u8) WHERE_OP(uint32_t, uint32_t, where_u32_u32) +WHERE_OP(int16_t, uint32_t, where_u32_i16) WHERE_OP(int32_t, uint32_t, where_u32_i32) WHERE_OP(int64_t, uint32_t, where_u32_i64) @@ -69,5 +80,6 @@ WHERE_OP(float, uint8_t, where_u8_f32) WHERE_OP(double, uint8_t, where_u8_f64) WHERE_OP(uint8_t, uint8_t, where_u8_u8) WHERE_OP(uint32_t, uint8_t, where_u8_u32) +WHERE_OP(int16_t, uint8_t, where_u8_i16) WHERE_OP(int32_t, uint8_t, where_u8_i32) WHERE_OP(int64_t, uint8_t, where_u8_i64) diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index 21d3d995c0..bfd60de0b1 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -153,6 +153,7 @@ UNARY_OP(__half, usigmoid_f16, sigmoid_fwd(x)) UNARY_OP(uint8_t, ucopy_u8, x) UNARY_OP(uint32_t, ucopy_u32, x) +UNARY_OP(int16_t, ucopy_i16, x) UNARY_OP(int32_t, ucopy_i32, x) UNARY_OP(int64_t, ucopy_i64, x) UNARY_OP(float, ucopy_f32, x) diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index a9b8129c3a..4c558c2cdb 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -59,6 +59,7 @@ BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \ BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \ BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \ +BINARY(FN, int16_t, int16_t, NAME##_i16, NAME##_i16_strided); \ BINARY(FN, int32_t, int32_t, NAME##_i32, NAME##_i32_strided); #define BINARY_OP_OUT(NAME, FN) \ @@ -66,6 +67,7 @@ BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \ BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \ BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \ BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \ +BINARY(FN, int16_t, uint8_t, NAME##_i16, NAME##_i16_strided); \ BINARY(FN, int32_t, uint8_t, NAME##_i32, NAME##_i32_strided); #define INT64_BINARY_OP(NAME, FN) \ diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index c8122ccf0a..5a8324bf11 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -77,6 +77,7 @@ CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half) CAST(cast_u32_i32, cast_u32_i32_strided, uint32_t, int32_t) +CAST(cast_u32_i16, cast_u32_i16_strided, uint32_t, int16_t) #if __METAL_VERSION__ >= 220 CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) #endif @@ -89,6 +90,7 @@ CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half) CAST(cast_u8_i32, cast_u8_i32_strided, uint8_t, int64_t) +CAST(cast_u8_i16, cast_u8_i16_strided, uint8_t, int16_t) #if __METAL_VERSION__ >= 220 CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t) #endif @@ -100,6 +102,7 @@ CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) CAST(cast_f16_f32, cast_f16_f32_strided, half, float) CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t) CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t) +CAST(cast_f16_i16, cast_f16_i16_strided, half, int16_t) CAST(cast_f16_i32, cast_f16_i32_strided, half, int64_t) CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t) #if defined(__HAVE_BFLOAT__) @@ -111,6 +114,7 @@ CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t) CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t) CAST(cast_i64_i32, cast_i64_i32_strided, int64_t, int32_t) +CAST(cast_i64_i16, cast_i64_i16_strided, int64_t, int16_t) CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half) #if defined(__HAVE_BFLOAT__) CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float) @@ -121,15 +125,28 @@ CAST(cast_i32_f32, cast_i32_f32_strided, int32_t, float) CAST(cast_i32_u8, cast_i32_u8_strided, int32_t, uint8_t) CAST(cast_i32_u32, cast_i32_u32_strided, int32_t, uint32_t) CAST(cast_i32_i64, cast_i32_i64_strided, int32_t, int64_t) +CAST(cast_i32_i16, cast_i32_i16_strided, int32_t, int16_t) CAST(cast_i32_f16, cast_i32_f16_strided, int32_t, half) #if defined(__HAVE_BFLOAT__) CAST_THROUGH(cast_i32_bf16, cast_i32_bf16_strided, int64_t, bfloat, float) #endif +// i16 +CAST(cast_i16_f32, cast_i16_f32_strided, int16_t, float) +CAST(cast_i16_u8, cast_i16_u8_strided, int16_t, uint8_t) +CAST(cast_i16_u32, cast_i16_u32_strided, int16_t, uint32_t) +CAST(cast_i16_i32, cast_i16_i32_strided, int16_t, int32_t) +CAST(cast_i16_i64, cast_i16_i64_strided, int16_t, int64_t) +CAST(cast_i16_f16, cast_i16_f16_strided, int16_t, half) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_i16_bf16, cast_i16_bf16_strided, int16_t, bfloat, float) +#endif + // f32 CAST(cast_f32_f16, cast_f32_f16_strided, float, half) CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t) CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t) +CAST(cast_f32_i16, cast_f32_i16_strided, float, int16_t) CAST(cast_f32_i32, cast_f32_i32_strided, float, int32_t) CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t) #if defined(__HAVE_BFLOAT__) @@ -139,6 +156,7 @@ CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) // bf16 #if defined(__HAVE_BFLOAT__) CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) +CAST(cast_bf16_i16, cast_bf16_i16_strided, bfloat, int16_t) CAST(cast_bf16_i32, cast_bf16_i32_strided, bfloat, int32_t) CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index eaa78d7b73..f01d4795d8 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -199,6 +199,12 @@ INDEX_OP(is_i32_f16, int32_t, half) INDEX_OP(is_i32_bf16, int32_t, bfloat) #endif +INDEX_OP(is_i16_f32, int16_t, float) +INDEX_OP(is_i16_f16, int16_t, half) +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_i16_bf16, int16_t, bfloat) +#endif + INDEX_OP(is_u32_f32, uint32_t, float) INDEX_OP(is_u32_f16, uint32_t, half) #if defined(__HAVE_BFLOAT__) @@ -219,10 +225,12 @@ GATHER_OP(gather_u32_bf16, uint, bfloat) SCATTER_ADD_OP(sa_u32_f32, uint32_t, float) SCATTER_ADD_OP(sa_u8_f32, uint8_t, float) +SCATTER_ADD_OP(sa_i16_f32, int16_t, float) SCATTER_ADD_OP(sa_i32_f32, int32_t, float) SCATTER_ADD_OP(sa_i64_f32, int64_t, float) SCATTER_ADD_OP(sa_u32_f16, uint32_t, half) SCATTER_ADD_OP(sa_u8_f16, uint8_t, half) +SCATTER_ADD_OP(sa_i16_f16, int16_t, half) SCATTER_ADD_OP(sa_i32_f16, int32_t, half) SCATTER_ADD_OP(sa_i64_f16, int64_t, half) #if defined(__HAVE_BFLOAT__) @@ -234,6 +242,7 @@ SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat) // i64 INDEX_ADD_OP(ia_i64_f16, int64_t, half) INDEX_ADD_OP(ia_i64_f32, int64_t, float) +INDEX_ADD_OP(ia_i64_i16, int64_t, int16_t) INDEX_ADD_OP(ia_i64_i32, int64_t, int32_t) INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t) INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t) @@ -242,7 +251,7 @@ INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) #endif -// i64 +// i32 INDEX_ADD_OP(ia_i32_f16, int32_t, half) INDEX_ADD_OP(ia_i32_f32, int32_t, float) INDEX_ADD_OP(ia_i32_i64, int32_t, int64_t) @@ -253,9 +262,23 @@ INDEX_ADD_OP(ia_i32_u8, int32_t, uint8_t) INDEX_ADD_OP(ia_i32_bf16, int32_t, bfloat) #endif +// i16 +INDEX_ADD_OP(ia_i16_f16, int16_t, half) +INDEX_ADD_OP(ia_i16_f32, int16_t, float) +INDEX_ADD_OP(ia_i16_i16, int16_t, int16_t) +INDEX_ADD_OP(ia_i16_i32, int16_t, int32_t) +INDEX_ADD_OP(ia_i16_i64, int16_t, int64_t) +INDEX_ADD_OP(ia_i16_u32, int16_t, uint32_t) +INDEX_ADD_OP(ia_i16_u8, int16_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_i16_bf16, int16_t, bfloat) +#endif + + // u32 INDEX_ADD_OP(ia_u32_f16, uint32_t, half) INDEX_ADD_OP(ia_u32_f32, uint32_t, float) +INDEX_ADD_OP(ia_u32_i16, uint32_t, int16_t) INDEX_ADD_OP(ia_u32_i32, uint32_t, int32_t) INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t) INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t) @@ -267,6 +290,7 @@ INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) // u8 INDEX_ADD_OP(ia_u8_f16, uint8_t, half) INDEX_ADD_OP(ia_u8_f32, uint8_t, float) +INDEX_ADD_OP(ia_u8_i16, uint8_t, int16_t) INDEX_ADD_OP(ia_u8_i32, uint8_t, int32_t) INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t) INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index d5e5b8eb66..ea1656193f 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -47,6 +47,7 @@ pub mod copy2d { pub const BFLOAT: Kernel = Kernel("copy2d_bf16"); pub const I64: Kernel = Kernel("copy2d_i64"); pub const I32: Kernel = Kernel("copy2d_i32"); + pub const I16: Kernel = Kernel("copy2d_i16"); pub const U32: Kernel = Kernel("copy2d_u32"); pub const U8: Kernel = Kernel("copy2d_u8"); } @@ -64,6 +65,7 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64")); pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32")); + pub const I16: Kernel = Kernel(concat!(stringify!($name), "_i16")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8")); } @@ -75,6 +77,7 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel("copy_bf16"); pub const I64: Kernel = Kernel("copy_i64"); pub const I32: Kernel = Kernel("copy_i32"); + pub const I16: Kernel = Kernel("copy_i16"); pub const U32: Kernel = Kernel("copy_u32"); pub const U8: Kernel = Kernel("copy_u8"); } @@ -90,6 +93,7 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled")); pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_tiled")); + pub const I16: Kernel = Kernel(concat!(stringify!($name), "_i16_tiled")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled")); } @@ -101,6 +105,7 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled"); pub const I64: Kernel = Kernel("copy_i64_tiled"); pub const I32: Kernel = Kernel("copy_i32_tiled"); + pub const I16: Kernel = Kernel("copy_i16_tiled"); pub const U32: Kernel = Kernel("copy_u32_tiled"); pub const U8: Kernel = Kernel("copy_u8_tiled"); } @@ -116,6 +121,7 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided")); pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_strided")); + pub const I16: Kernel = Kernel(concat!(stringify!($name), "_i16_strided")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided")); } @@ -127,6 +133,7 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel("copy_bf16_strided"); pub const I64: Kernel = Kernel("copy_i64_strided"); pub const I32: Kernel = Kernel("copy_i32_strided"); + pub const I16: Kernel = Kernel("copy_i16_strided"); pub const U32: Kernel = Kernel("copy_u32_strided"); pub const U8: Kernel = Kernel("copy_u8_strided"); } diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 484fa0a1b1..56ef56f7e0 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -608,6 +608,12 @@ REDUCE(MAX(x, y), fast_max_i32_strided, int32_t, INT_MIN) ARGMIN(fast_argmin_i32_strided, int32_t, INT_MAX) ARGMAX(fast_argmax_i32_strided, int32_t, INT_MIN) +REDUCE(x + y, fast_sum_i16_strided, int16_t, 0) +REDUCE(MIN(x, y), fast_min_i16_strided, int16_t, INT_MAX) +REDUCE(MAX(x, y), fast_max_i16_strided, int16_t, INT_MIN) +ARGMIN(fast_argmin_i16_strided, int16_t, INT_MAX) +ARGMAX(fast_argmax_i16_strided, int16_t, INT_MIN) + #if defined(__HAVE_BFLOAT__) REDUCE(x + y, fast_sum_bf16, bfloat, 0) REDUCE(x + y, fast_sum_bf16_strided, half, 0) diff --git a/candle-metal-kernels/src/sort.metal b/candle-metal-kernels/src/sort.metal index b7cf71bb58..9f001d8fb6 100644 --- a/candle-metal-kernels/src/sort.metal +++ b/candle-metal-kernels/src/sort.metal @@ -89,6 +89,7 @@ ARGSORT(half, f16) ARGSORT(uint8_t, u8) ARGSORT(uint32_t, u32) ARGSORT(int32_t, i32) +ARGSORT(int16_t, i16) #if __METAL_VERSION__ >= 220 ARGSORT(int64_t, i64) diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index 0e043332fe..98aacd0036 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -76,6 +76,7 @@ WHERE_OP(uint8_t, int64_t, where_i64_u8) WHERE_OP(uint32_t, int64_t, where_i64_u32) WHERE_OP(int64_t, int64_t, where_i64_i64) WHERE_OP(int64_t, int32_t, where_i64_i32) +WHERE_OP(int64_t, int16_t, where_i64_i16) #if defined(__HAVE_BFLOAT__) WHERE_OP(bfloat, int64_t, where_i64_bf16) #endif @@ -94,6 +95,20 @@ WHERE_OP(int32_t, int32_t, where_i32_i32) WHERE_OP(bfloat, int32_t, where_i32_bf16) #endif +WHERE_OP(int64_t, uint8_t, where_u8_i16) +WHERE_OP(int64_t, uint32_t, where_u32_i16) + +WHERE_OP(half, int16_t, where_i16_f16) +WHERE_OP(float, int16_t, where_i16_f32) +WHERE_OP(uint8_t, int16_t, where_i16_u8) +WHERE_OP(uint32_t, int16_t, where_i16_u32) +WHERE_OP(int64_t, int16_t, where_i16_i64) +WHERE_OP(int32_t, int16_t, where_i16_i32) +WHERE_OP(int16_t, int16_t, where_i16_i16) +#if defined(__HAVE_BFLOAT__) +WHERE_OP(bfloat, int16_t, where_i16_bf16) +#endif + #if defined(__HAVE_BFLOAT__) WHERE_OP(bfloat, uint8_t, where_u8_bf16) WHERE_OP(bfloat, uint32_t, where_u32_bf16) diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 0c5a2736ee..a76c311a3a 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -172,6 +172,9 @@ COPY2D(copy2d_i64, int64_t) UNARY(id, int32_t, copy_i32, copy_i32_strided) COPY2D(copy2d_i32, int32_t) +UNARY(id, int16_t, copy_i16, copy_i16_strided) +COPY2D(copy2d_i16, int16_t) + #if defined(__HAVE_BFLOAT__) BFLOAT_UNARY_OP(cos) BFLOAT_UNARY_OP(sin) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 5b66a743c3..d8fcc77769 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -712,6 +712,8 @@ fn simple_eval_( let output = match start.dtype() { DType::U8 => arange_step!(u8), DType::U32 => arange_step!(u32), + DType::I16 => arange_step!(i16), + DType::I32 => arange_step!(i32), DType::I64 => arange_step!(i64), DType::BF16 => arange_step!(f32), DType::F16 => arange_step!(f32), @@ -1305,7 +1307,7 @@ fn simple_eval_( let input = get(&node.input[0])?; let dt = input.dtype(); match dt { - DType::U8 | DType::U32 | DType::I64 => { + DType::U8 | DType::U32 | DType::I64 | DType::I16 | DType::I32 => { bail!( "unsupported dtype {}, only float types are allowed for LeakyRelu", dt.as_str() diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 55b5542ed8..d2179d577f 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -151,6 +151,7 @@ macro_rules! pydtype { }; } +pydtype!(i16, |v| v); pydtype!(i32, |v| v); pydtype!(i64, |v| v); pydtype!(u8, |v| v); @@ -201,6 +202,7 @@ trait MapDType { match t.dtype() { DType::U8 => self.f::(t), DType::U32 => self.f::(t), + DType::I16 => self.f::(t), DType::I32 => self.f::(t), DType::I64 => self.f::(t), DType::BF16 => self.f::(t), From c04861daeda1eae409157dd0ad00572d327531db Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 2 Oct 2024 05:10:25 -0400 Subject: [PATCH 72/75] Should compile now on metal --- candle-core/src/metal_backend/mod.rs | 6 ++++-- candle-core/src/quantized/metal.rs | 8 ++++---- candle-core/src/sort.rs | 2 ++ candle-metal-kernels/src/fill.metal | 2 ++ 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 56d37fbe98..7fad400eae 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1578,7 +1578,7 @@ impl BackendStorage for MetalStorage { if self.dtype == DType::BF16 { if s.unwrap_or(1.) != 1. { return Err( - MetalError::Message(format!("mlx matmul doesn't support alpha {s}")).into(), + MetalError::Message(format!("mlx matmul doesn't support alpha {s:?}")).into(), ); } candle_metal_kernels::call_mlx_gemm( @@ -1599,7 +1599,7 @@ impl BackendStorage for MetalStorage { } else if self.device.use_mlx_mm { if s.unwrap_or(1.) != 1. { return Err( - MetalError::Message(format!("mlx matmul doesn't support alpha {s}")).into(), + MetalError::Message(format!("mlx matmul doesn't support alpha {s:?}")).into(), ); } let dtype = match self.dtype { @@ -2131,6 +2131,8 @@ impl BackendDevice for MetalDevice { DType::F16 => "fill_f16", DType::BF16 => "fill_bf16", DType::F32 => "fill_f32", + DType::I32 => "fill_i32", + DType::I16 => "fill_i16", DType::F64 => { let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; return self.storage_from_cpu_storage(&cpu_storage); diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 8eb38c849c..8f12769060 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -208,19 +208,19 @@ impl QMetalStorage { } pub fn data(&self) -> Result> { - let size = (self.count * self.dtype.size_in_bytes()) as NSUInteger; + use metal::NSUInteger; - let buffer = self.device.new_buffer_managed(size)?; + let buffer = self.device.new_buffer_managed(self.buffer.length())?; { let command_buffer = self.device.command_buffer()?; command_buffer.set_label("to_cpu"); let blit = command_buffer.new_blit_command_encoder(); blit.set_label("blit_to_cpu"); - blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, size); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); blit.end_encoding(); } self.device.wait_until_completed()?; - Ok(read_to_vec::(&buffer, self.count)) + Ok(read_to_vec::(&buffer, self.buffer.length())) } } diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 14e3417138..b48f74ba5c 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -152,6 +152,7 @@ impl crate::CustomOp1 for ArgSort { DType::U32 => "asort_asc_u32", DType::I64 => "asort_asc_i64", DType::I32 => "asort_asc_i32", + DType::I16 => "asort_asc_i16", } } else { match storage.dtype() { @@ -163,6 +164,7 @@ impl crate::CustomOp1 for ArgSort { DType::U32 => "asort_desc_u32", DType::I64 => "asort_desc_i64", DType::I32 => "asort_desc_i32", + DType::I16 => "asort_desc_i16", } } }; diff --git a/candle-metal-kernels/src/fill.metal b/candle-metal-kernels/src/fill.metal index 35c3fe7ab2..7e99a8525d 100644 --- a/candle-metal-kernels/src/fill.metal +++ b/candle-metal-kernels/src/fill.metal @@ -33,6 +33,8 @@ FILL_OPS(u32, uint) FILL_OPS(i64, long) FILL_OPS(f16, half) FILL_OPS(f32, float) +FILL_OPS(i32, int) +FILL_OPS(i16, short) #if __METAL_VERSION__ >= 310 FILL_OPS(bf16, bfloat) From 156ebd1d84d6e08b11640a99cafb6d28d4ba30b5 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 2 Oct 2024 05:30:30 -0400 Subject: [PATCH 73/75] Fix dtype cast --- candle-core/src/quantized/metal.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 8f12769060..038ba7b531 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -220,7 +220,7 @@ impl QMetalStorage { blit.end_encoding(); } self.device.wait_until_completed()?; - Ok(read_to_vec::(&buffer, self.buffer.length())) + Ok(read_to_vec::(&buffer, self.buffer.length() as usize)) } } From 20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 3 Oct 2024 14:19:18 -0400 Subject: [PATCH 74/75] Fix set_dtype --- candle-nn/src/var_builder.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 8728764b7b..dfd4977042 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -236,6 +236,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { dtype, device: self.data.device.clone(), }), + dtype, ..self } } From 121bdfdadc537f6664a8b5874ad0e5b7495e227e Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sun, 6 Oct 2024 16:45:41 -0400 Subject: [PATCH 75/75] Add initial f8 e4m3 type --- .vscode/settings.json | 5 + Cargo.toml | 1 + candle-core/Cargo.toml | 3 +- candle-core/src/convert.rs | 6 + candle-core/src/cpu_backend/mod.rs | 133 ++++++++++ candle-core/src/cpu_backend/utils.rs | 6 + candle-core/src/cuda_backend/device.rs | 35 ++- candle-core/src/cuda_backend/mod.rs | 15 ++ candle-core/src/cuda_backend/utils.rs | 8 + candle-core/src/display.rs | 5 + candle-core/src/dtype.rs | 23 +- candle-core/src/npy.rs | 10 + candle-core/src/op.rs | 67 +++++ candle-core/src/safetensors.rs | 5 + candle-core/src/sort.rs | 1 + candle-kernels/src/affine.cu | 28 ++- candle-kernels/src/binary.cu | 15 ++ candle-kernels/src/cast.cu | 14 ++ candle-kernels/src/compatibility.cuh | 1 + candle-kernels/src/conv.cu | 12 + candle-kernels/src/cuda_utils.cuh | 23 ++ candle-kernels/src/fill.cu | 5 + candle-kernels/src/fused_layer_norm.cu | 329 ------------------------- candle-kernels/src/indexing.cu | 95 +++++++ candle-kernels/src/kvconcat.cu | 1 + candle-kernels/src/lib.rs | 1 - candle-kernels/src/reduce.cu | 8 + candle-kernels/src/sort.cu | 3 + candle-kernels/src/ternary.cu | 6 + candle-kernels/src/unary.cu | 27 ++ candle-pyo3/Cargo.toml | 1 + candle-pyo3/src/lib.rs | 3 + 32 files changed, 548 insertions(+), 347 deletions(-) delete mode 100644 candle-kernels/src/fused_layer_norm.cu diff --git a/.vscode/settings.json b/.vscode/settings.json index 6abf0d3d6d..280ea2e7f0 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -11,4 +11,9 @@ "rust-analyzer.cargo.features": [ "cuda", ], + "files.associations": { + "random": "cpp", + "ratio": "cpp", + "cmath": "cpp" + }, } \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index d6cf18614f..6dc5e85cd1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ fancy-regex = "0.13.0" gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } hf-hub = "0.3.0" half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } +float8 = { version = "0.1.0", features = ["num-traits", "rand_distr"], git = "https://github.com/EricLBuehler/float8.git" } hound = "3.5.1" image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] } imageproc = { version = "0.24.0", default-features = false } diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index cbf8f2007f..6ce7e31e1c 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -18,6 +18,7 @@ metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } half = { workspace = true } +float8 = { workspace = true } intel-mkl-src = { workspace = true, optional = true } libc = { workspace = true, optional = true } memmap2 = { workspace = true } @@ -39,7 +40,7 @@ criterion = { workspace = true } [features] default = [] -cuda = ["cudarc", "dep:candle-kernels"] +cuda = ["cudarc", "dep:candle-kernels", "float8/cuda"] cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs index 3e19d970c3..173a96d6e6 100644 --- a/candle-core/src/convert.rs +++ b/candle-core/src/convert.rs @@ -1,5 +1,6 @@ //! Implement conversion traits for tensors use crate::{DType, Device, Error, Tensor, WithDType}; +use float8::F8E4M3; use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::convert::TryFrom; @@ -149,6 +150,11 @@ impl Tensor { let vs = vs.to_vec1::()?; f.write_all(&vs)?; } + DType::F8E4M3 => { + for v in vs.to_vec1::()? { + f.write_u8(v.to_bits())? + } + } } Ok(()) } diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 24ce83581c..6ef74c0725 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -3,6 +3,7 @@ use std::ops::Deref; use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; +use float8::F8E4M3; use half::{bf16, f16}; use rayon::prelude::*; @@ -29,6 +30,7 @@ pub enum CpuStorage { F16(Vec), F32(Vec), F64(Vec), + F8E4M3(Vec), } #[derive(Debug, Clone)] @@ -42,6 +44,7 @@ pub enum CpuStorageRef<'a> { F16(&'a [f16]), F32(&'a [f32]), F64(&'a [f64]), + F8E4M3(&'a [F8E4M3]), } #[derive(Debug, Clone)] @@ -2366,6 +2369,17 @@ impl CpuStorage { .concat(); Self::F64(storages) } + Self::F8E4M3(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F8E4M3(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F8E4M3(storages) + } }; Ok(s) } @@ -2385,6 +2399,7 @@ impl BackendStorage for CpuStorage { Self::F16(_) => DType::F16, Self::F32(_) => DType::F32, Self::F64(_) => DType::F64, + Self::F8E4M3(_) => DType::F8E4M3, } } @@ -2427,6 +2442,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, bf16::from_f64); Ok(Self::BF16(data)) } + (Self::F8E4M3(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32())); + Ok(Self::BF16(data)) + } (Self::U8(storage), DType::F16) => { let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) @@ -2463,6 +2482,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, f16::from_f64); Ok(Self::F16(data)) } + (Self::F8E4M3(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32())); + Ok(Self::F16(data)) + } (Self::U8(storage), DType::F32) => { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) @@ -2499,6 +2522,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } + (Self::F8E4M3(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v.to_f32()); + Ok(Self::F32(data)) + } (Self::U8(storage), DType::U8) => { let data = unary_map(storage, layout, |v| v); Ok(Self::U8(data)) @@ -2535,6 +2562,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) } + (Self::F8E4M3(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u8); + Ok(Self::U8(data)) + } (Self::U8(storage), DType::U32) => { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) @@ -2571,6 +2602,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } + (Self::F8E4M3(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u32); + Ok(Self::U32(data)) + } (Self::U8(storage), DType::I16) => { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) @@ -2607,6 +2642,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as i16); Ok(Self::I16(data)) } + (Self::F8E4M3(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } (Self::U8(storage), DType::I32) => { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) @@ -2643,6 +2682,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as i32); Ok(Self::I32(data)) } + (Self::F8E4M3(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } (Self::U8(storage), DType::I64) => { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) @@ -2679,6 +2722,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) } + (Self::F8E4M3(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i64); + Ok(Self::I64(data)) + } (Self::U8(storage), DType::F64) => { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) @@ -2715,6 +2762,50 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v); Ok(Self::F64(data)) } + (Self::F8E4M3(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v.to_f64()); + Ok(Self::F64(data)) + } + (Self::U8(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::U32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::I16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::I32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::I64(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::BF16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from(v.to_f32())); + Ok(Self::F8E4M3(data)) + } + (Self::F16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32())); + Ok(Self::F8E4M3(data)) + } + (Self::F32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, F8E4M3::from_f32); + Ok(Self::F8E4M3(data)) + } + (Self::F64(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, F8E4M3::from_f64); + Ok(Self::F8E4M3(data)) + } + (Self::F8E4M3(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::F8E4M3(data)) + } } } @@ -2828,6 +2919,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v.powf(e)); Ok(Self::F64(data)) } + Self::F8E4M3(storage) => { + let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e))); + Ok(Self::F8E4M3(data)) + } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()), @@ -2855,6 +2950,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| elu(v, alpha)); Ok(Self::F64(data)) } + Self::F8E4M3(storage) => { + let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha))); + Ok(Self::F8E4M3(data)) + } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()), @@ -2901,6 +3000,15 @@ impl BackendStorage for CpuStorage { Ok(Self::F64(data)) } } + Self::F8E4M3(storage) => { + if B::F8E4M3_VEC { + let data = unary_map_vec(storage, layout, B::f8e4m3, B::f8e4m3_vec); + Ok(Self::F8E4M3(data)) + } else { + let data = unary_map(storage, layout, B::f8e4m3); + Ok(Self::F8E4M3(data)) + } + } Self::U8(storage) => { let data = unary_map(storage, layout, B::u8); Ok(Self::U8(data)) @@ -3455,6 +3563,15 @@ impl BackendDevice for CpuDevice { } Ok(CpuStorage::F16(data)) } + DType::F8E4M3 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distributions::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max)); + for _i in 0..elem_count { + data.push(rng.sample::(uniform)) + } + Ok(CpuStorage::F8E4M3(data)) + } DType::F32 => { let mut data = Vec::with_capacity(elem_count); let uniform = rand::distributions::Uniform::new(min as f32, max as f32); @@ -3501,6 +3618,15 @@ impl BackendDevice for CpuDevice { } Ok(CpuStorage::F16(data)) } + DType::F8E4M3 => { + let mut data = Vec::with_capacity(elem_count); + let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std)) + .map_err(Error::wrap)?; + for _i in 0..elem_count { + data.push(normal.sample(&mut rng)) + } + Ok(CpuStorage::F8E4M3(data)) + } DType::F32 => { let mut data = Vec::with_capacity(elem_count); let normal = @@ -3574,6 +3700,11 @@ impl BackendDevice for CpuDevice { v.set_len(elem_count); CpuStorage::F64(v) } + DType::F8E4M3 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::F8E4M3(v) + } }; Ok(storage) } @@ -3588,6 +3719,7 @@ impl BackendDevice for CpuDevice { DType::I64 => CpuStorage::I64(vec![1i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]), + DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ONE; elem_count]), DType::F32 => CpuStorage::F32(vec![1f32; elem_count]), DType::F64 => CpuStorage::F64(vec![1f64; elem_count]), }; @@ -3604,6 +3736,7 @@ impl BackendDevice for CpuDevice { DType::I64 => CpuStorage::I64(vec![0i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]), + DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]), DType::F32 => CpuStorage::F32(vec![0f32; elem_count]), DType::F64 => CpuStorage::F64(vec![0f64; elem_count]), }; diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index 20f362e8c4..495fcd660b 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -17,6 +17,7 @@ pub trait Map1 { C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)), C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)), C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)), + C::F8E4M3(vs) => Ok(C::F8E4M3(self.f(vs, layout)?)), } } } @@ -35,6 +36,7 @@ pub trait Map1Any { C::F16(vs) => Ok(self.f(vs, layout, C::F16)?), C::F32(vs) => Ok(self.f(vs, layout, C::F32)?), C::F64(vs) => Ok(self.f(vs, layout, C::F64)?), + C::F8E4M3(vs) => Ok(self.f(vs, layout, C::F8E4M3)?), } } } @@ -52,6 +54,7 @@ pub trait Map2 { (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)), (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)), + (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2)?)), _ => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), @@ -96,6 +99,7 @@ pub trait Map3 { (C::F16(v1), C::F16(v2), C::F16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), (C::F32(v1), C::F32(v2), C::F32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), (C::F64(v1), C::F64(v2), C::F64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F8E4M3(v1), C::F8E4M3(v2), C::F8E4M3(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), _ => Err(Error::DTypeMismatchBinaryOp3 { lhs: v1.dtype(), rhs: v2.dtype(), @@ -129,6 +133,7 @@ pub trait Map2Alpha { (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2, s)?)), (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2, s)?)), (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2, s)?)), + (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2, s)?)), _ => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), @@ -152,6 +157,7 @@ pub trait Map2U8 { (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), _ => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index ccca8c039c..1d8cb7d34b 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -3,6 +3,7 @@ use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; pub use candle_kernels as kernels; pub use cudarc; use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig}; +use float8::F8E4M3; use half::{bf16, f16}; use std::sync::{Arc, Mutex, RwLock}; @@ -136,6 +137,14 @@ impl CudaDevice { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_f8_e4m3", kernels::FILL)?; + let params = (&data, v, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -243,6 +252,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F64(data) + } }; Ok(CudaStorage { slice, @@ -262,7 +275,8 @@ impl BackendDevice for CudaDevice { | DType::I32 | DType::I16 | DType::F16 - | DType::BF16 => Err(CudaError::UnsupportedDtype { + | DType::BF16 + | DType::F8E4M3 => Err(CudaError::UnsupportedDtype { dtype, op: "rand_uniform", }) @@ -310,7 +324,8 @@ impl BackendDevice for CudaDevice { | DType::I32 | DType::I64 | DType::F16 - | DType::BF16 => Err(CudaError::UnsupportedDtype { + | DType::BF16 + | DType::F8E4M3 => Err(CudaError::UnsupportedDtype { dtype, op: "rand_normal", }) @@ -378,6 +393,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc::(elem_count).w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -423,6 +442,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::F64(data) } + CpuStorageRef::F8E4M3(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -468,6 +491,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::F64(data) } + CpuStorage::F8E4M3(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -513,6 +540,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_copy(storage).w()?; CudaStorageSlice::F64(data) } + CpuStorage::F8E4M3(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 1a394d4b58..9e8e099b3c 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -7,6 +7,7 @@ use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, }; +use float8::F8E4M3; use half::{bf16, f16}; #[cfg(feature = "cudnn")] @@ -54,6 +55,7 @@ pub enum CudaStorageSlice { F16(CudaSlice), F32(CudaSlice), F64(CudaSlice), + F8E4M3(CudaSlice), } struct Clone; @@ -1183,6 +1185,7 @@ impl BackendStorage for CudaStorage { CudaStorageSlice::F16(_) => DType::F16, CudaStorageSlice::F32(_) => DType::F32, CudaStorageSlice::F64(_) => DType::F64, + CudaStorageSlice::F8E4M3(_) => DType::F8E4M3, } } @@ -1211,6 +1214,7 @@ impl BackendStorage for CudaStorage { CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F8E4M3(inp) => *inp.slice(start_o..).device_ptr(), }; let inp = &inp; @@ -1271,6 +1275,12 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F64(out) } + DType::F8E4M3 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F8E4M3(out) + } }; Ok(Self { slice, @@ -1372,6 +1382,11 @@ impl BackendStorage for CudaStorage { let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::F64(cpu_storage)) } + CudaStorageSlice::F8E4M3(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::F8E4M3(cpu_storage)) + } } } diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index df06756d78..581d687aac 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -26,6 +26,7 @@ pub trait Map1 { S::F16(s) => S::F16(self.f(s, d, l)?), S::F32(s) => S::F32(self.f(s, d, l)?), S::F64(s) => S::F64(self.f(s, d, l)?), + S::F8E4M3(s) => S::F8E4M3(self.f(s, d, l)?), }; Ok(out) } @@ -50,6 +51,7 @@ pub trait Map2 { (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), + (S::F8E4M3(s1), S::F8E4M3(s2)) => S::F8E4M3(self.f(s1, l1, s2, l2, d)?), _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, }; Ok(out) @@ -88,6 +90,9 @@ pub trait Map3 { (S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F8E4M3(s1), S::F8E4M3(s2), S::F8E4M3(s3)) => { + S::F8E4M3(self.f(s1, l1, s2, l2, s3, l3, d)?) + } _ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?, }; Ok(out) @@ -120,6 +125,7 @@ pub trait Map2InPlace { (S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d), (S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d), (S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F8E4M3(dst), S::F8E4M3(src)) => self.f(dst, dst_s, src, src_l, d), _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, } } @@ -145,6 +151,7 @@ pub trait Map1Any { S::F16(s) => self.f(s, d, l, S::F16)?, S::F32(s) => self.f(s, d, l, S::F32)?, S::F64(s) => self.f(s, d, l, S::F64)?, + S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?, }; Ok(out) } @@ -169,6 +176,7 @@ pub trait Map2Any { (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?, _ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?, }; Ok(out) diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 50e0129aeb..c975440aa9 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -2,6 +2,7 @@ /// This implementation should be in line with the PyTorch version. /// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py use crate::{DType, Result, Tensor, WithDType}; +use float8::F8E4M3; use half::{bf16, f16}; impl Tensor { @@ -62,6 +63,7 @@ impl std::fmt::Debug for Tensor { DType::F16 => self.fmt_dt::(f), DType::F32 => self.fmt_dt::(f), DType::F64 => self.fmt_dt::(f), + DType::F8E4M3 => self.fmt_dt::(f), } } } @@ -511,6 +513,9 @@ impl std::fmt::Display for Tensor { writeln!(f)?; } } + DType::F8E4M3 => { + return write!(f, "F8E4M3 does not support display."); + } }; let device_str = match self.device().location() { diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 42d3b1eef9..f40ec3f7e1 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -1,11 +1,14 @@ //! Types for elements that can be stored and manipulated using tensors. #![allow(clippy::redundant_closure_call)] use crate::backend::BackendStorage; +use crate::cpu::kernels::VecOps; use crate::{CpuStorage, CpuStorageRef, Error, Result}; /// The different types of elements allowed in tensors. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum DType { + // Floating-point 8 bits integer (4-bit exponent, 3-bit mantissa). + F8E4M3, // Unsigned 8 bits integer. U8, // Unsigned 32 bits integer. @@ -50,6 +53,7 @@ impl std::str::FromStr for DType { "f16" => Ok(Self::F16), "f32" => Ok(Self::F32), "f64" => Ok(Self::F64), + "f8_e4m3" => Ok(Self::F8E4M3), _ => Err(DTypeParseError(s.to_string())), } } @@ -68,6 +72,7 @@ impl DType { Self::F16 => "f16", Self::F32 => "f32", Self::F64 => "f64", + Self::F8E4M3 => "f8_e4m3", } } @@ -75,6 +80,7 @@ impl DType { pub fn size_in_bytes(&self) -> usize { match self { Self::U8 => 1, + Self::F8E4M3 => 1, Self::U32 => 4, Self::I16 => 2, Self::I32 => 4, @@ -89,14 +95,14 @@ impl DType { pub fn is_int(&self) -> bool { match self { Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => true, - Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false, + Self::BF16 | Self::F16 | Self::F32 | Self::F64 | Self::F8E4M3 => false, } } pub fn is_float(&self) -> bool { match self { Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => false, - Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true, + Self::BF16 | Self::F16 | Self::F32 | Self::F64 | Self::F8E4M3 => true, } } } @@ -175,6 +181,7 @@ macro_rules! with_dtype { } }; } +use float8::F8E4M3; use half::{bf16, f16}; with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64); @@ -186,6 +193,17 @@ with_dtype!(f16, F16, f16::from_f64, f16::to_f64); with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64); with_dtype!(f64, F64, |v: f64| v, |v: f64| v); +with_dtype!(F8E4M3, F8E4M3, |v: f64| F8E4M3::from_f64(v), |v: F8E4M3| v + .to_f64()); + +impl VecOps for F8E4M3 { + fn max(self, rhs: Self) -> Self { + F8E4M3::max(self, rhs) + } + fn min(self, rhs: Self) -> Self { + F8E4M3::min(self, rhs) + } +} pub trait IntDType: WithDType { fn is_true(&self) -> bool; @@ -243,3 +261,4 @@ impl FloatDType for f16 {} impl FloatDType for bf16 {} impl FloatDType for f32 {} impl FloatDType for f64 {} +impl FloatDType for F8E4M3 {} diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 33a4f4c728..28d5a63e90 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -27,11 +27,13 @@ //! ``` use crate::{DType, Device, Error, Result, Shape, Tensor}; use byteorder::{LittleEndian, ReadBytesExt}; +use float8::F8E4M3; use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::collections::HashMap; use std::fs::File; use std::io::{BufReader, Read, Write}; use std::path::Path; +use std::slice; const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY"; const NPY_SUFFIX: &str = ".npy"; @@ -90,6 +92,7 @@ impl Header { DType::I64 => "i8", DType::U32 => "u4", DType::U8 => "u1", + DType::F8E4M3 => Err(Error::Npy("f8e4m3 is not supported".into()))?, }; if !shape.is_empty() { shape.push(',') @@ -251,6 +254,13 @@ impl Tensor { reader.read_i64_into::(&mut data_t)?; Tensor::from_vec(data_t, shape, &Device::Cpu) } + DType::F8E4M3 => { + let mut data_t = vec![F8E4M3::ZERO; elem_count]; + let ptr = data_t.as_mut_ptr().cast::(); + let len = data_t.len(); + reader.read_i8_into(unsafe { slice::from_raw_parts_mut(ptr, len) })?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } } } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 3786a82aaf..208977913a 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -1,5 +1,6 @@ #![allow(clippy::redundant_closure_call)] use crate::Tensor; +use float8::F8E4M3; use half::{bf16, f16}; use num_traits::float::Float; @@ -187,6 +188,7 @@ pub trait UnaryOpT { fn f16(v1: f16) -> f16; fn f32(v1: f32) -> f32; fn f64(v1: f64) -> f64; + fn f8e4m3(v1: F8E4M3) -> F8E4M3; fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; fn i16(v1: i16) -> i16; @@ -199,6 +201,8 @@ pub trait UnaryOpT { fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {} const F16_VEC: bool = false; fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {} + const F8E4M3_VEC: bool = false; + fn f8e4m3_vec(_xs: &[F8E4M3], _ys: &mut [F8E4M3]) {} const F32_VEC: bool = false; fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {} const F64_VEC: bool = false; @@ -213,6 +217,7 @@ pub trait BinaryOpT { fn f16(v1: f16, v2: f16) -> f16; fn f32(v1: f32, v2: f32) -> f32; fn f64(v1: f64, v2: f64) -> f64; + fn f8e4m3(v1: F8E4M3, v2: F8E4M3) -> F8E4M3; fn u8(v1: u8, v2: u8) -> u8; fn u32(v1: u32, v2: u32) -> u32; fn i16(v1: i16, v2: i16) -> i16; @@ -227,6 +232,8 @@ pub trait BinaryOpT { fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {} const F64_VEC: bool = false; fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {} + const F8E4M3_VEC: bool = false; + fn f8e4m3_vec(_xs1: &[F8E4M3], __xs2: &[F8E4M3], _ys: &mut [F8E4M3]) {} const U8_VEC: bool = false; fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {} const U32_VEC: bool = false; @@ -288,6 +295,10 @@ macro_rules! bin_op { $e(v1, v2) } #[inline(always)] + fn f8e4m3(v1: F8E4M3, v2: F8E4M3) -> F8E4M3 { + $e(v1, v2) + } + #[inline(always)] fn u8(v1: u8, v2: u8) -> u8 { $e(v1, v2) } @@ -376,6 +387,10 @@ macro_rules! unary_op { $e } #[inline(always)] + fn f8e4m3($a: F8E4M3) -> F8E4M3 { + $e + } + #[inline(always)] fn f32($a: f32) -> f32 { $e } @@ -428,6 +443,10 @@ macro_rules! unary_op { $e } #[inline(always)] + fn f8e4m3($a: F8E4M3) -> F8E4M3 { + $e + } + #[inline(always)] fn u8(_: u8) -> u8 { todo!("no unary function for u8") } @@ -527,6 +546,17 @@ impl UnaryOpT for Gelu { )) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from_f32(0.5) + * v + * (F8E4M3::ONE + + F8E4M3::tanh( + F8E4M3::from_f32(SQRT_TWO_OVER_PI_F32) + * v + * (F8E4M3::ONE + F8E4M3::from_f32(0.044715) * v * v), + )) + } + #[inline(always)] fn f32(v: f32) -> f32 { 0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v))) } @@ -608,6 +638,10 @@ impl UnaryOpT for Erf { f16::from_f64(Self::f64(v.to_f64())) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] fn f32(v: f32) -> f32 { Self::f64(v as f64) as f32 } @@ -650,6 +684,10 @@ impl UnaryOpT for Silu { v / (f16::ONE + (-v).exp()) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v / (F8E4M3::ONE + (-v).exp()) + } + #[inline(always)] fn f32(v: f32) -> f32 { v / (1.0 + (-v).exp()) } @@ -729,6 +767,10 @@ impl UnaryOpT for Abs { v.abs() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.abs() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.abs() } @@ -771,6 +813,10 @@ impl UnaryOpT for Ceil { v.ceil() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.ceil() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.ceil() } @@ -813,6 +859,10 @@ impl UnaryOpT for Floor { v.floor() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.floor() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.floor() } @@ -855,6 +905,10 @@ impl UnaryOpT for Round { v.round() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.round() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.round() } @@ -897,6 +951,10 @@ impl UnaryOpT for GeluErf { f16::from_f64(Self::f64(v.to_f64())) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] fn f32(v: f32) -> f32 { Self::f64(v as f64) as f32 } @@ -939,6 +997,10 @@ impl UnaryOpT for Relu { v.max(f16::ZERO) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.max(F8E4M3::ZERO) + } + #[inline(always)] fn f32(v: f32) -> f32 { v.max(0f32) } @@ -1045,6 +1107,11 @@ impl UnaryOpT for Sign { f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from((v > F8E4M3::ZERO) as i8 as f32) + - F8E4M3::from((v < F8E4M3::ZERO) as i8 as f32) + } + #[inline(always)] fn f32(v: f32) -> f32 { f32::from(v > 0.) - f32::from(v < 0.) } diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 12436a0903..52df166313 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -1,4 +1,5 @@ use crate::{DType, Device, Error, Result, Tensor, WithDType}; +use float8::F8E4M3; use safetensors::tensor as st; use safetensors::tensor::SafeTensors; use std::borrow::Cow; @@ -17,6 +18,7 @@ impl From for st::Dtype { DType::F16 => st::Dtype::F16, DType::F32 => st::Dtype::F32, DType::F64 => st::Dtype::F64, + DType::F8E4M3 => st::Dtype::F8_E4M3, } } } @@ -32,6 +34,7 @@ impl TryFrom for DType { st::Dtype::F16 => Ok(DType::F16), st::Dtype::F32 => Ok(DType::F32), st::Dtype::F64 => Ok(DType::F64), + st::Dtype::F8_E4M3 => Ok(DType::F8E4M3), dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } @@ -196,6 +199,7 @@ impl Tensor { DType::F16 => convert_slice::(data, shape, device), DType::F32 => convert_slice::(data, shape, device), DType::F64 => convert_slice::(data, shape, device), + DType::F8E4M3 => convert_slice::(data, shape, device), } } } @@ -232,6 +236,7 @@ fn convert_back(tensor: &Tensor) -> Result> { DType::BF16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F64 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::F8E4M3 => Ok(convert_back_::(tensor.to_vec1()?)), } } diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index b48f74ba5c..c7236e7f5f 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -72,6 +72,7 @@ impl crate::CustomOp1 for ArgSort { crate::CpuStorage::F16(vs) => self.asort(vs, layout), crate::CpuStorage::F32(vs) => self.asort(vs, layout), crate::CpuStorage::F64(vs) => self.asort(vs, layout), + crate::CpuStorage::F8E4M3(vs) => self.asort(vs, layout), }; let sort_indexes = crate::CpuStorage::U32(sort_indexes); Ok((sort_indexes, layout.shape().into())) diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index 301bcd5a64..ef75dffd36 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -1,7 +1,7 @@ #include "cuda_utils.cuh" #include -#define AFFINE_OP(TYPENAME, FN_NAME) \ +#define AFFINE_OP(TYPENAME, FN_NAME, AFFINE) \ extern "C" __global__ void FN_NAME( \ const size_t numel, \ const size_t num_dims, \ @@ -16,30 +16,34 @@ extern "C" __global__ void FN_NAME( \ if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ TYPENAME x = inp ? inp[i] : out[i]; \ - out[i] = x * mul + add; \ + out[i] = AFFINE; \ } \ } \ else { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ TYPENAME x = inp ? inp[strided_i] : out[i]; \ - out[i] = x * mul + add; \ + out[i] = AFFINE; \ } \ } \ } \ #if __CUDA_ARCH__ >= 800 -AFFINE_OP(__nv_bfloat16, affine_bf16) +AFFINE_OP(__nv_bfloat16, affine_bf16, x * mul + add) + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +AFFINE_OP(__nv_fp8_e4m3, affine_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(mul) + F8E4M3_TO_FLOAT(add))) #endif #if __CUDA_ARCH__ >= 530 -AFFINE_OP(__half, affine_f16) +AFFINE_OP(__half, affine_f16, x * mul + add) #endif -AFFINE_OP(float, affine_f32) -AFFINE_OP(double, affine_f64) -AFFINE_OP(uint8_t, affine_u8) -AFFINE_OP(uint32_t, affine_u32) -AFFINE_OP(int16_t, affine_i16) -AFFINE_OP(int32_t, affine_i32) -AFFINE_OP(int64_t, affine_i64) +AFFINE_OP(float, affine_f32, x * mul + add) +AFFINE_OP(double, affine_f64, x * mul + add) +AFFINE_OP(uint8_t, affine_u8, x * mul + add) +AFFINE_OP(uint32_t, affine_u32, x * mul + add) +AFFINE_OP(int16_t, affine_i16, x * mul + add) +AFFINE_OP(int32_t, affine_i32, x * mul + add) +AFFINE_OP(int64_t, affine_i64, x * mul + add) diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index 99ab23b875..7bda3e463e 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -14,6 +14,21 @@ BINARY_OP_OUT(__nv_bfloat16, uint8_t, lt_bf16, x < y) BINARY_OP_OUT(__nv_bfloat16, uint8_t, le_bf16, x <= y) BINARY_OP_OUT(__nv_bfloat16, uint8_t, gt_bf16, x > y) BINARY_OP_OUT(__nv_bfloat16, uint8_t, ge_bf16, x >= y) + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +BINARY_OP(__nv_fp8_e4m3, badd_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) + F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bdiv_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) / F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bmul_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bsub_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) - F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bmaximum_f8_e4m3, maxg(x, y)) +BINARY_OP(__nv_fp8_e4m3, bminimum_f8_e4m3, ming(x, y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, eq_f8_e4m3, F8E4M3_TO_FLOAT(x) == F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, ne_f8_e4m3, F8E4M3_TO_FLOAT(x) != F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, lt_f8_e4m3, F8E4M3_TO_FLOAT(x) < F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, le_f8_e4m3, F8E4M3_TO_FLOAT(x) <= F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, gt_f8_e4m3, F8E4M3_TO_FLOAT(x) > F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, ge_f8_e4m3, F8E4M3_TO_FLOAT(x) >= F8E4M3_TO_FLOAT(y)) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index e288bf1812..7176825b8d 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -98,6 +98,20 @@ CAST_THROUGH_OP(double, __nv_bfloat16, float, cast_f64_bf16) CAST_THROUGH_OP(uint8_t, __nv_bfloat16, float, cast_u8_bf16) CAST_THROUGH_OP(int32_t, __nv_bfloat16, float, cast_i32_bf16) CAST_THROUGH_OP(__nv_bfloat16, int32_t, float, cast_bf16_i32) +CAST_THROUGH_OP(__nv_bfloat16, __nv_fp8_e4m3, float, cast_bf16_f8_e4m3) + +CAST_OP(__nv_fp8_e4m3, float, cast_f8_e4m3_f32) +CAST_OP(float, __nv_fp8_e4m3, cast_f32_f8_e4m3) +CAST_THROUGH_OP(__nv_fp8_e4m3, uint8_t, float, cast_f8_e4m3_u8) +CAST_THROUGH_OP(__nv_fp8_e4m3, __half, float, cast_f8_e4m3_f16) +CAST_THROUGH_OP(__nv_fp8_e4m3, double, float, cast_f8_e4m3_f64) +CAST_THROUGH_OP(__half, __nv_fp8_e4m3, float, cast_f16_f8_e4m3) +CAST_THROUGH_OP(double, __nv_fp8_e4m3, float, cast_f64_f8_e4m3) +CAST_THROUGH_OP(uint8_t, __nv_fp8_e4m3, float, cast_u8_f8_e4m3) +CAST_THROUGH_OP(int32_t, __nv_fp8_e4m3, float, cast_i32_f8_e4m3) +CAST_THROUGH_OP(__nv_fp8_e4m3, int32_t, float, cast_f8_e4m3_i32) +CAST_THROUGH_OP(__nv_fp8_e4m3, __nv_bfloat16, float, cast_f8_e4m3_bf16) +CAST_THROUGH_OP(__nv_bfloat16, __nv_fp8_e4m3, float, cast_bf16_f8_e4m3) #endif #endif diff --git a/candle-kernels/src/compatibility.cuh b/candle-kernels/src/compatibility.cuh index d0791749bb..1e4cf215c1 100644 --- a/candle-kernels/src/compatibility.cuh +++ b/candle-kernels/src/compatibility.cuh @@ -1,5 +1,6 @@ #include "cuda_fp16.h" #include "cuda_bf16.h" +#include "cuda_fp8.h" // Table showing which features are supported on which compute capability // https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index fa834faa3a..6ca6fd7c2b 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -702,6 +702,18 @@ UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16) IM2COL_OP(__nv_bfloat16, im2col_bf16) IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16) COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16) + +// NOTE: No conv ops for f8 +// CONV1D_OP(__nv_bfloat16, float, conv1d_f8_e5m) +// CONV2D_OP(__nv_fp8_e4m3, float, conv2d_f8_e5m) +// CONVT1D_OP(__nv_fp8_e4m3, float, conv_transpose1d_f8_e5m) +// CONVT2D_OP(__nv_fp8_e4m3, float, conv_transpose2d_f8_e5m) +// AVG_POOL2D_OP(__nv_fp8_e4m3, float, avg_pool2d_f8_e5m) +// MAX_POOL2D_OP(__nv_fp8_e4m3, max_pool2d_f8_e5m) +// UPSAMPLE_NEAREST2D_OP(__nv_fp8_e4m3, upsample_nearest2d_f8_e5m) +// IM2COL_OP(__nv_fp8_e4m3, im2col_f8_e5m) +// IM2COL1D_OP(__nv_fp8_e4m3, im2col1d_f8_e5m) +// COL2IM1D_OP(__nv_fp8_e4m3, col2im1d_f8_e5m) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index df1497f672..da8a1fe1c1 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -231,4 +231,27 @@ __device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); __device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); } __device__ __forceinline__ __nv_bfloat16 absg(__nv_bfloat16 a) { return __habs(a); } __device__ __forceinline__ __nv_bfloat16 copysigng(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(copysignf(__bfloat162float(a), __bfloat162float(b))); } + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +__device__ __forceinline__ __nv_fp8_e4m3 powg(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(powf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } +__device__ __forceinline__ bool isnang(__nv_fp8_e4m3 a) { return isnanf(F8E4M3_TO_FLOAT(a)); } +__device__ __forceinline__ __nv_fp8_e4m3 sqrtg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(sqrtf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 cosg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(cosf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 sing(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(sinf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 recipg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(1. / F8E4M3_TO_FLOAT(a)); } +__device__ __forceinline__ __nv_fp8_e4m3 maxg(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(fmaxf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } +__device__ __forceinline__ __nv_fp8_e4m3 tanhg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(tanhf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 erfg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(erff(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 ceilg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(ceilf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 floorg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(floorf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 roundg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(roundf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 normcdfg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(normcdff(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 ming(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(fminf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } +__device__ __forceinline__ __nv_fp8_e4m3 logg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(logf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 expg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(expf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 absg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(fabsf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 copysigng(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(copysignf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } + + #endif diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index 0654c2631b..eeea8d4cd4 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -47,6 +47,11 @@ COPY2D_OP(__half, copy2d_f16) #if __CUDA_ARCH__ >= 800 #include +#include + extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); } COPY2D_OP(__nv_bfloat16, copy2d_bf16) + +extern "C" __global__ void fill_f8_e4m3(__nv_fp8_e4m3 *buf, __nv_fp8_e4m3 value, const size_t numel) { fill_with(buf, value, numel); } +COPY2D_OP(__nv_fp8_e4m3, copy2d_f8_e4m3) #endif diff --git a/candle-kernels/src/fused_layer_norm.cu b/candle-kernels/src/fused_layer_norm.cu deleted file mode 100644 index cea64c519b..0000000000 --- a/candle-kernels/src/fused_layer_norm.cu +++ /dev/null @@ -1,329 +0,0 @@ -// Based on https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/multihead_attn/layer_norm.cuh#L243 -// Modified Eric Buehler 2024 - -#include "cuda_fp16.h" -#include -#include - -#if __CUDA_ARCH__ >= 800 -#include -#endif - -template -__device__ void cuWelfordOnlineSum(const U curr, U &mu, U &sigma2, U &count) { - count = count + U(1); - U delta = curr - mu; - U lmean = mu + delta / count; - mu = lmean; - U delta2 = curr - lmean; - sigma2 = sigma2 + delta * delta2; -} - -template -__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, - U &mu, U &sigma2, U &count) { - U delta = muB - mu; - U nA = count; - U nB = countB; - count = count + countB; - U nX = count; - if (nX > U(0)) { - nA = nA / nX; - nB = nB / nX; - mu = nA * mu + nB * muB; - sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; - } else { - mu = U(0); - sigma2 = U(0); - } -} - -// https://github.com/pytorch/pytorch/blob/7fe0cc53e903e515e86b4a350614011c66e3b32d/aten/src/ATen/cuda/DeviceUtils.cuh#L50 -template -__device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = warpSize, unsigned int mask = 0xffffffff) -{ -#if !defined(USE_ROCM) - return __shfl_sync(mask, value, srcLane, width); -#else - return __shfl(value, srcLane, width); -#endif -} - -template -__device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1, - const int n2, const int i1, U &mu, U &sigma2, - U *buf) { - // Assumptions: - // 1) blockDim.x == warpSize - // 2) Tensor is contiguous - // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. - // - // compute variance and mean over n2 - U count = U(0); - mu = U(0); - sigma2 = U(0); - if (i1 < n1) { - // one warp normalizes one n1 index, - // synchronization is implicit - // initialize with standard Welford algorithm - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const T *lvals = vals + i1 * n2; - int l = 4 * thrx; - for (; l + 3 < n2; l += 4 * numx) { - for (int k = 0; k < 4; ++k) { - U curr = static_cast(lvals[l + k]); - cuWelfordOnlineSum(curr, mu, sigma2, count); - } - } - for (; l < n2; ++l) { - U curr = static_cast(lvals[l]); - cuWelfordOnlineSum(curr, mu, sigma2, count); - } - // intra-warp reductions - for (int l = 0; l <= 4; ++l) { - int srcLaneB = (threadIdx.x + (1 << l)) & 31; - U muB = WARP_SHFL(mu, srcLaneB); - U countB = WARP_SHFL(count, srcLaneB); - U sigma2B = WARP_SHFL(sigma2, srcLaneB); - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } - // threadIdx.x == 0 has correct values for each warp - // inter-warp reductions - if (blockDim.y > 1) { - U *ubuf = (U *)buf; - U *ibuf = (U *)(ubuf + blockDim.y); - for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && - threadIdx.y < 2 * offset) { - const int wrt_y = threadIdx.y - offset; - ubuf[2 * wrt_y] = mu; - ubuf[2 * wrt_y + 1] = sigma2; - ibuf[wrt_y] = count; - } - __syncthreads(); - // lower half merges - if (threadIdx.x == 0 && threadIdx.y < offset) { - U muB = ubuf[2 * threadIdx.y]; - U sigma2B = ubuf[2 * threadIdx.y + 1]; - U countB = ibuf[threadIdx.y]; - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } - __syncthreads(); - } - // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values - if (threadIdx.x == 0 && threadIdx.y == 0) { - ubuf[0] = mu; - ubuf[1] = sigma2; - } - __syncthreads(); - mu = ubuf[0]; - sigma2 = ubuf[1] / U(n2); - // don't care about final value of count, we know count == n2 - } else { - mu = WARP_SHFL(mu, 0); - sigma2 = WARP_SHFL(sigma2 / U(n2), 0); - } - } -} - -template <> -__device__ void cuWelfordMuSigma2(const __half *__restrict__ vals, - const int n1, const int n2, const int i1, - float &mu, float &sigma2, float *buf) { - // Assumptions: - // 1) blockDim.x == warpSize - // 2) Tensor is contiguous - // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. - // - // compute variance and mean over n2 - float count = 0.0f; - mu = float(0); - sigma2 = float(0); - - if (i1 < n1) { - // one warp normalizes one n1 index, - // synchronization is implicit - // initialize with standard Welford algorithm - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const __half *lvals = vals + i1 * n2; - int l = 8 * thrx; - if ((((size_t)lvals) & 3) != 0) { - // 16 bit alignment - // first thread consumes first point - if (thrx == 0) { - float curr = static_cast(lvals[0]); - cuWelfordOnlineSum(curr, mu, sigma2, count); - } - ++l; - } - // at this point, lvals[l] are 32 bit aligned for all threads. - for (; l + 7 < n2; l += 8 * numx) { - for (int k = 0; k < 8; k += 2) { - float2 curr = __half22float2(*((__half2 *)(lvals + l + k))); - cuWelfordOnlineSum(curr.x, mu, sigma2, count); - cuWelfordOnlineSum(curr.y, mu, sigma2, count); - } - } - for (; l < n2; ++l) { - float curr = static_cast(lvals[l]); - cuWelfordOnlineSum(curr, mu, sigma2, count); - } - // intra-warp reductions - for (int l = 0; l <= 4; ++l) { - int srcLaneB = (threadIdx.x + (1 << l)) & 31; - float muB = WARP_SHFL(mu, srcLaneB); - float countB = WARP_SHFL(count, srcLaneB); - float sigma2B = WARP_SHFL(sigma2, srcLaneB); - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } - // threadIdx.x == 0 has correct values for each warp - // inter-warp reductions - if (blockDim.y > 1) { - float *ubuf = (float *)buf; - float *ibuf = (float *)(ubuf + blockDim.y); - for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && - threadIdx.y < 2 * offset) { - const int wrt_y = threadIdx.y - offset; - ubuf[2 * wrt_y] = mu; - ubuf[2 * wrt_y + 1] = sigma2; - ibuf[wrt_y] = count; - } - __syncthreads(); - // lower half merges - if (threadIdx.x == 0 && threadIdx.y < offset) { - float muB = ubuf[2 * threadIdx.y]; - float sigma2B = ubuf[2 * threadIdx.y + 1]; - float countB = ibuf[threadIdx.y]; - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } - __syncthreads(); - } - // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values - if (threadIdx.x == 0 && threadIdx.y == 0) { - ubuf[0] = mu; - ubuf[1] = sigma2; - } - __syncthreads(); - mu = ubuf[0]; - sigma2 = ubuf[1] / float(n2); - // don't care about final value of count, we know count == n2 - } else { - mu = WARP_SHFL(mu, 0); - sigma2 = WARP_SHFL(sigma2 / float(n2), 0); - } - } -} - -template __device__ U rsqrt(U v) { return U(1) / sqrt(v); } -template <> __device__ float rsqrt(float v) { return rsqrtf(v); } -template <> __device__ double rsqrt(double v) { return rsqrt(v); } -template <> __device__ __half rsqrt(__half v) { return rsqrt(v); } -#if __CUDA_ARCH__ >= 800 -template <> __device__ __nv_bfloat16 rsqrt(__nv_bfloat16 v) { return rsqrt(v); } -#endif - -// This is the un-specialized struct. Note that we prevent instantiation of -// this struct by putting an undefined symbol in the function body so it won't -// compile. -// template -// struct SharedMemory -// { -// // Ensure that we won't compile any un-specialized types -// __device__ T *getPointer() -// { -// extern __device__ void error(void); -// error(); -// return NULL; -// } -// }; -// https://github.com/NVIDIA/apex/issues/246 -template struct SharedMemory; -template <> struct SharedMemory { - __device__ float *getPointer() { - extern __shared__ float s_float[]; - return s_float; - } -}; - -template <> struct SharedMemory<__half> { - __device__ __half *getPointer() { - extern __shared__ __half s_half[]; - return s_half; - } -}; - -#if __CUDA_ARCH__ >= 800 -template <> struct SharedMemory<__nv_bfloat16> { - __device__ __nv_bfloat16 *getPointer() { - extern __shared__ __nv_bfloat16 s_bf[]; - return s_bf; - } -}; -#endif - -template -__device__ void -cuApplyLayerNorm(T *__restrict__ output_vals, U *__restrict__ mean, - U *__restrict__ invvar, const T *__restrict__ vals, - const int n1, const int n2, const U epsilon, - const T *__restrict__ gamma, const T *__restrict__ beta) { - // Assumptions: - // 1) blockDim.x == warpSize - // 2) Tensors are contiguous - // - for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { - SharedMemory shared; - U *buf = shared.getPointer(); - U mu, sigma2; - cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf); - const T *lvals = vals + i1 * n2; - T *ovals = output_vals + i1 * n2; - U c_invvar = rsqrt(sigma2 + epsilon); - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL && beta != NULL) { - for (int i = thrx; i < n2; i += numx) { - U curr = static_cast(lvals[i]); - ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; - } - } else { - for (int i = thrx; i < n2; i += numx) { - U curr = static_cast(lvals[i]); - ovals[i] = static_cast(c_invvar * (curr - mu)); - } - } - if (threadIdx.x == 0 && threadIdx.y == 0) { - mean[i1] = mu; - invvar[i1] = c_invvar; - } - } -} - -extern "C" __global__ void layernorm_f16(__half *__restrict__ output_vals, __half *__restrict__ mean, - __half *__restrict__ invvar, const __half *__restrict__ vals, - const int n1, const int n2, const __half epsilon, - const __half *__restrict__ gamma, const __half *__restrict__ beta) { - cuApplyLayerNorm(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta); -} - -extern "C" __global__ void layernorm_f32(float *__restrict__ output_vals, float *__restrict__ mean, - float *__restrict__ invvar, const float *__restrict__ vals, - const int n1, const int n2, const float epsilon, - const float *__restrict__ gamma, const float *__restrict__ beta) { - cuApplyLayerNorm(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta); -} - -#if __CUDA_ARCH__ >= 800 -#include -extern "C" __global__ void layernorm_bf16(__nv_bfloat16 *__restrict__ output_vals, __nv_bfloat16 *__restrict__ mean, - __nv_bfloat16 *__restrict__ invvar, const __nv_bfloat16 *__restrict__ vals, - const int n1, const int n2, const __nv_bfloat16 epsilon, - const __nv_bfloat16 *__restrict__ gamma, const __nv_bfloat16 *__restrict__ beta) { - cuApplyLayerNorm(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta); -} -#endif diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index df0e3a071d..52846a04bf 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -99,6 +99,57 @@ __device__ void index_add( } } +#if __CUDA_ARCH__ >= 800 +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +template +__device__ void scatter_add_f8( + const I *ids, + const __nv_fp8_e4m3 *inp, + __nv_fp8_e4m3 *out, + const size_t left_size, + const size_t src_dim_size, + const size_t dst_dim_size, + const size_t right_size +) { + const size_t numel = left_size * right_size; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + const size_t pre = i / right_size; + const size_t post = i % right_size; + for (unsigned int j = 0; j < src_dim_size; ++j) { + const size_t src_i = (pre * src_dim_size + j) * right_size + post; + const size_t idx = ids[src_i]; + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] = __nv_fp8_e4m3(F8E4M3_TO_FLOAT(out[dst_i]) + F8E4M3_TO_FLOAT(inp[src_i])); + } + } +} + +template +__device__ void index_add_f8( + const I *ids, + const size_t ids_dim_size, + const __nv_fp8_e4m3 *inp, + __nv_fp8_e4m3 *out, + const size_t left_size, + const size_t src_dim_size, + const size_t dst_dim_size, + const size_t right_size +) { + const size_t numel = left_size * right_size; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + const size_t pre = i / right_size; + const size_t post = i % right_size; + for (unsigned int j = 0; j < ids_dim_size; ++j) { + const size_t idx = ids[j]; + const size_t src_i = (pre * ids_dim_size + j) * right_size + post; + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] = __nv_fp8_e4m3(F8E4M3_TO_FLOAT(out[dst_i]) + F8E4M3_TO_FLOAT(inp[src_i])); + } + } +} +#endif + #define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const INDEX_TYPENAME *ids, \ @@ -111,6 +162,18 @@ extern "C" __global__ void FN_NAME( \ const size_t right_size \ ) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ +#define IA_OP_F8(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const INDEX_TYPENAME *ids, \ + const size_t ids_dim_size, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t dst_dim_size, \ + const size_t right_size \ +) { index_add_f8(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ + template __device__ void scatter_add( const I *ids, @@ -145,6 +208,17 @@ extern "C" __global__ void FN_NAME( \ const size_t right_size \ ) { scatter_add(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ +#define SA_OP_F8(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const INDEX_TYPENAME *ids, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t dst_dim_size, \ + const size_t right_size \ +) { scatter_add_f8(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ + #if __CUDA_ARCH__ >= 800 IS_OP(__nv_bfloat16, int16_t, is_i16_bf16) @@ -167,6 +241,27 @@ SA_OP(__nv_bfloat16, int32_t, sa_i32_bf16) SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16) SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16) SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16) + +IS_OP(__nv_fp8_e4m3, int16_t, is_i16_f8_e4m3) +IS_OP(__nv_fp8_e4m3, int32_t, is_i32_f8_e4m3) +IS_OP(__nv_fp8_e4m3, int64_t, is_i64_f8_e4m3) +IS_OP(__nv_fp8_e4m3, uint32_t, is_u32_f8_e4m3) +IS_OP(__nv_fp8_e4m3, uint8_t, is_u8_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, int16_t, gather_i16_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, int32_t, gather_i32_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, int64_t, gather_i64_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, uint32_t, gather_u32_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, uint8_t, gather_u8_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, int16_t, ia_i16_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, int32_t, ia_i32_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, int64_t, ia_i64_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, uint32_t, ia_u32_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, uint8_t, ia_u8_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, int16_t, sa_i16_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, int32_t, sa_i32_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, int64_t, sa_i64_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, uint32_t, sa_u32_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, uint8_t, sa_u8_f8_e4m3) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/kvconcat.cu b/candle-kernels/src/kvconcat.cu index 7c78d3abe7..2bbd6c53a0 100644 --- a/candle-kernels/src/kvconcat.cu +++ b/candle-kernels/src/kvconcat.cu @@ -50,4 +50,5 @@ KVCONCAT_OP(__half, kvconcat_f16) #if __CUDA_ARCH__ >= 800 KVCONCAT_OP(__nv_bfloat16, kvconcat_bf16) +KVCONCAT_OP(__nv_fp8_e4m3, kvconcat_f8_e4m3) #endif \ No newline at end of file diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs index cec1b1e2d4..0bb490ca1c 100644 --- a/candle-kernels/src/lib.rs +++ b/candle-kernels/src/lib.rs @@ -3,7 +3,6 @@ pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx")); pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx")); pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); -pub const FUSED_LAYER_NORM: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_layer_norm.ptx")); pub const FUSED_RMS_NORM: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_rms_norm.ptx")); pub const FUSED_ROPE: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_rope.ptx")); pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx")); diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index fe2e30160a..f42cad471e 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -580,6 +580,14 @@ LAYERNORM_OP(__nv_bfloat16, layernorm_bf16) ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16) SUM_OP(__nv_bfloat16, sum_bf16) FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) + +// NOTE: No reduce ops for f8 +// SUM_OP(__nv_fp8_e4m3, sum_fp8_e4m3) +// SOFTMAX_OP(__nv_fp8_e4m3, float, softmax_fp8_e4m3) +// RMSNORM_OP(__nv_fp8_e4m3, rmsnorm_fp8_e4m3) +// LAYERNORM_OP(__nv_fp8_e4m3, layernorm_fp8_e4m3) +// ROPE_OP(__nv_fp8_e4m3, rope_fp8_e4m3, rope_i_fp8_e4m3, rope_thd_fp8_e4m3) +// FAST_OP(__nv_fp8_e4m3, fast_min_fp8_e4m3, fast_max_fp8_e4m3, fast_argmin_fp8_e4m3, fast_argmax_fp8_e4m3, fast_sum_fp8_e4m3) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/sort.cu b/candle-kernels/src/sort.cu index f2b2e9d458..7db1b20ec5 100644 --- a/candle-kernels/src/sort.cu +++ b/candle-kernels/src/sort.cu @@ -75,6 +75,9 @@ extern "C" __global__ void asort_desc_##RUST_NAME( \ #if __CUDA_ARCH__ >= 800 ASORT_OP(__nv_bfloat16, bf16) + +// NOTE: No sort ops for f8 +// ASORT_OP(__nv_fp8_e4m3, fp8_e4m3) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index 18beede021..c426640b39 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -38,6 +38,12 @@ WHERE_OP(__nv_bfloat16, int32_t, where_i32_bf16) WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16) WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16) WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) + +WHERE_OP(__nv_fp8_e4m3, int16_t, where_i16_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, int32_t, where_i32_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, int64_t, where_i64_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, uint32_t, where_u32_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, uint8_t, where_u8_fp8_e4m3) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index bfd60de0b1..ba899e643c 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -122,6 +122,33 @@ UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x)) UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param)) UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x)) UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x)) + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +UNARY_OP(__nv_fp8_e4m3, ucopy_fp8_e4m3, x) +UNARY_OP(__nv_fp8_e4m3, uneg_fp8_e4m3, __nv_fp8_e4m3(-F8E4M3_TO_FLOAT(x))) +UNARY_OP(__nv_fp8_e4m3, urecip_fp8_e4m3, recipg(x)) +UNARY_OP(__nv_fp8_e4m3, uexp_fp8_e4m3, expg(x)) +UNARY_OP(__nv_fp8_e4m3, ulog_fp8_e4m3, logg(x)) +UNARY_OP(__nv_fp8_e4m3, usin_fp8_e4m3, sing(x)) +UNARY_OP(__nv_fp8_e4m3, ucos_fp8_e4m3, cosg(x)) +UNARY_OP(__nv_fp8_e4m3, utanh_fp8_e4m3, tanhg(x)) +UNARY_OP(__nv_fp8_e4m3, uerf_fp8_e4m3, erfg(x)) +UNARY_OP(__nv_fp8_e4m3, uceil_fp8_e4m3, ceilg(x)) +UNARY_OP(__nv_fp8_e4m3, ufloor_fp8_e4m3, floorg(x)) +UNARY_OP(__nv_fp8_e4m3, uround_fp8_e4m3, roundg(x)) +UNARY_OP(__nv_fp8_e4m3, unormcdf_fp8_e4m3, normcdfg(x)) +UNARY_OP(__nv_fp8_e4m3, uabs_fp8_e4m3, absg(x)) +UNARY_OP(__nv_fp8_e4m3, usqr_fp8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x)*F8E4M3_TO_FLOAT(x))) +UNARY_OP(__nv_fp8_e4m3, usqrt_fp8_e4m3, sqrtg(x)) +UNARY_OP(__nv_fp8_e4m3, ugelu_fp8_e4m3, __nv_fp8_e4m3(gelu_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP(__nv_fp8_e4m3, ugelu_erf_fp8_e4m3, __nv_fp8_e4m3(gelu_erf_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP(__nv_fp8_e4m3, urelu_fp8_e4m3, __nv_fp8_e4m3(relu_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP1(__nv_fp8_e4m3, uelu_fp8_e4m3, __nv_fp8_e4m3(elu_fwd(F8E4M3_TO_FLOAT(x), F8E4M3_TO_FLOAT(param)))) +UNARY_OP(__nv_fp8_e4m3, usilu_fp8_e4m3, __nv_fp8_e4m3(silu_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP1(__nv_fp8_e4m3, upowf_fp8_e4m3, powg(x, param)) +UNARY_OP(__nv_fp8_e4m3, usign_fp8_e4m3, __nv_fp8_e4m3(sign_(F8E4M3_TO_FLOAT(x)))) +UNARY_OP(__nv_fp8_e4m3, usigmoid_fp8_e4m3, __nv_fp8_e4m3(sigmoid_fwd(F8E4M3_TO_FLOAT(x)))) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 8800133429..bfed9eb48b 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -19,6 +19,7 @@ candle = { workspace = true } candle-nn = { workspace = true } candle-onnx = { workspace = true, optional = true } half = { workspace = true } +float8 = { workspace = true } intel-mkl-src = { workspace = true, optional = true } pyo3 = { version = "0.21.0", features = ["extension-module", "abi3-py38"] } diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index d2179d577f..ab7f07d985 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,4 +1,5 @@ #![allow(clippy::redundant_closure_call)] +use float8::F8E4M3; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::pyclass::CompareOp; @@ -160,6 +161,7 @@ pydtype!(f16, f32::from); pydtype!(bf16, f32::from); pydtype!(f32, |v| v); pydtype!(f64, |v| v); +pydtype!(F8E4M3, f32::from); fn actual_index(t: &Tensor, dim: usize, index: i64) -> ::candle::Result { let dim = t.dim(dim)?; @@ -209,6 +211,7 @@ trait MapDType { DType::F16 => self.f::(t), DType::F32 => self.f::(t), DType::F64 => self.f::(t), + DType::F8E4M3 => self.f::(t), } } }