From 07ab0ebd4b511d70d8dba2cbe91280a03e82a8e2 Mon Sep 17 00:00:00 2001 From: marcus Date: Thu, 29 Feb 2024 10:12:23 -0800 Subject: [PATCH 1/3] updated llama cpp and removed cast to mut --- llama-cpp-2/src/context/kv_cache.rs | 48 ++++++++++++++++++++--------- llama-cpp-2/src/context/session.rs | 4 +-- llama-cpp-sys-2/llama.cpp | 2 +- 3 files changed, 36 insertions(+), 18 deletions(-) diff --git a/llama-cpp-2/src/context/kv_cache.rs b/llama-cpp-2/src/context/kv_cache.rs index 27b68093..9e61e980 100644 --- a/llama-cpp-2/src/context/kv_cache.rs +++ b/llama-cpp-2/src/context/kv_cache.rs @@ -1,7 +1,8 @@ //! utilities for working with the kv cache -use std::num::NonZeroU8; use crate::context::LlamaContext; +use std::ffi::c_int; +use std::num::NonZeroU8; impl LlamaContext<'_> { /// Copy the cache from one sequence to another. @@ -106,14 +107,20 @@ impl LlamaContext<'_> { /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1]. /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0]. /// * `d` - The factor to divide the positions by - pub fn kv_cache_seq_div(&mut self, seq_id: i32, p0: Option, p1: Option, d: NonZeroU8) { + pub fn kv_cache_seq_div( + &mut self, + seq_id: i32, + p0: Option, + p1: Option, + d: NonZeroU8, + ) { unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div( self.context.as_ptr(), seq_id, p0.map_or(-1, i32::from), p1.map_or(-1, i32::from), - d.get().try_into().expect("d does not fit into a i32"), + c_int::from(d.get()), ) } } @@ -154,12 +161,12 @@ impl LlamaContext<'_> { /// if there are more sequences in a cell than this value, however they will /// not be visible in the view cells_sequences. pub fn new_kv_cache_view(&self, n_max_seq: i32) -> KVCacheView { - let view = unsafe { llama_cpp_sys_2::llama_kv_cache_view_init(self.context.as_ptr(), n_max_seq) }; + let view = + unsafe { llama_cpp_sys_2::llama_kv_cache_view_init(self.context.as_ptr(), n_max_seq) }; KVCacheView { view, ctx: self } } } - /// Information associated with an individual cell in the KV cache view. #[derive(Debug)] pub struct KVCacheViewCell { @@ -178,7 +185,9 @@ pub struct KVCacheView<'a> { impl<'a> KVCacheView<'a> { /// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) pub fn update(&mut self) { - unsafe { llama_cpp_sys_2::llama_kv_cache_view_update(self.ctx.context.as_ptr(), &mut self.view) } + unsafe { + llama_cpp_sys_2::llama_kv_cache_view_update(self.ctx.context.as_ptr(), &mut self.view) + } } /// Number of KV cache cells. This will be the same as the context size. @@ -210,16 +219,27 @@ impl<'a> KVCacheView<'a> { } /// Information for individual cells. - pub fn cells(&self) -> impl Iterator { - unsafe { std::slice::from_raw_parts(self.view.cells, self.view.n_cells.try_into().unwrap()) } - .iter() - .map(|&cell| KVCacheViewCell { pos: cell.pos }) + pub fn cells(&self) -> impl Iterator { + unsafe { + std::slice::from_raw_parts( + self.view.cells, + usize::try_from(self.view.n_cells).expect("failed to fit n_cells into usize"), + ) + } + .iter() + .map(|&cell| KVCacheViewCell { pos: cell.pos }) } /// The sequences for each cell. There will be n_max_seq items per cell. - pub fn cells_sequences(&self) -> impl Iterator { - unsafe { std::slice::from_raw_parts(self.view.cells_sequences, (self.view.n_cells * self.view.n_max_seq).try_into().unwrap()) } - .chunks(self.view.n_max_seq.try_into().unwrap()) + pub fn cells_sequences(&self) -> impl Iterator { + unsafe { + std::slice::from_raw_parts( + self.view.cells_sequences, + usize::try_from(self.view.n_cells * self.view.n_max_seq) + .expect("failed to fit n_cells * n_max_seq into usize"), + ) + } + .chunks(usize::try_from(self.view.n_max_seq).expect("failed to fit n_max_seq into usize")) } } @@ -229,4 +249,4 @@ impl<'a> Drop for KVCacheView<'a> { llama_cpp_sys_2::llama_kv_cache_view_free(&mut self.view); } } -} \ No newline at end of file +} diff --git a/llama-cpp-2/src/context/session.rs b/llama-cpp-2/src/context/session.rs index 61bb3825..9d76b55e 100644 --- a/llama-cpp-2/src/context/session.rs +++ b/llama-cpp-2/src/context/session.rs @@ -150,9 +150,7 @@ impl LlamaContext<'_> { /// Returns the number of bytes read pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize { unsafe { - // we don't really need a mutable pointer for `src` -- this is a llama-cpp lapse, - // so we cast away the constness - llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr() as *mut u8) + llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) } } } diff --git a/llama-cpp-sys-2/llama.cpp b/llama-cpp-sys-2/llama.cpp index 87c91c07..d5ab2975 160000 --- a/llama-cpp-sys-2/llama.cpp +++ b/llama-cpp-sys-2/llama.cpp @@ -1 +1 @@ -Subproject commit 87c91c07663b707e831c59ec373b5e665ff9d64a +Subproject commit d5ab29757ebc59a30f03e408294ec20628a6374e From 0b2867c8761ed7ab17e964c1ef8393b95f6981c8 Mon Sep 17 00:00:00 2001 From: marcus Date: Thu, 29 Feb 2024 10:24:48 -0800 Subject: [PATCH 2/3] added some safety comments and decreased unsafe scope --- llama-cpp-2/src/context/kv_cache.rs | 42 ++++++++--------------- llama-cpp-2/src/context/session.rs | 52 +++++++++++++++++------------ 2 files changed, 43 insertions(+), 51 deletions(-) diff --git a/llama-cpp-2/src/context/kv_cache.rs b/llama-cpp-2/src/context/kv_cache.rs index 9e61e980..64b3ea13 100644 --- a/llama-cpp-2/src/context/kv_cache.rs +++ b/llama-cpp-2/src/context/kv_cache.rs @@ -25,14 +25,10 @@ impl LlamaContext<'_> { /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to [p1]. /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from [p0]. pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option, p1: Option) { + let p0 = p0.map_or(-1, i32::from); + let p1 = p1.map_or(-1, i32::from); unsafe { - llama_cpp_sys_2::llama_kv_cache_seq_cp( - self.context.as_ptr(), - src, - dest, - p0.map_or(-1, i32::from), - p1.map_or(-1, i32::from), - ) + llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1) } } @@ -44,13 +40,10 @@ impl LlamaContext<'_> { /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to [p1]. /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from [p0]. pub fn clear_kv_cache_seq(&mut self, src: i32, p0: Option, p1: Option) { + let p0 = p0.map_or(-1, i32::from); + let p1 = p1.map_or(-1, i32::from); unsafe { - llama_cpp_sys_2::llama_kv_cache_seq_rm( - self.context.as_ptr(), - src, - p0.map_or(-1, i32::from), - p1.map_or(-1, i32::from), - ); + llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1); } } @@ -85,14 +78,10 @@ impl LlamaContext<'_> { /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0]. /// * `delta` - The relative position to add to the tokens pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option, p1: Option, delta: i32) { + let p0 = p0.map_or(-1, i32::from); + let p1 = p1.map_or(-1, i32::from); unsafe { - llama_cpp_sys_2::llama_kv_cache_seq_add( - self.context.as_ptr(), - seq_id, - p0.map_or(-1, i32::from), - p1.map_or(-1, i32::from), - delta, - ) + llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta) } } @@ -114,15 +103,10 @@ impl LlamaContext<'_> { p1: Option, d: NonZeroU8, ) { - unsafe { - llama_cpp_sys_2::llama_kv_cache_seq_div( - self.context.as_ptr(), - seq_id, - p0.map_or(-1, i32::from), - p1.map_or(-1, i32::from), - c_int::from(d.get()), - ) - } + let p0 = p0.map_or(-1, i32::from); + let p1 = p1.map_or(-1, i32::from); + let d = c_int::from(d.get()); + unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) } } /// Returns the largest position present in the KV cache for the specified sequence diff --git a/llama-cpp-2/src/context/session.rs b/llama-cpp-2/src/context/session.rs index 9d76b55e..2a9bc0a4 100644 --- a/llama-cpp-2/src/context/session.rs +++ b/llama-cpp-2/src/context/session.rs @@ -105,29 +105,33 @@ impl LlamaContext<'_> { .ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?; let cstr = CString::new(path)?; - let mut tokens = Vec::with_capacity(max_tokens); + let mut tokens: Vec = Vec::with_capacity(max_tokens); let mut n_out = 0; - unsafe { - if llama_cpp_sys_2::llama_load_session_file( + // SAFETY: cast is valid as LlamaToken is repr(transparent) + let tokens_out = tokens.as_mut_ptr().cast::(); + + let load_session_success = unsafe { + llama_cpp_sys_2::llama_load_session_file( self.context.as_ptr(), cstr.as_ptr(), - // cast is valid as LlamaToken is repr(transparent) - Vec::::as_mut_ptr(&mut tokens).cast::(), + tokens_out, max_tokens, &mut n_out, - ) { - if n_out > max_tokens { - return Err(LoadSessionError::InsufficientMaxLength { - n_out, - max_tokens, - }); - } - tokens.set_len(n_out); - Ok(tokens) + ) + }; + if load_session_success { + if n_out > max_tokens { + return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens }); } else { - Err(LoadSessionError::FailedToLoad) + // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written + unsafe { + tokens.set_len(n_out); + } } + Ok(tokens) + } else { + Err(LoadSessionError::FailedToLoad) } } @@ -138,19 +142,23 @@ impl LlamaContext<'_> { } /// Copies the state to the specified destination address. - /// Destination needs to have allocated enough memory. + /// /// Returns the number of bytes copied + /// + /// # Safety + /// + /// Destination needs to have allocated enough memory. pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize { - unsafe { - llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest) - } + unsafe { llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest) } } /// Set the state reading from the specified address /// Returns the number of bytes read + /// + /// # Safety + /// + /// help wanted: not entirely sure what the safety requirements are here. pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize { - unsafe { - llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) - } + unsafe { llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) } } } From 035cb5748a3c71d93f70a2fe4a90c4ae7c9568b2 Mon Sep 17 00:00:00 2001 From: marcus Date: Thu, 29 Feb 2024 10:32:14 -0800 Subject: [PATCH 3/3] clippy --- llama-cpp-2/src/context/kv_cache.rs | 44 ++++++++++++++++++++--------- llama-cpp-2/src/context/session.rs | 12 ++++---- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/llama-cpp-2/src/context/kv_cache.rs b/llama-cpp-2/src/context/kv_cache.rs index 64b3ea13..a2596836 100644 --- a/llama-cpp-2/src/context/kv_cache.rs +++ b/llama-cpp-2/src/context/kv_cache.rs @@ -28,7 +28,7 @@ impl LlamaContext<'_> { let p0 = p0.map_or(-1, i32::from); let p1 = p1.map_or(-1, i32::from); unsafe { - llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1) + llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1); } } @@ -48,6 +48,7 @@ impl LlamaContext<'_> { } /// Returns the number of used KV cells (i.e. have at least one sequence assigned to them) + #[must_use] pub fn get_kv_cache_used_cells(&self) -> i32 { unsafe { llama_cpp_sys_2::llama_get_kv_cache_used_cells(self.context.as_ptr()) } } @@ -68,8 +69,8 @@ impl LlamaContext<'_> { /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) /// If the KV cache is RoPEd, the KV data is updated accordingly: - /// - lazily on next llama_decode() - /// - explicitly with llama_kv_cache_update() + /// - lazily on next [`LlamaContext::decode`] + /// - explicitly with [`Self::kv_cache_update`] /// /// # Parameters /// @@ -81,14 +82,14 @@ impl LlamaContext<'_> { let p0 = p0.map_or(-1, i32::from); let p1 = p1.map_or(-1, i32::from); unsafe { - llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta) + llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta); } } /// Integer division of the positions by factor of `d > 1` - /// If the KV cache is RoPEd, the KV data is updated accordingly: - /// - lazily on next llama_decode() - /// - explicitly with llama_kv_cache_update() + /// If the KV cache is `RoPEd`, the KV data is updated accordingly: + /// - lazily on next [`LlamaContext::decode`] + /// - explicitly with [`Self::kv_cache_update`] /// /// # Parameters /// @@ -114,14 +115,15 @@ impl LlamaContext<'_> { /// # Parameters /// /// * `seq_id` - The sequence id to get the max position for + #[must_use] pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 { unsafe { llama_cpp_sys_2::llama_kv_cache_seq_pos_max(self.context.as_ptr(), seq_id) } } /// Defragment the KV cache /// This will be applied: - /// - lazily on next llama_decode() - /// - explicitly with llama_kv_cache_update() + /// - lazily on next [`LlamaContext::decode`] + /// - explicitly with [`Self::kv_cache_update`] pub fn kv_cache_defrag(&mut self) { unsafe { llama_cpp_sys_2::llama_kv_cache_defrag(self.context.as_ptr()) } } @@ -133,6 +135,7 @@ impl LlamaContext<'_> { /// Returns the number of tokens in the KV cache (slow, use only for debug) /// If a KV cell has multiple sequences assigned to it, it will be counted multiple times + #[must_use] pub fn get_kv_cache_token_count(&self) -> i32 { unsafe { llama_cpp_sys_2::llama_get_kv_cache_token_count(self.context.as_ptr()) } } @@ -143,7 +146,8 @@ impl LlamaContext<'_> { /// /// * `n_max_seq` - Maximum number of sequences that can exist in a cell. It's not an error /// if there are more sequences in a cell than this value, however they will - /// not be visible in the view cells_sequences. + /// not be visible in the view `cells_sequences`. + #[must_use] pub fn new_kv_cache_view(&self, n_max_seq: i32) -> KVCacheView { let view = unsafe { llama_cpp_sys_2::llama_kv_cache_view_init(self.context.as_ptr(), n_max_seq) }; @@ -170,11 +174,12 @@ impl<'a> KVCacheView<'a> { /// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) pub fn update(&mut self) { unsafe { - llama_cpp_sys_2::llama_kv_cache_view_update(self.ctx.context.as_ptr(), &mut self.view) + llama_cpp_sys_2::llama_kv_cache_view_update(self.ctx.context.as_ptr(), &mut self.view); } } /// Number of KV cache cells. This will be the same as the context size. + #[must_use] pub fn n_cells(&self) -> i32 { self.view.n_cells } @@ -182,27 +187,35 @@ impl<'a> KVCacheView<'a> { /// Number of tokens in the cache. For example, if there are two populated /// cells, the first with 1 sequence id in it and the second with 2 sequence /// ids then you'll have 3 tokens. + #[must_use] pub fn token_count(&self) -> i32 { self.view.token_count } /// Number of populated cache cells. + #[must_use] pub fn used_cells(&self) -> i32 { self.view.used_cells } /// Maximum contiguous empty slots in the cache. + #[must_use] pub fn max_contiguous(&self) -> i32 { self.view.max_contiguous } - /// Index to the start of the max_contiguous slot range. Can be negative + /// Index to the start of the `max_contiguous` slot range. Can be negative /// when cache is full. + #[must_use] pub fn max_contiguous_idx(&self) -> i32 { self.view.max_contiguous_idx } /// Information for individual cells. + /// + /// # Panics + /// + /// - if `n_cells` does not fit into usize. pub fn cells(&self) -> impl Iterator { unsafe { std::slice::from_raw_parts( @@ -214,7 +227,12 @@ impl<'a> KVCacheView<'a> { .map(|&cell| KVCacheViewCell { pos: cell.pos }) } - /// The sequences for each cell. There will be n_max_seq items per cell. + /// The sequences for each cell. There will be `n_max_seq` items per cell. + /// + /// # Panics + /// + /// - if `n_cells * n_max_seq` does not fit into usize. + /// - if `n_max_seq` does not fit into usize. pub fn cells_sequences(&self) -> impl Iterator { unsafe { std::slice::from_raw_parts( diff --git a/llama-cpp-2/src/context/session.rs b/llama-cpp-2/src/context/session.rs index 2a9bc0a4..7eee031b 100644 --- a/llama-cpp-2/src/context/session.rs +++ b/llama-cpp-2/src/context/session.rs @@ -123,11 +123,10 @@ impl LlamaContext<'_> { if load_session_success { if n_out > max_tokens { return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens }); - } else { - // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written - unsafe { - tokens.set_len(n_out); - } + } + // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written + unsafe { + tokens.set_len(n_out); } Ok(tokens) } else { @@ -136,7 +135,8 @@ impl LlamaContext<'_> { } /// Returns the maximum size in bytes of the state (rng, logits, embedding - /// and kv_cache) - will often be smaller after compacting tokens + /// and `kv_cache`) - will often be smaller after compacting tokens + #[must_use] pub fn get_state_size(&self) -> usize { unsafe { llama_cpp_sys_2::llama_get_state_size(self.context.as_ptr()) } }