Skip to content

Commit

Permalink
Merge pull request #9 from atoma-network/build-model-interface
Browse files Browse the repository at this point in the history
[wip] Model interface
  • Loading branch information
jorgeantonio21 authored Mar 23, 2024
2 parents 308185d + 5753789 commit 48e7845
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 38 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ version = "0.1.0"

[workspace.dependencies]
async-trait = "0.1.78"
candle-nn = "0.4.1"
candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" }
candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" }
candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" }
ed25519-consensus = "2.1.0"
serde = "1.0.197"
thiserror = "1.0.58"
Expand Down
2 changes: 2 additions & 0 deletions atoma-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ edition = "2021"

[dependencies]
async-trait.workspace = true
candle.workspace = true
candle-nn.workspace = true
candle-transformers.workspace = true
ed25519-consensus.workspace = true
serde = { workspace = true, features = ["derive"] }
thiserror.workspace = true
Expand Down
2 changes: 1 addition & 1 deletion atoma-inference/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::path::PathBuf;

use crate::{
models::ModelType,
specs::{HardwareSpec, SoftwareSpec},
types::ModelType,
};

pub struct InferenceConfig {
Expand Down
1 change: 1 addition & 0 deletions atoma-inference/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod config;
pub mod core_thread;
pub mod models;
pub mod service;
pub mod specs;
pub mod types;
274 changes: 274 additions & 0 deletions atoma-inference/src/models.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
use std::fmt::Display;

use candle::{DType, Device, Error as CandleError, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
generation::LogitsProcessor,
models::{
llama::{Cache as LlamaCache, Config as LlamaConfig, Llama},
llama2_c::{Cache as Llama2Cache, Config as Llama2Config, Llama as Llama2},
mamba::{Config as MambaConfig, Model as MambaModel},
mistral::{Config as MistralConfig, Model as MistralModel},
mixtral::{Config as MixtralConfig, Model as MixtralModel},
stable_diffusion::StableDiffusionConfig,
},
};
use thiserror::Error;

use tokenizers::Tokenizer;

use crate::types::Temperature;

const EOS_TOKEN: &str = "</s>";

#[derive(Clone, Debug)]
pub enum ModelType {
Llama(usize),
Llama2(usize),
Mamba(usize),
Mixtral8x7b,
Mistral(usize),
StableDiffusionV1_5,
StableDiffusionV2_1,
StableDiffusionXl,
StableDiffusionTurbo,
}

impl Display for ModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Llama(size) => write!(f, "llama({})", size),
Self::Llama2(size) => write!(f, "llama2({})", size),
Self::Mamba(size) => write!(f, "mamba({})", size),
Self::Mixtral8x7b => write!(f, "mixtral_8x7b"),
Self::Mistral(size) => write!(f, "mistral({})", size),
Self::StableDiffusionV1_5 => write!(f, "stable_diffusion_v1_5"),
Self::StableDiffusionV2_1 => write!(f, "stable_diffusion_v2_1"),
Self::StableDiffusionXl => write!(f, "stable_diffusion_xl"),
Self::StableDiffusionTurbo => write!(f, "stable_diffusion_turbo"),
}
}
}

#[derive(Clone)]
pub enum ModelConfig {
Llama(LlamaConfig),
Llama2(Llama2Config),
Mamba(MambaConfig),
Mixtral8x7b(MixtralConfig),
Mistral(MistralConfig),
StableDiffusion(StableDiffusionConfig),
}

impl From<ModelType> for ModelConfig {
fn from(_model_type: ModelType) -> Self {
todo!()
}
}

#[derive(Clone)]
pub enum ModelCache {
Llama(LlamaCache),
Llama2(Llama2Cache),
}

pub trait ModelApi {
fn load(model_specs: ModelSpecs, var_builder: VarBuilder) -> Self;
fn run(
&self,
input: String,
max_tokens: usize,
random_seed: usize,
repeat_last_n: usize,
repeat_penalty: f32,
temperature: Temperature,
top_p: f32,
) -> Result<String, ModelError>;
}

#[allow(dead_code)]
pub struct ModelSpecs {
pub(crate) cache: Option<ModelCache>,
pub(crate) config: ModelConfig,
pub(crate) device: Device,
pub(crate) dtype: DType,
pub(crate) tokenizer: Tokenizer,
}

pub enum Model {
Llama {
model_specs: ModelSpecs,
model: Llama,
},
Llama2 {
model_specs: ModelSpecs,
model: Llama2,
},
Mamba {
model_specs: ModelSpecs,
model: MambaModel,
},
Mixtral8x7b {
model_specs: ModelSpecs,
model: MixtralModel,
},
Mistral {
model_specs: ModelSpecs,
model: MistralModel,
},
}

