From 411a679f57328c632d5111b8fb3c0789fb1397ef Mon Sep 17 00:00:00 2001 From: marcus Date: Thu, 14 Mar 2024 13:08:46 -0700 Subject: [PATCH] added microstat --- llama-cpp-2/src/context/sample/sampler.rs | 6 +++-- llama-cpp-2/src/token/data_array.rs | 30 +++++++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/llama-cpp-2/src/context/sample/sampler.rs b/llama-cpp-2/src/context/sample/sampler.rs index 74906a7c..cfe90499 100644 --- a/llama-cpp-2/src/context/sample/sampler.rs +++ b/llama-cpp-2/src/context/sample/sampler.rs @@ -54,8 +54,10 @@ pub type SampleFinalizer = dyn Fn(LlamaTokenDataArray, &mut C) -> Vec { /// The steps to take when sampling. pub steps: Vec<&'a SampleStep>, diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index 0f89d59f..776d222a 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -344,4 +344,34 @@ impl LlamaTokenDataArray { }); } } + + /// Mirostat 2.0 algorithm described in the [paper](https://arxiv.org/abs/2007.14966). Uses tokens instead of words. + /// + /// # Parameters + /// + /// * `tau` The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// * `eta` The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// * `mu` Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + pub fn sample_token_mirostat_v2( + &mut self, + ctx: &mut LlamaContext, + tau: f32, + eta: f32, + mu: &mut f32, + ) -> LlamaToken { + let mu_ptr = ptr::from_mut(mu); + let token = unsafe { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_token_mirostat_v2( + ctx.context.as_ptr(), + c_llama_token_data_array, + tau, + eta, + mu_ptr, + ) + }) + }; + *mu = unsafe { *mu_ptr }; + LlamaToken(token) + } }