From d44deef3d9c5e50aaf86ca9a0e707c50389ccd2d Mon Sep 17 00:00:00 2001 From: Nathan Koppel <nathankoppel0@gmail.com> Date: Sat, 7 Dec 2024 13:25:33 -0600 Subject: [PATCH] Add Mirostat to new API --- examples/simple/src/main.rs | 4 +++- llama-cpp-2/src/sampling.rs | 10 ++++++++++ llama-cpp-2/src/sampling/params.rs | 28 +++++++++++++++++++++++++++- 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/examples/simple/src/main.rs b/examples/simple/src/main.rs index f31a83c..e13274f 100644 --- a/examples/simple/src/main.rs +++ b/examples/simple/src/main.rs @@ -247,7 +247,9 @@ either reduce n_len or increase n_ctx" let mut decoder = encoding_rs::UTF_8.new_decoder(); let mut sampler = LlamaSampler::new(LlamaSamplerParams::chain(&[ - LlamaSamplerParams::Dist { seed: seed.unwrap_or(1234) }, + LlamaSamplerParams::Dist { + seed: seed.unwrap_or(1234), + }, LlamaSamplerParams::Greedy, ])); diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index e0313f4..2c94532 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -112,6 +112,16 @@ unsafe fn new_inner(params: LlamaSamplerParams) -> *mut llama_cpp_sys_2::llama_s penalize_nl, ignore_eos, ), + LlamaSamplerParams::Mirostat { + n_vocab, + tau, + eta, + m, + seed, + } => llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m), + LlamaSamplerParams::MirostatV2 { tau, eta, seed } => { + llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) + } LlamaSamplerParams::Dist { seed } => llama_cpp_sys_2::llama_sampler_init_dist(seed), LlamaSamplerParams::Greedy => llama_cpp_sys_2::llama_sampler_init_greedy(), } diff --git a/llama-cpp-2/src/sampling/params.rs b/llama-cpp-2/src/sampling/params.rs index fe5d23e..84cdbc3 100644 --- a/llama-cpp-2/src/sampling/params.rs +++ b/llama-cpp-2/src/sampling/params.rs @@ -85,6 +85,30 @@ pub enum LlamaSamplerParams<'a> { ignore_eos: bool, }, + /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words. + Mirostat { + /// ``model.n_vocab()`` + n_vocab: i32, + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + tau: f32, + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + eta: f32, + /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. + m: i32, + /// Seed to initialize random generation with + seed: u32, + }, + + /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words. + MirostatV2 { + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + tau: f32, + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + eta: f32, + /// Seed to initialize random generation with + seed: u32, + }, + /// Select a token at random based on each token's probabilities Dist { /// Seed to initialize random generation with @@ -146,7 +170,7 @@ impl<'a> LlamaSamplerParams<'a> { Self::MinP { p, min_keep: 1 } } - /// Whether this sampler's outputs are dependent on the tokens in the model's context. + /// Whether this sampler's outputs are dependent on the tokens in the model's context. pub(crate) fn uses_context_tokens(&self) -> bool { match self { LlamaSamplerParams::Chain { stages, .. } => { @@ -164,6 +188,8 @@ impl<'a> LlamaSamplerParams<'a> { | LlamaSamplerParams::TopP { .. } | LlamaSamplerParams::MinP { .. } | LlamaSamplerParams::Xtc { .. } + | LlamaSamplerParams::Mirostat { .. } + | LlamaSamplerParams::MirostatV2 { .. } | LlamaSamplerParams::Dist { .. } | LlamaSamplerParams::Greedy => false, }