Skip to content

Commit

Permalink
Merge pull request #594 from nkoppel/sampler_api
Browse files Browse the repository at this point in the history
Add sampling API back to LlamaTokenDataArray; Add DRY and XTC Samplers
  • Loading branch information
MarcusDunn authored Dec 9, 2024
2 parents 3d29dbf + 67ea688 commit cf69db5
Show file tree
Hide file tree
Showing 8 changed files with 478 additions and 220 deletions.
9 changes: 4 additions & 5 deletions examples/simple/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::{AddBos, Special};
use llama_cpp_2::sampling::params::LlamaSamplerChainParams;
use llama_cpp_2::sampling::LlamaSampler;

use std::ffi::CString;
Expand Down Expand Up @@ -246,10 +245,10 @@ either reduce n_len or increase n_ctx"
// The `Decoder`
let mut decoder = encoding_rs::UTF_8.new_decoder();

let sampler_params = LlamaSamplerChainParams::default();
let mut sampler = LlamaSampler::new(sampler_params)?
.add_dist(seed.unwrap_or(1234))
.add_greedy();
let mut sampler = LlamaSampler::chain_simple([
LlamaSampler::dist(seed.unwrap_or(1234)),
LlamaSampler::greedy(),
]);

while n_cur <= n_len {
// sample the next token
Expand Down
8 changes: 1 addition & 7 deletions examples/usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::{AddBos, Special};
use llama_cpp_2::sampling::params::LlamaSamplerChainParams;
use llama_cpp_2::sampling::LlamaSampler;
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use std::io::Write;

#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
Expand Down Expand Up @@ -55,11 +53,7 @@ fn main() {

// The `Decoder`
let mut decoder = encoding_rs::UTF_8.new_decoder();

let sampler_params = LlamaSamplerChainParams::default();
let mut sampler = LlamaSampler::new(sampler_params)
.expect("Failed to create sampler")
.add_greedy();
let mut sampler = LlamaSampler::greedy();

while n_cur <= n_len {
// sample the next token
Expand Down
32 changes: 32 additions & 0 deletions llama-cpp-2/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::llama_batch::LlamaBatch;
use crate::model::{LlamaLoraAdapter, LlamaModel};
use crate::timing::LlamaTimings;
use crate::token::data::LlamaTokenData;
use crate::token::data_array::LlamaTokenDataArray;
use crate::token::LlamaToken;
use crate::{
DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
Expand Down Expand Up @@ -202,6 +203,21 @@ impl<'model> LlamaContext<'model> {
})
}

/// Get the token data array for the last token in the context.
///
/// This is a convience method that implements:
/// ```ignore
/// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
/// ```
///
/// # Panics
///
/// - underlying logits data is null
#[must_use]
pub fn token_data_array(&self) -> LlamaTokenDataArray {
LlamaTokenDataArray::from_iter(self.candidates(), false)
}

/// Token logits obtained from the last call to `decode()`.
/// The logits for which `batch.logits[i] != 0` are stored contiguously
/// in the order they have appeared in the batch.
Expand All @@ -217,6 +233,7 @@ impl<'model> LlamaContext<'model> {
///
/// - `n_vocab` does not fit into a usize
/// - token data returned is null
#[must_use]
pub fn get_logits(&self) -> &[f32] {
let data = unsafe { llama_cpp_sys_2::llama_get_logits(self.context.as_ptr()) };
assert!(!data.is_null(), "logits data for last token is null");
Expand All @@ -237,6 +254,21 @@ impl<'model> LlamaContext<'model> {
})
}

/// Get the token data array for the ith token in the context.
///
/// This is a convience method that implements:
/// ```ignore
/// LlamaTokenDataArray::from_iter(ctx.candidates_ith(i), false)
/// ```
///
/// # Panics
///
/// - logit `i` is not initialized.
#[must_use]
pub fn token_data_array_ith(&self, i: i32) -> LlamaTokenDataArray {
LlamaTokenDataArray::from_iter(self.candidates_ith(i), false)
}

/// Get the logits for the ith token in the context.
///
/// # Panics
Expand Down
8 changes: 0 additions & 8 deletions llama-cpp-2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,6 @@ pub enum LlamaLoraAdapterRemoveError {
ErrorResult(i32),
}

/// An error that can occur when initializing a sampler.
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum LlamaSamplerError {
/// llama.cpp returned null
#[error("null reference from llama.cpp")]
NullReturn,
}

/// get the time (in microseconds) according to llama.cpp
/// ```
/// # use llama_cpp_2::llama_time_us;
Expand Down
Loading

0 comments on commit cf69db5

Please sign in to comment.