Skip to content

Commit

Permalink
Fixing the tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Nov 9, 2023
1 parent e1bab1b commit 16fd771
Show file tree
Hide file tree
Showing 22 changed files with 316 additions and 88 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
# metal = { path = "../metal-rs", features = ["mps"] }

[profile.release-with-debug]
inherits = "release"
Expand Down
6 changes: 4 additions & 2 deletions candle-core/examples/tensor-tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
}
Format::Ggml => {
let mut file = std::fs::File::open(file)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
let device = Device::Cpu;
let content = candle_core::quantized::ggml_file::Content::read(&mut file, &device)?;
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, qtensor) in tensors.iter() {
Expand Down Expand Up @@ -291,6 +292,7 @@ fn run_quantize(
q: Quantization,
qmode: QuantizationMode,
) -> Result<()> {
let device = Device::Cpu;
if in_files.is_empty() {
candle_core::bail!("no specified input files")
}
Expand Down Expand Up @@ -338,7 +340,7 @@ fn run_quantize(
.map(|(name, _)| {
println!(" quantizing {name}");
let mut in_file = std::fs::File::open(&in_files[0])?;
let tensor = content.tensor(&mut in_file, name)?;
let tensor = content.tensor(&mut in_file, name, &device)?;
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
Ok((name, tensor))
})
Expand Down
82 changes: 47 additions & 35 deletions candle-core/src/metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@ impl From<String> for MetalError {
}
}

impl MetalError {
fn msg<S: AsRef<str>>(msg: S) -> Self {
MetalError::Message(msg.as_ref().to_string())
}
}

#[derive(Clone)]
pub struct MetalDevice {
device: metal::Device,
Expand Down Expand Up @@ -410,10 +404,42 @@ impl BackendStorage for MetalStorage {
})
}

fn where_cond(&self, _: &Layout, rhs: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
debug!("TODO where_cond");
Ok(rhs.clone())
// todo!()
fn where_cond(
&self,
layout: &Layout,
t: &Self,
t_l: &Layout,
f: &Self,
f_l: &Layout,
) -> Result<Self> {
let device = self.device.clone();
let shape = t_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
let dtype = t.dtype;
let mut buffer = self.device.new_buffer(el, dtype);
let command_buffer = self.device.command_queue.new_command_buffer();
candle_metal_kernels::call_where_cond_strided(
&device.device,
&command_buffer,
&device.kernels,
"where_u8_f32",
&dims,
&self.buffer,
(layout.stride(), layout.start_offset()),
&t.buffer,
(&t_l.stride(), t_l.start_offset()),
&f.buffer,
(&f_l.stride(), f_l.start_offset()),
&mut buffer,
)
.map_err(MetalError::from)?;
command_buffer.commit();
Ok(Self {
buffer,
device,
dtype,
})
}

fn conv1d(
Expand Down Expand Up @@ -528,7 +554,7 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout,
) -> Result<Self> {
let transpose_left = false;
let transpose_right = false;
let transpose_right = !rhs_l.is_contiguous();
let alpha = 1.0;
let beta = 0.0;
self.matmul_generic(
Expand Down Expand Up @@ -588,27 +614,12 @@ impl BackendStorage for MetalStorage {
}

impl MetalStorage {
pub(crate) fn matmul_t(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let transpose_left = false;
let transpose_right = true;
let alpha = 1.0;
let beta = 0.0;
self.matmul_generic(
rhs,
(b, m, n, k),
lhs_l,
rhs_l,
transpose_left,
transpose_right,
alpha,
beta,
)
pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
Self {
buffer,
device,
dtype,
}
}
pub(crate) fn matmul_generic(
&self,
Expand Down Expand Up @@ -636,9 +647,10 @@ impl MetalStorage {
}
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
debug!(
"TODO non contiguous matmul yet {:?} {:?}",
"TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}",
lhs_l.is_contiguous(),
rhs_l.is_contiguous()
rhs_l.is_contiguous(),
rhs_l
);
return Ok(Self {
buffer: out_buffer,
Expand All @@ -647,7 +659,7 @@ impl MetalStorage {
});
}

debug!("GEMM");
debug!("TODO GEMM");
let command_buffer = self.device.command_queue.new_command_buffer();
encode_gemm::<Float32, Float32, Float32>(
&self.device,
Expand Down
29 changes: 10 additions & 19 deletions candle-core/src/quantized/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::{Device, Result, Shape, Tensor};
use tracing::debug;
use crate::{backend::BackendStorage, Device, Result, Shape, Tensor};

Check warning on line 1 in candle-core/src/quantized/mod.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

unused import: `backend::BackendStorage`

Check warning on line 1 in candle-core/src/quantized/mod.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

unused import: `backend::BackendStorage`

Check failure on line 1 in candle-core/src/quantized/mod.rs

View workflow job for this annotation

GitHub Actions / Clippy

unused import: `backend::BackendStorage`

#[cfg(target_feature = "avx")]
pub mod avx;
Expand Down Expand Up @@ -317,12 +316,14 @@ impl crate::CustomOp1 for QTensor {
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
}

#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, Shape)> {
debug!("TODO qmatmul");
use tracing::debug;
debug!("TODO qmatmul {self:?} - {layout:?}");
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
Expand All @@ -339,22 +340,12 @@ impl crate::CustomOp1 for QTensor {
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
// let storage = storage.as_slice::<f32>()?;
// let storage =
// &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let dst_storage = vec![0f32; dst_shape.elem_count()];
// self.matmul_t(
// (dst_shape.elem_count() / n, k, n),
// storage,
// &mut dst_storage,
// )?;
let cpu_storage = crate::CpuStorage::F32(dst_storage);
use crate::backend::{BackendDevice, BackendStorage};
if let Device::Metal(device) = &self.device {
Ok((device.storage_from_cpu_storage(&cpu_storage)?, dst_shape))
} else {
crate::bail!("qtensor not on metal device")
}
let dtype = storage.dtype();
let buffer = storage.device().new_buffer(dst_shape.elem_count(), dtype);

let device: crate::MetalDevice = storage.device().clone();
let dst_storage = crate::MetalStorage::new(buffer, device, dtype);
Ok((dst_storage, dst_shape))
}
}

Expand Down
4 changes: 2 additions & 2 deletions candle-core/tests/quantized_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ fn quantized_matmul() -> Result<()> {
]
);

let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let qtensor = quantized::QTensor::new(rhs_t, (4, 64), &cpu)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
assert_eq!(
Expand Down Expand Up @@ -90,7 +90,7 @@ fn quantized_matmul_neg() -> Result<()> {
]
);

let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let qtensor = quantized::QTensor::new(rhs_t, (4, 64), cpu)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
assert_eq!(
Expand Down
2 changes: 1 addition & 1 deletion candle-examples/examples/blip/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ pub fn main() -> anyhow::Result<()> {
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");

let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::Q(model))
Expand Down
2 changes: 1 addition & 1 deletion candle-examples/examples/llama2-c/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.extension()
.map_or(false, |v| v == "safetensors");
let (model, config) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")?
.shape()
Expand Down
4 changes: 3 additions & 1 deletion candle-examples/examples/mistral/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,14 @@ fn main() -> Result<()> {
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = Device::Cpu;

let start = std::time::Instant::now();
let config = Config::config_7b_v0_1(args.use_flash_attn);
let (model, device) = if args.quantized {
let filename = &filenames[0];
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QMistral::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
} else {
Expand Down
4 changes: 3 additions & 1 deletion candle-examples/examples/phi/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,9 @@ fn main() -> Result<()> {
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
};
let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
let device = Device::Cpu;
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
let model = QMixFormer::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
} else {
Expand Down
2 changes: 1 addition & 1 deletion candle-examples/examples/quantized-t5/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl T5ModelBuilder {
}

pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &self.device)?;
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}

Expand Down
6 changes: 4 additions & 2 deletions candle-examples/examples/replit-code/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,11 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config = Config::replit_code_v1_5_3b();
let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
let device = Device::Cpu;
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
(model, Device::Cpu)
(model, device)
} else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
Expand Down
6 changes: 4 additions & 2 deletions candle-examples/examples/stable-lm/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,12 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
let (model, device) = if args.quantized {
let device = Device::Cpu;
let filename = &filenames[0];
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QStableLM::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
(Model::Quantized(model), device)
} else {
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
Expand Down
6 changes: 4 additions & 2 deletions candle-examples/examples/whisper/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,10 @@ fn main() -> Result<()> {

let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let mut model = if args.quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&weights_filename,
&device,
)?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else {
let vb =
Expand Down
Loading

0 comments on commit 16fd771

Please sign in to comment.