Skip to content

Commit

Permalink
test2
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Mar 28, 2024
1 parent c5f3964 commit e9cd332
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
28 changes: 27 additions & 1 deletion atoma-inference/src/candle/stable_diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ extern crate accelerate_src;
extern crate intel_mkl_src;

use candle_transformers::models::stable_diffusion::{self, StableDiffusionConfig};
use image::open;
use std::path::PathBuf;

use candle::{DType, Device, IndexOp, Module, Tensor, D};
Expand Down Expand Up @@ -271,11 +272,36 @@ impl CandleModel for StableDiffusion {
&device,
)?;
// scale the initial noise by the standard deviation required by the scheduler
(latents * scheduler.init_noise_sigma())?
let sigma = scheduler.init_noise_sigma();
println!("sigma: {}", sigma);
(latents * sigma)?
}
};
let mut latents = latents.to_dtype(dtype)?;

use std::fs::File;
use std::io::Write;

let json_output = serde_json::to_string(
&latents
.clone()
.to_device(&Device::Cpu)?
.flatten_all()?
.to_dtype(DType::F64)?
.to_vec1::<f64>()?,
)
.unwrap();
let mut file = File::create("output.json")?;
file.write_all(json_output.as_bytes())?;
// "latents: {:?}",
// latents
// .clone()
// .to_device(&Device::Cpu)?
// .flatten_all()?
// .to_dtype(DType::F64)?
// .to_vec1::<f64>()?
// );

for (timestep_index, &timestep) in timesteps.iter().enumerate() {
if timestep_index < t_start {
continue;
Expand Down
4 changes: 2 additions & 2 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ fn llama1() {
}

fn main() {
// stable_diffusion();
llama1();
stable_diffusion();
// llama1();
// let result = llama::run("One day I will").unwrap();
// println!("{}", result);

Expand Down
1 change: 1 addition & 0 deletions output.json

Large diffs are not rendered by default.

0 comments on commit e9cd332

Please sign in to comment.