Skip to content

Commit

Permalink
Merge pull request #26 from atoma-network/20240405-refactor-model-tra…
Browse files Browse the repository at this point in the history
…it-method-signatures

refactor model trait method signatures
  • Loading branch information
jorgeantonio21 authored Apr 8, 2024
2 parents f5538b1 + 45a43ec commit 74b88a4
Show file tree
Hide file tree
Showing 12 changed files with 1,038 additions and 703 deletions.
3 changes: 1 addition & 2 deletions atoma-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@ dotenv.workspace = true
ed25519-consensus.workspace = true
futures.workspace = true
hf-hub.workspace = true
reqwest = { workspace = true, features = ["json"] }
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
image = { workspace = true }
thiserror.workspace = true
tokenizers = { workspace = true, features = ["onig"] }
tokenizers = { workspace = true }
tokio = { workspace = true, features = ["full", "tracing"] }
tracing.workspace = true
tracing-subscriber.workspace = true
Expand Down
57 changes: 35 additions & 22 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
use std::time::Duration;

use ed25519_consensus::SigningKey as PrivateKey;
use hf_hub::api::sync::Api;
use inference::{
models::{candle::mamba::MambaModel, config::ModelsConfig, types::TextRequest},
models::{config::ModelsConfig, types::StableDiffusionRequest},
service::{ModelService, ModelServiceError},
};

