From d44deef3d9c5e50aaf86ca9a0e707c50389ccd2d Mon Sep 17 00:00:00 2001
From: Nathan Koppel <nathankoppel0@gmail.com>
Date: Sat, 7 Dec 2024 13:25:33 -0600
Subject: [PATCH] Add Mirostat to new API

---
 examples/simple/src/main.rs        |  4 +++-
 llama-cpp-2/src/sampling.rs        | 10 ++++++++++
 llama-cpp-2/src/sampling/params.rs | 28 +++++++++++++++++++++++++++-
 3 files changed, 40 insertions(+), 2 deletions(-)

diff --git a/examples/simple/src/main.rs b/examples/simple/src/main.rs
index f31a83c..e13274f 100644
--- a/examples/simple/src/main.rs
+++ b/examples/simple/src/main.rs
@@ -247,7 +247,9 @@ either reduce n_len or increase n_ctx"
     let mut decoder = encoding_rs::UTF_8.new_decoder();
 
     let mut sampler = LlamaSampler::new(LlamaSamplerParams::chain(&[
-        LlamaSamplerParams::Dist { seed: seed.unwrap_or(1234) },
+        LlamaSamplerParams::Dist {
+            seed: seed.unwrap_or(1234),
+        },
         LlamaSamplerParams::Greedy,
     ]));
 
diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs
index e0313f4..2c94532 100644
--- a/llama-cpp-2/src/sampling.rs
+++ b/llama-cpp-2/src/sampling.rs
@@ -112,6 +112,16 @@ unsafe fn new_inner(params: LlamaSamplerParams) -> *mut llama_cpp_sys_2::llama_s
             penalize_nl,
             ignore_eos,
         ),
+        LlamaSamplerParams::Mirostat {
+            n_vocab,
+            tau,
+            eta,
+            m,
+            seed,
+        } => llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m),
+        LlamaSamplerParams::MirostatV2 { tau, eta, seed } => {
+            llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta)
+        }
         LlamaSamplerParams::Dist { seed } => llama_cpp_sys_2::llama_sampler_init_dist(seed),
         LlamaSamplerParams::Greedy => llama_cpp_sys_2::llama_sampler_init_greedy(),
     }
diff --git a/llama-cpp-2/src/sampling/params.rs b/llama-cpp-2/src/sampling/params.rs
index fe5d23e..84cdbc3 100644
--- a/llama-cpp-2/src/sampling/params.rs
+++ b/llama-cpp-2/src/sampling/params.rs
@@ -85,6 +85,30 @@ pub enum LlamaSamplerParams<'a> {
         ignore_eos: bool,
     },
 
+    /// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
+    Mirostat {
+        /// ``model.n_vocab()``
+        n_vocab: i32,
+        /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
+        tau: f32,
+        /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
+        eta: f32,
+        /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
+        m: i32,
+        /// Seed to initialize random generation with
+        seed: u32,
+    },
+
+    /// Mirostat 2.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
+    MirostatV2 {
+        /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
+        tau: f32,
+        /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
+        eta: f32,
+        /// Seed to initialize random generation with
+        seed: u32,
+    },
+
     /// Select a token at random based on each token's probabilities
     Dist {
         /// Seed to initialize random generation with
@@ -146,7 +170,7 @@ impl<'a> LlamaSamplerParams<'a> {
         Self::MinP { p, min_keep: 1 }
     }
 
-    /// Whether this sampler's outputs are dependent on the tokens in the model's context. 
+    /// Whether this sampler's outputs are dependent on the tokens in the model's context.
     pub(crate) fn uses_context_tokens(&self) -> bool {
         match self {
             LlamaSamplerParams::Chain { stages, .. } => {
@@ -164,6 +188,8 @@ impl<'a> LlamaSamplerParams<'a> {
             | LlamaSamplerParams::TopP { .. }
             | LlamaSamplerParams::MinP { .. }
             | LlamaSamplerParams::Xtc { .. }
+            | LlamaSamplerParams::Mirostat { .. }
+            | LlamaSamplerParams::MirostatV2 { .. }
             | LlamaSamplerParams::Dist { .. }
             | LlamaSamplerParams::Greedy => false,
         }