Skip to content

Commit

Permalink
multi-gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Apr 28, 2024
1 parent c72c8ab commit 525e6b6
Show file tree
Hide file tree
Showing 25 changed files with 1,503 additions and 64 deletions.
19 changes: 16 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ members = [
"atoma-event-subscribe/sui",
"atoma-inference",
"atoma-networking",
"atoma-multi-gpu-mock",
"atoma-node",
"atoma-json-rpc",
"atoma-storage",
"atoma-types"
"atoma-types",
]

[workspace.package]
Expand All @@ -29,18 +30,28 @@ atoma-sui = { path = "./atoma-event-subscribe/sui/" }
atoma-types = { path = "./atoma-types" }
axum = "0.7.5"
blake2 = "0.10.6"
bindgen_cuda = { version = "0.1.1" }
candle = { git = "https://github.com/huggingface/candle", package = "candle-core", branch = "main" }
candle-flash-attn = { git = "https://github.com/huggingface/candle", package = "candle-flash-attn", branch = "main" }
candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", branch = "main" }
candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", branch = "main" }
clap = "4.5.4"
config = "0.14.0"
cudarc = { version = "0.10.0", features = ["f16"] }
dotenv = "0.15.0"
ethers = "2.0.14"
futures = "0.3.30"
futures-util = "0.3.30"
half = { version = "2.3.1", features = [
"num-traits",
"use-intrinsics",
"rand_distr",
] }
hf-hub = "0.3.2"
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
image = { version = "0.25.0", default-features = false, features = [
"jpeg",
"png",
] }
rand = "0.8.5"
rayon = "1.10.0"
reqwest = "0.12.1"
Expand All @@ -50,10 +61,12 @@ serde_json = "1.0.114"
# solana-client = "1.18.9"
# solana-sdk = "1.18.8"
sui-keys = { git = "https://github.com/mystenlabs/sui", package = "sui-keys" }
sui-sdk = { git = "https://github.com/mystenlabs/sui", package = "sui-sdk"}
sui-sdk = { git = "https://github.com/mystenlabs/sui", package = "sui-sdk" }
thiserror = "1.0.58"
tokenizers = "0.15.2"
tokio = "1.36.0"
toml = "0.8.12"
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
tungstenite = "0.21.0"
url = { version = "2.1.0" }
1 change: 1 addition & 0 deletions atoma-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ tokenizers = { workspace = true }
tokio = { workspace = true, features = ["full", "tracing"] }
tracing.workspace = true
tracing-subscriber.workspace = true
tungstenite.workspace = true

[dev-dependencies]
rand.workspace = true
Expand Down
3 changes: 2 additions & 1 deletion atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{path::PathBuf, str::FromStr, time::Instant};

use atoma_types::{TextModelInput, TextModelOutput};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand All @@ -14,7 +15,7 @@ use tracing::{debug, error, info};
use crate::models::{
candle::hub_load_safetensors,
config::ModelConfig,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
types::{LlmLoadData, ModelType},
ModelError, ModelTrait,
};

Expand Down
3 changes: 2 additions & 1 deletion atoma-inference/src/models/candle/llama.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{path::PathBuf, str::FromStr, time::Instant};

use atoma_types::{TextModelInput, TextModelOutput};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand All @@ -15,7 +16,7 @@ use tracing::info;
use crate::models::{
config::ModelConfig,
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
types::{LlmLoadData, ModelType},
ModelError, ModelTrait,
};

Expand Down
201 changes: 201 additions & 0 deletions atoma-inference/src/models/candle/llama_multi_gpu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
use std::{
net::TcpListener,
path::PathBuf,
rc::Rc,
str::FromStr,
sync::{Arc, Barrier, Mutex},
thread,
time::Instant,
};

use atoma_types::{AtomaChildMessage, AtomaInferenceMessage, TextModelInput, TextModelOutput};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
generation::LogitsProcessor,
models::llama::{Config, LlamaConfig},
};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};

