Skip to content

Commit

Permalink
resolve conflicts after merging main
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Apr 3, 2024
2 parents 5e55a46 + eff8f1a commit c531d5b
Show file tree
Hide file tree
Showing 17 changed files with 704 additions and 336 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ candle-flash-attn = { git = "https://github.com/huggingface/candle", package = "
candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" }
candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" }
config = "0.14.0"
dotenv = "0.15.0"
ed25519-consensus = "2.1.0"
futures = "0.3.30"
hf-hub = "0.3.2"
Expand Down
3 changes: 2 additions & 1 deletion atoma-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ candle.workspace = true
candle-flash-attn = { workspace = true, optional = true }
candle-nn.workspace = true
candle-transformers.workspace = true
config.true = true
config.workspace = true
dotenv.workspace = true
ed25519-consensus.workspace = true
futures.workspace = true
hf-hub.workspace = true
Expand Down
120 changes: 26 additions & 94 deletions atoma-inference/src/apis/hugging_face.rs
Original file line number Diff line number Diff line change
@@ -1,95 +1,16 @@
use std::path::PathBuf;

use async_trait::async_trait;
use hf_hub::api::sync::{Api, ApiBuilder};
use hf_hub::{
api::sync::{Api, ApiBuilder},
Repo, RepoType,
};
use tracing::error;

use crate::models::ModelId;

use super::{ApiError, ApiTrait};

struct FilePaths {
file_paths: Vec<String>,
}

fn get_model_safe_tensors_from_hf(model_id: &ModelId) -> (String, FilePaths) {
match model_id.as_str() {
"Llama2_7b" => (
String::from("meta-llama/Llama-2-7b-hf"),
FilePaths {
file_paths: vec![
"model-00001-of-00002.safetensors".to_string(),
"model-00002-of-00002.safetensors".to_string(),
],
},
),
"Mamba3b" => (
String::from("state-spaces/mamba-2.8b-hf"),
FilePaths {
file_paths: vec![
"model-00001-of-00003.safetensors".to_string(),
"model-00002-of-00003.safetensors".to_string(),
"model-00003-of-00003.safetensors".to_string(),
],
},
),
"Mistral7b" => (
String::from("mistralai/Mistral-7B-Instruct-v0.2"),
FilePaths {
file_paths: vec![
"model-00001-of-00003.safetensors".to_string(),
"model-00002-of-00003.safetensors".to_string(),
"model-00003-of-00003.safetensors".to_string(),
],
},
),
"Mixtral8x7b" => (
String::from("mistralai/Mixtral-8x7B-Instruct-v0.1"),
FilePaths {
file_paths: vec![
"model-00001-of-00019.safetensors".to_string(),
"model-00002-of-00019.safetensors".to_string(),
"model-00003-of-00019.safetensors".to_string(),
"model-00004-of-00019.safetensors".to_string(),
"model-00005-of-00019.safetensors".to_string(),
"model-00006-of-00019.safetensors".to_string(),
"model-00007-of-00019.safetensors".to_string(),
"model-00008-of-00019.safetensors".to_string(),
"model-00009-of-00019.safetensors".to_string(),
"model-000010-of-00019.safetensors".to_string(),
"model-000011-of-00019.safetensors".to_string(),
"model-000012-of-00019.safetensors".to_string(),
"model-000013-of-00019.safetensors".to_string(),
"model-000014-of-00019.safetensors".to_string(),
"model-000015-of-00019.safetensors".to_string(),
"model-000016-of-00019.safetensors".to_string(),
"model-000017-of-00019.safetensors".to_string(),
"model-000018-of-00019.safetensors".to_string(),
"model-000019-of-00019.safetensors".to_string(),
],
},
),
"StableDiffusion2" => (
String::from("stabilityai/stable-diffusion-2"),
FilePaths {
file_paths: vec!["768-v-ema.safetensors".to_string()],
},
),
"StableDiffusionXl" => (
String::from("stabilityai/stable-diffusion-xl-base-1.0"),
FilePaths {
file_paths: vec![
"sd_xl_base_1.0.safetensors".to_string(),
"sd_xl_base_1.0_0.9vae.safetensors".to_string(),
"sd_xl_offset_example-lora_1.0.safetensors".to_string(),
],
},
),
_ => {
panic!("Invalid model id")
}
}
}

#[async_trait]
impl ApiTrait for Api {
fn create(api_key: String, cache_dir: PathBuf) -> Result<Self, ApiError>
Expand All @@ -103,15 +24,26 @@ impl ApiTrait for Api {
.build()?)
}

fn fetch(&self, model_id: &ModelId) -> Result<Vec<PathBuf>, ApiError> {
let (model_path, files) = get_model_safe_tensors_from_hf(model_id);
let api_repo = self.model(model_path);
let mut path_bufs = Vec::with_capacity(files.file_paths.len());

for file in files.file_paths {
path_bufs.push(api_repo.get(&file)?);
}

Ok(path_bufs)
fn fetch(&self, model_id: ModelId, revision: String) -> Result<Vec<PathBuf>, ApiError> {
let repo = self.repo(Repo::with_revision(
model_id.clone(),
RepoType::Model,
revision,
));

Ok(vec![
repo.get("config.json")?,
if model_id.contains("mamba") {
self.model("EleutherAI/gpt-neox-20b".to_string())
.get("tokenizer.json")
.map_err(|e| {
error!("Failed to fetch tokenizer file: {e}");
e
})?
} else {
repo.get("tokenizer.json")?
},
repo.get("model.safetensors")?,
])
}
}
2 changes: 1 addition & 1 deletion atoma-inference/src/apis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl From<HuggingFaceError> for ApiError {
}

pub trait ApiTrait: Send {
fn fetch(&self, model_id: &ModelId) -> Result<Vec<PathBuf>, ApiError>;
fn fetch(&self, model_id: ModelId, revision: String) -> Result<Vec<PathBuf>, ApiError>;
fn create(api_key: String, cache_dir: PathBuf) -> Result<Self, ApiError>
where
Self: Sized;
Expand Down
89 changes: 0 additions & 89 deletions atoma-inference/src/candle/mod.rs

This file was deleted.

7 changes: 3 additions & 4 deletions atoma-inference/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
pub mod apis;
pub mod candle;
pub mod model_thread;
pub mod models;
pub mod service;
pub mod specs;
pub mod types;

pub mod apis;
pub mod models;
72 changes: 62 additions & 10 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,68 @@
// use hf_hub::api::sync::Api;
// use inference::service::ModelService;
use std::time::Duration;

use ed25519_consensus::SigningKey as PrivateKey;
use hf_hub::api::sync::Api;
use inference::{
models::{
candle::mamba::MambaModel,
config::ModelConfig,
types::{TextRequest, TextResponse},
},
service::{ModelService, ModelServiceError},
};

#[tokio::main]
async fn main() {
async fn main() -> Result<(), ModelServiceError> {
tracing_subscriber::fmt::init();

// let (_, receiver) = tokio::sync::mpsc::channel(32);
let (req_sender, req_receiver) = tokio::sync::mpsc::channel::<TextRequest>(32);
let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::<TextResponse>(32);

let model_config = ModelConfig::from_file_path("../inference.toml".parse().unwrap());
let private_key_bytes =
std::fs::read("../private_key").map_err(ModelServiceError::PrivateKeyError)?;
let private_key_bytes: [u8; 32] = private_key_bytes
.try_into()
.expect("Incorrect private key bytes length");

let private_key = PrivateKey::from(private_key_bytes);
let mut service = ModelService::start::<MambaModel, Api>(
model_config,
private_key,
req_receiver,
resp_sender,
)
.expect("Failed to start inference service");

let pk = service.public_key();

tokio::spawn(async move {
service.run().await?;
Ok::<(), ModelServiceError>(())
});

tokio::time::sleep(Duration::from_millis(5000)).await;

req_sender
.send(TextRequest {
request_id: 0,
prompt: "Leon, the professional is a movie".to_string(),
model: "state-spaces/mamba-130m".to_string(),
max_tokens: 512,
temperature: Some(0.0),
random_seed: 42,
repeat_last_n: 64,
repeat_penalty: 1.1,
sampled_nodes: vec![pk],
top_p: Some(1.0),
top_k: 10,
})
.await
.expect("Failed to send request");

if let Some(response) = resp_receiver.recv().await {
println!("Got a response: {:?}", response);
}

// let _ = ModelService::start::<Model, Api>(
// "../inference.toml".parse().unwrap(),
// "../private_key".parse().unwrap(),
// receiver,
// )
// .expect("Failed to start inference service");
Ok(())
}
Loading

0 comments on commit c531d5b

Please sign in to comment.