Skip to content

Commit

Permalink
remove full dependency of std::sync
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Mar 27, 2024
1 parent 1cdb66a commit b56d0b5
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::{collections::HashMap, sync::mpsc};

use candle_nn::VarBuilder;
use ed25519_consensus::VerificationKey as PublicKey;
Expand All @@ -23,7 +23,7 @@ pub enum ModelThreadError {
}

pub struct ModelThreadHandle {
sender: std::sync::mpsc::Sender<ModelThreadCommand>,
sender: mpsc::Sender<ModelThreadCommand>,
join_handle: std::thread::JoinHandle<()>,
}

Expand All @@ -36,7 +36,7 @@ impl ModelThreadHandle {

pub struct ModelThread<T: ModelApi> {
model: T,
receiver: std::sync::mpsc::Receiver<ModelThreadCommand>,
receiver: mpsc::Receiver<ModelThreadCommand>,
}

impl<T> ModelThread<T>
Expand Down Expand Up @@ -85,7 +85,7 @@ where
}

pub struct ModelThreadDispatcher {
model_senders: HashMap<ModelType, std::sync::mpsc::Sender<ModelThreadCommand>>,
model_senders: HashMap<ModelType, mpsc::Sender<ModelThreadCommand>>,
pub(crate) responses: FuturesUnordered<oneshot::Receiver<InferenceResponse>>,
}

Expand All @@ -101,7 +101,7 @@ impl ModelThreadDispatcher {
let mut model_senders = HashMap::with_capacity(models.len());

for (model_type, model_specs, var_builder) in models {
let (model_sender, model_receiver) = std::sync::mpsc::channel::<ModelThreadCommand>();
let (model_sender, model_receiver) = mpsc::channel::<ModelThreadCommand>();
let model = T::load(model_specs, var_builder); // TODO: for now this piece of code cannot be shared among threads safely
let model_thread = ModelThread {
model,
Expand Down

0 comments on commit b56d0b5

Please sign in to comment.