use candle_transformers::models::llama as model;
use tokenizers::Tokenizer;
use tracing::info;
use tungstenite::accept;

use crate::models::{
config::ModelConfig,
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType},
ModelError, ModelTrait,
};

use super::hub_load_safetensors;

pub struct MultiGpuLlamaModel {
model_type: ModelType,
gpu_instances: Arc<Mutex<Vec<Arc<tungstenite::WebSocket<std::net::TcpStream>>>>>,
sync_barrier: Arc<Barrier>,
result: Arc<Mutex<Option<TextModelOutput>>>,
}

pub struct MultiGpuLlamaLoadData {
file_paths: Vec<PathBuf>,
dtype: String,
model_type: ModelType,
}

impl ModelTrait for MultiGpuLlamaModel {
type Input = TextModelInput;
type Output = TextModelOutput;
type LoadData = MultiGpuLlamaLoadData;

fn fetch(
api_key: String,
cache_dir: PathBuf,
config: ModelConfig,
) -> Result<Self::LoadData, ModelError> {
let api = ApiBuilder::new()
.with_progress(true)
.with_token(Some(api_key))
.with_cache_dir(cache_dir)
.build()?;

let model_type = ModelType::from_str(&config.model_id())?;
let repo_id = model_type.repo().to_string();
let revision = model_type.default_revision().to_string();

let repo = api.repo(Repo::with_revision(
repo_id.clone(),
RepoType::Model,
revision,
));
let config_file_path = repo.get("config.json")?;
let tokenizer_file_path = repo.get("tokenizer.json")?;

let model_weights_file_paths = if &repo_id == "TinyLlama/TinyLlama-1.1B-Chat-v1.0" {
vec![repo.get("model.safetensors")?]
} else {
hub_load_safetensors(&repo, "model.safetensors.index.json")?
};

let mut file_paths = Vec::with_capacity(2 + model_weights_file_paths.len());
file_paths.extend(vec![config_file_path, tokenizer_file_path]);
file_paths.extend(model_weights_file_paths);

Ok(Self::LoadData {
file_paths,
model_type: ModelType::from_str(&config.model_id())?,
dtype: config.dtype(), // use_flash_attention: config.use_flash_attention(),
})
}

fn model_type(&self) -> ModelType {
self.model_type.clone()
}

fn load(load_data: Self::LoadData) -> Result<Self, ModelError> {
let server = TcpListener::bind("127.0.0.1:0").unwrap();
let port = server.local_addr().unwrap().port();
println!("Server listening on port {}", port);
let num_shards = 2;
let mut num_connections = 0;
let gpu_instances = Arc::new(Mutex::new(Vec::with_capacity(num_shards)));
let init_barrier = Arc::new(Barrier::new(num_shards));
let sync_barrier = Arc::new(Barrier::new(num_shards + 1));
let result = Arc::new(Mutex::new(None));
for stream in server.incoming() {
let gpu_instances = Arc::clone(&gpu_instances);
let init_barrier = Arc::clone(&init_barrier);
let sync_barrier = Arc::clone(&sync_barrier);
let config_file_path = load_data.file_paths[0].clone();
let tokenizer_filename = load_data.file_paths[1].clone();
let filenames = load_data.file_paths[2..].to_vec();
let dtype = load_data.dtype.clone();
let result = Arc::clone(&result);
thread::spawn(move || {
let mut websocket = Arc::new(accept(stream.unwrap()).unwrap());
let index = {
let mut gpu_instances = gpu_instances.lock().unwrap();
gpu_instances.push(Arc::clone(&websocket));
gpu_instances.len() - 1
};
loop {
init_barrier.wait();
let msg = Arc::get_mut(&mut websocket).unwrap().read().unwrap();
if msg.is_text() {
let message: AtomaChildMessage =
serde_json::from_str(msg.to_string().as_str()).unwrap();
match message {
AtomaChildMessage::Initialized(nccl_id) => {
if let Some(nccl_id) = nccl_id {
let mut gpu_instances = gpu_instances.lock().unwrap();
for (i, websocket) in gpu_instances.iter_mut().enumerate() {
if i != index {
Arc::get_mut(websocket)
.unwrap()
.send(tungstenite::Message::Text(
serde_json::to_string(
&AtomaInferenceMessage::InitializeComm(
nccl_id.clone(),
),
)
.unwrap(),
))
.unwrap();
}
}
}
}
AtomaChildMessage::CommsReady => {
Arc::get_mut(&mut websocket)
.unwrap()
.send(tungstenite::Message::Text(
serde_json::to_string(&AtomaInferenceMessage::LoadModel(
config_file_path.clone(),
dtype.clone(),
filenames.clone(),
tokenizer_filename.clone(),
))
.unwrap(),
))
.unwrap();
}
AtomaChildMessage::Loaded => {
sync_barrier.wait();
}
AtomaChildMessage::InferenceResult(output) => {
*result.lock().unwrap() = Some(output);
sync_barrier.wait();
}
}
}
}
});
num_connections += 1;
if num_connections == num_shards {
break;
}
}
sync_barrier.wait();
Ok(MultiGpuLlamaModel {
model_type: load_data.model_type,
gpu_instances,
sync_barrier,
result,
})
}

fn run(&mut self, input: Self::Input) -> Result<Self::Output, ModelError> {
for websocket in self.gpu_instances.lock().unwrap().iter_mut() {
Arc::get_mut(websocket)
.unwrap()
.send(tungstenite::Message::Text(
serde_json::to_string(&AtomaInferenceMessage::Inference(input.clone()))
.unwrap(),
))
.unwrap();
}
self.sync_barrier.wait();
Ok(self.result.lock().unwrap().take().unwrap())
}
}
3 changes: 2 additions & 1 deletion atoma-inference/src/models/candle/mamba.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{path::PathBuf, str::FromStr, time::Instant};

