Skip to content

Commit

Permalink
added microstat
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusDunn committed Mar 14, 2024
1 parent c920513 commit 411a679
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
6 changes: 4 additions & 2 deletions llama-cpp-2/src/context/sample/sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ pub type SampleFinalizer<C> = dyn Fn(LlamaTokenDataArray, &mut C) -> Vec<LlamaTo

/// A series of sampling steps that will produce a vector of token data.
///
/// `C` is dynamic context that will be passed to the sampling functions. I expect `C` will
/// often be [`()`], [`crate::context::LlamaContext`] or a token history (or some combination of these).
/// `C` is dynamic context that will be passed to the sampling functions. Some sampling steps may
/// require state to be maintained across multiple samples, and this context can be used to store
/// that state. For example, [`LlamaTokenDataArray::sample_token_mirostat_v2`] requires a `mu` to be
/// shared across multiple samples.
pub struct Sampler<'a, C> {
/// The steps to take when sampling.
pub steps: Vec<&'a SampleStep<C>>,
Expand Down
30 changes: 30 additions & 0 deletions llama-cpp-2/src/token/data_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,4 +344,34 @@ impl LlamaTokenDataArray {
});
}
}

/// Mirostat 2.0 algorithm described in the [paper](https://arxiv.org/abs/2007.14966). Uses tokens instead of words.
///
/// # Parameters
///
/// * `tau` The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// * `eta` The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// * `mu` Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
pub fn sample_token_mirostat_v2(
&mut self,
ctx: &mut LlamaContext,
tau: f32,
eta: f32,
mu: &mut f32,
) -> LlamaToken {
let mu_ptr = ptr::from_mut(mu);
let token = unsafe {
self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_sys_2::llama_sample_token_mirostat_v2(
ctx.context.as_ptr(),
c_llama_token_data_array,
tau,
eta,
mu_ptr,
)
})
};
*mu = unsafe { *mu_ptr };
LlamaToken(token)
}
}

0 comments on commit 411a679

Please sign in to comment.