Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Apr 2, 2024
1 parent 32975c3 commit f5b92af
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
14 changes: 12 additions & 2 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::time::Duration;

use ed25519_consensus::SigningKey as PrivateKey;
use hf_hub::api::sync::Api;
use inference::{
models::{
candle::mamba::MambaModel,
config::ModelConfig,
types::{TextRequest, TextResponse},
},
service::{ModelService, ModelServiceError},
Expand All @@ -16,9 +18,17 @@ async fn main() -> Result<(), ModelServiceError> {
let (req_sender, req_receiver) = tokio::sync::mpsc::channel::<TextRequest>(32);
let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::<TextResponse>(32);

let model_config = ModelConfig::from_file_path("../inference.toml".parse().unwrap());
let private_key_bytes =
std::fs::read("../private_key").map_err(ModelServiceError::PrivateKeyError)?;
let private_key_bytes: [u8; 32] = private_key_bytes
.try_into()
.expect("Incorrect private key bytes length");

let private_key = PrivateKey::from(private_key_bytes);
let mut service = ModelService::start::<MambaModel, Api>(
"../inference.toml".parse().unwrap(),
"../private_key".parse().unwrap(),
model_config,
private_key,
req_receiver,
resp_sender,
)
Expand Down
21 changes: 6 additions & 15 deletions atoma-inference/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,16 @@ where
Resp: std::fmt::Debug + Response,
{
pub fn start<M, F>(
config_file_path: PathBuf,
private_key_path: PathBuf,
model_config: ModelConfig,
private_key: PrivateKey,
request_receiver: Receiver<Req>,
response_sender: Sender<Resp>,
) -> Result<Self, ModelServiceError>
where
M: ModelTrait<Input = Req::ModelInput, Output = Resp::ModelOutput> + Send + 'static,
F: ApiTrait + Send + Sync + 'static,
{
let private_key_bytes =
std::fs::read(private_key_path).map_err(ModelServiceError::PrivateKeyError)?;
let private_key_bytes: [u8; 32] = private_key_bytes
.try_into()
.expect("Incorrect private key bytes length");

let private_key = PrivateKey::from(private_key_bytes);
let public_key = private_key.verification_key();
let model_config = ModelConfig::from_file_path(config_file_path);

let flush_storage = model_config.flush_storage();
let storage_path = model_config.storage_path();
Expand Down Expand Up @@ -232,10 +224,8 @@ mod tests {
#[tokio::test]
async fn test_inference_service_initialization() {
const CONFIG_FILE_PATH: &str = "./inference.toml";
const PRIVATE_KEY_FILE_PATH: &str = "./private_key";

let private_key = PrivateKey::new(OsRng);
std::fs::write(PRIVATE_KEY_FILE_PATH, private_key.to_bytes()).unwrap();

let config_data = Value::Table(toml! {
api_key = "your_api_key"
Expand All @@ -255,15 +245,16 @@ mod tests {
let (_, req_receiver) = tokio::sync::mpsc::channel::<()>(1);
let (resp_sender, _) = tokio::sync::mpsc::channel::<()>(1);

let config = ModelConfig::from_file_path(CONFIG_FILE_PATH.parse().unwrap());

let _ = ModelService::<(), ()>::start::<TestModelInstance, MockApi>(
PathBuf::from(CONFIG_FILE_PATH),
PathBuf::from(PRIVATE_KEY_FILE_PATH),
config,
private_key,
req_receiver,
resp_sender,
)
.unwrap();

std::fs::remove_file(CONFIG_FILE_PATH).unwrap();
std::fs::remove_file(PRIVATE_KEY_FILE_PATH).unwrap();
}
}

0 comments on commit f5b92af

Please sign in to comment.