Skip to content

Commit

Permalink
Merge pull request #188 from utilityai/sampler
Browse files Browse the repository at this point in the history
Sampler
  • Loading branch information
MarcusDunn authored Mar 14, 2024
2 parents f34c6dc + 411a679 commit e78ddc1
Show file tree
Hide file tree
Showing 11 changed files with 390 additions and 213 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/llama-cpp-rs-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
- name: Fmt
run: cargo fmt
- name: Test
run: cargo test
run: cargo test --features sampler
arm64:
name: Check that it builds on various targets
runs-on: ubuntu-latest
Expand Down Expand Up @@ -67,7 +67,7 @@ jobs:
- name: Setup Rust
uses: dtolnay/rust-toolchain@stable
- name: Build
run: cargo build
run: cargo build --features sampler
windows:
name: Check that it builds on windows
runs-on: windows-latest
Expand All @@ -79,4 +79,4 @@ jobs:
- name: Setup Rust
uses: dtolnay/rust-toolchain@stable
- name: Build
run: cargo build
run: cargo build --features sampler
1 change: 1 addition & 0 deletions llama-cpp-2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ harness = false

[features]
cublas = ["llama-cpp-sys-2/cublas"]
sampler = []

[lints]
workspace = true
2 changes: 1 addition & 1 deletion llama-cpp-2/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl<'model> LlamaContext<'model> {
///
/// # Panics
///
/// - the returned [`c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
/// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
let result =
unsafe { llama_cpp_sys_2::llama_decode(self.context.as_ptr(), batch.llama_batch) };
Expand Down
18 changes: 9 additions & 9 deletions llama-cpp-2/src/context/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ impl LlamaContext<'_> {
///
/// * `src` - The sequence id to copy the cache from.
/// * `dest` - The sequence id to copy the cache to.
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to [p1].
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from [p0].
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option<u16>, p1: Option<u16>) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
Expand All @@ -37,8 +37,8 @@ impl LlamaContext<'_> {
/// # Parameters
///
/// * `src` - The sequence id to clear the cache for.
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to [p1].
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from [p0].
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
pub fn clear_kv_cache_seq(&mut self, src: i32, p0: Option<u16>, p1: Option<u16>) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
Expand Down Expand Up @@ -68,16 +68,16 @@ impl LlamaContext<'_> {
}

#[allow(clippy::doc_markdown)]
/// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
/// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)`
/// If the KV cache is RoPEd, the KV data is updated accordingly:
/// - lazily on next [`LlamaContext::decode`]
/// - explicitly with [`Self::kv_cache_update`]
///
/// # Parameters
///
/// * `seq_id` - The sequence id to update
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1].
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
/// * `delta` - The relative position to add to the tokens
pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option<u16>, p1: Option<u16>, delta: i32) {
let p0 = p0.map_or(-1, i32::from);
Expand All @@ -95,8 +95,8 @@ impl LlamaContext<'_> {
/// # Parameters
///
/// * `seq_id` - The sequence id to update
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1].
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
/// * `d` - The factor to divide the positions by
pub fn kv_cache_seq_div(
&mut self,
Expand Down
225 changes: 36 additions & 189 deletions llama-cpp-2/src/context/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,130 +5,10 @@ use crate::grammar::LlamaGrammar;
use crate::token::data_array::LlamaTokenDataArray;
use crate::token::LlamaToken;

/// struct to hold params for sampling
#[derive(Debug)]
#[deprecated(
since = "0.1.32",
note = "this does not scale well with many params and does not allow for changing of orders."
)]
pub struct Sampler<'grammar> {
token_data_array: LlamaTokenDataArray,
grammar: Option<&'grammar mut LlamaGrammar>,
temperature: Option<f32>,
}

