Skip to content

Commit

Permalink
move
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Apr 3, 2024
1 parent fe8d06a commit 253ed4c
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 181 deletions.
89 changes: 0 additions & 89 deletions atoma-inference/src/candle/mod.rs

This file was deleted.

86 changes: 0 additions & 86 deletions atoma-inference/src/candle/token_output_stream.rs

This file was deleted.

1 change: 0 additions & 1 deletion atoma-inference/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
pub mod apis;
pub mod candle;
pub mod model_thread;
pub mod models;
pub mod service;
Expand Down
79 changes: 79 additions & 0 deletions atoma-inference/src/models/candle/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,80 @@
use std::{fs::File, io::Write, path::PathBuf};

use candle::{
utils::{cuda_is_available, metal_is_available},
DType, Device, Tensor,
};
use tracing::info;

pub mod mamba;
pub mod stable_diffusion;

pub fn device() -> Result<Device, candle::Error> {
if cuda_is_available() {
info!("Using CUDA");
Device::new_cuda(0)
} else if metal_is_available() {
info!("Using Metal");
Device::new_metal(0)
} else {
info!("Using Cpu");
Ok(Device::Cpu)
}
}

pub fn hub_load_safetensors(
repo: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> candle::Result<Vec<std::path::PathBuf>> {
let json_file = repo.get(json_file).map_err(candle::Error::wrap)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| repo.get(v).map_err(candle::Error::wrap))
.collect::<candle::Result<Vec<_>>>()?;
Ok(safetensors_files)
}

pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> candle::Result<()> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
candle::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle::bail!("error saving image {p:?}"),
};
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}

pub fn save_tensor_to_file(tensor: &Tensor, filename: &str) -> Result<(), candle::Error> {
let json_output = serde_json::to_string(
&tensor
.to_device(&Device::Cpu)?
.flatten_all()?
.to_dtype(DType::F64)?
.to_vec1::<f64>()?,
)
.unwrap();
let mut file = File::create(PathBuf::from(filename))?;
file.write_all(json_output.as_bytes())?;
Ok(())
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@ use candle_transformers::models::stable_diffusion::{self};
use candle::{DType, Device, IndexOp, Module, Tensor, D};
use tokenizers::Tokenizer;

use crate::{
candle::device,
models::{types::PrecisionBits, ModelError, ModelId, ModelTrait},
};
use crate::models::{types::PrecisionBits, ModelError, ModelId, ModelTrait};

use super::save_tensor_to_file;
use super::{device, save_tensor_to_file};

pub struct Input {
prompt: String,
Expand Down

0 comments on commit 253ed4c

Please sign in to comment.