Skip to content

Commit

Permalink
t
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Mar 28, 2024
1 parent c1dde56 commit edc5690
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions atoma-inference/src/candle/stable_diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ impl CandleModel for StableDiffusion {
.to_vec1::<f64>()?,
)
.unwrap();
println!("time step: {}", timestep);
let mut file = File::create(format!("output-{timestep_index}.json"))?;
file.write_all(json_output.as_bytes())?;

Expand All @@ -315,9 +316,35 @@ impl CandleModel for StableDiffusion {

let latent_model_input =
scheduler.scale_model_input(latent_model_input, timestep)?;
let json_output = serde_json::to_string(
&latent_model_input
.clone()
.to_device(&Device::Cpu)?
.flatten_all()?
.to_dtype(DType::F64)?
.to_vec1::<f64>()?,
)
.unwrap();
println!("time step: {}", timestep);
let mut file = File::create(format!("latent_model_input-{timestep_index}.json"))?;
file.write_all(json_output.as_bytes())?;

let noise_pred =
unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;

let json_output = serde_json::to_string(
&noise_pred
.clone()
.to_device(&Device::Cpu)?
.flatten_all()?
.to_dtype(DType::F64)?
.to_vec1::<f64>()?,
)
.unwrap();
println!("time step: {}", timestep);
let mut file = File::create(format!("noise_pred-{timestep_index}.json"))?;
file.write_all(json_output.as_bytes())?;

let noise_pred = if use_guide_scale {
let noise_pred = noise_pred.chunk(2, 0)?;
let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
Expand Down

0 comments on commit edc5690

Please sign in to comment.