Skip to content

Commit

Permalink
add llama and stable diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Mar 25, 2024
1 parent 48e7845 commit 40e4a61
Show file tree
Hide file tree
Showing 8 changed files with 918 additions and 8 deletions.
24 changes: 24 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"cSpell.words": [
"amall",
"bsize",
"Catmull",
"ctxt",
"dtype",
"endoftext",
"laion",
"logits",
"madebyollin",
"mmaped",
"Narsil",
"openai",
"runwayml",
"safetensors",
"sdxl",
"stabilityai",
"timestep",
"uncond",
"Unet",
"unsqueeze"
]
}
22 changes: 20 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
[workspace]
resolver = "2"

members = ["atoma-event-subscribe", "atoma-inference", "atoma-networking", "atoma-json-rpc", "atoma-storage"]
edition = "2021"
members = [
"atoma-event-subscribe",
"atoma-inference",
"atoma-networking",
"atoma-json-rpc",
"atoma-storage",
]

[workspace.package]
version = "0.1.0"

[workspace.dependencies]
anyhow = "1.0.81"
clap = "4.5.3"
async-trait = "0.1.78"
candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" }
candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" }
Expand All @@ -17,3 +25,13 @@ thiserror = "1.0.58"
tokenizers = "0.15.2"
tokio = "1.36.0"
tracing = "0.1.40"
serde = "1.0.197"
thiserror = "1.0.58"
tokenizers = { version = "0.15.0", default-features = false }
tokio = "1.36.0"
tracing = "0.1.40"
hf-hub = "0.3.0"
image = { version = "0.25.0", default-features = false, features = [
"jpeg",
"png",
] }
17 changes: 12 additions & 5 deletions atoma-inference/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
[package]
name = "inference"
version = "0.1.0"
version.workspace = true
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
anyhow.workspace = true
async-trait.workspace = true
candle.workspace = true
candle-nn.workspace = true
candle-transformers.workspace = true
candle.workspace = true
clap.workspace = true
ed25519-consensus.workspace = true
hf-hub = { workspace = true, features = ["tokio"] }
image = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = "1.0.114"
thiserror.workspace = true
tokenizers.workspace = true
tokenizers = { workspace = true, features = ["onig"] }
tokio = { workspace = true, features = ["full", "tracing"] }
tracing.workspace = true

[features]
cuda = ["candle/cuda", "candle-nn/cuda"]
metal = ["candle/metal", "candle-nn/metal"]
18 changes: 17 additions & 1 deletion atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
// use models::llama::run;

use models::stable_diffusion::run;

mod models;

fn main() {
println!("Hello, world!");
// run(
// "The most important thing is ".to_string(),
// Default::default(),
// )
// .unwrap();
run(
"Green boat on ocean during storm".to_string(),
"".to_string(),
Default::default(),
)
.unwrap();
}
162 changes: 162 additions & 0 deletions atoma-inference/src/models/llama.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;

#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use anyhow::{bail, Error as E, Result};

use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;

use candle_transformers::models::llama as model;
use model::{Llama, LlamaConfig};
use tokenizers::Tokenizer;

use crate::models::{device, hub_load_safetensors, token_output_stream::TokenOutputStream};

const EOS_TOKEN: &str = "</s>";
const DEFAULT_PROMPT: &str = "My favorite theorem is ";

#[derive(Clone, Debug, Copy, PartialEq, Eq)]
enum Which {
V1,
V2,
Solar10_7B,
TinyLlama1_1BChat,
}

pub struct Config {
temperature: Option<f64>,
top_p: Option<f64>,
seed: u64,
sample_len: usize,
no_kv_cache: bool,
dtype: Option<String>,
model_id: Option<String>,
revision: Option<String>,
which: Which,
use_flash_attn: bool,
repeat_penalty: f32,
repeat_last_n: usize,
}

impl Default for Config {
fn default() -> Self {
Self {
temperature: None,
top_p: None,
seed: 299792458,
sample_len: 10000,
no_kv_cache: false,
dtype: None,
model_id: None,
revision: None,
which: Which::TinyLlama1_1BChat,
use_flash_attn: false,
repeat_penalty: 1.,
repeat_last_n: 64,
}
}
}

pub fn run(prompt: String, cfg: Config) -> Result<()> {
let device = device()?;
let dtype = match cfg.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let (llama, tokenizer_filename, mut cache) = {
let api = Api::new()?;
let model_id = cfg.model_id.unwrap_or_else(|| match cfg.which {
Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
println!("loading the model weights from {model_id}");
let revision = cfg.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));

let tokenizer_filename = api.get("tokenizer.json")?;
let config_filename = api.get("config.json")?;
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(cfg.use_flash_attn);

let filenames = match cfg.which {
Which::V1 | Which::V2 | Which::Solar10_7B => {
hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
};
let cache = model::Cache::new(!cfg.no_kv_cache, dtype, &config, &device)?;

let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &config)?, tokenizer_filename, cache)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
let mut tokens = tokenizer
.encode(prompt.clone(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();

let mut tokenizer = TokenOutputStream::new(tokenizer);
println!("starting the inference loop");
print!("{prompt}");
let mut logits_processor = LogitsProcessor::new(cfg.seed, cfg.temperature, cfg.top_p);
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;
for index in 0..cfg.sample_len {
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, context_index, &mut cache)?;
let logits = logits.squeeze(0)?;
let logits = if cfg.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(cfg.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
cfg.repeat_penalty,
&tokens[start_at..],
)?
};
index_pos += ctxt.len();

let next_token = logits_processor.sample(&logits)?;
token_generated += 1;
tokens.push(next_token);

if Some(next_token) == eos_token_id {
break;
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed();
println!(
"\n\n{} tokens generated ({} token/s)\n",
token_generated,
token_generated as f64 / dt.as_secs_f64(),
);
Ok(())
}
65 changes: 65 additions & 0 deletions atoma-inference/src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
pub mod llama;
pub mod stable_diffusion;
pub mod token_output_stream;

use anyhow::Result;
use candle::{
utils::{cuda_is_available, metal_is_available},
Device, Tensor,
};

pub fn device() -> Result<Device> {
if cuda_is_available() {
println!("Using CUDA");
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
println!("Using Metal");
Ok(Device::new_metal(0)?)
} else {
println!("Using Cpu");
Ok(Device::Cpu)
}
}

pub fn hub_load_safetensors(
repo: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> candle::Result<Vec<std::path::PathBuf>> {
let json_file = repo.get(json_file).map_err(candle::Error::wrap)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| repo.get(v).map_err(candle::Error::wrap))
.collect::<candle::Result<Vec<_>>>()?;
Ok(safetensors_files)
}

pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> candle::Result<()> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
candle::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle::bail!("error saving image {p:?}"),
};
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}
Loading

0 comments on commit 40e4a61

Please sign in to comment.