Skip to content

Commit

Permalink
Merge pull request #580 from nobodywho-ooo/bump-llama-cpp
Browse files Browse the repository at this point in the history
Implement new sampler API and bump llama.cpp
  • Loading branch information
MarcusDunn authored Nov 27, 2024
2 parents 77af620 + 1f41e6e commit 42aaeeb
Show file tree
Hide file tree
Showing 13 changed files with 343 additions and 1,226 deletions.
26 changes: 14 additions & 12 deletions examples/simple/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::{AddBos, Special};
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use llama_cpp_2::sampling::params::LlamaSamplerChainParams;
use llama_cpp_2::sampling::LlamaSampler;

use std::ffi::CString;
use std::io::Write;
use std::num::NonZeroU32;
Expand Down Expand Up @@ -174,9 +176,9 @@ fn main() -> Result<()> {
.with_context(|| "unable to load model")?;

// initialize the context
let mut ctx_params = LlamaContextParams::default()
.with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap())))
.with_seed(seed.unwrap_or(1234));
let mut ctx_params =
LlamaContextParams::default().with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap())));

if let Some(threads) = threads {
ctx_params = ctx_params.with_n_threads(threads);
}
Expand Down Expand Up @@ -244,31 +246,31 @@ either reduce n_len or increase n_ctx"
// The `Decoder`
let mut decoder = encoding_rs::UTF_8.new_decoder();

let sampler_params = LlamaSamplerChainParams::default();
let mut sampler = LlamaSampler::new(sampler_params)?.add_dist(seed.unwrap_or(1234));

while n_cur <= n_len {
// sample the next token
{
let candidates = ctx.candidates();

let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
let token = sampler.sample(&ctx, batch.n_tokens() - 1);

// sample the most likely token
let new_token_id = ctx.sample_token_greedy(candidates_p);
sampler.accept(token);

// is it an end of stream?
if model.is_eog_token(new_token_id) {
if model.is_eog_token(token) {
eprintln!();
break;
}

let output_bytes = model.token_to_bytes(new_token_id, Special::Tokenize)?;
let output_bytes = model.token_to_bytes(token, Special::Tokenize)?;
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
let mut output_string = String::with_capacity(32);
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
print!("{output_string}");
std::io::stdout().flush()?;

batch.clear();
batch.add(new_token_id, n_cur, &[0], true)?;
batch.add(token, n_cur, &[0], true)?;
}

n_cur += 1;
Expand Down
22 changes: 12 additions & 10 deletions examples/usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::{AddBos, Special};
use llama_cpp_2::sampling::params::LlamaSamplerChainParams;
use llama_cpp_2::sampling::LlamaSampler;
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use std::io::Write;

Expand Down Expand Up @@ -54,33 +56,33 @@ fn main() {
// The `Decoder`
let mut decoder = encoding_rs::UTF_8.new_decoder();

let sampler_params = LlamaSamplerChainParams::default();
let mut sampler = LlamaSampler::new(sampler_params)
.expect("Failed to create sampler")
.add_greedy();

while n_cur <= n_len {
// sample the next token
{
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);

let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
let token = sampler.sample(&ctx, batch.n_tokens() - 1);

// sample the most likely token
let new_token_id = ctx.sample_token_greedy(candidates_p);
sampler.accept(token);

// is it an end of stream?
if new_token_id == model.token_eos() {
if token == model.token_eos() {
eprintln!();
break;
}

let output_bytes = model
.token_to_bytes(new_token_id, Special::Tokenize)
.unwrap();
let output_bytes = model.token_to_bytes(token, Special::Tokenize).unwrap();
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
let mut output_string = String::with_capacity(32);
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
print!("{output_string}");
std::io::stdout().flush().unwrap();

batch.clear();
batch.add(new_token_id, n_cur, &[0], true).unwrap();
batch.add(token, n_cur, &[0], true).unwrap();
}

n_cur += 1;
Expand Down
5 changes: 2 additions & 3 deletions llama-cpp-2/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use crate::{

pub mod kv_cache;
pub mod params;
pub mod sample;
pub mod session;

/// Safe wrapper around `llama_context`.
Expand Down Expand Up @@ -267,12 +266,12 @@ impl<'model> LlamaContext<'model> {

/// Reset the timings for the context.
pub fn reset_timings(&mut self) {
unsafe { llama_cpp_sys_2::llama_reset_timings(self.context.as_ptr()) }
unsafe { llama_cpp_sys_2::llama_perf_context_reset(self.context.as_ptr()) }
}

/// Returns the timings for the context.
pub fn timings(&mut self) -> LlamaTimings {
let timings = unsafe { llama_cpp_sys_2::llama_get_timings(self.context.as_ptr()) };
let timings = unsafe { llama_cpp_sys_2::llama_perf_context(self.context.as_ptr()) };
LlamaTimings { timings }
}

Expand Down
37 changes: 2 additions & 35 deletions llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl From<RopeScalingType> for i32 {
pub enum LlamaPoolingType {
/// The pooling type is unspecified
Unspecified = -1,
/// No pooling
/// No pooling
None = 0,
/// Mean pooling
Mean = 1,
Expand Down Expand Up @@ -95,10 +95,8 @@ impl From<LlamaPoolingType> for i32 {
/// use llama_cpp_2::context::params::LlamaContextParams;
///
///let ctx_params = LlamaContextParams::default()
/// .with_n_ctx(NonZeroU32::new(2048))
/// .with_seed(1234);
/// .with_n_ctx(NonZeroU32::new(2048));
///
/// assert_eq!(ctx_params.seed(), 1234);
/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
/// ```
#[derive(Debug, Clone)]
Expand All @@ -116,37 +114,6 @@ unsafe impl Send for LlamaContextParams {}
unsafe impl Sync for LlamaContextParams {}

impl LlamaContextParams {
/// Set the seed of the context
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default();
/// let params = params.with_seed(1234);
/// assert_eq!(params.seed(), 1234);
/// ```
#[must_use]
pub fn with_seed(mut self, seed: u32) -> Self {
self.context_params.seed = seed;
self
}

/// Get the seed of the context
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default()
/// .with_seed(1234);
/// assert_eq!(params.seed(), 1234);
/// ```
#[must_use]
pub fn seed(&self) -> u32 {
self.context_params.seed
}

/// Set the side of the context
///
/// # Examples
Expand Down
141 changes: 0 additions & 141 deletions llama-cpp-2/src/context/sample.rs

This file was deleted.

Loading

0 comments on commit 42aaeeb

Please sign in to comment.