Skip to content

Commit

Permalink
Merge pull request #24 from atoma-network/change-to-serde-json-value
Browse files Browse the repository at this point in the history
feat: change request/response to json value
  • Loading branch information
jorgeantonio21 authored Apr 5, 2024
2 parents c2f064c + 6419af3 commit f5538b1
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 162 deletions.
39 changes: 19 additions & 20 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,16 @@ use std::time::Duration;
use ed25519_consensus::SigningKey as PrivateKey;
use hf_hub::api::sync::Api;
use inference::{
models::{
candle::mamba::MambaModel,
config::ModelsConfig,
types::{TextRequest, TextResponse},
},
models::{candle::mamba::MambaModel, config::ModelsConfig, types::TextRequest},
service::{ModelService, ModelServiceError},
};

#[tokio::main]
async fn main() -> Result<(), ModelServiceError> {
tracing_subscriber::fmt::init();

let (req_sender, req_receiver) = tokio::sync::mpsc::channel::<TextRequest>(32);
let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::<TextResponse>(32);
let (req_sender, req_receiver) = tokio::sync::mpsc::channel::<serde_json::Value>(32);
let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::<serde_json::Value>(32);

let model_config = ModelsConfig::from_file_path("../inference.toml".parse().unwrap());
let private_key_bytes =
Expand Down Expand Up @@ -44,19 +40,22 @@ async fn main() -> Result<(), ModelServiceError> {
tokio::time::sleep(Duration::from_millis(5000)).await;

req_sender
.send(TextRequest {
request_id: 0,
prompt: "Leon, the professional is a movie".to_string(),
model: "state-spaces/mamba-130m".to_string(),
max_tokens: 512,
temperature: Some(0.0),
random_seed: 42,
repeat_last_n: 64,
repeat_penalty: 1.1,
sampled_nodes: vec![pk],
top_p: Some(1.0),
top_k: 10,
})
.send(
serde_json::to_value(TextRequest {
request_id: 0,
prompt: "Leon, the professional is a movie".to_string(),
model: "state-spaces/mamba-130m".to_string(),
max_tokens: 512,
temperature: Some(0.0),
random_seed: 42,
repeat_last_n: 64,
repeat_penalty: 1.1,
sampled_nodes: vec![pk],
top_p: Some(1.0),
top_k: 10,
})
.unwrap(),
)
.await
.expect("Failed to send request");

Expand Down
112 changes: 49 additions & 63 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ use tracing::{debug, error, info, warn};

use crate::{
apis::{ApiError, ApiTrait},
models::{config::ModelsConfig, ModelError, ModelId, ModelTrait, Request, Response},
models::{config::ModelsConfig, ModelError, ModelId, ModelTrait},
};

pub struct ModelThreadCommand<Req, Resp>(Req, oneshot::Sender<Resp>)
where
Req: Request,
Resp: Response;
pub struct ModelThreadCommand {
request: serde_json::Value,
response_sender: oneshot::Sender<serde_json::Value>,
}

#[derive(Debug, Error)]
pub enum ModelThreadError {
Expand All @@ -27,6 +27,8 @@ pub enum ModelThreadError {
ModelError(ModelError),
#[error("Core thread shutdown: `{0}`")]
Shutdown(RecvError),
#[error("Serde error: `{0}`")]
SerdeError(#[from] serde_json::Error),
}

impl From<ModelError> for ModelThreadError {
Expand All @@ -41,82 +43,68 @@ impl From<ApiError> for ModelThreadError {
}
}

pub struct ModelThreadHandle<Req, Resp>
where
Req: Request,
Resp: Response,
{
sender: mpsc::Sender<ModelThreadCommand<Req, Resp>>,
pub struct ModelThreadHandle {
sender: mpsc::Sender<ModelThreadCommand>,
join_handle: std::thread::JoinHandle<Result<(), ModelThreadError>>,
}

impl<Req, Resp> ModelThreadHandle<Req, Resp>
where
Req: Request,
Resp: Response,
{
impl ModelThreadHandle {
pub fn stop(self) {
drop(self.sender);
self.join_handle.join().ok();
}
}

pub struct ModelThread<M: ModelTrait, Req: Request, Resp: Response> {
pub struct ModelThread<M: ModelTrait> {
model: M,
receiver: mpsc::Receiver<ModelThreadCommand<Req, Resp>>,
receiver: mpsc::Receiver<ModelThreadCommand>,
}

impl<M, Req, Resp> ModelThread<M, Req, Resp>
impl<M> ModelThread<M>
where
M: ModelTrait<Input = Req::ModelInput, Output = Resp::ModelOutput>,
Req: Request,
Resp: Response,
M: ModelTrait,
{
pub fn run(mut self, public_key: PublicKey) -> Result<(), ModelThreadError> {
pub fn run(mut self, _public_key: PublicKey) -> Result<(), ModelThreadError> {
debug!("Start Model thread");

while let Ok(command) = self.receiver.recv() {
let ModelThreadCommand(request, sender) = command;

if !request.is_node_authorized(&public_key) {
error!("Current node, with verification key = {:?} is not authorized to run request with id = {}", public_key, request.request_id());
continue;
}

let model_input = request.into_model_input();
let ModelThreadCommand {
request,
response_sender,
} = command;

// TODO: Implement node authorization
// if !request.is_node_authorized(&public_key) {
// error!("Current node, with verification key = {:?} is not authorized to run request with id = {}", public_key, request.request_id());
// continue;
// }

let model_input = serde_json::from_value(request).unwrap();
let model_output = self
.model
.run(model_input)
.map_err(ModelThreadError::ModelError)?;
let response = Resp::from_model_output(model_output);
sender.send(response).ok();
let response = serde_json::to_value(model_output)?;
response_sender.send(response).ok();
}

Ok(())
}
}

pub struct ModelThreadDispatcher<Req, Resp>
where
Req: Request,
Resp: Response,
{
model_senders: HashMap<ModelId, mpsc::Sender<ModelThreadCommand<Req, Resp>>>,
pub(crate) responses: FuturesUnordered<oneshot::Receiver<Resp>>,
pub struct ModelThreadDispatcher {
model_senders: HashMap<ModelId, mpsc::Sender<ModelThreadCommand>>,
pub(crate) responses: FuturesUnordered<oneshot::Receiver<serde_json::Value>>,
}

impl<Req, Resp> ModelThreadDispatcher<Req, Resp>
where
Req: Clone + Request,
Resp: Response,
{
impl ModelThreadDispatcher {
pub(crate) fn start<M, F>(
config: ModelsConfig,
public_key: PublicKey,
) -> Result<(Self, Vec<ModelThreadHandle<Req, Resp>>), ModelThreadError>
) -> Result<(Self, Vec<ModelThreadHandle>), ModelThreadError>
where
F: ApiTrait + Send + Sync + 'static,
M: ModelTrait<Input = Req::ModelInput, Output = Resp::ModelOutput> + Send + 'static,
M: ModelTrait, //<Input = Req::ModelInput, Output = Resp::ModelOutput> + Send + 'static,
{
let api_key = config.api_key();
let storage_path = config.storage_path();
Expand All @@ -129,19 +117,16 @@ where
info!("Spawning new thread for model: {}", model_config.model_id());
let api = api.clone();

let (model_sender, model_receiver) = mpsc::channel::<ModelThreadCommand<_, _>>();
let (model_sender, model_receiver) = mpsc::channel::<ModelThreadCommand>();
let model_name = model_config.model_id().clone();
model_senders.insert(model_name.clone(), model_sender.clone());

let join_handle = std::thread::spawn(move || {
info!("Fetching files for model: {model_name}");
let filenames = api.fetch(model_name, model_config.revision())?;
let x = serde_json::from_value(model_config.params().clone()).unwrap();

let model = M::load(
filenames,
model_config.precision(),
model_config.device_id(),
)?;
let model = M::load(filenames, x, model_config.device_id())?;
let model_thread = ModelThread {
model,
receiver: model_receiver,
Expand Down Expand Up @@ -170,10 +155,12 @@ where
Ok((model_dispatcher, handles))
}

fn send(&self, command: ModelThreadCommand<Req, Resp>) {
let request = command.0.clone();
let model_id = request.requested_model();
fn send(&self, command: ModelThreadCommand) {
let request = command.request.clone();
let model_id = request.get("model").unwrap().as_str().unwrap().to_string();
println!("model_id {model_id}");

println!("{:?}", self.model_senders);
let sender = self
.model_senders
.get(&model_id)
Expand All @@ -185,14 +172,13 @@ where
}
}

impl<T, U> ModelThreadDispatcher<T, U>
where
T: Clone + Request,
U: Response,
{
pub(crate) fn run_inference(&self, request: T) {
impl ModelThreadDispatcher {
pub(crate) fn run_inference(&self, request: serde_json::Value) {
let (sender, receiver) = oneshot::channel();
self.send(ModelThreadCommand(request, sender));
self.send(ModelThreadCommand {
request,
response_sender: sender,
});
self.responses.push(receiver);
}
}
3 changes: 2 additions & 1 deletion atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ impl ModelTrait for FalconModel {
type Fetch = ();
type Input = TextModelInput;
type Output = String;
type Load = PrecisionBits;

fn fetch(_fetch: &Self::Fetch) -> Result<(), ModelError> {
Ok(())
}

fn load(
filenames: Vec<PathBuf>,
precision: PrecisionBits,
precision: Self::Load,
device_id: usize,
) -> Result<Self, ModelError>
where
Expand Down
Loading

0 comments on commit f5538b1

Please sign in to comment.