From 6f5487b6c28fe61ca39581a93bbc44251e4d16c9 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Fri, 22 Mar 2024 04:00:27 +0000 Subject: [PATCH 1/4] feat: Update dependencies to include latest versions of candle and candle-transformers This commit updates the Cargo.toml files to include the latest versions of the `candle` and `candle-transformers` crates from the Hugging Face organization. Specifically, it adds `candle-core` version 0.4.2 and `candle-transformers` version 0.4.2 as dependencies. Additionally, it adjusts the crate imports and usage in the codebase to reflect these changes. Changes made: - In `Cargo.toml`, replaced the `candle-nn` dependency with `candle` and `candle-transformers` dependencies with their respective versions. - Updated the imports and usage of the `candle` and `candle-transformers` crates in the codebase. --- Cargo.toml | 3 +- atoma-inference/Cargo.toml | 3 +- atoma-inference/src/config.rs | 2 +- atoma-inference/src/lib.rs | 1 + atoma-inference/src/models.rs | 150 +++++++++++++++++++++++++++++++++ atoma-inference/src/service.rs | 4 +- atoma-inference/src/types.rs | 34 +------- 7 files changed, 163 insertions(+), 34 deletions(-) create mode 100644 atoma-inference/src/models.rs diff --git a/Cargo.toml b/Cargo.toml index a93ff302..acd35f3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,8 @@ version = "0.1.0" [workspace.dependencies] async-trait = "0.1.78" -candle-nn = "0.4.1" +candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" } +candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" } ed25519-consensus = "2.1.0" serde = "1.0.197" thiserror = "1.0.58" diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index 31cea4b1..b8c76dd3 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -7,7 +7,8 @@ edition = "2021" [dependencies] async-trait.workspace = true -candle-nn.workspace = true +candle.workspace = true +candle-transformers.workspace = true ed25519-consensus.workspace = true serde = { workspace = true, features = ["derive"] } thiserror.workspace = true diff --git a/atoma-inference/src/config.rs b/atoma-inference/src/config.rs index 413d8088..63d35627 100644 --- a/atoma-inference/src/config.rs +++ b/atoma-inference/src/config.rs @@ -1,8 +1,8 @@ use std::path::PathBuf; use crate::{ + models::ModelType, specs::{HardwareSpec, SoftwareSpec}, - types::ModelType, }; pub struct InferenceConfig { diff --git a/atoma-inference/src/lib.rs b/atoma-inference/src/lib.rs index 708d5da1..1ca7c952 100644 --- a/atoma-inference/src/lib.rs +++ b/atoma-inference/src/lib.rs @@ -1,5 +1,6 @@ pub mod config; pub mod core_thread; +pub mod models; pub mod service; pub mod specs; pub mod types; diff --git a/atoma-inference/src/models.rs b/atoma-inference/src/models.rs new file mode 100644 index 00000000..59a15a7b --- /dev/null +++ b/atoma-inference/src/models.rs @@ -0,0 +1,150 @@ +use std::{error::Error, fmt::Display}; + +use candle::Device; +use candle_transformers::{ + models::{ + llama::{Config as LlamaConfig, Llama}, + llama2_c::{Config as Llama2Config, Llama as Llama2}, + mamba::{Config as MambaConfig, Model as MambaModel}, + mistral::{Config as MistralConfig, Model as MistralModel}, + mixtral::{Config as MixtralConfig, Model as MixtralModel}, + stable_diffusion::StableDiffusionConfig, + }, + quantized_var_builder::VarBuilder, +}; + +use tokenizers::Tokenizer; + +#[derive(Clone, Debug)] +pub enum ModelType { + Llama(usize), + Llama2(usize), + Mamba(usize), + Mixtral8x7b, + Mistral(usize), + StableDiffusionV1_5, + StableDiffusionV2_1, + StableDiffusionXl, + StableDiffusionTurbo, +} + +impl Display for ModelType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Llama(size) => write!(f, "llama({})", size), + Self::Llama2(size) => write!(f, "llama2({})", size), + Self::Mamba(size) => write!(f, "mamba({})", size), + Self::Mixtral8x7b => write!(f, "mixtral_8x7b"), + Self::Mistral(size) => write!(f, "mistral({})", size), + Self::StableDiffusionV1_5 => write!(f, "stable_diffusion_v1_5"), + Self::StableDiffusionV2_1 => write!(f, "stable_diffusion_v2_1"), + Self::StableDiffusionXl => write!(f, "stable_diffusion_xl"), + Self::StableDiffusionTurbo => write!(f, "stable_diffusion_turbo"), + } + } +} + +#[derive(Clone)] +pub enum ModelConfig { + Llama(LlamaConfig), + Llama2(Llama2Config), + Mamba(MambaConfig), + Mixtral8x7b(MixtralConfig), + Mistral(MistralConfig), + StableDiffusion(StableDiffusionConfig), +} + +impl From for ModelConfig { + fn from(_model_type: ModelType) -> Self { + todo!() + } +} + +pub trait ModelApi { + fn load(model_specs: ModelSpecs, var_builder: VarBuilder) -> Self; + fn run(&self, input: String) -> Result>; +} + +#[allow(dead_code)] +pub struct ModelSpecs { + pub(crate) config: ModelConfig, + pub(crate) device: Device, + pub(crate) tokenizer: Tokenizer, +} + +pub enum Model { + Llama { + model_specs: ModelSpecs, + model: Llama, + }, + Llama2 { + model_specs: ModelSpecs, + model: Llama2, + }, + Mamba { + model_specs: ModelSpecs, + model: MambaModel, + }, + Mixtral8x7b { + model_specs: ModelSpecs, + model: MixtralModel, + }, + Mistral { + model_specs: ModelSpecs, + model: MistralModel, + }, +} + +impl ModelApi for Model { + fn load(model_specs: ModelSpecs, var_builder: VarBuilder) -> Self { + let model_config = model_specs.config.clone(); + match model_config { + ModelConfig::Llama(config) => { + let model = load_llama_model(config, var_builder); + Self::Llama { model, model_specs } + } + ModelConfig::Llama2(config) => { + let model = load_llama2_model(config, var_builder); + Self::Llama2 { model_specs, model } + } + ModelConfig::Mamba(config) => { + let model = load_mamba_model(config, var_builder); + Self::Mamba { model_specs, model } + } + ModelConfig::Mistral(config) => { + let model = load_mistral(config, var_builder); + Self::Mistral { model_specs, model } + } + ModelConfig::Mixtral8x7b(config) => { + let model = load_mixtral(config, var_builder); + Self::Mixtral8x7b { model_specs, model } + } + ModelConfig::StableDiffusion(config) => { + panic!("TODO: implement it") + } + } + } + + fn run(&self, input: String) -> Result> { + todo!() + } +} + +fn load_llama_model(config: LlamaConfig, var_builder: VarBuilder) -> Llama { + todo!() +} + +fn load_llama2_model(config: Llama2Config, var_builder: VarBuilder) -> Llama2 { + todo!() +} + +fn load_mamba_model(config: MambaConfig, var_builder: VarBuilder) -> MambaModel { + todo!() +} +fn load_mistral(config: MistralConfig, var_builder: VarBuilder) -> MistralModel { + todo!() +} + +fn load_mixtral(config: MixtralConfig, var_builder: VarBuilder) -> MixtralModel { + todo!() +} diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index e1a7ec19..924e94f5 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -3,7 +3,8 @@ use thiserror::Error; use crate::{ config::InferenceConfig, - types::{InferenceResponse, ModelResponse, ModelType, Prompt, QuantizationMethod, Temperature}, + models::ModelType, + types::{InferenceResponse, ModelResponse, Prompt, QuantizationMethod, Temperature}, }; #[derive(Debug, Error)] @@ -23,6 +24,7 @@ pub trait ApiTrait { #[allow(dead_code)] pub struct InferenceCore { config: InferenceConfig, + // models: Vec, pub(crate) public_key: PublicKey, private_key: PrivateKey, web2_api: T, diff --git a/atoma-inference/src/types.rs b/atoma-inference/src/types.rs index 87e5010b..9c7e4cd7 100644 --- a/atoma-inference/src/types.rs +++ b/atoma-inference/src/types.rs @@ -1,38 +1,12 @@ -use std::fmt::Display; - +use crate::models::ModelType; use ed25519_consensus::VerificationKey; -#[derive(Clone, Debug)] -pub struct Prompt(pub(crate) String); - -#[derive(Clone, Debug)] -pub enum ModelType { - Llama2(usize), - Mamba(usize), - Mixtral8x7b, - Mistral(usize), - StableDiffusionV1, - StableDiffusionV2, - StableDiffusionV3, -} - -impl Display for ModelType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Llama2(size) => write!(f, "Llama2({})", size), - Self::Mamba(size) => write!(f, "Mamba({})", size), - Self::Mixtral8x7b => write!(f, "Mixtral8x7b"), - Self::Mistral(size) => write!(f, "Mistral({})", size), - Self::StableDiffusionV1 => write!(f, "StableDiffusionV1"), - Self::StableDiffusionV2 => write!(f, "StableDiffusionV2"), - Self::StableDiffusionV3 => write!(f, "StableDiffusionV3"), - } - } -} - pub type NodeId = VerificationKey; pub type Temperature = f32; +#[derive(Clone, Debug)] +pub struct Prompt(pub(crate) String); + #[derive(Clone, Debug)] pub struct InferenceRequest { pub(crate) prompt: Prompt, From b5cd4aa37512347e16fd65a99a9faf485ba96aac Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Fri, 22 Mar 2024 04:59:27 +0000 Subject: [PATCH 2/4] feat: Integrate candle-nn dependency and update model inference functionality This commit updates the project dependencies by integrating as a workspace dependency in the Cargo.toml files. Additionally, it enhances the model inference functionality to support more advanced options such as specifying the maximum number of tokens, random seed, temperature, and top-p value. Changes made: - Added as a workspace dependency in the Cargo.toml files. - Updated the model inference functionality in the trait to accept additional parameters for controlling model inference. - Implemented the model inference logic in the enum variants to utilize the specified parameters for generating model predictions. - Defined the enum to handle errors related to model loading and tokenization. --- Cargo.toml | 1 + atoma-inference/Cargo.toml | 1 + atoma-inference/src/models.rs | 97 ++++++++++++++++++++++++---------- atoma-inference/src/service.rs | 4 +- atoma-inference/src/types.rs | 13 +++-- 5 files changed, 80 insertions(+), 36 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index acd35f3f..0d1fdee6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ version = "0.1.0" [workspace.dependencies] async-trait = "0.1.78" candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" } +candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" } candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" } ed25519-consensus = "2.1.0" serde = "1.0.197" diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index b8c76dd3..82e2740f 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" [dependencies] async-trait.workspace = true candle.workspace = true +candle-nn.workspace = true candle-transformers.workspace = true ed25519-consensus.workspace = true serde = { workspace = true, features = ["derive"] } diff --git a/atoma-inference/src/models.rs b/atoma-inference/src/models.rs index 59a15a7b..884e0626 100644 --- a/atoma-inference/src/models.rs +++ b/atoma-inference/src/models.rs @@ -1,7 +1,9 @@ -use std::{error::Error, fmt::Display}; +use std::fmt::Display; -use candle::Device; +use candle::{DType, Device, Error as CandleError}; +use candle_nn::VarBuilder; use candle_transformers::{ + generation::LogitsProcessor, models::{ llama::{Config as LlamaConfig, Llama}, llama2_c::{Config as Llama2Config, Llama as Llama2}, @@ -10,11 +12,13 @@ use candle_transformers::{ mixtral::{Config as MixtralConfig, Model as MixtralModel}, stable_diffusion::StableDiffusionConfig, }, - quantized_var_builder::VarBuilder, }; +use thiserror::Error; use tokenizers::Tokenizer; +use crate::types::Temperature; + #[derive(Clone, Debug)] pub enum ModelType { Llama(usize), @@ -62,13 +66,21 @@ impl From for ModelConfig { pub trait ModelApi { fn load(model_specs: ModelSpecs, var_builder: VarBuilder) -> Self; - fn run(&self, input: String) -> Result>; + fn run( + &self, + input: String, + max_tokens: usize, + random_seed: usize, + temperature: Temperature, + top_p: f32, + ) -> Result; } #[allow(dead_code)] pub struct ModelSpecs { pub(crate) config: ModelConfig, pub(crate) device: Device, + pub(crate) dtype: DType, pub(crate) tokenizer: Tokenizer, } @@ -100,51 +112,82 @@ impl ModelApi for Model { let model_config = model_specs.config.clone(); match model_config { ModelConfig::Llama(config) => { - let model = load_llama_model(config, var_builder); + let model = Llama::load(var_builder, &config).expect("Failed to load LlaMa model"); Self::Llama { model, model_specs } } ModelConfig::Llama2(config) => { - let model = load_llama2_model(config, var_builder); + let model = Llama2::load(var_builder, config).expect("Failed to load LlaMa2 model"); Self::Llama2 { model_specs, model } } ModelConfig::Mamba(config) => { - let model = load_mamba_model(config, var_builder); + let model = + MambaModel::new(&config, var_builder).expect("Failed to load Mamba model"); Self::Mamba { model_specs, model } } ModelConfig::Mistral(config) => { - let model = load_mistral(config, var_builder); + let model = + MistralModel::new(&config, var_builder).expect("Failed to load Mistral model"); Self::Mistral { model_specs, model } } ModelConfig::Mixtral8x7b(config) => { - let model = load_mixtral(config, var_builder); + let model = + MixtralModel::new(&config, var_builder).expect("Failed to load Mixtral model"); Self::Mixtral8x7b { model_specs, model } } - ModelConfig::StableDiffusion(config) => { + ModelConfig::StableDiffusion(_) => { panic!("TODO: implement it") } } } - fn run(&self, input: String) -> Result> { - todo!() - } -} + fn run( + &self, + input: String, + max_tokens: usize, + random_seed: usize, + temperature: Temperature, + top_p: f32, + ) -> Result { + match self { + Self::Llama { model_specs, model } => { + let tokenizer = model_specs + .tokenizer + .encode(input, true) + .map_err(ModelError::TokenizerError)?; -fn load_llama_model(config: LlamaConfig, var_builder: VarBuilder) -> Llama { - todo!() -} + let mut logits = LogitsProcessor::new( + random_seed as u64, + Some(temperature as f64), + Some(top_p as f64), + ); -fn load_llama2_model(config: Llama2Config, var_builder: VarBuilder) -> Llama2 { - todo!() -} + let start = std::time::Instant::now(); -fn load_mamba_model(config: MambaConfig, var_builder: VarBuilder) -> MambaModel { - todo!() -} -fn load_mistral(config: MistralConfig, var_builder: VarBuilder) -> MistralModel { - todo!() + let index_pos = 0; + let mut tokens_generated = 0; + + todo!() + } + Self::Llama2 { model_specs, model } => { + todo!() + } + Self::Mamba { model_specs, model } => { + todo!() + } + Self::Mistral { model_specs, model } => { + todo!() + } + Self::Mixtral8x7b { model_specs, model } => { + todo!() + } + } + } } -fn load_mixtral(config: MixtralConfig, var_builder: VarBuilder) -> MixtralModel { - todo!() +#[derive(Debug, Error)] +pub enum ModelError { + #[error("Failed to load error: `{0}`")] + LoadError(CandleError), + #[error("Failed input tokenization: `{0}`")] + TokenizerError(Box), } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 924e94f5..a1aec840 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -4,7 +4,7 @@ use thiserror::Error; use crate::{ config::InferenceConfig, models::ModelType, - types::{InferenceResponse, ModelResponse, Prompt, QuantizationMethod, Temperature}, + types::{InferenceResponse, ModelResponse, QuantizationMethod, Temperature}, }; #[derive(Debug, Error)] @@ -50,7 +50,7 @@ impl InferenceCore { #[allow(clippy::too_many_arguments)] pub fn inference( &mut self, - _prompt: Prompt, + _prompt: String, model: ModelType, _temperature: Option, _max_tokens: usize, diff --git a/atoma-inference/src/types.rs b/atoma-inference/src/types.rs index 9c7e4cd7..9dcda742 100644 --- a/atoma-inference/src/types.rs +++ b/atoma-inference/src/types.rs @@ -4,12 +4,9 @@ use ed25519_consensus::VerificationKey; pub type NodeId = VerificationKey; pub type Temperature = f32; -#[derive(Clone, Debug)] -pub struct Prompt(pub(crate) String); - #[derive(Clone, Debug)] pub struct InferenceRequest { - pub(crate) prompt: Prompt, + pub(crate) prompt: String, pub(crate) model: ModelType, pub(crate) max_tokens: usize, pub(crate) random_seed: usize, @@ -43,15 +40,17 @@ pub struct ModelResponse { #[derive(Clone, Debug)] pub enum QuantizationMethod { - Ggml(QuantizationBits), - Gptq(QuantizationBits), + Ggml(PrecisionBits), + Gptq(PrecisionBits), } #[derive(Clone, Debug)] -pub enum QuantizationBits { +pub enum PrecisionBits { Q1, Q2, Q4, Q5, Q8, + F16, + F32, } From 886afb2e3c4654fa6889608ea1295cb7bc6a8b71 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Fri, 22 Mar 2024 06:14:32 +0000 Subject: [PATCH 3/4] intermediate steps --- atoma-inference/src/models.rs | 46 ++++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/atoma-inference/src/models.rs b/atoma-inference/src/models.rs index 884e0626..224481ee 100644 --- a/atoma-inference/src/models.rs +++ b/atoma-inference/src/models.rs @@ -1,6 +1,6 @@ use std::fmt::Display; -use candle::{DType, Device, Error as CandleError}; +use candle::{DType, Device, Error as CandleError, Tensor}; use candle_nn::VarBuilder; use candle_transformers::{ generation::LogitsProcessor, @@ -150,10 +150,12 @@ impl ModelApi for Model { ) -> Result { match self { Self::Llama { model_specs, model } => { - let tokenizer = model_specs + let mut tokens = model_specs .tokenizer .encode(input, true) - .map_err(ModelError::TokenizerError)?; + .map_err(ModelError::TokenizerError)? + .get_ids() + .to_vec(); let mut logits = LogitsProcessor::new( random_seed as u64, @@ -166,6 +168,42 @@ impl ModelApi for Model { let index_pos = 0; let mut tokens_generated = 0; + let mut output = String::with_capacity(max_tokens); + + for index in 0..max_tokens { + let (context_size, context_index) = if cache.use_kv_cache && index > 0 { + (1, index_pos) + } else { + (tokens.len(), 0) + }; + let ctx = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctx, &model_specs.device)?.unsqueeze(0)?; + let logits = model.forward(&input, context_index, &mut cache)?; + let logits = logits.squeeze(0).map_err(ModelError::LogitsError)?; + let logits = if repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + repeat_penalty, + &tokens[start_at..], + )? + }; + index_pos += ctx.len(); + + let next_token = logits_processor.sample(&logits)?; + token_generated += 1; + tokens.push(next_token); + + if Some(next_token) == eos_token_id { + break; + } + if let Some(t) = model_specs.tokenizer(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + } todo!() } Self::Llama2 { model_specs, model } => { @@ -190,4 +228,6 @@ pub enum ModelError { LoadError(CandleError), #[error("Failed input tokenization: `{0}`")] TokenizerError(Box), + #[error("Logits error: `{0}`")] + LogitsError(CandleError) } From 57537891cbaba6c9126898352ac64b71bd28e7cf Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Fri, 22 Mar 2024 17:42:26 +0000 Subject: [PATCH 4/4] run inference for llama --- atoma-inference/src/models.rs | 85 ++++++++++++++++++++++++++--------- 1 file changed, 63 insertions(+), 22 deletions(-) diff --git a/atoma-inference/src/models.rs b/atoma-inference/src/models.rs index 224481ee..c2d4612b 100644 --- a/atoma-inference/src/models.rs +++ b/atoma-inference/src/models.rs @@ -5,8 +5,8 @@ use candle_nn::VarBuilder; use candle_transformers::{ generation::LogitsProcessor, models::{ - llama::{Config as LlamaConfig, Llama}, - llama2_c::{Config as Llama2Config, Llama as Llama2}, + llama::{Cache as LlamaCache, Config as LlamaConfig, Llama}, + llama2_c::{Cache as Llama2Cache, Config as Llama2Config, Llama as Llama2}, mamba::{Config as MambaConfig, Model as MambaModel}, mistral::{Config as MistralConfig, Model as MistralModel}, mixtral::{Config as MixtralConfig, Model as MixtralModel}, @@ -19,6 +19,8 @@ use tokenizers::Tokenizer; use crate::types::Temperature; +const EOS_TOKEN: &str = ""; + #[derive(Clone, Debug)] pub enum ModelType { Llama(usize), @@ -64,6 +66,12 @@ impl From for ModelConfig { } } +#[derive(Clone)] +pub enum ModelCache { + Llama(LlamaCache), + Llama2(Llama2Cache), +} + pub trait ModelApi { fn load(model_specs: ModelSpecs, var_builder: VarBuilder) -> Self; fn run( @@ -71,6 +79,8 @@ pub trait ModelApi { input: String, max_tokens: usize, random_seed: usize, + repeat_last_n: usize, + repeat_penalty: f32, temperature: Temperature, top_p: f32, ) -> Result; @@ -78,6 +88,7 @@ pub trait ModelApi { #[allow(dead_code)] pub struct ModelSpecs { + pub(crate) cache: Option, pub(crate) config: ModelConfig, pub(crate) device: Device, pub(crate) dtype: DType, @@ -145,11 +156,22 @@ impl ModelApi for Model { input: String, max_tokens: usize, random_seed: usize, + repeat_last_n: usize, + repeat_penalty: f32, temperature: Temperature, top_p: f32, ) -> Result { match self { Self::Llama { model_specs, model } => { + let mut cache = if let ModelCache::Llama(cache) = + model_specs.cache.clone().expect("Failed to get cache") + { + cache + } else { + return Err(ModelError::CacheError(String::from( + "Failed to obtain correct cache", + ))); + }; let mut tokens = model_specs .tokenizer .encode(input, true) @@ -157,18 +179,20 @@ impl ModelApi for Model { .get_ids() .to_vec(); - let mut logits = LogitsProcessor::new( + let mut logits_processor = LogitsProcessor::new( random_seed as u64, Some(temperature as f64), Some(top_p as f64), ); + let eos_token_id = model_specs.tokenizer.token_to_id(EOS_TOKEN); + let start = std::time::Instant::now(); - let index_pos = 0; + let mut index_pos = 0; let mut tokens_generated = 0; - let mut output = String::with_capacity(max_tokens); + let mut output = Vec::with_capacity(max_tokens); for index in 0..max_tokens { let (context_size, context_index) = if cache.use_kv_cache && index > 0 { @@ -177,8 +201,13 @@ impl ModelApi for Model { (tokens.len(), 0) }; let ctx = &tokens[tokens.len().saturating_sub(context_size)..]; - let input = Tensor::new(ctx, &model_specs.device)?.unsqueeze(0)?; - let logits = model.forward(&input, context_index, &mut cache)?; + let input = Tensor::new(ctx, &model_specs.device) + .map_err(ModelError::TensorError)? + .unsqueeze(0) + .map_err(ModelError::TensorError)?; + let logits = model + .forward(&input, context_index, &mut cache) + .map_err(ModelError::TensorError)?; let logits = logits.squeeze(0).map_err(ModelError::LogitsError)?; let logits = if repeat_penalty == 1. { logits @@ -188,34 +217,42 @@ impl ModelApi for Model { &logits, repeat_penalty, &tokens[start_at..], - )? + ) + .map_err(ModelError::TensorError)? }; index_pos += ctx.len(); - - let next_token = logits_processor.sample(&logits)?; - token_generated += 1; + + let next_token = logits_processor + .sample(&logits) + .map_err(ModelError::TensorError)?; + tokens_generated += 1; tokens.push(next_token); - + if Some(next_token) == eos_token_id { break; } - if let Some(t) = model_specs.tokenizer(next_token)? { - print!("{t}"); - std::io::stdout().flush()?; + // TODO: possibly do this in batches will speed up the process + if let Ok(t) = model_specs.tokenizer.decode(&[next_token], true) { + output.push(t); } + let dt = start.elapsed(); + tracing::info!( + "Generated {tokens_generated} tokens ({} tokens/s)", + tokens_generated as f64 / dt.as_secs_f64() + ); } - todo!() + Ok(output.join(" ")) } - Self::Llama2 { model_specs, model } => { + Self::Llama2 { .. } => { todo!() } - Self::Mamba { model_specs, model } => { + Self::Mamba { .. } => { todo!() } - Self::Mistral { model_specs, model } => { + Self::Mistral { .. } => { todo!() } - Self::Mixtral8x7b { model_specs, model } => { + Self::Mixtral8x7b { .. } => { todo!() } } @@ -224,10 +261,14 @@ impl ModelApi for Model { #[derive(Debug, Error)] pub enum ModelError { + #[error("Cache error: `{0}`")] + CacheError(String), #[error("Failed to load error: `{0}`")] LoadError(CandleError), + #[error("Logits error: `{0}`")] + LogitsError(CandleError), + #[error("Tensor error: `{0}`")] + TensorError(CandleError), #[error("Failed input tokenization: `{0}`")] TokenizerError(Box), - #[error("Logits error: `{0}`")] - LogitsError(CandleError) }