Skip to content

Commit

Permalink
Make LlamaTokenDataArray::selected an Option<usize>
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoppel committed Dec 7, 2024
1 parent 6d3dec9 commit ca07170
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
4 changes: 2 additions & 2 deletions llama-cpp-2/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ impl<'model> LlamaContext<'model> {
/// Get the token data array for the last token in the context.
///
/// This is a convience method that implements:
/// ```no_run
/// ```ignore
/// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
/// ```
///
Expand Down Expand Up @@ -257,7 +257,7 @@ impl<'model> LlamaContext<'model> {
/// Get the token data array for the ith token in the context.
///
/// This is a convience method that implements:
/// ```no_run
/// ```ignore
/// LlamaTokenDataArray::from_iter(ctx.candidates_ith(i), false)
/// ```
///
Expand Down
31 changes: 20 additions & 11 deletions llama-cpp-2/src/token/data_array.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! an rusty equivalent of `llama_token_data`.
//! an rusty equivalent of `llama_token_data_array`.
use std::{ffi::CString, ptr};

use crate::{model::LlamaModel, token::data::LlamaTokenData};
Expand All @@ -11,8 +11,8 @@ use super::LlamaToken;
pub struct LlamaTokenDataArray {
/// the underlying data
pub data: Vec<LlamaTokenData>,
/// the selected token
pub selected: i64,
/// the index of the selected token in ``data``
pub selected: Option<usize>,
/// is the data sorted?
pub sorted: bool,
}
Expand All @@ -35,7 +35,7 @@ impl LlamaTokenDataArray {
pub fn new(data: Vec<LlamaTokenData>, sorted: bool) -> Self {
Self {
data,
selected: -1,
selected: None,
sorted,
}
}
Expand All @@ -60,9 +60,7 @@ impl LlamaTokenDataArray {

#[must_use]
pub fn selected_token(&self) -> Option<LlamaToken> {
self.data
.get(usize::try_from(self.selected).ok()?)
.map(LlamaTokenData::id)
self.data.get(self.selected?).map(LlamaTokenData::id)
}
}

Expand All @@ -82,29 +80,40 @@ impl LlamaTokenDataArray {
modify: impl FnOnce(&mut llama_cpp_sys_2::llama_token_data_array) -> T,
) -> T {
let size = self.data.len();
let data = self.data.as_mut_ptr().cast();
let data = self
.data
.as_mut_ptr()
.cast::<llama_cpp_sys_2::llama_token_data>();

let mut c_llama_token_data_array = llama_cpp_sys_2::llama_token_data_array {
data,
size,
selected: self.selected,
selected: self.selected.and_then(|s| s.try_into().ok()).unwrap_or(-1),
sorted: self.sorted,
};

let result = modify(&mut c_llama_token_data_array);
assert!(
ptr::eq(data, c_llama_token_data_array.data),
"data pointer changed"
);
assert!(c_llama_token_data_array.size <= size, "size increased");

self.data.set_len(c_llama_token_data_array.size);
self.sorted = c_llama_token_data_array.sorted;
self.selected = c_llama_token_data_array.selected;
self.selected = c_llama_token_data_array
.selected
.try_into()
.ok()
.filter(|&s| s < self.data.len());

result
}

pub(crate) unsafe fn apply_sampler(&mut self, sampler: *mut llama_cpp_sys_2::llama_sampler) {
self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_sys_2::llama_sampler_apply(sampler, c_llama_token_data_array);
})
});
}

pub(crate) unsafe fn apply_and_free_sampler(
Expand Down

0 comments on commit ca07170

Please sign in to comment.