From 313ef8569be3d28f608fc418ad52abaa4b7a4b57 Mon Sep 17 00:00:00 2001 From: Martin Stefcek Date: Tue, 2 Apr 2024 21:20:49 +0400 Subject: [PATCH] add stable diffusion --- .gitignore | 1 + Cargo.toml | 15 +- atoma-inference/Cargo.toml | 15 +- atoma-inference/src/candle/mod.rs | 89 +++ .../src/candle/stable_diffusion.rs | 541 ++++++++++++++++++ .../src/candle/token_output_stream.rs | 86 +++ atoma-inference/src/lib.rs | 4 + 7 files changed, 745 insertions(+), 6 deletions(-) create mode 100644 atoma-inference/src/candle/mod.rs create mode 100644 atoma-inference/src/candle/stable_diffusion.rs create mode 100644 atoma-inference/src/candle/token_output_stream.rs diff --git a/.gitignore b/.gitignore index 1e7caa9e..9182e2f4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ Cargo.lock target/ +.vscode/ diff --git a/Cargo.toml b/Cargo.toml index e37a9e35..35d112f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,20 @@ [workspace] resolver = "2" +edition = "2021" -members = ["atoma-event-subscribe", "atoma-inference", "atoma-networking", "atoma-json-rpc", "atoma-storage"] +members = [ + "atoma-event-subscribe", + "atoma-inference", + "atoma-networking", + "atoma-json-rpc", + "atoma-storage", +] [workspace.package] version = "0.1.0" [workspace.dependencies] +anyhow = "1.0.81" async-trait = "0.1.78" candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" } candle-flash-attn = { git = "https://github.com/huggingface/candle", package = "candle-flash-attn", version = "0.4.2" } @@ -16,6 +24,11 @@ config = "0.14.0" ed25519-consensus = "2.1.0" futures = "0.3.30" hf-hub = "0.3.2" +clap = "4.5.3" +image = { version = "0.25.0", default-features = false, features = [ + "jpeg", + "png", +] } serde = "1.0.197" serde_json = "1.0.114" rand = "0.8.5" diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index 4b7e81e3..74f6cdaa 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -1,11 +1,10 @@ [package] name = "inference" -version = "0.1.0" +version.workspace = true edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] +anyhow.workspace = true async-trait.workspace = true candle.workspace = true candle-flash-attn = { workspace = true, optional = true } @@ -18,8 +17,10 @@ hf-hub.workspace = true reqwest = { workspace = true, features = ["json"] } serde = { workspace = true, features = ["derive"] } serde_json.workspace = true +clap.workspace = true +image = { workspace = true } thiserror.workspace = true -tokenizers.workspace = true +tokenizers = { workspace = true, features = ["onig"] } tokio = { workspace = true, features = ["full", "tracing"] } tracing.workspace = true tracing-subscriber.workspace = true @@ -30,7 +31,11 @@ toml.workspace = true [features] -accelerate = ["candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] +accelerate = [ + "candle/accelerate", + "candle-nn/accelerate", + "candle-transformers/accelerate", +] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cudnn = ["candle/cudnn"] flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] diff --git a/atoma-inference/src/candle/mod.rs b/atoma-inference/src/candle/mod.rs new file mode 100644 index 00000000..8fe9d9b6 --- /dev/null +++ b/atoma-inference/src/candle/mod.rs @@ -0,0 +1,89 @@ +pub mod stable_diffusion; +pub mod token_output_stream; + +use std::{fs::File, io::Write, path::PathBuf}; + +use candle::{ + utils::{cuda_is_available, metal_is_available}, + DType, Device, Tensor, +}; +use tracing::info; + +use crate::models::ModelError; + +pub trait CandleModel { + type Fetch; + type Input; + fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError>; + fn inference(input: Self::Input) -> Result, ModelError>; +} + +pub fn device() -> Result { + if cuda_is_available() { + info!("Using CUDA"); + Device::new_cuda(0) + } else if metal_is_available() { + info!("Using Metal"); + Device::new_metal(0) + } else { + info!("Using Cpu"); + Ok(Device::Cpu) + } +} + +pub fn hub_load_safetensors( + repo: &hf_hub::api::sync::ApiRepo, + json_file: &str, +) -> candle::Result> { + let json_file = repo.get(json_file).map_err(candle::Error::wrap)?; + let json_file = std::fs::File::open(json_file)?; + let json: serde_json::Value = + serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; + let weight_map = match json.get("weight_map") { + None => candle::bail!("no weight map in {json_file:?}"), + Some(serde_json::Value::Object(map)) => map, + Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), + }; + let mut safetensors_files = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_files.insert(file.to_string()); + } + } + let safetensors_files = safetensors_files + .iter() + .map(|v| repo.get(v).map_err(candle::Error::wrap)) + .collect::>>()?; + Ok(safetensors_files) +} + +pub fn save_image>(img: &Tensor, p: P) -> candle::Result<()> { + let p = p.as_ref(); + let (channel, height, width) = img.dims3()?; + if channel != 3 { + candle::bail!("save_image expects an input of shape (3, height, width)") + } + let img = img.permute((1, 2, 0))?.flatten_all()?; + let pixels = img.to_vec1::()?; + let image: image::ImageBuffer, Vec> = + match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { + Some(image) => image, + None => candle::bail!("error saving image {p:?}"), + }; + image.save(p).map_err(candle::Error::wrap)?; + Ok(()) +} + +pub fn save_tensor_to_file(tensor: &Tensor, filename: &str) -> Result<(), candle::Error> { + let json_output = serde_json::to_string( + &tensor + .to_device(&Device::Cpu)? + .flatten_all()? + .to_dtype(DType::F64)? + .to_vec1::()?, + ) + .unwrap(); + let mut file = File::create(PathBuf::from(filename))?; + file.write_all(json_output.as_bytes())?; + Ok(()) +} diff --git a/atoma-inference/src/candle/stable_diffusion.rs b/atoma-inference/src/candle/stable_diffusion.rs new file mode 100644 index 00000000..c8718e4c --- /dev/null +++ b/atoma-inference/src/candle/stable_diffusion.rs @@ -0,0 +1,541 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use candle_transformers::models::stable_diffusion::{self}; + +use candle::{DType, Device, IndexOp, Module, Tensor, D}; +use tokenizers::Tokenizer; + +use crate::{candle::device, models::ModelError}; + +use super::{save_tensor_to_file, CandleModel}; + +pub struct Input { + prompt: String, + uncond_prompt: String, + + height: Option, + width: Option, + + /// The UNet weight file, in .safetensors format. + unet_weights: Option, + + /// The CLIP weight file, in .safetensors format. + clip_weights: Option, + + /// The VAE weight file, in .safetensors format. + vae_weights: Option, + + /// The file specifying the tokenizer to used for tokenization. + tokenizer: Option, + + /// The size of the sliced attention or 0 for automatic slicing (disabled by default) + sliced_attention_size: Option, + + /// The number of steps to run the diffusion for. + n_steps: Option, + + /// The number of samples to generate. + num_samples: i64, + + sd_version: StableDiffusionVersion, + + /// Generate intermediary images at each step. + intermediary_images: bool, + + use_flash_attn: bool, + + use_f16: bool, + + guidance_scale: Option, + + img2img: Option, + + /// The strength, indicates how much to transform the initial image. The + /// value must be between 0 and 1, a value of 1 discards the initial image + /// information. + img2img_strength: f64, + + /// The seed to use when generating random samples. + seed: Option, +} + +impl Input { + pub fn default_prompt(prompt: String) -> Self { + Self { + prompt, + uncond_prompt: "".to_string(), + height: Some(256), + width: Some(256), + unet_weights: None, + clip_weights: None, + vae_weights: None, + tokenizer: None, + sliced_attention_size: None, + n_steps: Some(20), + num_samples: 1, + sd_version: StableDiffusionVersion::V1_5, + intermediary_images: false, + use_flash_attn: false, + use_f16: true, + guidance_scale: None, + img2img: None, + img2img_strength: 0.8, + seed: Some(0), + } + } +} + +impl From<&Input> for Fetch { + fn from(input: &Input) -> Self { + Self { + tokenizer: input.tokenizer.clone(), + sd_version: input.sd_version, + use_f16: input.use_f16, + clip_weights: input.clip_weights.clone(), + vae_weights: input.vae_weights.clone(), + unet_weights: input.unet_weights.clone(), + } + } +} +pub struct StableDiffusion {} + +pub struct Fetch { + tokenizer: Option, + sd_version: StableDiffusionVersion, + use_f16: bool, + clip_weights: Option, + vae_weights: Option, + unet_weights: Option, +} + +impl CandleModel for StableDiffusion { + type Input = Input; + type Fetch = Fetch; + + fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError> { + let which = match fetch.sd_version { + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false], + _ => vec![true], + }; + for first in which { + let (clip_weights_file, tokenizer_file) = if first { + (ModelFile::Clip, ModelFile::Tokenizer) + } else { + (ModelFile::Clip2, ModelFile::Tokenizer2) + }; + + clip_weights_file.get(fetch.clip_weights.clone(), fetch.sd_version, false)?; + ModelFile::Vae.get(fetch.vae_weights.clone(), fetch.sd_version, fetch.use_f16)?; + tokenizer_file.get(fetch.tokenizer.clone(), fetch.sd_version, fetch.use_f16)?; + ModelFile::Unet.get(fetch.unet_weights.clone(), fetch.sd_version, fetch.use_f16)?; + } + Ok(()) + } + + fn inference(input: Self::Input) -> Result, ModelError> { + if !(0. ..=1.).contains(&input.img2img_strength) { + Err(ModelError::Config(format!( + "img2img_strength must be between 0 and 1, got {}", + input.img2img_strength, + )))? + } + + let guidance_scale = match input.guidance_scale { + Some(guidance_scale) => guidance_scale, + None => match input.sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 7.5, + StableDiffusionVersion::Turbo => 0., + }, + }; + let n_steps = match input.n_steps { + Some(n_steps) => n_steps, + None => match input.sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 30, + StableDiffusionVersion::Turbo => 1, + }, + }; + let dtype = if input.use_f16 { + DType::F16 + } else { + DType::F32 + }; + let sd_config = match input.sd_version { + StableDiffusionVersion::V1_5 => stable_diffusion::StableDiffusionConfig::v1_5( + input.sliced_attention_size, + input.height, + input.width, + ), + StableDiffusionVersion::V2_1 => stable_diffusion::StableDiffusionConfig::v2_1( + input.sliced_attention_size, + input.height, + input.width, + ), + StableDiffusionVersion::Xl => stable_diffusion::StableDiffusionConfig::sdxl( + input.sliced_attention_size, + input.height, + input.width, + ), + StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo( + input.sliced_attention_size, + input.height, + input.width, + ), + }; + + let scheduler = sd_config.build_scheduler(n_steps)?; + let device = device()?; + if let Some(seed) = input.seed { + device.set_seed(seed)?; + } + let use_guide_scale = guidance_scale > 1.0; + + let which = match input.sd_version { + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false], + _ => vec![true], + }; + let text_embeddings = which + .iter() + .map(|first| { + Self::text_embeddings( + &input.prompt, + &input.uncond_prompt, + input.tokenizer.clone(), + input.clip_weights.clone(), + input.sd_version, + &sd_config, + input.use_f16, + &device, + dtype, + use_guide_scale, + *first, + ) + }) + .collect::, _>>()?; + + let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?; + + let vae_weights = ModelFile::Vae.get(input.vae_weights, input.sd_version, input.use_f16)?; + let vae = sd_config.build_vae(vae_weights, &device, dtype)?; + let init_latent_dist = match &input.img2img { + None => None, + Some(image) => { + let image = Self::image_preprocess(image)?.to_device(&device)?; + Some(vae.encode(&image)?) + } + }; + let unet_weights = + ModelFile::Unet.get(input.unet_weights, input.sd_version, input.use_f16)?; + let unet = sd_config.build_unet(unet_weights, &device, 4, input.use_flash_attn, dtype)?; + + let t_start = if input.img2img.is_some() { + n_steps - (n_steps as f64 * input.img2img_strength) as usize + } else { + 0 + }; + let bsize = 1; + + let vae_scale = match input.sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 0.18215, + StableDiffusionVersion::Turbo => 0.13025, + }; + let mut res = Vec::new(); + + for _ in 0..input.num_samples { + let timesteps = scheduler.timesteps(); + let latents = match &init_latent_dist { + Some(init_latent_dist) => { + let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?; + if t_start < timesteps.len() { + let noise = latents.randn_like(0f64, 1f64)?; + scheduler.add_noise(&latents, noise, timesteps[t_start])? + } else { + latents + } + } + None => { + let latents = Tensor::randn( + 0f32, + 1f32, + (bsize, 4, sd_config.height / 8, sd_config.width / 8), + &device, + )?; + // scale the initial noise by the standard deviation required by the scheduler + (latents * scheduler.init_noise_sigma())? + } + }; + let mut latents = latents.to_dtype(dtype)?; + + for (timestep_index, ×tep) in timesteps.iter().enumerate() { + if timestep_index < t_start { + continue; + } + let latent_model_input = if use_guide_scale { + Tensor::cat(&[&latents, &latents], 0)? + } else { + latents.clone() + }; + + let latent_model_input = + scheduler.scale_model_input(latent_model_input, timestep)?; + let noise_pred = + unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?; + + let noise_pred = if use_guide_scale { + let noise_pred = noise_pred.chunk(2, 0)?; + let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); + + (noise_pred_uncond + + ((noise_pred_text - noise_pred_uncond)? * guidance_scale)?)? + } else { + noise_pred + }; + + latents = scheduler.step(&noise_pred, timestep, &latents)?; + } + save_tensor_to_file(&latents, "tensor1")?; + let image = vae.decode(&(&latents / vae_scale)?)?; + save_tensor_to_file(&image, "tensor2")?; + let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; + save_tensor_to_file(&image, "tensor3")?; + let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?; + save_tensor_to_file(&image, "tensor4")?; + res.push(image); + } + Ok(res) + } +} + +#[derive(Clone, Copy)] +enum StableDiffusionVersion { + V1_5, + V2_1, + Xl, + Turbo, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ModelFile { + Tokenizer, + Tokenizer2, + Clip, + Clip2, + Unet, + Vae, +} + +impl StableDiffusionVersion { + fn repo(&self) -> &'static str { + match self { + Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0", + Self::V2_1 => "stabilityai/stable-diffusion-2-1", + Self::V1_5 => "runwayml/stable-diffusion-v1-5", + Self::Turbo => "stabilityai/sdxl-turbo", + } + } + + fn unet_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "unet/diffusion_pytorch_model.fp16.safetensors" + } else { + "unet/diffusion_pytorch_model.safetensors" + } + } + } + } + + fn vae_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "vae/diffusion_pytorch_model.fp16.safetensors" + } else { + "vae/diffusion_pytorch_model.safetensors" + } + } + } + } + + fn clip_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "text_encoder/model.fp16.safetensors" + } else { + "text_encoder/model.safetensors" + } + } + } + } + + fn clip2_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + if use_f16 { + "text_encoder_2/model.fp16.safetensors" + } else { + "text_encoder_2/model.safetensors" + } + } + } + } +} + +impl ModelFile { + fn get( + &self, + filename: Option, + version: StableDiffusionVersion, + use_f16: bool, + ) -> Result { + use hf_hub::api::sync::Api; + match filename { + Some(filename) => Ok(std::path::PathBuf::from(filename)), + None => { + let (repo, path) = match self { + Self::Tokenizer => { + let tokenizer_repo = match version { + StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { + "openai/clip-vit-base-patch32" + } + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => { + // This seems similar to the patch32 version except some very small + // difference in the split regex. + "openai/clip-vit-large-patch14" + } + }; + (tokenizer_repo, "tokenizer.json") + } + Self::Tokenizer2 => { + ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json") + } + Self::Clip => (version.repo(), version.clip_file(use_f16)), + Self::Clip2 => (version.repo(), version.clip2_file(use_f16)), + Self::Unet => (version.repo(), version.unet_file(use_f16)), + Self::Vae => { + // Override for SDXL when using f16 weights. + // See https://github.com/huggingface/candle/issues/1060 + if matches!( + version, + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo, + ) && use_f16 + { + ( + "madebyollin/sdxl-vae-fp16-fix", + "diffusion_pytorch_model.safetensors", + ) + } else { + (version.repo(), version.vae_file(use_f16)) + } + } + }; + let filename = Api::new()?.model(repo.to_string()).get(path)?; + Ok(filename) + } + } + } +} + +impl StableDiffusion { + #[allow(clippy::too_many_arguments)] + fn text_embeddings( + prompt: &str, + uncond_prompt: &str, + tokenizer: Option, + clip_weights: Option, + sd_version: StableDiffusionVersion, + sd_config: &stable_diffusion::StableDiffusionConfig, + use_f16: bool, + device: &Device, + dtype: DType, + use_guide_scale: bool, + first: bool, + ) -> Result { + let (clip_weights_file, tokenizer_file) = if first { + (ModelFile::Clip, ModelFile::Tokenizer) + } else { + (ModelFile::Clip2, ModelFile::Tokenizer2) + }; + let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?; + let tokenizer = Tokenizer::from_file(tokenizer)?; + let pad_id = match &sd_config.clip.pad_with { + Some(padding) => { + *tokenizer + .get_vocab(true) + .get(padding.as_str()) + .ok_or(ModelError::Msg(format!( + "Padding token {padding} not found in the tokenizer vocabulary" + )))? + } + None => *tokenizer + .get_vocab(true) + .get("<|endoftext|>") + .ok_or(ModelError::Msg("".to_string()))?, + }; + let mut tokens = tokenizer.encode(prompt, true)?.get_ids().to_vec(); + while tokens.len() < sd_config.clip.max_position_embeddings { + tokens.push(pad_id) + } + let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; + + let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?; + let clip_config = if first { + &sd_config.clip + } else { + sd_config.clip2.as_ref().unwrap() + }; + let text_model = stable_diffusion::build_clip_transformer( + clip_config, + clip_weights, + device, + DType::F64, + )?; + let text_embeddings = text_model.forward(&tokens)?; + + let text_embeddings = if use_guide_scale { + let mut uncond_tokens = tokenizer.encode(uncond_prompt, true)?.get_ids().to_vec(); + while uncond_tokens.len() < sd_config.clip.max_position_embeddings { + uncond_tokens.push(pad_id) + } + + let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; + let uncond_embeddings = text_model.forward(&uncond_tokens)?; + + Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)? + } else { + text_embeddings.to_dtype(dtype)? + }; + Ok(text_embeddings) + } + + fn image_preprocess>(path: T) -> Result { + let img = image::io::Reader::open(path)?.decode()?; + let (height, width) = (img.height() as usize, img.width() as usize); + let height = height - height % 32; + let width = width - width % 32; + let img = img.resize_to_fill( + width as u32, + height as u32, + image::imageops::FilterType::CatmullRom, + ); + let img = img.to_rgb8(); + let img = img.into_raw(); + let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? + .permute((2, 0, 1))? + .to_dtype(DType::F32)? + .affine(2. / 255., -1.)? + .unsqueeze(0)?; + Ok(img) + } +} diff --git a/atoma-inference/src/candle/token_output_stream.rs b/atoma-inference/src/candle/token_output_stream.rs new file mode 100644 index 00000000..1f507c5e --- /dev/null +++ b/atoma-inference/src/candle/token_output_stream.rs @@ -0,0 +1,86 @@ +use candle::Result; + +/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a +/// streaming way rather than having to wait for the full decoding. +pub struct TokenOutputStream { + tokenizer: tokenizers::Tokenizer, + tokens: Vec, + prev_index: usize, + current_index: usize, +} + +impl TokenOutputStream { + pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { + Self { + tokenizer, + tokens: Vec::new(), + prev_index: 0, + current_index: 0, + } + } + + pub fn into_inner(self) -> tokenizers::Tokenizer { + self.tokenizer + } + + fn decode(&self, tokens: &[u32]) -> Result { + match self.tokenizer.decode(tokens, true) { + Ok(str) => Ok(str), + Err(err) => candle::bail!("cannot decode: {err}"), + } + } + + // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 + pub fn next_token(&mut self, token: u32) -> Result> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + self.tokens.push(token); + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { + let text = text.split_at(prev_text.len()); + self.prev_index = self.current_index; + self.current_index = self.tokens.len(); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_rest(&self) -> Result> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() { + let text = text.split_at(prev_text.len()); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_all(&self) -> Result { + self.decode(&self.tokens) + } + + pub fn get_token(&self, token_s: &str) -> Option { + self.tokenizer.get_vocab(true).get(token_s).copied() + } + + pub fn tokenizer(&self) -> &tokenizers::Tokenizer { + &self.tokenizer + } + + pub fn clear(&mut self) { + self.tokens.clear(); + self.prev_index = 0; + self.current_index = 0; + } +} diff --git a/atoma-inference/src/lib.rs b/atoma-inference/src/lib.rs index 539230f0..27cbfd46 100644 --- a/atoma-inference/src/lib.rs +++ b/atoma-inference/src/lib.rs @@ -1,4 +1,8 @@ pub mod model_thread; +pub mod candle; +pub mod config; +pub mod core_thread; +pub mod models; pub mod service; pub mod specs; pub mod types;