Skip to content

Commit

Permalink
f
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Mar 28, 2024
1 parent 96ada78 commit 4205697
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
4 changes: 2 additions & 2 deletions atoma-inference/src/candle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> candle::Resu
Ok(())
}

pub fn save_tensor_to_file(tensor: &Tensor, path: &PathBuf) -> Result<(), candle::Error> {
pub fn save_tensor_to_file(tensor: &Tensor, filename: &str) -> Result<(), candle::Error> {
let json_output = serde_json::to_string(
&tensor
.to_device(&Device::Cpu)?
Expand All @@ -84,7 +84,7 @@ pub fn save_tensor_to_file(tensor: &Tensor, path: &PathBuf) -> Result<(), candle
.to_vec1::<f64>()?,
)
.unwrap();
let mut file = File::create(path)?;
let mut file = File::create(PathBuf::from(filename))?;
file.write_all(json_output.as_bytes())?;
Ok(())
}
7 changes: 5 additions & 2 deletions atoma-inference/src/candle/stable_diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use tokenizers::Tokenizer;

use crate::{candle::device, models::ModelError};

use super::CandleModel;
use super::{save_tensor_to_file, CandleModel};

pub struct Input {
prompt: String,
Expand Down Expand Up @@ -303,10 +303,13 @@ impl CandleModel for StableDiffusion {

latents = scheduler.step(&noise_pred, timestep, &latents)?;
}

save_tensor_to_file(&latents, "tensor1");
let image = vae.decode(&(&latents / vae_scale)?)?;
save_tensor_to_file(&image, "tensor2");
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
save_tensor_to_file(&image, "tensor3");
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
save_tensor_to_file(&image, "tensor4");
res.push(image);
}
Ok(res)
Expand Down
4 changes: 2 additions & 2 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ fn llama1() {
}

fn main() {
// stable_diffusion();
llama1();
stable_diffusion();
// llama1();
// let result = llama::run("One day I will").unwrap();
// println!("{}", result);

Expand Down

0 comments on commit 4205697

Please sign in to comment.