Skip to content

Commit

Permalink
feat: add the option to set device_id for each model
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Apr 4, 2024
1 parent 3370fce commit d9c18bb
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 43 deletions.
4 changes: 2 additions & 2 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use hf_hub::api::sync::Api;
use inference::{
models::{
candle::mamba::MambaModel,
config::ModelConfig,
config::ModelsConfig,
types::{TextRequest, TextResponse},
},
service::{ModelService, ModelServiceError},
Expand All @@ -18,7 +18,7 @@ async fn main() -> Result<(), ModelServiceError> {
let (req_sender, req_receiver) = tokio::sync::mpsc::channel::<TextRequest>(32);
let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::<TextResponse>(32);

let model_config = ModelConfig::from_file_path("../inference.toml".parse().unwrap());
let model_config = ModelsConfig::from_file_path("../inference.toml".parse().unwrap());
let private_key_bytes =
std::fs::read("../private_key").map_err(ModelServiceError::PrivateKeyError)?;
let private_key_bytes: [u8; 32] = private_key_bytes
Expand Down
21 changes: 10 additions & 11 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tracing::{debug, error, info, warn};

use crate::{
apis::{ApiError, ApiTrait},
models::{config::ModelConfig, ModelError, ModelId, ModelTrait, Request, Response},
models::{config::ModelsConfig, ModelError, ModelId, ModelTrait, Request, Response},
};

pub struct ModelThreadCommand<Req, Resp>(Req, oneshot::Sender<Resp>)
Expand Down Expand Up @@ -111,33 +111,32 @@ where
Resp: Response,
{
pub(crate) fn start<M, F>(
config: ModelConfig,
config: ModelsConfig,
public_key: PublicKey,
) -> Result<(Self, Vec<ModelThreadHandle<Req, Resp>>), ModelThreadError>
where
F: ApiTrait + Send + Sync + 'static,
M: ModelTrait<Input = Req::ModelInput, Output = Resp::ModelOutput> + Send + 'static,
{
let model_ids = config.model_ids();
let api_key = config.api_key();
let storage_path = config.storage_path();
let api = Arc::new(F::create(api_key, storage_path)?);

let mut handles = Vec::with_capacity(model_ids.len());
let mut model_senders = HashMap::with_capacity(model_ids.len());
let mut handles = Vec::new();
let mut model_senders = HashMap::new();

for (model_id, precision, revision) in model_ids {
info!("Spawning new thread for model: {model_id}");
for model_config in config.models() {
info!("Spawning new thread for model: {}", model_config.model_id);
let api = api.clone();

let (model_sender, model_receiver) = mpsc::channel::<ModelThreadCommand<_, _>>();
let model_name = model_id.clone();
let model_name = model_config.model_id.clone();

let join_handle = std::thread::spawn(move || {
info!("Fetching files for model: {model_name}");
let filenames = api.fetch(model_name, revision)?;
let filenames = api.fetch(model_name, model_config.revision)?;

let model = M::load(filenames, precision)?;
let model = M::load(filenames, model_config.precision, model_config.device_id)?;
let model_thread = ModelThread {
model,
receiver: model_receiver,
Expand All @@ -156,7 +155,7 @@ where
join_handle,
sender: model_sender.clone(),
});
model_senders.insert(model_id, model_sender);
model_senders.insert(model_config.model_id, model_sender);
}

let model_dispatcher = ModelThreadDispatcher {
Expand Down
27 changes: 13 additions & 14 deletions atoma-inference/src/models/candle/mamba.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use std::{path::PathBuf, time::Instant};

use candle::{
utils::{cuda_is_available, metal_is_available},
DType, Device, Tensor,
};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
generation::LogitsProcessor,
Expand All @@ -15,8 +12,12 @@ use tracing::info;

use crate::{
bail,
models::types::{PrecisionBits, TextModelInput},
models::{token_output_stream::TokenOutputStream, ModelError, ModelId, ModelTrait},
models::{
candle::device,
token_output_stream::TokenOutputStream,
types::{PrecisionBits, TextModelInput},
ModelError, ModelId, ModelTrait,
},
};

pub struct MambaModel {
Expand Down Expand Up @@ -53,7 +54,11 @@ impl ModelTrait for MambaModel {
type Input = TextModelInput;
type Output = String;

fn load(filenames: Vec<PathBuf>, precision: PrecisionBits) -> Result<Self, ModelError>
fn load(
filenames: Vec<PathBuf>,
precision: PrecisionBits,
device_id: usize,
) -> Result<Self, ModelError>
where
Self: Sized,
{
Expand All @@ -70,13 +75,7 @@ impl ModelTrait for MambaModel {
let config: Config =
serde_json::from_slice(&std::fs::read(config_filename).map_err(ModelError::IoError)?)
.map_err(ModelError::DeserializeError)?;
let device = if cuda_is_available() {
Device::new_cuda(0).map_err(ModelError::CandleError)?
} else if metal_is_available() {
Device::new_metal(0).map_err(ModelError::CandleError)?
} else {
Device::Cpu
};
let device = device(device_id)?;
let dtype = precision.into_dtype();

info!("Loading model weights..");
Expand Down
6 changes: 3 additions & 3 deletions atoma-inference/src/models/candle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ use super::ModelError;
pub mod mamba;
pub mod stable_diffusion;

pub fn device() -> Result<Device, candle::Error> {
pub fn device(device_id: usize) -> Result<Device, candle::Error> {
if cuda_is_available() {
info!("Using CUDA");
Device::new_cuda(0)
Device::new_cuda(device_id)
} else if metal_is_available() {
info!("Using Metal");
Device::new_metal(0)
Device::new_metal(device_id)
} else {
info!("Using Cpu");
Ok(Device::Cpu)
Expand Down
9 changes: 6 additions & 3 deletions atoma-inference/src/models/candle/stable_diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ impl From<&Input> for Fetch {
}
}
}
pub struct StableDiffusion {}
pub struct StableDiffusion {
device_id: usize,
}

pub struct Fetch {
tokenizer: Option<String>,
Expand All @@ -116,11 +118,12 @@ impl ModelTrait for StableDiffusion {
fn load(
_filenames: Vec<std::path::PathBuf>,
_precision: PrecisionBits,
device_id: usize,
) -> Result<Self, ModelError>
where
Self: Sized,
{
Ok(Self {})
Ok(Self { device_id })
}

fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError> {
Expand Down Expand Up @@ -202,7 +205,7 @@ impl ModelTrait for StableDiffusion {
};

let scheduler = sd_config.build_scheduler(n_steps)?;
let device = device()?;
let device = device(self.device_id)?;
if let Some(seed) = input.seed {
device.set_seed(seed)?;
}
Expand Down
20 changes: 14 additions & 6 deletions atoma-inference/src/models/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,28 @@ use crate::{models::types::PrecisionBits, models::ModelId};

type Revision = String;

#[derive(Debug, Deserialize, Serialize)]
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModelConfig {
pub model_id: ModelId,
pub precision: PrecisionBits,
pub revision: Revision,
pub device_id: usize,
}

#[derive(Debug, Deserialize, Serialize)]
pub struct ModelsConfig {
api_key: String,
flush_storage: bool,
models: Vec<(ModelId, PrecisionBits, Revision)>,
models: Vec<ModelConfig>,
storage_path: PathBuf,
tracing: bool,
}

impl ModelConfig {
impl ModelsConfig {
pub fn new(
api_key: String,
flush_storage: bool,
models: Vec<(ModelId, PrecisionBits, Revision)>,
models: Vec<ModelConfig>,
storage_path: PathBuf,
tracing: bool,
) -> Self {
Expand All @@ -42,7 +50,7 @@ impl ModelConfig {
self.flush_storage
}

pub fn model_ids(&self) -> Vec<(ModelId, PrecisionBits, Revision)> {
pub fn models(&self) -> Vec<ModelConfig> {
self.models.clone()
}

Expand Down Expand Up @@ -103,7 +111,7 @@ pub mod tests {

#[test]
fn test_config() {
let config = ModelConfig::new(
let config = ModelsConfig::new(
String::from("my_key"),
true,
vec![("Llama2_7b".to_string(), PrecisionBits::F16, "".to_string())],
Expand Down
6 changes: 5 additions & 1 deletion atoma-inference/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ pub trait ModelTrait {
fn fetch(_fetch: &Self::Fetch) -> Result<(), ModelError> {
Ok(())
}
fn load(filenames: Vec<PathBuf>, precision: PrecisionBits) -> Result<Self, ModelError>
fn load(
filenames: Vec<PathBuf>,
precision: PrecisionBits,
device_id: usize,
) -> Result<Self, ModelError>
where
Self: Sized;
fn model_id(&self) -> ModelId;
Expand Down
6 changes: 3 additions & 3 deletions atoma-inference/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use thiserror::Error;
use crate::{
apis::{ApiError, ApiTrait},
model_thread::{ModelThreadDispatcher, ModelThreadError, ModelThreadHandle},
models::{config::ModelConfig, ModelTrait, Request, Response},
models::{config::ModelsConfig, ModelTrait, Request, Response},
};

pub struct ModelService<Req, Resp>
Expand All @@ -34,7 +34,7 @@ where
Resp: std::fmt::Debug + Response,
{
pub fn start<M, F>(
model_config: ModelConfig,
model_config: ModelsConfig,
private_key: PrivateKey,
request_receiver: Receiver<Req>,
response_sender: Sender<Resp>,
Expand Down Expand Up @@ -250,7 +250,7 @@ mod tests {
let (_, req_receiver) = tokio::sync::mpsc::channel::<()>(1);
let (resp_sender, _) = tokio::sync::mpsc::channel::<()>(1);

let config = ModelConfig::from_file_path(CONFIG_FILE_PATH.parse().unwrap());
let config = ModelsConfig::from_file_path(CONFIG_FILE_PATH.parse().unwrap());

let _ = ModelService::<(), ()>::start::<TestModelInstance, MockApi>(
config,
Expand Down

0 comments on commit d9c18bb

Please sign in to comment.