Skip to content

Commit

Permalink
Add Mirostat to new API
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoppel committed Dec 7, 2024
1 parent 27ebd82 commit d44deef
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 2 deletions.
4 changes: 3 additions & 1 deletion examples/simple/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ either reduce n_len or increase n_ctx"
let mut decoder = encoding_rs::UTF_8.new_decoder();

let mut sampler = LlamaSampler::new(LlamaSamplerParams::chain(&[
LlamaSamplerParams::Dist { seed: seed.unwrap_or(1234) },
LlamaSamplerParams::Dist {
seed: seed.unwrap_or(1234),
},
LlamaSamplerParams::Greedy,
]));

Expand Down
10 changes: 10 additions & 0 deletions llama-cpp-2/src/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ unsafe fn new_inner(params: LlamaSamplerParams) -> *mut llama_cpp_sys_2::llama_s
penalize_nl,
ignore_eos,
),
LlamaSamplerParams::Mirostat {
n_vocab,
tau,
eta,
m,
seed,
} => llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m),
LlamaSamplerParams::MirostatV2 { tau, eta, seed } => {
llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta)
}
LlamaSamplerParams::Dist { seed } => llama_cpp_sys_2::llama_sampler_init_dist(seed),
LlamaSamplerParams::Greedy => llama_cpp_sys_2::llama_sampler_init_greedy(),
}
Expand Down
28 changes: 27 additions & 1 deletion llama-cpp-2/src/sampling/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,30 @@ pub enum LlamaSamplerParams<'a> {
ignore_eos: bool,
},

/// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
Mirostat {
/// ``model.n_vocab()``
n_vocab: i32,
/// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
tau: f32,
/// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
eta: f32,
/// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
m: i32,
/// Seed to initialize random generation with
seed: u32,
},

/// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
MirostatV2 {
/// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
tau: f32,
/// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
eta: f32,
/// Seed to initialize random generation with
seed: u32,
},

/// Select a token at random based on each token's probabilities
Dist {
/// Seed to initialize random generation with
Expand Down Expand Up @@ -146,7 +170,7 @@ impl<'a> LlamaSamplerParams<'a> {
Self::MinP { p, min_keep: 1 }
}

/// Whether this sampler's outputs are dependent on the tokens in the model's context.
/// Whether this sampler's outputs are dependent on the tokens in the model's context.
pub(crate) fn uses_context_tokens(&self) -> bool {
match self {
LlamaSamplerParams::Chain { stages, .. } => {
Expand All @@ -164,6 +188,8 @@ impl<'a> LlamaSamplerParams<'a> {
| LlamaSamplerParams::TopP { .. }
| LlamaSamplerParams::MinP { .. }
| LlamaSamplerParams::Xtc { .. }
| LlamaSamplerParams::Mirostat { .. }
| LlamaSamplerParams::MirostatV2 { .. }
| LlamaSamplerParams::Dist { .. }
| LlamaSamplerParams::Greedy => false,
}
Expand Down

0 comments on commit d44deef

Please sign in to comment.