impl<'grammar> Sampler<'grammar> {
#[deprecated(
since = "0.1.32",
note = "this does not scale well with many params and does not allow for changing of orders."
)]
fn sample(self, llama_context: &mut LlamaContext) -> LlamaToken {
match self {
Sampler {
token_data_array,
grammar: None,
temperature: None,
} => llama_context.sample_token_greedy(token_data_array),
Sampler {
mut token_data_array,
grammar: Some(grammar),
temperature: None,
} => {
llama_context.sample_grammar(&mut token_data_array, grammar);
let token = llama_context.sample_token_greedy(token_data_array);
llama_context.grammar_accept_token(grammar, token);
token
}
Sampler {
mut token_data_array,
grammar: None,
temperature: Some(temp),
} => {
llama_context.sample_temp(&mut token_data_array, temp);
llama_context.sample_token_softmax(&mut token_data_array);
token_data_array.data[0].id()
}
Sampler {
mut token_data_array,
grammar: Some(grammar),
temperature: Some(temperature),
} => {
llama_context.sample_grammar(&mut token_data_array, grammar);
llama_context.sample_temp(&mut token_data_array, temperature);
llama_context.sample_token_softmax(&mut token_data_array);
let token = llama_context.sample_token_greedy(token_data_array);
llama_context.grammar_accept_token(grammar, token);
token
}
}
}

/// Create a new sampler.
#[must_use]
#[deprecated(
since = "0.1.32",
note = "this does not scale well with many params and does not allow for changing of orders."
)]
pub fn new(llama_token_data_array: LlamaTokenDataArray) -> Self {
Self {
token_data_array: llama_token_data_array,
grammar: None,
temperature: None,
}
}

/// Set the grammar for sampling.
#[must_use]
#[deprecated(
since = "0.1.32",
note = "this does not scale well with many params and does not allow for changing of orders."
)]
pub fn with_grammar(mut self, grammar: &'grammar mut LlamaGrammar) -> Self {
self.grammar = Some(grammar);
self
}

/// Set the temperature for sampling.
///
/// ```
/// # use llama_cpp_2::context::LlamaContext;
/// # use llama_cpp_2::context::sample::Sampler;
/// # use llama_cpp_2::grammar::LlamaGrammar;
/// # use llama_cpp_2::token::data::LlamaTokenData;
/// # use llama_cpp_2::token::data_array::LlamaTokenDataArray;
/// # use llama_cpp_2::token::LlamaToken;
///
/// let _sampler = Sampler::new(LlamaTokenDataArray::new(vec![LlamaTokenData::new(LlamaToken(0), 0.0, 0.0)], false))
/// .with_temperature(0.5);
/// ```
#[must_use]
#[deprecated(
since = "0.1.32",
note = "this does not scale well with many params and does not allow for changing of orders."
)]
pub fn with_temperature(mut self, temperature: f32) -> Self {
if temperature == 0.0 {
return self;
}
self.temperature = Some(temperature);
self
}
}
#[cfg(feature = "sampler")]
pub mod sampler;

impl LlamaContext<'_> {
/// Sample a token.
///
/// # Panics
///
/// - sampler contains no tokens
#[deprecated(
since = "0.1.32",
note = "this does not scale well with many params and does not allow for changing of orders."
)]
pub fn sample(&mut self, sampler: Sampler) -> LlamaToken {
sampler.sample(self)
}

/// Accept a token into the grammar.
pub fn grammar_accept_token(&mut self, grammar: &mut LlamaGrammar, token: LlamaToken) {
unsafe {
Expand Down Expand Up @@ -157,38 +37,20 @@ impl LlamaContext<'_> {
}
}

/// Modify [`token_data`] in place using temperature sampling.
///
/// # Panics
///
/// - [`temperature`] is not between 0.0 and 1.0
pub fn sample_temp(&self, token_data: &mut LlamaTokenDataArray, temperature: f32) {
assert!(
temperature >= 0.0,
"temperature must be positive (was {temperature})"
);
assert!(
temperature <= 1.0,
"temperature must be less than or equal to 1.0 (was {temperature})"
);
if temperature == 0.0 {
return;
}
let ctx: *mut llama_cpp_sys_2::llama_context = self.context.as_ptr();
unsafe {
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_sys_2::llama_sample_temp(ctx, c_llama_token_data_array, temperature);
});
}
/// See [`LlamaTokenDataArray::sample_temp`]
pub fn sample_temp(&mut self, token_data: &mut LlamaTokenDataArray, temperature: f32) {
token_data.sample_temp(Some(self), temperature);
}