impl ModelApi for Model {
fn load(model_specs: ModelSpecs, var_builder: VarBuilder) -> Self {
let model_config = model_specs.config.clone();
match model_config {
ModelConfig::Llama(config) => {
let model = Llama::load(var_builder, &config).expect("Failed to load LlaMa model");
Self::Llama { model, model_specs }
}
ModelConfig::Llama2(config) => {
let model = Llama2::load(var_builder, config).expect("Failed to load LlaMa2 model");
Self::Llama2 { model_specs, model }
}
ModelConfig::Mamba(config) => {
let model =
MambaModel::new(&config, var_builder).expect("Failed to load Mamba model");
Self::Mamba { model_specs, model }
}
ModelConfig::Mistral(config) => {
let model =
MistralModel::new(&config, var_builder).expect("Failed to load Mistral model");
Self::Mistral { model_specs, model }
}
ModelConfig::Mixtral8x7b(config) => {
let model =
MixtralModel::new(&config, var_builder).expect("Failed to load Mixtral model");
Self::Mixtral8x7b { model_specs, model }
}
ModelConfig::StableDiffusion(_) => {
panic!("TODO: implement it")
}
}
}

fn run(
&self,
input: String,
max_tokens: usize,
random_seed: usize,
repeat_last_n: usize,
repeat_penalty: f32,
temperature: Temperature,
top_p: f32,
) -> Result<String, ModelError> {
match self {
Self::Llama { model_specs, model } => {
let mut cache = if let ModelCache::Llama(cache) =
model_specs.cache.clone().expect("Failed to get cache")
{
cache
} else {
return Err(ModelError::CacheError(String::from(
"Failed to obtain correct cache",
)));
};
let mut tokens = model_specs
.tokenizer
.encode(input, true)
.map_err(ModelError::TokenizerError)?
.get_ids()
.to_vec();

let mut logits_processor = LogitsProcessor::new(
random_seed as u64,
Some(temperature as f64),
Some(top_p as f64),
);

let eos_token_id = model_specs.tokenizer.token_to_id(EOS_TOKEN);

let start = std::time::Instant::now();

let mut index_pos = 0;
let mut tokens_generated = 0;

let mut output = Vec::with_capacity(max_tokens);

for index in 0..max_tokens {
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctx = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctx, &model_specs.device)
.map_err(ModelError::TensorError)?
.unsqueeze(0)
.map_err(ModelError::TensorError)?;
let logits = model
.forward(&input, context_index, &mut cache)
.map_err(ModelError::TensorError)?;
let logits = logits.squeeze(0).map_err(ModelError::LogitsError)?;
let logits = if repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
repeat_penalty,
&tokens[start_at..],
)
.map_err(ModelError::TensorError)?
};
index_pos += ctx.len();

let next_token = logits_processor
.sample(&logits)
.map_err(ModelError::TensorError)?;
tokens_generated += 1;
tokens.push(next_token);

if Some(next_token) == eos_token_id {
break;
}
// TODO: possibly do this in batches will speed up the process
if let Ok(t) = model_specs.tokenizer.decode(&[next_token], true) {
output.push(t);
}
let dt = start.elapsed();
tracing::info!(
"Generated {tokens_generated} tokens ({} tokens/s)",
tokens_generated as f64 / dt.as_secs_f64()
);
}
Ok(output.join(" "))
}
Self::Llama2 { .. } => {
todo!()
}
Self::Mamba { .. } => {
todo!()
}
Self::Mistral { .. } => {
todo!()
}
Self::Mixtral8x7b { .. } => {
todo!()
}
}
}
}

#[derive(Debug, Error)]
pub enum ModelError {
#[error("Cache error: `{0}`")]
CacheError(String),
#[error("Failed to load error: `{0}`")]
LoadError(CandleError),
#[error("Logits error: `{0}`")]
LogitsError(CandleError),
#[error("Tensor error: `{0}`")]
TensorError(CandleError),
#[error("Failed input tokenization: `{0}`")]
TokenizerError(Box<dyn std::error::Error + Send + Sync>),
}
6 changes: 4 additions & 2 deletions atoma-inference/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use thiserror::Error;

use crate::{
config::InferenceConfig,
types::{InferenceResponse, ModelResponse, ModelType, Prompt, QuantizationMethod, Temperature},
models::ModelType,
types::{InferenceResponse, ModelResponse, QuantizationMethod, Temperature},
};

#[derive(Debug, Error)]
Expand All @@ -23,6 +24,7 @@ pub trait ApiTrait {
#[allow(dead_code)]
pub struct InferenceCore<T> {
config: InferenceConfig,
// models: Vec<Model>,
pub(crate) public_key: PublicKey,
private_key: PrivateKey,
web2_api: T,
Expand All @@ -48,7 +50,7 @@ impl<T: ApiTrait> InferenceCore<T> {
#[allow(clippy::too_many_arguments)]
pub fn inference(
&mut self,
_prompt: Prompt,
_prompt: String,
model: ModelType,
_temperature: Option<Temperature>,
_max_tokens: usize,
Expand Down
Loading

0 comments on commit 48e7845

Please sign in to comment.