diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index d7078dd..8946da2 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -206,7 +206,7 @@ impl<'model> LlamaContext<'model> { /// Get the token data array for the last token in the context. /// /// This is a convience method that implements: - /// ```no_run + /// ```ignore /// LlamaTokenDataArray::from_iter(ctx.candidates(), false) /// ``` /// @@ -257,7 +257,7 @@ impl<'model> LlamaContext<'model> { /// Get the token data array for the ith token in the context. /// /// This is a convience method that implements: - /// ```no_run + /// ```ignore /// LlamaTokenDataArray::from_iter(ctx.candidates_ith(i), false) /// ``` /// diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index 01b2432..090c866 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -1,4 +1,4 @@ -//! an rusty equivalent of `llama_token_data`. +//! an rusty equivalent of `llama_token_data_array`. use std::{ffi::CString, ptr}; use crate::{model::LlamaModel, token::data::LlamaTokenData}; @@ -11,8 +11,8 @@ use super::LlamaToken; pub struct LlamaTokenDataArray { /// the underlying data pub data: Vec, - /// the selected token - pub selected: i64, + /// the index of the selected token in ``data`` + pub selected: Option, /// is the data sorted? pub sorted: bool, } @@ -35,7 +35,7 @@ impl LlamaTokenDataArray { pub fn new(data: Vec, sorted: bool) -> Self { Self { data, - selected: -1, + selected: None, sorted, } } @@ -60,9 +60,7 @@ impl LlamaTokenDataArray { #[must_use] pub fn selected_token(&self) -> Option { - self.data - .get(usize::try_from(self.selected).ok()?) - .map(LlamaTokenData::id) + self.data.get(self.selected?).map(LlamaTokenData::id) } } @@ -82,29 +80,40 @@ impl LlamaTokenDataArray { modify: impl FnOnce(&mut llama_cpp_sys_2::llama_token_data_array) -> T, ) -> T { let size = self.data.len(); - let data = self.data.as_mut_ptr().cast(); + let data = self + .data + .as_mut_ptr() + .cast::(); + let mut c_llama_token_data_array = llama_cpp_sys_2::llama_token_data_array { data, size, - selected: self.selected, + selected: self.selected.and_then(|s| s.try_into().ok()).unwrap_or(-1), sorted: self.sorted, }; + let result = modify(&mut c_llama_token_data_array); assert!( ptr::eq(data, c_llama_token_data_array.data), "data pointer changed" ); assert!(c_llama_token_data_array.size <= size, "size increased"); + self.data.set_len(c_llama_token_data_array.size); self.sorted = c_llama_token_data_array.sorted; - self.selected = c_llama_token_data_array.selected; + self.selected = c_llama_token_data_array + .selected + .try_into() + .ok() + .filter(|&s| s < self.data.len()); + result } pub(crate) unsafe fn apply_sampler(&mut self, sampler: *mut llama_cpp_sys_2::llama_sampler) { self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { llama_cpp_sys_2::llama_sampler_apply(sampler, c_llama_token_data_array); - }) + }); } pub(crate) unsafe fn apply_and_free_sampler(