Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Apr 4, 2024
1 parent d9c18bb commit bddd088
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 10 deletions.
14 changes: 9 additions & 5 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,22 @@ where
let mut model_senders = HashMap::new();

for model_config in config.models() {
info!("Spawning new thread for model: {}", model_config.model_id);
info!("Spawning new thread for model: {}", model_config.model_id());
let api = api.clone();

let (model_sender, model_receiver) = mpsc::channel::<ModelThreadCommand<_, _>>();
let model_name = model_config.model_id.clone();
let model_name = model_config.model_id().clone();
model_senders.insert(model_name.clone(), model_sender.clone());

let join_handle = std::thread::spawn(move || {
info!("Fetching files for model: {model_name}");
let filenames = api.fetch(model_name, model_config.revision)?;
let filenames = api.fetch(model_name, model_config.revision())?;

let model = M::load(filenames, model_config.precision, model_config.device_id)?;
let model = M::load(
filenames,
model_config.precision(),
model_config.device_id(),
)?;
let model_thread = ModelThread {
model,
receiver: model_receiver,
Expand All @@ -155,7 +160,6 @@ where
join_handle,
sender: model_sender.clone(),
});
model_senders.insert(model_config.model_id, model_sender);
}

let model_dispatcher = ModelThreadDispatcher {
Expand Down
47 changes: 42 additions & 5 deletions atoma-inference/src/models/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,42 @@ type Revision = String;

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModelConfig {
pub model_id: ModelId,
pub precision: PrecisionBits,
pub revision: Revision,
pub device_id: usize,
model_id: ModelId,
precision: PrecisionBits,
revision: Revision,
device_id: usize,
}

impl ModelConfig {
pub fn new(
model_id: ModelId,
precision: PrecisionBits,
revision: Revision,
device_id: usize,
) -> Self {
Self {
model_id,
precision,
revision,
device_id,
}
}

pub fn model_id(&self) -> &ModelId {
&self.model_id
}

pub fn precision(&self) -> PrecisionBits {
self.precision
}

pub fn revision(&self) -> Revision {
self.revision.clone()
}

pub fn device_id(&self) -> usize {
self.device_id
}
}

#[derive(Debug, Deserialize, Serialize)]
Expand Down Expand Up @@ -114,7 +146,12 @@ pub mod tests {
let config = ModelsConfig::new(
String::from("my_key"),
true,
vec![("Llama2_7b".to_string(), PrecisionBits::F16, "".to_string())],
vec![ModelConfig::new(
"Llama2_7b".to_string(),
PrecisionBits::F16,
"".to_string(),
0,
)],
"storage_path".parse().unwrap(),
true,
);
Expand Down

0 comments on commit bddd088

Please sign in to comment.