Skip to content

Commit

Permalink
fix2
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Mar 25, 2024
1 parent c3fd6d1 commit b74c8ae
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 32 deletions.
16 changes: 5 additions & 11 deletions atoma-inference/src/candle/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl Default for Config {
}
}

pub fn run(prompt: String, cfg: Config) -> Result<()> {
pub fn run(prompt: String, cfg: Config) -> Result<String> {
let device = device()?;
let dtype = match cfg.dtype.as_deref() {
Some("f16") => DType::F16,
Expand Down Expand Up @@ -114,6 +114,7 @@ pub fn run(prompt: String, cfg: Config) -> Result<()> {
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;
let mut res = String::new();
for index in 0..cfg.sample_len {
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
Expand Down Expand Up @@ -144,18 +145,11 @@ pub fn run(prompt: String, cfg: Config) -> Result<()> {
break;
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
res += &t;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
res += &rest;
}
let dt = start_gen.elapsed();
println!(
"\n\n{} tokens generated ({} token/s)\n",
token_generated,
token_generated as f64 / dt.as_secs_f64(),
);
Ok(())
Ok(res)
}
19 changes: 7 additions & 12 deletions atoma-inference/src/candle/stable_diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl Default for Config {
vae_weights: None,
tokenizer: None,
sliced_attention_size: None,
n_steps: Some(1000),
n_steps: Some(20),
num_samples: 1,
final_image: "sd_final.png".to_string(),
sd_version: StableDiffusionVersion::V1_5,
Expand Down Expand Up @@ -368,7 +368,7 @@ fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> Result<Tensor, ModelE
Ok(img)
}

pub fn run(prompt: String, uncond_prompt: String, cfg: Config) -> Result<(), ModelError> {
pub fn run(prompt: String, uncond_prompt: String, cfg: Config) -> Result<Vec<Tensor>, ModelError> {
if !(0. ..=1.).contains(&cfg.img2img_strength) {
Err(ModelError::Config(format!(
"img2img_strength must be between 0 and 1, got {}",
Expand Down Expand Up @@ -483,6 +483,8 @@ pub fn run(prompt: String, uncond_prompt: String, cfg: Config) -> Result<(), Mod
StableDiffusionVersion::Turbo => 0.13025,
};

let mut res = Vec::new();

for idx in 0..cfg.num_samples {
let timesteps = scheduler.timesteps();
let latents = match &init_latent_dist {
Expand Down Expand Up @@ -560,13 +562,7 @@ pub fn run(prompt: String, uncond_prompt: String, cfg: Config) -> Result<(), Mod
.unwrap()
.i(0)
.unwrap();
let image_filename = output_filename(
&cfg.final_image,
idx + 1,
cfg.num_samples,
Some(timestep_index + 1),
);
save_image(&image, image_filename).unwrap()
res.push(image);
}
}

Expand All @@ -581,8 +577,7 @@ pub fn run(prompt: String, uncond_prompt: String, cfg: Config) -> Result<(), Mod
.unwrap()
.i(0)
.unwrap();
let image_filename = output_filename(&cfg.final_image, idx + 1, cfg.num_samples, None);
save_image(&image, image_filename).unwrap()
res.push(image);
}
Ok(())
Ok(res)
}
19 changes: 10 additions & 9 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
// use models::llama::run;
use candle::llama::run;

use candle::stable_diffusion::run;
// use candle::stable_diffusion::run;

mod candle;
mod models;
mod types;

fn main() {
let x = run(
"The most important thing is ".to_string(),
Default::default(),
)
.unwrap();
println!("{}", x);
// run(
// "The most important thing is ".to_string(),
// "Green boat on ocean during storm".to_string(),
// "".to_string(),
// Default::default(),
// )
// .unwrap();
run(
"Green boat on ocean during storm".to_string(),
"".to_string(),
Default::default(),
)
.unwrap();
}

0 comments on commit b74c8ae

Please sign in to comment.