From 41a3da974894551783058df54aac3aa3904462ab Mon Sep 17 00:00:00 2001 From: Cifko Date: Wed, 27 Mar 2024 16:49:58 +0100 Subject: [PATCH] add stable diffusion --- .gitignore | 1 + Cargo.toml | 16 +- atoma-inference/Cargo.toml | 18 +- atoma-inference/src/candle/mod.rs | 75 +++ .../src/candle/stable_diffusion.rs | 539 ++++++++++++++++++ .../src/candle/token_output_stream.rs | 86 +++ atoma-inference/src/lib.rs | 1 + atoma-inference/src/llama/mod.rs | 40 ++ atoma-inference/src/main.rs | 40 +- atoma-inference/src/models.rs | 41 +- 10 files changed, 830 insertions(+), 27 deletions(-) create mode 100644 atoma-inference/src/candle/mod.rs create mode 100644 atoma-inference/src/candle/stable_diffusion.rs create mode 100644 atoma-inference/src/candle/token_output_stream.rs create mode 100644 atoma-inference/src/llama/mod.rs diff --git a/.gitignore b/.gitignore index 1e7caa9e..9182e2f4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ Cargo.lock target/ +.vscode/ diff --git a/Cargo.toml b/Cargo.toml index 0d1fdee6..ee28f58f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,17 +1,31 @@ [workspace] resolver = "2" +edition = "2021" -members = ["atoma-event-subscribe", "atoma-inference", "atoma-networking", "atoma-json-rpc", "atoma-storage"] +members = [ + "atoma-event-subscribe", + "atoma-inference", + "atoma-networking", + "atoma-json-rpc", + "atoma-storage", +] [workspace.package] version = "0.1.0" [workspace.dependencies] +anyhow = "1.0.81" async-trait = "0.1.78" candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" } candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" } candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" } +clap = "4.5.3" ed25519-consensus = "2.1.0" +hf-hub = "0.3.0" +image = { version = "0.25.0", default-features = false, features = [ + "jpeg", + "png", +] } serde = "1.0.197" thiserror = "1.0.58" tokenizers = "0.15.2" diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index 82e2740f..810bfe22 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -1,18 +1,26 @@ [package] name = "inference" -version = "0.1.0" +version.workspace = true edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] +anyhow.workspace = true async-trait.workspace = true -candle.workspace = true candle-nn.workspace = true candle-transformers.workspace = true +candle.workspace = true +clap.workspace = true ed25519-consensus.workspace = true +hf-hub = { workspace = true, features = ["tokio"] } +image = { workspace = true } serde = { workspace = true, features = ["derive"] } +serde_json = "1.0.114" thiserror.workspace = true -tokenizers.workspace = true +tokenizers = { workspace = true, features = ["onig"] } tokio = { workspace = true, features = ["full", "tracing"] } tracing.workspace = true +llama_cpp = "0.3.1" + +[features] +cuda = ["candle/cuda", "candle-nn/cuda"] +metal = ["candle/metal", "candle-nn/metal"] diff --git a/atoma-inference/src/candle/mod.rs b/atoma-inference/src/candle/mod.rs new file mode 100644 index 00000000..d2bdd966 --- /dev/null +++ b/atoma-inference/src/candle/mod.rs @@ -0,0 +1,75 @@ +pub mod stable_diffusion; +pub mod token_output_stream; + +use std::path::PathBuf; + +use candle::{ + utils::{cuda_is_available, metal_is_available}, + Device, Tensor, +}; +use tracing::info; + +use crate::models::ModelError; + +pub trait CandleModel { + type Fetch; + type Input; + fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError>; + fn inference(input: Self::Input) -> Result, ModelError>; +} + +pub fn device() -> Result { + if cuda_is_available() { + info!("Using CUDA"); + Device::new_cuda(0) + } else if metal_is_available() { + info!("Using Metal"); + Device::new_metal(0) + } else { + info!("Using Cpu"); + Ok(Device::Cpu) + } +} + +pub fn hub_load_safetensors( + repo: &hf_hub::api::sync::ApiRepo, + json_file: &str, +) -> candle::Result> { + let json_file = repo.get(json_file).map_err(candle::Error::wrap)?; + let json_file = std::fs::File::open(json_file)?; + let json: serde_json::Value = + serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; + let weight_map = match json.get("weight_map") { + None => candle::bail!("no weight map in {json_file:?}"), + Some(serde_json::Value::Object(map)) => map, + Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), + }; + let mut safetensors_files = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_files.insert(file.to_string()); + } + } + let safetensors_files = safetensors_files + .iter() + .map(|v| repo.get(v).map_err(candle::Error::wrap)) + .collect::>>()?; + Ok(safetensors_files) +} + +pub fn save_image>(img: &Tensor, p: P) -> candle::Result<()> { + let p = p.as_ref(); + let (channel, height, width) = img.dims3()?; + if channel != 3 { + candle::bail!("save_image expects an input of shape (3, height, width)") + } + let img = img.permute((1, 2, 0))?.flatten_all()?; + let pixels = img.to_vec1::()?; + let image: image::ImageBuffer, Vec> = + match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { + Some(image) => image, + None => candle::bail!("error saving image {p:?}"), + }; + image.save(p).map_err(candle::Error::wrap)?; + Ok(()) +} diff --git a/atoma-inference/src/candle/stable_diffusion.rs b/atoma-inference/src/candle/stable_diffusion.rs new file mode 100644 index 00000000..23ecb95e --- /dev/null +++ b/atoma-inference/src/candle/stable_diffusion.rs @@ -0,0 +1,539 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use candle_transformers::models::stable_diffusion::{self, StableDiffusionConfig}; +use std::path::PathBuf; + +use candle::{DType, Device, IndexOp, Module, Tensor, D}; +use tokenizers::Tokenizer; + +use crate::{candle::device, models::ModelError}; + +use super::CandleModel; + +pub struct Input { + prompt: String, + uncond_prompt: String, + + height: Option, + width: Option, + + /// The UNet weight file, in .safetensors format. + unet_weights: Option, + + /// The CLIP weight file, in .safetensors format. + clip_weights: Option, + + /// The VAE weight file, in .safetensors format. + vae_weights: Option, + + /// The file specifying the tokenizer to used for tokenization. + tokenizer: Option, + + /// The size of the sliced attention or 0 for automatic slicing (disabled by default) + sliced_attention_size: Option, + + /// The number of steps to run the diffusion for. + n_steps: Option, + + /// The number of samples to generate. + num_samples: i64, + + sd_version: StableDiffusionVersion, + + /// Generate intermediary images at each step. + intermediary_images: bool, + + use_flash_attn: bool, + + use_f16: bool, + + guidance_scale: Option, + + img2img: Option, + + /// The strength, indicates how much to transform the initial image. The + /// value must be between 0 and 1, a value of 1 discards the initial image + /// information. + img2img_strength: f64, + + /// The seed to use when generating random samples. + seed: Option, +} + +impl Input { + pub fn default_prompt(prompt: String) -> Self { + Self { + prompt, + uncond_prompt: "".to_string(), + height: Some(256), + width: Some(256), + unet_weights: None, + clip_weights: None, + vae_weights: None, + tokenizer: None, + sliced_attention_size: None, + n_steps: Some(20), + num_samples: 1, + sd_version: StableDiffusionVersion::V1_5, + intermediary_images: false, + use_flash_attn: false, + use_f16: true, + guidance_scale: None, + img2img: None, + img2img_strength: 0.8, + seed: None, + } + } +} + +impl From<&Input> for Fetch { + fn from(input: &Input) -> Self { + Self { + tokenizer: input.tokenizer.clone(), + sd_version: input.sd_version, + use_f16: input.use_f16, + clip_weights: input.clip_weights.clone(), + vae_weights: input.vae_weights.clone(), + unet_weights: input.unet_weights.clone(), + } + } +} +pub struct StableDiffusion {} + +pub struct Fetch { + tokenizer: Option, + sd_version: StableDiffusionVersion, + use_f16: bool, + clip_weights: Option, + vae_weights: Option, + unet_weights: Option, +} + +impl CandleModel for StableDiffusion { + type Input = Input; + type Fetch = Fetch; + + fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError> { + let which = match fetch.sd_version { + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false], + _ => vec![true], + }; + for first in which { + let (clip_weights_file, tokenizer_file) = if first { + (ModelFile::Clip, ModelFile::Tokenizer) + } else { + (ModelFile::Clip2, ModelFile::Tokenizer2) + }; + + clip_weights_file.get(fetch.clip_weights.clone(), fetch.sd_version, false)?; + ModelFile::Vae.get(fetch.vae_weights.clone(), fetch.sd_version, fetch.use_f16)?; + tokenizer_file.get(fetch.tokenizer.clone(), fetch.sd_version, fetch.use_f16)?; + ModelFile::Unet.get(fetch.unet_weights.clone(), fetch.sd_version, fetch.use_f16)?; + } + Ok(()) + } + + fn inference(input: Self::Input) -> Result, ModelError> { + if !(0. ..=1.).contains(&input.img2img_strength) { + Err(ModelError::Config(format!( + "img2img_strength must be between 0 and 1, got {}", + input.img2img_strength, + )))? + } + + let guidance_scale = match input.guidance_scale { + Some(guidance_scale) => guidance_scale, + None => match input.sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 7.5, + StableDiffusionVersion::Turbo => 0., + }, + }; + let n_steps = match input.n_steps { + Some(n_steps) => n_steps, + None => match input.sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 30, + StableDiffusionVersion::Turbo => 1, + }, + }; + let dtype = if input.use_f16 { + DType::F16 + } else { + DType::F32 + }; + let sd_config = match input.sd_version { + StableDiffusionVersion::V1_5 => stable_diffusion::StableDiffusionConfig::v1_5( + input.sliced_attention_size, + input.height, + input.width, + ), + StableDiffusionVersion::V2_1 => stable_diffusion::StableDiffusionConfig::v2_1( + input.sliced_attention_size, + input.height, + input.width, + ), + StableDiffusionVersion::Xl => stable_diffusion::StableDiffusionConfig::sdxl( + input.sliced_attention_size, + input.height, + input.width, + ), + StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo( + input.sliced_attention_size, + input.height, + input.width, + ), + }; + + let scheduler = sd_config.build_scheduler(n_steps)?; + let device = device()?; + if let Some(seed) = input.seed { + device.set_seed(seed)?; + } + let use_guide_scale = guidance_scale > 1.0; + + let which = match input.sd_version { + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false], + _ => vec![true], + }; + let text_embeddings = which + .iter() + .map(|first| { + Self::text_embeddings( + &input.prompt, + &input.uncond_prompt, + input.tokenizer.clone(), + input.clip_weights.clone(), + input.sd_version, + &sd_config, + input.use_f16, + &device, + dtype, + use_guide_scale, + *first, + ) + }) + .collect::, _>>()?; + + let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?; + + let vae_weights = ModelFile::Vae.get(input.vae_weights, input.sd_version, input.use_f16)?; + let vae = sd_config.build_vae(vae_weights, &device, dtype)?; + let init_latent_dist = match &input.img2img { + None => None, + Some(image) => { + let image = Self::image_preprocess(image)?.to_device(&device)?; + Some(vae.encode(&image)?) + } + }; + let unet_weights = + ModelFile::Unet.get(input.unet_weights, input.sd_version, input.use_f16)?; + let unet = sd_config.build_unet(unet_weights, &device, 4, input.use_flash_attn, dtype)?; + + let t_start = if input.img2img.is_some() { + n_steps - (n_steps as f64 * input.img2img_strength) as usize + } else { + 0 + }; + let bsize = 1; + + let vae_scale = match input.sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 0.18215, + StableDiffusionVersion::Turbo => 0.13025, + }; + let mut res = Vec::new(); + + for _ in 0..input.num_samples { + let timesteps = scheduler.timesteps(); + let latents = match &init_latent_dist { + Some(init_latent_dist) => { + let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?; + if t_start < timesteps.len() { + let noise = latents.randn_like(0f64, 1f64)?; + scheduler.add_noise(&latents, noise, timesteps[t_start])? + } else { + latents + } + } + None => { + let latents = Tensor::randn( + 0f32, + 1f32, + (bsize, 4, sd_config.height / 8, sd_config.width / 8), + &device, + )?; + // scale the initial noise by the standard deviation required by the scheduler + (latents * scheduler.init_noise_sigma())? + } + }; + let mut latents = latents.to_dtype(dtype)?; + + for (timestep_index, ×tep) in timesteps.iter().enumerate() { + if timestep_index < t_start { + continue; + } + let latent_model_input = if use_guide_scale { + Tensor::cat(&[&latents, &latents], 0)? + } else { + latents.clone() + }; + + let latent_model_input = + scheduler.scale_model_input(latent_model_input, timestep)?; + let noise_pred = + unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?; + + let noise_pred = if use_guide_scale { + let noise_pred = noise_pred.chunk(2, 0)?; + let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); + + (noise_pred_uncond + + ((noise_pred_text - noise_pred_uncond)? * guidance_scale)?)? + } else { + noise_pred + }; + + latents = scheduler.step(&noise_pred, timestep, &latents)?; + } + + let image = vae.decode(&(&latents / vae_scale)?)?; + let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; + let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?; + res.push(image); + } + Ok(res) + } +} + +#[derive(Clone, Copy)] +enum StableDiffusionVersion { + V1_5, + V2_1, + Xl, + Turbo, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ModelFile { + Tokenizer, + Tokenizer2, + Clip, + Clip2, + Unet, + Vae, +} + +impl StableDiffusionVersion { + fn repo(&self) -> &'static str { + match self { + Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0", + Self::V2_1 => "stabilityai/stable-diffusion-2-1", + Self::V1_5 => "runwayml/stable-diffusion-v1-5", + Self::Turbo => "stabilityai/sdxl-turbo", + } + } + + fn unet_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "unet/diffusion_pytorch_model.fp16.safetensors" + } else { + "unet/diffusion_pytorch_model.safetensors" + } + } + } + } + + fn vae_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "vae/diffusion_pytorch_model.fp16.safetensors" + } else { + "vae/diffusion_pytorch_model.safetensors" + } + } + } + } + + fn clip_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "text_encoder/model.fp16.safetensors" + } else { + "text_encoder/model.safetensors" + } + } + } + } + + fn clip2_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "text_encoder_2/model.fp16.safetensors" + } else { + "text_encoder_2/model.safetensors" + } + } + } + } +} + +impl ModelFile { + fn get( + &self, + filename: Option, + version: StableDiffusionVersion, + use_f16: bool, + ) -> Result { + use hf_hub::api::sync::Api; + match filename { + Some(filename) => Ok(std::path::PathBuf::from(filename)), + None => { + let (repo, path) = match self { + Self::Tokenizer => { + let tokenizer_repo = match version { + StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { + "openai/clip-vit-base-patch32" + } + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => { + // This seems similar to the patch32 version except some very small + // difference in the split regex. + "openai/clip-vit-large-patch14" + } + }; + (tokenizer_repo, "tokenizer.json") + } + Self::Tokenizer2 => { + ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json") + } + Self::Clip => (version.repo(), version.clip_file(use_f16)), + Self::Clip2 => (version.repo(), version.clip2_file(use_f16)), + Self::Unet => (version.repo(), version.unet_file(use_f16)), + Self::Vae => { + // Override for SDXL when using f16 weights. + // See https://github.com/huggingface/candle/issues/1060 + if matches!( + version, + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo, + ) && use_f16 + { + ( + "madebyollin/sdxl-vae-fp16-fix", + "diffusion_pytorch_model.safetensors", + ) + } else { + (version.repo(), version.vae_file(use_f16)) + } + } + }; + let filename = Api::new()?.model(repo.to_string()).get(path)?; + Ok(filename) + } + } + } +} + +impl StableDiffusion { + #[allow(clippy::too_many_arguments)] + fn text_embeddings( + prompt: &str, + uncond_prompt: &str, + tokenizer: Option, + clip_weights: Option, + sd_version: StableDiffusionVersion, + sd_config: &stable_diffusion::StableDiffusionConfig, + use_f16: bool, + device: &Device, + dtype: DType, + use_guide_scale: bool, + first: bool, + ) -> Result { + let (clip_weights_file, tokenizer_file) = if first { + (ModelFile::Clip, ModelFile::Tokenizer) + } else { + (ModelFile::Clip2, ModelFile::Tokenizer2) + }; + let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?; + let tokenizer = Tokenizer::from_file(tokenizer)?; + let pad_id = match &sd_config.clip.pad_with { + Some(padding) => { + *tokenizer + .get_vocab(true) + .get(padding.as_str()) + .ok_or(ModelError::Msg(format!( + "Padding token {padding} not found in the tokenizer vocabulary" + )))? + } + None => *tokenizer + .get_vocab(true) + .get("<|endoftext|>") + .ok_or(ModelError::Msg("".to_string()))?, + }; + let mut tokens = tokenizer.encode(prompt, true)?.get_ids().to_vec(); + while tokens.len() < sd_config.clip.max_position_embeddings { + tokens.push(pad_id) + } + let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; + + let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?; + let clip_config = if first { + &sd_config.clip + } else { + sd_config.clip2.as_ref().unwrap() + }; + let text_model = stable_diffusion::build_clip_transformer( + clip_config, + clip_weights, + device, + DType::F32, + )?; + let text_embeddings = text_model.forward(&tokens)?; + + let text_embeddings = if use_guide_scale { + let mut uncond_tokens = tokenizer.encode(uncond_prompt, true)?.get_ids().to_vec(); + while uncond_tokens.len() < sd_config.clip.max_position_embeddings { + uncond_tokens.push(pad_id) + } + + let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; + let uncond_embeddings = text_model.forward(&uncond_tokens)?; + + Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)? + } else { + text_embeddings.to_dtype(dtype)? + }; + Ok(text_embeddings) + } + + fn image_preprocess>(path: T) -> Result { + let img = image::io::Reader::open(path)?.decode()?; + let (height, width) = (img.height() as usize, img.width() as usize); + let height = height - height % 32; + let width = width - width % 32; + let img = img.resize_to_fill( + width as u32, + height as u32, + image::imageops::FilterType::CatmullRom, + ); + let img = img.to_rgb8(); + let img = img.into_raw(); + let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? + .permute((2, 0, 1))? + .to_dtype(DType::F32)? + .affine(2. / 255., -1.)? + .unsqueeze(0)?; + Ok(img) + } +} diff --git a/atoma-inference/src/candle/token_output_stream.rs b/atoma-inference/src/candle/token_output_stream.rs new file mode 100644 index 00000000..1f507c5e --- /dev/null +++ b/atoma-inference/src/candle/token_output_stream.rs @@ -0,0 +1,86 @@ +use candle::Result; + +/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a +/// streaming way rather than having to wait for the full decoding. +pub struct TokenOutputStream { + tokenizer: tokenizers::Tokenizer, + tokens: Vec, + prev_index: usize, + current_index: usize, +} + +impl TokenOutputStream { + pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { + Self { + tokenizer, + tokens: Vec::new(), + prev_index: 0, + current_index: 0, + } + } + + pub fn into_inner(self) -> tokenizers::Tokenizer { + self.tokenizer + } + + fn decode(&self, tokens: &[u32]) -> Result { + match self.tokenizer.decode(tokens, true) { + Ok(str) => Ok(str), + Err(err) => candle::bail!("cannot decode: {err}"), + } + } + + // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 + pub fn next_token(&mut self, token: u32) -> Result> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + self.tokens.push(token); + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { + let text = text.split_at(prev_text.len()); + self.prev_index = self.current_index; + self.current_index = self.tokens.len(); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_rest(&self) -> Result> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() { + let text = text.split_at(prev_text.len()); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_all(&self) -> Result { + self.decode(&self.tokens) + } + + pub fn get_token(&self, token_s: &str) -> Option { + self.tokenizer.get_vocab(true).get(token_s).copied() + } + + pub fn tokenizer(&self) -> &tokenizers::Tokenizer { + &self.tokenizer + } + + pub fn clear(&mut self) { + self.tokens.clear(); + self.prev_index = 0; + self.current_index = 0; + } +} diff --git a/atoma-inference/src/lib.rs b/atoma-inference/src/lib.rs index 1ca7c952..bd81752c 100644 --- a/atoma-inference/src/lib.rs +++ b/atoma-inference/src/lib.rs @@ -1,3 +1,4 @@ +pub mod candle; pub mod config; pub mod core_thread; pub mod models; diff --git a/atoma-inference/src/llama/mod.rs b/atoma-inference/src/llama/mod.rs new file mode 100644 index 00000000..bf294ff4 --- /dev/null +++ b/atoma-inference/src/llama/mod.rs @@ -0,0 +1,40 @@ +use llama_cpp::standard_sampler::StandardSampler; +use llama_cpp::{LlamaContextError, LlamaLoadError, LlamaModel, LlamaParams, SessionParams}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum LlamaError { + #[error("Llama load error: {0}")] + LlamaLoadError(#[from] LlamaLoadError), + #[error("Llama context error: {0}")] + LlamaContextError(#[from] LlamaContextError), +} + +pub fn run(prompt: &str) -> Result { + let model = + LlamaModel::load_from_file("../.gguf/llama-2-7b.Q5_K_M.gguf", LlamaParams::default())?; + + let mut ctx = model.create_session(SessionParams::default())?; + + ctx.advance_context(prompt)?; + + let max_tokens = 10; + let mut decoded_tokens = 0; + + let completions = ctx + .start_completing_with(StandardSampler::default(), 1024) + .into_strings(); + + let mut result = String::new(); + + for completion in completions { + result.push_str(&completion); + + decoded_tokens += 1; + + if decoded_tokens > max_tokens { + break; + } + } + Ok(result) +} diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index e7a11a96..f02aae54 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -1,3 +1,41 @@ +// use candle::llama::run; + +// use candle::stable_diffusion::run; + +mod candle; +mod llama; +mod models; +mod types; + +fn stable_diffusion() { + use crate::candle::stable_diffusion; + use crate::candle::{save_image, CandleModel}; + let input = stable_diffusion::Input::default_prompt("Mountains over purple lake".to_string()); + stable_diffusion::StableDiffusion::fetch(&(&input).into()).unwrap(); + let images = stable_diffusion::StableDiffusion::inference(input).unwrap(); + let mut i = 0; + for image in &images { + save_image(image, format!("image-{i}.png")).unwrap(); + i += 1; + } +} + fn main() { - println!("Hello, world!"); + stable_diffusion(); + // let result = llama::run("One day I will").unwrap(); + // println!("{}", result); + + // let x = run( + // "The most important thing is ".to_string(), + // Default::default(), + // ) + // .unwrap(); + // println!("{}", x); + // run( + // "Green boat on ocean during storm".to_string(), + // "".to_string(), + // Default::default(), + // ) + // .unwrap(); + // let mut i = 0; } diff --git a/atoma-inference/src/models.rs b/atoma-inference/src/models.rs index c2d4612b..86c512bd 100644 --- a/atoma-inference/src/models.rs +++ b/atoma-inference/src/models.rs @@ -13,8 +13,9 @@ use candle_transformers::{ stable_diffusion::StableDiffusionConfig, }, }; -use thiserror::Error; +use hf_hub::api::sync::ApiError; +use thiserror::Error; use tokenizers::Tokenizer; use crate::types::Temperature; @@ -201,14 +202,9 @@ impl ModelApi for Model { (tokens.len(), 0) }; let ctx = &tokens[tokens.len().saturating_sub(context_size)..]; - let input = Tensor::new(ctx, &model_specs.device) - .map_err(ModelError::TensorError)? - .unsqueeze(0) - .map_err(ModelError::TensorError)?; - let logits = model - .forward(&input, context_index, &mut cache) - .map_err(ModelError::TensorError)?; - let logits = logits.squeeze(0).map_err(ModelError::LogitsError)?; + let input = Tensor::new(ctx, &model_specs.device)?.unsqueeze(0)?; + let logits = model.forward(&input, context_index, &mut cache)?; + let logits = logits.squeeze(0)?; let logits = if repeat_penalty == 1. { logits } else { @@ -217,14 +213,11 @@ impl ModelApi for Model { &logits, repeat_penalty, &tokens[start_at..], - ) - .map_err(ModelError::TensorError)? + )? }; index_pos += ctx.len(); - let next_token = logits_processor - .sample(&logits) - .map_err(ModelError::TensorError)?; + let next_token = logits_processor.sample(&logits)?; tokens_generated += 1; tokens.push(next_token); @@ -263,12 +256,20 @@ impl ModelApi for Model { pub enum ModelError { #[error("Cache error: `{0}`")] CacheError(String), - #[error("Failed to load error: `{0}`")] - LoadError(CandleError), - #[error("Logits error: `{0}`")] - LogitsError(CandleError), - #[error("Tensor error: `{0}`")] - TensorError(CandleError), + #[error("Candle error: `{0}`")] + CandleError(#[from] CandleError), #[error("Failed input tokenization: `{0}`")] TokenizerError(Box), + #[error("ApiError error: `{0}`")] + ApiError(#[from] ApiError), + #[error("Model error: `{0}`")] + Msg(String), + #[error("Config error: `{0}`")] + Config(String), + #[error("Error: `{0}`")] + BoxedError(#[from] Box), + #[error("Io error: `{0}`")] + IoError(#[from] std::io::Error), + #[error("Image error: `{0}`")] + ImageError(#[from] image::ImageError), }