Skip to content

Commit

Permalink
Implement new sampler API
Browse files Browse the repository at this point in the history
This commit implements the new sampler API from `llama.cpp`
introduced in b3680 and removes the custom sampling logic.

The new sampling API is exposes through a builder pattern.

Made tests pass.
  • Loading branch information
volesen committed Nov 26, 2024
1 parent 5c27009 commit 1f41e6e
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 141 deletions.
26 changes: 14 additions & 12 deletions examples/simple/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::{AddBos, Special};
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use llama_cpp_2::sampling::params::LlamaSamplerChainParams;
use llama_cpp_2::sampling::LlamaSampler;

use std::ffi::CString;
use std::io::Write;
use std::num::NonZeroU32;
Expand Down Expand Up @@ -174,9 +176,9 @@ fn main() -> Result<()> {
.with_context(|| "unable to load model")?;

// initialize the context
let mut ctx_params = LlamaContextParams::default()
.with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap())))
.with_seed(seed.unwrap_or(1234));
let mut ctx_params =
LlamaContextParams::default().with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap())));

if let Some(threads) = threads {
ctx_params = ctx_params.with_n_threads(threads);
}
Expand Down Expand Up @@ -244,31 +246,31 @@ either reduce n_len or increase n_ctx"
// The `Decoder`
let mut decoder = encoding_rs::UTF_8.new_decoder();

let sampler_params = LlamaSamplerChainParams::default();
let mut sampler = LlamaSampler::new(sampler_params)?.add_dist(seed.unwrap_or(1234));

while n_cur <= n_len {
// sample the next token
{
let candidates = ctx.candidates();

let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
let token = sampler.sample(&ctx, batch.n_tokens() - 1);

// sample the most likely token
let new_token_id = ctx.sample_token_greedy(candidates_p);
sampler.accept(token);

// is it an end of stream?
if model.is_eog_token(new_token_id) {
if model.is_eog_token(token) {
eprintln!();
break;
}

let output_bytes = model.token_to_bytes(new_token_id, Special::Tokenize)?;
let output_bytes = model.token_to_bytes(token, Special::Tokenize)?;
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
let mut output_string = String::with_capacity(32);
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
print!("{output_string}");
std::io::stdout().flush()?;

batch.clear();
batch.add(new_token_id, n_cur, &[0], true)?;
batch.add(token, n_cur, &[0], true)?;
}

n_cur += 1;
Expand Down
22 changes: 12 additions & 10 deletions examples/usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::{AddBos, Special};
use llama_cpp_2::sampling::params::LlamaSamplerChainParams;
use llama_cpp_2::sampling::LlamaSampler;
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use std::io::Write;

Expand Down Expand Up @@ -54,33 +56,33 @@ fn main() {
// The `Decoder`
let mut decoder = encoding_rs::UTF_8.new_decoder();

let sampler_params = LlamaSamplerChainParams::default();
let mut sampler = LlamaSampler::new(sampler_params)
.expect("Failed to create sampler")
.add_greedy();

while n_cur <= n_len {
// sample the next token
{
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);

let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
let token = sampler.sample(&ctx, batch.n_tokens() - 1);

// sample the most likely token
let new_token_id = ctx.sample_token_greedy(candidates_p);
sampler.accept(token);

// is it an end of stream?
if new_token_id == model.token_eos() {
if token == model.token_eos() {
eprintln!();
break;
}

let output_bytes = model
.token_to_bytes(new_token_id, Special::Tokenize)
.unwrap();
let output_bytes = model.token_to_bytes(token, Special::Tokenize).unwrap();
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
let mut output_string = String::with_capacity(32);
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
print!("{output_string}");
std::io::stdout().flush().unwrap();

batch.clear();
batch.add(new_token_id, n_cur, &[0], true).unwrap();
batch.add(token, n_cur, &[0], true).unwrap();
}

n_cur += 1;
Expand Down
6 changes: 2 additions & 4 deletions llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl From<RopeScalingType> for i32 {
pub enum LlamaPoolingType {
/// The pooling type is unspecified
Unspecified = -1,
/// No pooling
/// No pooling
None = 0,
/// Mean pooling
Mean = 1,
Expand Down Expand Up @@ -95,10 +95,8 @@ impl From<LlamaPoolingType> for i32 {
/// use llama_cpp_2::context::params::LlamaContextParams;
///
///let ctx_params = LlamaContextParams::default()
/// .with_n_ctx(NonZeroU32::new(2048))
/// .with_seed(1234);
/// .with_n_ctx(NonZeroU32::new(2048));
///
/// assert_eq!(ctx_params.seed(), 1234);
/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
/// ```
#[derive(Debug, Clone)]
Expand Down
112 changes: 0 additions & 112 deletions llama-cpp-2/src/context/sample/sampler.rs

This file was deleted.

Loading

0 comments on commit 1f41e6e

Please sign in to comment.