From 4f179308a6adc5a13896d2244ed431b24a01a4c6 Mon Sep 17 00:00:00 2001 From: Martin Stefcek Date: Tue, 2 Apr 2024 21:20:49 +0400 Subject: [PATCH 1/6] add stable diffusion --- .gitignore | 1 + Cargo.toml | 15 +- atoma-inference/Cargo.toml | 15 +- atoma-inference/src/candle/mod.rs | 89 +++ .../src/candle/stable_diffusion.rs | 541 ++++++++++++++++++ .../src/candle/token_output_stream.rs | 86 +++ atoma-inference/src/lib.rs | 4 + 7 files changed, 745 insertions(+), 6 deletions(-) create mode 100644 atoma-inference/src/candle/mod.rs create mode 100644 atoma-inference/src/candle/stable_diffusion.rs create mode 100644 atoma-inference/src/candle/token_output_stream.rs diff --git a/.gitignore b/.gitignore index 1e7caa9e..9182e2f4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ Cargo.lock target/ +.vscode/ diff --git a/Cargo.toml b/Cargo.toml index c2ccab6a..fc2b0a3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,20 @@ [workspace] resolver = "2" +edition = "2021" -members = ["atoma-event-subscribe", "atoma-inference", "atoma-networking", "atoma-json-rpc", "atoma-storage"] +members = [ + "atoma-event-subscribe", + "atoma-inference", + "atoma-networking", + "atoma-json-rpc", + "atoma-storage", +] [workspace.package] version = "0.1.0" [workspace.dependencies] +anyhow = "1.0.81" async-trait = "0.1.78" candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" } candle-flash-attn = { git = "https://github.com/huggingface/candle", package = "candle-flash-attn", version = "0.4.2" } @@ -17,6 +25,11 @@ dotenv = "0.15.0" ed25519-consensus = "2.1.0" futures = "0.3.30" hf-hub = "0.3.2" +clap = "4.5.3" +image = { version = "0.25.0", default-features = false, features = [ + "jpeg", + "png", +] } serde = "1.0.197" serde_json = "1.0.114" rand = "0.8.5" diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index a9cbffc5..c6037821 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -1,11 +1,10 @@ [package] name = "inference" -version = "0.1.0" +version.workspace = true edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] +anyhow.workspace = true async-trait.workspace = true candle.workspace = true candle-flash-attn = { workspace = true, optional = true } @@ -19,8 +18,10 @@ hf-hub.workspace = true reqwest = { workspace = true, features = ["json"] } serde = { workspace = true, features = ["derive"] } serde_json.workspace = true +clap.workspace = true +image = { workspace = true } thiserror.workspace = true -tokenizers.workspace = true +tokenizers = { workspace = true, features = ["onig"] } tokio = { workspace = true, features = ["full", "tracing"] } tracing.workspace = true tracing-subscriber.workspace = true @@ -31,7 +32,11 @@ toml.workspace = true [features] -accelerate = ["candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] +accelerate = [ + "candle/accelerate", + "candle-nn/accelerate", + "candle-transformers/accelerate", +] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cudnn = ["candle/cudnn"] flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] diff --git a/atoma-inference/src/candle/mod.rs b/atoma-inference/src/candle/mod.rs new file mode 100644 index 00000000..8fe9d9b6 --- /dev/null +++ b/atoma-inference/src/candle/mod.rs @@ -0,0 +1,89 @@ +pub mod stable_diffusion; +pub mod token_output_stream; + +use std::{fs::File, io::Write, path::PathBuf}; + +use candle::{ + utils::{cuda_is_available, metal_is_available}, + DType, Device, Tensor, +}; +use tracing::info; + +use crate::models::ModelError; + +pub trait CandleModel { + type Fetch; + type Input; + fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError>; + fn inference(input: Self::Input) -> Result, ModelError>; +} + +pub fn device() -> Result { + if cuda_is_available() { + info!("Using CUDA"); + Device::new_cuda(0) + } else if metal_is_available() { + info!("Using Metal"); + Device::new_metal(0) + } else { + info!("Using Cpu"); + Ok(Device::Cpu) + } +} + +pub fn hub_load_safetensors( + repo: &hf_hub::api::sync::ApiRepo, + json_file: &str, +) -> candle::Result> { + let json_file = repo.get(json_file).map_err(candle::Error::wrap)?; + let json_file = std::fs::File::open(json_file)?; + let json: serde_json::Value = + serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; + let weight_map = match json.get("weight_map") { + None => candle::bail!("no weight map in {json_file:?}"), + Some(serde_json::Value::Object(map)) => map, + Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), + }; + let mut safetensors_files = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_files.insert(file.to_string()); + } + } + let safetensors_files = safetensors_files + .iter() + .map(|v| repo.get(v).map_err(candle::Error::wrap)) + .collect::>>()?; + Ok(safetensors_files) +} + +pub fn save_image>(img: &Tensor, p: P) -> candle::Result<()> { + let p = p.as_ref(); + let (channel, height, width) = img.dims3()?; + if channel != 3 { + candle::bail!("save_image expects an input of shape (3, height, width)") + } + let img = img.permute((1, 2, 0))?.flatten_all()?; + let pixels = img.to_vec1::()?; + let image: image::ImageBuffer, Vec> = + match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { + Some(image) => image, + None => candle::bail!("error saving image {p:?}"), + }; + image.save(p).map_err(candle::Error::wrap)?; + Ok(()) +} + +pub fn save_tensor_to_file(tensor: &Tensor, filename: &str) -> Result<(), candle::Error> { + let json_output = serde_json::to_string( + &tensor + .to_device(&Device::Cpu)? + .flatten_all()? + .to_dtype(DType::F64)? + .to_vec1::()?, + ) + .unwrap(); + let mut file = File::create(PathBuf::from(filename))?; + file.write_all(json_output.as_bytes())?; + Ok(()) +} diff --git a/atoma-inference/src/candle/stable_diffusion.rs b/atoma-inference/src/candle/stable_diffusion.rs new file mode 100644 index 00000000..c8718e4c --- /dev/null +++ b/atoma-inference/src/candle/stable_diffusion.rs @@ -0,0 +1,541 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use candle_transformers::models::stable_diffusion::{self}; + +use candle::{DType, Device, IndexOp, Module, Tensor, D}; +use tokenizers::Tokenizer; + +use crate::{candle::device, models::ModelError}; + +use super::{save_tensor_to_file, CandleModel}; + +pub struct Input { + prompt: String, + uncond_prompt: String, + + height: Option, + width: Option, + + /// The UNet weight file, in .safetensors format. + unet_weights: Option, + + /// The CLIP weight file, in .safetensors format. + clip_weights: Option, + + /// The VAE weight file, in .safetensors format. + vae_weights: Option, + + /// The file specifying the tokenizer to used for tokenization. + tokenizer: Option, + + /// The size of the sliced attention or 0 for automatic slicing (disabled by default) + sliced_attention_size: Option, + + /// The number of steps to run the diffusion for. + n_steps: Option, + + /// The number of samples to generate. + num_samples: i64, + + sd_version: StableDiffusionVersion, + + /// Generate intermediary images at each step. + intermediary_images: bool, + + use_flash_attn: bool, + + use_f16: bool, + + guidance_scale: Option, + + img2img: Option, + + /// The strength, indicates how much to transform the initial image. The + /// value must be between 0 and 1, a value of 1 discards the initial image + /// information. + img2img_strength: f64, + + /// The seed to use when generating random samples. + seed: Option, +} + +impl Input { + pub fn default_prompt(prompt: String) -> Self { + Self { + prompt, + uncond_prompt: "".to_string(), + height: Some(256), + width: Some(256), + unet_weights: None, + clip_weights: None, + vae_weights: None, + tokenizer: None, + sliced_attention_size: None, + n_steps: Some(20), + num_samples: 1, + sd_version: StableDiffusionVersion::V1_5, + intermediary_images: false, + use_flash_attn: false, + use_f16: true, + guidance_scale: None, + img2img: None, + img2img_strength: 0.8, + seed: Some(0), + } + } +} + +impl From<&Input> for Fetch { + fn from(input: &Input) -> Self { + Self { + tokenizer: input.tokenizer.clone(), + sd_version: input.sd_version, + use_f16: input.use_f16, + clip_weights: input.clip_weights.clone(), + vae_weights: input.vae_weights.clone(), + unet_weights: input.unet_weights.clone(), + } + } +} +pub struct StableDiffusion {} + +pub struct Fetch { + tokenizer: Option, + sd_version: StableDiffusionVersion, + use_f16: bool, + clip_weights: Option, + vae_weights: Option, + unet_weights: Option, +} + +impl CandleModel for StableDiffusion { + type Input = Input; + type Fetch = Fetch; + + fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError> { + let which = match fetch.sd_version { + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false], + _ => vec![true], + }; + for first in which { + let (clip_weights_file, tokenizer_file) = if first { + (ModelFile::Clip, ModelFile::Tokenizer) + } else { + (ModelFile::Clip2, ModelFile::Tokenizer2) + }; + + clip_weights_file.get(fetch.clip_weights.clone(), fetch.sd_version, false)?; + ModelFile::Vae.get(fetch.vae_weights.clone(), fetch.sd_version, fetch.use_f16)?; + tokenizer_file.get(fetch.tokenizer.clone(), fetch.sd_version, fetch.use_f16)?; + ModelFile::Unet.get(fetch.unet_weights.clone(), fetch.sd_version, fetch.use_f16)?; + } + Ok(()) + } + + fn inference(input: Self::Input) -> Result, ModelError> { + if !(0. ..=1.).contains(&input.img2img_strength) { + Err(ModelError::Config(format!( + "img2img_strength must be between 0 and 1, got {}", + input.img2img_strength, + )))? + } + + let guidance_scale = match input.guidance_scale { + Some(guidance_scale) => guidance_scale, + None => match input.sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 7.5, + StableDiffusionVersion::Turbo => 0., + }, + }; + let n_steps = match input.n_steps { + Some(n_steps) => n_steps, + None => match input.sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 30, + StableDiffusionVersion::Turbo => 1, + }, + }; + let dtype = if input.use_f16 { + DType::F16 + } else { + DType::F32 + }; + let sd_config = match input.sd_version { + StableDiffusionVersion::V1_5 => stable_diffusion::StableDiffusionConfig::v1_5( + input.sliced_attention_size, + input.height, + input.width, + ), + StableDiffusionVersion::V2_1 => stable_diffusion::StableDiffusionConfig::v2_1( + input.sliced_attention_size, + input.height, + input.width, + ), + StableDiffusionVersion::Xl => stable_diffusion::StableDiffusionConfig::sdxl( + input.sliced_attention_size, + input.height, + input.width, + ), + StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo( + input.sliced_attention_size, + input.height, + input.width, + ), + }; + + let scheduler = sd_config.build_scheduler(n_steps)?; + let device = device()?; + if let Some(seed) = input.seed { + device.set_seed(seed)?; + } + let use_guide_scale = guidance_scale > 1.0; + + let which = match input.sd_version { + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false], + _ => vec![true], + }; + let text_embeddings = which + .iter() + .map(|first| { + Self::text_embeddings( + &input.prompt, + &input.uncond_prompt, + input.tokenizer.clone(), + input.clip_weights.clone(), + input.sd_version, + &sd_config, + input.use_f16, + &device, + dtype, + use_guide_scale, + *first, + ) + }) + .collect::, _>>()?; + + let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?; + + let vae_weights = ModelFile::Vae.get(input.vae_weights, input.sd_version, input.use_f16)?; + let vae = sd_config.build_vae(vae_weights, &device, dtype)?; + let init_latent_dist = match &input.img2img { + None => None, + Some(image) => { + let image = Self::image_preprocess(image)?.to_device(&device)?; + Some(vae.encode(&image)?) + } + }; + let unet_weights = + ModelFile::Unet.get(input.unet_weights, input.sd_version, input.use_f16)?; + let unet = sd_config.build_unet(unet_weights, &device, 4, input.use_flash_attn, dtype)?; + + let t_start = if input.img2img.is_some() { + n_steps - (n_steps as f64 * input.img2img_strength) as usize + } else { + 0 + }; + let bsize = 1; + + let vae_scale = match input.sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 0.18215, + StableDiffusionVersion::Turbo => 0.13025, + }; + let mut res = Vec::new(); + + for _ in 0..input.num_samples { + let timesteps = scheduler.timesteps(); + let latents = match &init_latent_dist { + Some(init_latent_dist) => { + let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?; + if t_start < timesteps.len() { + let noise = latents.randn_like(0f64, 1f64)?; + scheduler.add_noise(&latents, noise, timesteps[t_start])? + } else { + latents + } + } + None => { + let latents = Tensor::randn( + 0f32, + 1f32, + (bsize, 4, sd_config.height / 8, sd_config.width / 8), + &device, + )?; + // scale the initial noise by the standard deviation required by the scheduler + (latents * scheduler.init_noise_sigma())? + } + }; + let mut latents = latents.to_dtype(dtype)?; + + for (timestep_index, ×tep) in timesteps.iter().enumerate() { + if timestep_index < t_start { + continue; + } + let latent_model_input = if use_guide_scale { + Tensor::cat(&[&latents, &latents], 0)? + } else { + latents.clone() + }; + + let latent_model_input = + scheduler.scale_model_input(latent_model_input, timestep)?; + let noise_pred = + unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?; + + let noise_pred = if use_guide_scale { + let noise_pred = noise_pred.chunk(2, 0)?; + let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); + + (noise_pred_uncond + + ((noise_pred_text - noise_pred_uncond)? * guidance_scale)?)? + } else { + noise_pred + }; + + latents = scheduler.step(&noise_pred, timestep, &latents)?; + } + save_tensor_to_file(&latents, "tensor1")?; + let image = vae.decode(&(&latents / vae_scale)?)?; + save_tensor_to_file(&image, "tensor2")?; + let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; + save_tensor_to_file(&image, "tensor3")?; + let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?; + save_tensor_to_file(&image, "tensor4")?; + res.push(image); + } + Ok(res) + } +} + +#[derive(Clone, Copy)] +enum StableDiffusionVersion { + V1_5, + V2_1, + Xl, + Turbo, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ModelFile { + Tokenizer, + Tokenizer2, + Clip, + Clip2, + Unet, + Vae, +} + +impl StableDiffusionVersion { + fn repo(&self) -> &'static str { + match self { + Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0", + Self::V2_1 => "stabilityai/stable-diffusion-2-1", + Self::V1_5 => "runwayml/stable-diffusion-v1-5", + Self::Turbo => "stabilityai/sdxl-turbo", + } + } + + fn unet_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "unet/diffusion_pytorch_model.fp16.safetensors" + } else { + "unet/diffusion_pytorch_model.safetensors" + } + } + } + } + + fn vae_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "vae/diffusion_pytorch_model.fp16.safetensors" + } else { + "vae/diffusion_pytorch_model.safetensors" + } + } + } + } + + fn clip_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "text_encoder/model.fp16.safetensors" + } else { + "text_encoder/model.safetensors" + } + } + } + } + + fn clip2_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "text_encoder_2/model.fp16.safetensors" + } else { + "text_encoder_2/model.safetensors" + } + } + } + } +} + +impl ModelFile { + fn get( + &self, + filename: Option, + version: StableDiffusionVersion, + use_f16: bool, + ) -> Result { + use hf_hub::api::sync::Api; + match filename { + Some(filename) => Ok(std::path::PathBuf::from(filename)), + None => { + let (repo, path) = match self { + Self::Tokenizer => { + let tokenizer_repo = match version { + StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { + "openai/clip-vit-base-patch32" + } + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => { + // This seems similar to the patch32 version except some very small + // difference in the split regex. + "openai/clip-vit-large-patch14" + } + }; + (tokenizer_repo, "tokenizer.json") + } + Self::Tokenizer2 => { + ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json") + } + Self::Clip => (version.repo(), version.clip_file(use_f16)), + Self::Clip2 => (version.repo(), version.clip2_file(use_f16)), + Self::Unet => (version.repo(), version.unet_file(use_f16)), + Self::Vae => { + // Override for SDXL when using f16 weights. + // See https://github.com/huggingface/candle/issues/1060 + if matches!( + version, + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo, + ) && use_f16 + { + ( + "madebyollin/sdxl-vae-fp16-fix", + "diffusion_pytorch_model.safetensors", + ) + } else { + (version.repo(), version.vae_file(use_f16)) + } + } + }; + let filename = Api::new()?.model(repo.to_string()).get(path)?; + Ok(filename) + } + } + } +} + +impl StableDiffusion { + #[allow(clippy::too_many_arguments)] + fn text_embeddings( + prompt: &str, + uncond_prompt: &str, + tokenizer: Option, + clip_weights: Option, + sd_version: StableDiffusionVersion, + sd_config: &stable_diffusion::StableDiffusionConfig, + use_f16: bool, + device: &Device, + dtype: DType, + use_guide_scale: bool, + first: bool, + ) -> Result { + let (clip_weights_file, tokenizer_file) = if first { + (ModelFile::Clip, ModelFile::Tokenizer) + } else { + (ModelFile::Clip2, ModelFile::Tokenizer2) + }; + let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?; + let tokenizer = Tokenizer::from_file(tokenizer)?; + let pad_id = match &sd_config.clip.pad_with { + Some(padding) => { + *tokenizer + .get_vocab(true) + .get(padding.as_str()) + .ok_or(ModelError::Msg(format!( + "Padding token {padding} not found in the tokenizer vocabulary" + )))? + } + None => *tokenizer + .get_vocab(true) + .get("<|endoftext|>") + .ok_or(ModelError::Msg("".to_string()))?, + }; + let mut tokens = tokenizer.encode(prompt, true)?.get_ids().to_vec(); + while tokens.len() < sd_config.clip.max_position_embeddings { + tokens.push(pad_id) + } + let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; + + let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?; + let clip_config = if first { + &sd_config.clip + } else { + sd_config.clip2.as_ref().unwrap() + }; + let text_model = stable_diffusion::build_clip_transformer( + clip_config, + clip_weights, + device, + DType::F64, + )?; + let text_embeddings = text_model.forward(&tokens)?; + + let text_embeddings = if use_guide_scale { + let mut uncond_tokens = tokenizer.encode(uncond_prompt, true)?.get_ids().to_vec(); + while uncond_tokens.len() < sd_config.clip.max_position_embeddings { + uncond_tokens.push(pad_id) + } + + let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; + let uncond_embeddings = text_model.forward(&uncond_tokens)?; + + Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)? + } else { + text_embeddings.to_dtype(dtype)? + }; + Ok(text_embeddings) + } + + fn image_preprocess>(path: T) -> Result { + let img = image::io::Reader::open(path)?.decode()?; + let (height, width) = (img.height() as usize, img.width() as usize); + let height = height - height % 32; + let width = width - width % 32; + let img = img.resize_to_fill( + width as u32, + height as u32, + image::imageops::FilterType::CatmullRom, + ); + let img = img.to_rgb8(); + let img = img.into_raw(); + let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? + .permute((2, 0, 1))? + .to_dtype(DType::F32)? + .affine(2. / 255., -1.)? + .unsqueeze(0)?; + Ok(img) + } +} diff --git a/atoma-inference/src/candle/token_output_stream.rs b/atoma-inference/src/candle/token_output_stream.rs new file mode 100644 index 00000000..1f507c5e --- /dev/null +++ b/atoma-inference/src/candle/token_output_stream.rs @@ -0,0 +1,86 @@ +use candle::Result; + +/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a +/// streaming way rather than having to wait for the full decoding. +pub struct TokenOutputStream { + tokenizer: tokenizers::Tokenizer, + tokens: Vec, + prev_index: usize, + current_index: usize, +} + +impl TokenOutputStream { + pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { + Self { + tokenizer, + tokens: Vec::new(), + prev_index: 0, + current_index: 0, + } + } + + pub fn into_inner(self) -> tokenizers::Tokenizer { + self.tokenizer + } + + fn decode(&self, tokens: &[u32]) -> Result { + match self.tokenizer.decode(tokens, true) { + Ok(str) => Ok(str), + Err(err) => candle::bail!("cannot decode: {err}"), + } + } + + // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 + pub fn next_token(&mut self, token: u32) -> Result> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + self.tokens.push(token); + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { + let text = text.split_at(prev_text.len()); + self.prev_index = self.current_index; + self.current_index = self.tokens.len(); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_rest(&self) -> Result> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() { + let text = text.split_at(prev_text.len()); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_all(&self) -> Result { + self.decode(&self.tokens) + } + + pub fn get_token(&self, token_s: &str) -> Option { + self.tokenizer.get_vocab(true).get(token_s).copied() + } + + pub fn tokenizer(&self) -> &tokenizers::Tokenizer { + &self.tokenizer + } + + pub fn clear(&mut self) { + self.tokens.clear(); + self.prev_index = 0; + self.current_index = 0; + } +} diff --git a/atoma-inference/src/lib.rs b/atoma-inference/src/lib.rs index 4ad5a4d4..c6f12e7a 100644 --- a/atoma-inference/src/lib.rs +++ b/atoma-inference/src/lib.rs @@ -1,4 +1,8 @@ pub mod model_thread; +pub mod candle; +pub mod config; +pub mod core_thread; +pub mod models; pub mod service; pub mod specs; From 2e748e8a19c1f3b1f37c65e3aa68ae27e9d68a79 Mon Sep 17 00:00:00 2001 From: Martin Stefcek Date: Tue, 2 Apr 2024 21:37:01 +0400 Subject: [PATCH 2/6] fix clippy --- Cargo.toml | 2 +- .../src/candle/stable_diffusion.rs | 5 +--- atoma-inference/src/lib.rs | 8 ++---- atoma-inference/src/models/candle/mamba.rs | 6 ++--- atoma-inference/src/models/mod.rs | 26 +++++++++---------- 5 files changed, 19 insertions(+), 28 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fc2b0a3a..1758ec34 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ members = [ version = "0.1.0" [workspace.dependencies] +reqwest = "0.12.1" anyhow = "1.0.81" async-trait = "0.1.78" candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" } @@ -33,7 +34,6 @@ image = { version = "0.25.0", default-features = false, features = [ serde = "1.0.197" serde_json = "1.0.114" rand = "0.8.5" -reqwest = "0.12.1" thiserror = "1.0.58" tokenizers = "0.15.2" tokio = "1.36.0" diff --git a/atoma-inference/src/candle/stable_diffusion.rs b/atoma-inference/src/candle/stable_diffusion.rs index c8718e4c..46dadbc1 100644 --- a/atoma-inference/src/candle/stable_diffusion.rs +++ b/atoma-inference/src/candle/stable_diffusion.rs @@ -43,9 +43,6 @@ pub struct Input { sd_version: StableDiffusionVersion, - /// Generate intermediary images at each step. - intermediary_images: bool, - use_flash_attn: bool, use_f16: bool, @@ -78,7 +75,6 @@ impl Input { n_steps: Some(20), num_samples: 1, sd_version: StableDiffusionVersion::V1_5, - intermediary_images: false, use_flash_attn: false, use_f16: true, guidance_scale: None, @@ -315,6 +311,7 @@ impl CandleModel for StableDiffusion { } } +#[allow(dead_code)] #[derive(Clone, Copy)] enum StableDiffusionVersion { V1_5, diff --git a/atoma-inference/src/lib.rs b/atoma-inference/src/lib.rs index c6f12e7a..80915305 100644 --- a/atoma-inference/src/lib.rs +++ b/atoma-inference/src/lib.rs @@ -1,10 +1,6 @@ -pub mod model_thread; +pub mod apis; pub mod candle; -pub mod config; -pub mod core_thread; +pub mod model_thread; pub mod models; pub mod service; pub mod specs; - -pub mod apis; -pub mod models; diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 284e25e4..65ed9835 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -64,8 +64,7 @@ impl ModelTrait for MambaModel { let tokenizer_filename = filenames[1].clone(); let weights_filenames = filenames[2..].to_vec(); - let tokenizer = - Tokenizer::from_file(tokenizer_filename).map_err(ModelError::TokenizerError)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename)?; let config: Config = serde_json::from_slice(&std::fs::read(config_filename).map_err(ModelError::IoError)?) @@ -110,8 +109,7 @@ impl ModelTrait for MambaModel { let mut tokens = self .tokenizer .tokenizer() - .encode(prompt, true) - .map_err(ModelError::TokenizerError)? + .encode(prompt, true)? .get_ids() .to_vec(); let mut logits_processor = diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index dc82f4c1..724a5d42 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -41,22 +41,22 @@ pub trait Response: Send + 'static { #[derive(Debug, Error)] pub enum ModelError { - #[error("Tokenizer error: `{0}`")] - TokenizerError(Box), - #[error("IO error: `{0}`")] - IoError(std::io::Error), #[error("Deserialize error: `{0}`")] - DeserializeError(serde_json::Error), - #[error("Candle error: `{0}`")] - CandleError(CandleError), + DeserializeError(#[from] serde_json::Error), #[error("{0}")] Msg(String), -} - -impl From for ModelError { - fn from(error: CandleError) -> Self { - Self::CandleError(error) - } + #[error("Candle error: `{0}`")] + CandleError(#[from] CandleError), + #[error("Config error: `{0}`")] + Config(String), + #[error("Image error: `{0}`")] + ImageError(#[from] image::ImageError), + #[error("Io error: `{0}`")] + IoError(#[from] std::io::Error), + #[error("Error: `{0}`")] + BoxedError(#[from] Box), + #[error("ApiError error: `{0}`")] + ApiError(#[from] hf_hub::api::sync::ApiError), } #[macro_export] From bed664e80c807928f30ee838dbb88f1afcc41387 Mon Sep 17 00:00:00 2001 From: Martin Stefcek Date: Wed, 3 Apr 2024 17:40:52 +0400 Subject: [PATCH 3/6] change to modeltrait --- .../src/candle/stable_diffusion.rs | 26 ++++++++++++++++--- atoma-inference/src/models/candle/mamba.rs | 1 + atoma-inference/src/models/mod.rs | 4 +++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/atoma-inference/src/candle/stable_diffusion.rs b/atoma-inference/src/candle/stable_diffusion.rs index 46dadbc1..013dfe36 100644 --- a/atoma-inference/src/candle/stable_diffusion.rs +++ b/atoma-inference/src/candle/stable_diffusion.rs @@ -9,9 +9,12 @@ use candle_transformers::models::stable_diffusion::{self}; use candle::{DType, Device, IndexOp, Module, Tensor, D}; use tokenizers::Tokenizer; -use crate::{candle::device, models::ModelError}; +use crate::{ + candle::device, + models::{types::PrecisionBits, ModelError, ModelId, ModelTrait}, +}; -use super::{save_tensor_to_file, CandleModel}; +use super::save_tensor_to_file; pub struct Input { prompt: String, @@ -108,9 +111,20 @@ pub struct Fetch { unet_weights: Option, } -impl CandleModel for StableDiffusion { +impl ModelTrait for StableDiffusion { type Input = Input; type Fetch = Fetch; + type Output = Vec; + + fn load( + _filenames: Vec, + _precision: PrecisionBits, + ) -> Result + where + Self: Sized, + { + Ok(Self {}) + } fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError> { let which = match fetch.sd_version { @@ -132,7 +146,11 @@ impl CandleModel for StableDiffusion { Ok(()) } - fn inference(input: Self::Input) -> Result, ModelError> { + fn model_id(&self) -> ModelId { + "candle/stable_diffusion".to_string() + } + + fn run(&mut self, input: Self::Input) -> Result { if !(0. ..=1.).contains(&input.img2img_strength) { Err(ModelError::Config(format!( "img2img_strength must be between 0 and 1, got {}", diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 65ed9835..bdaa7e2d 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -49,6 +49,7 @@ impl MambaModel { } impl ModelTrait for MambaModel { + type Fetch = (); type Input = TextModelInput; type Output = String; diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index 724a5d42..467b2398 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -14,9 +14,13 @@ pub mod types; pub type ModelId = String; pub trait ModelTrait { + type Fetch; type Input; type Output; + fn fetch(_fetch: &Self::Fetch) -> Result<(), ModelError> { + Ok(()) + } fn load(filenames: Vec, precision: PrecisionBits) -> Result where Self: Sized; From fe8d06a3541459c49792b44ab5e6a1a8c9cce11e Mon Sep 17 00:00:00 2001 From: Martin Stefcek Date: Wed, 3 Apr 2024 18:50:41 +0400 Subject: [PATCH 4/6] fix clippy --- atoma-inference/src/service.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 1d13024f..60cfdf9a 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -207,6 +207,11 @@ mod tests { impl ModelTrait for TestModelInstance { type Input = (); type Output = (); + type Fetch = (); + + fn fetch(_fetch: &Self::Fetch) -> Result<(), crate::models::ModelError> { + Ok(()) + } fn load(_: Vec, _: PrecisionBits) -> Result { Ok(Self {}) From 253ed4cd0401b6298b28eb5fd63a6988089f46aa Mon Sep 17 00:00:00 2001 From: Martin Stefcek Date: Wed, 3 Apr 2024 18:55:04 +0400 Subject: [PATCH 5/6] move --- atoma-inference/src/candle/mod.rs | 89 ------------------- .../src/candle/token_output_stream.rs | 86 ------------------ atoma-inference/src/lib.rs | 1 - atoma-inference/src/models/candle/mod.rs | 79 ++++++++++++++++ .../{ => models}/candle/stable_diffusion.rs | 7 +- 5 files changed, 81 insertions(+), 181 deletions(-) delete mode 100644 atoma-inference/src/candle/mod.rs delete mode 100644 atoma-inference/src/candle/token_output_stream.rs rename atoma-inference/src/{ => models}/candle/stable_diffusion.rs (99%) diff --git a/atoma-inference/src/candle/mod.rs b/atoma-inference/src/candle/mod.rs deleted file mode 100644 index 8fe9d9b6..00000000 --- a/atoma-inference/src/candle/mod.rs +++ /dev/null @@ -1,89 +0,0 @@ -pub mod stable_diffusion; -pub mod token_output_stream; - -use std::{fs::File, io::Write, path::PathBuf}; - -use candle::{ - utils::{cuda_is_available, metal_is_available}, - DType, Device, Tensor, -}; -use tracing::info; - -use crate::models::ModelError; - -pub trait CandleModel { - type Fetch; - type Input; - fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError>; - fn inference(input: Self::Input) -> Result, ModelError>; -} - -pub fn device() -> Result { - if cuda_is_available() { - info!("Using CUDA"); - Device::new_cuda(0) - } else if metal_is_available() { - info!("Using Metal"); - Device::new_metal(0) - } else { - info!("Using Cpu"); - Ok(Device::Cpu) - } -} - -pub fn hub_load_safetensors( - repo: &hf_hub::api::sync::ApiRepo, - json_file: &str, -) -> candle::Result> { - let json_file = repo.get(json_file).map_err(candle::Error::wrap)?; - let json_file = std::fs::File::open(json_file)?; - let json: serde_json::Value = - serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; - let weight_map = match json.get("weight_map") { - None => candle::bail!("no weight map in {json_file:?}"), - Some(serde_json::Value::Object(map)) => map, - Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), - }; - let mut safetensors_files = std::collections::HashSet::new(); - for value in weight_map.values() { - if let Some(file) = value.as_str() { - safetensors_files.insert(file.to_string()); - } - } - let safetensors_files = safetensors_files - .iter() - .map(|v| repo.get(v).map_err(candle::Error::wrap)) - .collect::>>()?; - Ok(safetensors_files) -} - -pub fn save_image>(img: &Tensor, p: P) -> candle::Result<()> { - let p = p.as_ref(); - let (channel, height, width) = img.dims3()?; - if channel != 3 { - candle::bail!("save_image expects an input of shape (3, height, width)") - } - let img = img.permute((1, 2, 0))?.flatten_all()?; - let pixels = img.to_vec1::()?; - let image: image::ImageBuffer, Vec> = - match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { - Some(image) => image, - None => candle::bail!("error saving image {p:?}"), - }; - image.save(p).map_err(candle::Error::wrap)?; - Ok(()) -} - -pub fn save_tensor_to_file(tensor: &Tensor, filename: &str) -> Result<(), candle::Error> { - let json_output = serde_json::to_string( - &tensor - .to_device(&Device::Cpu)? - .flatten_all()? - .to_dtype(DType::F64)? - .to_vec1::()?, - ) - .unwrap(); - let mut file = File::create(PathBuf::from(filename))?; - file.write_all(json_output.as_bytes())?; - Ok(()) -} diff --git a/atoma-inference/src/candle/token_output_stream.rs b/atoma-inference/src/candle/token_output_stream.rs deleted file mode 100644 index 1f507c5e..00000000 --- a/atoma-inference/src/candle/token_output_stream.rs +++ /dev/null @@ -1,86 +0,0 @@ -use candle::Result; - -/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a -/// streaming way rather than having to wait for the full decoding. -pub struct TokenOutputStream { - tokenizer: tokenizers::Tokenizer, - tokens: Vec, - prev_index: usize, - current_index: usize, -} - -impl TokenOutputStream { - pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { - Self { - tokenizer, - tokens: Vec::new(), - prev_index: 0, - current_index: 0, - } - } - - pub fn into_inner(self) -> tokenizers::Tokenizer { - self.tokenizer - } - - fn decode(&self, tokens: &[u32]) -> Result { - match self.tokenizer.decode(tokens, true) { - Ok(str) => Ok(str), - Err(err) => candle::bail!("cannot decode: {err}"), - } - } - - // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 - pub fn next_token(&mut self, token: u32) -> Result> { - let prev_text = if self.tokens.is_empty() { - String::new() - } else { - let tokens = &self.tokens[self.prev_index..self.current_index]; - self.decode(tokens)? - }; - self.tokens.push(token); - let text = self.decode(&self.tokens[self.prev_index..])?; - if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { - let text = text.split_at(prev_text.len()); - self.prev_index = self.current_index; - self.current_index = self.tokens.len(); - Ok(Some(text.1.to_string())) - } else { - Ok(None) - } - } - - pub fn decode_rest(&self) -> Result> { - let prev_text = if self.tokens.is_empty() { - String::new() - } else { - let tokens = &self.tokens[self.prev_index..self.current_index]; - self.decode(tokens)? - }; - let text = self.decode(&self.tokens[self.prev_index..])?; - if text.len() > prev_text.len() { - let text = text.split_at(prev_text.len()); - Ok(Some(text.1.to_string())) - } else { - Ok(None) - } - } - - pub fn decode_all(&self) -> Result { - self.decode(&self.tokens) - } - - pub fn get_token(&self, token_s: &str) -> Option { - self.tokenizer.get_vocab(true).get(token_s).copied() - } - - pub fn tokenizer(&self) -> &tokenizers::Tokenizer { - &self.tokenizer - } - - pub fn clear(&mut self) { - self.tokens.clear(); - self.prev_index = 0; - self.current_index = 0; - } -} diff --git a/atoma-inference/src/lib.rs b/atoma-inference/src/lib.rs index 80915305..5bec0058 100644 --- a/atoma-inference/src/lib.rs +++ b/atoma-inference/src/lib.rs @@ -1,5 +1,4 @@ pub mod apis; -pub mod candle; pub mod model_thread; pub mod models; pub mod service; diff --git a/atoma-inference/src/models/candle/mod.rs b/atoma-inference/src/models/candle/mod.rs index 323f72f5..de38d19d 100644 --- a/atoma-inference/src/models/candle/mod.rs +++ b/atoma-inference/src/models/candle/mod.rs @@ -1 +1,80 @@ +use std::{fs::File, io::Write, path::PathBuf}; + +use candle::{ + utils::{cuda_is_available, metal_is_available}, + DType, Device, Tensor, +}; +use tracing::info; + pub mod mamba; +pub mod stable_diffusion; + +pub fn device() -> Result { + if cuda_is_available() { + info!("Using CUDA"); + Device::new_cuda(0) + } else if metal_is_available() { + info!("Using Metal"); + Device::new_metal(0) + } else { + info!("Using Cpu"); + Ok(Device::Cpu) + } +} + +pub fn hub_load_safetensors( + repo: &hf_hub::api::sync::ApiRepo, + json_file: &str, +) -> candle::Result> { + let json_file = repo.get(json_file).map_err(candle::Error::wrap)?; + let json_file = std::fs::File::open(json_file)?; + let json: serde_json::Value = + serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; + let weight_map = match json.get("weight_map") { + None => candle::bail!("no weight map in {json_file:?}"), + Some(serde_json::Value::Object(map)) => map, + Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), + }; + let mut safetensors_files = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_files.insert(file.to_string()); + } + } + let safetensors_files = safetensors_files + .iter() + .map(|v| repo.get(v).map_err(candle::Error::wrap)) + .collect::>>()?; + Ok(safetensors_files) +} + +pub fn save_image>(img: &Tensor, p: P) -> candle::Result<()> { + let p = p.as_ref(); + let (channel, height, width) = img.dims3()?; + if channel != 3 { + candle::bail!("save_image expects an input of shape (3, height, width)") + } + let img = img.permute((1, 2, 0))?.flatten_all()?; + let pixels = img.to_vec1::()?; + let image: image::ImageBuffer, Vec> = + match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { + Some(image) => image, + None => candle::bail!("error saving image {p:?}"), + }; + image.save(p).map_err(candle::Error::wrap)?; + Ok(()) +} + +pub fn save_tensor_to_file(tensor: &Tensor, filename: &str) -> Result<(), candle::Error> { + let json_output = serde_json::to_string( + &tensor + .to_device(&Device::Cpu)? + .flatten_all()? + .to_dtype(DType::F64)? + .to_vec1::()?, + ) + .unwrap(); + let mut file = File::create(PathBuf::from(filename))?; + file.write_all(json_output.as_bytes())?; + Ok(()) +} diff --git a/atoma-inference/src/candle/stable_diffusion.rs b/atoma-inference/src/models/candle/stable_diffusion.rs similarity index 99% rename from atoma-inference/src/candle/stable_diffusion.rs rename to atoma-inference/src/models/candle/stable_diffusion.rs index 013dfe36..3c3ab7b9 100644 --- a/atoma-inference/src/candle/stable_diffusion.rs +++ b/atoma-inference/src/models/candle/stable_diffusion.rs @@ -9,12 +9,9 @@ use candle_transformers::models::stable_diffusion::{self}; use candle::{DType, Device, IndexOp, Module, Tensor, D}; use tokenizers::Tokenizer; -use crate::{ - candle::device, - models::{types::PrecisionBits, ModelError, ModelId, ModelTrait}, -}; +use crate::models::{types::PrecisionBits, ModelError, ModelId, ModelTrait}; -use super::save_tensor_to_file; +use super::{device, save_tensor_to_file}; pub struct Input { prompt: String, From b6d03125a4b819012b816686b3961511e66c4073 Mon Sep 17 00:00:00 2001 From: Martin Stefcek Date: Thu, 4 Apr 2024 09:35:53 +0400 Subject: [PATCH 6/6] fixes --- Cargo.toml | 2 -- atoma-inference/Cargo.toml | 2 -- atoma-inference/src/models/candle/mod.rs | 16 ++++++++++------ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1758ec34..50995b9b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ version = "0.1.0" [workspace.dependencies] reqwest = "0.12.1" -anyhow = "1.0.81" async-trait = "0.1.78" candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" } candle-flash-attn = { git = "https://github.com/huggingface/candle", package = "candle-flash-attn", version = "0.4.2" } @@ -26,7 +25,6 @@ dotenv = "0.15.0" ed25519-consensus = "2.1.0" futures = "0.3.30" hf-hub = "0.3.2" -clap = "4.5.3" image = { version = "0.25.0", default-features = false, features = [ "jpeg", "png", diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index c6037821..3da618eb 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -4,7 +4,6 @@ version.workspace = true edition = "2021" [dependencies] -anyhow.workspace = true async-trait.workspace = true candle.workspace = true candle-flash-attn = { workspace = true, optional = true } @@ -18,7 +17,6 @@ hf-hub.workspace = true reqwest = { workspace = true, features = ["json"] } serde = { workspace = true, features = ["derive"] } serde_json.workspace = true -clap.workspace = true image = { workspace = true } thiserror.workspace = true tokenizers = { workspace = true, features = ["onig"] } diff --git a/atoma-inference/src/models/candle/mod.rs b/atoma-inference/src/models/candle/mod.rs index de38d19d..a10f6f6f 100644 --- a/atoma-inference/src/models/candle/mod.rs +++ b/atoma-inference/src/models/candle/mod.rs @@ -6,6 +6,10 @@ use candle::{ }; use tracing::info; +use crate::bail; + +use super::ModelError; + pub mod mamba; pub mod stable_diffusion; @@ -25,15 +29,15 @@ pub fn device() -> Result { pub fn hub_load_safetensors( repo: &hf_hub::api::sync::ApiRepo, json_file: &str, -) -> candle::Result> { +) -> Result, ModelError> { let json_file = repo.get(json_file).map_err(candle::Error::wrap)?; let json_file = std::fs::File::open(json_file)?; let json: serde_json::Value = serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; let weight_map = match json.get("weight_map") { - None => candle::bail!("no weight map in {json_file:?}"), + None => bail!("no weight map in {json_file:?}"), Some(serde_json::Value::Object(map)) => map, - Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), + Some(_) => bail!("weight map in {json_file:?} is not a map"), }; let mut safetensors_files = std::collections::HashSet::new(); for value in weight_map.values() { @@ -48,18 +52,18 @@ pub fn hub_load_safetensors( Ok(safetensors_files) } -pub fn save_image>(img: &Tensor, p: P) -> candle::Result<()> { +pub fn save_image>(img: &Tensor, p: P) -> Result<(), ModelError> { let p = p.as_ref(); let (channel, height, width) = img.dims3()?; if channel != 3 { - candle::bail!("save_image expects an input of shape (3, height, width)") + bail!("save_image expects an input of shape (3, height, width)") } let img = img.permute((1, 2, 0))?.flatten_all()?; let pixels = img.to_vec1::()?; let image: image::ImageBuffer, Vec> = match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { Some(image) => image, - None => candle::bail!("error saving image {p:?}"), + None => bail!("error saving image {p:?}"), }; image.save(p).map_err(candle::Error::wrap)?; Ok(())