Skip to content

Commit

Permalink
simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Mar 31, 2024
1 parent 4a12b71 commit cddb534
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
5 changes: 4 additions & 1 deletion atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ where
U: Response,
{
pub(crate) fn start<M, F>(
api: F,
config: ModelConfig,
public_key: PublicKey,
) -> Result<(Self, Vec<ModelThreadHandle<T, U>>), ModelThreadError>
Expand All @@ -119,6 +118,10 @@ where
+ 'static,
{
let model_ids = config.model_ids();
let api_key = config.api_key();
let storage_path = config.storage_path();
let api = F::create(api_key, storage_path)?;

let mut handles = Vec::with_capacity(model_ids.len());
let mut model_senders = HashMap::with_capacity(model_ids.len());

Expand Down
14 changes: 7 additions & 7 deletions atoma-inference/src/models/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@ use crate::models::ModelId;
pub struct ModelConfig {
api_key: String,
models: Vec<ModelId>,
storage_folder: PathBuf,
storage_path: PathBuf,
tracing: bool,
}

impl ModelConfig {
pub fn new(
api_key: String,
models: Vec<ModelId>,
storage_folder: PathBuf,
storage_path: PathBuf,
tracing: bool,
) -> Self {
Self {
api_key,
models,
storage_folder,
storage_path,
tracing,
}
}
Expand All @@ -36,8 +36,8 @@ impl ModelConfig {
self.models.clone()
}

pub fn storage_folder(&self) -> PathBuf {
self.storage_folder.clone()
pub fn storage_path(&self) -> PathBuf {
self.storage_path.clone()
}

pub fn tracing(&self) -> bool {
Expand Down Expand Up @@ -66,12 +66,12 @@ pub mod tests {
let config = ModelConfig::new(
String::from("my_key"),
vec!["Llama2_7b".to_string()],
"storage_folder".parse().unwrap(),
"storage_path".parse().unwrap(),
true,
);

let toml_str = toml::to_string(&config).unwrap();
let should_be_toml_str = "api_key = \"my_key\"\nmodels = [\"Llama2_7b\"]\nstorage_folder = \"storage_folder\"\ntracing = true\n";
let should_be_toml_str = "api_key = \"my_key\"\nmodels = [\"Llama2_7b\"]\nstorage_path = \"storage_path\"\ntracing = true\n";
assert_eq!(toml_str, should_be_toml_str);
}
}
8 changes: 2 additions & 6 deletions atoma-inference/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,9 @@ where
let private_key = PrivateKey::from(private_key_bytes);
let public_key = private_key.verification_key();
let model_config = ModelConfig::from_file_path(config_file_path);
let api_key = model_config.api_key();
let storage_folder = model_config.storage_folder();

let api = F::create(api_key, storage_folder)?;

let (dispatcher, model_thread_handle) =
ModelThreadDispatcher::start::<M, F>(api, model_config, public_key)
ModelThreadDispatcher::start::<M, F>(model_config, public_key)
.map_err(ModelServiceError::ModelThreadError)?;
let start_time = Instant::now();

Expand Down Expand Up @@ -236,7 +232,7 @@ mod tests {
let config_data = Value::Table(toml! {
api_key = "your_api_key"
models = ["Mamba3b"]
storage_folder = "./storage_folder/"
storage_path = "./storage_path/"
tokenizer_file_path = "./tokenizer_file_path/"
tracing = true
});
Expand Down

0 comments on commit cddb534

Please sign in to comment.