Skip to content

Commit

Permalink
Implement new sampler API
Browse files Browse the repository at this point in the history
This commit implements the new sampler API, allowing
us to build a sampler chain using a builder-like pattern
  • Loading branch information
volesen committed Nov 26, 2024
1 parent 5c27009 commit 6c5430d
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 15 deletions.
27 changes: 15 additions & 12 deletions examples/simple/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::{AddBos, Special};
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use llama_cpp_2::sampling::params::LlamaSamplerChainParams;
use llama_cpp_2::sampling::LlamaSampler;

use std::ffi::CString;
use std::io::Write;
use std::num::NonZeroU32;
Expand Down Expand Up @@ -174,9 +176,9 @@ fn main() -> Result<()> {
.with_context(|| "unable to load model")?;

// initialize the context
let mut ctx_params = LlamaContextParams::default()
.with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap())))
.with_seed(seed.unwrap_or(1234));
let mut ctx_params =
LlamaContextParams::default().with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap())));

if let Some(threads) = threads {
ctx_params = ctx_params.with_n_threads(threads);
}
Expand Down Expand Up @@ -244,31 +246,32 @@ either reduce n_len or increase n_ctx"
// The `Decoder`
let mut decoder = encoding_rs::UTF_8.new_decoder();

//
let sampler_params = LlamaSamplerChainParams::default();
let mut sampler = LlamaSampler::new(sampler_params)?.add_dist(seed.unwrap_or(1234));

while n_cur <= n_len {
// sample the next token
{
let candidates = ctx.candidates();

let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
let token = sampler.sample(&ctx, batch.n_tokens() - 1);

// sample the most likely token
let new_token_id = ctx.sample_token_greedy(candidates_p);
sampler.accept(token);

// is it an end of stream?
if model.is_eog_token(new_token_id) {
if model.is_eog_token(token) {
eprintln!();
break;
}

let output_bytes = model.token_to_bytes(new_token_id, Special::Tokenize)?;
let output_bytes = model.token_to_bytes(token, Special::Tokenize)?;
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
let mut output_string = String::with_capacity(32);
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
print!("{output_string}");
std::io::stdout().flush()?;

batch.clear();
batch.add(new_token_id, n_cur, &[0], true)?;
batch.add(token, n_cur, &[0], true)?;
}

n_cur += 1;
Expand Down
217 changes: 217 additions & 0 deletions llama-cpp-2/src/sampling.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
//! Safe wrapper around `llama_sampler`.
pub mod params;

use std::ffi::CString;
use std::fmt::{Debug, Formatter};
use std::ptr::NonNull;

use crate::context::LlamaContext;
use crate::model::LlamaModel;
use crate::token::LlamaToken;
use crate::LlamaSamplerError;

/// A safe wrapper around `llama_sampler`.
Expand All @@ -18,6 +22,9 @@ impl Debug for LlamaSampler {
}

impl LlamaSampler {
/// Create a new `LlamaSampler` from the given parameters.
/// # Errors
/// Returns an error if the underlying C++ code returns a null pointer.
pub fn new(params: params::LlamaSamplerChainParams) -> Result<Self, LlamaSamplerError> {
let sampler = unsafe {
NonNull::new(llama_cpp_sys_2::llama_sampler_chain_init(
Expand All @@ -28,6 +35,216 @@ impl LlamaSampler {

Ok(Self { sampler })
}

/// Samples the token with the largest probability.
#[must_use]
#[allow(unused_mut)]
pub fn add_greedy(mut self) -> Self {
unsafe {
let greedy_sampler = llama_cpp_sys_2::llama_sampler_init_greedy();
llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), greedy_sampler);
}

self
}

/// Samples according to the probability distribution of the tokens.
#[must_use]
#[allow(unused_mut)]
pub fn add_dist(mut self, seed: u32) -> Self {
unsafe {
let dist_sampler = llama_cpp_sys_2::llama_sampler_init_dist(seed);
llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), dist_sampler);
}

self
}

/// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" <https://arxiv.org/abs/1904.09751>
#[must_use]
#[allow(unused_mut)]
pub fn add_top_k(mut self, k: i32) -> Self {
unsafe {
let top_k_sampler = llama_cpp_sys_2::llama_sampler_init_top_k(k);
llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), top_k_sampler);
}

self
}

/// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" <https://arxiv.org/abs/1904.09751>
#[must_use]
#[allow(unused_mut)]
pub fn add_top_p(mut self, p: f32, min_keep: usize) -> Self {
unsafe {
let top_p_sampler = llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep);
llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), top_p_sampler);
}

self
}

/// Minimum P sampling as described in <https://github.com/ggerganov/llama.cpp/pull/3841>
#[must_use]
#[allow(unused_mut)]
pub fn add_min_p(mut self, p: f32, min_keep: usize) -> Self {
unsafe {
let min_p_sampler = llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep);
llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), min_p_sampler);
}

