diff --git a/llama-cpp-2/src/context/kv_cache.rs b/llama-cpp-2/src/context/kv_cache.rs index 27b68093..a2596836 100644 --- a/llama-cpp-2/src/context/kv_cache.rs +++ b/llama-cpp-2/src/context/kv_cache.rs @@ -1,7 +1,8 @@ //! utilities for working with the kv cache -use std::num::NonZeroU8; use crate::context::LlamaContext; +use std::ffi::c_int; +use std::num::NonZeroU8; impl LlamaContext<'_> { /// Copy the cache from one sequence to another. @@ -24,14 +25,10 @@ impl LlamaContext<'_> { /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to [p1]. /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from [p0]. pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option, p1: Option) { + let p0 = p0.map_or(-1, i32::from); + let p1 = p1.map_or(-1, i32::from); unsafe { - llama_cpp_sys_2::llama_kv_cache_seq_cp( - self.context.as_ptr(), - src, - dest, - p0.map_or(-1, i32::from), - p1.map_or(-1, i32::from), - ) + llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1); } } @@ -43,17 +40,15 @@ impl LlamaContext<'_> { /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to [p1]. /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from [p0]. pub fn clear_kv_cache_seq(&mut self, src: i32, p0: Option, p1: Option) { + let p0 = p0.map_or(-1, i32::from); + let p1 = p1.map_or(-1, i32::from); unsafe { - llama_cpp_sys_2::llama_kv_cache_seq_rm( - self.context.as_ptr(), - src, - p0.map_or(-1, i32::from), - p1.map_or(-1, i32::from), - ); + llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1); } } /// Returns the number of used KV cells (i.e. have at least one sequence assigned to them) + #[must_use] pub fn get_kv_cache_used_cells(&self) -> i32 { unsafe { llama_cpp_sys_2::llama_get_kv_cache_used_cells(self.context.as_ptr()) } } @@ -74,8 +69,8 @@ impl LlamaContext<'_> { /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) /// If the KV cache is RoPEd, the KV data is updated accordingly: - /// - lazily on next llama_decode() - /// - explicitly with llama_kv_cache_update() + /// - lazily on next [`LlamaContext::decode`] + /// - explicitly with [`Self::kv_cache_update`] /// /// # Parameters /// @@ -84,21 +79,17 @@ impl LlamaContext<'_> { /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0]. /// * `delta` - The relative position to add to the tokens pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option, p1: Option, delta: i32) { + let p0 = p0.map_or(-1, i32::from); + let p1 = p1.map_or(-1, i32::from); unsafe { - llama_cpp_sys_2::llama_kv_cache_seq_add( - self.context.as_ptr(), - seq_id, - p0.map_or(-1, i32::from), - p1.map_or(-1, i32::from), - delta, - ) + llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta); } } /// Integer division of the positions by factor of `d > 1` - /// If the KV cache is RoPEd, the KV data is updated accordingly: - /// - lazily on next llama_decode() - /// - explicitly with llama_kv_cache_update() + /// If the KV cache is `RoPEd`, the KV data is updated accordingly: + /// - lazily on next [`LlamaContext::decode`] + /// - explicitly with [`Self::kv_cache_update`] /// /// # Parameters /// @@ -106,16 +97,17 @@ impl LlamaContext<'_> { /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1]. /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0]. /// * `d` - The factor to divide the positions by - pub fn kv_cache_seq_div(&mut self, seq_id: i32, p0: Option, p1: Option, d: NonZeroU8) { - unsafe { - llama_cpp_sys_2::llama_kv_cache_seq_div( - self.context.as_ptr(), - seq_id, - p0.map_or(-1, i32::from), - p1.map_or(-1, i32::from), - d.get().try_into().expect("d does not fit into a i32"), - ) - } + pub fn kv_cache_seq_div( + &mut self, + seq_id: i32, + p0: Option, + p1: Option, + d: NonZeroU8, + ) { + let p0 = p0.map_or(-1, i32::from); + let p1 = p1.map_or(-1, i32::from); + let d = c_int::from(d.get()); + unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) } } /// Returns the largest position present in the KV cache for the specified sequence @@ -123,14 +115,15 @@ impl LlamaContext<'_> { /// # Parameters /// /// * `seq_id` - The sequence id to get the max position for + #[must_use] pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 { unsafe { llama_cpp_sys_2::llama_kv_cache_seq_pos_max(self.context.as_ptr(), seq_id) } } /// Defragment the KV cache /// This will be applied: - /// - lazily on next llama_decode() - /// - explicitly with llama_kv_cache_update() + /// - lazily on next [`LlamaContext::decode`] + /// - explicitly with [`Self::kv_cache_update`] pub fn kv_cache_defrag(&mut self) { unsafe { llama_cpp_sys_2::llama_kv_cache_defrag(self.context.as_ptr()) } } @@ -142,6 +135,7 @@ impl LlamaContext<'_> { /// Returns the number of tokens in the KV cache (slow, use only for debug) /// If a KV cell has multiple sequences assigned to it, it will be counted multiple times + #[must_use] pub fn get_kv_cache_token_count(&self) -> i32 { unsafe { llama_cpp_sys_2::llama_get_kv_cache_token_count(self.context.as_ptr()) } } @@ -152,14 +146,15 @@ impl LlamaContext<'_> { /// /// * `n_max_seq` - Maximum number of sequences that can exist in a cell. It's not an error /// if there are more sequences in a cell than this value, however they will - /// not be visible in the view cells_sequences. + /// not be visible in the view `cells_sequences`. + #[must_use] pub fn new_kv_cache_view(&self, n_max_seq: i32) -> KVCacheView { - let view = unsafe { llama_cpp_sys_2::llama_kv_cache_view_init(self.context.as_ptr(), n_max_seq) }; + let view = + unsafe { llama_cpp_sys_2::llama_kv_cache_view_init(self.context.as_ptr(), n_max_seq) }; KVCacheView { view, ctx: self } } } - /// Information associated with an individual cell in the KV cache view. #[derive(Debug)] pub struct KVCacheViewCell { @@ -178,10 +173,13 @@ pub struct KVCacheView<'a> { impl<'a> KVCacheView<'a> { /// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) pub fn update(&mut self) { - unsafe { llama_cpp_sys_2::llama_kv_cache_view_update(self.ctx.context.as_ptr(), &mut self.view) } + unsafe { + llama_cpp_sys_2::llama_kv_cache_view_update(self.ctx.context.as_ptr(), &mut self.view); + } } /// Number of KV cache cells. This will be the same as the context size. + #[must_use] pub fn n_cells(&self) -> i32 { self.view.n_cells } @@ -189,37 +187,61 @@ impl<'a> KVCacheView<'a> { /// Number of tokens in the cache. For example, if there are two populated /// cells, the first with 1 sequence id in it and the second with 2 sequence /// ids then you'll have 3 tokens. + #[must_use] pub fn token_count(&self) -> i32 { self.view.token_count } /// Number of populated cache cells. + #[must_use] pub fn used_cells(&self) -> i32 { self.view.used_cells } /// Maximum contiguous empty slots in the cache. + #[must_use] pub fn max_contiguous(&self) -> i32 { self.view.max_contiguous } - /// Index to the start of the max_contiguous slot range. Can be negative + /// Index to the start of the `max_contiguous` slot range. Can be negative /// when cache is full. + #[must_use] pub fn max_contiguous_idx(&self) -> i32 { self.view.max_contiguous_idx } /// Information for individual cells. - pub fn cells(&self) -> impl Iterator { - unsafe { std::slice::from_raw_parts(self.view.cells, self.view.n_cells.try_into().unwrap()) } - .iter() - .map(|&cell| KVCacheViewCell { pos: cell.pos }) + /// + /// # Panics + /// + /// - if `n_cells` does not fit into usize. + pub fn cells(&self) -> impl Iterator { + unsafe { + std::slice::from_raw_parts( + self.view.cells, + usize::try_from(self.view.n_cells).expect("failed to fit n_cells into usize"), + ) + } + .iter() + .map(|&cell| KVCacheViewCell { pos: cell.pos }) } - /// The sequences for each cell. There will be n_max_seq items per cell. - pub fn cells_sequences(&self) -> impl Iterator { - unsafe { std::slice::from_raw_parts(self.view.cells_sequences, (self.view.n_cells * self.view.n_max_seq).try_into().unwrap()) } - .chunks(self.view.n_max_seq.try_into().unwrap()) + /// The sequences for each cell. There will be `n_max_seq` items per cell. + /// + /// # Panics + /// + /// - if `n_cells * n_max_seq` does not fit into usize. + /// - if `n_max_seq` does not fit into usize. + pub fn cells_sequences(&self) -> impl Iterator { + unsafe { + std::slice::from_raw_parts( + self.view.cells_sequences, + usize::try_from(self.view.n_cells * self.view.n_max_seq) + .expect("failed to fit n_cells * n_max_seq into usize"), + ) + } + .chunks(usize::try_from(self.view.n_max_seq).expect("failed to fit n_max_seq into usize")) } } @@ -229,4 +251,4 @@ impl<'a> Drop for KVCacheView<'a> { llama_cpp_sys_2::llama_kv_cache_view_free(&mut self.view); } } -} \ No newline at end of file +} diff --git a/llama-cpp-2/src/context/session.rs b/llama-cpp-2/src/context/session.rs index 61bb3825..7eee031b 100644 --- a/llama-cpp-2/src/context/session.rs +++ b/llama-cpp-2/src/context/session.rs @@ -105,54 +105,60 @@ impl LlamaContext<'_> { .ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?; let cstr = CString::new(path)?; - let mut tokens = Vec::with_capacity(max_tokens); + let mut tokens: Vec = Vec::with_capacity(max_tokens); let mut n_out = 0; - unsafe { - if llama_cpp_sys_2::llama_load_session_file( + // SAFETY: cast is valid as LlamaToken is repr(transparent) + let tokens_out = tokens.as_mut_ptr().cast::(); + + let load_session_success = unsafe { + llama_cpp_sys_2::llama_load_session_file( self.context.as_ptr(), cstr.as_ptr(), - // cast is valid as LlamaToken is repr(transparent) - Vec::::as_mut_ptr(&mut tokens).cast::(), + tokens_out, max_tokens, &mut n_out, - ) { - if n_out > max_tokens { - return Err(LoadSessionError::InsufficientMaxLength { - n_out, - max_tokens, - }); - } + ) + }; + if load_session_success { + if n_out > max_tokens { + return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens }); + } + // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written + unsafe { tokens.set_len(n_out); - Ok(tokens) - } else { - Err(LoadSessionError::FailedToLoad) } + Ok(tokens) + } else { + Err(LoadSessionError::FailedToLoad) } } /// Returns the maximum size in bytes of the state (rng, logits, embedding - /// and kv_cache) - will often be smaller after compacting tokens + /// and `kv_cache`) - will often be smaller after compacting tokens + #[must_use] pub fn get_state_size(&self) -> usize { unsafe { llama_cpp_sys_2::llama_get_state_size(self.context.as_ptr()) } } /// Copies the state to the specified destination address. - /// Destination needs to have allocated enough memory. + /// /// Returns the number of bytes copied + /// + /// # Safety + /// + /// Destination needs to have allocated enough memory. pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize { - unsafe { - llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest) - } + unsafe { llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest) } } /// Set the state reading from the specified address /// Returns the number of bytes read + /// + /// # Safety + /// + /// help wanted: not entirely sure what the safety requirements are here. pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize { - unsafe { - // we don't really need a mutable pointer for `src` -- this is a llama-cpp lapse, - // so we cast away the constness - llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr() as *mut u8) - } + unsafe { llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) } } } diff --git a/llama-cpp-sys-2/llama.cpp b/llama-cpp-sys-2/llama.cpp index 87c91c07..d5ab2975 160000 --- a/llama-cpp-sys-2/llama.cpp +++ b/llama-cpp-sys-2/llama.cpp @@ -1 +1 @@ -Subproject commit 87c91c07663b707e831c59ec373b5e665ff9d64a +Subproject commit d5ab29757ebc59a30f03e408294ec20628a6374e