/// Sample a token greedily.
/// Sample a token greedily. Note that this *does not* take into account anything that has modified the probabilities - it only looks at logits.
///
/// Most of the time [`LlamaTokenDataArray::sample_softmax`] or [`LlamaTokenDataArray::sample_token`] should be used instead.
///
/// # Panics
///
/// - [`token_data`] is empty
/// - if `token_data` is empty
#[must_use]
pub fn sample_token_greedy(&self, mut token_data: LlamaTokenDataArray) -> LlamaToken {
pub fn sample_token_greedy(&mut self, mut token_data: LlamaTokenDataArray) -> LlamaToken {
assert!(!token_data.data.is_empty(), "no tokens");
let mut data_arr = llama_cpp_sys_2::llama_token_data_array {
data: token_data
Expand All @@ -207,39 +69,34 @@ impl LlamaContext<'_> {
LlamaToken(token)
}

/// Tail Free Sampling described in [Tail-Free-Sampling](https://www.trentonbricken.com/Tail-Free-Sampling/).
pub fn sample_tail_free(&self, token_data: &mut LlamaTokenDataArray, z: f32, min_keep: usize) {
let ctx = self.context.as_ptr();
unsafe {
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_sys_2::llama_sample_tail_free(ctx, c_llama_token_data_array, z, min_keep);
});
}
/// See [`LlamaTokenDataArray::sample_tail_free`]
pub fn sample_tail_free(
&mut self,
token_data: &mut LlamaTokenDataArray,
z: f32,
min_keep: usize,
) {
token_data.sample_tail_free(Some(self), z, min_keep);
}

/// Locally Typical Sampling implementation described in the [paper](https://arxiv.org/abs/2202.00666).
pub fn sample_typical(&self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) {
let ctx = self.context.as_ptr();
unsafe {
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_sys_2::llama_sample_typical(ctx, c_llama_token_data_array, p, min_keep);
});
}
/// See [`LlamaTokenDataArray::sample_typical`]
pub fn sample_typical(
&mut self,
token_data: &mut LlamaTokenDataArray,
p: f32,
min_keep: usize,
) {
token_data.sample_typical(Some(self), p, min_keep);
}

/// Nucleus sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751)"
pub fn sample_top_p(&self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) {
let ctx = self.context.as_ptr();
unsafe {
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_sys_2::llama_sample_top_p(ctx, c_llama_token_data_array, p, min_keep);
});
}
/// See [`LlamaTokenDataArray::sample_top_p`]
pub fn sample_top_p(&mut self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) {
token_data.sample_top_p(Some(self), p, min_keep);
}

/// Minimum P sampling as described in [#3841](https://github.com/ggerganov/llama.cpp/pull/3841)
pub fn sample_min_p(
&self,
&mut self,
llama_token_data: &mut LlamaTokenDataArray,
p: f32,
min_keep: usize,
Expand All @@ -252,24 +109,14 @@ impl LlamaContext<'_> {
}
}

/// Top-K sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751)
pub fn sample_top_k(&self, token_data: &mut LlamaTokenDataArray, k: i32, min_keep: usize) {
let ctx = self.context.as_ptr();
unsafe {
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_sys_2::llama_sample_top_k(ctx, c_llama_token_data_array, k, min_keep);
});
}
/// See [`LlamaTokenDataArray::sample_top_k`]
pub fn sample_top_k(&mut self, token_data: &mut LlamaTokenDataArray, k: i32, min_keep: usize) {
token_data.sample_top_k(Some(self), k, min_keep);
}

/// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
pub fn sample_token_softmax(&self, token_data: &mut LlamaTokenDataArray) {
let ctx = self.context.as_ptr();
unsafe {
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_sys_2::llama_sample_softmax(ctx, c_llama_token_data_array);
});
}
/// See [`LlamaTokenDataArray::sample_softmax`]
pub fn sample_token_softmax(&mut self, token_data: &mut LlamaTokenDataArray) {
token_data.sample_softmax(Some(self));
}

/// See [`LlamaTokenDataArray::sample_repetition_penalty`]
Expand Down
Loading

0 comments on commit e78ddc1

Please sign in to comment.