self
}

/// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
#[must_use]
#[allow(unused_mut)]
pub fn add_typical(mut self, p: f32, min_keep: usize) -> Self {
unsafe {
let typical_sampler = llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep);
llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), typical_sampler);
}

self
}

/// Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf
#[must_use]
#[allow(unused_mut)]
pub fn add_temp(mut self, t: f32) -> Self {
unsafe {
let temp_sampler = llama_cpp_sys_2::llama_sampler_init_temp(t);
llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_sampler);
}

self
}

/// Dynamic temperature implementation (a.k.a. entropy) described in the paper <https://arxiv.org/abs/2309.02772>.
#[must_use]
#[allow(unused_mut)]
pub fn add_temp_ext(mut self, t: f32, delta: f32, exponent: f32) -> Self {
unsafe {
let temp_ext_sampler = llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent);
llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler);
}

self
}

/// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
///
/// # Arguments
///
/// * `candidates` - A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// * `tau` - The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// * `eta` - The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// * `m` - The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// * `mu` - Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
#[must_use]
#[allow(unused_mut)]
pub fn add_mirostat(mut self, n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
unsafe {
let temp_ext_sampler =
llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m);
llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler);
}

self
}

/// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
///
/// # Arguments
///
/// * `candidates` - A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// * `tau` - The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// * `eta` - The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// * `mu` - Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
#[must_use]
#[allow(unused_mut)]
pub fn add_mirostat_v2(mut self, seed: u32, tau: f32, eta: f32) -> Self {
unsafe {
let temp_ext_sampler = llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta);
llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler);
}

self
}

/// Samples constrained by a context-free grammar in the GGML BNF (GBNF) format.
///
/// # Panics
/// Panics if a provided string contains a null byte.
#[must_use]
#[allow(unused_mut)]
pub fn add_grammar(
mut self,
model: &LlamaModel,
grammar_str: &str,
grammar_root: &str,
) -> Self {
unsafe {
let grammar_str = CString::new(grammar_str).unwrap();
let grammar_root = CString::new(grammar_root).unwrap();
let grammar_sampler = llama_cpp_sys_2::llama_sampler_init_grammar(
model.model.as_ptr(),
grammar_str.as_ptr(),
grammar_root.as_ptr(),
);
llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), grammar_sampler);
}

self
}

/// Adds penalties to the sampler. This can be used to penalize certain patterns in the generated text, such as repeating the same token multiple times or using the same token too frequently.
#[allow(unused_mut, clippy::too_many_arguments)]
#[must_use]
pub fn add_penalties(
mut self,
n_vocab: i32,
special_eos_id: i32,
linefeed_id: i32,
penalty_last_n: i32,
penalty_repeat: f32,
penalty_freq: f32,
penalty_present: f32,
penalize_nl: bool,
ignore_eos: bool,
) -> Self {
unsafe {
let temp_ext_sampler = llama_cpp_sys_2::llama_sampler_init_penalties(
n_vocab,
special_eos_id,
linefeed_id,
penalty_last_n,
penalty_repeat,
penalty_freq,
penalty_present,
penalize_nl,
ignore_eos,
);
llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler);
}

self
}

/// Sample and accept a token from the idx-th output of the last evaluation
#[must_use]
pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
let token = unsafe {
llama_cpp_sys_2::llama_sampler_sample(self.sampler.as_ptr(), ctx.context.as_ptr(), idx)
};

LlamaToken(token)
}

/// Accepts a token from the sampler, possibly updating the internal state of certain samplers (e.g. grammar, repetition, etc.)
pub fn accept(&mut self, token: LlamaToken) {
unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler.as_ptr(), token.0) }
}
}

impl Drop for LlamaSampler {
Expand Down
9 changes: 6 additions & 3 deletions llama-cpp-2/src/sampling/params.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//! Safe wrapper around `llama_sampler`.
//! Safe wrapper around `llama_sampler_chain_params`.
use std::fmt::{Debug, Formatter};
use std::ptr::NonNull;

/// A safe wrapper around `llama_sampler`.
pub struct LlamaSamplerChainParams {
Expand All @@ -25,11 +24,15 @@ impl Default for LlamaSamplerChainParams {
}

impl LlamaSamplerChainParams {
pub fn with_no_perf(&mut self, no_perf: bool) -> &mut Self {
/// Set whether to measure performance timings
#[must_use]
pub fn with_no_perf(mut self, no_perf: bool) -> Self {
self.sampler_chain_params.no_perf = no_perf;
self
}

/// Get whether to measure performance timings
#[must_use]
pub fn no_perf(&self) -> bool {
self.sampler_chain_params.no_perf
}
Expand Down

0 comments on commit 6c5430d

Please sign in to comment.