From 0348336d4053aeead2107170231f57e628cd850a Mon Sep 17 00:00:00 2001 From: Cifko Date: Thu, 28 Mar 2024 09:56:42 +0100 Subject: [PATCH] add save tensor function --- atoma-inference/src/candle/mod.rs | 18 ++++++++++++++++-- atoma-inference/src/candle/stable_diffusion.rs | 2 +- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/atoma-inference/src/candle/mod.rs b/atoma-inference/src/candle/mod.rs index d9f57acd..045a5a1a 100644 --- a/atoma-inference/src/candle/mod.rs +++ b/atoma-inference/src/candle/mod.rs @@ -2,11 +2,11 @@ pub mod llama; pub mod stable_diffusion; pub mod token_output_stream; -use std::path::PathBuf; +use std::{fs::File, io::Write, path::PathBuf}; use candle::{ utils::{cuda_is_available, metal_is_available}, - Device, Tensor, + DType, Device, Tensor, }; use tracing::info; @@ -74,3 +74,17 @@ pub fn save_image>(img: &Tensor, p: P) -> candle::Resu image.save(p).map_err(candle::Error::wrap)?; Ok(()) } + +pub fn save_tensor_to_file(tensor: &Tensor, path: &PathBuf) -> Result<(), candle::Error> { + let json_output = serde_json::to_string( + &tensor + .to_device(&Device::Cpu)? + .flatten_all()? + .to_dtype(DType::F64)? + .to_vec1::()?, + ) + .unwrap(); + let mut file = File::create(path)?; + file.write_all(json_output.as_bytes())?; + Ok(()) +} diff --git a/atoma-inference/src/candle/stable_diffusion.rs b/atoma-inference/src/candle/stable_diffusion.rs index 12b50dda..fa5409de 100644 --- a/atoma-inference/src/candle/stable_diffusion.rs +++ b/atoma-inference/src/candle/stable_diffusion.rs @@ -497,7 +497,7 @@ impl StableDiffusion { clip_config, clip_weights, device, - DType::F32, + DType::F64, )?; let text_embeddings = text_model.forward(&tokens)?;