diff --git a/atoma-inference/src/candle/stable_diffusion.rs b/atoma-inference/src/candle/stable_diffusion.rs index a8a8b5c7..4122c260 100644 --- a/atoma-inference/src/candle/stable_diffusion.rs +++ b/atoma-inference/src/candle/stable_diffusion.rs @@ -304,6 +304,7 @@ impl CandleModel for StableDiffusion { .to_vec1::()?, ) .unwrap(); + println!("time step: {}", timestep); let mut file = File::create(format!("output-{timestep_index}.json"))?; file.write_all(json_output.as_bytes())?; @@ -315,9 +316,35 @@ impl CandleModel for StableDiffusion { let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; + let json_output = serde_json::to_string( + &latent_model_input + .clone() + .to_device(&Device::Cpu)? + .flatten_all()? + .to_dtype(DType::F64)? + .to_vec1::()?, + ) + .unwrap(); + println!("time step: {}", timestep); + let mut file = File::create(format!("latent_model_input-{timestep_index}.json"))?; + file.write_all(json_output.as_bytes())?; + let noise_pred = unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?; + let json_output = serde_json::to_string( + &noise_pred + .clone() + .to_device(&Device::Cpu)? + .flatten_all()? + .to_dtype(DType::F64)? + .to_vec1::()?, + ) + .unwrap(); + println!("time step: {}", timestep); + let mut file = File::create(format!("noise_pred-{timestep_index}.json"))?; + file.write_all(json_output.as_bytes())?; + let noise_pred = if use_guide_scale { let noise_pred = noise_pred.chunk(2, 0)?; let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);