#[tokio::main]
async fn main() -> Result<(), ModelServiceError> {
tracing_subscriber::fmt::init();

let (req_sender, req_receiver) = tokio::sync::mpsc::channel::<serde_json::Value>(32);
let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::<serde_json::Value>(32);
let (req_sender, req_receiver) = tokio::sync::mpsc::channel(32);
let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel(32);

let model_config = ModelsConfig::from_file_path("../inference.toml".parse().unwrap());
let private_key_bytes =
Expand All @@ -22,13 +21,8 @@ async fn main() -> Result<(), ModelServiceError> {
.expect("Incorrect private key bytes length");

let private_key = PrivateKey::from(private_key_bytes);
let mut service = ModelService::start::<MambaModel, Api>(
model_config,
private_key,
req_receiver,
resp_sender,
)
.expect("Failed to start inference service");
let mut service = ModelService::start(model_config, private_key, req_receiver, resp_sender)
.expect("Failed to start inference service");

let pk = service.public_key();

Expand All @@ -37,22 +31,41 @@ async fn main() -> Result<(), ModelServiceError> {
Ok::<(), ModelServiceError>(())
});

tokio::time::sleep(Duration::from_millis(5000)).await;
tokio::time::sleep(Duration::from_millis(5_000)).await;

// req_sender
// .send(serde_json::to_value(TextRequest {
// request_id: 0,
// prompt: "Leon, the professional is a movie".to_string(),
// model: "llama_tiny_llama_1_1b_chat".to_string(),
// max_tokens: 512,
// temperature: Some(0.0),
// random_seed: 42,
// repeat_last_n: 64,
// repeat_penalty: 1.1,
// sampled_nodes: vec![pk],
// top_p: Some(1.0),
// _top_k: 10,
// }).unwrap())
// .await
// .expect("Failed to send request");

req_sender
.send(
serde_json::to_value(TextRequest {
serde_json::to_value(StableDiffusionRequest {
request_id: 0,
prompt: "Leon, the professional is a movie".to_string(),
model: "state-spaces/mamba-130m".to_string(),
max_tokens: 512,
temperature: Some(0.0),
random_seed: 42,
repeat_last_n: 64,
repeat_penalty: 1.1,
prompt: "A depiction of Natalie Portman".to_string(),
uncond_prompt: "".to_string(),
height: Some(256),
width: Some(256),
num_samples: 1,
n_steps: None,
model: "stable_diffusion_v1-5".to_string(),
guidance_scale: None,
img2img: None,
img2img_strength: 0.8,
random_seed: Some(42),
sampled_nodes: vec![pk],
top_p: Some(1.0),
top_k: 10,
})
.unwrap(),
)
Expand Down
175 changes: 130 additions & 45 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::{
collections::HashMap,
sync::{mpsc, Arc},
collections::HashMap, fmt::Debug, path::PathBuf, str::FromStr, sync::mpsc, thread::JoinHandle,
};

use ed25519_consensus::VerificationKey as PublicKey;
Expand All @@ -10,8 +9,16 @@ use tokio::sync::oneshot::{self, error::RecvError};
use tracing::{debug, error, info, warn};

use crate::{
apis::{ApiError, ApiTrait},
models::{config::ModelsConfig, ModelError, ModelId, ModelTrait},
apis::ApiError,
models::{
candle::{
falcon::FalconModel, llama::LlamaModel, mamba::MambaModel,
stable_diffusion::StableDiffusion,
},
config::{ModelConfig, ModelsConfig},
types::ModelType,
ModelError, ModelId, ModelTrait,
},
};

pub struct ModelThreadCommand {
Expand Down Expand Up @@ -73,17 +80,13 @@ where
response_sender,
} = command;

// TODO: Implement node authorization
// if !request.is_node_authorized(&public_key) {
// error!("Current node, with verification key = {:?} is not authorized to run request with id = {}", public_key, request.request_id());
// continue;
// }

let model_input = serde_json::from_value(request).unwrap();
let model_output = self
.model
.run(model_input)
.map_err(ModelThreadError::ModelError)?;
let model_input = serde_json::from_value(request)?;
let model_output = self.model.run(model_input)?;
let response = serde_json::to_value(model_output)?;
response_sender.send(response).ok();
}
Expand All @@ -98,49 +101,35 @@ pub struct ModelThreadDispatcher {
}

impl ModelThreadDispatcher {
pub(crate) fn start<M, F>(
pub(crate) fn start(
config: ModelsConfig,
public_key: PublicKey,
) -> Result<(Self, Vec<ModelThreadHandle>), ModelThreadError>
where
F: ApiTrait + Send + Sync + 'static,
M: ModelTrait, //<Input = Req::ModelInput, Output = Resp::ModelOutput> + Send + 'static,
{
let api_key = config.api_key();
let storage_path = config.storage_path();
let api = Arc::new(F::create(api_key, storage_path)?);

) -> Result<(Self, Vec<ModelThreadHandle>), ModelThreadError> {
let mut handles = Vec::new();
let mut model_senders = HashMap::new();

let api_key = config.api_key();
let cache_dir = config.cache_dir();

for model_config in config.models() {
info!("Spawning new thread for model: {}", model_config.model_id());
let api = api.clone();

let (model_sender, model_receiver) = mpsc::channel::<ModelThreadCommand>();
let model_name = model_config.model_id().clone();
let model_type = ModelType::from_str(&model_name)?;

let (model_sender, model_receiver) = mpsc::channel::<ModelThreadCommand>();
model_senders.insert(model_name.clone(), model_sender.clone());

let join_handle = std::thread::spawn(move || {
info!("Fetching files for model: {model_name}");
let filenames = api.fetch(model_name, model_config.revision())?;
let x = serde_json::from_value(model_config.params().clone()).unwrap();

let model = M::load(filenames, x, model_config.device_id())?;
let model_thread = ModelThread {
model,
receiver: model_receiver,
};

if let Err(e) = model_thread.run(public_key) {
error!("Model thread error: {e}");
if !matches!(e, ModelThreadError::Shutdown(_)) {
panic!("Fatal error occurred: {e}");
}
}

Ok(())
});
let join_handle = dispatch_model_thread(
api_key.clone(),
cache_dir.clone(),
model_name,
model_type,
model_config,
public_key,
model_receiver,
);

handles.push(ModelThreadHandle {
join_handle,
sender: model_sender.clone(),
Expand All @@ -157,10 +146,15 @@ impl ModelThreadDispatcher {

fn send(&self, command: ModelThreadCommand) {
let request = command.request.clone();
let model_id = request.get("model").unwrap().as_str().unwrap().to_string();
println!("model_id {model_id}");
let model_id = if let Some(model_id) = request.get("model") {
model_id.as_str().unwrap().to_string()
} else {
error!("Request malformed: Missing model_id from request");
return;
};

info!("model_id {model_id}");

println!("{:?}", self.model_senders);
let sender = self
.model_senders
.get(&model_id)
Expand All @@ -182,3 +176,94 @@ impl ModelThreadDispatcher {
self.responses.push(receiver);
}
}

fn dispatch_model_thread(
api_key: String,
cache_dir: PathBuf,
model_name: String,
model_type: ModelType,
model_config: ModelConfig,
public_key: PublicKey,
model_receiver: mpsc::Receiver<ModelThreadCommand>,
) -> JoinHandle<Result<(), ModelThreadError>> {
match model_type {
ModelType::Falcon7b | ModelType::Falcon40b | ModelType::Falcon180b => {
spawn_model_thread::<FalconModel>(
model_name,
api_key.clone(),
cache_dir.clone(),
model_config,
public_key,
model_receiver,
)
}
ModelType::LlamaV1
| ModelType::LlamaV2
| ModelType::LlamaTinyLlama1_1BChat
| ModelType::LlamaSolar10_7B => spawn_model_thread::<LlamaModel>(
model_name,
api_key,
cache_dir,
model_config,
public_key,
model_receiver,
),
ModelType::Mamba130m
| ModelType::Mamba370m
| ModelType::Mamba790m
| ModelType::Mamba1_4b
| ModelType::Mamba2_8b => spawn_model_thread::<MambaModel>(
model_name,
api_key,
cache_dir,
model_config,
public_key,
model_receiver,
),
ModelType::Mistral7b => todo!(),
ModelType::Mixtral8x7b => todo!(),
ModelType::StableDiffusionV1_5
| ModelType::StableDiffusionV2_1
| ModelType::StableDiffusionTurbo
| ModelType::StableDiffusionXl => spawn_model_thread::<StableDiffusion>(
model_name,
api_key,
cache_dir,
model_config,
public_key,
model_receiver,
),
}
}

fn spawn_model_thread<M>(
model_name: String,
api_key: String,
cache_dir: PathBuf,
model_config: ModelConfig,
public_key: PublicKey,
model_receiver: mpsc::Receiver<ModelThreadCommand>,
) -> JoinHandle<Result<(), ModelThreadError>>
where
M: ModelTrait + Send + 'static,
{
std::thread::spawn(move || {
info!("Fetching files for model: {model_name}");
let load_data = M::fetch(api_key, cache_dir, model_config)?;

let model = M::load(load_data)?;
let model_thread = ModelThread {
model,
receiver: model_receiver,
};

if let Err(e) = model_thread.run(public_key) {
error!("Model thread error: {e}");
if !matches!(e, ModelThreadError::Shutdown(_)) {
panic!("Fatal error occurred: {e}");
}
}

Ok(())
})
}
Loading

0 comments on commit 74b88a4

Please sign in to comment.