Skip to content

Commit

Permalink
add save tensor function
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Mar 28, 2024
1 parent c5f3964 commit 0348336
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
18 changes: 16 additions & 2 deletions atoma-inference/src/candle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ pub mod llama;
pub mod stable_diffusion;
pub mod token_output_stream;

use std::path::PathBuf;
use std::{fs::File, io::Write, path::PathBuf};

use candle::{
utils::{cuda_is_available, metal_is_available},
Device, Tensor,
DType, Device, Tensor,
};
use tracing::info;

Expand Down Expand Up @@ -74,3 +74,17 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> candle::Resu
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}

pub fn save_tensor_to_file(tensor: &Tensor, path: &PathBuf) -> Result<(), candle::Error> {
let json_output = serde_json::to_string(
&tensor
.to_device(&Device::Cpu)?
.flatten_all()?
.to_dtype(DType::F64)?
.to_vec1::<f64>()?,
)
.unwrap();
let mut file = File::create(path)?;
file.write_all(json_output.as_bytes())?;
Ok(())
}
2 changes: 1 addition & 1 deletion atoma-inference/src/candle/stable_diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ impl StableDiffusion {
clip_config,
clip_weights,
device,
DType::F32,
DType::F64,
)?;
let text_embeddings = text_model.forward(&tokens)?;

Expand Down

0 comments on commit 0348336

Please sign in to comment.