Skip to content

Commit

Permalink
correct minor bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Apr 6, 2024
1 parent b7115ae commit ceb52cf
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 78 deletions.
63 changes: 45 additions & 18 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ use std::time::Duration;
use ed25519_consensus::SigningKey as PrivateKey;
use inference::{
models::{
candle::falcon::FalconModel,
candle::stable_diffusion::StableDiffusion,
config::ModelsConfig,
types::{TextRequest, TextResponse},
types::{
ModelType, StableDiffusionRequest, StableDiffusionResponse,
},
},
service::{ModelService, ModelServiceError},
};
Expand All @@ -14,8 +16,9 @@ use inference::{
async fn main() -> Result<(), ModelServiceError> {
tracing_subscriber::fmt::init();

let (req_sender, req_receiver) = tokio::sync::mpsc::channel::<TextRequest>(32);
let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::<TextResponse>(32);
let (req_sender, req_receiver) = tokio::sync::mpsc::channel::<StableDiffusionRequest>(32);
let (resp_sender, mut resp_receiver) =
tokio::sync::mpsc::channel::<StableDiffusionResponse>(32);

let model_config = ModelsConfig::from_file_path("../inference.toml".parse().unwrap());
let private_key_bytes =
Expand All @@ -25,9 +28,14 @@ async fn main() -> Result<(), ModelServiceError> {
.expect("Incorrect private key bytes length");

let private_key = PrivateKey::from(private_key_bytes);
let mut service =
ModelService::start::<FalconModel>(model_config, private_key, req_receiver, resp_sender)
.expect("Failed to start inference service");
let mut service: ModelService<StableDiffusionRequest, StableDiffusionResponse> =
ModelService::start::<StableDiffusion>(
model_config,
private_key,
req_receiver,
resp_sender,
)
.expect("Failed to start inference service");

let pk = service.public_key();

Expand All @@ -36,21 +44,40 @@ async fn main() -> Result<(), ModelServiceError> {
Ok::<(), ModelServiceError>(())
});

tokio::time::sleep(Duration::from_millis(50000000)).await;
tokio::time::sleep(Duration::from_millis(5_000)).await;

// req_sender
// .send(TextRequest {
// request_id: 0,
// prompt: "Leon, the professional is a movie".to_string(),
// model: "falcon_7b".to_string(),
// max_tokens: 512,
// temperature: Some(0.0),
// random_seed: 42,
// repeat_last_n: 64,
// repeat_penalty: 1.1,
// sampled_nodes: vec![pk],
// top_p: Some(1.0),
// top_k: 10,
// })
// .await
// .expect("Failed to send request");

req_sender
.send(TextRequest {
.send(StableDiffusionRequest {
request_id: 0,
prompt: "Leon, the professional is a movie".to_string(),
model: "llama_tiny_llama_1_1b_chat".to_string(),
max_tokens: 512,
temperature: Some(0.0),
random_seed: 42,
repeat_last_n: 64,
repeat_penalty: 1.1,
prompt: "A portrait of young Natalie Portman".to_string(),
uncond_prompt: "".to_string(),
height: None,
width: None,
num_samples: 1,
n_steps: None,
model_type: ModelType::StableDiffusionV1_5,
guidance_scale: None,
img2img: None,
img2img_strength: 0.5,
random_seed: Some(42),
sampled_nodes: vec![pk],
top_p: Some(1.0),
top_k: 10,
})
.await
.expect("Failed to send request");
Expand Down
67 changes: 45 additions & 22 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashMap, fmt::Debug, sync::mpsc};
use std::{collections::HashMap, fmt::Debug, path::PathBuf, sync::mpsc, thread::JoinHandle};

use ed25519_consensus::VerificationKey as PublicKey;
use futures::stream::FuturesUnordered;
Expand All @@ -8,7 +8,10 @@ use tracing::{debug, error, info, warn};

use crate::{
apis::ApiError,
models::{config::ModelsConfig, ModelError, ModelId, ModelTrait, Request, Response},
models::{
config::{ModelConfig, ModelsConfig},
ModelError, ModelId, ModelTrait, Request, Response,
},
};

pub struct ModelThreadCommand<Req, Resp>
Expand Down Expand Up @@ -134,31 +137,19 @@ where
for model_config in config.models() {
info!("Spawning new thread for model: {}", model_config.model_id());

let model_api_key = api_key.clone();
let model_cache_dir = cache_dir.clone();
let (model_sender, model_receiver) = mpsc::channel::<ModelThreadCommand<_, _>>();
let model_name = model_config.model_id().clone();
model_senders.insert(model_name.clone(), model_sender.clone());

let join_handle = std::thread::spawn(move || {
info!("Fetching files for model: {model_name}");
let load_data = M::fetch(model_api_key, model_cache_dir, model_config)?;

let model = M::load(load_data)?;
let model_thread = ModelThread {
model,
receiver: model_receiver,
};

if let Err(e) = model_thread.run(public_key) {
error!("Model thread error: {e}");
if !matches!(e, ModelThreadError::Shutdown(_)) {
panic!("Fatal error occurred: {e}");
}
}
let join_handle = Self::start_model_thread::<M>(
model_name,
api_key.clone(),
cache_dir.clone(),
model_config,
public_key,
model_receiver,
);

Ok(())
});
handles.push(ModelThreadHandle {
join_handle,
sender: model_sender.clone(),
Expand All @@ -173,6 +164,38 @@ where
Ok((model_dispatcher, handles))
}

fn start_model_thread<M>(
model_name: String,
api_key: String,
cache_dir: PathBuf,
model_config: ModelConfig,
public_key: PublicKey,
model_receiver: mpsc::Receiver<ModelThreadCommand<Req, Resp>>,
) -> JoinHandle<Result<(), ModelThreadError>>
where
M: ModelTrait<Input = Req::ModelInput, Output = Resp::ModelOutput> + Send + 'static,
{
std::thread::spawn(move || {
info!("Fetching files for model: {model_name}");
let load_data = M::fetch(api_key, cache_dir, model_config)?;

let model = M::load(load_data)?;
let model_thread = ModelThread {
model,
receiver: model_receiver,
};

if let Err(e) = model_thread.run(public_key) {
error!("Model thread error: {e}");
if !matches!(e, ModelThreadError::Shutdown(_)) {
panic!("Fatal error occurred: {e}");
}
}

Ok(())
})
}

fn send(&self, command: ModelThreadCommand<Req, Resp>) {
let request = command.request.clone();
let model_id = request.requested_model();
Expand Down
23 changes: 16 additions & 7 deletions atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ use candle_transformers::{
};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use tokenizers::Tokenizer;
use tracing::{debug, info};
use tracing::{debug, error, info};

use crate::models::{
candle::hub_load_safetensors, config::ModelConfig, types::{LlmLoadData, ModelType, TextModelInput}, ModelError, ModelTrait
candle::hub_load_safetensors,
config::ModelConfig,
types::{LlmLoadData, ModelType, TextModelInput},
ModelError, ModelTrait,
};

use super::device;
Expand Down Expand Up @@ -68,12 +71,19 @@ impl ModelTrait for FalconModel {

info!("{repo_id} <> {revision}");

let repo = api.repo(Repo::with_revision(repo_id, RepoType::Model, revision));

let mut file_paths = vec![];
let repo = api.repo(Repo::new(repo_id.clone(), RepoType::Model));
file_paths.push(repo.get("config.json")?);

let repo = api.repo(Repo::with_revision(repo_id, RepoType::Model, revision));
file_paths.push(repo.get("tokenizer.json")?);
file_paths.extend(hub_load_safetensors(&repo, "model.safetensors.index.json")?);

file_paths.extend(
hub_load_safetensors(&repo, "model.safetensors.index.json").map_err(|e| {
error!("{e}");
e
})?,
);

Ok(Self::LoadData {
device,
Expand All @@ -97,11 +107,10 @@ impl ModelTrait for FalconModel {
let weights_filenames = load_data.file_paths[2..].to_vec();

let tokenizer = Tokenizer::from_file(tokenizer_filename)?;

let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
config.validate()?;

if load_data.dtype != DType::BF16 || load_data.dtype != DType::F32 {
if load_data.dtype != DType::BF16 && load_data.dtype != DType::F32 {
panic!("Invalid dtype, it must be either BF16 or F32 precision");
}

Expand Down
54 changes: 24 additions & 30 deletions atoma-inference/src/models/candle/stable_diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D};
use hf_hub::api::sync::ApiBuilder;
use serde::Deserialize;
use tokenizers::Tokenizer;
use tracing::info;

use crate::{
bail,
Expand All @@ -18,32 +19,32 @@ use crate::{
use super::{convert_to_image, device, save_tensor_to_file};

#[derive(Deserialize)]
pub struct Input {
prompt: String,
uncond_prompt: String,
pub struct StableDiffusionInput {
pub prompt: String,
pub uncond_prompt: String,

height: Option<usize>,
width: Option<usize>,
pub height: Option<usize>,
pub width: Option<usize>,

/// The number of steps to run the diffusion for.
n_steps: Option<usize>,
pub n_steps: Option<usize>,

/// The number of samples to generate.
num_samples: i64,
pub num_samples: i64,

sd_version: StableDiffusionVersion,
pub model_type: ModelType,

guidance_scale: Option<f64>,
pub guidance_scale: Option<f64>,

img2img: Option<String>,
pub img2img: Option<String>,

/// The strength, indicates how much to transform the initial image. The
/// value must be between 0 and 1, a value of 1 discards the initial image
/// information.
img2img_strength: f64,
pub img2img_strength: f64,

/// The seed to use when generating random samples.
seed: Option<u64>,
pub random_seed: Option<u64>,
}

pub struct StableDiffusionLoadData {
Expand Down Expand Up @@ -72,7 +73,7 @@ pub struct StableDiffusion {
}

impl ModelTrait for StableDiffusion {
type Input = Input;
type Input = StableDiffusionInput;
type Output = Vec<(Vec<u8>, usize, usize)>;
type LoadData = StableDiffusionLoadData;

Expand Down Expand Up @@ -236,8 +237,8 @@ impl ModelTrait for StableDiffusion {
)))?
}

// self.config.height = input.height;
// self.config.width = input.width;
// self.config.height = input.height.unwrap_or(512);
// self.config.width = input.width.unwrap_or(512);

let guidance_scale = match input.guidance_scale {
Some(guidance_scale) => guidance_scale,
Expand All @@ -261,7 +262,7 @@ impl ModelTrait for StableDiffusion {
};

let scheduler = self.config.build_scheduler(n_steps)?;
if let Some(seed) = input.seed {
if let Some(seed) = input.random_seed {
self.device.set_seed(seed)?;
}
let use_guide_scale = guidance_scale > 1.0;
Expand Down Expand Up @@ -306,11 +307,12 @@ impl ModelTrait for StableDiffusion {
};
let bsize = 1;

let vae_scale = match input.sd_version {
StableDiffusionVersion::V1_5
| StableDiffusionVersion::V2_1
| StableDiffusionVersion::Xl => 0.18215,
StableDiffusionVersion::Turbo => 0.13025,
let vae_scale = match input.model_type {
ModelType::StableDiffusionV1_5
| ModelType::StableDiffusionV2_1
| ModelType::StableDiffusionXl => 0.18215,
ModelType::StableDiffusionTurbo => 0.13025,
_ => bail!("Invalid stable diffusion model type"),
};
let mut res = Vec::new();

Expand Down Expand Up @@ -355,6 +357,7 @@ impl ModelTrait for StableDiffusion {
latents.clone()
};

info!("FLAG: {:?}", latent_model_input.shape());
let latent_model_input =
scheduler.scale_model_input(latent_model_input, timestep)?;
let noise_pred =
Expand Down Expand Up @@ -386,15 +389,6 @@ impl ModelTrait for StableDiffusion {
}
}

#[allow(dead_code)]
#[derive(Clone, Copy, Deserialize)]
enum StableDiffusionVersion {
V1_5,
V2_1,
Xl,
Turbo,
}

impl ModelType {
fn unet_file(&self, use_f16: bool) -> &'static str {
match self {
Expand Down
Loading

0 comments on commit ceb52cf

Please sign in to comment.