use atoma_types::{TextModelInput, TextModelOutput};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand All @@ -17,7 +18,7 @@ use crate::{
candle::device,
config::ModelConfig,
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
types::{LlmLoadData, ModelType},
ModelError, ModelTrait,
},
};
Expand Down
3 changes: 2 additions & 1 deletion atoma-inference/src/models/candle/mistral.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::str::FromStr;

use atoma_types::{TextModelInput, TextModelOutput};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand All @@ -16,7 +17,7 @@ use crate::{
models::{
candle::{device, hub_load_safetensors},
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
types::{LlmLoadData, ModelType},
ModelError, ModelTrait,
},
};
Expand Down
3 changes: 2 additions & 1 deletion atoma-inference/src/models/candle/mixtral.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::str::FromStr;

use atoma_types::{TextModelInput, TextModelOutput};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand All @@ -16,7 +17,7 @@ use crate::{
models::{
candle::{device, hub_load_safetensors},
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
types::{LlmLoadData, ModelType},
ModelError, ModelTrait,
},
};
Expand Down
1 change: 1 addition & 0 deletions atoma-inference/src/models/candle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use super::ModelError;

pub mod falcon;
pub mod llama;
pub mod llama_multi_gpu;
pub mod mamba;
pub mod mistral;
pub mod mixtral;
Expand Down
3 changes: 2 additions & 1 deletion atoma-inference/src/models/candle/quantized.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{path::PathBuf, str::FromStr};

use atoma_types::{TextModelInput, TextModelOutput};
use candle::{
quantized::{ggml_file, gguf_file},
DType, Device, Tensor,
Expand All @@ -16,7 +17,7 @@ use crate::models::{
candle::device,
config::ModelConfig,
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
types::{LlmLoadData, ModelType},
ModelError, ModelTrait,
};
use candle_transformers::models::quantized_llama as model;
Expand Down
Loading

0 comments on commit 525e6b6

Please sign in to comment.