From 16fd77112a15f157caac784f73600593b3972194 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 9 Nov 2023 12:27:30 +0100 Subject: [PATCH] Fixing the tests. --- Cargo.toml | 1 + candle-core/examples/tensor-tools.rs | 6 +- candle-core/src/metal_backend.rs | 82 +++++---- candle-core/src/quantized/mod.rs | 29 ++-- candle-core/tests/quantized_tests.rs | 4 +- candle-examples/examples/blip/main.rs | 2 +- candle-examples/examples/llama2-c/main.rs | 2 +- candle-examples/examples/mistral/main.rs | 4 +- candle-examples/examples/phi/main.rs | 4 +- candle-examples/examples/quantized-t5/main.rs | 2 +- candle-examples/examples/replit-code/main.rs | 6 +- candle-examples/examples/stable-lm/main.rs | 6 +- candle-examples/examples/whisper/main.rs | 6 +- candle-metal-kernels/src/lib.rs | 163 +++++++++++++++++- candle-metal-kernels/src/ternary.metal | 57 ++++++ candle-pyo3/src/lib.rs | 7 +- .../src/models/quantized_llama.rs | 6 - candle-wasm-examples/blip/src/bin/m.rs | 2 +- candle-wasm-examples/phi/src/bin/m.rs | 6 +- .../t5/src/bin/m-quantized.rs | 6 +- candle-wasm-examples/whisper/src/worker.rs | 1 + candle-wasm-tests/tests/quantized_tests.rs | 2 +- 22 files changed, 316 insertions(+), 88 deletions(-) create mode 100644 candle-metal-kernels/src/ternary.metal diff --git a/Cargo.toml b/Cargo.toml index d313010570..27875de501 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,7 @@ wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +# metal = { path = "../metal-rs", features = ["mps"] } [profile.release-with-debug] inherits = "release" diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index d06b30d106..b3ba231f6e 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -191,7 +191,8 @@ fn run_ls(file: &std::path::PathBuf, format: Option, verbose: bool) -> R } Format::Ggml => { let mut file = std::fs::File::open(file)?; - let content = candle_core::quantized::ggml_file::Content::read(&mut file)?; + let device = Device::Cpu; + let content = candle_core::quantized::ggml_file::Content::read(&mut file, &device)?; let mut tensors = content.tensors.into_iter().collect::>(); tensors.sort_by(|a, b| a.0.cmp(&b.0)); for (name, qtensor) in tensors.iter() { @@ -291,6 +292,7 @@ fn run_quantize( q: Quantization, qmode: QuantizationMode, ) -> Result<()> { + let device = Device::Cpu; if in_files.is_empty() { candle_core::bail!("no specified input files") } @@ -338,7 +340,7 @@ fn run_quantize( .map(|(name, _)| { println!(" quantizing {name}"); let mut in_file = std::fs::File::open(&in_files[0])?; - let tensor = content.tensor(&mut in_file, name)?; + let tensor = content.tensor(&mut in_file, name, &device)?; let tensor = qmode.quantize(name, tensor, quantize_fn)?; Ok((name, tensor)) }) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index a7e20e8f7e..04a2c3dd9c 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -28,12 +28,6 @@ impl From for MetalError { } } -impl MetalError { - fn msg>(msg: S) -> Self { - MetalError::Message(msg.as_ref().to_string()) - } -} - #[derive(Clone)] pub struct MetalDevice { device: metal::Device, @@ -410,10 +404,42 @@ impl BackendStorage for MetalStorage { }) } - fn where_cond(&self, _: &Layout, rhs: &Self, _: &Layout, _: &Self, _: &Layout) -> Result { - debug!("TODO where_cond"); - Ok(rhs.clone()) - // todo!() + fn where_cond( + &self, + layout: &Layout, + t: &Self, + t_l: &Layout, + f: &Self, + f_l: &Layout, + ) -> Result { + let device = self.device.clone(); + let shape = t_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let dtype = t.dtype; + let mut buffer = self.device.new_buffer(el, dtype); + let command_buffer = self.device.command_queue.new_command_buffer(); + candle_metal_kernels::call_where_cond_strided( + &device.device, + &command_buffer, + &device.kernels, + "where_u8_f32", + &dims, + &self.buffer, + (layout.stride(), layout.start_offset()), + &t.buffer, + (&t_l.stride(), t_l.start_offset()), + &f.buffer, + (&f_l.stride(), f_l.start_offset()), + &mut buffer, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + Ok(Self { + buffer, + device, + dtype, + }) } fn conv1d( @@ -528,7 +554,7 @@ impl BackendStorage for MetalStorage { rhs_l: &Layout, ) -> Result { let transpose_left = false; - let transpose_right = false; + let transpose_right = !rhs_l.is_contiguous(); let alpha = 1.0; let beta = 0.0; self.matmul_generic( @@ -588,27 +614,12 @@ impl BackendStorage for MetalStorage { } impl MetalStorage { - pub(crate) fn matmul_t( - &self, - rhs: &Self, - (b, m, n, k): (usize, usize, usize, usize), - lhs_l: &Layout, - rhs_l: &Layout, - ) -> Result { - let transpose_left = false; - let transpose_right = true; - let alpha = 1.0; - let beta = 0.0; - self.matmul_generic( - rhs, - (b, m, n, k), - lhs_l, - rhs_l, - transpose_left, - transpose_right, - alpha, - beta, - ) + pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self { + Self { + buffer, + device, + dtype, + } } pub(crate) fn matmul_generic( &self, @@ -636,9 +647,10 @@ impl MetalStorage { } if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { debug!( - "TODO non contiguous matmul yet {:?} {:?}", + "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}", lhs_l.is_contiguous(), - rhs_l.is_contiguous() + rhs_l.is_contiguous(), + rhs_l ); return Ok(Self { buffer: out_buffer, @@ -647,7 +659,7 @@ impl MetalStorage { }); } - debug!("GEMM"); + debug!("TODO GEMM"); let command_buffer = self.device.command_queue.new_command_buffer(); encode_gemm::( &self.device, diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 9e3306ed03..96d2eaed1c 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,5 +1,4 @@ -use crate::{Device, Result, Shape, Tensor}; -use tracing::debug; +use crate::{backend::BackendStorage, Device, Result, Shape, Tensor}; #[cfg(target_feature = "avx")] pub mod avx; @@ -317,12 +316,14 @@ impl crate::CustomOp1 for QTensor { Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) } + #[cfg(feature = "metal")] fn metal_fwd( &self, storage: &crate::MetalStorage, layout: &crate::Layout, ) -> Result<(crate::MetalStorage, Shape)> { - debug!("TODO qmatmul"); + use tracing::debug; + debug!("TODO qmatmul {self:?} - {layout:?}"); if !layout.is_contiguous() { crate::bail!("input tensor is not contiguous {layout:?}") } @@ -339,22 +340,12 @@ impl crate::CustomOp1 for QTensor { } dst_shape.push(n); let dst_shape = Shape::from(dst_shape); - // let storage = storage.as_slice::()?; - // let storage = - // &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; - let dst_storage = vec![0f32; dst_shape.elem_count()]; - // self.matmul_t( - // (dst_shape.elem_count() / n, k, n), - // storage, - // &mut dst_storage, - // )?; - let cpu_storage = crate::CpuStorage::F32(dst_storage); - use crate::backend::{BackendDevice, BackendStorage}; - if let Device::Metal(device) = &self.device { - Ok((device.storage_from_cpu_storage(&cpu_storage)?, dst_shape)) - } else { - crate::bail!("qtensor not on metal device") - } + let dtype = storage.dtype(); + let buffer = storage.device().new_buffer(dst_shape.elem_count(), dtype); + + let device: crate::MetalDevice = storage.device().clone(); + let dst_storage = crate::MetalStorage::new(buffer, device, dtype); + Ok((dst_storage, dst_shape)) } } diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index a2cecbc3b0..37b82138ed 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -42,7 +42,7 @@ fn quantized_matmul() -> Result<()> { ] ); - let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; + let qtensor = quantized::QTensor::new(rhs_t, (4, 64), &cpu)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let res = matmul.forward(&tensor_lhs)?; assert_eq!( @@ -90,7 +90,7 @@ fn quantized_matmul_neg() -> Result<()> { ] ); - let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; + let qtensor = quantized::QTensor::new(rhs_t, (4, 64), cpu)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let res = matmul.forward(&tensor_lhs)?; assert_eq!( diff --git a/candle-examples/examples/blip/main.rs b/candle-examples/examples/blip/main.rs index a1051a8eaa..8b3083dedf 100644 --- a/candle-examples/examples/blip/main.rs +++ b/candle-examples/examples/blip/main.rs @@ -111,7 +111,7 @@ pub fn main() -> anyhow::Result<()> { let image = load_image(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); - let vb = quantized_blip::VarBuilder::from_gguf(model_file)?; + let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?; let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?; let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?; (image_embeds, device, Model::Q(model)) diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 0ceb27af7e..7b93002f84 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .extension() .map_or(false, |v| v == "safetensors"); let (model, config) = if is_gguf { - let vb = qmodel::VarBuilder::from_gguf(config_path)?; + let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?; let (_vocab_size, dim) = vb .get_no_shape("model.embed_tokens.weight")? .shape() diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index 18f18e5d89..00692d4ac2 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -234,12 +234,14 @@ fn main() -> Result<()> { }; println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let device = Device::Cpu; let start = std::time::Instant::now(); let config = Config::config_7b_v0_1(args.use_flash_attn); let (model, device) = if args.quantized { let filename = &filenames[0]; - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?; + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; let model = QMistral::new(&config, vb)?; (Model::Quantized(model), Device::Cpu) } else { diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 720a4441eb..a6ee6938fe 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -286,7 +286,9 @@ fn main() -> Result<()> { WhichModel::PhiHermes => Config::phi_hermes_1_3b(), }; let (model, device) = if args.quantized { - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?; + let device = Device::Cpu; + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?; let model = QMixFormer::new(&config, vb)?; (Model::Quantized(model), Device::Cpu) } else { diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs index 5a1cdf0c40..b25c44f4d4 100644 --- a/candle-examples/examples/quantized-t5/main.rs +++ b/candle-examples/examples/quantized-t5/main.rs @@ -132,7 +132,7 @@ impl T5ModelBuilder { } pub fn build_model(&self) -> Result { - let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?; + let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &self.device)?; Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) } diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs index 0f72b86251..715920ec1a 100644 --- a/candle-examples/examples/replit-code/main.rs +++ b/candle-examples/examples/replit-code/main.rs @@ -238,9 +238,11 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config = Config::replit_code_v1_5_3b(); let (model, device) = if args.quantized { - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?; + let device = Device::Cpu; + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?; let model = Model::Q(Q::new(&config, vb.pp("transformer"))?); - (model, Device::Cpu) + (model, device) } else { let device = candle_examples::device(args.cpu)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs index 0535aa702b..ec1180216f 100644 --- a/candle-examples/examples/stable-lm/main.rs +++ b/candle-examples/examples/stable-lm/main.rs @@ -235,10 +235,12 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config = Config::stablelm_3b_4e1t(args.use_flash_attn); let (model, device) = if args.quantized { + let device = Device::Cpu; let filename = &filenames[0]; - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?; + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; let model = QStableLM::new(&config, vb)?; - (Model::Quantized(model), Device::Cpu) + (Model::Quantized(model), device) } else { let device = candle_examples::device(args.cpu)?; let dtype = if device.is_cuda() { diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 0724745111..01df46b893 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -529,8 +529,10 @@ fn main() -> Result<()> { let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let mut model = if args.quantized { - let vb = - candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + &weights_filename, + &device, + )?; Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?) } else { let vb = diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 7a74de9e40..b8803d71d9 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -10,6 +10,7 @@ const AFFINE: &str = include_str!("affine.metal"); const INDEXING: &str = include_str!("indexing.metal"); const UNARY: &str = include_str!("unary.metal"); const BINARY: &str = include_str!("binary.metal"); +const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); @@ -19,6 +20,7 @@ pub enum Source { Indexing, Unary, Binary, + Ternary, Cast, Reduce, } @@ -119,6 +121,7 @@ impl Kernels { Source::Affine => AFFINE, Source::Unary => UNARY, Source::Binary => BINARY, + Source::Ternary => TERNARY, Source::Indexing => INDEXING, Source::Cast => CAST, Source::Reduce => REDUCE, @@ -600,6 +603,83 @@ pub fn call_affine( Ok(()) } +pub fn call_where_cond_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + cond: &Buffer, + (cond_stride, cond_offset): (&[usize], usize), + left: &Buffer, + (left_stride, left_offset): (&[usize], usize), + right: &Buffer, + (right_stride, right_offset): (&[usize], usize), + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Ternary, name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let size: usize = shape.iter().product(); + encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&size)); + encoder.set_bytes( + 1, + core::mem::size_of::() as u64, + void_ptr(&shape.len()), + ); + encoder.set_bytes( + 2, + (shape.len() * core::mem::size_of::()) as u64, + shape.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 3, + (cond_stride.len() * core::mem::size_of::()) as u64, + cond_stride.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 4, + (left_stride.len() * core::mem::size_of::()) as u64, + left_stride.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 5, + (right_stride.len() * core::mem::size_of::()) as u64, + right_stride.as_ptr() as *const c_void, + ); + encoder.set_buffer(6, Some(&cond), cond_offset as u64); + encoder.set_buffer(7, Some(&left), left_offset as u64); + encoder.set_buffer(8, Some(&right), right_offset as u64); + encoder.set_buffer(9, Some(&output), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests { use super::*; @@ -978,11 +1058,7 @@ mod tests { assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); } - fn run_reduce( - v: &[T], - out_length: usize, - name: &'static str, - ) -> Vec { + fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { let device = device(); let kernels = Kernels::new(); let command_queue = device.new_command_queue(); @@ -1089,4 +1165,81 @@ mod tests { vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] ); } + + fn run_where_cond( + shape: &[usize], + cond: &[I], + (cond_stride, cond_offset): (Vec, usize), + left_true: &[T], + (left_stride, left_offset): (Vec, usize), + right_false: &[T], + (right_stride, right_offset): (Vec, usize), + name: &'static str, + ) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let length = cond.len(); + let cond = device.new_buffer_with_data( + cond.as_ptr() as *const core::ffi::c_void, + (cond.len() * core::mem::size_of::()) as u64, + options, + ); + let left = device.new_buffer_with_data( + left_true.as_ptr() as *const core::ffi::c_void, + (length * core::mem::size_of::()) as u64, + options, + ); + let right = device.new_buffer_with_data( + right_false.as_ptr() as *const core::ffi::c_void, + (length * core::mem::size_of::()) as u64, + options, + ); + + let mut output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + call_where_cond_strided( + &device, + &command_buffer, + &kernels, + name, + &shape, + &cond, + (&cond_stride, cond_offset), + &left, + (&left_stride, left_offset), + &right, + (&cond_stride, cond_offset), + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(length) + } + + #[test] + fn where_cond() { + let shape = vec![6]; + let cond = vec![0u8, 1, 0, 0, 1, 1]; + let cond_l = (vec![1], 0); + let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let left_l = (vec![1], 0); + let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0]; + let right_l = (vec![1], 0); + let results = run_where_cond( + &shape, + &cond, + cond_l, + &left_true, + left_l, + &right_false, + right_l, + "where_u8_f32", + ); + assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); + } } diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal new file mode 100644 index 0000000000..0945b355cf --- /dev/null +++ b/candle-metal-kernels/src/ternary.metal @@ -0,0 +1,57 @@ +#include +# +using namespace metal; + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + + +#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &numel, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t *strides_t, \ + constant size_t *strides_f, \ + device const ID_TYPENAME *ids, \ + device const TYPENAME *t, \ + device const TYPENAME *f, \ + device TYPENAME *out ,\ + uint i [[ thread_position_in_grid ]] \ +) { \ + uint strided_i = get_strided_index(i, num_dims, dims, strides); \ + uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \ + uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \ + out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \ +} \ + +// WHERE_OP(float, int64_t, where_i64_f32) +// WHERE_OP(double, int64_t, where_i64_f64) +// WHERE_OP(uint8_t, int64_t, where_i64_u8) +// WHERE_OP(uint32_t, int64_t, where_i64_u32) +// WHERE_OP(int64_t, int64_t, where_i64_i64) +// +// WHERE_OP(float, uint32_t, where_u32_f32) +// WHERE_OP(double, uint32_t, where_u32_f64) +// WHERE_OP(uint8_t, uint32_t, where_u32_u8) +// WHERE_OP(uint32_t, uint32_t, where_u32_u32) +// WHERE_OP(int64_t, uint32_t, where_u32_i64) + +WHERE_OP(float, uint8_t, where_u8_f32) +// WHERE_OP(double, uint8_t, where_u8_f64) +// WHERE_OP(uint8_t, uint8_t, where_u8_u8) +// WHERE_OP(uint32_t, uint8_t, where_u8_u32) +// WHERE_OP(int64_t, uint8_t, where_u8_i64) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index f83357940d..fff32d2a10 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1271,7 +1271,9 @@ fn save_safetensors( /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]] fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> { let mut file = std::fs::File::open(path)?; - let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?; + let device = Device::Cpu; + let ggml = + ::candle::quantized::ggml_file::Content::read(&mut file, &device).map_err(wrap_err)?; let tensors = ggml .tensors .into_iter() @@ -1306,6 +1308,7 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]] fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { use ::candle::quantized::gguf_file; + let device = Device::Cpu; fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult { let v: PyObject = match v { gguf_file::Value::U8(x) => x.into_py(py), @@ -1336,7 +1339,7 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { .tensor_infos .keys() .map(|key| { - let qtensor = gguf.tensor(&mut file, key)?; + let qtensor = gguf.tensor(&mut file, key, &device)?; Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))) }) .collect::<::candle::Result>>() diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 8c04093a24..8025693d7c 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -195,12 +195,6 @@ fn precomput_freqs_cis( .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) .collect(); let theta = Tensor::new(theta.as_slice(), device)?; - let range: Vec = (0..MAX_SEQ_LEN).map(|r| r as f32).collect(); - // let idx_theta = Tensor::new(range.as_slice(), device)? - // .reshape((MAX_SEQ_LEN, 1))? - // .matmul(&theta.reshape((1, theta.elem_count()))?)?; - // TODO This change avoids allocating on Metal and then casting since allocating directly on - // CPU as f32 seems just as fast let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? .to_dtype(DType::F32)? .reshape((MAX_SEQ_LEN, 1))? diff --git a/candle-wasm-examples/blip/src/bin/m.rs b/candle-wasm-examples/blip/src/bin/m.rs index 660bb71743..e2ba4fed48 100644 --- a/candle-wasm-examples/blip/src/bin/m.rs +++ b/candle-wasm-examples/blip/src/bin/m.rs @@ -61,7 +61,7 @@ impl Model { let start = Date::now(); let model: SelectedModel = if quantized { - let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights)?; + let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights, &device)?; let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?; SelectedModel::Q(model) } else { diff --git a/candle-wasm-examples/phi/src/bin/m.rs b/candle-wasm-examples/phi/src/bin/m.rs index c18e6c3814..03a1b67319 100644 --- a/candle-wasm-examples/phi/src/bin/m.rs +++ b/candle-wasm-examples/phi/src/bin/m.rs @@ -38,9 +38,11 @@ impl Model { let tokenizer = Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; let start = Date::now(); + let device = Device::Cpu; let model = if quantized { - let vb = - candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer( + &weights, &device, + )?; let model = QMixFormer::new(&config, vb)?; SelectedModel::Quantized(model) } else { diff --git a/candle-wasm-examples/t5/src/bin/m-quantized.rs b/candle-wasm-examples/t5/src/bin/m-quantized.rs index 2f490b84d2..002a7830dd 100644 --- a/candle-wasm-examples/t5/src/bin/m-quantized.rs +++ b/candle-wasm-examples/t5/src/bin/m-quantized.rs @@ -31,7 +31,8 @@ impl ModelConditionalGeneration { ) -> Result { console_error_panic_hook::set_once(); console_log!("loading model"); - let vb = VarBuilder::from_gguf_buffer(&weights)?; + let device = Device::Cpu; + let vb = VarBuilder::from_gguf_buffer(&weights, &device)?; let mut config: Config = serde_json::from_slice(&config)?; let tokenizer = Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; @@ -128,7 +129,8 @@ impl ModelEncoder { ) -> Result { console_error_panic_hook::set_once(); console_log!("loading model"); - let vb = VarBuilder::from_gguf_buffer(&weights)?; + let device = Device::Cpu; + let vb = VarBuilder::from_gguf_buffer(&weights, &device)?; let mut config: Config = serde_json::from_slice(&config)?; config.use_cache = false; let tokenizer = diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs index a8646e3d01..412772ce36 100644 --- a/candle-wasm-examples/whisper/src/worker.rs +++ b/candle-wasm-examples/whisper/src/worker.rs @@ -309,6 +309,7 @@ impl Decoder { let model = if md.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer( &md.weights, + &device, )?; Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?) } else { diff --git a/candle-wasm-tests/tests/quantized_tests.rs b/candle-wasm-tests/tests/quantized_tests.rs index 5d53728c75..df3b6cf6f6 100644 --- a/candle-wasm-tests/tests/quantized_tests.rs +++ b/candle-wasm-tests/tests/quantized_tests.rs @@ -40,7 +40,7 @@ fn quantized_matmul_neg() -> Result<()> { ] ); - let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; + let qtensor = quantized::QTensor::new(rhs_t, (4, 64), cpu)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let res = matmul.forward(&tensor_lhs)?; assert_eq!(