-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
resolve conflicts after merging main
- Loading branch information
Showing
17 changed files
with
704 additions
and
336 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,6 @@ | ||
pub mod apis; | ||
pub mod candle; | ||
pub mod model_thread; | ||
pub mod models; | ||
pub mod service; | ||
pub mod specs; | ||
pub mod types; | ||
|
||
pub mod apis; | ||
pub mod models; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,68 @@ | ||
// use hf_hub::api::sync::Api; | ||
// use inference::service::ModelService; | ||
use std::time::Duration; | ||
|
||
use ed25519_consensus::SigningKey as PrivateKey; | ||
use hf_hub::api::sync::Api; | ||
use inference::{ | ||
models::{ | ||
candle::mamba::MambaModel, | ||
config::ModelConfig, | ||
types::{TextRequest, TextResponse}, | ||
}, | ||
service::{ModelService, ModelServiceError}, | ||
}; | ||
|
||
#[tokio::main] | ||
async fn main() { | ||
async fn main() -> Result<(), ModelServiceError> { | ||
tracing_subscriber::fmt::init(); | ||
|
||
// let (_, receiver) = tokio::sync::mpsc::channel(32); | ||
let (req_sender, req_receiver) = tokio::sync::mpsc::channel::<TextRequest>(32); | ||
let (resp_sender, mut resp_receiver) = tokio::sync::mpsc::channel::<TextResponse>(32); | ||
|
||
let model_config = ModelConfig::from_file_path("../inference.toml".parse().unwrap()); | ||
let private_key_bytes = | ||
std::fs::read("../private_key").map_err(ModelServiceError::PrivateKeyError)?; | ||
let private_key_bytes: [u8; 32] = private_key_bytes | ||
.try_into() | ||
.expect("Incorrect private key bytes length"); | ||
|
||
let private_key = PrivateKey::from(private_key_bytes); | ||
let mut service = ModelService::start::<MambaModel, Api>( | ||
model_config, | ||
private_key, | ||
req_receiver, | ||
resp_sender, | ||
) | ||
.expect("Failed to start inference service"); | ||
|
||
let pk = service.public_key(); | ||
|
||
tokio::spawn(async move { | ||
service.run().await?; | ||
Ok::<(), ModelServiceError>(()) | ||
}); | ||
|
||
tokio::time::sleep(Duration::from_millis(5000)).await; | ||
|
||
req_sender | ||
.send(TextRequest { | ||
request_id: 0, | ||
prompt: "Leon, the professional is a movie".to_string(), | ||
model: "state-spaces/mamba-130m".to_string(), | ||
max_tokens: 512, | ||
temperature: Some(0.0), | ||
random_seed: 42, | ||
repeat_last_n: 64, | ||
repeat_penalty: 1.1, | ||
sampled_nodes: vec![pk], | ||
top_p: Some(1.0), | ||
top_k: 10, | ||
}) | ||
.await | ||
.expect("Failed to send request"); | ||
|
||
if let Some(response) = resp_receiver.recv().await { | ||
println!("Got a response: {:?}", response); | ||
} | ||
|
||
// let _ = ModelService::start::<Model, Api>( | ||
// "../inference.toml".parse().unwrap(), | ||
// "../private_key".parse().unwrap(), | ||
// receiver, | ||
// ) | ||
// .expect("Failed to start inference service"); | ||
Ok(()) | ||
} |
Oops, something went wrong.