From 635d621f5b7de3859ba4fb43696f64d69f4c7fd1 Mon Sep 17 00:00:00 2001 From: marcus Date: Thu, 14 Mar 2024 09:05:02 -0700 Subject: [PATCH 1/6] moved sample softmax to token_data with optional context and started on sampler --- llama-cpp-2/Cargo.toml | 1 + llama-cpp-2/src/context/sample.rs | 149 ++-------------------- llama-cpp-2/src/context/sample/sampler.rs | 35 +++++ llama-cpp-2/src/lib.rs | 1 + llama-cpp-2/src/token/data_array.rs | 37 ++++++ 5 files changed, 86 insertions(+), 137 deletions(-) create mode 100644 llama-cpp-2/src/context/sample/sampler.rs diff --git a/llama-cpp-2/Cargo.toml b/llama-cpp-2/Cargo.toml index c1c0193b..0cae4872 100644 --- a/llama-cpp-2/Cargo.toml +++ b/llama-cpp-2/Cargo.toml @@ -28,6 +28,7 @@ harness = false [features] cublas = ["llama-cpp-sys-2/cublas"] +sampler = [] [lints] workspace = true diff --git a/llama-cpp-2/src/context/sample.rs b/llama-cpp-2/src/context/sample.rs index 7b1b69a5..dac62b27 100644 --- a/llama-cpp-2/src/context/sample.rs +++ b/llama-cpp-2/src/context/sample.rs @@ -5,130 +5,10 @@ use crate::grammar::LlamaGrammar; use crate::token::data_array::LlamaTokenDataArray; use crate::token::LlamaToken; -/// struct to hold params for sampling -#[derive(Debug)] -#[deprecated( - since = "0.1.32", - note = "this does not scale well with many params and does not allow for changing of orders." -)] -pub struct Sampler<'grammar> { - token_data_array: LlamaTokenDataArray, - grammar: Option<&'grammar mut LlamaGrammar>, - temperature: Option, -} - -impl<'grammar> Sampler<'grammar> { - #[deprecated( - since = "0.1.32", - note = "this does not scale well with many params and does not allow for changing of orders." - )] - fn sample(self, llama_context: &mut LlamaContext) -> LlamaToken { - match self { - Sampler { - token_data_array, - grammar: None, - temperature: None, - } => llama_context.sample_token_greedy(token_data_array), - Sampler { - mut token_data_array, - grammar: Some(grammar), - temperature: None, - } => { - llama_context.sample_grammar(&mut token_data_array, grammar); - let token = llama_context.sample_token_greedy(token_data_array); - llama_context.grammar_accept_token(grammar, token); - token - } - Sampler { - mut token_data_array, - grammar: None, - temperature: Some(temp), - } => { - llama_context.sample_temp(&mut token_data_array, temp); - llama_context.sample_token_softmax(&mut token_data_array); - token_data_array.data[0].id() - } - Sampler { - mut token_data_array, - grammar: Some(grammar), - temperature: Some(temperature), - } => { - llama_context.sample_grammar(&mut token_data_array, grammar); - llama_context.sample_temp(&mut token_data_array, temperature); - llama_context.sample_token_softmax(&mut token_data_array); - let token = llama_context.sample_token_greedy(token_data_array); - llama_context.grammar_accept_token(grammar, token); - token - } - } - } - - /// Create a new sampler. - #[must_use] - #[deprecated( - since = "0.1.32", - note = "this does not scale well with many params and does not allow for changing of orders." - )] - pub fn new(llama_token_data_array: LlamaTokenDataArray) -> Self { - Self { - token_data_array: llama_token_data_array, - grammar: None, - temperature: None, - } - } - - /// Set the grammar for sampling. - #[must_use] - #[deprecated( - since = "0.1.32", - note = "this does not scale well with many params and does not allow for changing of orders." - )] - pub fn with_grammar(mut self, grammar: &'grammar mut LlamaGrammar) -> Self { - self.grammar = Some(grammar); - self - } - - /// Set the temperature for sampling. - /// - /// ``` - /// # use llama_cpp_2::context::LlamaContext; - /// # use llama_cpp_2::context::sample::Sampler; - /// # use llama_cpp_2::grammar::LlamaGrammar; - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let _sampler = Sampler::new(LlamaTokenDataArray::new(vec![LlamaTokenData::new(LlamaToken(0), 0.0, 0.0)], false)) - /// .with_temperature(0.5); - /// ``` - #[must_use] - #[deprecated( - since = "0.1.32", - note = "this does not scale well with many params and does not allow for changing of orders." - )] - pub fn with_temperature(mut self, temperature: f32) -> Self { - if temperature == 0.0 { - return self; - } - self.temperature = Some(temperature); - self - } -} +#[cfg(feature = "sampler")] +mod sampler; impl LlamaContext<'_> { - /// Sample a token. - /// - /// # Panics - /// - /// - sampler contains no tokens - #[deprecated( - since = "0.1.32", - note = "this does not scale well with many params and does not allow for changing of orders." - )] - pub fn sample(&mut self, sampler: Sampler) -> LlamaToken { - sampler.sample(self) - } - /// Accept a token into the grammar. pub fn grammar_accept_token(&mut self, grammar: &mut LlamaGrammar, token: LlamaToken) { unsafe { @@ -162,7 +42,7 @@ impl LlamaContext<'_> { /// # Panics /// /// - [`temperature`] is not between 0.0 and 1.0 - pub fn sample_temp(&self, token_data: &mut LlamaTokenDataArray, temperature: f32) { + pub fn sample_temp(&mut self, token_data: &mut LlamaTokenDataArray, temperature: f32) { assert!( temperature >= 0.0, "temperature must be positive (was {temperature})" @@ -188,7 +68,7 @@ impl LlamaContext<'_> { /// /// - [`token_data`] is empty #[must_use] - pub fn sample_token_greedy(&self, mut token_data: LlamaTokenDataArray) -> LlamaToken { + pub fn sample_token_greedy(&mut self, mut token_data: LlamaTokenDataArray) -> LlamaToken { assert!(!token_data.data.is_empty(), "no tokens"); let mut data_arr = llama_cpp_sys_2::llama_token_data_array { data: token_data @@ -208,7 +88,7 @@ impl LlamaContext<'_> { } /// Tail Free Sampling described in [Tail-Free-Sampling](https://www.trentonbricken.com/Tail-Free-Sampling/). - pub fn sample_tail_free(&self, token_data: &mut LlamaTokenDataArray, z: f32, min_keep: usize) { + pub fn sample_tail_free(&mut self, token_data: &mut LlamaTokenDataArray, z: f32, min_keep: usize) { let ctx = self.context.as_ptr(); unsafe { token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { @@ -218,7 +98,7 @@ impl LlamaContext<'_> { } /// Locally Typical Sampling implementation described in the [paper](https://arxiv.org/abs/2202.00666). - pub fn sample_typical(&self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) { + pub fn sample_typical(&mut self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) { let ctx = self.context.as_ptr(); unsafe { token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { @@ -228,7 +108,7 @@ impl LlamaContext<'_> { } /// Nucleus sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751)" - pub fn sample_top_p(&self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) { + pub fn sample_top_p(&mut self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) { let ctx = self.context.as_ptr(); unsafe { token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { @@ -239,7 +119,7 @@ impl LlamaContext<'_> { /// Minimum P sampling as described in [#3841](https://github.com/ggerganov/llama.cpp/pull/3841) pub fn sample_min_p( - &self, + &mut self, llama_token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize, @@ -253,7 +133,7 @@ impl LlamaContext<'_> { } /// Top-K sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) - pub fn sample_top_k(&self, token_data: &mut LlamaTokenDataArray, k: i32, min_keep: usize) { + pub fn sample_top_k(&mut self, token_data: &mut LlamaTokenDataArray, k: i32, min_keep: usize) { let ctx = self.context.as_ptr(); unsafe { token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { @@ -262,14 +142,9 @@ impl LlamaContext<'_> { } } - /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - pub fn sample_token_softmax(&self, token_data: &mut LlamaTokenDataArray) { - let ctx = self.context.as_ptr(); - unsafe { - token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_softmax(ctx, c_llama_token_data_array); - }); - } + /// See [`LlamaTokenDataArray::sample_softmax`] + pub fn sample_token_softmax(&mut self, token_data: &mut LlamaTokenDataArray) { + token_data.sample_softmax(Some(self)); } /// See [`LlamaTokenDataArray::sample_repetition_penalty`] diff --git a/llama-cpp-2/src/context/sample/sampler.rs b/llama-cpp-2/src/context/sample/sampler.rs new file mode 100644 index 00000000..19dd797a --- /dev/null +++ b/llama-cpp-2/src/context/sample/sampler.rs @@ -0,0 +1,35 @@ +//! A more rusty way of sampling. Allows for adding a stack of sampling steps and a `finalizer` which selects a token from the remaining candidates. +//! +//! # Example +//! +//! ```rust +//! +//! ``` + +use crate::token::data_array::LlamaTokenDataArray; +use crate::token::LlamaToken; + +/// A series of sampling steps that will produce a token. +struct Sampler { + steps: Vec ()>>, + finalizer: Box Option>, +} + +impl Sampler { + /// Create a very simple sampler that selects the token with the highest probability. + fn greedy() -> Self { + Self { + steps: Vec::new(), + finalizer: Box::new(|mut token_data| { + if token_data.data.is_empty() { + return None; + } + if token_data.sorted { + Some(token_data[0]) + } else { + token_data + } + }), + } + } +} diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index bf429364..7a478e56 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -12,6 +12,7 @@ //! # Feature Flags //! //! - `cublas` enables CUDA gpu support. +//! - `sampler` adds the [`Sampler`] struct for a more rusty way of sampling. use std::ffi::NulError; use std::fmt::Debug; use std::num::NonZeroI32; diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index 66b6e583..35a241ce 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -155,4 +155,41 @@ impl LlamaTokenDataArray { }); } } + + /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. + /// + /// # Example + /// + /// ```rust + /// # use llama_cpp_2::token::data::LlamaTokenData; + /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// # use llama_cpp_2::token::LlamaToken; + /// + /// let lowest = LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0); + /// let middle = LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0); + /// let highest = LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0); + /// + /// let candidates = vec![lowest, middle, highest]; + /// + /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); + /// candidates.sample_softmax(None); + /// + /// assert!(candidates.sorted); + /// assert_eq!(candidates.data[0].id(), highest.id()); + /// assert_eq!(candidates.data[0].logit(), highest.logit()); + /// assert!(candidates.data[0].p() > candidates.data[1].p()); + /// assert_eq!(candidates.data[1].id(), middle.id()); + /// assert_eq!(candidates.data[1].logit(), middle.logit()); + /// assert!(candidates.data[1].p() > candidates.data[2].p()); + /// assert_eq!(candidates.data[2].id(), lowest.id()); + /// assert_eq!(candidates.data[2].logit(), lowest.logit()); + /// ``` + pub fn sample_softmax(&mut self, ctx: Option<&mut LlamaContext>) { + unsafe { + let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_softmax(ctx, c_llama_token_data_array); + }); + } + } } From 6cfe39c7affaf445c700deb0a9fff5923a61b52e Mon Sep 17 00:00:00 2001 From: marcus Date: Thu, 14 Mar 2024 10:56:20 -0700 Subject: [PATCH 2/6] fixed clippy --- .github/workflows/llama-cpp-rs-check.yml | 6 +- llama-cpp-2/src/context/sample.rs | 82 +++---- llama-cpp-2/src/context/sample/sampler.rs | 253 ++++++++++++++++++++-- llama-cpp-2/src/token/data_array.rs | 170 ++++++++++++++- 4 files changed, 427 insertions(+), 84 deletions(-) diff --git a/.github/workflows/llama-cpp-rs-check.yml b/.github/workflows/llama-cpp-rs-check.yml index c26572c9..da5d067f 100644 --- a/.github/workflows/llama-cpp-rs-check.yml +++ b/.github/workflows/llama-cpp-rs-check.yml @@ -34,7 +34,7 @@ jobs: - name: Fmt run: cargo fmt - name: Test - run: cargo test + run: cargo test --features sampler arm64: name: Check that it builds on various targets runs-on: ubuntu-latest @@ -67,7 +67,7 @@ jobs: - name: Setup Rust uses: dtolnay/rust-toolchain@stable - name: Build - run: cargo build + run: cargo build --features sampler windows: name: Check that it builds on windows runs-on: windows-latest @@ -79,4 +79,4 @@ jobs: - name: Setup Rust uses: dtolnay/rust-toolchain@stable - name: Build - run: cargo build + run: cargo build --features sampler diff --git a/llama-cpp-2/src/context/sample.rs b/llama-cpp-2/src/context/sample.rs index dac62b27..773675ee 100644 --- a/llama-cpp-2/src/context/sample.rs +++ b/llama-cpp-2/src/context/sample.rs @@ -6,7 +6,7 @@ use crate::token::data_array::LlamaTokenDataArray; use crate::token::LlamaToken; #[cfg(feature = "sampler")] -mod sampler; +pub mod sampler; impl LlamaContext<'_> { /// Accept a token into the grammar. @@ -37,36 +37,18 @@ impl LlamaContext<'_> { } } - /// Modify [`token_data`] in place using temperature sampling. - /// - /// # Panics - /// - /// - [`temperature`] is not between 0.0 and 1.0 + /// See [`LlamaTokenDataArray::sample_temp`] pub fn sample_temp(&mut self, token_data: &mut LlamaTokenDataArray, temperature: f32) { - assert!( - temperature >= 0.0, - "temperature must be positive (was {temperature})" - ); - assert!( - temperature <= 1.0, - "temperature must be less than or equal to 1.0 (was {temperature})" - ); - if temperature == 0.0 { - return; - } - let ctx: *mut llama_cpp_sys_2::llama_context = self.context.as_ptr(); - unsafe { - token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_temp(ctx, c_llama_token_data_array, temperature); - }); - } + token_data.sample_temp(Some(self), temperature); } - /// Sample a token greedily. + /// Sample a token greedily. Note that this *does not* take into account anything that has modified the probabilities - it only looks at logits. + /// + /// Most of the time [`LlamaTokenDataArray::sample_softmax`] or [`LlamaTokenDataArray::sample_token`] should be used instead. /// /// # Panics /// - /// - [`token_data`] is empty + /// - if [`token_data`] is empty #[must_use] pub fn sample_token_greedy(&mut self, mut token_data: LlamaTokenDataArray) -> LlamaToken { assert!(!token_data.data.is_empty(), "no tokens"); @@ -87,34 +69,29 @@ impl LlamaContext<'_> { LlamaToken(token) } - /// Tail Free Sampling described in [Tail-Free-Sampling](https://www.trentonbricken.com/Tail-Free-Sampling/). - pub fn sample_tail_free(&mut self, token_data: &mut LlamaTokenDataArray, z: f32, min_keep: usize) { - let ctx = self.context.as_ptr(); - unsafe { - token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_tail_free(ctx, c_llama_token_data_array, z, min_keep); - }); - } + /// See [`LlamaTokenDataArray::sample_tail_free`] + pub fn sample_tail_free( + &mut self, + token_data: &mut LlamaTokenDataArray, + z: f32, + min_keep: usize, + ) { + token_data.sample_tail_free(Some(self), z, min_keep); } - /// Locally Typical Sampling implementation described in the [paper](https://arxiv.org/abs/2202.00666). - pub fn sample_typical(&mut self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) { - let ctx = self.context.as_ptr(); - unsafe { - token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_typical(ctx, c_llama_token_data_array, p, min_keep); - }); - } + /// See [`LlamaTokenDataArray::sample_typical`] + pub fn sample_typical( + &mut self, + token_data: &mut LlamaTokenDataArray, + p: f32, + min_keep: usize, + ) { + token_data.sample_typical(Some(self), p, min_keep); } - /// Nucleus sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751)" + /// See [`LlamaTokenDataArray::sample_top_p`] pub fn sample_top_p(&mut self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) { - let ctx = self.context.as_ptr(); - unsafe { - token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_top_p(ctx, c_llama_token_data_array, p, min_keep); - }); - } + token_data.sample_top_p(Some(self), p, min_keep); } /// Minimum P sampling as described in [#3841](https://github.com/ggerganov/llama.cpp/pull/3841) @@ -132,14 +109,9 @@ impl LlamaContext<'_> { } } - /// Top-K sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) + /// See [`LlamaTokenDataArray::sample_top_k`] pub fn sample_top_k(&mut self, token_data: &mut LlamaTokenDataArray, k: i32, min_keep: usize) { - let ctx = self.context.as_ptr(); - unsafe { - token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_top_k(ctx, c_llama_token_data_array, k, min_keep); - }); - } + token_data.sample_top_k(Some(self), k, min_keep); } /// See [`LlamaTokenDataArray::sample_softmax`] diff --git a/llama-cpp-2/src/context/sample/sampler.rs b/llama-cpp-2/src/context/sample/sampler.rs index 19dd797a..bf3d6edf 100644 --- a/llama-cpp-2/src/context/sample/sampler.rs +++ b/llama-cpp-2/src/context/sample/sampler.rs @@ -3,33 +3,252 @@ //! # Example //! //! ```rust +//! use llama_cpp_2::context::sample::sampler::Sampler; +//! use llama_cpp_2::token::data::LlamaTokenData; +//! use llama_cpp_2::token::data_array::LlamaTokenDataArray; +//! use llama_cpp_2::token::LlamaToken; //! +//! let mut history = vec![]; +//! let candidates = LlamaTokenDataArray::from_iter((0..4).map(|i| LlamaTokenData::new(LlamaToken::new(i), i as f32 / 6.0, 0.0)), false); +//! +//! let token = { +//! let mut sampler = Sampler::greedy(); +//! sampler.push_sample_repetition_penalty_step(&history, 64, 1.1, 0.0, 0.0); +//! sampler.push_top_k_step(40, 1); +//! sampler.push_sample_tail_free_step(1.0, 1); +//! sampler.push_sample_typical_step(1.0, 1); +//! sampler.push_sample_top_p_step(0.95, 1); +//! sampler.push_min_p_step(0.05, 1); +//! sampler.push_temperature_step(0.5); +//! sampler.sample(candidates) +//! }; +//! history.push(token[0].id()); +//! +//! println!("{:?}", token); //! ``` +use crate::token::data::LlamaTokenData; use crate::token::data_array::LlamaTokenDataArray; use crate::token::LlamaToken; +use std::fmt::{Debug, Formatter}; + +/// A single step to sample tokens from the remaining candidates. +pub type SampleStep<'a> = Box; + +/// The final step to select one or more tokens from the remaining candidates. +pub type SampleFinalizer<'a> = Box Vec + 'a>; -/// A series of sampling steps that will produce a token. -struct Sampler { - steps: Vec ()>>, - finalizer: Box Option>, +/// A series of sampling steps that will produce a vector of token data. +/// +/// [`a`] is the lifetime of captured references in the steps and finalizer. +#[non_exhaustive] +pub struct Sampler<'a> { + /// The steps to take when sampling. + pub steps: Vec>, + /// The final step to select one or more tokens from the remaining candidates. + pub finalizer: SampleFinalizer<'a>, +} + +impl Debug for Sampler<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Sampler") + .field( + "steps", + &format!( + "{} steps of Box ()>", + &self.steps.len() + ), + ) + .field( + "finalizer", + &"Box Vec>", + ) + .finish() + } } -impl Sampler { - /// Create a very simple sampler that selects the token with the highest probability. - fn greedy() -> Self { +impl<'a> Sampler<'a> { + /// Create a very simple sampler that selects a single token with the greatest logit (greedy sampling). + /// + /// # Example + /// + /// ```rust + /// use llama_cpp_2::context::sample::sampler::Sampler; + /// use llama_cpp_2::token::data::LlamaTokenData; + /// use llama_cpp_2::token::data_array::LlamaTokenDataArray;use llama_cpp_2::token::LlamaToken; + /// + /// let mut sampler = Sampler::greedy(); + /// + /// let candidates = (0..4).map(|i| LlamaTokenData::new(LlamaToken::new(i), i as f32 / 6.0, 0.0)); + /// let tokens = sampler.sample(LlamaTokenDataArray::from_iter(candidates, false)); + /// assert_eq!(tokens[0].id(), LlamaToken::new(3)); + /// ``` + #[must_use] + pub fn greedy() -> Self { + let finalizer= |mut token_data: LlamaTokenDataArray| { + if token_data.data.is_empty() { + return vec![]; + } + if token_data.sorted { + vec![token_data.data[0]] + } else { + token_data.sample_softmax(None); + vec![token_data.data[0]] + } + }; + Self::new(finalizer) + } + + /// Adds a repetition penalty sampling step to the sampler. + /// + /// See [`LlamaTokenDataArray::sample_repetition_penalty`] + pub fn push_sample_repetition_penalty_step( + &mut self, + history: &'a [LlamaToken], + penalty_last_n: usize, + penalty_repeat: f32, + penalty_freq: f32, + penalty_present: f32, + ) { + self.steps + .push(Box::new(move |can: &mut LlamaTokenDataArray| { + can.sample_repetition_penalty( + None, + history, + penalty_last_n, + penalty_repeat, + penalty_freq, + penalty_present, + ); + })); + } + + /// Adds a typical sampling step to the sampler. + /// + /// See [`LlamaTokenDataArray::sample_typical`] + pub fn push_sample_typical_step(&mut self, p: f32, min_keep: usize) { + self.steps + .push(Box::new(move |can: &mut LlamaTokenDataArray| { + can.sample_typical(None, p, min_keep); + })); + } + + /// Adds a Top-p sampling step to the sampler. + /// + /// See [`LlamaTokenDataArray::sample_top_p`] + pub fn push_sample_top_p_step(&mut self, p: f32, min_keep: usize) { + self.steps + .push(Box::new(move |can: &mut LlamaTokenDataArray| { + can.sample_top_p(None, p, min_keep); + })); + } + + /// Adds a tail-free sampling step to the sampler. + /// + /// See [`LlamaTokenDataArray::sample_tail_free`] + pub fn push_sample_tail_free_step(&mut self, z: f32, min_keep: usize) { + self.steps + .push(Box::new(move |can: &mut LlamaTokenDataArray| { + can.sample_tail_free(None, z, min_keep); + })); + } + + /// Adds a top-k sampling step to the sampler. + /// + /// See [`LlamaTokenDataArray::sample_top_k`] + pub fn push_top_k_step(&mut self, k: i32, min_keep: usize) { + self.steps + .push(Box::new(move |can: &mut LlamaTokenDataArray| { + can.sample_top_k(None, k, min_keep); + })); + } + + /// Adds a temperature sampling step to the sampler. + /// + /// See [`LlamaTokenDataArray::sample_temp`] + pub fn push_temperature_step(&mut self, temperature: f32) { + self.steps + .push(Box::new(move |can: &mut LlamaTokenDataArray| { + can.sample_temp(None, temperature); + })); + } + + /// Adds a minimum P sampling step to the sampler. + /// + /// See [`LlamaTokenDataArray::sample_min_p`] + pub fn push_min_p_step(&mut self, p: f32, min_keep: usize) { + self.steps + .push(Box::new(move |can: &mut LlamaTokenDataArray| { + can.sample_min_p(None, p, min_keep); + })); + } + + /// Create a new sampler with a given finalizer. + /// + /// # Example + /// + /// ```rust + /// use llama_cpp_2::context::sample::sampler::Sampler; + /// use llama_cpp_2::token::data::LlamaTokenData; + /// use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// use llama_cpp_2::token::LlamaToken; + /// + /// // a very silly way to sample. + /// let always_0 = |can: LlamaTokenDataArray| -> Vec { can.data.into_iter().filter(|t| t.id() == LlamaToken::new(0)).collect::>() }; + /// + /// let mut sampler = Sampler::new(always_0); + /// + /// let candidates = (0..4).map(|i| LlamaTokenData::new(LlamaToken::new(i), i as f32, 0.0)); + /// + /// let token = sampler.sample(LlamaTokenDataArray::from_iter(candidates, false)); + /// assert_eq!(token[0].id(), LlamaToken::new(0)); + /// + /// ``` + pub fn new( + finalizer: impl FnMut(LlamaTokenDataArray) -> Vec + 'a, + ) -> Self { Self { steps: Vec::new(), - finalizer: Box::new(|mut token_data| { - if token_data.data.is_empty() { - return None; - } - if token_data.sorted { - Some(token_data[0]) - } else { - token_data - } - }), + finalizer: Box::new(finalizer), + } + } + + /// Adds a step to the sampler. + /// + /// # Example + /// + /// ```rust + /// use llama_cpp_2::context::sample::sampler::Sampler; + /// use llama_cpp_2::token::data::LlamaTokenData; + /// use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// use llama_cpp_2::token::LlamaToken; + /// + /// let mut favor_even_tokens = |can: &mut LlamaTokenDataArray| { + /// for token in can.data.iter_mut() { + /// if token.id().0 % 2 == 0 { + /// token.set_logit(token.logit() + 1.0); + /// } + /// } + /// }; + /// let mut sampler = Sampler::greedy(); + /// sampler.push_step(favor_even_tokens); + /// + /// let candidates = (0..4).map(|i| LlamaTokenData::new(LlamaToken::new(i), i as f32, 0.0)); + /// + /// let token = sampler.sample(LlamaTokenDataArray::from_iter(candidates, false)); + /// + /// assert_eq!(token[0].id(), LlamaToken::new(2)); + /// ``` + pub fn push_step(&mut self, step: impl FnMut(&mut LlamaTokenDataArray) + 'a) { + self.steps.push(Box::new(step)); + } + + /// Sample a token from the given candidates. + #[must_use] + pub fn sample(&mut self, mut candidates: LlamaTokenDataArray) -> Vec { + for step in &mut self.steps { + step(&mut candidates); } + (self.finalizer)(candidates) } } diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index 35a241ce..0f89d59f 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -65,10 +65,10 @@ impl LlamaTokenDataArray { /// [modify] cannot change the data pointer. /// if the data is not sorted, sorted must be false. /// the size of the data can only decrease (i.e you cannot add new elements). - pub(crate) unsafe fn modify_as_c_llama_token_data_array( + pub(crate) unsafe fn modify_as_c_llama_token_data_array( &mut self, - modify: impl FnOnce(&mut llama_cpp_sys_2::llama_token_data_array), - ) { + modify: impl FnOnce(&mut llama_cpp_sys_2::llama_token_data_array) -> T, + ) -> T { let size = self.data.len(); let data = self.data.as_mut_ptr().cast(); let mut c_llama_token_data_array = llama_cpp_sys_2::llama_token_data_array { @@ -76,7 +76,7 @@ impl LlamaTokenDataArray { size, sorted: self.sorted, }; - modify(&mut c_llama_token_data_array); + let result = modify(&mut c_llama_token_data_array); assert!( ptr::eq(data, c_llama_token_data_array.data), "data pointer changed" @@ -84,6 +84,7 @@ impl LlamaTokenDataArray { assert!(c_llama_token_data_array.size <= size, "size increased"); self.data.set_len(c_llama_token_data_array.size); self.sorted = c_llama_token_data_array.sorted; + result } /// Repetition penalty described in [CTRL academic paper](https://arxiv.org/abs/1909.05858), with negative logit fix. @@ -93,7 +94,7 @@ impl LlamaTokenDataArray { /// /// * `ctx` - the context to use. May be `None` if you do not care to record the sample timings. /// * `last_tokens` - the last tokens in the context. - /// + /// /// * `penalty_last_n` - the number of tokens back to consider for the repetition penalty. (0 for no penalty) /// * `penalty_repeat` - the repetition penalty. (1.0 for no penalty) /// * `penalty_freq` - the frequency penalty. (0.0 for no penalty) @@ -157,9 +158,9 @@ impl LlamaTokenDataArray { } /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - /// - /// # Example - /// + /// + /// # Example + /// /// ```rust /// # use llama_cpp_2::token::data::LlamaTokenData; /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; @@ -168,7 +169,7 @@ impl LlamaTokenDataArray { /// let lowest = LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0); /// let middle = LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0); /// let highest = LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0); - /// + /// /// let candidates = vec![lowest, middle, highest]; /// /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); @@ -192,4 +193,155 @@ impl LlamaTokenDataArray { }); } } + + /// Modify the logits of [`Self`] in place using temperature sampling. + /// + /// # Example + /// + /// ```rust + /// # use llama_cpp_2::token::data::LlamaTokenData; + /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// # use llama_cpp_2::token::LlamaToken; + /// + /// let candidates = vec![ + /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), + /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), + /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0) + /// ]; + /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); + /// + /// candidates.sample_temp(None, 0.5); + /// + /// assert_ne!(candidates.data[0].logit(), 0.1); + /// assert_ne!(candidates.data[1].logit(), 0.2); + /// assert_ne!(candidates.data[2].logit(), 0.7); + /// ``` + pub fn sample_temp(&mut self, ctx: Option<&mut LlamaContext>, temperature: f32) { + if temperature == 0.0 { + return; + } + let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); + unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_temp(ctx, c_llama_token_data_array, temperature); + }); + } + } + + /// Randomly selects a token from the candidates based on their probabilities. + pub fn sample_token(&mut self, ctx: &mut LlamaContext) -> LlamaToken { + let llama_token = unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_token(ctx.context.as_ptr(), c_llama_token_data_array) + }) + }; + LlamaToken(llama_token) + } + + /// Top-K sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) + pub fn sample_top_k(&mut self, ctx: Option<&mut LlamaContext>, k: i32, min_keep: usize) { + let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); + unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_top_k(ctx, c_llama_token_data_array, k, min_keep); + }); + } + } + + /// Tail Free Sampling described in [Tail-Free-Sampling](https://www.trentonbricken.com/Tail-Free-Sampling/). + pub fn sample_tail_free(&mut self, ctx: Option<&mut LlamaContext>, z: f32, min_keep: usize) { + let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); + unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_tail_free(ctx, c_llama_token_data_array, z, min_keep); + }); + } + } + + /// Locally Typical Sampling implementation described in the [paper](https://arxiv.org/abs/2202.00666). + /// + /// # Example + /// + /// ```rust + /// + /// # use llama_cpp_2::token::data::LlamaTokenData; + /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// # use llama_cpp_2::token::LlamaToken; + /// + /// let candidates = vec![ + /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), + /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), + /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), + /// ]; + /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); + /// candidates.sample_typical(None, 0.5, 1); + /// + /// ``` + pub fn sample_typical(&mut self, ctx: Option<&mut LlamaContext>, p: f32, min_keep: usize) { + let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); + unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_typical(ctx, c_llama_token_data_array, p, min_keep); + }); + } + } + + /// Nucleus sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) + /// + /// # Example + /// + /// ```rust + /// + /// # use llama_cpp_2::token::data::LlamaTokenData; + /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// # use llama_cpp_2::token::LlamaToken; + /// + /// let candidates = vec![ + /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), + /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), + /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), + /// ]; + /// + /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); + /// candidates.sample_top_p(None, 0.5, 1); + /// + /// assert_eq!(candidates.data.len(), 2); + /// assert_eq!(candidates.data[0].id(), LlamaToken::new(2)); + /// assert_eq!(candidates.data[1].id(), LlamaToken::new(1)); + /// ``` + pub fn sample_top_p(&mut self, ctx: Option<&mut LlamaContext>, p: f32, min_keep: usize) { + let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); + unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_top_p(ctx, c_llama_token_data_array, p, min_keep); + }); + } + } + + /// Minimum P sampling as described in [#3841](https://github.com/ggerganov/llama.cpp/pull/3841) + /// + /// # Example + /// + /// ``` + /// # use llama_cpp_2::token::data::LlamaTokenData; + /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// # use llama_cpp_2::token::LlamaToken; + /// + /// let candidates = vec![ + /// LlamaTokenData::new(LlamaToken::new(4), 0.0001, 0.0), + /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), + /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), + /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), + /// ]; + /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); + /// candidates.sample_min_p(None, 0.05, 1); + /// ``` + pub fn sample_min_p(&mut self, ctx: Option<&mut LlamaContext>, p: f32, min_keep: usize) { + let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); + unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_min_p(ctx, c_llama_token_data_array, p, min_keep); + }); + } + } } From 8f4d733fa89b2a07dfbd075486b41aa4413a7570 Mon Sep 17 00:00:00 2001 From: marcus Date: Thu, 14 Mar 2024 11:00:28 -0700 Subject: [PATCH 3/6] fixed rustdoc warnings --- llama-cpp-2/src/context.rs | 2 +- llama-cpp-2/src/context/kv_cache.rs | 18 +++++++++--------- llama-cpp-2/src/context/sample.rs | 2 +- llama-cpp-2/src/context/sample/sampler.rs | 2 +- llama-cpp-2/src/lib.rs | 2 +- llama-cpp-2/src/llama_batch.rs | 4 ++-- llama-cpp-2/src/model.rs | 2 +- llama-cpp-2/src/model/params.rs | 3 --- 8 files changed, 16 insertions(+), 19 deletions(-) diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index 38b2e5ff..36a0ebe2 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -69,7 +69,7 @@ impl<'model> LlamaContext<'model> { /// /// # Panics /// - /// - the returned [`c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems) + /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems) pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> { let result = unsafe { llama_cpp_sys_2::llama_decode(self.context.as_ptr(), batch.llama_batch) }; diff --git a/llama-cpp-2/src/context/kv_cache.rs b/llama-cpp-2/src/context/kv_cache.rs index 54ec27b4..313ea928 100644 --- a/llama-cpp-2/src/context/kv_cache.rs +++ b/llama-cpp-2/src/context/kv_cache.rs @@ -22,8 +22,8 @@ impl LlamaContext<'_> { /// /// * `src` - The sequence id to copy the cache from. /// * `dest` - The sequence id to copy the cache to. - /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to [p1]. - /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from [p0]. + /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`. + /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`. pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option, p1: Option) { let p0 = p0.map_or(-1, i32::from); let p1 = p1.map_or(-1, i32::from); @@ -37,8 +37,8 @@ impl LlamaContext<'_> { /// # Parameters /// /// * `src` - The sequence id to clear the cache for. - /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to [p1]. - /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from [p0]. + /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`. + /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`. pub fn clear_kv_cache_seq(&mut self, src: i32, p0: Option, p1: Option) { let p0 = p0.map_or(-1, i32::from); let p1 = p1.map_or(-1, i32::from); @@ -68,7 +68,7 @@ impl LlamaContext<'_> { } #[allow(clippy::doc_markdown)] - /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) + /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)` /// If the KV cache is RoPEd, the KV data is updated accordingly: /// - lazily on next [`LlamaContext::decode`] /// - explicitly with [`Self::kv_cache_update`] @@ -76,8 +76,8 @@ impl LlamaContext<'_> { /// # Parameters /// /// * `seq_id` - The sequence id to update - /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1]. - /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0]. + /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`. + /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`. /// * `delta` - The relative position to add to the tokens pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option, p1: Option, delta: i32) { let p0 = p0.map_or(-1, i32::from); @@ -95,8 +95,8 @@ impl LlamaContext<'_> { /// # Parameters /// /// * `seq_id` - The sequence id to update - /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1]. - /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0]. + /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`. + /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`. /// * `d` - The factor to divide the positions by pub fn kv_cache_seq_div( &mut self, diff --git a/llama-cpp-2/src/context/sample.rs b/llama-cpp-2/src/context/sample.rs index 773675ee..9704e264 100644 --- a/llama-cpp-2/src/context/sample.rs +++ b/llama-cpp-2/src/context/sample.rs @@ -48,7 +48,7 @@ impl LlamaContext<'_> { /// /// # Panics /// - /// - if [`token_data`] is empty + /// - if `token_data` is empty #[must_use] pub fn sample_token_greedy(&mut self, mut token_data: LlamaTokenDataArray) -> LlamaToken { assert!(!token_data.data.is_empty(), "no tokens"); diff --git a/llama-cpp-2/src/context/sample/sampler.rs b/llama-cpp-2/src/context/sample/sampler.rs index bf3d6edf..7232b290 100644 --- a/llama-cpp-2/src/context/sample/sampler.rs +++ b/llama-cpp-2/src/context/sample/sampler.rs @@ -40,7 +40,7 @@ pub type SampleFinalizer<'a> = Box Vec { /// The steps to take when sampling. diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index 7a478e56..270149c6 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -12,7 +12,7 @@ //! # Feature Flags //! //! - `cublas` enables CUDA gpu support. -//! - `sampler` adds the [`Sampler`] struct for a more rusty way of sampling. +//! - `sampler` adds the [`context::sample::sampler::Sampler`] struct for a more rusty way of sampling. use std::ffi::NulError; use std::fmt::Debug; use std::num::NonZeroI32; diff --git a/llama-cpp-2/src/llama_batch.rs b/llama-cpp-2/src/llama_batch.rs index 664d9035..e52bfa9e 100644 --- a/llama-cpp-2/src/llama_batch.rs +++ b/llama-cpp-2/src/llama_batch.rs @@ -30,7 +30,7 @@ impl LlamaBatch { self.initialized_logits.clear(); } - /// add a token to the batch for sequences [`seq_ids`] at position [pos]. If [logits] is true, the + /// add a token to the batch for sequences `seq_ids` at position `pos`. If `logits` is true, the /// token will be initialized and can be read from after the next decode. /// /// # Panics @@ -91,7 +91,7 @@ impl LlamaBatch { Ok(()) } - /// Add a sequence of tokens to the batch for the given sequence id. If [`logits_all`] is true, the + /// Add a sequence of tokens to the batch for the given sequence id. If `logits_all` is true, the /// tokens will be initialized and can be read from after the next decode. /// /// Either way the last token in the sequence will have its logits set to `true`. diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index e1a52d20..4bd85b83 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -201,7 +201,7 @@ impl LlamaModel { /// /// # Panics /// - /// - if [`buffer_size`] does not fit into a [`c_int`]. + /// - if `buffer_size` does not fit into a [`c_int`]. /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen) pub fn token_to_str_with_size( &self, diff --git a/llama-cpp-2/src/model/params.rs b/llama-cpp-2/src/model/params.rs index 69ddb004..b4d5a25e 100644 --- a/llama-cpp-2/src/model/params.rs +++ b/llama-cpp-2/src/model/params.rs @@ -9,9 +9,6 @@ use std::ptr::null; pub mod kv_overrides; /// A safe wrapper around `llama_model_params`. -/// -/// [`T`] is the type of the backing storage for the key-value overrides. Generally it can be left to [`()`] which will -/// make your life with the borrow checker much easier. #[allow(clippy::module_name_repetitions)] pub struct LlamaModelParams { pub(crate) params: llama_cpp_sys_2::llama_model_params, From 0271b08cdd314a2c750b6f1248d1e14230e8f00e Mon Sep 17 00:00:00 2001 From: marcus Date: Thu, 14 Mar 2024 12:33:43 -0700 Subject: [PATCH 4/6] v2 of the sampler --- llama-cpp-2/src/context/sample/sampler.rs | 238 ++++------------------ llama-cpp-2/src/lib.rs | 2 +- 2 files changed, 44 insertions(+), 196 deletions(-) diff --git a/llama-cpp-2/src/context/sample/sampler.rs b/llama-cpp-2/src/context/sample/sampler.rs index 7232b290..2ca808e6 100644 --- a/llama-cpp-2/src/context/sample/sampler.rs +++ b/llama-cpp-2/src/context/sample/sampler.rs @@ -1,55 +1,65 @@ -//! A more rusty way of sampling. Allows for adding a stack of sampling steps and a `finalizer` which selects a token from the remaining candidates. +//! Create a sampler struct to encapsulate the sampling process. //! //! # Example //! //! ```rust -//! use llama_cpp_2::context::sample::sampler::Sampler; +//! use llama_cpp_2::context::sample::sampler::{Sampler, SampleStep}; //! use llama_cpp_2::token::data::LlamaTokenData; //! use llama_cpp_2::token::data_array::LlamaTokenDataArray; //! use llama_cpp_2::token::LlamaToken; //! +//! let mut finalizer = &|mut canidates: LlamaTokenDataArray, history: &mut Vec| { +//! canidates.sample_softmax(None); +//! let token = canidates.data[0]; +//! history.push(token.id()); +//! vec![token] +//! }; +//! //! let mut history = vec![]; -//! let candidates = LlamaTokenDataArray::from_iter((0..4).map(|i| LlamaTokenData::new(LlamaToken::new(i), i as f32 / 6.0, 0.0)), false); +//! let mut sampler = Sampler::new(finalizer); //! -//! let token = { -//! let mut sampler = Sampler::greedy(); -//! sampler.push_sample_repetition_penalty_step(&history, 64, 1.1, 0.0, 0.0); -//! sampler.push_top_k_step(40, 1); -//! sampler.push_sample_tail_free_step(1.0, 1); -//! sampler.push_sample_typical_step(1.0, 1); -//! sampler.push_sample_top_p_step(0.95, 1); -//! sampler.push_min_p_step(0.05, 1); -//! sampler.push_temperature_step(0.5); -//! sampler.sample(candidates) -//! }; -//! history.push(token[0].id()); +//! sampler.push_step(&|c, history| c.sample_repetition_penalty(None, history, 64, 1.1, 0.0, 0.0)); +//! sampler.push_step(&|c, _| c.sample_top_k(None, 40, 1)); +//! sampler.push_step(&|c, _| c.sample_tail_free(None, 1.0, 1)); +//! sampler.push_step(&|c, _| c.sample_typical(None, 1.0, 1)); +//! sampler.push_step(&|c, _| c.sample_top_p(None, 0.95, 1)); +//! sampler.push_step(&|c, _| c.sample_min_p(None, 0.05, 1)); +//! sampler.push_step(&|c, _| c.sample_temp(None, 0.5)); +//! +//! let candidates = LlamaTokenDataArray::from_iter((0..4).map(|i| LlamaTokenData::new(LlamaToken::new(i), i as f32 / 6.0, 0.0)), false); +//! +//! for _ in 0..10 { +//! let tokens = sampler.sample(&mut history, candidates.clone()); +//! assert_eq!(tokens.len(), 1); +//! } //! -//! println!("{:?}", token); +//! assert_eq!(history.len(), 10); //! ``` use crate::token::data::LlamaTokenData; use crate::token::data_array::LlamaTokenDataArray; -use crate::token::LlamaToken; use std::fmt::{Debug, Formatter}; /// A single step to sample tokens from the remaining candidates. -pub type SampleStep<'a> = Box; +pub type SampleStep = dyn Fn(&mut LlamaTokenDataArray, &mut C); + /// The final step to select one or more tokens from the remaining candidates. -pub type SampleFinalizer<'a> = Box Vec + 'a>; +pub type SampleFinalizer = dyn Fn(LlamaTokenDataArray, &mut C) -> Vec; /// A series of sampling steps that will produce a vector of token data. /// /// `a` is the lifetime of captured references in the steps and finalizer. #[non_exhaustive] -pub struct Sampler<'a> { +pub struct Sampler<'a, C> { /// The steps to take when sampling. - pub steps: Vec>, + pub steps: Vec<&'a SampleStep>, /// The final step to select one or more tokens from the remaining candidates. - pub finalizer: SampleFinalizer<'a>, + pub finalizer: &'a SampleFinalizer, } -impl Debug for Sampler<'_> { +impl Debug for Sampler<'_, T> +{ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("Sampler") .field( @@ -67,188 +77,26 @@ impl Debug for Sampler<'_> { } } -impl<'a> Sampler<'a> { - /// Create a very simple sampler that selects a single token with the greatest logit (greedy sampling). - /// - /// # Example - /// - /// ```rust - /// use llama_cpp_2::context::sample::sampler::Sampler; - /// use llama_cpp_2::token::data::LlamaTokenData; - /// use llama_cpp_2::token::data_array::LlamaTokenDataArray;use llama_cpp_2::token::LlamaToken; - /// - /// let mut sampler = Sampler::greedy(); - /// - /// let candidates = (0..4).map(|i| LlamaTokenData::new(LlamaToken::new(i), i as f32 / 6.0, 0.0)); - /// let tokens = sampler.sample(LlamaTokenDataArray::from_iter(candidates, false)); - /// assert_eq!(tokens[0].id(), LlamaToken::new(3)); - /// ``` - #[must_use] - pub fn greedy() -> Self { - let finalizer= |mut token_data: LlamaTokenDataArray| { - if token_data.data.is_empty() { - return vec![]; - } - if token_data.sorted { - vec![token_data.data[0]] - } else { - token_data.sample_softmax(None); - vec![token_data.data[0]] - } - }; - Self::new(finalizer) - } - - /// Adds a repetition penalty sampling step to the sampler. - /// - /// See [`LlamaTokenDataArray::sample_repetition_penalty`] - pub fn push_sample_repetition_penalty_step( - &mut self, - history: &'a [LlamaToken], - penalty_last_n: usize, - penalty_repeat: f32, - penalty_freq: f32, - penalty_present: f32, - ) { - self.steps - .push(Box::new(move |can: &mut LlamaTokenDataArray| { - can.sample_repetition_penalty( - None, - history, - penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - ); - })); - } - - /// Adds a typical sampling step to the sampler. - /// - /// See [`LlamaTokenDataArray::sample_typical`] - pub fn push_sample_typical_step(&mut self, p: f32, min_keep: usize) { - self.steps - .push(Box::new(move |can: &mut LlamaTokenDataArray| { - can.sample_typical(None, p, min_keep); - })); - } - - /// Adds a Top-p sampling step to the sampler. - /// - /// See [`LlamaTokenDataArray::sample_top_p`] - pub fn push_sample_top_p_step(&mut self, p: f32, min_keep: usize) { - self.steps - .push(Box::new(move |can: &mut LlamaTokenDataArray| { - can.sample_top_p(None, p, min_keep); - })); - } - - /// Adds a tail-free sampling step to the sampler. - /// - /// See [`LlamaTokenDataArray::sample_tail_free`] - pub fn push_sample_tail_free_step(&mut self, z: f32, min_keep: usize) { - self.steps - .push(Box::new(move |can: &mut LlamaTokenDataArray| { - can.sample_tail_free(None, z, min_keep); - })); - } - - /// Adds a top-k sampling step to the sampler. - /// - /// See [`LlamaTokenDataArray::sample_top_k`] - pub fn push_top_k_step(&mut self, k: i32, min_keep: usize) { - self.steps - .push(Box::new(move |can: &mut LlamaTokenDataArray| { - can.sample_top_k(None, k, min_keep); - })); - } - - /// Adds a temperature sampling step to the sampler. - /// - /// See [`LlamaTokenDataArray::sample_temp`] - pub fn push_temperature_step(&mut self, temperature: f32) { - self.steps - .push(Box::new(move |can: &mut LlamaTokenDataArray| { - can.sample_temp(None, temperature); - })); - } - - /// Adds a minimum P sampling step to the sampler. - /// - /// See [`LlamaTokenDataArray::sample_min_p`] - pub fn push_min_p_step(&mut self, p: f32, min_keep: usize) { - self.steps - .push(Box::new(move |can: &mut LlamaTokenDataArray| { - can.sample_min_p(None, p, min_keep); - })); - } - +impl<'a, T> Sampler<'a, T> { /// Create a new sampler with a given finalizer. - /// - /// # Example - /// - /// ```rust - /// use llama_cpp_2::context::sample::sampler::Sampler; - /// use llama_cpp_2::token::data::LlamaTokenData; - /// use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// use llama_cpp_2::token::LlamaToken; - /// - /// // a very silly way to sample. - /// let always_0 = |can: LlamaTokenDataArray| -> Vec { can.data.into_iter().filter(|t| t.id() == LlamaToken::new(0)).collect::>() }; - /// - /// let mut sampler = Sampler::new(always_0); - /// - /// let candidates = (0..4).map(|i| LlamaTokenData::new(LlamaToken::new(i), i as f32, 0.0)); - /// - /// let token = sampler.sample(LlamaTokenDataArray::from_iter(candidates, false)); - /// assert_eq!(token[0].id(), LlamaToken::new(0)); - /// - /// ``` - pub fn new( - finalizer: impl FnMut(LlamaTokenDataArray) -> Vec + 'a, - ) -> Self { + pub fn new(finalizer: &'a SampleFinalizer) -> Self { Self { - steps: Vec::new(), - finalizer: Box::new(finalizer), + steps: vec![], + finalizer, } } /// Adds a step to the sampler. - /// - /// # Example - /// - /// ```rust - /// use llama_cpp_2::context::sample::sampler::Sampler; - /// use llama_cpp_2::token::data::LlamaTokenData; - /// use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// use llama_cpp_2::token::LlamaToken; - /// - /// let mut favor_even_tokens = |can: &mut LlamaTokenDataArray| { - /// for token in can.data.iter_mut() { - /// if token.id().0 % 2 == 0 { - /// token.set_logit(token.logit() + 1.0); - /// } - /// } - /// }; - /// let mut sampler = Sampler::greedy(); - /// sampler.push_step(favor_even_tokens); - /// - /// let candidates = (0..4).map(|i| LlamaTokenData::new(LlamaToken::new(i), i as f32, 0.0)); - /// - /// let token = sampler.sample(LlamaTokenDataArray::from_iter(candidates, false)); - /// - /// assert_eq!(token[0].id(), LlamaToken::new(2)); - /// ``` - pub fn push_step(&mut self, step: impl FnMut(&mut LlamaTokenDataArray) + 'a) { - self.steps.push(Box::new(step)); + pub fn push_step(&mut self, step: &'a SampleStep) { + self.steps.push(step); } /// Sample a token from the given candidates. #[must_use] - pub fn sample(&mut self, mut candidates: LlamaTokenDataArray) -> Vec { - for step in &mut self.steps { - step(&mut candidates); + pub fn sample(&mut self, context: &mut T, mut candidates: LlamaTokenDataArray) -> Vec { + for step in &self.steps { + step(&mut candidates, context); } - (self.finalizer)(candidates) + (self.finalizer)(candidates, context) } } diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index 270149c6..4fc57d24 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -12,7 +12,7 @@ //! # Feature Flags //! //! - `cublas` enables CUDA gpu support. -//! - `sampler` adds the [`context::sample::sampler::Sampler`] struct for a more rusty way of sampling. +//! - `sampler` adds the [`context::sample::sampler`] struct for a more rusty way of sampling. use std::ffi::NulError; use std::fmt::Debug; use std::num::NonZeroI32; From c9205137310e41e903c7f41242f7e03c6e13f5ce Mon Sep 17 00:00:00 2001 From: marcus Date: Thu, 14 Mar 2024 12:44:04 -0700 Subject: [PATCH 5/6] updated some docs --- llama-cpp-2/src/context/sample/sampler.rs | 26 +++++++++++++++-------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/llama-cpp-2/src/context/sample/sampler.rs b/llama-cpp-2/src/context/sample/sampler.rs index 2ca808e6..74906a7c 100644 --- a/llama-cpp-2/src/context/sample/sampler.rs +++ b/llama-cpp-2/src/context/sample/sampler.rs @@ -1,6 +1,10 @@ -//! Create a sampler struct to encapsulate the sampling process. +//! Create a sampler struct to encapsulate the sampling process. This allows passing all the possible +//! sampling parameters around as a single struct, and also allow late binding of expensive context +//! like [`crate::context::LlamaContext`] or token history to the sampler. //! //! # Example +//! +//! **Llama.cpp default sampler** //! //! ```rust //! use llama_cpp_2::context::sample::sampler::{Sampler, SampleStep}; @@ -8,6 +12,7 @@ //! use llama_cpp_2::token::data_array::LlamaTokenDataArray; //! use llama_cpp_2::token::LlamaToken; //! +//! // Sample a token greedily and add to the history. //! let mut finalizer = &|mut canidates: LlamaTokenDataArray, history: &mut Vec| { //! canidates.sample_softmax(None); //! let token = canidates.data[0]; @@ -26,8 +31,9 @@ //! sampler.push_step(&|c, _| c.sample_min_p(None, 0.05, 1)); //! sampler.push_step(&|c, _| c.sample_temp(None, 0.5)); //! +//! // random candidates //! let candidates = LlamaTokenDataArray::from_iter((0..4).map(|i| LlamaTokenData::new(LlamaToken::new(i), i as f32 / 6.0, 0.0)), false); -//! +//! //! for _ in 0..10 { //! let tokens = sampler.sample(&mut history, candidates.clone()); //! assert_eq!(tokens.len(), 1); @@ -43,14 +49,13 @@ use std::fmt::{Debug, Formatter}; /// A single step to sample tokens from the remaining candidates. pub type SampleStep = dyn Fn(&mut LlamaTokenDataArray, &mut C); - -/// The final step to select one or more tokens from the remaining candidates. +/// The final step to select tokens from the remaining candidates. pub type SampleFinalizer = dyn Fn(LlamaTokenDataArray, &mut C) -> Vec; /// A series of sampling steps that will produce a vector of token data. /// -/// `a` is the lifetime of captured references in the steps and finalizer. -#[non_exhaustive] +/// `C` is dynamic context that will be passed to the sampling functions. I expect `C` will +/// often be [`()`], [`crate::context::LlamaContext`] or a token history (or some combination of these). pub struct Sampler<'a, C> { /// The steps to take when sampling. pub steps: Vec<&'a SampleStep>, @@ -58,8 +63,7 @@ pub struct Sampler<'a, C> { pub finalizer: &'a SampleFinalizer, } -impl Debug for Sampler<'_, T> -{ +impl Debug for Sampler<'_, T> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("Sampler") .field( @@ -93,7 +97,11 @@ impl<'a, T> Sampler<'a, T> { /// Sample a token from the given candidates. #[must_use] - pub fn sample(&mut self, context: &mut T, mut candidates: LlamaTokenDataArray) -> Vec { + pub fn sample( + &mut self, + context: &mut T, + mut candidates: LlamaTokenDataArray, + ) -> Vec { for step in &self.steps { step(&mut candidates, context); } From 411a679f57328c632d5111b8fb3c0789fb1397ef Mon Sep 17 00:00:00 2001 From: marcus Date: Thu, 14 Mar 2024 13:08:46 -0700 Subject: [PATCH 6/6] added microstat --- llama-cpp-2/src/context/sample/sampler.rs | 6 +++-- llama-cpp-2/src/token/data_array.rs | 30 +++++++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/llama-cpp-2/src/context/sample/sampler.rs b/llama-cpp-2/src/context/sample/sampler.rs index 74906a7c..cfe90499 100644 --- a/llama-cpp-2/src/context/sample/sampler.rs +++ b/llama-cpp-2/src/context/sample/sampler.rs @@ -54,8 +54,10 @@ pub type SampleFinalizer = dyn Fn(LlamaTokenDataArray, &mut C) -> Vec { /// The steps to take when sampling. pub steps: Vec<&'a SampleStep>, diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index 0f89d59f..776d222a 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -344,4 +344,34 @@ impl LlamaTokenDataArray { }); } } + + /// Mirostat 2.0 algorithm described in the [paper](https://arxiv.org/abs/2007.14966). Uses tokens instead of words. + /// + /// # Parameters + /// + /// * `tau` The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// * `eta` The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// * `mu` Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + pub fn sample_token_mirostat_v2( + &mut self, + ctx: &mut LlamaContext, + tau: f32, + eta: f32, + mu: &mut f32, + ) -> LlamaToken { + let mu_ptr = ptr::from_mut(mu); + let token = unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_token_mirostat_v2( + ctx.context.as_ptr(), + c_llama_token_data_array, + tau, + eta, + mu_ptr, + ) + }) + }; + *mu = unsafe { *mu_ptr }; + LlamaToken(token) + } }