Skip to content

Commit

Permalink
t
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Mar 28, 2024
1 parent edc5690 commit b9ad0da
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion atoma-inference/src/candle/stable_diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ extern crate intel_mkl_src;

use candle_transformers::models::stable_diffusion::{self, StableDiffusionConfig};
use image::open;
use std::path::PathBuf;
use std::{fs::File, io::Write, path::PathBuf};

use candle::{DType, Device, IndexOp, Module, Tensor, D};
use tokenizers::Tokenizer;
Expand Down Expand Up @@ -223,6 +223,17 @@ impl CandleModel for StableDiffusion {
.collect::<Result<Vec<_>, _>>()?;

let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
let json_output = serde_json::to_string(
&text_embeddings
.clone()
.to_device(&Device::Cpu)?
.flatten_all()?
.to_dtype(DType::F64)?
.to_vec1::<f64>()?,
)
.unwrap();
let mut file = File::create(format!("text_embeddings.json"))?;
file.write_all(json_output.as_bytes())?;

let vae_weights = ModelFile::Vae.get(input.vae_weights, input.sd_version, input.use_f16)?;
let vae = sd_config.build_vae(vae_weights, &device, dtype)?;
Expand Down

0 comments on commit b9ad0da

Please sign in to comment.