diff --git a/.github/workflows/llama-cpp-rs-check.yml b/.github/workflows/llama-cpp-rs-check.yml index c26572c9..da5d067f 100644 --- a/.github/workflows/llama-cpp-rs-check.yml +++ b/.github/workflows/llama-cpp-rs-check.yml @@ -34,7 +34,7 @@ jobs: - name: Fmt run: cargo fmt - name: Test - run: cargo test + run: cargo test --features sampler arm64: name: Check that it builds on various targets runs-on: ubuntu-latest @@ -67,7 +67,7 @@ jobs: - name: Setup Rust uses: dtolnay/rust-toolchain@stable - name: Build - run: cargo build + run: cargo build --features sampler windows: name: Check that it builds on windows runs-on: windows-latest @@ -79,4 +79,4 @@ jobs: - name: Setup Rust uses: dtolnay/rust-toolchain@stable - name: Build - run: cargo build + run: cargo build --features sampler diff --git a/llama-cpp-2/Cargo.toml b/llama-cpp-2/Cargo.toml index c1c0193b..0cae4872 100644 --- a/llama-cpp-2/Cargo.toml +++ b/llama-cpp-2/Cargo.toml @@ -28,6 +28,7 @@ harness = false [features] cublas = ["llama-cpp-sys-2/cublas"] +sampler = [] [lints] workspace = true diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index 38b2e5ff..36a0ebe2 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -69,7 +69,7 @@ impl<'model> LlamaContext<'model> { /// /// # Panics /// - /// - the returned [`c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems) + /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems) pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> { let result = unsafe { llama_cpp_sys_2::llama_decode(self.context.as_ptr(), batch.llama_batch) }; diff --git a/llama-cpp-2/src/context/kv_cache.rs b/llama-cpp-2/src/context/kv_cache.rs index 54ec27b4..313ea928 100644 --- a/llama-cpp-2/src/context/kv_cache.rs +++ b/llama-cpp-2/src/context/kv_cache.rs @@ -22,8 +22,8 @@ impl LlamaContext<'_> { /// /// * `src` - The sequence id to copy the cache from. /// * `dest` - The sequence id to copy the cache to. - /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to [p1]. - /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from [p0]. + /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`. + /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`. pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option, p1: Option) { let p0 = p0.map_or(-1, i32::from); let p1 = p1.map_or(-1, i32::from); @@ -37,8 +37,8 @@ impl LlamaContext<'_> { /// # Parameters /// /// * `src` - The sequence id to clear the cache for. - /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to [p1]. - /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from [p0]. + /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`. + /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`. pub fn clear_kv_cache_seq(&mut self, src: i32, p0: Option, p1: Option) { let p0 = p0.map_or(-1, i32::from); let p1 = p1.map_or(-1, i32::from); @@ -68,7 +68,7 @@ impl LlamaContext<'_> { } #[allow(clippy::doc_markdown)] - /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) + /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)` /// If the KV cache is RoPEd, the KV data is updated accordingly: /// - lazily on next [`LlamaContext::decode`] /// - explicitly with [`Self::kv_cache_update`] @@ -76,8 +76,8 @@ impl LlamaContext<'_> { /// # Parameters /// /// * `seq_id` - The sequence id to update - /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1]. - /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0]. + /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`. + /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`. /// * `delta` - The relative position to add to the tokens pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option, p1: Option, delta: i32) { let p0 = p0.map_or(-1, i32::from); @@ -95,8 +95,8 @@ impl LlamaContext<'_> { /// # Parameters /// /// * `seq_id` - The sequence id to update - /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1]. - /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0]. + /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`. + /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`. /// * `d` - The factor to divide the positions by pub fn kv_cache_seq_div( &mut self, diff --git a/llama-cpp-2/src/context/sample.rs b/llama-cpp-2/src/context/sample.rs index 7b1b69a5..9704e264 100644 --- a/llama-cpp-2/src/context/sample.rs +++ b/llama-cpp-2/src/context/sample.rs @@ -5,130 +5,10 @@ use crate::grammar::LlamaGrammar; use crate::token::data_array::LlamaTokenDataArray; use crate::token::LlamaToken; -/// struct to hold params for sampling -#[derive(Debug)] -#[deprecated( - since = "0.1.32", - note = "this does not scale well with many params and does not allow for changing of orders." -)] -pub struct Sampler<'grammar> { - token_data_array: LlamaTokenDataArray, - grammar: Option<&'grammar mut LlamaGrammar>, - temperature: Option, -} - -impl<'grammar> Sampler<'grammar> { - #[deprecated( - since = "0.1.32", - note = "this does not scale well with many params and does not allow for changing of orders." - )] - fn sample(self, llama_context: &mut LlamaContext) -> LlamaToken { - match self { - Sampler { - token_data_array, - grammar: None, - temperature: None, - } => llama_context.sample_token_greedy(token_data_array), - Sampler { - mut token_data_array, - grammar: Some(grammar), - temperature: None, - } => { - llama_context.sample_grammar(&mut token_data_array, grammar); - let token = llama_context.sample_token_greedy(token_data_array); - llama_context.grammar_accept_token(grammar, token); - token - } - Sampler { - mut token_data_array, - grammar: None, - temperature: Some(temp), - } => { - llama_context.sample_temp(&mut token_data_array, temp); - llama_context.sample_token_softmax(&mut token_data_array); - token_data_array.data[0].id() - } - Sampler { - mut token_data_array, - grammar: Some(grammar), - temperature: Some(temperature), - } => { - llama_context.sample_grammar(&mut token_data_array, grammar); - llama_context.sample_temp(&mut token_data_array, temperature); - llama_context.sample_token_softmax(&mut token_data_array); - let token = llama_context.sample_token_greedy(token_data_array); - llama_context.grammar_accept_token(grammar, token); - token - } - } - } - - /// Create a new sampler. - #[must_use] - #[deprecated( - since = "0.1.32", - note = "this does not scale well with many params and does not allow for changing of orders." - )] - pub fn new(llama_token_data_array: LlamaTokenDataArray) -> Self { - Self { - token_data_array: llama_token_data_array, - grammar: None, - temperature: None, - } - } - - /// Set the grammar for sampling. - #[must_use] - #[deprecated( - since = "0.1.32", - note = "this does not scale well with many params and does not allow for changing of orders." - )] - pub fn with_grammar(mut self, grammar: &'grammar mut LlamaGrammar) -> Self { - self.grammar = Some(grammar); - self - } - - /// Set the temperature for sampling. - /// - /// ``` - /// # use llama_cpp_2::context::LlamaContext; - /// # use llama_cpp_2::context::sample::Sampler; - /// # use llama_cpp_2::grammar::LlamaGrammar; - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let _sampler = Sampler::new(LlamaTokenDataArray::new(vec![LlamaTokenData::new(LlamaToken(0), 0.0, 0.0)], false)) - /// .with_temperature(0.5); - /// ``` - #[must_use] - #[deprecated( - since = "0.1.32", - note = "this does not scale well with many params and does not allow for changing of orders." - )] - pub fn with_temperature(mut self, temperature: f32) -> Self { - if temperature == 0.0 { - return self; - } - self.temperature = Some(temperature); - self - } -} +#[cfg(feature = "sampler")] +pub mod sampler; impl LlamaContext<'_> { - /// Sample a token. - /// - /// # Panics - /// - /// - sampler contains no tokens - #[deprecated( - since = "0.1.32", - note = "this does not scale well with many params and does not allow for changing of orders." - )] - pub fn sample(&mut self, sampler: Sampler) -> LlamaToken { - sampler.sample(self) - } - /// Accept a token into the grammar. pub fn grammar_accept_token(&mut self, grammar: &mut LlamaGrammar, token: LlamaToken) { unsafe { @@ -157,38 +37,20 @@ impl LlamaContext<'_> { } } - /// Modify [`token_data`] in place using temperature sampling. - /// - /// # Panics - /// - /// - [`temperature`] is not between 0.0 and 1.0 - pub fn sample_temp(&self, token_data: &mut LlamaTokenDataArray, temperature: f32) { - assert!( - temperature >= 0.0, - "temperature must be positive (was {temperature})" - ); - assert!( - temperature <= 1.0, - "temperature must be less than or equal to 1.0 (was {temperature})" - ); - if temperature == 0.0 { - return; - } - let ctx: *mut llama_cpp_sys_2::llama_context = self.context.as_ptr(); - unsafe { - token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_temp(ctx, c_llama_token_data_array, temperature); - }); - } + /// See [`LlamaTokenDataArray::sample_temp`] + pub fn sample_temp(&mut self, token_data: &mut LlamaTokenDataArray, temperature: f32) { + token_data.sample_temp(Some(self), temperature); } - /// Sample a token greedily. + /// Sample a token greedily. Note that this *does not* take into account anything that has modified the probabilities - it only looks at logits. + /// + /// Most of the time [`LlamaTokenDataArray::sample_softmax`] or [`LlamaTokenDataArray::sample_token`] should be used instead. /// /// # Panics /// - /// - [`token_data`] is empty + /// - if `token_data` is empty #[must_use] - pub fn sample_token_greedy(&self, mut token_data: LlamaTokenDataArray) -> LlamaToken { + pub fn sample_token_greedy(&mut self, mut token_data: LlamaTokenDataArray) -> LlamaToken { assert!(!token_data.data.is_empty(), "no tokens"); let mut data_arr = llama_cpp_sys_2::llama_token_data_array { data: token_data @@ -207,39 +69,34 @@ impl LlamaContext<'_> { LlamaToken(token) } - /// Tail Free Sampling described in [Tail-Free-Sampling](https://www.trentonbricken.com/Tail-Free-Sampling/). - pub fn sample_tail_free(&self, token_data: &mut LlamaTokenDataArray, z: f32, min_keep: usize) { - let ctx = self.context.as_ptr(); - unsafe { - token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_tail_free(ctx, c_llama_token_data_array, z, min_keep); - }); - } + /// See [`LlamaTokenDataArray::sample_tail_free`] + pub fn sample_tail_free( + &mut self, + token_data: &mut LlamaTokenDataArray, + z: f32, + min_keep: usize, + ) { + token_data.sample_tail_free(Some(self), z, min_keep); } - /// Locally Typical Sampling implementation described in the [paper](https://arxiv.org/abs/2202.00666). - pub fn sample_typical(&self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) { - let ctx = self.context.as_ptr(); - unsafe { - token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_typical(ctx, c_llama_token_data_array, p, min_keep); - }); - } + /// See [`LlamaTokenDataArray::sample_typical`] + pub fn sample_typical( + &mut self, + token_data: &mut LlamaTokenDataArray, + p: f32, + min_keep: usize, + ) { + token_data.sample_typical(Some(self), p, min_keep); } - /// Nucleus sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751)" - pub fn sample_top_p(&self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) { - let ctx = self.context.as_ptr(); - unsafe { - token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_top_p(ctx, c_llama_token_data_array, p, min_keep); - }); - } + /// See [`LlamaTokenDataArray::sample_top_p`] + pub fn sample_top_p(&mut self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) { + token_data.sample_top_p(Some(self), p, min_keep); } /// Minimum P sampling as described in [#3841](https://github.com/ggerganov/llama.cpp/pull/3841) pub fn sample_min_p( - &self, + &mut self, llama_token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize, @@ -252,24 +109,14 @@ impl LlamaContext<'_> { } } - /// Top-K sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) - pub fn sample_top_k(&self, token_data: &mut LlamaTokenDataArray, k: i32, min_keep: usize) { - let ctx = self.context.as_ptr(); - unsafe { - token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_top_k(ctx, c_llama_token_data_array, k, min_keep); - }); - } + /// See [`LlamaTokenDataArray::sample_top_k`] + pub fn sample_top_k(&mut self, token_data: &mut LlamaTokenDataArray, k: i32, min_keep: usize) { + token_data.sample_top_k(Some(self), k, min_keep); } - /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - pub fn sample_token_softmax(&self, token_data: &mut LlamaTokenDataArray) { - let ctx = self.context.as_ptr(); - unsafe { - token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_softmax(ctx, c_llama_token_data_array); - }); - } + /// See [`LlamaTokenDataArray::sample_softmax`] + pub fn sample_token_softmax(&mut self, token_data: &mut LlamaTokenDataArray) { + token_data.sample_softmax(Some(self)); } /// See [`LlamaTokenDataArray::sample_repetition_penalty`] diff --git a/llama-cpp-2/src/context/sample/sampler.rs b/llama-cpp-2/src/context/sample/sampler.rs new file mode 100644 index 00000000..cfe90499 --- /dev/null +++ b/llama-cpp-2/src/context/sample/sampler.rs @@ -0,0 +1,112 @@ +//! Create a sampler struct to encapsulate the sampling process. This allows passing all the possible +//! sampling parameters around as a single struct, and also allow late binding of expensive context +//! like [`crate::context::LlamaContext`] or token history to the sampler. +//! +//! # Example +//! +//! **Llama.cpp default sampler** +//! +//! ```rust +//! use llama_cpp_2::context::sample::sampler::{Sampler, SampleStep}; +//! use llama_cpp_2::token::data::LlamaTokenData; +//! use llama_cpp_2::token::data_array::LlamaTokenDataArray; +//! use llama_cpp_2::token::LlamaToken; +//! +//! // Sample a token greedily and add to the history. +//! let mut finalizer = &|mut canidates: LlamaTokenDataArray, history: &mut Vec| { +//! canidates.sample_softmax(None); +//! let token = canidates.data[0]; +//! history.push(token.id()); +//! vec![token] +//! }; +//! +//! let mut history = vec![]; +//! let mut sampler = Sampler::new(finalizer); +//! +//! sampler.push_step(&|c, history| c.sample_repetition_penalty(None, history, 64, 1.1, 0.0, 0.0)); +//! sampler.push_step(&|c, _| c.sample_top_k(None, 40, 1)); +//! sampler.push_step(&|c, _| c.sample_tail_free(None, 1.0, 1)); +//! sampler.push_step(&|c, _| c.sample_typical(None, 1.0, 1)); +//! sampler.push_step(&|c, _| c.sample_top_p(None, 0.95, 1)); +//! sampler.push_step(&|c, _| c.sample_min_p(None, 0.05, 1)); +//! sampler.push_step(&|c, _| c.sample_temp(None, 0.5)); +//! +//! // random candidates +//! let candidates = LlamaTokenDataArray::from_iter((0..4).map(|i| LlamaTokenData::new(LlamaToken::new(i), i as f32 / 6.0, 0.0)), false); +//! +//! for _ in 0..10 { +//! let tokens = sampler.sample(&mut history, candidates.clone()); +//! assert_eq!(tokens.len(), 1); +//! } +//! +//! assert_eq!(history.len(), 10); +//! ``` + +use crate::token::data::LlamaTokenData; +use crate::token::data_array::LlamaTokenDataArray; +use std::fmt::{Debug, Formatter}; + +/// A single step to sample tokens from the remaining candidates. +pub type SampleStep = dyn Fn(&mut LlamaTokenDataArray, &mut C); + +/// The final step to select tokens from the remaining candidates. +pub type SampleFinalizer = dyn Fn(LlamaTokenDataArray, &mut C) -> Vec; + +/// A series of sampling steps that will produce a vector of token data. +/// +/// `C` is dynamic context that will be passed to the sampling functions. Some sampling steps may +/// require state to be maintained across multiple samples, and this context can be used to store +/// that state. For example, [`LlamaTokenDataArray::sample_token_mirostat_v2`] requires a `mu` to be +/// shared across multiple samples. +pub struct Sampler<'a, C> { + /// The steps to take when sampling. + pub steps: Vec<&'a SampleStep>, + /// The final step to select one or more tokens from the remaining candidates. + pub finalizer: &'a SampleFinalizer, +} + +impl Debug for Sampler<'_, T> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Sampler") + .field( + "steps", + &format!( + "{} steps of Box ()>", + &self.steps.len() + ), + ) + .field( + "finalizer", + &"Box Vec>", + ) + .finish() + } +} + +impl<'a, T> Sampler<'a, T> { + /// Create a new sampler with a given finalizer. + pub fn new(finalizer: &'a SampleFinalizer) -> Self { + Self { + steps: vec![], + finalizer, + } + } + + /// Adds a step to the sampler. + pub fn push_step(&mut self, step: &'a SampleStep) { + self.steps.push(step); + } + + /// Sample a token from the given candidates. + #[must_use] + pub fn sample( + &mut self, + context: &mut T, + mut candidates: LlamaTokenDataArray, + ) -> Vec { + for step in &self.steps { + step(&mut candidates, context); + } + (self.finalizer)(candidates, context) + } +} diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index bf429364..4fc57d24 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -12,6 +12,7 @@ //! # Feature Flags //! //! - `cublas` enables CUDA gpu support. +//! - `sampler` adds the [`context::sample::sampler`] struct for a more rusty way of sampling. use std::ffi::NulError; use std::fmt::Debug; use std::num::NonZeroI32; diff --git a/llama-cpp-2/src/llama_batch.rs b/llama-cpp-2/src/llama_batch.rs index 664d9035..e52bfa9e 100644 --- a/llama-cpp-2/src/llama_batch.rs +++ b/llama-cpp-2/src/llama_batch.rs @@ -30,7 +30,7 @@ impl LlamaBatch { self.initialized_logits.clear(); } - /// add a token to the batch for sequences [`seq_ids`] at position [pos]. If [logits] is true, the + /// add a token to the batch for sequences `seq_ids` at position `pos`. If `logits` is true, the /// token will be initialized and can be read from after the next decode. /// /// # Panics @@ -91,7 +91,7 @@ impl LlamaBatch { Ok(()) } - /// Add a sequence of tokens to the batch for the given sequence id. If [`logits_all`] is true, the + /// Add a sequence of tokens to the batch for the given sequence id. If `logits_all` is true, the /// tokens will be initialized and can be read from after the next decode. /// /// Either way the last token in the sequence will have its logits set to `true`. diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index e1a52d20..4bd85b83 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -201,7 +201,7 @@ impl LlamaModel { /// /// # Panics /// - /// - if [`buffer_size`] does not fit into a [`c_int`]. + /// - if `buffer_size` does not fit into a [`c_int`]. /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen) pub fn token_to_str_with_size( &self, diff --git a/llama-cpp-2/src/model/params.rs b/llama-cpp-2/src/model/params.rs index 69ddb004..b4d5a25e 100644 --- a/llama-cpp-2/src/model/params.rs +++ b/llama-cpp-2/src/model/params.rs @@ -9,9 +9,6 @@ use std::ptr::null; pub mod kv_overrides; /// A safe wrapper around `llama_model_params`. -/// -/// [`T`] is the type of the backing storage for the key-value overrides. Generally it can be left to [`()`] which will -/// make your life with the borrow checker much easier. #[allow(clippy::module_name_repetitions)] pub struct LlamaModelParams { pub(crate) params: llama_cpp_sys_2::llama_model_params, diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index 66b6e583..776d222a 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -65,10 +65,10 @@ impl LlamaTokenDataArray { /// [modify] cannot change the data pointer. /// if the data is not sorted, sorted must be false. /// the size of the data can only decrease (i.e you cannot add new elements). - pub(crate) unsafe fn modify_as_c_llama_token_data_array( + pub(crate) unsafe fn modify_as_c_llama_token_data_array( &mut self, - modify: impl FnOnce(&mut llama_cpp_sys_2::llama_token_data_array), - ) { + modify: impl FnOnce(&mut llama_cpp_sys_2::llama_token_data_array) -> T, + ) -> T { let size = self.data.len(); let data = self.data.as_mut_ptr().cast(); let mut c_llama_token_data_array = llama_cpp_sys_2::llama_token_data_array { @@ -76,7 +76,7 @@ impl LlamaTokenDataArray { size, sorted: self.sorted, }; - modify(&mut c_llama_token_data_array); + let result = modify(&mut c_llama_token_data_array); assert!( ptr::eq(data, c_llama_token_data_array.data), "data pointer changed" @@ -84,6 +84,7 @@ impl LlamaTokenDataArray { assert!(c_llama_token_data_array.size <= size, "size increased"); self.data.set_len(c_llama_token_data_array.size); self.sorted = c_llama_token_data_array.sorted; + result } /// Repetition penalty described in [CTRL academic paper](https://arxiv.org/abs/1909.05858), with negative logit fix. @@ -93,7 +94,7 @@ impl LlamaTokenDataArray { /// /// * `ctx` - the context to use. May be `None` if you do not care to record the sample timings. /// * `last_tokens` - the last tokens in the context. - /// + /// /// * `penalty_last_n` - the number of tokens back to consider for the repetition penalty. (0 for no penalty) /// * `penalty_repeat` - the repetition penalty. (1.0 for no penalty) /// * `penalty_freq` - the frequency penalty. (0.0 for no penalty) @@ -155,4 +156,222 @@ impl LlamaTokenDataArray { }); } } + + /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. + /// + /// # Example + /// + /// ```rust + /// # use llama_cpp_2::token::data::LlamaTokenData; + /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// # use llama_cpp_2::token::LlamaToken; + /// + /// let lowest = LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0); + /// let middle = LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0); + /// let highest = LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0); + /// + /// let candidates = vec![lowest, middle, highest]; + /// + /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); + /// candidates.sample_softmax(None); + /// + /// assert!(candidates.sorted); + /// assert_eq!(candidates.data[0].id(), highest.id()); + /// assert_eq!(candidates.data[0].logit(), highest.logit()); + /// assert!(candidates.data[0].p() > candidates.data[1].p()); + /// assert_eq!(candidates.data[1].id(), middle.id()); + /// assert_eq!(candidates.data[1].logit(), middle.logit()); + /// assert!(candidates.data[1].p() > candidates.data[2].p()); + /// assert_eq!(candidates.data[2].id(), lowest.id()); + /// assert_eq!(candidates.data[2].logit(), lowest.logit()); + /// ``` + pub fn sample_softmax(&mut self, ctx: Option<&mut LlamaContext>) { + unsafe { + let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_softmax(ctx, c_llama_token_data_array); + }); + } + } + + /// Modify the logits of [`Self`] in place using temperature sampling. + /// + /// # Example + /// + /// ```rust + /// # use llama_cpp_2::token::data::LlamaTokenData; + /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// # use llama_cpp_2::token::LlamaToken; + /// + /// let candidates = vec![ + /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), + /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), + /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0) + /// ]; + /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); + /// + /// candidates.sample_temp(None, 0.5); + /// + /// assert_ne!(candidates.data[0].logit(), 0.1); + /// assert_ne!(candidates.data[1].logit(), 0.2); + /// assert_ne!(candidates.data[2].logit(), 0.7); + /// ``` + pub fn sample_temp(&mut self, ctx: Option<&mut LlamaContext>, temperature: f32) { + if temperature == 0.0 { + return; + } + let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); + unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_temp(ctx, c_llama_token_data_array, temperature); + }); + } + } + + /// Randomly selects a token from the candidates based on their probabilities. + pub fn sample_token(&mut self, ctx: &mut LlamaContext) -> LlamaToken { + let llama_token = unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_token(ctx.context.as_ptr(), c_llama_token_data_array) + }) + }; + LlamaToken(llama_token) + } + + /// Top-K sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) + pub fn sample_top_k(&mut self, ctx: Option<&mut LlamaContext>, k: i32, min_keep: usize) { + let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); + unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_top_k(ctx, c_llama_token_data_array, k, min_keep); + }); + } + } + + /// Tail Free Sampling described in [Tail-Free-Sampling](https://www.trentonbricken.com/Tail-Free-Sampling/). + pub fn sample_tail_free(&mut self, ctx: Option<&mut LlamaContext>, z: f32, min_keep: usize) { + let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); + unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_tail_free(ctx, c_llama_token_data_array, z, min_keep); + }); + } + } + + /// Locally Typical Sampling implementation described in the [paper](https://arxiv.org/abs/2202.00666). + /// + /// # Example + /// + /// ```rust + /// + /// # use llama_cpp_2::token::data::LlamaTokenData; + /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// # use llama_cpp_2::token::LlamaToken; + /// + /// let candidates = vec![ + /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), + /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), + /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), + /// ]; + /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); + /// candidates.sample_typical(None, 0.5, 1); + /// + /// ``` + pub fn sample_typical(&mut self, ctx: Option<&mut LlamaContext>, p: f32, min_keep: usize) { + let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); + unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_typical(ctx, c_llama_token_data_array, p, min_keep); + }); + } + } + + /// Nucleus sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) + /// + /// # Example + /// + /// ```rust + /// + /// # use llama_cpp_2::token::data::LlamaTokenData; + /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// # use llama_cpp_2::token::LlamaToken; + /// + /// let candidates = vec![ + /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), + /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), + /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), + /// ]; + /// + /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); + /// candidates.sample_top_p(None, 0.5, 1); + /// + /// assert_eq!(candidates.data.len(), 2); + /// assert_eq!(candidates.data[0].id(), LlamaToken::new(2)); + /// assert_eq!(candidates.data[1].id(), LlamaToken::new(1)); + /// ``` + pub fn sample_top_p(&mut self, ctx: Option<&mut LlamaContext>, p: f32, min_keep: usize) { + let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); + unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_top_p(ctx, c_llama_token_data_array, p, min_keep); + }); + } + } + + /// Minimum P sampling as described in [#3841](https://github.com/ggerganov/llama.cpp/pull/3841) + /// + /// # Example + /// + /// ``` + /// # use llama_cpp_2::token::data::LlamaTokenData; + /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// # use llama_cpp_2::token::LlamaToken; + /// + /// let candidates = vec![ + /// LlamaTokenData::new(LlamaToken::new(4), 0.0001, 0.0), + /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), + /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), + /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), + /// ]; + /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); + /// candidates.sample_min_p(None, 0.05, 1); + /// ``` + pub fn sample_min_p(&mut self, ctx: Option<&mut LlamaContext>, p: f32, min_keep: usize) { + let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); + unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_min_p(ctx, c_llama_token_data_array, p, min_keep); + }); + } + } + + /// Mirostat 2.0 algorithm described in the [paper](https://arxiv.org/abs/2007.14966). Uses tokens instead of words. + /// + /// # Parameters + /// + /// * `tau` The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// * `eta` The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// * `mu` Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + pub fn sample_token_mirostat_v2( + &mut self, + ctx: &mut LlamaContext, + tau: f32, + eta: f32, + mu: &mut f32, + ) -> LlamaToken { + let mu_ptr = ptr::from_mut(mu); + let token = unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_token_mirostat_v2( + ctx.context.as_ptr(), + c_llama_token_data_array, + tau, + eta, + mu_ptr, + ) + }) + }; + *mu = unsafe { *mu_ptr }; + LlamaToken(token) + } }