From b6d03125a4b819012b816686b3961511e66c4073 Mon Sep 17 00:00:00 2001 From: Martin Stefcek Date: Thu, 4 Apr 2024 09:35:53 +0400 Subject: [PATCH] fixes --- Cargo.toml | 2 -- atoma-inference/Cargo.toml | 2 -- atoma-inference/src/models/candle/mod.rs | 16 ++++++++++------ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1758ec34..50995b9b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ version = "0.1.0" [workspace.dependencies] reqwest = "0.12.1" -anyhow = "1.0.81" async-trait = "0.1.78" candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" } candle-flash-attn = { git = "https://github.com/huggingface/candle", package = "candle-flash-attn", version = "0.4.2" } @@ -26,7 +25,6 @@ dotenv = "0.15.0" ed25519-consensus = "2.1.0" futures = "0.3.30" hf-hub = "0.3.2" -clap = "4.5.3" image = { version = "0.25.0", default-features = false, features = [ "jpeg", "png", diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index c6037821..3da618eb 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -4,7 +4,6 @@ version.workspace = true edition = "2021" [dependencies] -anyhow.workspace = true async-trait.workspace = true candle.workspace = true candle-flash-attn = { workspace = true, optional = true } @@ -18,7 +17,6 @@ hf-hub.workspace = true reqwest = { workspace = true, features = ["json"] } serde = { workspace = true, features = ["derive"] } serde_json.workspace = true -clap.workspace = true image = { workspace = true } thiserror.workspace = true tokenizers = { workspace = true, features = ["onig"] } diff --git a/atoma-inference/src/models/candle/mod.rs b/atoma-inference/src/models/candle/mod.rs index de38d19d..a10f6f6f 100644 --- a/atoma-inference/src/models/candle/mod.rs +++ b/atoma-inference/src/models/candle/mod.rs @@ -6,6 +6,10 @@ use candle::{ }; use tracing::info; +use crate::bail; + +use super::ModelError; + pub mod mamba; pub mod stable_diffusion; @@ -25,15 +29,15 @@ pub fn device() -> Result { pub fn hub_load_safetensors( repo: &hf_hub::api::sync::ApiRepo, json_file: &str, -) -> candle::Result> { +) -> Result, ModelError> { let json_file = repo.get(json_file).map_err(candle::Error::wrap)?; let json_file = std::fs::File::open(json_file)?; let json: serde_json::Value = serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; let weight_map = match json.get("weight_map") { - None => candle::bail!("no weight map in {json_file:?}"), + None => bail!("no weight map in {json_file:?}"), Some(serde_json::Value::Object(map)) => map, - Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), + Some(_) => bail!("weight map in {json_file:?} is not a map"), }; let mut safetensors_files = std::collections::HashSet::new(); for value in weight_map.values() { @@ -48,18 +52,18 @@ pub fn hub_load_safetensors( Ok(safetensors_files) } -pub fn save_image>(img: &Tensor, p: P) -> candle::Result<()> { +pub fn save_image>(img: &Tensor, p: P) -> Result<(), ModelError> { let p = p.as_ref(); let (channel, height, width) = img.dims3()?; if channel != 3 { - candle::bail!("save_image expects an input of shape (3, height, width)") + bail!("save_image expects an input of shape (3, height, width)") } let img = img.permute((1, 2, 0))?.flatten_all()?; let pixels = img.to_vec1::()?; let image: image::ImageBuffer, Vec> = match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { Some(image) => image, - None => candle::bail!("error saving image {p:?}"), + None => bail!("error saving image {p:?}"), }; image.save(p).map_err(candle::Error::wrap)?; Ok(())