Skip to content

Commit

Permalink
t
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Mar 28, 2024
1 parent 0348336 commit 96ada78
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
16 changes: 14 additions & 2 deletions atoma-inference/src/candle/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use std::path::PathBuf;

use anyhow::{bail, Error as E, Result};

use candle::{DType, Tensor};
Expand All @@ -17,6 +19,8 @@ use tokenizers::Tokenizer;

use crate::candle::{device, hub_load_safetensors, token_output_stream::TokenOutputStream};

use super::save_tensor_to_file;

const EOS_TOKEN: &str = "</s>";

#[derive(Clone, Debug, Copy, PartialEq, Eq)]
Expand Down Expand Up @@ -47,7 +51,7 @@ impl Default for Config {
Self {
temperature: None,
top_p: None,
seed: 299792458,
seed: 0,
sample_len: 10000,
no_kv_cache: false,
dtype: None,
Expand Down Expand Up @@ -110,6 +114,7 @@ pub fn run(prompt: String, cfg: Config) -> Result<String> {
let mut logits_processor = LogitsProcessor::new(cfg.seed, cfg.temperature, cfg.top_p);
let mut index_pos = 0;
let mut res = String::new();
let mut result = Vec::new();
for index in 0..cfg.sample_len {
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
Expand All @@ -131,8 +136,8 @@ pub fn run(prompt: String, cfg: Config) -> Result<String> {
)?
};
index_pos += ctxt.len();

let next_token = logits_processor.sample(&logits)?;
result.push(logits);
tokens.push(next_token);

if Some(next_token) == eos_token_id {
Expand All @@ -145,5 +150,12 @@ pub fn run(prompt: String, cfg: Config) -> Result<String> {
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
res += &rest;
}
for (i, tensor) in result.iter().enumerate() {
save_tensor_to_file(
tensor,
&PathBuf::from(format!("llama_logits_{}.pt", i.to_string())),
)
.unwrap();
}
Ok(res)
}
6 changes: 1 addition & 5 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@ fn stable_diffusion() {

fn llama1() {
use crate::candle::llama::run;
let x = run(
"The most important thing is ".to_string(),
Default::default(),
)
.unwrap();
let x = run("If we find a solution to ".to_string(), Default::default()).unwrap();
println!("{}", x);
}

Expand Down

0 comments on commit 96ada78

Please sign in to comment.