From b9ad0da88e5d492fa810ca58fc7ea42839061b73 Mon Sep 17 00:00:00 2001 From: Cifko Date: Thu, 28 Mar 2024 09:31:52 +0100 Subject: [PATCH] t --- atoma-inference/src/candle/stable_diffusion.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/atoma-inference/src/candle/stable_diffusion.rs b/atoma-inference/src/candle/stable_diffusion.rs index 4122c260..f93a3714 100644 --- a/atoma-inference/src/candle/stable_diffusion.rs +++ b/atoma-inference/src/candle/stable_diffusion.rs @@ -6,7 +6,7 @@ extern crate intel_mkl_src; use candle_transformers::models::stable_diffusion::{self, StableDiffusionConfig}; use image::open; -use std::path::PathBuf; +use std::{fs::File, io::Write, path::PathBuf}; use candle::{DType, Device, IndexOp, Module, Tensor, D}; use tokenizers::Tokenizer; @@ -223,6 +223,17 @@ impl CandleModel for StableDiffusion { .collect::, _>>()?; let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?; + let json_output = serde_json::to_string( + &text_embeddings + .clone() + .to_device(&Device::Cpu)? + .flatten_all()? + .to_dtype(DType::F64)? + .to_vec1::()?, + ) + .unwrap(); + let mut file = File::create(format!("text_embeddings.json"))?; + file.write_all(json_output.as_bytes())?; let vae_weights = ModelFile::Vae.get(input.vae_weights, input.sd_version, input.use_f16)?; let vae = sd_config.build_vae(vae_weights, &device, dtype)?;