diff --git a/examples/simple/src/main.rs b/examples/simple/src/main.rs index 267d6864..ec8f87d4 100644 --- a/examples/simple/src/main.rs +++ b/examples/simple/src/main.rs @@ -17,7 +17,9 @@ use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue; use llama_cpp_2::model::params::LlamaModelParams; use llama_cpp_2::model::LlamaModel; use llama_cpp_2::model::{AddBos, Special}; -use llama_cpp_2::token::data_array::LlamaTokenDataArray; +use llama_cpp_2::sampling::params::LlamaSamplerChainParams; +use llama_cpp_2::sampling::LlamaSampler; + use std::ffi::CString; use std::io::Write; use std::num::NonZeroU32; @@ -174,9 +176,9 @@ fn main() -> Result<()> { .with_context(|| "unable to load model")?; // initialize the context - let mut ctx_params = LlamaContextParams::default() - .with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap()))) - .with_seed(seed.unwrap_or(1234)); + let mut ctx_params = + LlamaContextParams::default().with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap()))); + if let Some(threads) = threads { ctx_params = ctx_params.with_n_threads(threads); } @@ -244,23 +246,23 @@ either reduce n_len or increase n_ctx" // The `Decoder` let mut decoder = encoding_rs::UTF_8.new_decoder(); + let sampler_params = LlamaSamplerChainParams::default(); + let mut sampler = LlamaSampler::new(sampler_params)?.add_dist(seed.unwrap_or(1234)); + while n_cur <= n_len { // sample the next token { - let candidates = ctx.candidates(); - - let candidates_p = LlamaTokenDataArray::from_iter(candidates, false); + let token = sampler.sample(&ctx, batch.n_tokens() - 1); - // sample the most likely token - let new_token_id = ctx.sample_token_greedy(candidates_p); + sampler.accept(token); // is it an end of stream? - if model.is_eog_token(new_token_id) { + if model.is_eog_token(token) { eprintln!(); break; } - let output_bytes = model.token_to_bytes(new_token_id, Special::Tokenize)?; + let output_bytes = model.token_to_bytes(token, Special::Tokenize)?; // use `Decoder.decode_to_string()` to avoid the intermediate buffer let mut output_string = String::with_capacity(32); let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false); @@ -268,7 +270,7 @@ either reduce n_len or increase n_ctx" std::io::stdout().flush()?; batch.clear(); - batch.add(new_token_id, n_cur, &[0], true)?; + batch.add(token, n_cur, &[0], true)?; } n_cur += 1; diff --git a/examples/usage.rs b/examples/usage.rs index 1b7d1f5d..2b7f1915 100644 --- a/examples/usage.rs +++ b/examples/usage.rs @@ -14,6 +14,8 @@ use llama_cpp_2::llama_batch::LlamaBatch; use llama_cpp_2::model::params::LlamaModelParams; use llama_cpp_2::model::LlamaModel; use llama_cpp_2::model::{AddBos, Special}; +use llama_cpp_2::sampling::params::LlamaSamplerChainParams; +use llama_cpp_2::sampling::LlamaSampler; use llama_cpp_2::token::data_array::LlamaTokenDataArray; use std::io::Write; @@ -54,25 +56,25 @@ fn main() { // The `Decoder` let mut decoder = encoding_rs::UTF_8.new_decoder(); + let sampler_params = LlamaSamplerChainParams::default(); + let mut sampler = LlamaSampler::new(sampler_params) + .expect("Failed to create sampler") + .add_greedy(); + while n_cur <= n_len { // sample the next token { - let candidates = ctx.candidates_ith(batch.n_tokens() - 1); - - let candidates_p = LlamaTokenDataArray::from_iter(candidates, false); + let token = sampler.sample(&ctx, batch.n_tokens() - 1); - // sample the most likely token - let new_token_id = ctx.sample_token_greedy(candidates_p); + sampler.accept(token); // is it an end of stream? - if new_token_id == model.token_eos() { + if token == model.token_eos() { eprintln!(); break; } - let output_bytes = model - .token_to_bytes(new_token_id, Special::Tokenize) - .unwrap(); + let output_bytes = model.token_to_bytes(token, Special::Tokenize).unwrap(); // use `Decoder.decode_to_string()` to avoid the intermediate buffer let mut output_string = String::with_capacity(32); let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false); @@ -80,7 +82,7 @@ fn main() { std::io::stdout().flush().unwrap(); batch.clear(); - batch.add(new_token_id, n_cur, &[0], true).unwrap(); + batch.add(token, n_cur, &[0], true).unwrap(); } n_cur += 1; diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index 80ee8f75..cdebb88a 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -17,7 +17,6 @@ use crate::{ pub mod kv_cache; pub mod params; -pub mod sample; pub mod session; /// Safe wrapper around `llama_context`. @@ -267,12 +266,12 @@ impl<'model> LlamaContext<'model> { /// Reset the timings for the context. pub fn reset_timings(&mut self) { - unsafe { llama_cpp_sys_2::llama_reset_timings(self.context.as_ptr()) } + unsafe { llama_cpp_sys_2::llama_perf_context_reset(self.context.as_ptr()) } } /// Returns the timings for the context. pub fn timings(&mut self) -> LlamaTimings { - let timings = unsafe { llama_cpp_sys_2::llama_get_timings(self.context.as_ptr()) }; + let timings = unsafe { llama_cpp_sys_2::llama_perf_context(self.context.as_ptr()) }; LlamaTimings { timings } } diff --git a/llama-cpp-2/src/context/params.rs b/llama-cpp-2/src/context/params.rs index 14eca8b0..cfaf967b 100644 --- a/llama-cpp-2/src/context/params.rs +++ b/llama-cpp-2/src/context/params.rs @@ -47,7 +47,7 @@ impl From for i32 { pub enum LlamaPoolingType { /// The pooling type is unspecified Unspecified = -1, - /// No pooling + /// No pooling None = 0, /// Mean pooling Mean = 1, @@ -95,10 +95,8 @@ impl From for i32 { /// use llama_cpp_2::context::params::LlamaContextParams; /// ///let ctx_params = LlamaContextParams::default() -/// .with_n_ctx(NonZeroU32::new(2048)) -/// .with_seed(1234); +/// .with_n_ctx(NonZeroU32::new(2048)); /// -/// assert_eq!(ctx_params.seed(), 1234); /// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048)); /// ``` #[derive(Debug, Clone)] @@ -116,37 +114,6 @@ unsafe impl Send for LlamaContextParams {} unsafe impl Sync for LlamaContextParams {} impl LlamaContextParams { - /// Set the seed of the context - /// - /// # Examples - /// - /// ```rust - /// use llama_cpp_2::context::params::LlamaContextParams; - /// let params = LlamaContextParams::default(); - /// let params = params.with_seed(1234); - /// assert_eq!(params.seed(), 1234); - /// ``` - #[must_use] - pub fn with_seed(mut self, seed: u32) -> Self { - self.context_params.seed = seed; - self - } - - /// Get the seed of the context - /// - /// # Examples - /// - /// ```rust - /// use llama_cpp_2::context::params::LlamaContextParams; - /// let params = LlamaContextParams::default() - /// .with_seed(1234); - /// assert_eq!(params.seed(), 1234); - /// ``` - #[must_use] - pub fn seed(&self) -> u32 { - self.context_params.seed - } - /// Set the side of the context /// /// # Examples diff --git a/llama-cpp-2/src/context/sample.rs b/llama-cpp-2/src/context/sample.rs deleted file mode 100644 index cc0f85ee..00000000 --- a/llama-cpp-2/src/context/sample.rs +++ /dev/null @@ -1,141 +0,0 @@ -//! Sampling functions for the context. - -use crate::context::LlamaContext; -use crate::grammar::LlamaGrammar; -use crate::token::data_array::LlamaTokenDataArray; -use crate::token::LlamaToken; - -#[cfg(feature = "sampler")] -pub mod sampler; - -impl LlamaContext<'_> { - /// Accept a token into the grammar. - pub fn grammar_accept_token(&mut self, grammar: &mut LlamaGrammar, token: LlamaToken) { - unsafe { - llama_cpp_sys_2::llama_grammar_accept_token( - grammar.grammar.as_ptr(), - self.context.as_ptr(), - token.0, - ); - } - } - - /// Perform grammar sampling. - pub fn sample_grammar( - &mut self, - llama_token_data_array: &mut LlamaTokenDataArray, - llama_grammar: &LlamaGrammar, - ) { - unsafe { - llama_token_data_array.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_grammar( - self.context.as_ptr(), - c_llama_token_data_array, - llama_grammar.grammar.as_ptr(), - ); - }); - } - } - - /// See [`LlamaTokenDataArray::sample_temp`] - pub fn sample_temp(&mut self, token_data: &mut LlamaTokenDataArray, temperature: f32) { - token_data.sample_temp(Some(self), temperature); - } - - /// Sample a token greedily. Note that this *does not* take into account anything that has modified the probabilities - it only looks at logits. - /// - /// Most of the time [`LlamaTokenDataArray::sample_softmax`] or [`LlamaTokenDataArray::sample_token`] should be used instead. - /// - /// # Panics - /// - /// - if `token_data` is empty - #[must_use] - pub fn sample_token_greedy(&mut self, mut token_data: LlamaTokenDataArray) -> LlamaToken { - assert!(!token_data.data.is_empty(), "no tokens"); - let mut data_arr = llama_cpp_sys_2::llama_token_data_array { - data: token_data - .data - .as_mut_ptr() - .cast::(), - size: token_data.data.len(), - sorted: token_data.sorted, - }; - let token = unsafe { - llama_cpp_sys_2::llama_sample_token_greedy( - self.context.as_ptr(), - std::ptr::addr_of_mut!(data_arr), - ) - }; - LlamaToken(token) - } - - /// See [`LlamaTokenDataArray::sample_tail_free`] - pub fn sample_tail_free( - &mut self, - token_data: &mut LlamaTokenDataArray, - z: f32, - min_keep: usize, - ) { - token_data.sample_tail_free(Some(self), z, min_keep); - } - - /// See [`LlamaTokenDataArray::sample_typical`] - pub fn sample_typical( - &mut self, - token_data: &mut LlamaTokenDataArray, - p: f32, - min_keep: usize, - ) { - token_data.sample_typical(Some(self), p, min_keep); - } - - /// See [`LlamaTokenDataArray::sample_top_p`] - pub fn sample_top_p(&mut self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) { - token_data.sample_top_p(Some(self), p, min_keep); - } - - /// Minimum P sampling as described in [#3841](https://github.com/ggerganov/llama.cpp/pull/3841) - pub fn sample_min_p( - &mut self, - llama_token_data: &mut LlamaTokenDataArray, - p: f32, - min_keep: usize, - ) { - let ctx = self.context.as_ptr(); - unsafe { - llama_token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_min_p(ctx, c_llama_token_data_array, p, min_keep); - }); - } - } - - /// See [`LlamaTokenDataArray::sample_top_k`] - pub fn sample_top_k(&mut self, token_data: &mut LlamaTokenDataArray, k: i32, min_keep: usize) { - token_data.sample_top_k(Some(self), k, min_keep); - } - - /// See [`LlamaTokenDataArray::sample_softmax`] - pub fn sample_token_softmax(&mut self, token_data: &mut LlamaTokenDataArray) { - token_data.sample_softmax(Some(self)); - } - - /// See [`LlamaTokenDataArray::sample_repetition_penalty`] - pub fn sample_repetition_penalty( - &mut self, - token_data: &mut LlamaTokenDataArray, - last_tokens: &[LlamaToken], - penalty_last_n: usize, - penalty_repeat: f32, - penalty_freq: f32, - penalty_present: f32, - ) { - token_data.sample_repetition_penalty( - Some(self), - last_tokens, - penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - ); - } -} diff --git a/llama-cpp-2/src/context/sample/sampler.rs b/llama-cpp-2/src/context/sample/sampler.rs deleted file mode 100644 index 948a1aa5..00000000 --- a/llama-cpp-2/src/context/sample/sampler.rs +++ /dev/null @@ -1,112 +0,0 @@ -//! Create a sampler struct to encapsulate the sampling process. This allows passing all the possible -//! sampling parameters around as a single struct, and also allow late binding of expensive context -//! like [`crate::context::LlamaContext`] or token history to the sampler. -//! -//! # Example -//! -//! **Llama.cpp default sampler** -//! -//! ```rust -//! use llama_cpp_2::context::sample::sampler::{Sampler, SampleStep}; -//! use llama_cpp_2::token::data::LlamaTokenData; -//! use llama_cpp_2::token::data_array::LlamaTokenDataArray; -//! use llama_cpp_2::token::LlamaToken; -//! -//! // Sample a token greedily and add to the history. -//! let mut finalizer = &|mut canidates: LlamaTokenDataArray, history: &mut Vec| { -//! canidates.sample_softmax(None); -//! let token = canidates.data[0]; -//! history.push(token.id()); -//! vec![token] -//! }; -//! -//! let mut history = vec![]; -//! let mut sampler = Sampler::new(finalizer); -//! -//! sampler.push_step(&|c, history| c.sample_repetition_penalty(None, history, 64, 1.1, 0.0, 0.0)); -//! sampler.push_step(&|c, _| c.sample_top_k(None, 40, 1)); -//! sampler.push_step(&|c, _| c.sample_tail_free(None, 1.0, 1)); -//! sampler.push_step(&|c, _| c.sample_typical(None, 1.0, 1)); -//! sampler.push_step(&|c, _| c.sample_top_p(None, 0.95, 1)); -//! sampler.push_step(&|c, _| c.sample_min_p(None, 0.05, 1)); -//! sampler.push_step(&|c, _| c.sample_temp(None, 0.5)); -//! -//! // random candidates -//! let candidates = LlamaTokenDataArray::from_iter((0..4).map(|i| LlamaTokenData::new(LlamaToken::new(i), i as f32 / 6.0, 0.0)), false); -//! -//! for _ in 0..10 { -//! let tokens = sampler.sample(&mut history, candidates.clone()); -//! assert_eq!(tokens.len(), 1); -//! } -//! -//! assert_eq!(history.len(), 10); -//! ``` - -use crate::token::data::LlamaTokenData; -use crate::token::data_array::LlamaTokenDataArray; -use std::fmt::{Debug, Formatter}; - -/// A single step to sample tokens from the remaining candidates. -pub type SampleStep = dyn Fn(&mut LlamaTokenDataArray, &mut C); - -/// The final step to select tokens from the remaining candidates. -pub type SampleFinalizer = dyn Fn(LlamaTokenDataArray, &mut C) -> Vec; - -/// A series of sampling steps that will produce a vector of token data. -/// -/// `C` is dynamic context that will be passed to the sampling functions. Some sampling steps may -/// require state to be maintained across multiple samples, and this context can be used to store -/// that state. For example, [`LlamaTokenDataArray::sample_token_mirostat_v2`] requires a `mu` to be -/// shared across multiple samples. -pub struct Sampler<'a, C> { - /// The steps to take when sampling. - pub steps: Vec<&'a SampleStep>, - /// The final step to select one or more tokens from the remaining candidates. - pub finalizer: &'a SampleFinalizer, -} - -impl Debug for Sampler<'_, T> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Sampler") - .field( - "steps", - &format!( - "{} steps of Box ()>", - &self.steps.len() - ), - ) - .field( - "finalizer", - &"Box Vec>", - ) - .finish() - } -} - -impl<'a, T> Sampler<'a, T> { - /// Create a new sampler with a given finalizer. - pub fn new(finalizer: &'a SampleFinalizer) -> Self { - Self { - steps: vec![], - finalizer, - } - } - - /// Adds a step to the sampler. - pub fn push_step(&mut self, step: &'a SampleStep) { - self.steps.push(step); - } - - /// Sample a token from the given candidates. - #[must_use] - pub fn sample( - &mut self, - context: &mut T, - mut candidates: LlamaTokenDataArray, - ) -> Vec { - for step in &self.steps { - step(&mut candidates, context); - } - (self.finalizer)(candidates, context) - } -} diff --git a/llama-cpp-2/src/grammar.rs b/llama-cpp-2/src/grammar.rs deleted file mode 100644 index 667a870b..00000000 --- a/llama-cpp-2/src/grammar.rs +++ /dev/null @@ -1,491 +0,0 @@ -//! The grammar module contains the grammar parser and the grammar struct. -//! -//! This allows creating a llama-cpp grammar. This is essentially a translation of the parser in -//! `common` to rust - -use std::collections::BTreeMap; -use std::fmt::{Debug, Formatter}; - -use llama_cpp_sys_2::{llama_grammar, llama_grammar_element, llama_gretype}; -use std::ptr::NonNull; -use std::str::FromStr; -use tracing::error; - -/// Details of extraneous characters after a rule error. -#[derive(thiserror::Error, Debug)] -#[error("Extraneous chars after rule {name:?}: {chars:?}")] -pub struct ExtraneousCharsAfterRule { - /// The name of the rule being parsed - pub name: String, - /// the extraneous characters - pub chars: String, - /// the rest of the input, this is still to be parsed. - pub rest: String, -} - -/// There was an error parsing the grammar. -#[derive(thiserror::Error, Debug)] -#[allow(clippy::module_name_repetitions)] -pub enum GrammarParseError { - /// There was an unexpected end of input. - #[error("Unexpected end of input")] - UnexpectedEndOfInput { - /// the stage of parsing that was being performed when we ran out of input. - parse_stage: &'static str, - }, - /// There was unexpected characters after a rule name but before "::=". There can only be whitespace. - #[error("Unexpected Chars after name {name:?} and before \"::=\": {chars}")] - UnexpectedCharsAfterName { - /// the name of the rule being parsed - name: String, - /// the unexpected characters - chars: String, - }, - /// There was no "::=" after a rule name. - #[error("Expected ::= after name {name:?}")] - ExpectedEqualsAfterName { - /// the name of the rule being parsed - name: String, - }, - /// There was no closing bracket in a nested rule. - #[error("Expected closing bracket in nested rule {name:?}")] - MissingClosingBracketInNestedRule { - /// the name of the rule being parsed - name: String, - }, - /// There was no rule before a postfix operator. - #[error("Missing rule before postfix operator in {name:?}")] - ExpectedRuleBeforePostfixOperator { - /// the name of the rule being parsed - name: String, - }, - /// There was an incorrect hex size. - #[error("Expected hex number with size {expected_size}, but number was {actual:?}")] - IncorrectHexSize { - /// the expected size of the hex number - expected_size: usize, - /// the actual hex number - actual: String, - }, - /// An unknown escape character was found. - #[error("Unknown escape {escape:?}")] - UnknownEscape { - /// the unknown character - escape: char, - }, - /// Failed to parse hex from a string. - #[error("Failed to parse hex from {string}: {error}")] - ParseHexError { - /// the error that occurred when parsing the hex - #[source] - error: std::num::ParseIntError, - /// the string that was being parsed - string: String, - }, - /// there was not space after the name - // todo: is this actually an error? - #[error("Missing space after name in {rest:?}")] - MissingSpaceAfterName { - /// the rest of the input, this is still to be parsed. - rest: String, - }, - /// There was unexpected characters after the rule. - #[error("{0}")] - ExtraneousCharsAfterRule(ExtraneousCharsAfterRule), -} - -/// A grammar for llama-cpp. -#[allow(clippy::module_name_repetitions)] -pub struct LlamaGrammar { - parse: ParseState, - pub(crate) grammar: NonNull, -} - -impl Clone for LlamaGrammar { - fn clone(&self) -> Self { - let grammar = unsafe { llama_cpp_sys_2::llama_grammar_copy(self.grammar.as_ptr()) }; - Self { - parse: self.parse.clone(), - grammar: NonNull::new(grammar).expect("copied grammar should never be null"), - } - } -} - -unsafe impl Send for LlamaGrammar {} - -unsafe impl Sync for LlamaGrammar {} - -#[allow(clippy::module_name_repetitions)] -impl Debug for LlamaGrammar { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("LlamaGrammar") - .field("grammar", &self.grammar) - .field("parse", &self.parse) - .finish() - } -} - -#[derive(Debug, Clone, PartialEq)] -struct ParseState { - symbol_ids: BTreeMap, - rules: Vec>, -} - -impl ParseState { - fn new() -> Self { - Self { - symbol_ids: BTreeMap::new(), - rules: Vec::new(), - } - } - - fn get_symbol_id(&mut self, name: &str) -> u32 { - let next_id = - u32::try_from(self.symbol_ids.len()).expect("too many rules (must fit into u32)"); - let result = self.symbol_ids.entry(name.to_string()).or_insert(next_id); - *result - } - - fn generate_symbol_id(&mut self, name: &str) -> u32 { - let next_id = - u32::try_from(self.symbol_ids.len()).expect("too many rules (must fit into u32)"); - let generated_name = format!("{name}_{next_id}"); - let None = self.symbol_ids.insert(generated_name, next_id) else { - panic!("Failed to create unique name for {name}"); - }; - next_id - } - - fn parse_rule<'a>(&mut self, rest: &'a str) -> Result, GrammarParseError> { - let rest = Self::consume_whitespace_and_comments(rest, true); - if rest.is_empty() { - return Ok(None); - } - let (name, rest) = Self::parse_name(rest)?; - let rest = rest.trim_start(); - let rule_id = self.get_symbol_id(name); - - let (after_name, rest) = - rest.split_once("::=") - .ok_or_else(|| GrammarParseError::ExpectedEqualsAfterName { - name: name.to_string(), - })?; - - if !after_name.is_empty() { - return Err(GrammarParseError::UnexpectedCharsAfterName { - name: name.to_string(), - chars: after_name.to_string(), - }); - } - - let rest = self.parse_alternatives(name, rule_id, rest, false)?; - - let Some((after_rule, rest)) = rest.split_once('\n') else { - return Ok(None); - }; - - if !after_rule.chars().all(char::is_whitespace) { - return Err(GrammarParseError::ExtraneousCharsAfterRule( - ExtraneousCharsAfterRule { - name: name.to_string(), - chars: after_rule.to_string(), - rest: rest.to_string(), - }, - )); - } - - Ok(Some(rest)) - } - - fn consume_whitespace_and_comments(mut rest: &str, allow_newlines: bool) -> &str { - loop { - rest = rest.trim_start_matches( - |c: char| if allow_newlines { true } else { c != '\n' } && c.is_whitespace(), - ); - if rest.starts_with('#') { - rest = rest.split_once('\n').map_or("", |(_comment, rest)| rest); - } else { - break; - } - } - rest - } - - fn parse_alternatives<'a>( - &mut self, - name: &str, - id: u32, - rest: &'a str, - nested: bool, - ) -> Result<&'a str, GrammarParseError> { - let mut rule = Vec::new(); - let rest = self.parse_sequence(rest.trim_start(), name, &mut rule, nested)?; - let mut rest = Self::consume_whitespace_and_comments(rest, nested); - while rest.starts_with('|') { - rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_ALT, - value: 0, - }); - rest = Self::consume_whitespace_and_comments(&rest[1..], true); - rest = self.parse_sequence(rest, name, &mut rule, nested)?; - } - rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_END, - value: 0, - }); - self.add_rule(id, rule); - Ok(rest) - } - - fn add_rule(&mut self, id: u32, rule: Vec) { - let id = id as usize; - if self.rules.len() <= id { - self.rules.resize(id + 1, Vec::new()); - } - self.rules[id] = rule; - } - - #[allow(clippy::too_many_lines)] - fn parse_sequence<'a>( - &mut self, - mut rest: &'a str, - name: &str, - rule: &mut Vec, - nested: bool, - ) -> Result<&'a str, GrammarParseError> { - let mut last_sym_start = rule.len(); - while !rest.is_empty() { - let first_char = - rest.chars() - .next() - .ok_or(GrammarParseError::UnexpectedEndOfInput { - parse_stage: "sequence", - })?; - if first_char == '"' { - rest = &rest[1..]; - last_sym_start = rule.len(); - while !rest.starts_with('"') { - let (c, r) = Self::parse_char(rest)?; - rest = r; - rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_CHAR, - value: c as _, - }); - } - rest = Self::consume_whitespace_and_comments(&rest[1..], nested); - } else if first_char == '[' { - rest = &rest[1..]; - let start_type = if rest.starts_with('^') { - rest = &rest[1..]; - llama_cpp_sys_2::LLAMA_GRETYPE_CHAR_NOT - } else { - llama_cpp_sys_2::LLAMA_GRETYPE_CHAR - }; - last_sym_start = rule.len(); - while !rest.starts_with(']') { - let (c, r) = Self::parse_char(rest)?; - rest = r; - let gre_type = if last_sym_start < rule.len() { - llama_cpp_sys_2::LLAMA_GRETYPE_CHAR_ALT - } else { - start_type - }; - rule.push(llama_grammar_element { - type_: gre_type, - value: c as _, - }); - if rest.starts_with('-') && rest.get(1..).is_some_and(|r| !r.starts_with(']')) { - let (c, r) = Self::parse_char(&rest[1..])?; - rest = r; - rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_CHAR_RNG_UPPER, - value: c as _, - }); - } - } - rest = Self::consume_whitespace_and_comments(&rest[1..], nested); - } else if first_char.is_alphabetic() { - let (name, r) = Self::parse_name(rest)?; - rest = Self::consume_whitespace_and_comments(r, nested); - let ref_rule_id = self.get_symbol_id(name); - last_sym_start = rule.len(); - rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_RULE_REF, - value: ref_rule_id, - }); - } else if first_char == '(' { - rest = rest[1..].trim_start(); - let sub_rule_id = self.generate_symbol_id(name); - rest = self.parse_alternatives(name, sub_rule_id, rest, true)?; - last_sym_start = rule.len(); - rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_RULE_REF, - value: sub_rule_id, - }); - if !rest.starts_with(')') { - return Err(GrammarParseError::MissingClosingBracketInNestedRule { - name: name.to_string(), - }); - } - rest = Self::consume_whitespace_and_comments(&rest[1..], nested); - } else if first_char == '*' || first_char == '+' || first_char == '?' { - if last_sym_start == rule.len() { - return Err(GrammarParseError::ExpectedRuleBeforePostfixOperator { - name: name.to_string(), - }); - } - let sub_rule_id = self.generate_symbol_id(name); - let mut sub_rule: Vec = - rule.iter().skip(last_sym_start).copied().collect(); - if rest.starts_with(['*', '+']) { - sub_rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_RULE_REF, - value: sub_rule_id, - }); - } - sub_rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_ALT, - value: 0, - }); - if rest.starts_with('+') { - sub_rule.extend(rule.iter().skip(last_sym_start).copied()); - } - sub_rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_END, - value: 0, - }); - self.add_rule(sub_rule_id, sub_rule); - - rule.truncate(last_sym_start); - rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_RULE_REF, - value: sub_rule_id, - }); - - rest = Self::consume_whitespace_and_comments(&rest[1..], nested); - } else { - break; - } - } - - Ok(rest) - } - - fn parse_hex(rest: &str, size: usize) -> Result<(llama_gretype, &str), GrammarParseError> { - if rest.len() < size { - return Err(GrammarParseError::IncorrectHexSize { - expected_size: size, - actual: rest.to_string(), - }); - } - - let (hex, rest) = rest.split_at(size); - let value = - u32::from_str_radix(hex, 16).map_err(|error| GrammarParseError::ParseHexError { - string: hex.to_string(), - error, - })?; - - Ok((value as llama_gretype, rest)) - } - - fn parse_char(rest: &str) -> Result<(llama_gretype, &str), GrammarParseError> { - if let Some(rest) = rest.strip_prefix('\\') { - let Some(escaped) = rest.chars().next() else { - return Err(GrammarParseError::UnexpectedEndOfInput { - parse_stage: "escape char", - }); - }; - let rest = &rest[escaped.len_utf8()..]; - match escaped { - 'x' => Self::parse_hex(rest, 2), - 'u' => Self::parse_hex(rest, 4), - 'U' => Self::parse_hex(rest, 8), - 't' => Ok((u32::from('\t') as llama_gretype, rest)), - 'r' => Ok((u32::from('\r') as llama_gretype, rest)), - 'n' => Ok((u32::from('\n') as llama_gretype, rest)), - '\\' => Ok((u32::from('\\') as llama_gretype, rest)), - '"' => Ok((u32::from('"') as llama_gretype, rest)), - '[' => Ok((u32::from('[') as llama_gretype, rest)), - ']' => Ok((u32::from(']') as llama_gretype, rest)), - c => Err(GrammarParseError::UnknownEscape { escape: c }), - } - } else if let Some(c) = rest.chars().next() { - Ok((u32::from(c) as llama_gretype, &rest[c.len_utf8()..])) - } else { - Err(GrammarParseError::UnexpectedEndOfInput { - parse_stage: "char", - }) - } - } - - fn parse_name(rest: &str) -> Result<(&str, &str), GrammarParseError> { - let name_end = rest - .find(|c: char| !c.is_alphanumeric() && c != '-' && c != '_') - .ok_or(GrammarParseError::MissingSpaceAfterName { - rest: rest.to_string(), - })?; - let name = &rest[..name_end]; - let rest = &rest[name_end..]; - Ok((name, rest)) - } -} - -/// An error that can occur creating a grammar from a string. -#[derive(thiserror::Error, Debug)] -pub enum LlamaGrammarFromStrError { - /// There was an error parsing the grammar. - #[error("Failed to parse grammar {0}")] - ParseError(#[from] GrammarParseError), - /// Llama-cpp returned null - this can occur for many reasons, but should ideally be caught on - /// the rust side beforehand. - #[error("llama-cpp returned null")] - LlamaCppNullError, -} - -impl FromStr for ParseState { - type Err = GrammarParseError; - - fn from_str(s: &str) -> Result { - let mut parse_state = ParseState::new(); - let mut remaining = Some(s); - while let Some(str) = remaining { - remaining = parse_state.parse_rule(str)?; - } - Ok(parse_state) - } -} - -impl FromStr for LlamaGrammar { - type Err = LlamaGrammarFromStrError; - - fn from_str(s: &str) -> Result { - let mut parse_state = ParseState::from_str(s)?; - - let n_rules = parse_state.rules.len(); - let root_id = parse_state.get_symbol_id("root"); - let mut vec = parse_state - .rules - .iter_mut() - .map(|v| v.as_ptr()) - .collect::>(); - let rules = vec.as_mut_ptr(); - - let grammar = - unsafe { llama_cpp_sys_2::llama_grammar_init(rules, n_rules, root_id as usize) }; - - Ok(Self { - parse: parse_state, - grammar: NonNull::new(grammar).ok_or(LlamaGrammarFromStrError::LlamaCppNullError)?, - }) - } -} - -impl Drop for LlamaGrammar { - fn drop(&mut self) { - unsafe { llama_cpp_sys_2::llama_grammar_free(self.grammar.as_ptr()) } - } -} - -#[cfg(test)] -mod tests; diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index 2717c845..424572bd 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -23,10 +23,10 @@ use std::path::PathBuf; use std::string::FromUtf8Error; pub mod context; -pub mod grammar; pub mod llama_backend; pub mod llama_batch; pub mod model; +pub mod sampling; pub mod timing; pub mod token; pub mod token_type; @@ -62,6 +62,7 @@ pub enum LLamaCppError { /// see [`EmbeddingsError`] #[error(transparent)] EmbeddingError(#[from] EmbeddingsError), + // See [`LlamaSamplerError`] } /// There was an error while getting the chat template from a model. @@ -194,6 +195,14 @@ pub enum LlamaLoraAdapterRemoveError { ErrorResult(i32), } +/// An error that can occur when initializing a sampler. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum LlamaSamplerError { + /// llama.cpp returned null + #[error("null reference from llama.cpp")] + NullReturn, +} + /// get the time (in microseconds) according to llama.cpp /// ``` /// # use llama_cpp_2::llama_time_us; diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs new file mode 100644 index 00000000..7181e149 --- /dev/null +++ b/llama-cpp-2/src/sampling.rs @@ -0,0 +1,256 @@ +//! Safe wrapper around `llama_sampler`. +pub mod params; + +use std::ffi::CString; +use std::fmt::{Debug, Formatter}; +use std::ptr::NonNull; + +use crate::context::LlamaContext; +use crate::model::LlamaModel; +use crate::token::LlamaToken; +use crate::LlamaSamplerError; + +/// A safe wrapper around `llama_sampler`. +pub struct LlamaSampler { + pub(crate) sampler: NonNull, +} + +impl Debug for LlamaSampler { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LlamaSamplerChain").finish() + } +} + +impl LlamaSampler { + /// Create a new `LlamaSampler` from the given parameters. + /// # Errors + /// Returns an error if the underlying C++ code returns a null pointer. + pub fn new(params: params::LlamaSamplerChainParams) -> Result { + let sampler = unsafe { + NonNull::new(llama_cpp_sys_2::llama_sampler_chain_init( + params.sampler_chain_params, + )) + .ok_or(LlamaSamplerError::NullReturn) + }?; + + Ok(Self { sampler }) + } + + /// Samples the token with the largest probability. + #[must_use] + #[allow(unused_mut)] + pub fn add_greedy(mut self) -> Self { + unsafe { + let greedy_sampler = llama_cpp_sys_2::llama_sampler_init_greedy(); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), greedy_sampler); + } + + self + } + + /// Samples according to the probability distribution of the tokens. + #[must_use] + #[allow(unused_mut)] + pub fn add_dist(mut self, seed: u32) -> Self { + unsafe { + let dist_sampler = llama_cpp_sys_2::llama_sampler_init_dist(seed); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), dist_sampler); + } + + self + } + + /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" + #[must_use] + #[allow(unused_mut)] + pub fn add_top_k(mut self, k: i32) -> Self { + unsafe { + let top_k_sampler = llama_cpp_sys_2::llama_sampler_init_top_k(k); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), top_k_sampler); + } + + self + } + + /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" + #[must_use] + #[allow(unused_mut)] + pub fn add_top_p(mut self, p: f32, min_keep: usize) -> Self { + unsafe { + let top_p_sampler = llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), top_p_sampler); + } + + self + } + + /// Minimum P sampling as described in + #[must_use] + #[allow(unused_mut)] + pub fn add_min_p(mut self, p: f32, min_keep: usize) -> Self { + unsafe { + let min_p_sampler = llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), min_p_sampler); + } + + self + } + + /// Locally Typical Sampling implementation described in the paper . + #[must_use] + #[allow(unused_mut)] + pub fn add_typical(mut self, p: f32, min_keep: usize) -> Self { + unsafe { + let typical_sampler = llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), typical_sampler); + } + + self + } + + /// Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf + #[must_use] + #[allow(unused_mut)] + pub fn add_temp(mut self, t: f32) -> Self { + unsafe { + let temp_sampler = llama_cpp_sys_2::llama_sampler_init_temp(t); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_sampler); + } + + self + } + + /// Dynamic temperature implementation (a.k.a. entropy) described in the paper . + #[must_use] + #[allow(unused_mut)] + pub fn add_temp_ext(mut self, t: f32, delta: f32, exponent: f32) -> Self { + unsafe { + let temp_ext_sampler = llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler); + } + + self + } + + /// Mirostat 1.0 algorithm described in the paper . Uses tokens instead of words. + /// + /// # Arguments + /// + /// * `candidates` - A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// * `tau` - The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// * `eta` - The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// * `m` - The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. + /// * `mu` - Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + #[must_use] + #[allow(unused_mut)] + pub fn add_mirostat(mut self, n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self { + unsafe { + let temp_ext_sampler = + llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler); + } + + self + } + + /// Mirostat 2.0 algorithm described in the paper . Uses tokens instead of words. + /// + /// # Arguments + /// + /// * `candidates` - A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// * `tau` - The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// * `eta` - The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// * `mu` - Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + #[must_use] + #[allow(unused_mut)] + pub fn add_mirostat_v2(mut self, seed: u32, tau: f32, eta: f32) -> Self { + unsafe { + let temp_ext_sampler = llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler); + } + + self + } + + /// Samples constrained by a context-free grammar in the GGML BNF (GBNF) format. + /// + /// # Panics + /// Panics if a provided string contains a null byte. + #[must_use] + #[allow(unused_mut)] + pub fn add_grammar( + mut self, + model: &LlamaModel, + grammar_str: &str, + grammar_root: &str, + ) -> Self { + unsafe { + let grammar_str = CString::new(grammar_str).unwrap(); + let grammar_root = CString::new(grammar_root).unwrap(); + let grammar_sampler = llama_cpp_sys_2::llama_sampler_init_grammar( + model.model.as_ptr(), + grammar_str.as_ptr(), + grammar_root.as_ptr(), + ); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), grammar_sampler); + } + + self + } + + /// Adds penalties to the sampler. This can be used to penalize certain patterns in the generated text, such as repeating the same token multiple times or using the same token too frequently. + #[allow(unused_mut, clippy::too_many_arguments)] + #[must_use] + pub fn add_penalties( + mut self, + n_vocab: i32, + special_eos_id: i32, + linefeed_id: i32, + penalty_last_n: i32, + penalty_repeat: f32, + penalty_freq: f32, + penalty_present: f32, + penalize_nl: bool, + ignore_eos: bool, + ) -> Self { + unsafe { + let temp_ext_sampler = llama_cpp_sys_2::llama_sampler_init_penalties( + n_vocab, + special_eos_id, + linefeed_id, + penalty_last_n, + penalty_repeat, + penalty_freq, + penalty_present, + penalize_nl, + ignore_eos, + ); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler); + } + + self + } + + /// Sample and accept a token from the idx-th output of the last evaluation + #[must_use] + pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> LlamaToken { + let token = unsafe { + llama_cpp_sys_2::llama_sampler_sample(self.sampler.as_ptr(), ctx.context.as_ptr(), idx) + }; + + LlamaToken(token) + } + + /// Accepts a token from the sampler, possibly updating the internal state of certain samplers (e.g. grammar, repetition, etc.) + pub fn accept(&mut self, token: LlamaToken) { + unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler.as_ptr(), token.0) } + } +} + +impl Drop for LlamaSampler { + fn drop(&mut self) { + unsafe { + llama_cpp_sys_2::llama_sampler_free(self.sampler.as_ptr()); + } + } +} diff --git a/llama-cpp-2/src/sampling/params.rs b/llama-cpp-2/src/sampling/params.rs new file mode 100644 index 00000000..0e67c1fa --- /dev/null +++ b/llama-cpp-2/src/sampling/params.rs @@ -0,0 +1,39 @@ +//! Safe wrapper around `llama_sampler_chain_params`. + +use std::fmt::{Debug, Formatter}; + +/// A safe wrapper around `llama_sampler`. +pub struct LlamaSamplerChainParams { + pub(crate) sampler_chain_params: llama_cpp_sys_2::llama_sampler_chain_params, +} + +impl Debug for LlamaSamplerChainParams { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LlamaSamplerChainParams").finish() + } +} + +impl Default for LlamaSamplerChainParams { + fn default() -> Self { + let sampler_chain_params = unsafe { llama_cpp_sys_2::llama_sampler_chain_default_params() }; + + Self { + sampler_chain_params, + } + } +} + +impl LlamaSamplerChainParams { + /// Set whether to measure performance timings + #[must_use] + pub fn with_no_perf(mut self, no_perf: bool) -> Self { + self.sampler_chain_params.no_perf = no_perf; + self + } + + /// Get whether to measure performance timings + #[must_use] + pub fn no_perf(&self) -> bool { + self.sampler_chain_params.no_perf + } +} diff --git a/llama-cpp-2/src/timing.rs b/llama-cpp-2/src/timing.rs index 51cf682a..b45d9318 100644 --- a/llama-cpp-2/src/timing.rs +++ b/llama-cpp-2/src/timing.rs @@ -4,43 +4,35 @@ use std::fmt::{Debug, Display, Formatter}; /// A wrapper around `llama_timings`. #[derive(Clone, Copy, Debug)] pub struct LlamaTimings { - pub(crate) timings: llama_cpp_sys_2::llama_timings, + pub(crate) timings: llama_cpp_sys_2::llama_perf_context_data, } impl LlamaTimings { /// Create a new `LlamaTimings`. /// ``` /// # use llama_cpp_2::timing::LlamaTimings; - /// let timings = LlamaTimings::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7, 8, 9); - /// let timings_str = "load time = 3.00 ms - /// sample time = 4.00 ms / 7 runs (0.57 ms per token, 1750.00 tokens per second) - /// prompt eval time = 5.00 ms / 8 tokens (0.62 ms per token, 1600.00 tokens per second) - /// eval time = 6.00 ms / 9 runs (0.67 ms per token, 1500.00 tokens per second) - /// total time = 1.00 ms"; + /// let timings = LlamaTimings::new(1.0, 2.0, 3.0, 4.0, 5, 6); + /// let timings_str = "load time = 2.00 ms + /// prompt eval time = 3.00 ms / 5 tokens (0.60 ms per token, 1666.67 tokens per second) + /// eval time = 4.00 ms / 6 runs (0.67 ms per token, 1500.00 tokens per second)\n"; /// assert_eq!(timings_str, format!("{}", timings)); /// ``` #[allow(clippy::too_many_arguments)] #[must_use] pub fn new( t_start_ms: f64, - t_end_ms: f64, t_load_ms: f64, - t_sample_ms: f64, t_p_eval_ms: f64, t_eval_ms: f64, - n_sample: i32, n_p_eval: i32, n_eval: i32, ) -> Self { Self { - timings: llama_cpp_sys_2::llama_timings { + timings: llama_cpp_sys_2::llama_perf_context_data { t_start_ms, - t_end_ms, t_load_ms, - t_sample_ms, t_p_eval_ms, t_eval_ms, - n_sample, n_p_eval, n_eval, }, @@ -53,24 +45,12 @@ impl LlamaTimings { self.timings.t_start_ms } - /// Get the end time in milliseconds. - #[must_use] - pub fn t_end_ms(&self) -> f64 { - self.timings.t_end_ms - } - /// Get the load time in milliseconds. #[must_use] pub fn t_load_ms(&self) -> f64 { self.timings.t_load_ms } - /// Get the sample time in milliseconds. - #[must_use] - pub fn t_sample_ms(&self) -> f64 { - self.timings.t_sample_ms - } - /// Get the prompt evaluation time in milliseconds. #[must_use] pub fn t_p_eval_ms(&self) -> f64 { @@ -83,12 +63,6 @@ impl LlamaTimings { self.timings.t_eval_ms } - /// Get the number of samples. - #[must_use] - pub fn n_sample(&self) -> i32 { - self.timings.n_sample - } - /// Get the number of prompt evaluations. #[must_use] pub fn n_p_eval(&self) -> i32 { @@ -106,21 +80,11 @@ impl LlamaTimings { self.timings.t_start_ms = t_start_ms; } - /// Set the end time in milliseconds. - pub fn set_t_end_ms(&mut self, t_end_ms: f64) { - self.timings.t_end_ms = t_end_ms; - } - /// Set the load time in milliseconds. pub fn set_t_load_ms(&mut self, t_load_ms: f64) { self.timings.t_load_ms = t_load_ms; } - /// Set the sample time in milliseconds. - pub fn set_t_sample_ms(&mut self, t_sample_ms: f64) { - self.timings.t_sample_ms = t_sample_ms; - } - /// Set the prompt evaluation time in milliseconds. pub fn set_t_p_eval_ms(&mut self, t_p_eval_ms: f64) { self.timings.t_p_eval_ms = t_p_eval_ms; @@ -131,11 +95,6 @@ impl LlamaTimings { self.timings.t_eval_ms = t_eval_ms; } - /// Set the number of samples. - pub fn set_n_sample(&mut self, n_sample: i32) { - self.timings.n_sample = n_sample; - } - /// Set the number of prompt evaluations. pub fn set_n_p_eval(&mut self, n_p_eval: i32) { self.timings.n_p_eval = n_p_eval; @@ -150,14 +109,6 @@ impl LlamaTimings { impl Display for LlamaTimings { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { writeln!(f, "load time = {:.2} ms", self.t_load_ms())?; - writeln!( - f, - "sample time = {:.2} ms / {} runs ({:.2} ms per token, {:.2} tokens per second)", - self.t_sample_ms(), - self.n_sample(), - self.t_sample_ms() / f64::from(self.n_sample()), - 1e3 / self.t_sample_ms() * f64::from(self.n_sample()) - )?; writeln!( f, "prompt eval time = {:.2} ms / {} tokens ({:.2} ms per token, {:.2} tokens per second)", @@ -174,10 +125,6 @@ impl Display for LlamaTimings { self.t_eval_ms() / f64::from(self.n_eval()), 1e3 / self.t_eval_ms() * f64::from(self.n_eval()) )?; - write!( - f, - "total time = {:.2} ms", - self.t_end_ms() - self.t_start_ms() - ) + Ok(()) } } diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index e81ab336..d9693049 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -1,10 +1,5 @@ //! an rusty equivalent of `llama_token_data`. -use crate::context::LlamaContext; use crate::token::data::LlamaTokenData; -use crate::token::LlamaToken; -use llama_cpp_sys_2::llama_token; -use std::cmp::min; -use std::ptr; /// a safe wrapper around `llama_token_data_array`. #[derive(Debug, Clone, PartialEq)] @@ -53,358 +48,3 @@ impl LlamaTokenDataArray { Self::new(data.into_iter().collect(), sorted) } } - -impl LlamaTokenDataArray { - /// Modify the underlying data as a `llama_token_data_array`. and reconstruct the `LlamaTokenDataArray`. - /// - /// # Panics - /// - /// Panics if some of the safety conditions are not met. (we cannot check all of them at runtime so breaking them is UB) - /// - /// SAFETY: - /// [modify] cannot change the data pointer. - /// if the data is not sorted, sorted must be false. - /// the size of the data can only decrease (i.e you cannot add new elements). - pub(crate) unsafe fn modify_as_c_llama_token_data_array( - &mut self, - modify: impl FnOnce(&mut llama_cpp_sys_2::llama_token_data_array) -> T, - ) -> T { - let size = self.data.len(); - let data = self.data.as_mut_ptr().cast(); - let mut c_llama_token_data_array = llama_cpp_sys_2::llama_token_data_array { - data, - size, - sorted: self.sorted, - }; - let result = modify(&mut c_llama_token_data_array); - assert!( - ptr::eq(data, c_llama_token_data_array.data), - "data pointer changed" - ); - assert!(c_llama_token_data_array.size <= size, "size increased"); - self.data.set_len(c_llama_token_data_array.size); - self.sorted = c_llama_token_data_array.sorted; - result - } - - /// Repetition penalty described in [CTRL academic paper](https://arxiv.org/abs/1909.05858), with negative logit fix. - /// Frequency and presence penalties described in [OpenAI API](https://platform.openai.com/docs/api-reference/parameter-details). - /// - /// # Parameters - /// - /// * `ctx` - the context to use. May be `None` if you do not care to record the sample timings. - /// * `last_tokens` - the last tokens in the context. - /// - /// * `penalty_last_n` - the number of tokens back to consider for the repetition penalty. (0 for no penalty) - /// * `penalty_repeat` - the repetition penalty. (1.0 for no penalty) - /// * `penalty_freq` - the frequency penalty. (0.0 for no penalty) - /// * `penalty_present` - the presence penalty. (0.0 for no penalty) - /// - /// # Example - /// - /// ```rust - /// # use std::collections::BTreeMap; - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// let history = vec![ - /// LlamaToken::new(2), - /// LlamaToken::new(1), - /// LlamaToken::new(0), - /// ]; - /// - /// let candidates = vec![ - /// LlamaToken::new(0), - /// LlamaToken::new(1), - /// LlamaToken::new(2), - /// LlamaToken::new(3), - /// ]; - /// - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates.iter().map(|&token| LlamaTokenData::new(token, 0.0, 0.0)), false); - /// - /// candidates.sample_repetition_penalty(None, &history, 2, 1.1, 0.1, 0.1); - /// - /// let token_logits = candidates.data.into_iter().map(|token_data| (token_data.id(), token_data.logit())).collect::>(); - /// assert_eq!(token_logits[&LlamaToken(0)], 0.0, "expected no penalty as it is out of `penalty_last_n`"); - /// assert!(token_logits[&LlamaToken(1)] < 0.0, "expected penalty as it is in `penalty_last_n`"); - /// assert!(token_logits[&LlamaToken(2)] < 0.0, "expected penalty as it is in `penalty_last_n`"); - /// assert_eq!(token_logits[&LlamaToken(3)], 0.0, "expected no penalty as it is not in `history`"); - /// ``` - pub fn sample_repetition_penalty( - &mut self, - ctx: Option<&mut LlamaContext>, - last_tokens: &[LlamaToken], - penalty_last_n: usize, - penalty_repeat: f32, - penalty_freq: f32, - penalty_present: f32, - ) { - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - let penalty_last_n = min(penalty_last_n, last_tokens.len().saturating_sub(1)); - unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_repetition_penalties( - ctx, - c_llama_token_data_array, - // safe cast as LlamaToken is repr(transparent) - last_tokens.as_ptr().cast::(), - penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - ); - }); - } - } - - /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - /// - /// # Example - /// - /// ```rust - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let lowest = LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0); - /// let middle = LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0); - /// let highest = LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0); - /// - /// let candidates = vec![lowest, middle, highest]; - /// - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); - /// candidates.sample_softmax(None); - /// - /// assert!(candidates.sorted); - /// assert_eq!(candidates.data[0].id(), highest.id()); - /// assert_eq!(candidates.data[0].logit(), highest.logit()); - /// assert!(candidates.data[0].p() > candidates.data[1].p()); - /// assert_eq!(candidates.data[1].id(), middle.id()); - /// assert_eq!(candidates.data[1].logit(), middle.logit()); - /// assert!(candidates.data[1].p() > candidates.data[2].p()); - /// assert_eq!(candidates.data[2].id(), lowest.id()); - /// assert_eq!(candidates.data[2].logit(), lowest.logit()); - /// ``` - pub fn sample_softmax(&mut self, ctx: Option<&mut LlamaContext>) { - unsafe { - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_softmax(ctx, c_llama_token_data_array); - }); - } - } - - /// Modify the logits of [`Self`] in place using temperature sampling. - /// - /// # Example - /// - /// ```rust - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let candidates = vec![ - /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), - /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), - /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0) - /// ]; - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); - /// - /// candidates.sample_temp(None, 0.5); - /// - /// assert_ne!(candidates.data[0].logit(), 0.1); - /// assert_ne!(candidates.data[1].logit(), 0.2); - /// assert_ne!(candidates.data[2].logit(), 0.7); - /// ``` - pub fn sample_temp(&mut self, ctx: Option<&mut LlamaContext>, temperature: f32) { - if temperature == 0.0 { - return; - } - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_temp(ctx, c_llama_token_data_array, temperature); - }); - } - } - - /// Randomly selects a token from the candidates based on their probabilities. - pub fn sample_token(&mut self, ctx: &mut LlamaContext) -> LlamaToken { - let llama_token = unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_token(ctx.context.as_ptr(), c_llama_token_data_array) - }) - }; - LlamaToken(llama_token) - } - - /// Top-K sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) - pub fn sample_top_k(&mut self, ctx: Option<&mut LlamaContext>, k: i32, min_keep: usize) { - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_top_k(ctx, c_llama_token_data_array, k, min_keep); - }); - } - } - - /// Tail Free Sampling described in [Tail-Free-Sampling](https://www.trentonbricken.com/Tail-Free-Sampling/). - pub fn sample_tail_free(&mut self, ctx: Option<&mut LlamaContext>, z: f32, min_keep: usize) { - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_tail_free(ctx, c_llama_token_data_array, z, min_keep); - }); - } - } - - /// Locally Typical Sampling implementation described in the [paper](https://arxiv.org/abs/2202.00666). - /// - /// # Example - /// - /// ```rust - /// - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let candidates = vec![ - /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), - /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), - /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), - /// ]; - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); - /// candidates.sample_typical(None, 0.5, 1); - /// - /// ``` - pub fn sample_typical(&mut self, ctx: Option<&mut LlamaContext>, p: f32, min_keep: usize) { - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_typical(ctx, c_llama_token_data_array, p, min_keep); - }); - } - } - - /// Nucleus sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) - /// - /// # Example - /// - /// ```rust - /// - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let candidates = vec![ - /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), - /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), - /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), - /// ]; - /// - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); - /// candidates.sample_top_p(None, 0.5, 1); - /// - /// assert_eq!(candidates.data.len(), 2); - /// assert_eq!(candidates.data[0].id(), LlamaToken::new(2)); - /// assert_eq!(candidates.data[1].id(), LlamaToken::new(1)); - /// ``` - pub fn sample_top_p(&mut self, ctx: Option<&mut LlamaContext>, p: f32, min_keep: usize) { - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_top_p(ctx, c_llama_token_data_array, p, min_keep); - }); - } - } - - /// Minimum P sampling as described in [#3841](https://github.com/ggerganov/llama.cpp/pull/3841) - /// - /// # Example - /// - /// ``` - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let candidates = vec![ - /// LlamaTokenData::new(LlamaToken::new(4), 0.0001, 0.0), - /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), - /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), - /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), - /// ]; - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); - /// candidates.sample_min_p(None, 0.05, 1); - /// ``` - pub fn sample_min_p(&mut self, ctx: Option<&mut LlamaContext>, p: f32, min_keep: usize) { - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_min_p(ctx, c_llama_token_data_array, p, min_keep); - }); - } - } - - /// Mirostat 2.0 algorithm described in the [paper](https://arxiv.org/abs/2007.14966). Uses tokens instead of words. - /// - /// # Parameters - /// - /// * `tau` The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - /// * `eta` The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// * `mu` Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - pub fn sample_token_mirostat_v2( - &mut self, - ctx: &mut LlamaContext, - tau: f32, - eta: f32, - mu: &mut f32, - ) -> LlamaToken { - let mu_ptr = ptr::from_mut(mu); - let token = unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_token_mirostat_v2( - ctx.context.as_ptr(), - c_llama_token_data_array, - tau, - eta, - mu_ptr, - ) - }) - }; - *mu = unsafe { *mu_ptr }; - LlamaToken(token) - } - - /// Mirostat 1.0 algorithm described in the [paper](https://arxiv.org/abs/2007.14966). Uses tokens instead of words. - /// - /// # Parameters - /// - /// * `tau` The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - /// * `eta` The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// * `m` The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. - /// * `mu` Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - pub fn sample_token_mirostat_v1( - &mut self, - ctx: &mut LlamaContext, - tau: f32, - eta: f32, - m: i32, - mu: &mut f32, - ) -> LlamaToken { - let mu_ptr = ptr::from_mut(mu); - let token = unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_token_mirostat( - ctx.context.as_ptr(), - c_llama_token_data_array, - tau, - eta, - m, - mu_ptr, - ) - }) - }; - *mu = unsafe { *mu_ptr }; - LlamaToken(token) - } -} diff --git a/llama-cpp-sys-2/llama.cpp b/llama-cpp-sys-2/llama.cpp index 8f1d81a0..0abc6a2c 160000 --- a/llama-cpp-sys-2/llama.cpp +++ b/llama-cpp-sys-2/llama.cpp @@ -1 +1 @@ -Subproject commit 8f1d81a0b6f50b9bad72db0b6fcd299ad9ecd48c +Subproject commit 0abc6a2c25272d5cf01384dda8ee8bfec4ba8745