From 40e4a615f9faa06e25273dca8a3d70eaa6e2a74c Mon Sep 17 00:00:00 2001 From: Martin Stefcek Date: Fri, 22 Mar 2024 22:43:34 +0400 Subject: [PATCH] add llama and stable diffusion --- .vscode/settings.json | 24 + Cargo.toml | 22 +- atoma-inference/Cargo.toml | 17 +- atoma-inference/src/main.rs | 18 +- atoma-inference/src/models/llama.rs | 162 ++++++ atoma-inference/src/models/mod.rs | 65 +++ .../src/models/stable_diffusion.rs | 532 ++++++++++++++++++ .../src/models/token_output_stream.rs | 86 +++ 8 files changed, 918 insertions(+), 8 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 atoma-inference/src/models/llama.rs create mode 100644 atoma-inference/src/models/mod.rs create mode 100644 atoma-inference/src/models/stable_diffusion.rs create mode 100644 atoma-inference/src/models/token_output_stream.rs diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..da84b231 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,24 @@ +{ + "cSpell.words": [ + "amall", + "bsize", + "Catmull", + "ctxt", + "dtype", + "endoftext", + "laion", + "logits", + "madebyollin", + "mmaped", + "Narsil", + "openai", + "runwayml", + "safetensors", + "sdxl", + "stabilityai", + "timestep", + "uncond", + "Unet", + "unsqueeze" + ] +} diff --git a/Cargo.toml b/Cargo.toml index 0d1fdee6..fe66bf3e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,20 @@ [workspace] resolver = "2" - -members = ["atoma-event-subscribe", "atoma-inference", "atoma-networking", "atoma-json-rpc", "atoma-storage"] +edition = "2021" +members = [ + "atoma-event-subscribe", + "atoma-inference", + "atoma-networking", + "atoma-json-rpc", + "atoma-storage", +] [workspace.package] version = "0.1.0" [workspace.dependencies] +anyhow = "1.0.81" +clap = "4.5.3" async-trait = "0.1.78" candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" } candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" } @@ -17,3 +25,13 @@ thiserror = "1.0.58" tokenizers = "0.15.2" tokio = "1.36.0" tracing = "0.1.40" +serde = "1.0.197" +thiserror = "1.0.58" +tokenizers = { version = "0.15.0", default-features = false } +tokio = "1.36.0" +tracing = "0.1.40" +hf-hub = "0.3.0" +image = { version = "0.25.0", default-features = false, features = [ + "jpeg", + "png", +] } diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index 82e2740f..647795bd 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -1,18 +1,25 @@ [package] name = "inference" -version = "0.1.0" +version.workspace = true edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] +anyhow.workspace = true async-trait.workspace = true -candle.workspace = true candle-nn.workspace = true candle-transformers.workspace = true +candle.workspace = true +clap.workspace = true ed25519-consensus.workspace = true +hf-hub = { workspace = true, features = ["tokio"] } +image = { workspace = true } serde = { workspace = true, features = ["derive"] } +serde_json = "1.0.114" thiserror.workspace = true -tokenizers.workspace = true +tokenizers = { workspace = true, features = ["onig"] } tokio = { workspace = true, features = ["full", "tracing"] } tracing.workspace = true + +[features] +cuda = ["candle/cuda", "candle-nn/cuda"] +metal = ["candle/metal", "candle-nn/metal"] diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index e7a11a96..cf7f5f9c 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -1,3 +1,19 @@ +// use models::llama::run; + +use models::stable_diffusion::run; + +mod models; + fn main() { - println!("Hello, world!"); + // run( + // "The most important thing is ".to_string(), + // Default::default(), + // ) + // .unwrap(); + run( + "Green boat on ocean during storm".to_string(), + "".to_string(), + Default::default(), + ) + .unwrap(); } diff --git a/atoma-inference/src/models/llama.rs b/atoma-inference/src/models/llama.rs new file mode 100644 index 00000000..d5f03aae --- /dev/null +++ b/atoma-inference/src/models/llama.rs @@ -0,0 +1,162 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::{bail, Error as E, Result}; + +use candle::{DType, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use std::io::Write; + +use candle_transformers::models::llama as model; +use model::{Llama, LlamaConfig}; +use tokenizers::Tokenizer; + +use crate::models::{device, hub_load_safetensors, token_output_stream::TokenOutputStream}; + +const EOS_TOKEN: &str = ""; +const DEFAULT_PROMPT: &str = "My favorite theorem is "; + +#[derive(Clone, Debug, Copy, PartialEq, Eq)] +enum Which { + V1, + V2, + Solar10_7B, + TinyLlama1_1BChat, +} + +pub struct Config { + temperature: Option, + top_p: Option, + seed: u64, + sample_len: usize, + no_kv_cache: bool, + dtype: Option, + model_id: Option, + revision: Option, + which: Which, + use_flash_attn: bool, + repeat_penalty: f32, + repeat_last_n: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + temperature: None, + top_p: None, + seed: 299792458, + sample_len: 10000, + no_kv_cache: false, + dtype: None, + model_id: None, + revision: None, + which: Which::TinyLlama1_1BChat, + use_flash_attn: false, + repeat_penalty: 1., + repeat_last_n: 64, + } + } +} + +pub fn run(prompt: String, cfg: Config) -> Result<()> { + let device = device()?; + let dtype = match cfg.dtype.as_deref() { + Some("f16") => DType::F16, + Some("bf16") => DType::BF16, + Some("f32") => DType::F32, + Some(dtype) => bail!("Unsupported dtype {dtype}"), + None => DType::F16, + }; + let (llama, tokenizer_filename, mut cache) = { + let api = Api::new()?; + let model_id = cfg.model_id.unwrap_or_else(|| match cfg.which { + Which::V1 => "Narsil/amall-7b".to_string(), + Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(), + Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(), + Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(), + }); + println!("loading the model weights from {model_id}"); + let revision = cfg.revision.unwrap_or("main".to_string()); + let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + + let tokenizer_filename = api.get("tokenizer.json")?; + let config_filename = api.get("config.json")?; + let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let config = config.into_config(cfg.use_flash_attn); + + let filenames = match cfg.which { + Which::V1 | Which::V2 | Which::Solar10_7B => { + hub_load_safetensors(&api, "model.safetensors.index.json")? + } + Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?], + }; + let cache = model::Cache::new(!cfg.no_kv_cache, dtype, &config, &device)?; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + (Llama::load(vb, &config)?, tokenizer_filename, cache) + }; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let eos_token_id = tokenizer.token_to_id(EOS_TOKEN); + let mut tokens = tokenizer + .encode(prompt.clone(), true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + let mut tokenizer = TokenOutputStream::new(tokenizer); + println!("starting the inference loop"); + print!("{prompt}"); + let mut logits_processor = LogitsProcessor::new(cfg.seed, cfg.temperature, cfg.top_p); + let start_gen = std::time::Instant::now(); + let mut index_pos = 0; + let mut token_generated = 0; + for index in 0..cfg.sample_len { + let (context_size, context_index) = if cache.use_kv_cache && index > 0 { + (1, index_pos) + } else { + (tokens.len(), 0) + }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; + let logits = llama.forward(&input, context_index, &mut cache)?; + let logits = logits.squeeze(0)?; + let logits = if cfg.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(cfg.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + cfg.repeat_penalty, + &tokens[start_at..], + )? + }; + index_pos += ctxt.len(); + + let next_token = logits_processor.sample(&logits)?; + token_generated += 1; + tokens.push(next_token); + + if Some(next_token) == eos_token_id { + break; + } + if let Some(t) = tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + } + if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + let dt = start_gen.elapsed(); + println!( + "\n\n{} tokens generated ({} token/s)\n", + token_generated, + token_generated as f64 / dt.as_secs_f64(), + ); + Ok(()) +} diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs new file mode 100644 index 00000000..41eb67b1 --- /dev/null +++ b/atoma-inference/src/models/mod.rs @@ -0,0 +1,65 @@ +pub mod llama; +pub mod stable_diffusion; +pub mod token_output_stream; + +use anyhow::Result; +use candle::{ + utils::{cuda_is_available, metal_is_available}, + Device, Tensor, +}; + +pub fn device() -> Result { + if cuda_is_available() { + println!("Using CUDA"); + Ok(Device::new_cuda(0)?) + } else if metal_is_available() { + println!("Using Metal"); + Ok(Device::new_metal(0)?) + } else { + println!("Using Cpu"); + Ok(Device::Cpu) + } +} + +pub fn hub_load_safetensors( + repo: &hf_hub::api::sync::ApiRepo, + json_file: &str, +) -> candle::Result> { + let json_file = repo.get(json_file).map_err(candle::Error::wrap)?; + let json_file = std::fs::File::open(json_file)?; + let json: serde_json::Value = + serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; + let weight_map = match json.get("weight_map") { + None => candle::bail!("no weight map in {json_file:?}"), + Some(serde_json::Value::Object(map)) => map, + Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), + }; + let mut safetensors_files = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_files.insert(file.to_string()); + } + } + let safetensors_files = safetensors_files + .iter() + .map(|v| repo.get(v).map_err(candle::Error::wrap)) + .collect::>>()?; + Ok(safetensors_files) +} + +pub fn save_image>(img: &Tensor, p: P) -> candle::Result<()> { + let p = p.as_ref(); + let (channel, height, width) = img.dims3()?; + if channel != 3 { + candle::bail!("save_image expects an input of shape (3, height, width)") + } + let img = img.permute((1, 2, 0))?.flatten_all()?; + let pixels = img.to_vec1::()?; + let image: image::ImageBuffer, Vec> = + match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { + Some(image) => image, + None => candle::bail!("error saving image {p:?}"), + }; + image.save(p).map_err(candle::Error::wrap)?; + Ok(()) +} diff --git a/atoma-inference/src/models/stable_diffusion.rs b/atoma-inference/src/models/stable_diffusion.rs new file mode 100644 index 00000000..75b04ae9 --- /dev/null +++ b/atoma-inference/src/models/stable_diffusion.rs @@ -0,0 +1,532 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use candle_transformers::models::stable_diffusion; + +use anyhow::{Error as E, Result}; +use candle::{DType, Device, IndexOp, Module, Tensor, D}; +use tokenizers::Tokenizer; + +use crate::models::{device, save_image}; + +pub struct Config { + height: Option, + width: Option, + + /// The UNet weight file, in .safetensors format. + unet_weights: Option, + + /// The CLIP weight file, in .safetensors format. + clip_weights: Option, + + /// The VAE weight file, in .safetensors format. + vae_weights: Option, + + /// The file specifying the tokenizer to used for tokenization. + tokenizer: Option, + + /// The size of the sliced attention or 0 for automatic slicing (disabled by default) + sliced_attention_size: Option, + + /// The number of steps to run the diffusion for. + n_steps: Option, + + /// The number of samples to generate. + num_samples: i64, + + /// The name of the final image to generate. + final_image: String, + + sd_version: StableDiffusionVersion, + + /// Generate intermediary images at each step. + intermediary_images: bool, + + use_flash_attn: bool, + + use_f16: bool, + + guidance_scale: Option, + + img2img: Option, + + /// The strength, indicates how much to transform the initial image. The + /// value must be between 0 and 1, a value of 1 discards the initial image + /// information. + img2img_strength: f64, + + /// The seed to use when generating random samples. + seed: Option, +} + +impl Default for Config { + fn default() -> Self { + Self { + height: Some(256), + width: Some(256), + unet_weights: None, + clip_weights: None, + vae_weights: None, + tokenizer: None, + sliced_attention_size: None, + n_steps: Some(1000), + num_samples: 1, + final_image: "sd_final.png".to_string(), + sd_version: StableDiffusionVersion::V1_5, + intermediary_images: false, + use_flash_attn: false, + use_f16: true, + guidance_scale: None, + img2img: None, + img2img_strength: 0.8, + seed: None, + } + } +} + +#[derive(Clone, Copy)] +enum StableDiffusionVersion { + V1_5, + V2_1, + Xl, + Turbo, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ModelFile { + Tokenizer, + Tokenizer2, + Clip, + Clip2, + Unet, + Vae, +} + +impl StableDiffusionVersion { + fn repo(&self) -> &'static str { + match self { + Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0", + Self::V2_1 => "stabilityai/stable-diffusion-2-1", + Self::V1_5 => "runwayml/stable-diffusion-v1-5", + Self::Turbo => "stabilityai/sdxl-turbo", + } + } + + fn unet_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "unet/diffusion_pytorch_model.fp16.safetensors" + } else { + "unet/diffusion_pytorch_model.safetensors" + } + } + } + } + + fn vae_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "vae/diffusion_pytorch_model.fp16.safetensors" + } else { + "vae/diffusion_pytorch_model.safetensors" + } + } + } + } + + fn clip_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "text_encoder/model.fp16.safetensors" + } else { + "text_encoder/model.safetensors" + } + } + } + } + + fn clip2_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "text_encoder_2/model.fp16.safetensors" + } else { + "text_encoder_2/model.safetensors" + } + } + } + } +} + +impl ModelFile { + fn get( + &self, + filename: Option, + version: StableDiffusionVersion, + use_f16: bool, + ) -> Result { + use hf_hub::api::sync::Api; + match filename { + Some(filename) => Ok(std::path::PathBuf::from(filename)), + None => { + let (repo, path) = match self { + Self::Tokenizer => { + let tokenizer_repo = match version { + StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { + "openai/clip-vit-base-patch32" + } + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => { + // This seems similar to the patch32 version except some very small + // difference in the split regex. + "openai/clip-vit-large-patch14" + } + }; + (tokenizer_repo, "tokenizer.json") + } + Self::Tokenizer2 => { + ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json") + } + Self::Clip => (version.repo(), version.clip_file(use_f16)), + Self::Clip2 => (version.repo(), version.clip2_file(use_f16)), + Self::Unet => (version.repo(), version.unet_file(use_f16)), + Self::Vae => { + // Override for SDXL when using f16 weights. + // See https://github.com/huggingface/candle/issues/1060 + if matches!( + version, + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo, + ) && use_f16 + { + ( + "madebyollin/sdxl-vae-fp16-fix", + "diffusion_pytorch_model.safetensors", + ) + } else { + (version.repo(), version.vae_file(use_f16)) + } + } + }; + let filename = Api::new()?.model(repo.to_string()).get(path)?; + Ok(filename) + } + } + } +} + +fn output_filename( + basename: &str, + sample_idx: i64, + num_samples: i64, + timestep_idx: Option, +) -> String { + let filename = if num_samples > 1 { + match basename.rsplit_once('.') { + None => format!("{basename}.{sample_idx}.png"), + Some((filename_no_extension, extension)) => { + format!("{filename_no_extension}.{sample_idx}.{extension}") + } + } + } else { + basename.to_string() + }; + match timestep_idx { + None => filename, + Some(timestep_idx) => match filename.rsplit_once('.') { + None => format!("{filename}-{timestep_idx}.png"), + Some((filename_no_extension, extension)) => { + format!("{filename_no_extension}-{timestep_idx}.{extension}") + } + }, + } +} + +#[allow(clippy::too_many_arguments)] +fn text_embeddings( + prompt: &str, + uncond_prompt: &str, + tokenizer: Option, + clip_weights: Option, + sd_version: StableDiffusionVersion, + sd_config: &stable_diffusion::StableDiffusionConfig, + use_f16: bool, + device: &Device, + dtype: DType, + use_guide_scale: bool, + first: bool, +) -> Result { + let tokenizer_file = if first { + ModelFile::Tokenizer + } else { + ModelFile::Tokenizer2 + }; + let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?; + let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; + let pad_id = match &sd_config.clip.pad_with { + Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(), + None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(), + }; + println!("Running with prompt \"{prompt}\"."); + let mut tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + while tokens.len() < sd_config.clip.max_position_embeddings { + tokens.push(pad_id) + } + let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; + + println!("Building the Clip transformer."); + let clip_weights_file = if first { + ModelFile::Clip + } else { + ModelFile::Clip2 + }; + let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?; + let clip_config = if first { + &sd_config.clip + } else { + sd_config.clip2.as_ref().unwrap() + }; + let text_model = + stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?; + let text_embeddings = text_model.forward(&tokens)?; + + let text_embeddings = if use_guide_scale { + let mut uncond_tokens = tokenizer + .encode(uncond_prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + while uncond_tokens.len() < sd_config.clip.max_position_embeddings { + uncond_tokens.push(pad_id) + } + + let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; + let uncond_embeddings = text_model.forward(&uncond_tokens)?; + + Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)? + } else { + text_embeddings.to_dtype(dtype)? + }; + Ok(text_embeddings) +} + +fn image_preprocess>(path: T) -> anyhow::Result { + let img = image::io::Reader::open(path)?.decode()?; + let (height, width) = (img.height() as usize, img.width() as usize); + let height = height - height % 32; + let width = width - width % 32; + let img = img.resize_to_fill( + width as u32, + height as u32, + image::imageops::FilterType::CatmullRom, + ); + let img = img.to_rgb8(); + let img = img.into_raw(); + let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? + .permute((2, 0, 1))? + .to_dtype(DType::F32)? + .affine(2. / 255., -1.)? + .unsqueeze(0)?; + Ok(img) +} + +pub fn run(prompt: String, uncond_prompt: String, cfg: Config) -> Result<()> { + if !(0. ..=1.).contains(&cfg.img2img_strength) { + anyhow::bail!( + "img2img-strength should be between 0 and 1, got {}", + cfg.img2img_strength + ) + } + + let guidance_scale = match cfg.guidance_scale { + Some(guidance_scale) => guidance_scale, + None => match cfg.sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 7.5, + StableDiffusionVersion::Turbo => 0., + }, + }; + let n_steps = match cfg.n_steps { + Some(n_steps) => n_steps, + None => match cfg.sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 30, + StableDiffusionVersion::Turbo => 1, + }, + }; + let dtype = if cfg.use_f16 { DType::F16 } else { DType::F32 }; + let sd_config = match cfg.sd_version { + StableDiffusionVersion::V1_5 => stable_diffusion::StableDiffusionConfig::v1_5( + cfg.sliced_attention_size, + cfg.height, + cfg.width, + ), + StableDiffusionVersion::V2_1 => stable_diffusion::StableDiffusionConfig::v2_1( + cfg.sliced_attention_size, + cfg.height, + cfg.width, + ), + StableDiffusionVersion::Xl => stable_diffusion::StableDiffusionConfig::sdxl( + cfg.sliced_attention_size, + cfg.height, + cfg.width, + ), + StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo( + cfg.sliced_attention_size, + cfg.height, + cfg.width, + ), + }; + + let scheduler = sd_config.build_scheduler(n_steps)?; + let device = device()?; + if let Some(seed) = cfg.seed { + device.set_seed(seed)?; + } + let use_guide_scale = guidance_scale > 1.0; + + let which = match cfg.sd_version { + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false], + _ => vec![true], + }; + let text_embeddings = which + .iter() + .map(|first| { + text_embeddings( + &prompt, + &uncond_prompt, + cfg.tokenizer.clone(), + cfg.clip_weights.clone(), + cfg.sd_version, + &sd_config, + cfg.use_f16, + &device, + dtype, + use_guide_scale, + *first, + ) + }) + .collect::>>()?; + + let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?; + println!("{text_embeddings:?}"); + + println!("Building the autoencoder."); + let vae_weights = ModelFile::Vae.get(cfg.vae_weights, cfg.sd_version, cfg.use_f16)?; + let vae = sd_config.build_vae(vae_weights, &device, dtype)?; + let init_latent_dist = match &cfg.img2img { + None => None, + Some(image) => { + let image = image_preprocess(image)?.to_device(&device)?; + Some(vae.encode(&image)?) + } + }; + println!("Building the unet."); + let unet_weights = ModelFile::Unet.get(cfg.unet_weights, cfg.sd_version, cfg.use_f16)?; + let unet = sd_config.build_unet(unet_weights, &device, 4, cfg.use_flash_attn, dtype)?; + + let t_start = if cfg.img2img.is_some() { + n_steps - (n_steps as f64 * cfg.img2img_strength) as usize + } else { + 0 + }; + let bsize = 1; + + let vae_scale = match cfg.sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 0.18215, + StableDiffusionVersion::Turbo => 0.13025, + }; + + for idx in 0..cfg.num_samples { + let timesteps = scheduler.timesteps(); + let latents = match &init_latent_dist { + Some(init_latent_dist) => { + let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?; + if t_start < timesteps.len() { + let noise = latents.randn_like(0f64, 1f64)?; + scheduler.add_noise(&latents, noise, timesteps[t_start])? + } else { + latents + } + } + None => { + let latents = Tensor::randn( + 0f32, + 1f32, + (bsize, 4, sd_config.height / 8, sd_config.width / 8), + &device, + )?; + // scale the initial noise by the standard deviation required by the scheduler + (latents * scheduler.init_noise_sigma())? + } + }; + let mut latents = latents.to_dtype(dtype)?; + + println!("starting sampling"); + for (timestep_index, ×tep) in timesteps.iter().enumerate() { + if timestep_index < t_start { + continue; + } + let start_time = std::time::Instant::now(); + let latent_model_input = if use_guide_scale { + Tensor::cat(&[&latents, &latents], 0)? + } else { + latents.clone() + }; + + let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; + let noise_pred = + unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?; + + let noise_pred = if use_guide_scale { + let noise_pred = noise_pred.chunk(2, 0)?; + let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); + + (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * guidance_scale)?)? + } else { + noise_pred + }; + + latents = scheduler.step(&noise_pred, timestep, &latents)?; + let dt = start_time.elapsed().as_secs_f32(); + println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt); + + if cfg.intermediary_images { + let image = vae.decode(&(&latents / vae_scale)?)?; + let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; + let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; + let image_filename = output_filename( + &cfg.final_image, + idx + 1, + cfg.num_samples, + Some(timestep_index + 1), + ); + save_image(&image, image_filename)? + } + } + + println!( + "Generating the final image for sample {}/{}.", + idx + 1, + cfg.num_samples + ); + let image = vae.decode(&(&latents / vae_scale)?)?; + let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; + let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?; + let image_filename = output_filename(&cfg.final_image, idx + 1, cfg.num_samples, None); + save_image(&image, image_filename)? + } + Ok(()) +} diff --git a/atoma-inference/src/models/token_output_stream.rs b/atoma-inference/src/models/token_output_stream.rs new file mode 100644 index 00000000..1f507c5e --- /dev/null +++ b/atoma-inference/src/models/token_output_stream.rs @@ -0,0 +1,86 @@ +use candle::Result; + +/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a +/// streaming way rather than having to wait for the full decoding. +pub struct TokenOutputStream { + tokenizer: tokenizers::Tokenizer, + tokens: Vec, + prev_index: usize, + current_index: usize, +} + +impl TokenOutputStream { + pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { + Self { + tokenizer, + tokens: Vec::new(), + prev_index: 0, + current_index: 0, + } + } + + pub fn into_inner(self) -> tokenizers::Tokenizer { + self.tokenizer + } + + fn decode(&self, tokens: &[u32]) -> Result { + match self.tokenizer.decode(tokens, true) { + Ok(str) => Ok(str), + Err(err) => candle::bail!("cannot decode: {err}"), + } + } + + // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 + pub fn next_token(&mut self, token: u32) -> Result> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + self.tokens.push(token); + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { + let text = text.split_at(prev_text.len()); + self.prev_index = self.current_index; + self.current_index = self.tokens.len(); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_rest(&self) -> Result> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() { + let text = text.split_at(prev_text.len()); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_all(&self) -> Result { + self.decode(&self.tokens) + } + + pub fn get_token(&self, token_s: &str) -> Option { + self.tokenizer.get_vocab(true).get(token_s).copied() + } + + pub fn tokenizer(&self) -> &tokenizers::Tokenizer { + &self.tokenizer + } + + pub fn clear(&mut self) { + self.tokens.clear(); + self.prev_index = 0; + self.current_index = 0; + } +}