diff --git a/atoma-inference/src/candle/mod.rs b/atoma-inference/src/candle/mod.rs index 045a5a1a..7b5da9fe 100644 --- a/atoma-inference/src/candle/mod.rs +++ b/atoma-inference/src/candle/mod.rs @@ -75,7 +75,7 @@ pub fn save_image>(img: &Tensor, p: P) -> candle::Resu Ok(()) } -pub fn save_tensor_to_file(tensor: &Tensor, path: &PathBuf) -> Result<(), candle::Error> { +pub fn save_tensor_to_file(tensor: &Tensor, filename: &str) -> Result<(), candle::Error> { let json_output = serde_json::to_string( &tensor .to_device(&Device::Cpu)? @@ -84,7 +84,7 @@ pub fn save_tensor_to_file(tensor: &Tensor, path: &PathBuf) -> Result<(), candle .to_vec1::()?, ) .unwrap(); - let mut file = File::create(path)?; + let mut file = File::create(PathBuf::from(filename))?; file.write_all(json_output.as_bytes())?; Ok(()) } diff --git a/atoma-inference/src/candle/stable_diffusion.rs b/atoma-inference/src/candle/stable_diffusion.rs index fa5409de..766f64dd 100644 --- a/atoma-inference/src/candle/stable_diffusion.rs +++ b/atoma-inference/src/candle/stable_diffusion.rs @@ -12,7 +12,7 @@ use tokenizers::Tokenizer; use crate::{candle::device, models::ModelError}; -use super::CandleModel; +use super::{save_tensor_to_file, CandleModel}; pub struct Input { prompt: String, @@ -303,10 +303,13 @@ impl CandleModel for StableDiffusion { latents = scheduler.step(&noise_pred, timestep, &latents)?; } - + save_tensor_to_file(&latents, "tensor1"); let image = vae.decode(&(&latents / vae_scale)?)?; + save_tensor_to_file(&image, "tensor2"); let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; + save_tensor_to_file(&image, "tensor3"); let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?; + save_tensor_to_file(&image, "tensor4"); res.push(image); } Ok(res) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 46ae6c83..1c19d93d 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -27,8 +27,8 @@ fn llama1() { } fn main() { - // stable_diffusion(); - llama1(); + stable_diffusion(); + // llama1(); // let result = llama::run("One day I will").unwrap(); // println!("{}", result);