From 6c5430d87792cab432815b029c7956e0bdfcd555 Mon Sep 17 00:00:00 2001 From: volesen Date: Tue, 26 Nov 2024 18:06:20 +0100 Subject: [PATCH] Implement new sampler API This commit implements the new sampler API, allowing us to build a sampler chain using a builder-like pattern --- examples/simple/src/main.rs | 27 ++-- llama-cpp-2/src/sampling.rs | 217 +++++++++++++++++++++++++++++ llama-cpp-2/src/sampling/params.rs | 9 +- 3 files changed, 238 insertions(+), 15 deletions(-) diff --git a/examples/simple/src/main.rs b/examples/simple/src/main.rs index 267d686..182e6e3 100644 --- a/examples/simple/src/main.rs +++ b/examples/simple/src/main.rs @@ -17,7 +17,9 @@ use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue; use llama_cpp_2::model::params::LlamaModelParams; use llama_cpp_2::model::LlamaModel; use llama_cpp_2::model::{AddBos, Special}; -use llama_cpp_2::token::data_array::LlamaTokenDataArray; +use llama_cpp_2::sampling::params::LlamaSamplerChainParams; +use llama_cpp_2::sampling::LlamaSampler; + use std::ffi::CString; use std::io::Write; use std::num::NonZeroU32; @@ -174,9 +176,9 @@ fn main() -> Result<()> { .with_context(|| "unable to load model")?; // initialize the context - let mut ctx_params = LlamaContextParams::default() - .with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap()))) - .with_seed(seed.unwrap_or(1234)); + let mut ctx_params = + LlamaContextParams::default().with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap()))); + if let Some(threads) = threads { ctx_params = ctx_params.with_n_threads(threads); } @@ -244,23 +246,24 @@ either reduce n_len or increase n_ctx" // The `Decoder` let mut decoder = encoding_rs::UTF_8.new_decoder(); + // + let sampler_params = LlamaSamplerChainParams::default(); + let mut sampler = LlamaSampler::new(sampler_params)?.add_dist(seed.unwrap_or(1234)); + while n_cur <= n_len { // sample the next token { - let candidates = ctx.candidates(); - - let candidates_p = LlamaTokenDataArray::from_iter(candidates, false); + let token = sampler.sample(&ctx, batch.n_tokens() - 1); - // sample the most likely token - let new_token_id = ctx.sample_token_greedy(candidates_p); + sampler.accept(token); // is it an end of stream? - if model.is_eog_token(new_token_id) { + if model.is_eog_token(token) { eprintln!(); break; } - let output_bytes = model.token_to_bytes(new_token_id, Special::Tokenize)?; + let output_bytes = model.token_to_bytes(token, Special::Tokenize)?; // use `Decoder.decode_to_string()` to avoid the intermediate buffer let mut output_string = String::with_capacity(32); let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false); @@ -268,7 +271,7 @@ either reduce n_len or increase n_ctx" std::io::stdout().flush()?; batch.clear(); - batch.add(new_token_id, n_cur, &[0], true)?; + batch.add(token, n_cur, &[0], true)?; } n_cur += 1; diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index 9242ef0..7181e14 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -1,9 +1,13 @@ //! Safe wrapper around `llama_sampler`. pub mod params; +use std::ffi::CString; use std::fmt::{Debug, Formatter}; use std::ptr::NonNull; +use crate::context::LlamaContext; +use crate::model::LlamaModel; +use crate::token::LlamaToken; use crate::LlamaSamplerError; /// A safe wrapper around `llama_sampler`. @@ -18,6 +22,9 @@ impl Debug for LlamaSampler { } impl LlamaSampler { + /// Create a new `LlamaSampler` from the given parameters. + /// # Errors + /// Returns an error if the underlying C++ code returns a null pointer. pub fn new(params: params::LlamaSamplerChainParams) -> Result { let sampler = unsafe { NonNull::new(llama_cpp_sys_2::llama_sampler_chain_init( @@ -28,6 +35,216 @@ impl LlamaSampler { Ok(Self { sampler }) } + + /// Samples the token with the largest probability. + #[must_use] + #[allow(unused_mut)] + pub fn add_greedy(mut self) -> Self { + unsafe { + let greedy_sampler = llama_cpp_sys_2::llama_sampler_init_greedy(); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), greedy_sampler); + } + + self + } + + /// Samples according to the probability distribution of the tokens. + #[must_use] + #[allow(unused_mut)] + pub fn add_dist(mut self, seed: u32) -> Self { + unsafe { + let dist_sampler = llama_cpp_sys_2::llama_sampler_init_dist(seed); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), dist_sampler); + } + + self + } + + /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" + #[must_use] + #[allow(unused_mut)] + pub fn add_top_k(mut self, k: i32) -> Self { + unsafe { + let top_k_sampler = llama_cpp_sys_2::llama_sampler_init_top_k(k); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), top_k_sampler); + } + + self + } + + /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" + #[must_use] + #[allow(unused_mut)] + pub fn add_top_p(mut self, p: f32, min_keep: usize) -> Self { + unsafe { + let top_p_sampler = llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), top_p_sampler); + } + + self + } + + /// Minimum P sampling as described in + #[must_use] + #[allow(unused_mut)] + pub fn add_min_p(mut self, p: f32, min_keep: usize) -> Self { + unsafe { + let min_p_sampler = llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), min_p_sampler); + } + + self + } + + /// Locally Typical Sampling implementation described in the paper . + #[must_use] + #[allow(unused_mut)] + pub fn add_typical(mut self, p: f32, min_keep: usize) -> Self { + unsafe { + let typical_sampler = llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), typical_sampler); + } + + self + } + + /// Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf + #[must_use] + #[allow(unused_mut)] + pub fn add_temp(mut self, t: f32) -> Self { + unsafe { + let temp_sampler = llama_cpp_sys_2::llama_sampler_init_temp(t); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_sampler); + } + + self + } + + /// Dynamic temperature implementation (a.k.a. entropy) described in the paper . + #[must_use] + #[allow(unused_mut)] + pub fn add_temp_ext(mut self, t: f32, delta: f32, exponent: f32) -> Self { + unsafe { + let temp_ext_sampler = llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler); + } + + self + } + + /// Mirostat 1.0 algorithm described in the paper . Uses tokens instead of words. + /// + /// # Arguments + /// + /// * `candidates` - A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// * `tau` - The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// * `eta` - The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// * `m` - The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. + /// * `mu` - Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + #[must_use] + #[allow(unused_mut)] + pub fn add_mirostat(mut self, n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self { + unsafe { + let temp_ext_sampler = + llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler); + } + + self + } + + /// Mirostat 2.0 algorithm described in the paper . Uses tokens instead of words. + /// + /// # Arguments + /// + /// * `candidates` - A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// * `tau` - The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// * `eta` - The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// * `mu` - Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + #[must_use] + #[allow(unused_mut)] + pub fn add_mirostat_v2(mut self, seed: u32, tau: f32, eta: f32) -> Self { + unsafe { + let temp_ext_sampler = llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler); + } + + self + } + + /// Samples constrained by a context-free grammar in the GGML BNF (GBNF) format. + /// + /// # Panics + /// Panics if a provided string contains a null byte. + #[must_use] + #[allow(unused_mut)] + pub fn add_grammar( + mut self, + model: &LlamaModel, + grammar_str: &str, + grammar_root: &str, + ) -> Self { + unsafe { + let grammar_str = CString::new(grammar_str).unwrap(); + let grammar_root = CString::new(grammar_root).unwrap(); + let grammar_sampler = llama_cpp_sys_2::llama_sampler_init_grammar( + model.model.as_ptr(), + grammar_str.as_ptr(), + grammar_root.as_ptr(), + ); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), grammar_sampler); + } + + self + } + + /// Adds penalties to the sampler. This can be used to penalize certain patterns in the generated text, such as repeating the same token multiple times or using the same token too frequently. + #[allow(unused_mut, clippy::too_many_arguments)] + #[must_use] + pub fn add_penalties( + mut self, + n_vocab: i32, + special_eos_id: i32, + linefeed_id: i32, + penalty_last_n: i32, + penalty_repeat: f32, + penalty_freq: f32, + penalty_present: f32, + penalize_nl: bool, + ignore_eos: bool, + ) -> Self { + unsafe { + let temp_ext_sampler = llama_cpp_sys_2::llama_sampler_init_penalties( + n_vocab, + special_eos_id, + linefeed_id, + penalty_last_n, + penalty_repeat, + penalty_freq, + penalty_present, + penalize_nl, + ignore_eos, + ); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler); + } + + self + } + + /// Sample and accept a token from the idx-th output of the last evaluation + #[must_use] + pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> LlamaToken { + let token = unsafe { + llama_cpp_sys_2::llama_sampler_sample(self.sampler.as_ptr(), ctx.context.as_ptr(), idx) + }; + + LlamaToken(token) + } + + /// Accepts a token from the sampler, possibly updating the internal state of certain samplers (e.g. grammar, repetition, etc.) + pub fn accept(&mut self, token: LlamaToken) { + unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler.as_ptr(), token.0) } + } } impl Drop for LlamaSampler { diff --git a/llama-cpp-2/src/sampling/params.rs b/llama-cpp-2/src/sampling/params.rs index 972df84..0e67c1f 100644 --- a/llama-cpp-2/src/sampling/params.rs +++ b/llama-cpp-2/src/sampling/params.rs @@ -1,7 +1,6 @@ -//! Safe wrapper around `llama_sampler`. +//! Safe wrapper around `llama_sampler_chain_params`. use std::fmt::{Debug, Formatter}; -use std::ptr::NonNull; /// A safe wrapper around `llama_sampler`. pub struct LlamaSamplerChainParams { @@ -25,11 +24,15 @@ impl Default for LlamaSamplerChainParams { } impl LlamaSamplerChainParams { - pub fn with_no_perf(&mut self, no_perf: bool) -> &mut Self { + /// Set whether to measure performance timings + #[must_use] + pub fn with_no_perf(mut self, no_perf: bool) -> Self { self.sampler_chain_params.no_perf = no_perf; self } + /// Get whether to measure performance timings + #[must_use] pub fn no_perf(&self) -> bool { self.sampler_chain_params.no_perf }