From 525e6b645217930a5b967b1ba20649106320d95d Mon Sep 17 00:00:00 2001 From: Martin Stefcek Date: Fri, 26 Apr 2024 12:25:26 +0400 Subject: [PATCH] multi-gpu --- Cargo.toml | 19 +- atoma-inference/Cargo.toml | 1 + atoma-inference/src/models/candle/falcon.rs | 3 +- atoma-inference/src/models/candle/llama.rs | 3 +- .../src/models/candle/llama_multi_gpu.rs | 201 +++++++++ atoma-inference/src/models/candle/mamba.rs | 3 +- atoma-inference/src/models/candle/mistral.rs | 3 +- atoma-inference/src/models/candle/mixtral.rs | 3 +- atoma-inference/src/models/candle/mod.rs | 1 + .../src/models/candle/quantized.rs | 3 +- atoma-inference/src/models/config.rs | 8 + atoma-inference/src/models/types.rs | 55 +-- atoma-multi-gpu-mock/Cargo.toml | 38 ++ atoma-multi-gpu-mock/src/main.rs | 162 +++++++ atoma-multi-gpu-mock/src/models/llama/mod.rs | 32 ++ atoma-multi-gpu-mock/src/models/mod.rs | 1 + atoma-multi-gpu-mock/src/types.rs | 32 ++ atoma-multi-gpu/Cargo.toml | 38 ++ atoma-multi-gpu/src/main.rs | 223 ++++++++++ atoma-multi-gpu/src/models/llama/mod.rs | 109 +++++ atoma-multi-gpu/src/models/llama/model.rs | 411 ++++++++++++++++++ atoma-multi-gpu/src/models/mod.rs | 1 + atoma-multi-gpu/src/token_output_stream.rs | 90 ++++ atoma-multi-gpu/src/types.rs | 54 +++ atoma-types/src/lib.rs | 73 +++- 25 files changed, 1503 insertions(+), 64 deletions(-) create mode 100644 atoma-inference/src/models/candle/llama_multi_gpu.rs create mode 100644 atoma-multi-gpu-mock/Cargo.toml create mode 100644 atoma-multi-gpu-mock/src/main.rs create mode 100644 atoma-multi-gpu-mock/src/models/llama/mod.rs create mode 100644 atoma-multi-gpu-mock/src/models/mod.rs create mode 100644 atoma-multi-gpu-mock/src/types.rs create mode 100644 atoma-multi-gpu/Cargo.toml create mode 100644 atoma-multi-gpu/src/main.rs create mode 100644 atoma-multi-gpu/src/models/llama/mod.rs create mode 100644 atoma-multi-gpu/src/models/llama/model.rs create mode 100644 atoma-multi-gpu/src/models/mod.rs create mode 100644 atoma-multi-gpu/src/token_output_stream.rs create mode 100644 atoma-multi-gpu/src/types.rs diff --git a/Cargo.toml b/Cargo.toml index 0e23548a..4fa4876c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,10 +9,11 @@ members = [ "atoma-event-subscribe/sui", "atoma-inference", "atoma-networking", + "atoma-multi-gpu-mock", "atoma-node", "atoma-json-rpc", "atoma-storage", - "atoma-types" + "atoma-types", ] [workspace.package] @@ -29,18 +30,28 @@ atoma-sui = { path = "./atoma-event-subscribe/sui/" } atoma-types = { path = "./atoma-types" } axum = "0.7.5" blake2 = "0.10.6" +bindgen_cuda = { version = "0.1.1" } candle = { git = "https://github.com/huggingface/candle", package = "candle-core", branch = "main" } candle-flash-attn = { git = "https://github.com/huggingface/candle", package = "candle-flash-attn", branch = "main" } candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", branch = "main" } candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", branch = "main" } clap = "4.5.4" config = "0.14.0" +cudarc = { version = "0.10.0", features = ["f16"] } dotenv = "0.15.0" ethers = "2.0.14" futures = "0.3.30" futures-util = "0.3.30" +half = { version = "2.3.1", features = [ + "num-traits", + "use-intrinsics", + "rand_distr", +] } hf-hub = "0.3.2" -image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] } +image = { version = "0.25.0", default-features = false, features = [ + "jpeg", + "png", +] } rand = "0.8.5" rayon = "1.10.0" reqwest = "0.12.1" @@ -50,10 +61,12 @@ serde_json = "1.0.114" # solana-client = "1.18.9" # solana-sdk = "1.18.8" sui-keys = { git = "https://github.com/mystenlabs/sui", package = "sui-keys" } -sui-sdk = { git = "https://github.com/mystenlabs/sui", package = "sui-sdk"} +sui-sdk = { git = "https://github.com/mystenlabs/sui", package = "sui-sdk" } thiserror = "1.0.58" tokenizers = "0.15.2" tokio = "1.36.0" toml = "0.8.12" tracing = "0.1.40" tracing-subscriber = "0.3.18" +tungstenite = "0.21.0" +url = { version = "2.1.0" } diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index 5eaf970f..93b83a86 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -23,6 +23,7 @@ tokenizers = { workspace = true } tokio = { workspace = true, features = ["full", "tracing"] } tracing.workspace = true tracing-subscriber.workspace = true +tungstenite.workspace = true [dev-dependencies] rand.workspace = true diff --git a/atoma-inference/src/models/candle/falcon.rs b/atoma-inference/src/models/candle/falcon.rs index 59fcbc8b..78c7bf76 100644 --- a/atoma-inference/src/models/candle/falcon.rs +++ b/atoma-inference/src/models/candle/falcon.rs @@ -1,5 +1,6 @@ use std::{path::PathBuf, str::FromStr, time::Instant}; +use atoma_types::{TextModelInput, TextModelOutput}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::{ @@ -14,7 +15,7 @@ use tracing::{debug, error, info}; use crate::models::{ candle::hub_load_safetensors, config::ModelConfig, - types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput}, + types::{LlmLoadData, ModelType}, ModelError, ModelTrait, }; diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index becaaba5..c261c097 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -1,5 +1,6 @@ use std::{path::PathBuf, str::FromStr, time::Instant}; +use atoma_types::{TextModelInput, TextModelOutput}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::{ @@ -15,7 +16,7 @@ use tracing::info; use crate::models::{ config::ModelConfig, token_output_stream::TokenOutputStream, - types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput}, + types::{LlmLoadData, ModelType}, ModelError, ModelTrait, }; diff --git a/atoma-inference/src/models/candle/llama_multi_gpu.rs b/atoma-inference/src/models/candle/llama_multi_gpu.rs new file mode 100644 index 00000000..83160673 --- /dev/null +++ b/atoma-inference/src/models/candle/llama_multi_gpu.rs @@ -0,0 +1,201 @@ +use std::{ + net::TcpListener, + path::PathBuf, + rc::Rc, + str::FromStr, + sync::{Arc, Barrier, Mutex}, + thread, + time::Instant, +}; + +use atoma_types::{AtomaChildMessage, AtomaInferenceMessage, TextModelInput, TextModelOutput}; +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::{ + generation::LogitsProcessor, + models::llama::{Config, LlamaConfig}, +}; +use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; + +use candle_transformers::models::llama as model; +use tokenizers::Tokenizer; +use tracing::info; +use tungstenite::accept; + +use crate::models::{ + config::ModelConfig, + token_output_stream::TokenOutputStream, + types::{LlmLoadData, ModelType}, + ModelError, ModelTrait, +}; + +use super::hub_load_safetensors; + +pub struct MultiGpuLlamaModel { + model_type: ModelType, + gpu_instances: Arc>>>>, + sync_barrier: Arc, + result: Arc>>, +} + +pub struct MultiGpuLlamaLoadData { + file_paths: Vec, + dtype: String, + model_type: ModelType, +} + +impl ModelTrait for MultiGpuLlamaModel { + type Input = TextModelInput; + type Output = TextModelOutput; + type LoadData = MultiGpuLlamaLoadData; + + fn fetch( + api_key: String, + cache_dir: PathBuf, + config: ModelConfig, + ) -> Result { + let api = ApiBuilder::new() + .with_progress(true) + .with_token(Some(api_key)) + .with_cache_dir(cache_dir) + .build()?; + + let model_type = ModelType::from_str(&config.model_id())?; + let repo_id = model_type.repo().to_string(); + let revision = model_type.default_revision().to_string(); + + let repo = api.repo(Repo::with_revision( + repo_id.clone(), + RepoType::Model, + revision, + )); + let config_file_path = repo.get("config.json")?; + let tokenizer_file_path = repo.get("tokenizer.json")?; + + let model_weights_file_paths = if &repo_id == "TinyLlama/TinyLlama-1.1B-Chat-v1.0" { + vec![repo.get("model.safetensors")?] + } else { + hub_load_safetensors(&repo, "model.safetensors.index.json")? + }; + + let mut file_paths = Vec::with_capacity(2 + model_weights_file_paths.len()); + file_paths.extend(vec![config_file_path, tokenizer_file_path]); + file_paths.extend(model_weights_file_paths); + + Ok(Self::LoadData { + file_paths, + model_type: ModelType::from_str(&config.model_id())?, + dtype: config.dtype(), // use_flash_attention: config.use_flash_attention(), + }) + } + + fn model_type(&self) -> ModelType { + self.model_type.clone() + } + + fn load(load_data: Self::LoadData) -> Result { + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let port = server.local_addr().unwrap().port(); + println!("Server listening on port {}", port); + let num_shards = 2; + let mut num_connections = 0; + let gpu_instances = Arc::new(Mutex::new(Vec::with_capacity(num_shards))); + let init_barrier = Arc::new(Barrier::new(num_shards)); + let sync_barrier = Arc::new(Barrier::new(num_shards + 1)); + let result = Arc::new(Mutex::new(None)); + for stream in server.incoming() { + let gpu_instances = Arc::clone(&gpu_instances); + let init_barrier = Arc::clone(&init_barrier); + let sync_barrier = Arc::clone(&sync_barrier); + let config_file_path = load_data.file_paths[0].clone(); + let tokenizer_filename = load_data.file_paths[1].clone(); + let filenames = load_data.file_paths[2..].to_vec(); + let dtype = load_data.dtype.clone(); + let result = Arc::clone(&result); + thread::spawn(move || { + let mut websocket = Arc::new(accept(stream.unwrap()).unwrap()); + let index = { + let mut gpu_instances = gpu_instances.lock().unwrap(); + gpu_instances.push(Arc::clone(&websocket)); + gpu_instances.len() - 1 + }; + loop { + init_barrier.wait(); + let msg = Arc::get_mut(&mut websocket).unwrap().read().unwrap(); + if msg.is_text() { + let message: AtomaChildMessage = + serde_json::from_str(msg.to_string().as_str()).unwrap(); + match message { + AtomaChildMessage::Initialized(nccl_id) => { + if let Some(nccl_id) = nccl_id { + let mut gpu_instances = gpu_instances.lock().unwrap(); + for (i, websocket) in gpu_instances.iter_mut().enumerate() { + if i != index { + Arc::get_mut(websocket) + .unwrap() + .send(tungstenite::Message::Text( + serde_json::to_string( + &AtomaInferenceMessage::InitializeComm( + nccl_id.clone(), + ), + ) + .unwrap(), + )) + .unwrap(); + } + } + } + } + AtomaChildMessage::CommsReady => { + Arc::get_mut(&mut websocket) + .unwrap() + .send(tungstenite::Message::Text( + serde_json::to_string(&AtomaInferenceMessage::LoadModel( + config_file_path.clone(), + dtype.clone(), + filenames.clone(), + tokenizer_filename.clone(), + )) + .unwrap(), + )) + .unwrap(); + } + AtomaChildMessage::Loaded => { + sync_barrier.wait(); + } + AtomaChildMessage::InferenceResult(output) => { + *result.lock().unwrap() = Some(output); + sync_barrier.wait(); + } + } + } + } + }); + num_connections += 1; + if num_connections == num_shards { + break; + } + } + sync_barrier.wait(); + Ok(MultiGpuLlamaModel { + model_type: load_data.model_type, + gpu_instances, + sync_barrier, + result, + }) + } + + fn run(&mut self, input: Self::Input) -> Result { + for websocket in self.gpu_instances.lock().unwrap().iter_mut() { + Arc::get_mut(websocket) + .unwrap() + .send(tungstenite::Message::Text( + serde_json::to_string(&AtomaInferenceMessage::Inference(input.clone())) + .unwrap(), + )) + .unwrap(); + } + self.sync_barrier.wait(); + Ok(self.result.lock().unwrap().take().unwrap()) + } +} diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 9788d7d8..aba67b46 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -1,5 +1,6 @@ use std::{path::PathBuf, str::FromStr, time::Instant}; +use atoma_types::{TextModelInput, TextModelOutput}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::{ @@ -17,7 +18,7 @@ use crate::{ candle::device, config::ModelConfig, token_output_stream::TokenOutputStream, - types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput}, + types::{LlmLoadData, ModelType}, ModelError, ModelTrait, }, }; diff --git a/atoma-inference/src/models/candle/mistral.rs b/atoma-inference/src/models/candle/mistral.rs index 942f3a97..077027a8 100644 --- a/atoma-inference/src/models/candle/mistral.rs +++ b/atoma-inference/src/models/candle/mistral.rs @@ -1,5 +1,6 @@ use std::str::FromStr; +use atoma_types::{TextModelInput, TextModelOutput}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::{ @@ -16,7 +17,7 @@ use crate::{ models::{ candle::{device, hub_load_safetensors}, token_output_stream::TokenOutputStream, - types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput}, + types::{LlmLoadData, ModelType}, ModelError, ModelTrait, }, }; diff --git a/atoma-inference/src/models/candle/mixtral.rs b/atoma-inference/src/models/candle/mixtral.rs index 864c174a..7676ec5f 100644 --- a/atoma-inference/src/models/candle/mixtral.rs +++ b/atoma-inference/src/models/candle/mixtral.rs @@ -1,5 +1,6 @@ use std::str::FromStr; +use atoma_types::{TextModelInput, TextModelOutput}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::{ @@ -16,7 +17,7 @@ use crate::{ models::{ candle::{device, hub_load_safetensors}, token_output_stream::TokenOutputStream, - types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput}, + types::{LlmLoadData, ModelType}, ModelError, ModelTrait, }, }; diff --git a/atoma-inference/src/models/candle/mod.rs b/atoma-inference/src/models/candle/mod.rs index cbd4c562..28bc932c 100644 --- a/atoma-inference/src/models/candle/mod.rs +++ b/atoma-inference/src/models/candle/mod.rs @@ -12,6 +12,7 @@ use super::ModelError; pub mod falcon; pub mod llama; +pub mod llama_multi_gpu; pub mod mamba; pub mod mistral; pub mod mixtral; diff --git a/atoma-inference/src/models/candle/quantized.rs b/atoma-inference/src/models/candle/quantized.rs index f6a0e7dc..5de11a8c 100644 --- a/atoma-inference/src/models/candle/quantized.rs +++ b/atoma-inference/src/models/candle/quantized.rs @@ -1,5 +1,6 @@ use std::{path::PathBuf, str::FromStr}; +use atoma_types::{TextModelInput, TextModelOutput}; use candle::{ quantized::{ggml_file, gguf_file}, DType, Device, Tensor, @@ -16,7 +17,7 @@ use crate::models::{ candle::device, config::ModelConfig, token_output_stream::TokenOutputStream, - types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput}, + types::{LlmLoadData, ModelType}, ModelError, ModelTrait, }; use candle_transformers::models::quantized_llama as model; diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index a1e4e20c..23cb4bc7 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -63,6 +63,7 @@ pub struct ModelsConfig { models: Vec, tracing: bool, jrpc_port: u64, + multi_gpu: bool, } impl ModelsConfig { @@ -73,6 +74,7 @@ impl ModelsConfig { models: Vec, tracing: bool, jrpc_port: u64, + multi_gpu: bool, ) -> Self { Self { api_key, @@ -81,6 +83,7 @@ impl ModelsConfig { models, tracing, jrpc_port, + multi_gpu, } } @@ -108,6 +111,10 @@ impl ModelsConfig { self.jrpc_port } + pub fn is_multi_gpu(&self) -> bool { + self.multi_gpu + } + pub fn from_file_path>(config_file_path: P) -> Self { let builder = Config::builder().add_source(config::File::with_name( config_file_path.as_ref().to_str().unwrap(), @@ -156,6 +163,7 @@ impl ModelsConfig { models, tracing, jrpc_port, + multi_gpu: false, } } } diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 36c8b26c..8785e6f3 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -1,5 +1,6 @@ use std::{fmt::Display, path::PathBuf, str::FromStr}; +use atoma_types::TextModelInput; use candle::{DType, Device}; use serde::{Deserialize, Serialize}; @@ -322,60 +323,6 @@ impl Request for TextRequest { } } -#[derive(Deserialize)] -pub struct TextModelInput { - pub(crate) prompt: String, - pub(crate) temperature: f64, - pub(crate) random_seed: u64, - pub(crate) repeat_penalty: f32, - pub(crate) repeat_last_n: usize, - pub(crate) max_tokens: usize, - pub(crate) top_k: Option, - pub(crate) top_p: Option, -} - -impl TextModelInput { - #[allow(clippy::too_many_arguments)] - pub fn new( - prompt: String, - temperature: f64, - random_seed: u64, - repeat_penalty: f32, - repeat_last_n: usize, - max_tokens: usize, - top_k: Option, - top_p: Option, - ) -> Self { - Self { - prompt, - temperature, - random_seed, - repeat_penalty, - repeat_last_n, - max_tokens, - top_k, - top_p, - } - } -} - -#[derive(Serialize)] -pub struct TextModelOutput { - pub text: String, - pub time: f64, - pub tokens_count: usize, -} - -impl Display for TextModelOutput { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Output: {}\nTime: {}\nTokens count: {}", - self.text, self.time, self.tokens_count - ) - } -} - #[derive(Clone, Debug, Deserialize, Serialize)] pub struct TextResponse { pub output: String, diff --git a/atoma-multi-gpu-mock/Cargo.toml b/atoma-multi-gpu-mock/Cargo.toml new file mode 100644 index 00000000..6d18ef62 --- /dev/null +++ b/atoma-multi-gpu-mock/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "atoma-multi-gpu-mock" +version.workspace = true +edition = "2021" + +[dependencies] +atoma-types.workspace = true +candle.workspace = true +candle-flash-attn = { workspace = true, optional = true } +candle-nn.workspace = true +candle-transformers.workspace = true +cudarc = { workspace = true, optional = true } +half = { workspace = true, optional = true } +hf-hub.workspace = true +serde.workspace = true +serde_json.workspace = true +thiserror.workspace = true +tokenizers = { workspace = true } +tokio = { workspace = true, features = ["full", "tracing"] } +tungstenite.workspace = true +url.workspace = true + +[dev-dependencies] + +[build-dependencies] +bindgen_cuda = { workspace = true, optional = true } + +[features] +default = [] +cuda = [ + "candle/cuda", + "candle-nn/cuda", + "candle-transformers/cuda", + "dep:bindgen_cuda", +] +cudnn = ["candle/cudnn"] +flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] +nccl = ["cuda", "cudarc/nccl", "dep:half"] diff --git a/atoma-multi-gpu-mock/src/main.rs b/atoma-multi-gpu-mock/src/main.rs new file mode 100644 index 00000000..84805c71 --- /dev/null +++ b/atoma-multi-gpu-mock/src/main.rs @@ -0,0 +1,162 @@ +// An implementation of LLaMA https://github.com/facebookresearch/llama +// +// This is based on nanoGPT in a similar way to: +// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py +// +// The tokenizer config can be retrieved from: +// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use atoma_types::{AtomaChildMessage, AtomaInferenceMessage, TextModelInput, TextModelOutput}; +use candle::DType; +use std::env; +use std::path::PathBuf; +use std::str::FromStr; +use tungstenite::{connect, Message}; +use url::Url; + +mod models; +mod types; +pub use models::*; +pub use types::*; + +struct SingleGpu { + id: Option, + rank: usize, + num_shards: usize, + comm: Option, + model: Option, +} + +impl SingleGpu { + fn new(rank: usize, num_shards: usize) -> Self { + Self { + id: None, + rank, + num_shards, + comm: None, + model: None, + } + } + + // This should be called only on the main gpu. + fn create_main_id(&mut self) -> Result<(), ModelError> { + self.id = Some(42); + self.create_comm()?; + Ok(()) + } + + fn copy_id_from(&mut self, data: Vec) -> Result<(), ModelError> { + self.id = Some(data[0]); + self.create_comm()?; + Ok(()) + } + + fn create_comm(&mut self) -> Result<(), ModelError> { + // Don't mistake nccl device with cuda device. + self.comm = Some(43); + println!("Rank {} spawned", self.rank); + Ok(()) + } + + fn load_model( + &mut self, + config_filename: PathBuf, + dtype: DType, + filenames: Vec, + tokenizer_filename: PathBuf, + ) -> Result<(), ModelError> { + if self.comm.is_none() { + panic!("Comm not initialized"); + } + if self.model.is_some() { + panic!("Model already loaded"); + } + self.model = Some(llama::Model::load( + config_filename, + dtype, + filenames, + tokenizer_filename, + )?); + Ok(()) + } + + fn inference(&self, input: TextModelInput) -> Result { + self.model.as_ref().unwrap().inference(input) + } +} + +#[tokio::main] +async fn main() -> Result<(), ()> { + let port = env::args().nth(1).expect("Expected a port number"); + let num_shards = env::args() + .nth(2) + .expect("Expected the number of shards") + .parse() + .expect("Expected a number"); + let rank = env::args() + .nth(3) + .expect("Expected the rank") + .parse() + .expect("Expected a number"); + + let url = format!("ws://127.0.0.1:{}/socket", port); + let (mut socket, response) = connect(Url::parse(url.as_str()).unwrap()).expect("Can't connect"); + println!("Connected to the server: {}", response.status()); + let mut gpu_instance = SingleGpu::new(rank, num_shards); + if rank == 0 { + gpu_instance.create_main_id().unwrap(); + } + let id = gpu_instance.id.map(|x| vec![x]); + + socket + .send(Message::Text( + serde_json::to_string(&AtomaChildMessage::Initialized(id)).unwrap(), + )) + .unwrap(); + loop { + let msg = socket.read().unwrap(); + let msg: AtomaInferenceMessage = serde_json::from_str(msg.to_string().as_str()).unwrap(); + match msg { + AtomaInferenceMessage::InitializeComm(data) => { + if gpu_instance.id.is_some() { + panic!("Id already initialized"); + } + gpu_instance.copy_id_from(data).unwrap(); + socket + .send(Message::Text( + serde_json::to_string(&AtomaChildMessage::CommsReady).unwrap(), + )) + .unwrap(); + } + AtomaInferenceMessage::LoadModel( + config_filename, + dtype, + filenames, + tokenizer_filename, + ) => { + let dtype = DType::from_str(dtype.as_str()).unwrap(); + gpu_instance + .load_model(config_filename, dtype, filenames, tokenizer_filename) + .unwrap(); + socket + .send(Message::Text( + serde_json::to_string(&AtomaChildMessage::Loaded).unwrap(), + )) + .unwrap(); + } + AtomaInferenceMessage::Inference(input) => { + let result = gpu_instance.inference(input).unwrap(); + socket + .send(Message::Text( + serde_json::to_string(&AtomaChildMessage::InferenceResult(result)).unwrap(), + )) + .unwrap(); + } + AtomaInferenceMessage::Exit => break, + } + } + Ok(()) +} diff --git a/atoma-multi-gpu-mock/src/models/llama/mod.rs b/atoma-multi-gpu-mock/src/models/llama/mod.rs new file mode 100644 index 00000000..23ecdfb7 --- /dev/null +++ b/atoma-multi-gpu-mock/src/models/llama/mod.rs @@ -0,0 +1,32 @@ +use std::path::PathBuf; + +use atoma_types::{TextModelInput, TextModelOutput}; +use candle::DType; + +use crate::ModelError; + +pub struct Model {} + +impl Model { + pub fn load( + config_filename: PathBuf, + dtype: DType, + filenames: Vec, + tokenizer_filename: PathBuf, + ) -> Result { + println!("Loading model..."); + println!("Config file: {:?}", config_filename); + println!("tokenizer file: {:?}", tokenizer_filename); + println!("filenames: {:?}", filenames); + println!("Dtype: {:?}", dtype); + Ok(Self {}) + } + + pub fn inference(&self, input: TextModelInput) -> Result { + Ok(TextModelOutput { + text: input.prompt, + time: 1.0, + tokens_count: 1, + }) + } +} diff --git a/atoma-multi-gpu-mock/src/models/mod.rs b/atoma-multi-gpu-mock/src/models/mod.rs new file mode 100644 index 00000000..93183888 --- /dev/null +++ b/atoma-multi-gpu-mock/src/models/mod.rs @@ -0,0 +1 @@ +pub mod llama; diff --git a/atoma-multi-gpu-mock/src/types.rs b/atoma-multi-gpu-mock/src/types.rs new file mode 100644 index 00000000..57fbc191 --- /dev/null +++ b/atoma-multi-gpu-mock/src/types.rs @@ -0,0 +1,32 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ModelError { + // NcclError(NcclError), + #[error("{0}")] + CandleError(#[from] candle::Error), + #[error("{0}")] + IoError(#[from] std::io::Error), + #[error("{0}")] + SerdeJsonError(#[from] serde_json::Error), + #[error("Error: `{0}`")] + BoxedError(#[from] Box), + // TokenizerError(tokenizers::Error), + #[error("Error: `{0}`")] + Msg(String), + #[error("ApiError error: `{0}`")] + ApiError(#[from] hf_hub::api::sync::ApiError), +} + +#[macro_export] +macro_rules! bail { + ($msg:literal $(,)?) => { + return Err(ModelError::Msg(format!($msg).into())) + }; + ($err:expr $(,)?) => { + return Err(ModelError::Msg(format!($err).into()).bt()) + }; + ($fmt:expr, $($arg:tt)*) => { + return Err(ModelError::Msg(format!($fmt, $($arg)*).into()).bt()) + }; +} diff --git a/atoma-multi-gpu/Cargo.toml b/atoma-multi-gpu/Cargo.toml new file mode 100644 index 00000000..5b6410c9 --- /dev/null +++ b/atoma-multi-gpu/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "atoma-multi-gpu" +version.workspace = true +edition = "2021" + +[dependencies] +atoma-types.workspace = true +candle.workspace = true +candle-flash-attn = { workspace = true, optional = true } +candle-nn.workspace = true +candle-transformers.workspace = true +cudarc = { workspace = true, optional = true } +half = { workspace = true, optional = true } +hf-hub.workspace = true +serde.workspace = true +serde_json.workspace = true +thiserror.workspace = true +tokenizers = { workspace = true } +tokio = { workspace = true, features = ["full", "tracing"] } +tungstenite.workspace = true +url.workspace = true + +[dev-dependencies] + +[build-dependencies] +bindgen_cuda = { workspace = true, optional = true } + +[features] +default = ["cuda", "nccl", "flash-attn"] +cuda = [ + "candle/cuda", + "candle-nn/cuda", + "candle-transformers/cuda", + "dep:bindgen_cuda", +] +cudnn = ["candle/cudnn"] +flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] +nccl = ["cuda", "cudarc/nccl", "dep:half"] diff --git a/atoma-multi-gpu/src/main.rs b/atoma-multi-gpu/src/main.rs new file mode 100644 index 00000000..12eb3ba8 --- /dev/null +++ b/atoma-multi-gpu/src/main.rs @@ -0,0 +1,223 @@ +// An implementation of LLaMA https://github.com/facebookresearch/llama +// +// This is based on nanoGPT in a similar way to: +// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py +// +// The tokenizer config can be retrieved from: +// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use atoma_types::{AtomaChildMessage, AtomaInferenceMessage}; +use candle::{DType, Device}; +use cudarc::driver::safe::CudaDevice; +use cudarc::nccl::safe::{Comm, Id}; +use std::env; +use std::path::PathBuf; +use std::rc::Rc; +use std::str::FromStr; +use tungstenite::{connect, Message}; +use url::Url; + +mod models; +mod token_output_stream; +mod types; +pub use models::*; +pub use types::*; + +#[derive(Clone, Debug, Copy, PartialEq, Eq)] +enum Which { + V2_7b, + V2_70b, + V3_8b, + V3_70b, +} + +#[derive(Debug)] +struct Args { + num_shards: usize, + rank: usize, + temperature: f64, + top_p: Option, + seed: u64, + sample_len: usize, + no_kv_cache: bool, + prompt: Option, + model_id: Option, + revision: Option, + dtype: Option, + which: Which, + comm_file: String, +} + +pub fn hub_load_safetensors( + repo: &hf_hub::api::sync::ApiRepo, + json_file: &str, +) -> Result, ModelError> { + let json_file = repo.get(json_file)?; + let json_file = std::fs::File::open(json_file)?; + let json: serde_json::Value = serde_json::from_reader(&json_file)?; + let weight_map = match json.get("weight_map") { + None => bail!("no weight map in {json_file:?}"), + Some(serde_json::Value::Object(map)) => map, + Some(_) => bail!("weight map in {json_file:?} is not a map"), + }; + let mut safetensors_files = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_files.insert(file.to_string()); + } + } + let safetensors_files = safetensors_files + .iter() + .map(|v| repo.get(v).map_err(candle::Error::wrap)) + .collect::>>()?; + Ok(safetensors_files) +} + +struct SingleGpu { + id: Option, + rank: usize, + num_shards: usize, + comm: Option>, + model: Option, +} + +impl SingleGpu { + fn new(rank: usize, num_shards: usize) -> Self { + Self { + id: None, + rank, + num_shards, + comm: None, + model: None, + } + } + + // This should be called only on the main gpu. + fn create_main_id(&mut self) -> Result<(), ModelError> { + self.id = Some(Id::new().map_err(|err| ModelError::NcclError(err))?); + self.create_comm()?; + Ok(()) + } + + fn copy_id_from(&mut self, data: Vec) -> Result<(), ModelError> { + self.id = Some(Id::uninit( + data.into_iter() + .map(|i| i as i8) + .collect::>() + .try_into() + .unwrap(), + )); + self.create_comm()?; + Ok(()) + } + + fn create_comm(&mut self) -> Result<(), ModelError> { + // Don't mistake nccl device with cuda device. + let device = CudaDevice::new(self.rank).unwrap(); + self.comm = Some( + match Comm::from_rank(device, self.rank, self.num_shards, self.id.unwrap()) { + Ok(comm) => Rc::new(comm), + Err(err) => panic!("nccl error {:?}", err.0), + }, + ); + println!("Rank {} spawned", self.rank); + Ok(()) + } + + fn load_model( + &mut self, + config_filename: PathBuf, + dtype: DType, + filenames: Vec, + tokenizer_filename: PathBuf, + ) -> Result<(), ModelError> { + if self.comm.is_none() { + panic!("Comm not initialized"); + } + if self.model.is_some() { + panic!("Model already loaded"); + } + self.model = Some(llama::Model::load( + config_filename, + Device::new_cuda(self.rank).unwrap(), + dtype, + filenames, + tokenizer_filename, + Rc::clone(self.comm.as_ref().unwrap()), + )?); + Ok(()) + } + + fn inference(&self, input: TextModelInput) -> Result { + self.model.as_ref().unwrap().inference(input) + } +} + +#[tokio::main] +async fn main() -> Result<(), ()> { + let port = env::args().nth(1).expect("Expected a port number"); + let num_shards = env::args() + .nth(2) + .expect("Expected the number of shards") + .parse() + .expect("Expected a number"); + let comm_file = env::args().nth(3).expect("Expected a comm file"); + let rank = env::args() + .nth(4) + .expect("Expected the rank") + .parse() + .expect("Expected a number"); + + let url = format!("ws://127.0.0.1:{}/socket", port); + let (mut socket, response) = connect(Url::parse(url.as_str()).unwrap()).expect("Can't connect"); + println!("Connected to the server: {}", response.status()); + let mut gpu_instance = SingleGpu::new(rank, num_shards); + if rank == 0 { + gpu_instance.create_main_id().unwrap(); + } + let id = gpu_instance + .id + .map(|id| id.internal().iter().map(|&i| i as u8).collect::>()); + + socket + .send(Message::Text( + serde_json::to_string(&AtomaChildMessage::Initialized(rank, id)).unwrap(), + )) + .unwrap(); + loop { + let msg = socket.read().unwrap(); + let msg: AtomaInferenceMessage = serde_json::from_str(msg.to_string().as_str()).unwrap(); + match msg { + AtomaInferenceMessage::InitializeComm(data) => { + if gpu_instance.id.is_some() { + panic!("Id already initialized"); + } + gpu_instance.copy_id_from(data).unwrap(); + } + AtomaInferenceMessage::LoadModel( + config_filename, + dtype, + filenames, + tokenizer_filename, + ) => { + let dtype = DType::from_str(dtype.as_str()).unwrap(); + gpu_instance + .load_model(config_filename, dtype, filenames, tokenizer_filename) + .unwrap(); + } + AtomaInferenceMessage::Inference(input) => { + let result = gpu_instance.inference(input).unwrap(); + socket + .send(Message::Text( + serde_json::to_string(&AtomaChildMessage::InferenceResult(result)).unwrap(), + )) + .unwrap(); + } + AtomaInferenceMessage::Exit => break, + } + } + Ok(()) +} diff --git a/atoma-multi-gpu/src/models/llama/mod.rs b/atoma-multi-gpu/src/models/llama/mod.rs new file mode 100644 index 00000000..813a011c --- /dev/null +++ b/atoma-multi-gpu/src/models/llama/mod.rs @@ -0,0 +1,109 @@ +mod model; + +use std::{path::PathBuf, rc::Rc}; + +use candle::{DType, Device, Tensor}; +use candle_transformers::{generation::LogitsProcessor, models::llama::LlamaConfig}; +use cudarc::nccl::Comm; +use tokenizers::Tokenizer; + +use crate::{token_output_stream::TokenOutputStream, ModelError, TextModelInput, TextModelOutput}; + +const MAX_SEQ_LEN: usize = 4096; + +pub struct Model { + config: LlamaConfig, + tokenizer: Tokenizer, + model: model::Llama, + device: Device, +} + +impl Model { + // fn load_model(dtype: DType, model_type: ModelType) -> Result, ModelError> { + // let api = Api::new()?; + // let model_id = model_type.repo(); + // println!("loading the model weights from {model_id}"); + // let revision = model_type.default_revision(); + // let api = api.repo(Repo::with_revision( + // model_id.to_string(), + // RepoType::Model, + // revision.to_string(), + // )); + // let config_filename = api.get("config.json")?; + // let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; + // let tokenizer_filename = api.get("tokenizer.json")?; + // hub_load_safetensors(&api, "model.safetensors.index.json") + // } + pub fn load( + config_filename: PathBuf, + device: Device, + dtype: DType, + filenames: Vec, + tokenizer_filename: PathBuf, + comm: Rc, + ) -> Result { + let config = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let cache = model::Cache::new(dtype, &config, &device)?; + let vb = unsafe { + candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)? + }; + let model = model::Llama::load(vb, &cache, &config, comm).unwrap(); + let tokenizer = Tokenizer::from_file(tokenizer_filename)?; + + Ok(Self { + config, + tokenizer, + model, + device, + }) + } + + pub fn inference(&self, input: TextModelInput) -> Result { + let temperature = if input.temperature <= 0. { + None + } else { + Some(input.temperature) + }; + + let prompt = input.prompt; + let mut tokens = self.tokenizer.encode(prompt, true)?.get_ids().to_vec(); + + let mut tokenizer: TokenOutputStream = TokenOutputStream::new(self.tokenizer.clone()); + let mut logits_processor = + LogitsProcessor::new(input.random_seed, temperature, input.top_p); + let mut new_tokens = vec![]; + let mut start_gen = std::time::Instant::now(); + let mut index_pos = 0; + let mut res = String::new(); + + for index in 0..input.max_tokens { + // Only start timing at the second token as processing the first token waits for all the + // weights to be loaded in an async way. + if index == 1 { + start_gen = std::time::Instant::now() + }; + let context_size = if index > 0 { 1 } else { tokens.len() }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input, index_pos)?; + let logits = logits.squeeze(0)?; + index_pos += ctxt.len(); + + let next_token = logits_processor.sample(&logits)?; + tokens.push(next_token); + new_tokens.push(next_token); + if Some(next_token) == self.config.eos_token_id { + break; + } + if let Some(t) = tokenizer.next_token(next_token)? { + res += &t; + } + } + let dt = start_gen.elapsed(); + Ok(TextModelOutput { + text: res, + time: dt.as_secs_f64(), + tokens_count: tokenizer.get_num_generated_tokens(), + }) + } +} diff --git a/atoma-multi-gpu/src/models/llama/model.rs b/atoma-multi-gpu/src/models/llama/model.rs new file mode 100644 index 00000000..dde7908b --- /dev/null +++ b/atoma-multi-gpu/src/models/llama/model.rs @@ -0,0 +1,411 @@ +use candle::backend::BackendStorage; +use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D}; +use candle_nn::var_builder::ShardedVarBuilder as VarBuilder; +use candle_nn::{Embedding, Linear, Module, RmsNorm}; +use cudarc::nccl::safe::{Comm, ReduceOp}; +use std::rc::Rc; +use std::sync::{Arc, Mutex}; + +use super::MAX_SEQ_LEN; + +pub type Config = candle_transformers::models::llama::LlamaConfig; + +struct TensorParallelColumnLinear { + linear: Linear, +} + +impl TensorParallelColumnLinear { + fn new(linear: Linear) -> Self { + Self { linear } + } + fn forward(&self, x: &Tensor) -> Result { + self.linear.forward(x) + } +} + +struct TensorParallelRowLinear { + linear: Linear, + all_reduce: AllReduce, +} + +struct AllReduce { + comm: Rc, +} + +/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html +/// But for this example purposes, this will work +unsafe impl Sync for AllReduce {} +unsafe impl Send for AllReduce {} + +impl CustomOp1 for AllReduce { + fn name(&self) -> &'static str { + "allreduce" + } + + fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> { + candle::bail!("AllReduce is never used on cpu") + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + storage: &candle::CudaStorage, + layout: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::cuda_backend::WrapErr; + use cudarc::driver::DeviceSlice; + use half::{bf16, f16}; + + let elem_count = layout.shape().elem_count(); + let device = storage.device().clone(); + let dst = match storage.dtype() { + DType::BF16 => { + let slice = storage.as_cuda_slice::()?; + let slice = match layout.contiguous_offsets() { + Some((0, l)) if l == slice.len() => slice, + Some(_) | None => candle::bail!("input has to be contiguous"), + }; + let mut dst = unsafe { device.alloc::(elem_count) }.w()?; + self.comm + .all_reduce(slice, &mut dst, &ReduceOp::Sum) + .map_err(candle::Error::debug)?; + candle::CudaStorage::wrap_cuda_slice(dst, device) + } + DType::F16 => { + let s = storage.as_cuda_slice::()?; + let s = match layout.contiguous_offsets() { + Some((0, l)) if l == s.len() => s, + Some(_) | None => candle::bail!("input has to be contiguous"), + }; + let mut dst = unsafe { device.alloc::(elem_count) }.w()?; + self.comm + .all_reduce(s, &mut dst, &ReduceOp::Sum) + .map_err(candle::Error::debug)?; + candle::CudaStorage::wrap_cuda_slice(dst, device) + } + dtype => candle::bail!("unsupported dtype {dtype:?}"), + }; + Ok((dst, layout.shape().clone())) + } +} + +impl TensorParallelRowLinear { + fn new(linear: Linear, comm: Rc) -> Self { + let all_reduce = AllReduce { comm }; + Self { linear, all_reduce } + } + fn forward(&self, x: &Tensor) -> Result { + self.linear.forward(x)?.apply_op1_no_bwd(&self.all_reduce) + } +} + +fn shard(dim: usize, rank: usize, world_size: usize) -> candle_nn::var_builder::Shard { + candle_nn::var_builder::Shard { + dim, + rank, + world_size, + } +} + +impl TensorParallelColumnLinear { + fn load(vb: VarBuilder, comm: Rc) -> Result { + let rank = comm.rank(); + let size = comm.world_size(); + let weight = vb.get_with_hints((), "weight", shard(0, rank, size))?; + Ok(Self::new(Linear::new(weight, None))) + } + + fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc) -> Result { + let rank = comm.rank(); + let size = comm.world_size(); + let weights: Vec<_> = prefixes + .iter() + .map(|p| vb.pp(p).get_with_hints((), "weight", shard(0, rank, size))) + .collect::>>()?; + let weight = Tensor::cat(&weights, 0)?; + Ok(Self::new(Linear::new(weight, None))) + } +} + +impl TensorParallelRowLinear { + fn load(vb: VarBuilder, comm: Rc) -> Result { + let rank = comm.rank(); + let size = comm.world_size(); + let weight = vb.get_with_hints((), "weight", shard(1, rank, size))?; + Ok(Self::new(Linear::new(weight, None), comm)) + } +} + +#[derive(Clone)] +pub struct Cache { + #[allow(clippy::type_complexity)] + kvs: Arc>>>, + cos: Tensor, + sin: Tensor, +} + +impl Cache { + pub fn new(dtype: DType, config: &Config, device: &Device) -> Result { + // precompute freqs_cis + let n_elem = config.hidden_size / config.num_attention_heads; + let theta: Vec<_> = (0..n_elem) + .step_by(2) + .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + // This is different from the paper, see: + // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 + let cos = idx_theta.cos()?.to_dtype(dtype)?; + let sin = idx_theta.sin()?.to_dtype(dtype)?; + Ok(Self { + kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])), + cos, + sin, + }) + } +} + +fn silu(xs: &Tensor) -> Result { + xs / (xs.neg()?.exp()? + 1.0)? +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { + let weight = vb.get((size2, size1), "weight")?; + Ok(Linear::new(weight, None)) +} + +fn embedding(cfg: &Config, vb: VarBuilder) -> Result { + let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; + Ok(Embedding::new(embeddings, cfg.hidden_size)) +} + +struct CausalSelfAttention { + qkv_proj: TensorParallelColumnLinear, + o_proj: TensorParallelRowLinear, + num_attention_heads: usize, + num_key_value_heads: usize, + head_dim: usize, + cache: Cache, +} + +impl CausalSelfAttention { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { + let (_b_sz, _, seq_len, _hidden_size) = x.shape().dims4()?; + let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; + let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; + candle_nn::rotary_emb::rope(x, &cos, &sin) + } + + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + let (b_sz, seq_len, _) = x.shape().dims3()?; + + let qkv = self.qkv_proj.forward(x)?; + let hidden_size = self.num_attention_heads * self.head_dim; + + let q = qkv.i((.., .., ..self.num_attention_heads * self.head_dim))?; + let k = qkv.i(( + .., + .., + self.num_attention_heads * self.head_dim + ..self.num_attention_heads * self.head_dim + + self.num_key_value_heads * self.head_dim, + ))?; + let v = qkv.i(( + .., + .., + self.num_attention_heads * self.head_dim + self.num_key_value_heads * self.head_dim.., + ))?; + // todo!("Q {:?} K {:?} V {:?} - x {:?}", q.shape(), k.shape(), v.shape(), x.shape()); + + let q = q + .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let k = k + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let mut v = v + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + + let q = self.apply_rotary_emb(&q, index_pos)?; + let mut k = self.apply_rotary_emb(&k, index_pos)?; + + let mut cache = self.cache.kvs.lock().unwrap(); + if let Some((cache_k, cache_v)) = &cache[block_idx] { + k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; + let k_seq_len = k.dims()[1]; + if k_seq_len > MAX_SEQ_LEN { + k = k + .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? + } + let v_seq_len = v.dims()[1]; + if v_seq_len > 2 * MAX_SEQ_LEN { + v = v + .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? + } + } + cache[block_idx] = Some((k.clone(), v.clone())); + + let k = self.repeat_kv(k)?; + let v = self.repeat_kv(v)?; + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); + let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)? + .reshape((b_sz, seq_len, hidden_size))?; + let y = self.o_proj.forward(&y)?; + Ok(y) + } + + fn repeat_kv(&self, x: Tensor) -> Result { + let n_rep = self.num_attention_heads / self.num_key_value_heads; + candle_transformers::utils::repeat_kv(x, n_rep) + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { + let qkv_proj = TensorParallelColumnLinear::load_multi( + vb.clone(), + &["q_proj", "k_proj", "v_proj"], + comm.clone(), + )?; + let o_proj = TensorParallelRowLinear::load(vb.pp("o_proj"), comm.clone())?; + Ok(Self { + qkv_proj, + o_proj, + num_attention_heads: cfg.num_attention_heads / comm.world_size(), + num_key_value_heads: cfg.num_key_value_heads() / comm.world_size(), + head_dim: cfg.hidden_size / cfg.num_attention_heads, + cache: cache.clone(), + }) + } +} + +struct Mlp { + c_fc1: TensorParallelColumnLinear, + c_fc2: TensorParallelColumnLinear, + c_proj: TensorParallelRowLinear, +} + +impl Mlp { + fn forward(&self, x: &Tensor) -> Result { + let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + dbg!(x.shape()); + self.c_proj.forward(&x) + } + + fn load(vb: VarBuilder, _cfg: &Config, comm: Rc) -> Result { + let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; + let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; + let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?; + Ok(Self { + c_fc1, + c_fc2, + c_proj, + }) + } +} + +struct Block { + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Mlp, +} + +fn rms_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get_with_hints(size, "weight", shard(0, 0, 1))?; + Ok(RmsNorm::new(weight, eps)) +} + +impl Block { + fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + Self { + rms_1, + attn, + rms_2, + mlp, + } + } + + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + let residual = x; + let x = self.rms_1.forward(x)?; + let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; + let residual = &x; + let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; + Ok(x) + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { + let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?; + let mlp = Mlp::load(vb.pp("mlp"), cfg, comm)?; + let input_layernorm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("input_layernorm"))?; + let post_attention_layernorm = + rms_norm(cfg.hidden_size, 1e-5, vb.pp("post_attention_layernorm"))?; + Ok(Self::new( + input_layernorm, + attn, + post_attention_layernorm, + mlp, + )) + } +} + +pub struct Llama { + wte: Embedding, + blocks: Vec, + ln_f: RmsNorm, + lm_head: Linear, +} + +impl Llama { + fn new(wte: Embedding, blocks: Vec, ln_f: RmsNorm, lm_head: Linear) -> Self { + Self { + wte, + blocks, + ln_f, + lm_head, + } + } + + pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result { + let (_b_sz, seq_len) = x.shape().dims2()?; + let mut x = self.wte.forward(x)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + x = block.forward(&x, index_pos, block_idx)?; + } + let x = self.ln_f.forward(&x)?; + let x = x.i((.., seq_len - 1, ..))?; + let logits = self.lm_head.forward(&x)?; + dbg!(logits.shape()); + logits.to_dtype(DType::F32) + } + + pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { + let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let norm: RmsNorm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("model.norm"))?; + let blocks: Vec<_> = (0..cfg.num_hidden_layers) + .map(|i| { + Block::load( + vb.pp(&format!("model.layers.{i}")), + cache, + cfg, + comm.clone(), + ) + }) + .collect::>>()?; + Ok(Self::new(wte, blocks, norm, lm_head)) + } +} diff --git a/atoma-multi-gpu/src/models/mod.rs b/atoma-multi-gpu/src/models/mod.rs new file mode 100644 index 00000000..93183888 --- /dev/null +++ b/atoma-multi-gpu/src/models/mod.rs @@ -0,0 +1 @@ +pub mod llama; diff --git a/atoma-multi-gpu/src/token_output_stream.rs b/atoma-multi-gpu/src/token_output_stream.rs new file mode 100644 index 00000000..49bdb5dc --- /dev/null +++ b/atoma-multi-gpu/src/token_output_stream.rs @@ -0,0 +1,90 @@ +use crate::{bail, ModelError}; + +/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a +/// streaming way rather than having to wait for the full decoding. +pub struct TokenOutputStream { + tokenizer: tokenizers::Tokenizer, + tokens: Vec, + prev_index: usize, + current_index: usize, +} + +impl TokenOutputStream { + pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { + Self { + tokenizer, + tokens: Vec::new(), + prev_index: 0, + current_index: 0, + } + } + + pub fn into_inner(self) -> tokenizers::Tokenizer { + self.tokenizer + } + + fn decode(&self, tokens: &[u32]) -> Result { + match self.tokenizer.decode(tokens, true) { + Ok(str) => Ok(str), + Err(err) => bail!("cannot decode: {err}"), + } + } + + // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 + pub fn next_token(&mut self, token: u32) -> Result, ModelError> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + self.tokens.push(token); + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { + let text = text.split_at(prev_text.len()); + self.prev_index = self.current_index; + self.current_index = self.tokens.len(); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_rest(&self) -> Result, ModelError> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() { + let text = text.split_at(prev_text.len()); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_all(&self) -> Result { + self.decode(&self.tokens) + } + + pub fn get_token(&self, token_s: &str) -> Option { + self.tokenizer.get_vocab(true).get(token_s).copied() + } + + pub fn tokenizer(&self) -> &tokenizers::Tokenizer { + &self.tokenizer + } + + pub fn get_num_generated_tokens(&self) -> usize { + self.tokens.len() + } + + pub fn clear(&mut self) { + self.tokens.clear(); + self.prev_index = 0; + self.current_index = 0; + } +} diff --git a/atoma-multi-gpu/src/types.rs b/atoma-multi-gpu/src/types.rs new file mode 100644 index 00000000..3d970987 --- /dev/null +++ b/atoma-multi-gpu/src/types.rs @@ -0,0 +1,54 @@ +use cudarc::nccl::result::NcclError; +use serde::Serialize; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ModelError { + // NcclError(NcclError), + #[error("{0}")] + CandleError(#[from] candle::Error), + #[error("{0}")] + IoError(#[from] std::io::Error), + #[error("{0}")] + SerdeJsonError(#[from] serde_json::Error), + #[error("Error: `{0}`")] + BoxedError(#[from] Box), + // TokenizerError(tokenizers::Error), + #[error("Error: `{0}`")] + Msg(String), + #[error("ApiError error: `{0}`")] + ApiError(#[from] hf_hub::api::sync::ApiError), + #[error("NcclError error: `{0:?}`")] + NcclError(NcclError), +} + +pub struct TextModelInput { + pub prompt: String, + pub temperature: f64, + pub random_seed: u64, + pub repeat_penalty: f32, + pub repeat_last_n: usize, + pub max_tokens: usize, + pub top_k: Option, + pub top_p: Option, +} + +#[derive(Serialize)] +pub struct TextModelOutput { + pub text: String, + pub time: f64, + pub tokens_count: usize, +} + +#[macro_export] +macro_rules! bail { + ($msg:literal $(,)?) => { + return Err(ModelError::Msg(format!($msg).into())) + }; + ($err:expr $(,)?) => { + return Err(ModelError::Msg(format!($err).into()).bt()) + }; + ($fmt:expr, $($arg:tt)*) => { + return Err(ModelError::Msg(format!($fmt, $($arg)*).into()).bt()) + }; +} diff --git a/atoma-types/src/lib.rs b/atoma-types/src/lib.rs index 0dc63279..c4e34d7b 100644 --- a/atoma-types/src/lib.rs +++ b/atoma-types/src/lib.rs @@ -1,5 +1,6 @@ -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::Value; +use std::{fmt::Display, path::PathBuf}; #[derive(Clone, Debug, Deserialize)] pub struct Request { @@ -64,3 +65,73 @@ impl Response { self.response.clone() } } + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct TextModelInput { + pub prompt: String, + pub temperature: f64, + pub random_seed: u64, + pub repeat_penalty: f32, + pub repeat_last_n: usize, + pub max_tokens: usize, + pub top_k: Option, + pub top_p: Option, +} + +impl TextModelInput { + #[allow(clippy::too_many_arguments)] + pub fn new( + prompt: String, + temperature: f64, + random_seed: u64, + repeat_penalty: f32, + repeat_last_n: usize, + max_tokens: usize, + top_k: Option, + top_p: Option, + ) -> Self { + Self { + prompt, + temperature, + random_seed, + repeat_penalty, + repeat_last_n, + max_tokens, + top_k, + top_p, + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct TextModelOutput { + pub text: String, + pub time: f64, + pub tokens_count: usize, +} + +impl Display for TextModelOutput { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Output: {}\nTime: {}\nTokens count: {}", + self.text, self.time, self.tokens_count + ) + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum AtomaChildMessage { + Initialized(Option>), + CommsReady, + Loaded, + InferenceResult(TextModelOutput), +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum AtomaInferenceMessage { + InitializeComm(Vec), + LoadModel(PathBuf, String, Vec, PathBuf), + Inference(TextModelInput), + Exit, +}