Skip to content

Commit

Permalink
add stable diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Apr 2, 2024
1 parent 1c42945 commit 313ef85
Show file tree
Hide file tree
Showing 7 changed files with 745 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Cargo.lock
target/
.vscode/
15 changes: 14 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
[workspace]
resolver = "2"
edition = "2021"

members = ["atoma-event-subscribe", "atoma-inference", "atoma-networking", "atoma-json-rpc", "atoma-storage"]
members = [
"atoma-event-subscribe",
"atoma-inference",
"atoma-networking",
"atoma-json-rpc",
"atoma-storage",
]

[workspace.package]
version = "0.1.0"

[workspace.dependencies]
anyhow = "1.0.81"
async-trait = "0.1.78"
candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" }
candle-flash-attn = { git = "https://github.com/huggingface/candle", package = "candle-flash-attn", version = "0.4.2" }
Expand All @@ -16,6 +24,11 @@ config = "0.14.0"
ed25519-consensus = "2.1.0"
futures = "0.3.30"
hf-hub = "0.3.2"
clap = "4.5.3"
image = { version = "0.25.0", default-features = false, features = [
"jpeg",
"png",
] }
serde = "1.0.197"
serde_json = "1.0.114"
rand = "0.8.5"
Expand Down
15 changes: 10 additions & 5 deletions atoma-inference/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
[package]
name = "inference"
version = "0.1.0"
version.workspace = true
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
anyhow.workspace = true
async-trait.workspace = true
candle.workspace = true
candle-flash-attn = { workspace = true, optional = true }
Expand All @@ -18,8 +17,10 @@ hf-hub.workspace = true
reqwest = { workspace = true, features = ["json"] }
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
clap.workspace = true
image = { workspace = true }
thiserror.workspace = true
tokenizers.workspace = true
tokenizers = { workspace = true, features = ["onig"] }
tokio = { workspace = true, features = ["full", "tracing"] }
tracing.workspace = true
tracing-subscriber.workspace = true
Expand All @@ -30,7 +31,11 @@ toml.workspace = true


[features]
accelerate = ["candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
accelerate = [
"candle/accelerate",
"candle-nn/accelerate",
"candle-transformers/accelerate",
]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
Expand Down
89 changes: 89 additions & 0 deletions atoma-inference/src/candle/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
pub mod stable_diffusion;
pub mod token_output_stream;

use std::{fs::File, io::Write, path::PathBuf};

use candle::{
utils::{cuda_is_available, metal_is_available},
DType, Device, Tensor,
};
use tracing::info;

use crate::models::ModelError;

pub trait CandleModel {
type Fetch;
type Input;
fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError>;
fn inference(input: Self::Input) -> Result<Vec<Tensor>, ModelError>;
}

pub fn device() -> Result<Device, candle::Error> {
if cuda_is_available() {
info!("Using CUDA");
Device::new_cuda(0)
} else if metal_is_available() {
info!("Using Metal");
Device::new_metal(0)
} else {
info!("Using Cpu");
Ok(Device::Cpu)
}
}

pub fn hub_load_safetensors(
repo: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> candle::Result<Vec<std::path::PathBuf>> {
let json_file = repo.get(json_file).map_err(candle::Error::wrap)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| repo.get(v).map_err(candle::Error::wrap))
.collect::<candle::Result<Vec<_>>>()?;
Ok(safetensors_files)
}

pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> candle::Result<()> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
candle::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle::bail!("error saving image {p:?}"),
};
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}

pub fn save_tensor_to_file(tensor: &Tensor, filename: &str) -> Result<(), candle::Error> {
let json_output = serde_json::to_string(
&tensor
.to_device(&Device::Cpu)?
.flatten_all()?
.to_dtype(DType::F64)?
.to_vec1::<f64>()?,
)
.unwrap();
let mut file = File::create(PathBuf::from(filename))?;
file.write_all(json_output.as_bytes())?;
Ok(())
}
Loading

0 comments on commit 313ef85

Please sign in to comment.